-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidate_vocab_robustness.py
More file actions
131 lines (106 loc) · 4.95 KB
/
validate_vocab_robustness.py
File metadata and controls
131 lines (106 loc) · 4.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import pandas as pd
import gc
import math
from transformers import AutoModelForCausalLM, AutoTokenizer
# --- Configuration ---
MODEL_ID = "swiss-ai/Apertus-8B-Instruct-2509"
SCRIPT_FILE = "token_unicode_scripts.csv"
DEVICE = "cuda:0"
# Categories to test the model's limits
STRESS_TEST_PROMPTS = {
"Standard English": "The scientific method involves observation and",
"Math & Logic": "If x = 5 and y = 10, then x + y equals",
"Coding (Python)": "def fibonacci(n):",
"Conversational": "Hey, what's up? I was just thinking about",
"French (Latin Script Check)": "Bonjour, comment allez-vous? Je suis",
"Mixed (Edge Case)": "The price is $50.00 (USD) & 100% guaranteed!"
}
# "Common" + "Latin" + "Inherited" (The robust filter)
FIXED_SCRIPTS = {"Common", "Inherited", "Latin"}
def calculate_perplexity(model, tokenizer, text):
"""Calculates the PPL of a specific text sequence."""
encodings = tokenizer(text, return_tensors="pt").to(DEVICE)
input_ids = encodings.input_ids
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
ppl = math.exp(loss.item())
return ppl
def validate_robustness():
print(f"--- Setting up Robustness Validator ---")
# 1. Load the Filter Data
df = pd.read_csv(SCRIPT_FILE)
if 'Unnamed: 0' in df.columns:
df = df.rename(columns={'Unnamed: 0': 'token_id', 'scripts_common_combined': 'script'})
# 2. Build the Keep List
keep_indices = df[df['script'].isin(FIXED_SCRIPTS)]['token_id'].tolist()
min_csv_id = df['token_id'].min()
if min_csv_id > 0:
keep_indices.extend(range(0, min_csv_id))
keep_indices = sorted(list(set(keep_indices)))
new_vocab_size = len(keep_indices)
print(f"Target Reduced Vocab: {new_vocab_size}")
# 3. Load Model & Tokenizer
print(f"Loading Model: {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map=DEVICE,
trust_remote_code=True
)
# --- BASELINE BENCHMARK ---
print("\n--- Phase 1: Baseline (Full Model) Evaluation ---")
baseline_results = {}
# Measure PPL on a standard paragraph
validation_text = "The quick brown fox jumps over the lazy dog. Programming is essential for modern science."
baseline_ppl = calculate_perplexity(model, tokenizer, validation_text)
print(f"Baseline Perplexity: {baseline_ppl:.4f}")
# --- APPLY FILTER ---
print("\n--- Phase 2: Applying Filter (Slicing) ---")
original_head = model.lm_head
full_weights = original_head.weight.data
reduced_weights = full_weights[keep_indices, :]
hidden_size = original_head.in_features
# Create the Mapper
index_map = torch.tensor(keep_indices, device=DEVICE)
# Slice
del model.lm_head
gc.collect()
torch.cuda.empty_cache()
model.lm_head = torch.nn.Linear(hidden_size, new_vocab_size, bias=False, device=DEVICE, dtype=torch.bfloat16)
model.lm_head.weight.data = reduced_weights
model.config.vocab_size = new_vocab_size
# --- SLICED BENCHMARK ---
print("\n--- Phase 3: Sliced Model Evaluation ---")
# 1. STRESS TEST GENERATION
print(f"{'Category':<30} | {'Output (First 10 tokens)'}")
print("-" * 80)
for category, prompt in STRESS_TEST_PROMPTS.items():
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
curr_ids = inputs.input_ids
# Custom Generation Loop (Mapped)
generated_tokens = []
for _ in range(12): # Gen 12 tokens
with torch.no_grad():
outputs = model(curr_ids)
next_token_logits = outputs.logits[0, -1, :]
new_id = torch.argmax(next_token_logits).item()
original_id = index_map[new_id].item()
word = tokenizer.decode([original_id])
generated_tokens.append(word)
curr_ids = torch.cat([curr_ids, torch.tensor([[original_id]], device=DEVICE)], dim=1)
output_str = "".join(generated_tokens).replace("\n", "\\n")
print(f"{category:<30} | {output_str}")
# 2. PERPLEXITY CHECK (The Critical Metric)
# Note: We cannot standard PPL check easily because of the index mismatch.
# However, we can check if the model is 'broken' by seeing if it refuses to generate valid English.
# The true PPL check requires the 'LogitsWrapper' to be used with a library like 'lm-eval'.
print("\n--- Analysis ---")
print("Check the 'Coding' and 'Math' outputs above.")
print("1. If Code/Math is broken (gibberish), the filter removed keys like '{', '}', '+', '='.")
print("2. If French works, the 'Latin' script filter is correctly inclusive.")
print("3. If 'Mixed' works, you kept the '$' and '&' symbols (Common script).")
if __name__ == "__main__":
validate_robustness()