38 lines
1.3 KiB
Matlab
38 lines
1.3 KiB
Matlab
function predict_next_word(word1, word2, word3, model, k)
|
|
% Predicts the next word.
|
|
% Inputs:
|
|
% word1: The first word as a string.
|
|
% word2: The second word as a string.
|
|
% word3: The third word as a string.
|
|
% model: Model returned by the training script.
|
|
% k: The k most probable predictions are shown.
|
|
% Example usage:
|
|
% predict_next_word('john', 'might', 'be', model, 3);
|
|
% predict_next_word('life', 'in', 'new', model, 3);
|
|
|
|
word_embedding_weights = model.word_embedding_weights;
|
|
vocab = model.vocab;
|
|
id1 = strmatch(word1, vocab, 'exact');
|
|
id2 = strmatch(word2, vocab, 'exact');
|
|
id3 = strmatch(word3, vocab, 'exact');
|
|
if ~any(id1)
|
|
fprintf(1, 'Word ''%s\'' not in vocabulary.\n', word1);
|
|
return;
|
|
end
|
|
if ~any(id2)
|
|
fprintf(1, 'Word ''%s\'' not in vocabulary.\n', word2);
|
|
return;
|
|
end
|
|
if ~any(id3)
|
|
fprintf(1, 'Word ''%s\'' not in vocabulary.\n', word3);
|
|
return;
|
|
end
|
|
input = [id1; id2; id3];
|
|
[embedding_layer_state, hidden_layer_state, output_layer_state] = ...
|
|
fprop(input, model.word_embedding_weights, model.embed_to_hid_weights,...
|
|
model.hid_to_output_weights, model.hid_bias, model.output_bias);
|
|
[prob, indices] = sort(output_layer_state, 'descend');
|
|
for i = 1:k
|
|
fprintf(1, '%s %s %s %s Prob: %.5f\n', word1, word2, word3, vocab{indices(i)}, prob(i));
|
|
end
|