-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
49 lines (36 loc) · 1.79 KB
/
train.py
File metadata and controls
49 lines (36 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import argparse
from model.decoder_network import DecoderNetwork as Model
from model.decoder_network import NetworkConfigs as ModelConfig
from trainer.trainer import Trainer
from trainer.config import TrainConfig
from data.dataset import get_datapipe
def main(pickle_name, use_test_data):
pickle_name = pickle_name.rstrip('.pkl').replace('\\', '/')
model_config = ModelConfig()
assert pickle_name == model_config.pickle_name
model = Model(model_config)
train_dataset = get_datapipe(pickle_name, 'train',
frozen_embeddings=model_config.frozen_embeddings,
sbert_model=model_config.sbert_model,
hugface_model=model_config.hugface_model,
max_length=model_config.max_length)
if use_test_data:
test_dataset = get_datapipe(pickle_name, 'test',
frozen_embeddings=model_config.frozen_embeddings,
sbert_model=model_config.sbert_model,
hugface_model=model_config.hugface_model,
max_length=model_config.max_length)
else:
test_dataset = None
train_config = TrainConfig()
trainer = Trainer(model=model,
train_dataset=train_dataset,
test_dataset=test_dataset,
train_config=train_config)
trainer.distributed_train()
if __name__ == '__main__':
my_parser = argparse.ArgumentParser()
my_parser.add_argument('dataset', type=str, help='Data pickle file to train with.')
my_parser.add_argument('--test', action='store_true', help='Use a test/validation data subset.')
args = my_parser.parse_args()
main(args.dataset, args.test)