Skip to content

Commit 83af3f0

Browse files
Impaviditypnpnpn
authored andcommitted
cleanup the codebase
1 parent c90d4e0 commit 83af3f0

13 files changed

Lines changed: 2148 additions & 146 deletions

File tree

Lines changed: 101 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,27 @@
2626
from dataclasses import dataclass, field
2727
from typing import Optional
2828
import torch
29-
3029
from transformers import (
3130
MODEL_WITH_LM_HEAD_MAPPING,
3231
AutoTokenizer,
3332
HfArgumentParser,
3433
PreTrainedTokenizer,
3534
set_seed,
3635
)
37-
from relogic.pretrainkit.trainer import Trainer
36+
from relogic.pretrainkit.multitask_trainer import Trainer
3837
from relogic.pretrainkit.datasets.semparse.tabart import DataCollatorForTaBART, TaBARTDataset
38+
from relogic.pretrainkit.datasets.semparse.text2sql import DataCollatorForQuerySchema2SQL, QuerySchema2SQLDataset
3939
from relogic.pretrainkit.scorers.match_sequence import MatchSequenceScorer
40-
from relogic.pretrainkit.models.semparse.tabart import TaBARTModel
40+
from relogic.pretrainkit.models.semparse.logical_tabart import LogicalTaBARTModel
4141
from relogic.pretrainkit.training_args import TrainingArguments
42-
import relogic.utils.crash_on_ipy
4342

4443
logger = logging.getLogger(__name__)
4544

4645

4746
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
4847
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
4948

49+
is_sagemaker = 'SM_MODEL_DIR' in os.environ
5050

5151
@dataclass
5252
class ModelArguments:
@@ -73,17 +73,26 @@ class ModelArguments:
7373
cache_dir: Optional[str] = field(
7474
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
7575
)
76-
task: Optional[str] = field(
77-
default="mlm", metadata={"help": "Learning target. mlm, col_pred, mlm+col_pred"}
76+
pretraining_model: Optional[str] = field(
77+
default=None, metadata={"help": "What is the model to use for pretraining."}
78+
)
79+
load_from_pretrained_ckpt: Optional[str] = field(
80+
default=None, metadata={"help": "Initialize the model with pretrained checkpoint"}
81+
)
82+
pretrained_ckpt_dir: Optional[str] = field(
83+
default="pretrained_checkpoint", metadata={"help": "Pretrained Checkpoint"}
7884
)
7985

8086

87+
8188
@dataclass
8289
class DataTrainingArguments:
8390
"""
8491
Arguments pertaining to what data we are going to input our model for training and eval.
8592
"""
86-
93+
task_names: Optional[str] = field(
94+
default=None, metadata={"help": "The name of tasks which are separated by ,"}
95+
)
8796
train_data_file: Optional[str] = field(
8897
default=None, metadata={"help": "The input training data file (a text file)."}
8998
)
@@ -114,11 +123,47 @@ class DataTrainingArguments:
114123
overwrite_cache: bool = field(
115124
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
116125
)
126+
not_use_text: bool = field(
127+
default=False, metadata={"help": "To use text in pretraining or not"}
128+
)
129+
only_use_text: bool = field(
130+
default=False, metadata={"help": "To only use text in pretraining or not"}
131+
)
132+
cross_lingual: bool = field(
133+
default=False,
134+
metadata={"help": "Whether to use Cross-lingual Tabart Training"},
135+
)
136+
dump_file_name: str = field(
137+
default="eval_dump.json",
138+
metadata={"help": "The file name of evaluation dumping."}
139+
)
117140

141+
def get_dataset_by_name(pretraining_model, task_name, cross_lingual, tokenizer, file_path, use_text, only_use_text):
142+
if task_name != "text2sql":
143+
return TaBARTDataset(tokenizer=tokenizer, file_path=file_path, col_token="<col>",
144+
task_name=task_name, use_text=use_text, only_use_text=only_use_text)
145+
if task_name == "text2sql":
146+
return QuerySchema2SQLDataset(tokenizer=tokenizer, file_path=file_path, task_name=task_name)
118147

119-
def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
120-
file_path = args.eval_data_file if evaluate else args.train_data_file
121-
return TaBARTDataset(tokenizer=tokenizer, file_path=file_path, col_token="<col>")
148+
149+
def get_datasets(pretraining_model, args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
150+
file_paths = args.eval_data_file.split(",") if evaluate else args.train_data_file.split(",")
151+
task_names = args.task_names.split(",")
152+
datasets = [get_dataset_by_name(pretraining_model, task_name, args.cross_lingual, tokenizer, file_path, not args.not_use_text, args.only_use_text)
153+
for task_name, file_path in zip(task_names, file_paths)]
154+
return datasets
155+
156+
def get_data_collator_by_name(pretraining_model, task_name, cross_lingual, tokenizer):
157+
if task_name != "text2sql":
158+
return DataCollatorForTaBART(tokenizer=tokenizer, task=task_name, col_token="<col>")
159+
if task_name == "text2sql":
160+
return DataCollatorForQuerySchema2SQL(tokenizer=tokenizer)
161+
162+
163+
def get_data_collators(pretraining_model, args: DataTrainingArguments, tokenizer: PreTrainedTokenizer):
164+
task_names = args.task_names.split(",")
165+
collators = [get_data_collator_by_name(pretraining_model, task_name, args.cross_lingual, tokenizer) for task_name in task_names]
166+
return collators
122167

123168

124169
def main():
@@ -129,21 +174,34 @@ def main():
129174
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
130175
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
131176

177+
if is_sagemaker:
178+
training_args.do_train = training_args.do_train_str == "True"
179+
training_args.do_eval = training_args.do_eval_str == "True"
180+
training_args.evaluate_during_training = training_args.evaluate_during_training_str == "True"
181+
data_args.train_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.train_data_file.split(",")])
182+
data_args.eval_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.eval_data_file.split(",")])
183+
training_args.output_dir = os.environ['SM_MODEL_DIR']
184+
model_args.pretrained_ckpt_dir = os.environ.get("SM_CHANNEL_PRETRAINED_CKPT_DIR", None)
185+
186+
if model_args.pretrained_ckpt_dir is not None and model_args.load_from_pretrained_ckpt is not None:
187+
model_args.load_from_pretrained_ckpt = os.path.join(model_args.pretrained_ckpt_dir, model_args.load_from_pretrained_ckpt)
188+
132189
if data_args.eval_data_file is None and training_args.do_eval:
133190
raise ValueError(
134191
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
135192
"or remove the --do_eval argument."
136193
)
137194

