-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_model.py
More file actions
48 lines (39 loc) · 1.34 KB
/
test_model.py
File metadata and controls
48 lines (39 loc) · 1.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import sys
import numpy as np
import torch
import time
import torch.nn as nn
from utils.utils import *
from torch.utils.data import Dataset, DataLoader
from model.model import *
from model.dataset import *
from model.attentionSecondOrderLSTM import AttentionSecondOrderLSTM
def test(dataset):
batch_size = 1
vocab = dataset.get_vocab()
criterion = nn.NLLLoss(ignore_index=vocab.pad_id)
model = LSTMLanguageModel(vocab, hidden_size=30, batch_size=batch_size, embed_size=12)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model.lstm = AttentionSecondOrderLSTM(second_order_size=2, input_size=12, hidden_size=30)
x, y = next(iter(dataloader))
# y = y[:, :5]
# y[:, 4] = 0
# print(x)
# print(y)
y = y.view(-1)
y_pred, _ = model(x, 0.0001)
loss = criterion(y_pred, y)
print(loss)
# x, y = next(iter(dataloader))
# # y = y[:, :5]
# # y[:, 4] = 0
# # print(x)
# # print(y)
# y = y.view(-1)
# y_pred = model(x)
# loss = criterion(y_pred, y)
# print(loss)
# assert (np.allclose(y_pred, y_pred1[:len(y_pred)])), "{}".format(y_pred - y_pred1[:len(y_pred)])
if __name__ == "__main__":
dataset = CustomDataset("./data/mbounded-dyck-k/m4/train.formal.txt")
test(dataset)