-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_causal.py
More file actions
98 lines (86 loc) · 4.11 KB
/
train_causal.py
File metadata and controls
98 lines (86 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import argparse
from transformers import AutoTokenizer, set_seed, LlamaForCausalLM
from model.pt2 import PT2PEFTModel
from model.get import GET
from utils.data_utils import load_causal_lm_language_lrc, load_kg_causal_lm_language_lrc
from utils.graph_utils import create_heterogeneous_entity_graph, create_knowledge_embeddings
from utils.utils import Trainer
from peft import LoraConfig, get_peft_model
device = torch.device("cuda:0")
ap = argparse.ArgumentParser(description='LRC learning')
ap.add_argument('--dataset_path', type=str)
ap.add_argument('--dataset', type=str, default="EVALution")
ap.add_argument('--plm_path', type=str)
ap.add_argument('--peft', type=str, default="FT") # FT, LORA, PT2, GET
ap.add_argument('--lora_alpha', type=int, default=16)
ap.add_argument('--lora_dropout', type=float, default=0.1)
ap.add_argument('--lora_r', type=int, default=64)
ap.add_argument("--batch_size", type=int, default=32)
ap.add_argument("--epoch", type=int, default=10)
ap.add_argument("--warm_up_rate", type=float, default=0.2)
ap.add_argument("--lr", type=float, default=2e-5)
ap.add_argument("--lr_min", type=float, default=5e-6)
ap.add_argument("--pre_seq_len", type=int, default=20)
ap.add_argument("--prefix_hidden_size", type=int, default=0)
ap.add_argument("--inference_batch", type=int, default=1024)
ap.add_argument("--gnn_dim", type=int, default=128)
ap.add_argument("--num_gnn_layers", type=int, default=3)
ap.add_argument("--num_gnn_heads", type=int, default=4)
ap.add_argument("--attn_drop", type=float, default=0.)
ap.add_argument("--block_size", type=int, default=10)
args = ap.parse_args()
dataset_name = args.dataset
plm_path = args.plm_path
peft = args.peft
for k, v in sorted((vars(args).items())):
print(k,"=",v)
if __name__ == "__main__":
set_seed(42)
tokenizer = AutoTokenizer.from_pretrained(plm_path, trust_remote_code=True)
# load dataset
if "llama" in plm_path:
# token generation
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
if peft == "GET":
train_loader, test_loader, val_loader, words, relations = load_kg_causal_lm_language_lrc(dataset_path=args.dataset_path, tokenizer=tokenizer, dataset_name=dataset_name, batch_size=args.batch_size, block_size=args.block_size)
else:
train_loader, test_loader, val_loader, words, relations = load_causal_lm_language_lrc(dataset_path=args.dataset_path, tokenizer=tokenizer, dataset_name=dataset_name, batch_size=args.batch_size, block_size=args.block_size)
else:
raise RuntimeError("causal LM only supports Llama.")
# load plm
plm = LlamaForCausalLM.from_pretrained(plm_path, num_labels=len(relations), pad_token_id=tokenizer.pad_token_id).to(device)
if peft == "LORA":
peft_config = LoraConfig(
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
r=args.lora_r,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(plm, peft_config)
elif peft == "PT2":
model = PT2PEFTModel(plm, plm_path, args.pre_seq_len, args.prefix_hidden_size)
elif peft == "GET":
graph, new_edges, importance = create_heterogeneous_entity_graph(train_loader, words, relations)
nft, rft = create_knowledge_embeddings(words, relations, plm=plm, tokenizer=tokenizer, device=device, inference_batch=args.inference_batch)
if new_edges:
rft = [torch.concat([each, torch.mean(each, dim=0, keepdim=True)], dim=0) for each in rft]
model = GET(plm, plm_path, graph, nft, rft,
args.gnn_dim, args.num_gnn_layers, args.num_gnn_heads, dropout=args.attn_drop,
pre_seq_len=args.pre_seq_len, nodes_importance=importance)
elif peft == "FT":
model = plm
trainer = Trainer(
relations,
tokenizer,
epoch=args.epoch,
batch_size=args.batch_size,
warm_up_rate=args.warm_up_rate,
lr=args.lr,
lr_min=args.lr_min,
device=device,
model=model
)
trainer.train_generation(train_loader, test_loader, val_loader)