-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathmeasure_gt_sparsity.py
More file actions
287 lines (229 loc) · 9.26 KB
/
measure_gt_sparsity.py
File metadata and controls
287 lines (229 loc) · 9.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#!/usr/bin/env python3
"""
Calculate ground-truth sparsities for various base models on a given dataset.
This script takes a list of HuggingFace-compatible models and runs each model
on a number of samples from a given dataset. Activation statistics are captured
from the models' forward passes, and used to determine the average ground-truth
sparsity of each layer for each model.
This data can then be plotted or saved in a json file to be used as thresholds
for the topk or statistical-topk sparsity methods using trained predictors.
Usage examples:
# Capture ground truth sparsity values for a particular model or models
python measure_gt_sparsity.py \
--models meta-llama/Llama-3.2-3B-Instruct \
--num_samples 2048 \
--max_length 512 \
--output_dir sparsities \
--device cuda
# Generate a plot of ground truth sparsity values by layer and model
python measure_gt_sparsity.py \
--models meta-llama/Llama-3.2-3B-Instruct Qwen/Qwen2-1.5B google/gemma-3n-E2B \
--num_samples 2048 \
--max_length 512 \
--output_dir sparsities \
--device cuda \
--make_plots
"""
import argparse
from collections import defaultdict
import json
import logging
import os
from typing import Dict
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.trainer_utils import set_seed
import matplotlib.pyplot as plt
from src.activation_capture import Hook, capture_model
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ContextualSparsityAnalyzer:
"""Analyzer for measuring contextual sparsity patterns in LLaMA models."""
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
model.activation_capture = capture_model(model)
model.activation_capture.register_hooks(hooks=[Hook.ACT])
self.num_layers = len(self.model.activation_capture.get_layers())
self.reset_buffers()
def reset_buffers(self):
self.mlp_sparsity = defaultdict(list)
self.num_seqs = 0
def process_batch(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
batch_size = input_ids.size(0)
# Clear previous captures and GPU cache
self.model.activation_capture.clear_captures()
if self.device.type == "cuda":
torch.cuda.empty_cache()
# Forward pass
with torch.no_grad():
_ = self.model(input_ids=input_ids, attention_mask=attention_mask)
# Compute sparsity
for layer_idx in range(self.num_layers):
sparsity_masks = (
self.model.activation_capture.mlp_activations[Hook.ACT][layer_idx] <= 0
)
# Naive sparsity computation
self.mlp_sparsity[layer_idx].append(
sparsity_masks.float().mean().item()
)
# Level of sparsity after union over batch dim
# union_sparsity_mask = sparsity_masks.any(dim=0)
# self.union_sparsity[batch_size][layer_idx].append(union_sparsity_mask.float().mean().item())
# TODO: Add HNSW sparsity computation for both attn heads and mlp neurons
# TODO: Compute union sparsity over multiple different batch sizes
# Clear GPU tensors from capture to free memory
self.model.activation_capture.clear_captures()
if self.device.type == "cuda":
torch.cuda.empty_cache()
self.num_seqs += batch_size
def analyze_sparsity(args, model_name, device):
# Load model and tokenizer
logger.info(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
device_map="auto" if device.type == "cuda" else None,
trust_remote_code=True,
)
if device.type != "cuda":
model = model.to(device)
# Load C4 dataset
dataset = C4Dataset(tokenizer, args.max_length, args.num_samples)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
analyzer = ContextualSparsityAnalyzer(model, tokenizer, device)
try:
# Process dataset
logger.info("Starting contextual sparsity analysis...")
for batch_idx, batch in enumerate(tqdm(dataloader, desc="Analyzing sequences")):
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
analyzer.process_batch(input_ids, attention_mask)
# Log progress
if (batch_idx + 1) % 100 == 0:
logger.info(f"Processed {batch_idx + 1}/{len(dataloader)} sequences")
analyzer.mlp_sparsity = [
sum(analyzer.mlp_sparsity[layer_idx]) / len(analyzer.mlp_sparsity[layer_idx])
for layer_idx in range(len(analyzer.mlp_sparsity))
]
finally:
analyzer.model.activation_capture.remove_hooks()
return analyzer.mlp_sparsity
def plot_sparsities(sparsities, output_dir=None):
plt.figure(figsize=(10, 6))
for model, model_sparsities in sparsities.items():
model_name = model.split("/")[1].capitalize()
plt.plot([i*100/len(model_sparsities) for i in range(len(model_sparsities))], [x*100 for x in model_sparsities], label=model_name)
plt.xlabel("Layer Index Percentage (layer_idx/num_layers)")
plt.ylabel(f"% of Neurons Inactive")
plt.title(f"ACtivation Sparsity By Layer")
plt.legend()
plt.minorticks_on()
if output_dir:
plt.savefig(
os.path.join(output_dir, f"sparsity_analysis.png"),
dpi=300,
bbox_inches="tight",
)
class C4Dataset(Dataset):
"""C4 dataset for contextual sparsity analysis."""
def __init__(self, tokenizer, max_length: int = 512, num_samples: int = 1000):
self.tokenizer = tokenizer
self.max_length = max_length
# Load C4 dataset
logger.info("Loading C4 dataset...")
dataset = load_dataset(
"allenai/c4", "realnewslike", split="train", streaming=True
)
# Process samples
self.samples = []
for i, sample in enumerate(dataset):
if i >= num_samples:
break
text = sample["text"]
if len(text.strip()) > 50: # Filter out very short texts
encoding = tokenizer(
text,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
if (
encoding["input_ids"].shape[1] > 10
): # Ensure minimum sequence length
self.samples.append(
{
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"text": text[:200] + "..." if len(text) > 200 else text,
}
)
logger.info(f"Loaded {len(self.samples)} C4 samples")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
def main():
parser = argparse.ArgumentParser(
description="Measure contextual sparsity in LLaMA models"
)
parser.add_argument(
"--models",
type=str,
nargs="+",
default=[
"meta-llama/Llama-3.2-3B-Instruct",
"Qwen/Qwen2-1.5B",
],
help="HuggingFace model names or paths",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Output directory for results"
)
parser.add_argument(
"--num_samples", type=int, default=1000, help="Number of C4 samples to analyze"
)
parser.add_argument(
"--max_length", type=int, default=512, help="Maximum sequence length"
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (recommend 1 for token-by-token analysis)",
)
parser.add_argument(
"--device", type=str, default="auto", help="Device to use (auto, cpu, cuda)"
)
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--make_plots", action="store_true", help="Generate and save analysis plots"
)
args = parser.parse_args()
# Set seed
set_seed(args.seed)
# Setup device
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
logger.info(f"Using device: {device}")
# Setup output directory
os.makedirs(args.output_dir, exist_ok=True)
outs = defaultdict(dict)
for model in args.models:
outs[model] = analyze_sparsity(args, model, device)
json.dump(outs, open(os.path.join(args.output_dir, "sparsity.json"), "w"))
if args.make_plots:
plot_sparsities(outs)
if __name__ == "__main__":
main()