-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
81 lines (64 loc) · 2.79 KB
/
eval.py
File metadata and controls
81 lines (64 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import etl
import helpers
import torch
from attention_decoder import AttentionDecoderRNN
from encoder import EncoderRNN
from language import Language
from torch.autograd import Variable
# Parse argument for input sentence
parser = argparse.ArgumentParser()
parser.add_argument('language')
parser.add_argument('input')
args = parser.parse_args()
helpers.validate_language_params(args.language)
input_lang, output_lang, pairs = etl.prepare_data(args.language)
attn_model = 'general'
hidden_size = 500
n_layers = 2
dropout_p = 0.05
# Initialize models
encoder = EncoderRNN(input_lang.n_words, hidden_size, n_layers)
decoder = AttentionDecoderRNN(attn_model, hidden_size, output_lang.n_words, n_layers, dropout_p=dropout_p)
# Load model parameters
encoder.load_state_dict(torch.load('../data/encoder_params_{}'.format(args.language)))
decoder.load_state_dict(torch.load('../data/decoder_params_{}'.format(args.language)))
decoder.attention.load_state_dict(torch.load('../data/attention_params_{}'.format(args.language)))
# Move models to GPU
encoder.cuda()
decoder.cuda()
def evaluate(sentence, max_length=10):
input_variable = etl.variable_from_sentence(input_lang, sentence)
input_length = input_variable.size()[0]
# Run through encoder
encoder_hidden = encoder.init_hidden()
encoder_outputs, encoder_hidden = encoder(input_variable, encoder_hidden)
# Create starting vectors for decoder
decoder_input = Variable(torch.LongTensor([[Language.sos_token]])) # SOS
decoder_context = Variable(torch.zeros(1, decoder.hidden_size))
decoder_input = decoder_input.cuda()
decoder_context = decoder_context.cuda()
decoder_hidden = encoder_hidden
decoded_words = []
decoder_attentions = torch.zeros(max_length, max_length)
# Run through decoder
for di in range(max_length):
decoder_output, decoder_context, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_context,
decoder_hidden, encoder_outputs)
decoder_attentions[di, :decoder_attention.size(2)] += decoder_attention.squeeze(0).squeeze(0).cpu().data
# Choose top word from output
topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
if ni == Language.eos_token:
decoded_words.append('<EOS>')
break
else:
decoded_words.append(output_lang.index2word[ni])
# Next input is chosen word
decoder_input = Variable(torch.LongTensor([[ni]]))
decoder_input = decoder_input.cuda()
return decoded_words, decoder_attentions[:di + 1, :len(encoder_outputs)]
sentence = helpers.normalize_string(args.input)
output_words, decoder_attn = evaluate(sentence)
output_sentence = ' '.join(output_words)
print(output_sentence)