138-
if (
139-
os.path.exists(training_args.output_dir)
140-
and os.listdir(training_args.output_dir)
141-
and training_args.do_train
142-
and not training_args.overwrite_output_dir
143-
):
144-
raise ValueError(
145-
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
146-
)
195+
if not is_sagemaker:
196+
if (
197+
os.path.exists(training_args.output_dir)
198+
and os.listdir(training_args.output_dir)
199+
and training_args.do_train
200+
and not training_args.overwrite_output_dir
201+
):
202+
raise ValueError(
203+
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
204+
)
147205

148206
# Setup logging
149207
logging.basicConfig(
@@ -164,91 +222,59 @@ def main():
164222
# Set seed
165223
set_seed(training_args.seed)
166224

167-
# Load pretrained model and tokenizer
168-
#
169-
# Distributed training:
170-
# The .from_pretrained methods guarantee that only one local process can concurrently
171-
# download model & vocab.
172-
173-
# if model_args.config_name:
174-
# config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
175-
# elif model_args.model_name_or_path:
176-
# config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
177-
# else:
178-
# config = CONFIG_MAPPING[model_args.model_type]()
179-
# logger.warning("You are instantiating a new config instance from scratch.")
180-
#
181-
# if model_args.tokenizer_name:
182-
# tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
183-
# elif model_args.model_name_or_path:
184-
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
185-
# else:
186-
# raise ValueError(
187-
# "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
188-
# "and load it from here, using --tokenizer_name"
189-
# )
190-
#
191-
# if model_args.model_name_or_path:
192-
# model = AutoModelWithLMHead.from_pretrained(
193-
# model_args.model_name_or_path,
194-
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
195-
# config=config,
196-
# cache_dir=model_args.cache_dir,
197-
# )
198-
# else:
199-
# logger.info("Training new model from scratch")
200-
# model = AutoModelWithLMHead.from_config(config)
201-
#
202-
# model.resize_token_embeddings(len(tokenizer))
203-
#
204-
# if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
205-
# raise ValueError(
206-
# "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
207-
# "flag (masked language modeling)."
208-
# )
209-
210225
"""Initialize models and tokenizer"""
211226
if model_args.tokenizer_name:
212-
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
227+
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=False)
213228
elif model_args.model_name_or_path:
214-
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
229+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False)
215230
else:
216231
raise ValueError(
217232
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
218233
"and load it from here, using --tokenizer_name"
219234
)
220235
tokenizer.add_special_tokens({"additional_special_tokens": ["<col>"]})
221-
model = TaBARTModel()
236+
237+
model = LogicalTaBARTModel(data_args.task_names)
222238
model.bert.resize_token_embeddings(len(tokenizer))
239+
model.bert_for_texttosql.resize_token_embeddings(len(tokenizer))
240+
model.bert.model.shared.weight = model.bert_for_texttosql.model.shared.weight
241+
model.bert.model.encoder.embed_tokens.weight = model.bert_for_texttosql.model.encoder.embed_tokens.weight
223242

