2626from dataclasses import dataclass , field
2727from typing import Optional
2828import torch
29-
3029from transformers import (
3130 MODEL_WITH_LM_HEAD_MAPPING ,
3231 AutoTokenizer ,
3332 HfArgumentParser ,
3433 PreTrainedTokenizer ,
3534 set_seed ,
3635)
37- from relogic .pretrainkit .trainer import Trainer
36+ from relogic .pretrainkit .multitask_trainer import Trainer
3837from relogic .pretrainkit .datasets .semparse .tabart import DataCollatorForTaBART , TaBARTDataset
38+ from relogic .pretrainkit .datasets .semparse .text2sql import DataCollatorForQuerySchema2SQL , QuerySchema2SQLDataset
3939from relogic .pretrainkit .scorers .match_sequence import MatchSequenceScorer
40- from relogic .pretrainkit .models .semparse .tabart import TaBARTModel
40+ from relogic .pretrainkit .models .semparse .logical_tabart import LogicalTaBARTModel
4141from relogic .pretrainkit .training_args import TrainingArguments
42- import relogic .utils .crash_on_ipy
4342
4443logger = logging .getLogger (__name__ )
4544
4645
4746MODEL_CONFIG_CLASSES = list (MODEL_WITH_LM_HEAD_MAPPING .keys ())
4847MODEL_TYPES = tuple (conf .model_type for conf in MODEL_CONFIG_CLASSES )
4948
49+ is_sagemaker = 'SM_MODEL_DIR' in os .environ
5050
5151@dataclass
5252class ModelArguments :
@@ -73,17 +73,26 @@ class ModelArguments:
7373 cache_dir : Optional [str ] = field (
7474 default = None , metadata = {"help" : "Where do you want to store the pretrained models downloaded from s3" }
7575 )
76- task : Optional [str ] = field (
77- default = "mlm" , metadata = {"help" : "Learning target. mlm, col_pred, mlm+col_pred" }
76+ pretraining_model : Optional [str ] = field (
77+ default = None , metadata = {"help" : "What is the model to use for pretraining." }
78+ )
79+ load_from_pretrained_ckpt : Optional [str ] = field (
80+ default = None , metadata = {"help" : "Initialize the model with pretrained checkpoint" }
81+ )
82+ pretrained_ckpt_dir : Optional [str ] = field (
83+ default = "pretrained_checkpoint" , metadata = {"help" : "Pretrained Checkpoint" }
7884 )
7985
8086
87+
8188@dataclass
8289class DataTrainingArguments :
8390 """
8491 Arguments pertaining to what data we are going to input our model for training and eval.
8592 """
86-
93+ task_names : Optional [str ] = field (
94+ default = None , metadata = {"help" : "The name of tasks which are separated by ," }
95+ )
8796 train_data_file : Optional [str ] = field (
8897 default = None , metadata = {"help" : "The input training data file (a text file)." }
8998 )
@@ -114,11 +123,47 @@ class DataTrainingArguments:
114123 overwrite_cache : bool = field (
115124 default = False , metadata = {"help" : "Overwrite the cached training and evaluation sets" }
116125 )
126+ not_use_text : bool = field (
127+ default = False , metadata = {"help" : "To use text in pretraining or not" }
128+ )
129+ only_use_text : bool = field (
130+ default = False , metadata = {"help" : "To only use text in pretraining or not" }
131+ )
132+ cross_lingual : bool = field (
133+ default = False ,
134+ metadata = {"help" : "Whether to use Cross-lingual Tabart Training" },
135+ )
136+ dump_file_name : str = field (
137+ default = "eval_dump.json" ,
138+ metadata = {"help" : "The file name of evaluation dumping." }
139+ )
117140
141+ def get_dataset_by_name (pretraining_model , task_name , cross_lingual , tokenizer , file_path , use_text , only_use_text ):
142+ if task_name != "text2sql" :
143+ return TaBARTDataset (tokenizer = tokenizer , file_path = file_path , col_token = "<col>" ,
144+ task_name = task_name , use_text = use_text , only_use_text = only_use_text )
145+ if task_name == "text2sql" :
146+ return QuerySchema2SQLDataset (tokenizer = tokenizer , file_path = file_path , task_name = task_name )
118147
119- def get_dataset (args : DataTrainingArguments , tokenizer : PreTrainedTokenizer , evaluate = False ):
120- file_path = args .eval_data_file if evaluate else args .train_data_file
121- return TaBARTDataset (tokenizer = tokenizer , file_path = file_path , col_token = "<col>" )
148+
149+ def get_datasets (pretraining_model , args : DataTrainingArguments , tokenizer : PreTrainedTokenizer , evaluate = False ):
150+ file_paths = args .eval_data_file .split ("," ) if evaluate else args .train_data_file .split ("," )
151+ task_names = args .task_names .split ("," )
152+ datasets = [get_dataset_by_name (pretraining_model , task_name , args .cross_lingual , tokenizer , file_path , not args .not_use_text , args .only_use_text )
153+ for task_name , file_path in zip (task_names , file_paths )]
154+ return datasets
155+
156+ def get_data_collator_by_name (pretraining_model , task_name , cross_lingual , tokenizer ):
157+ if task_name != "text2sql" :
158+ return DataCollatorForTaBART (tokenizer = tokenizer , task = task_name , col_token = "<col>" )
159+ if task_name == "text2sql" :
160+ return DataCollatorForQuerySchema2SQL (tokenizer = tokenizer )
161+
162+
163+ def get_data_collators (pretraining_model , args : DataTrainingArguments , tokenizer : PreTrainedTokenizer ):
164+ task_names = args .task_names .split ("," )
165+ collators = [get_data_collator_by_name (pretraining_model , task_name , args .cross_lingual , tokenizer ) for task_name in task_names ]
166+ return collators
122167
123168
124169def main ():
@@ -129,21 +174,34 @@ def main():
129174 parser = HfArgumentParser ((ModelArguments , DataTrainingArguments , TrainingArguments ))
130175 model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
131176
177+ if is_sagemaker :
178+ training_args .do_train = training_args .do_train_str == "True"
179+ training_args .do_eval = training_args .do_eval_str == "True"
180+ training_args .evaluate_during_training = training_args .evaluate_during_training_str == "True"
181+ data_args .train_data_file = "," .join ([os .path .join (os .environ ['SM_CHANNEL_TRAIN' ], item ) for item in data_args .train_data_file .split ("," )])
182+ data_args .eval_data_file = "," .join ([os .path .join (os .environ ['SM_CHANNEL_TRAIN' ], item ) for item in data_args .eval_data_file .split ("," )])
183+ training_args .output_dir = os .environ ['SM_MODEL_DIR' ]
184+ model_args .pretrained_ckpt_dir = os .environ .get ("SM_CHANNEL_PRETRAINED_CKPT_DIR" , None )
185+
186+ if model_args .pretrained_ckpt_dir is not None and model_args .load_from_pretrained_ckpt is not None :
187+ model_args .load_from_pretrained_ckpt = os .path .join (model_args .pretrained_ckpt_dir , model_args .load_from_pretrained_ckpt )
188+
132189 if data_args .eval_data_file is None and training_args .do_eval :
133190 raise ValueError (
134191 "Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
135192 "or remove the --do_eval argument."
136193 )
137194
138- if (
139- os .path .exists (training_args .output_dir )
140- and os .listdir (training_args .output_dir )
141- and training_args .do_train
142- and not training_args .overwrite_output_dir
143- ):
144- raise ValueError (
145- f"Output directory ({ training_args .output_dir } ) already exists and is not empty. Use --overwrite_output_dir to overcome."
146- )
195+ if not is_sagemaker :
196+ if (
197+ os .path .exists (training_args .output_dir )
198+ and os .listdir (training_args .output_dir )
199+ and training_args .do_train
200+ and not training_args .overwrite_output_dir
201+ ):
202+ raise ValueError (
203+ f"Output directory ({ training_args .output_dir } ) already exists and is not empty. Use --overwrite_output_dir to overcome."
204+ )
147205
148206 # Setup logging
149207 logging .basicConfig (
@@ -164,91 +222,59 @@ def main():
164222 # Set seed
165223 set_seed (training_args .seed )
166224
167- # Load pretrained model and tokenizer
168- #
169- # Distributed training:
170- # The .from_pretrained methods guarantee that only one local process can concurrently
171- # download model & vocab.
172-
173- # if model_args.config_name:
174- # config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
175- # elif model_args.model_name_or_path:
176- # config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
177- # else:
178- # config = CONFIG_MAPPING[model_args.model_type]()
179- # logger.warning("You are instantiating a new config instance from scratch.")
180- #
181- # if model_args.tokenizer_name:
182- # tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
183- # elif model_args.model_name_or_path:
184- # tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
185- # else:
186- # raise ValueError(
187- # "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
188- # "and load it from here, using --tokenizer_name"
189- # )
190- #
191- # if model_args.model_name_or_path:
192- # model = AutoModelWithLMHead.from_pretrained(
193- # model_args.model_name_or_path,
194- # from_tf=bool(".ckpt" in model_args.model_name_or_path),
195- # config=config,
196- # cache_dir=model_args.cache_dir,
197- # )
198- # else:
199- # logger.info("Training new model from scratch")
200- # model = AutoModelWithLMHead.from_config(config)
201- #
202- # model.resize_token_embeddings(len(tokenizer))
203- #
204- # if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
205- # raise ValueError(
206- # "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
207- # "flag (masked language modeling)."
208- # )
209-
210225 """Initialize models and tokenizer"""
211226 if model_args .tokenizer_name :
212- tokenizer = AutoTokenizer .from_pretrained (model_args .tokenizer_name , cache_dir = model_args .cache_dir )
227+ tokenizer = AutoTokenizer .from_pretrained (model_args .tokenizer_name , cache_dir = model_args .cache_dir , use_fast = False )
213228 elif model_args .model_name_or_path :
214- tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path , cache_dir = model_args .cache_dir )
229+ tokenizer = AutoTokenizer .from_pretrained (model_args .model_name_or_path , cache_dir = model_args .cache_dir , use_fast = False )
215230 else :
216231 raise ValueError (
217232 "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
218233 "and load it from here, using --tokenizer_name"
219234 )
220235 tokenizer .add_special_tokens ({"additional_special_tokens" : ["<col>" ]})
221- model = TaBARTModel ()
236+
237+ model = LogicalTaBARTModel (data_args .task_names )
222238 model .bert .resize_token_embeddings (len (tokenizer ))
239+ model .bert_for_texttosql .resize_token_embeddings (len (tokenizer ))
240+ model .bert .model .shared .weight = model .bert_for_texttosql .model .shared .weight
241+ model .bert .model .encoder .embed_tokens .weight = model .bert_for_texttosql .model .encoder .embed_tokens .weight
223242
224243 if training_args .do_eval and not training_args .do_train :
225244 model_param = torch .load (os .path .join (model_args .model_name_or_path , "pytorch_model.bin" ))
226245 model .load_state_dict (model_param )
227246 print ("All key matched and load successfully." )
228247
229248 if data_args .block_size <= 0 :
230- data_args .block_size = tokenizer .max_len
249+ data_args .block_size = tokenizer .model_max_length
231250 # Our input block size will be the max possible for the model
232251 else :
233- data_args .block_size = min (data_args .block_size , tokenizer .max_len )
252+ data_args .block_size = min (data_args .block_size , tokenizer .model_max_length )
234253
235254 # Get datasets
236255
237- train_dataset = get_dataset ( data_args , tokenizer = tokenizer ) if training_args .do_train else None
238- eval_dataset = get_dataset ( data_args , tokenizer = tokenizer , evaluate = True ) if training_args .do_eval else None
256+ train_datasets = get_datasets ( model_args . pretraining_model , data_args , tokenizer = tokenizer ) if training_args .do_train else None
257+ eval_datasets = get_datasets ( model_args . pretraining_model , data_args , tokenizer = tokenizer , evaluate = True ) if training_args .do_eval else None
239258 # data_collator = DataCollatorForLanguageModeling(
240259 # tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
241260 # )
242- data_collator = DataCollatorForTaBART (tokenizer = tokenizer , task = model_args .task )
243-
244- match_sequence_scorer = MatchSequenceScorer (bos_id = data_collator .label_bos_id , eos_id = data_collator .label_eos_id , output_path = os .path .join (training_args .output_dir , "eval_dump.json" ))
261+ data_collators = get_data_collators (model_args .pretraining_model , data_args , tokenizer = tokenizer )
262+
263+ eos_id = None
264+ for data_collator in data_collators :
265+ if eos_id is None :
266+ eos_id = data_collator .label_eos_id
267+ else :
268+ assert eos_id == data_collator .label_eos_id
269+ match_sequence_scorer = MatchSequenceScorer (
270+ eos_id = eos_id , output_path = os .path .join (training_args .output_dir , "eval_dump.json" ))
245271 # Initialize our Trainer
246272 trainer = Trainer (
247273 model = model ,
248274 args = training_args ,
249- data_collator = data_collator ,
250- train_dataset = train_dataset ,
251- eval_dataset = eval_dataset ,
275+ data_collators = data_collators ,
276+ train_datasets = train_datasets ,
277+ eval_datasets = eval_datasets ,
252278 compute_metrics = match_sequence_scorer
253279 )
254280
0 commit comments