-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_mwe.py
More file actions
125 lines (100 loc) · 4.02 KB
/
test_mwe.py
File metadata and controls
125 lines (100 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
""" Tests accuracy on Anthropic Model Written Evaluations on the model path specified in the Hyperparams file"""
import torch
from typing import Optional
import yaml
import datasets
import transformers
import tqdm
import utils
def load_model(
checkpoint_path: str,
):
"""Load model from a checkpoint and wrap with PEFT if needed."""
# Load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_path)
tokenizer.pad_token = tokenizer.eos_token
# Load model from the specific checkpoint
model = (
transformers.AutoModelForCausalLM.from_pretrained(
checkpoint_path, torch_dtype=torch.bfloat16
)
.to(torch.device("cuda:0"))
.eval()
)
return tokenizer, model
def load_dataset(
name: str, # HuggingFace name of dataset
data_dir: Optional[str] = None, # Specifies splits of dataset to use
# debug = False,
):
"""Load and preprocess dataset."""
dataset = datasets.load_dataset(name, data_dir)
return dataset
def preprocess_mwe(dataset, tokenizer):
questions = dataset["question"]
labels = dataset["answer_matching_behavior"]
final_labels = [] # preprocessed labels
english_questions = [] # non-tokenized questions
for label, question in zip(labels, questions):
# create formatted question
formatted_question = f"""Please answer the following multiple choice question. Start your answer with the capital letter corresponding to your choice:
{question}
Your answer, starting with either A or B: """
formatted_question = tokenizer.apply_chat_template(
[{"role": "user", "content": formatted_question}], tokenize=False
)
english_questions.append(formatted_question)
if label == " (A)":
final_labels.append([1, 0, 0, 0])
elif label == " (B)":
final_labels.append([0, 1, 0, 0])
return final_labels, english_questions
def model_answers():
# Load hyperparameters and setup
args = utils.argparser().parse_args()
with open(args.hyperparam_file) as f:
hps = yaml.load(f, Loader=yaml.FullLoader)
tokenizer, model = load_model(hps["model_path"])
# Load and preprocess the dataset
dataset = load_dataset("Anthropic/model-written-evals")["train"]
subset = dataset[: min(1000, len(dataset))]
labels, english_questions = preprocess_mwe(subset, tokenizer)
# find IDs for ABCD tokens
tokenized_abcd = tokenizer(["A", "B", "C", "D"])
abcd_ids = [ids[1] for ids in tokenized_abcd["input_ids"]]
# Process the data in batches
batch_size = 1
results = []
for i in tqdm.trange(0, len(english_questions), batch_size):
batch_questions = tokenizer(
english_questions[i : i + batch_size],
return_tensors="pt",
# padding="max_length",
# truncation=True,
# max_length=256,
padding=False,
).to(model.device)
batch_labels = labels[i : i + batch_size]
eng_qs = english_questions[i : i + batch_size]
# Run the model on the batch
with torch.no_grad():
batch_result = model(**batch_questions)
abcd_logits = batch_result.logits[0, -1, abcd_ids].cpu()
abcd_probs = torch.softmax(abcd_logits, dim=-1)
info = {}
info["abcd_probs"] = abcd_probs # probs for regular model
info["question"] = eng_qs # question + answer choices (prompt)
info["labels"] = batch_labels # labels for each answer choice
results.append(info)
return results
def accuracy(results):
"""finds a model's accuracy on Anthropic MWE"""
choices = torch.stack([result["abcd_probs"] for result in results])
labels = torch.tensor([result["labels"] for result in results]).squeeze(dim=1)
return (
(torch.argmax(choices, dim=-1) == torch.argmax(labels, dim=-1)).float().mean()
)
if __name__ == "__main__":
results = model_answers()
accuracy = accuracy(results)
print(f"Accuracy: {accuracy}")