-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_sft.py
More file actions
118 lines (101 loc) · 3.45 KB
/
train_sft.py
File metadata and controls
118 lines (101 loc) · 3.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Performs SFT on a specified model"""
from trl import SFTTrainer
import utils
from transformers import (
BitsAndBytesConfig,
TrainingArguments,
)
import yaml
import getpass
import wandb
import torch as t
from typing import Dict, Any
device = t.device("cuda" if t.cuda.is_available() else "cpu")
def setup_logging(hps: Dict[str, Any]):
# Choose logging and checkpoint saving directory
logdir = utils.choose_log_dir(
f"{utils.run_dir}/{hps['dataset_name']}/training/{hps['training_algorithm']}",
debug=hps["debug"],
)
# Add a couple of keys to the hps object and save it as a yaml file
hps["logdir"] = logdir
hps["training_kwargs"]["run_name"] = "/".join(logdir.split("/")[-2:])
hps["user"] = getpass.getuser()
hps["tags"] += [
hps["dataset"]["name"],
"training",
hps["training_algorithm"],
]
with open(f"{logdir}/hps.yaml", "w") as f:
yaml.dump(hps, f)
# If not in debug mode, setup wandb logging
if not hps["debug"]:
wandb.init(
project="dpo_rlhf_generalization",
dir=logdir,
name=hps["training_kwargs"]["run_name"],
config=utils.wandb_configify(hps),
tags=hps["tags"],
save_code=True,
settings=wandb.Settings(code_dir="."),
)
print(f"Hyperparameters:\n{hps}\n")
return logdir
def main():
# Load hyperparameters
args = utils.argparser().parse_args()
with open(
args.hyperparam_file,
) as f:
hps = yaml.load(f, Loader=yaml.FullLoader)
logdir = setup_logging(hps)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=t.bfloat16,
)
# load model
tokenizer, model = utils.load_model(
hps["model"],
reward_model=False,
eval=False,
quantized=True,
bnb_config=bnb_config,
)
dataset = utils.load_dataset(tokenizer, **hps["dataset"], debug=False, sft=True)
args = TrainingArguments(
output_dir="sft_model_instruct", # directory to save and repository id
num_train_epochs=2, # number of training epochs
per_device_train_batch_size=4, # batch size per device during training
gradient_accumulation_steps=2, # number of steps before performing a backward/update pass
gradient_checkpointing=True, # use gradient checkpointing to save memory
optim="adamw_torch_fused", # use fused adamw optimizer
logging_steps=10, # log every 10 steps
save_strategy="epoch", # save checkpoint every epoch
learning_rate=2e-4, # learning rate, based on QLoRA paper
bf16=True, # use bfloat16 precision
tf32=True, # use tf32 precision
max_grad_norm=0.3, # max gradient norm based on QLoRA paper
warmup_ratio=0.03, # warmup ratio based on QLoRA paper
)
trainer = SFTTrainer(
model,
tokenizer=tokenizer,
train_dataset=dataset["train"].select(range(20000)),
args=args,
dataset_text_field="prompt",
dataset_batch_size=1000,
eval_dataset=dataset["test"],
peft_config=hps["peft_config_class"](**hps["peft_config_kwargs"]),
packing=True,
max_seq_length=3072,
)
trainer.train()
trainer.save_model()
wandb.finish()
del model
del trainer
t.cuda.empty_cache()
if __name__ == "__main__":
main()