% This function trains a neural network language model.
function [model] = train(epochs)
% Inputs:
%   epochs: Number of epochs to run.
% Output:
%   model: A struct containing the learned weights and biases and vocabulary.

if size(ver('Octave'),1)
  OctaveMode = 1;
  warning('error', 'Octave:broadcast');
  start_time = time;
else
  OctaveMode = 0;
  start_time = clock;
end

% SET HYPERPARAMETERS HERE.
batchsize = 100;  % Mini-batch size.
learning_rate = 0.1;  % Learning rate; default = 0.1.
momentum = 0.9;  % Momentum; default = 0.9.
numhid1 = 50;  % Dimensionality of embedding space; default = 50.
numhid2 = 200;  % Number of units in hidden layer; default = 200.
init_wt = 0.01;  % Standard deviation of the normal distribution
                 % which is sampled to get the initial weights; default = 0.01

% VARIABLES FOR TRACKING TRAINING PROGRESS.
show_training_CE_after = 100;
show_validation_CE_after = 1000;

% LOAD DATA.
[train_input, train_target, valid_input, valid_target, ...
  test_input, test_target, vocab] = load_data(batchsize);
[numwords, batchsize, numbatches] = size(train_input);
vocab_size = size(vocab, 2);

% INITIALIZE WEIGHTS AND BIASES.
word_embedding_weights = init_wt * randn(vocab_size, numhid1);
embed_to_hid_weights = init_wt * randn(numwords * numhid1, numhid2);
hid_to_output_weights = init_wt * randn(numhid2, vocab_size);
hid_bias = zeros(numhid2, 1);
output_bias = zeros(vocab_size, 1);

word_embedding_weights_delta = zeros(vocab_size, numhid1);
word_embedding_weights_gradient = zeros(vocab_size, numhid1);
embed_to_hid_weights_delta = zeros(numwords * numhid1, numhid2);
hid_to_output_weights_delta = zeros(numhid2, vocab_size);
hid_bias_delta = zeros(numhid2, 1);
output_bias_delta = zeros(vocab_size, 1);
expansion_matrix = eye(vocab_size);
count = 0;
tiny = exp(-30);
trainset_CE = 0;

