forked from rui-ye/OpenFedLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
194 lines (180 loc) · 10.5 KB
/
config.py
File metadata and controls
194 lines (180 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from dataclasses import dataclass, field, asdict
from genericpath import exists
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments, BitsAndBytesConfig
from trl import SFTConfig, DPOConfig
from peft import LoraConfig
import os
import json
from accelerate import Accelerator
import torch
from datetime import datetime, timedelta
# Define and parse arguments.
@dataclass
class FedArguments:
fed_alg: Optional[str] = field(default="fedavg", metadata={"help": "the algorithm to use"})
num_rounds: Optional[int] = field(default=500, metadata={"help": "the number of rounds"})
num_clients: Optional[int] = field(default=10, metadata={"help": "the number of clients"})
sample_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients to sample"})
split_strategy: Optional[str] = field(default="iid", metadata={"help": "the split strategy"})
prox_mu: Optional[float] = field(default=0.01, metadata={"help": "the mu parameter of FedProx"})
fedopt_tau: Optional[float] = field(default=1e-3, metadata={"help": "the tau parameter of FedAdagrad, FedYogi and FedAdam"})
fedopt_eta: Optional[float] = field(default=1e-3, metadata={"help": "the global learning rate parameter of FedAdagrad, FedYogi and FedAdam"})
fedopt_beta1: Optional[float] = field(default=0.9, metadata={"help": "the beta1 parameter of FedYogi and FedAdam"})
fedopt_beta2: Optional[float] = field(default=0.99, metadata={"help": "the beta2 parameter of FedYogi and FedAdam"})
save_model_freq: Optional[int] = field(default=10, metadata={"help": "the frequency to save the model. 50 means save every 50 rounds"})
@dataclass
class ScriptArguments:
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(
default="lucasmccabe-lmi/CodeAlpaca-20k", metadata={"help": "the dataset name"}
)
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=2e-5, metadata={"help": "the learning rate"}) # vicuna and alpaca use 2e-5
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
logging_steps: Optional[int] = field(default=100, metadata={"help": "the number of logging steps"})
use_auth_token: Optional[bool] = field(default=False, metadata={"help": "Use HF auth token to access the model"}) # token and use_auth_token cannot be used together
num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the number of training epochs"}) # 这个会被 max_steps 重写,如果有 max_steps,那么每轮的训练 epoch=1,且采样的 num_examples=max_steps*batch_size*gradient_accumulation_steps
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
save_strategy: Optional[str] = field(default="epoch")
save_steps: Optional[int] = field(
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"}
)
save_total_limit: Optional[int] = field(default=3, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
template: Optional[str] = field(default="alpaca", metadata={"help": "the template to use"})
seed: Optional[int] = field(default=2023, metadata={"help": "the seed to use"})
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter of DPO"})
dataset_sample: Optional[int] = field(default=20000, metadata={"help": "the number of samples to use from the dataset"})
local_data_dir: Optional[str] = field(default=None, metadata={"help": "the local data directory if you want to use downloaded data"})
unsloth: Optional[int] = field(default=1)
bf16: Optional[int] = field(default=0)
fp16: Optional[int] = field(default=0)
online_dataset: Optional[int] = field(default=0)
full_data: Optional[int] = field(default=0)
still_contain_base_data: Optional[int] = field(default=0)
seed: Optional[int] = field(default=None)
prompt_num: Optional[int] = field(default=2)
response_num: Optional[int] = field(default=4)
generate_data_path: Optional[str] = field(default="/mnt/bn/merlin-datavolume-tsy/leon/datasets/self-rewarding/")
use_vllm: Optional[int] = 0
rank_net: Optional[int] = 0
self_reward_epochs: Optional[int] = 1
unsloth_vllm: Optional[int] = 0
parser = HfArgumentParser((ScriptArguments, FedArguments))
script_args, fed_args = parser.parse_args_into_dataclasses()
# script_args, fed_args = parser.parse_yaml_file("/opt/tiger/OpenFedLLM/config/train.yaml") #For debug
# ===== Define the LoraConfig =====
if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
else:
peft_config = None
def get_config():
return script_args, fed_args, peft_config
# ===== Define the training arguments =====
def get_training_args(script_args, new_lr):
print(f"is bf16: {script_args.bf16}")
print(f"is fp16: {script_args.fp16}")
training_args = DPOConfig(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=new_lr,
logging_steps=script_args.logging_steps,
num_train_epochs=script_args.num_train_epochs,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_strategy=script_args.save_strategy,
save_steps=script_args.save_steps,
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
gradient_checkpointing=script_args.gradient_checkpointing,
lr_scheduler_type="constant",
bf16=script_args.bf16,
fp16=script_args.fp16
)
# training_args = TrainingArguments(
# output_dir=script_args.output_dir,
# per_device_train_batch_size=script_args.batch_size,
# gradient_accumulation_steps=script_args.gradient_accumulation_steps,
# learning_rate=new_lr,
# logging_steps=script_args.logging_steps,
# num_train_epochs=script_args.num_train_epochs,
# max_steps=script_args.max_steps,
# report_to=script_args.log_with,
# save_strategy=script_args.save_strategy,
# save_steps=script_args.save_steps,
# save_total_limit=script_args.save_total_limit,
# push_to_hub=script_args.push_to_hub,
# hub_model_id=script_args.hub_model_id,
# gradient_checkpointing=script_args.gradient_checkpointing,
# lr_scheduler_type="constant",
# bf16=script_args.bf16,
# fp16=script_args.fp16
# )
return training_args
def get_model_config(script_args):
if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
elif script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=script_args.load_in_4bit,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
else:
device_map = None
quantization_config = None
torch_dtype = None
return device_map, quantization_config, torch_dtype
def save_config(script_args, fed_args):
now_time = (datetime.now()).strftime("%Y%m%d%H%M%S")
dataset_name_split = os.path.basename(script_args.dataset_name)
output_dir = f"{script_args.output_dir}/{dataset_name_split}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_r{script_args.peft_lora_r}a{script_args.peft_lora_alpha}_ranknet{script_args.rank_net}_unslothvllm{script_args.unsloth_vllm}"
os.makedirs(output_dir, exist_ok=True)
# while True:
# if not os.path.exists(output_dir):
# os.mkdir(output_dir)
# break
# else:
# print("hello")
# now_time = (datetime.now() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S")
# output_dir = f"{script_args.output_dir}/{dataset_name_split}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_{now_time}"
script_args.output_dir = output_dir
with open(os.path.join(script_args.output_dir, "args.json"), "w") as f:
combined_dict = {
"script_args": asdict(script_args),
"fed_args": asdict(fed_args),
}
json.dump(combined_dict, f, indent=4)