224243
if training_args.do_eval and not training_args.do_train:
225244
model_param = torch.load(os.path.join(model_args.model_name_or_path, "pytorch_model.bin"))
226245
model.load_state_dict(model_param)
227246
print("All key matched and load successfully.")
228247

229248
if data_args.block_size <= 0:
230-
data_args.block_size = tokenizer.max_len
249+
data_args.block_size = tokenizer.model_max_length
231250
# Our input block size will be the max possible for the model
232251
else:
233-
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
252+
data_args.block_size = min(data_args.block_size, tokenizer.model_max_length)
234253

235254
# Get datasets
236255

237-
train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
238-
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
256+
train_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer) if training_args.do_train else None
257+
eval_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
239258
# data_collator = DataCollatorForLanguageModeling(
240259
# tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
241260
# )
242-
data_collator = DataCollatorForTaBART(tokenizer=tokenizer, task=model_args.task)
243-
244-
match_sequence_scorer = MatchSequenceScorer(bos_id=data_collator.label_bos_id, eos_id=data_collator.label_eos_id, output_path=os.path.join(training_args.output_dir, "eval_dump.json"))
261+
data_collators = get_data_collators(model_args.pretraining_model, data_args, tokenizer=tokenizer)
262+
263+
eos_id = None
264+
for data_collator in data_collators:
265+
if eos_id is None:
266+
eos_id = data_collator.label_eos_id
267+
else:
268+
assert eos_id == data_collator.label_eos_id
269+
match_sequence_scorer = MatchSequenceScorer(
270+
eos_id=eos_id, output_path=os.path.join(training_args.output_dir, "eval_dump.json"))
245271
# Initialize our Trainer
246272
trainer = Trainer(
247273
model=model,
248274
args=training_args,
249-
data_collator=data_collator,
250-
train_dataset=train_dataset,
251-
eval_dataset=eval_dataset,
275+
data_collators=data_collators,
276+
train_datasets=train_datasets,
277+
eval_datasets=eval_datasets,
252278
compute_metrics=match_sequence_scorer
253279
)
254280

relogic/logickit/__init__.py

Whitespace-only changes.

relogic/logickit/base/__init__.py

Whitespace-only changes.

relogic/logickit/base/utils.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import sys
2+
import os
3+
try:
4+
import cPickle as pickle
5+
except ImportError:
6+
import pickle
7+
8+
9+
class Memoize(object):
10+
def __init__(self, f):
11+
self.f = f
12+
self.cache = {}
13+
14+
def __call__(self, *args):
15+
if args not in self.cache:
16+
self.cache[args] = self.f(*args)
17+
return self.cache[args]
18+
19+
def load_pickle(path, memoized=True):
20+
return _load_pickle_memoize(path) if memoized else _load_pickle(path)
21+
22+
def _load_pickle(path):
23+
with open(path, 'rb') as f:
24+
return pickle.load(f)
25+
26+
@Memoize
27+
def _load_pickle_memoize(path):
28+
return _load_pickle(path)
29+
30+
31+
def write_pickle(o, path):
32+
dir = path.rsplit('/', 1)[0]
33+
if not os.path.exists(dir):
34+
os.mkdir(dir)
35+
with open(path, 'wb') as f:
36+
pickle.dump(o, f, -1)
37+
38+
def log(*args):
39+
msg = ' '.join(map(str, args))
40+
sys.stdout.write(msg + '\n')
41+
sys.stdout.flush()
42+
43+
44+
def heading(*args):
45+
log()
46+
log(80 * '=')
47+
log(*args)
48+
log(80 * '=')
49+
50+
51+
import torch
52+
def print_rank_0(message, **kwargs):
53+
"""If distributed is initialized print only on rank 0."""
54+
# if torch.distributed.is_initialized():
55+
# if torch.distributed.get_rank() == 0:
56+
# print(message, flush=True, **kwargs)
57+
# else:
58+
# print(message, flush=True, **kwargs)
59+
print(message, flush=True, **kwargs)
60+
def is_rank_0():
61+
# if torch.distributed.is_initialized():
62+
# if torch.distributed.get_rank() == 0:
63+
# return True
64+
# else:
65+
# return True
66+
# return False
67+
return True

relogic/logickit/modules/__init__.py

Whitespace-only changes.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from relogic.logickit.modules.span_extractors.endpoint_span_extractor import EndpointSpanExtractor
2+
from relogic.logickit.modules.span_extractors.self_attentive_span_extractor import SelfAttentiveSpanExtractor
3+
from relogic.logickit.modules.span_extractors.attentive_span_extractor import AttentiveSpanExtractor

0 commit comments

Comments
 (0)