-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
119 lines (93 loc) · 3.08 KB
/
train.py
File metadata and controls
119 lines (93 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import tensorflow as tf
from model import Seq2Seq
from val import run_validation
from data import (
load_jsonl,
load_sp_model,
tokenize_text_pairs,
train_sentencepiece_model
)
def train_model(
train_path,
val_path,
epochs = 20,
sp_src_prefix = "sp_en",
sp_tgt_prefix = "sp_de",
model_export_dir = "checkpoints/seq2seq",
):
print("Loading training data")
text_pairs = load_jsonl(train_path, only_text_pairs=True)
src_text = [src for src, _ in text_pairs]
tgt_text = [tgt for _, tgt in text_pairs]
print("Training SentencePiece models")
train_sentencepiece_model(src_text, sp_src_prefix)
train_sentencepiece_model(tgt_text, sp_tgt_prefix)
print("Loading SentencePiece models")
sp_src = load_sp_model(sp_src_prefix)
sp_tgt = load_sp_model(sp_tgt_prefix)
print("Tokenizing training data")
token_pairs = tokenize_text_pairs(text_pairs, sp_src, sp_tgt)
print("Initializing Seq2Seq model")
model = Seq2Seq(
src_tokenizer=sp_src,
tgt_tokenizer=sp_tgt,
)
print("Training Seq2Seq model")
model.train(
token_pairs,
epochs=epochs,
val_path=val_path,
)
print("Saving trained model")
model.export(model_export_dir, prefix="seq2seq")
print(f"\tModel saved to: {model_export_dir}")
print("Running evaluation")
run_validation(val_path, model)
print("Training complete!")
print("Example translation:")
print(model.translate("Hello, how are you?"))
return model
def load_and_evaluate(
val_path,
sp_src_prefix = "sp_en",
sp_tgt_prefix = "sp_de",
model_export_dir = "checkpoints/seq2seq",
):
print("Loading SentencePiece tokenizers")
sp_src = load_sp_model(sp_src_prefix)
sp_tgt = load_sp_model(sp_tgt_prefix)
print("Initializing Seq2Seq model")
model = Seq2Seq(
src_tokenizer=sp_src,
tgt_tokenizer=sp_tgt,
)
# Dummy inputs to initialize weights
dummy_src = tf.constant([[1, 2, 3]], dtype=tf.int32)
dummy_tgt = tf.constant([[1, 2, 3]], dtype=tf.int32)
# Call encoder & decoder once to build internal weights
_, h, c = model.encoder(dummy_src)
_ = model.decoder(dummy_tgt, (h, c))
print("Loading saved model weights")
model.load(model_export_dir, prefix="seq2seq")
print("Running validation")
run_validation(val_path, model)
def main():
train_path = r"C:\Users\daane\Desktop\GitHub Repos\NLP_Final_Project\dataset\train\de\train.jsonl"
val_path = r"C:\Users\daane\Desktop\GitHub Repos\NLP_Final_Project\dataset\validation\de_DE.jsonl"
model_export_dir = r"C:\Users\daane\Desktop\GitHub Repos\NLP_Final_Project\checkpoints\seq2seq"
# train_model(
# train_path=train_path,
# val_path=val_path,
# epochs=200,
# sp_src_prefix="sp_en",
# sp_tgt_prefix="sp_de",
# model_export_dir=model_export_dir
# )
load_and_evaluate(
val_path=val_path,
sp_src_prefix="sp_en",
sp_tgt_prefix="sp_de",
model_export_dir=model_export_dir
)
if __name__ == "__main__":
main()