% TRAIN.
for epoch = 1:epochs
  fprintf(1, 'Epoch %d\n', epoch);
  this_chunk_CE = 0;
  trainset_CE = 0;
  % LOOP OVER MINI-BATCHES.
  for m = 1:numbatches
    input_batch = train_input(:, :, m);
    target_batch = train_target(:, :, m);

    % FORWARD PROPAGATE.
    % Compute the state of each layer in the network given the input batch
    % and all weights and biases
    [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
      fprop(input_batch, ...
            word_embedding_weights, embed_to_hid_weights, ...
            hid_to_output_weights, hid_bias, output_bias);

    % COMPUTE DERIVATIVE.
    %% Expand the target to a sparse 1-of-K vector.
    expanded_target_batch = expansion_matrix(:, target_batch);
    %% Compute derivative of cross-entropy loss function.
    %%% vocab_size X batchsize
    error_deriv = output_layer_state - expanded_target_batch;

    % MEASURE LOSS FUNCTION.
    CE = -sum(sum(...
      expanded_target_batch .* log(output_layer_state + tiny))) / batchsize;
    count =  count + 1;
    this_chunk_CE = this_chunk_CE + (CE - this_chunk_CE) / count;
    trainset_CE = trainset_CE + (CE - trainset_CE) / m;
    fprintf(1, '\rBatch %d Train CE %.3f', m, this_chunk_CE);
    if mod(m, show_training_CE_after) == 0
      fprintf(1, '\n');
      count = 0;
      this_chunk_CE = 0;
    end
    if OctaveMode
      fflush(1);
    end

    % BACK PROPAGATE.
    %% OUTPUT LAYER.
    %%% numhid2 X vocab_size
    hid_to_output_weights_gradient =  hidden_layer_state * error_deriv';
    %%% vocab_size
    output_bias_gradient = sum(error_deriv, 2);
    %%% numhid2 X batchsize
    back_propagated_deriv_1 = (hid_to_output_weights * error_deriv) ...
      .* hidden_layer_state .* (1 - hidden_layer_state);

    %% HIDDEN LAYER.
    % FILL IN CODE. Replace the line below by one of the options.
    % embed_to_hid_weights_gradient = zeros(numhid1 * numwords, numhid2);
    embed_to_hid_weights_gradient = embedding_layer_state * back_propagated_deriv_1';
    % Options:
    % (a) embed_to_hid_weights_gradient = back_propagated_deriv_1' * embedding_layer_state;
    % (b) embed_to_hid_weights_gradient = embedding_layer_state * back_propagated_deriv_1';
    % (c) embed_to_hid_weights_gradient = back_propagated_deriv_1;
    % (d) embed_to_hid_weights_gradient = embedding_layer_state;

    % FILL IN CODE. Replace the line below by one of the options.
    % hid_bias_gradient = zeros(numhid2, 1);
    hid_bias_gradient = sum(back_propagated_deriv_1, 2);
    % Options
    % (a) hid_bias_gradient = sum(back_propagated_deriv_1, 2);
    % (b) hid_bias_gradient = sum(back_propagated_deriv_1, 1);
    % (c) hid_bias_gradient = back_propagated_deriv_1;
    % (d) hid_bias_gradient = back_propagated_deriv_1';

    % FILL IN CODE. Replace the line below by one of the options.
    back_propagated_deriv_2 = embed_to_hid_weights * back_propagated_deriv_1;
    % Options
    % (a) back_propagated_deriv_2 = embed_to_hid_weights * back_propagated_deriv_1;
    % (b) back_propagated_deriv_2 = back_propagated_deriv_1 * embed_to_hid_weights;
    % (c) back_propagated_deriv_2 = back_propagated_deriv_1' * embed_to_hid_weights;
    % (d) back_propagated_deriv_2 = back_propagated_deriv_1 * embed_to_hid_weights';

    word_embedding_weights_gradient(:) = 0;
    %% EMBEDDING LAYER.
    for w = 1:numwords
       word_embedding_weights_gradient = word_embedding_weights_gradient + ...
         expansion_matrix(:, input_batch(w, :)) * ...
         (back_propagated_deriv_2(1 + (w - 1) * numhid1 : w * numhid1, :)');
    end

    % UPDATE WEIGHTS AND BIASES.
    word_embedding_weights_delta = ...
      momentum .* word_embedding_weights_delta + ...
      word_embedding_weights_gradient ./ batchsize;
    word_embedding_weights = word_embedding_weights...
      - learning_rate * word_embedding_weights_delta;

    embed_to_hid_weights_delta = ...
      momentum .* embed_to_hid_weights_delta + ...
      embed_to_hid_weights_gradient ./ batchsize;
    embed_to_hid_weights = embed_to_hid_weights...
      - learning_rate * embed_to_hid_weights_delta;

    hid_to_output_weights_delta = ...
      momentum .* hid_to_output_weights_delta + ...
      hid_to_output_weights_gradient ./ batchsize;
    hid_to_output_weights = hid_to_output_weights...
      - learning_rate * hid_to_output_weights_delta;

    hid_bias_delta = momentum .* hid_bias_delta + ...
      hid_bias_gradient ./ batchsize;
    hid_bias = hid_bias - learning_rate * hid_bias_delta;

    output_bias_delta = momentum .* output_bias_delta + ...
      output_bias_gradient ./ batchsize;
    output_bias = output_bias - learning_rate * output_bias_delta;

    % VALIDATE.
    if mod(m, show_validation_CE_after) == 0
      fprintf(1, '\rRunning validation ...');
      if OctaveMode
        fflush(1);
      end
      [embedding_layer_state, hidden_layer_state, output_layer_state] = ...
        fprop(valid_input, word_embedding_weights, embed_to_hid_weights,...
              hid_to_output_weights, hid_bias, output_bias);
      datasetsize = size(valid_input, 2);
      expanded_valid_target = expansion_matrix(:, valid_target);
      CE = -sum(sum(...
        expanded_valid_target .* log(output_layer_state + tiny))) /datasetsize;
      fprintf(1, ' Validation CE %.3f\n', CE);
      if OctaveMode
        fflush(1);
      end
    end
  end
  fprintf(1, '\rAverage Training CE %.3f\n', trainset_CE);
end
fprintf(1, 'Finished Training.\n');
if OctaveMode
  fflush(1);
end
fprintf(1, 'Final Training CE %.3f\n', trainset_CE);

% EVALUATE ON VALIDATION SET.
fprintf(1, '\rRunning validation ...');
if OctaveMode
  fflush(1);
end
[embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  fprop(valid_input, word_embedding_weights, embed_to_hid_weights,...
        hid_to_output_weights, hid_bias, output_bias);
datasetsize = size(valid_input, 2);
expanded_valid_target = expansion_matrix(:, valid_target);
CE = -sum(sum(...
  expanded_valid_target .* log(output_layer_state + tiny))) / datasetsize;
fprintf(1, '\rFinal Validation CE %.3f\n', CE);
if OctaveMode
  fflush(1);
end

% EVALUATE ON TEST SET.
fprintf(1, '\rRunning test ...');
if OctaveMode
  fflush(1);
end
[embedding_layer_state, hidden_layer_state, output_layer_state] = ...
  fprop(test_input, word_embedding_weights, embed_to_hid_weights,...
        hid_to_output_weights, hid_bias, output_bias);
datasetsize = size(test_input, 2);
expanded_test_target = expansion_matrix(:, test_target);
CE = -sum(sum(...
  expanded_test_target .* log(output_layer_state + tiny))) / datasetsize;
fprintf(1, '\rFinal Test CE %.3f\n', CE);
if OctaveMode
  fflush(1);
end

model.word_embedding_weights = word_embedding_weights;
model.embed_to_hid_weights = embed_to_hid_weights;
model.hid_to_output_weights = hid_to_output_weights;
model.hid_bias = hid_bias;
model.output_bias = output_bias;
model.vocab = vocab;

% In MATLAB replace line below with 'end_time = clock;'
if OctaveMode
  end_time = time;
  diff = end_time - start_time;
else
  end_time = clock;
  diff = etime(end_time, start_time);
end
fprintf(1, 'Training took %.2f seconds\n', diff);
end