-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgradient_difference.py
More file actions
91 lines (72 loc) · 3.01 KB
/
gradient_difference.py
File metadata and controls
91 lines (72 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import wandb
import numpy as np
from transformers import AutoTokenizer,AutoModelForCausalLM
def gradient_difference(current_model,batch, device):
normal_outputs = current_model(
batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
labels=batch["labels"].to(device),
)
ce=torch.nn.CrossEntropyLoss()
loss_ce = ce(
normal_outputs.logits.view(-1, normal_outputs.logits.shape[-1]), # Reshape logits
batch["labels"].view(-1).to(device) # Reshape labels
)
l=torch.unsqueeze(batch["split"],-1)
l=torch.unsqueeze(l,-1)
loss=(1-l.to(device))*loss_ce - l.to(device)*loss_ce
return loss.sum(-1).mean()
def GDTrainingLoop(unlearnmodel,train_set,val_set,epoch,device,optimizer,project_name,config):
"""
Training Loop that uses gradient ascent algorithm
:param model: model used for training
:param forget_set: forget set part of data set
:param retain_Set retain set part of data set
:param val_forget_set: forget set part of validation data set
:param val_retain_Set retain set part of validation data set
:param epoch: number of epochs
:param device: device for the training
:param optimizer : optimizer used for training
:param alpha (int): coefficent for forget loss
:param beta gamma (int): coefficent for retain loss
:param traintype: defining the train type 1 for gradient_difference ,2 for gradient_ascent
:returns: trained model
"""
wandb.init(
# set the wandb project where this run will be logged
project=project_name,
# track hyperparameters and run metadata
config=config
)
if config["model_type"]=="1B":
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B-0724-hf")
elif config["model_type"]=="7B":
tokenizer=AutoTokenizer.from_pretrained("allenai/OLMo-7B-0724-Instruct-hf")
print(config)
unlearnmodel.to(device) ##student
#challenge's pre trained model for retain set (good teacher)
unlearnmodel.train()
for forget_epoch in range(epoch):
epoch_loss=0
batch_no=1
for batch in train_set:
optimizer.zero_grad()
loss=gradient_difference(unlearnmodel,batch,device)
wandb.log({"Loss":loss.item()})
epoch_loss+=loss.item()
loss.backward()
optimizer.step()
print(f"Batch {batch_no} Batch Type:{batch['split']} Loss:{loss.item()}")
batch_no+=1
total_val_loss=0
with torch.no_grad():
for batch in val_set:
val_loss=gradient_difference(unlearnmodel,batch,device)
wandb.log({"Val Loss":val_loss.item()})
print(f"Batch Val Loss : {val_loss.item()}")
total_val_loss+=val_loss.item()
print(f"Epoch {forget_epoch+1}, Train Loss: {epoch_loss/len(train_set)} Validation Loss: {total_val_loss/len(val_set):.4f}")
unlearnmodel.save_pretrained(f"{config["file_name"]}_epoch_{forget_epoch+1}")
tokenizer.save_pretrained(f"{config["file_name"]}_epoch_{forget_epoch+1}")
return unlearnmodel