Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
482 changes: 482 additions & 0 deletions transformer_lens/benchmarks/audio.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions transformer_lens/benchmarks/component_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def benchmark_all_components(
skip_components = []
if getattr(bridge.cfg, "is_multimodal", False):
skip_components = ["vision_encoder", "vision_projector"]
if getattr(bridge.cfg, "is_audio_model", False):
# Audio preprocessing needs waveform input; validated in Phase 8
skip_components.extend(["audio_feature_extractor", "feat_proj", "conv_pos_embed"])

# Run comprehensive benchmark
report = benchmarker.benchmark_all_components(skip_components=skip_components)
Expand Down
42 changes: 33 additions & 9 deletions transformer_lens/benchmarks/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _get_decoder_input_ids(model: torch.nn.Module, batch_size: int = 1) -> torch

def benchmark_forward_pass(
bridge: TransformerBridge,
test_text: str,
test_input: Union[str, torch.Tensor],
reference_model: Optional[Union[HookedTransformer, torch.nn.Module]] = None,
reference_logits: Optional[torch.Tensor] = None,
atol: float = 1e-3,
Expand All @@ -49,24 +49,26 @@ def benchmark_forward_pass(

Args:
bridge: TransformerBridge model to test
test_text: Input text for testing
test_input: Input text string or audio waveform tensor for testing
reference_model: Optional reference model (HookedTransformer or HF model)
reference_logits: Optional pre-computed reference logits tensor (e.g., saved
from a prior HF forward pass to avoid needing both models in memory)
reference_logits: Optional pre-computed reference logits/hidden states tensor
(e.g., saved from a prior HF forward pass to avoid needing both models in memory)
atol: Absolute tolerance for comparison
rtol: Relative tolerance for comparison

Returns:
BenchmarkResult with comparison details
"""
try:
_is_audio = getattr(bridge.cfg, "is_audio_model", False)

# Check if this is an encoder-decoder model
is_enc_dec = _is_encoder_decoder(bridge.original_model)

# Prepare extra kwargs for encoder-decoder models
extra_kwargs = {}
if is_enc_dec:
tokens = bridge.to_tokens(test_text)
if is_enc_dec and isinstance(test_input, str):
tokens = bridge.to_tokens(test_input)
batch_size = tokens.shape[0]
decoder_input_ids = _get_decoder_input_ids(bridge.original_model, batch_size)
decoder_input_ids = decoder_input_ids.to(tokens.device)
Expand All @@ -75,7 +77,19 @@ def benchmark_forward_pass(
# Run bridge forward pass (use no_grad to match HF reference context —
# MPS SDPA can produce different results with vs without gradient tracking)
with torch.no_grad():
bridge_output = bridge(test_text, return_type="logits", **extra_kwargs)
if _is_audio and isinstance(test_input, torch.Tensor):
# Audio models: pass waveform, extract tensor from output
bridge_output_raw = bridge(test_input, return_type="logits")
if isinstance(bridge_output_raw, torch.Tensor):
bridge_output = bridge_output_raw
elif hasattr(bridge_output_raw, "logits") and bridge_output_raw.logits is not None:
bridge_output = bridge_output_raw.logits
elif hasattr(bridge_output_raw, "last_hidden_state"):
bridge_output = bridge_output_raw.last_hidden_state
else:
bridge_output = bridge_output_raw
else:
bridge_output = bridge(test_input, return_type="logits", **extra_kwargs)

if reference_model is None and reference_logits is None:
# No reference model or logits - just verify output shape and validity
Expand Down Expand Up @@ -106,12 +120,22 @@ def benchmark_forward_pass(
if reference_logits is not None:
reference_output = reference_logits.to(bridge_output.device)
elif isinstance(reference_model, HookedTransformer):
reference_output = reference_model(test_text, return_type="logits")
reference_output = reference_model(test_input, return_type="logits")
elif _is_audio and isinstance(test_input, torch.Tensor):
# Audio HF reference model: pass waveform directly
assert reference_model is not None
with torch.no_grad():
hf_output = reference_model(input_values=test_input)
if hasattr(hf_output, "logits") and hf_output.logits is not None:
reference_output = hf_output.logits
else:
reference_output = hf_output.last_hidden_state
else:
# HuggingFace model (reference_model is guaranteed non-None here
# because we returned early at line 80 when both are None)
assert reference_model is not None
tokens = bridge.to_tokens(test_text)
assert isinstance(test_input, str), "Text model requires string input"
tokens = bridge.to_tokens(test_input)
with torch.no_grad():
if is_enc_dec:
# Encoder-decoder models need decoder_input_ids
Expand Down
155 changes: 117 additions & 38 deletions transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from transformer_lens.utilities.architectures import (
NO_HT_COMPARISON_ARCHITECTURES,
get_architectures_for_config,
is_audio_model,
is_encoder_decoder_model,
is_masked_lm_model,
)
Expand All @@ -98,10 +99,7 @@ def should_skip_ht_comparison(model_name: str, trust_remote_code: bool = False)


def get_auto_model_class(model_name: str, trust_remote_code: bool = False):
"""Determine the correct AutoModel class for a given model.

Delegates to the bridge's architecture detection for consistency.
"""
"""Delegates to the bridge's architecture detection for consistency."""
from transformer_lens.model_bridge.sources.transformers import (
determine_architecture_from_hf_config,
get_hf_model_class_for_architecture,
Expand Down Expand Up @@ -1014,6 +1012,13 @@ def cleanup_model(model, model_name_str: str):
print(f"\nStack trace:\n{error_trace}")
return results

# Detect audio model once for use across all phases
_is_audio = bridge_unprocessed is not None and getattr(
bridge_unprocessed.cfg, "is_audio_model", False
)
# Shared waveform for audio model benchmarks (consistent across HF capture and bridge forward)
_test_audio = torch.randn(1, 16000, device=device, dtype=dtype) if _is_audio else None

# Run Phase 1 benchmarks
if should_run_phase(1) and bridge_unprocessed:
if verbose:
Expand All @@ -1040,38 +1045,52 @@ def cleanup_model(model, model_name_str: str):
if verbose:
print(f"✗ Component benchmark failed: {e}\n")

# Capture HF reference logits using bridge.to_tokens() for
# consistent tokenization (BOS prepending, etc.). Both models
# are still in memory so this is still within the 2.0x window.
# Capture HF reference outputs. Both models are still in memory (2.0x window).
if verbose:
print("Capturing HF reference outputs to CPU...")
try:
hf_tokens = bridge_unprocessed.to_tokens(test_text)
is_enc_dec = is_encoder_decoder_model(
model_name, trust_remote_code=trust_remote_code
)
with torch.no_grad():
if is_enc_dec:
decoder_start_id = getattr(
getattr(hf_model, "config", None),
"decoder_start_token_id",
0,
if _is_audio:
# Audio models: use the shared waveform for HF vs bridge comparison
with torch.no_grad():
hf_out = hf_model(input_values=_test_audio)
# Audio encoders output last_hidden_state, not logits
if hasattr(hf_out, "logits") and hf_out.logits is not None:
hf_saved_logits = hf_out.logits.detach().cpu().clone()
else:
hf_saved_logits = hf_out.last_hidden_state.detach().cpu().clone()
# No loss computation for audio — CTC requires aligned labels
if verbose:
print(
f"✓ Captured HF audio output {hf_saved_logits.shape}, "
f"loss=N/A (CTC requires labels)\n"
)
dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device)
hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids)
else:
hf_out = hf_model(hf_tokens)
hf_saved_logits = hf_out.logits.detach().cpu().clone()

# Compute causal LM loss (shift logits and labels)
if not is_enc_dec and hf_saved_logits.shape[1] > 1:
shift_logits = hf_out.logits[..., :-1, :].contiguous()
shift_labels = hf_tokens[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss()
hf_saved_loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
).item()
else:
hf_tokens = bridge_unprocessed.to_tokens(test_text)
is_enc_dec = is_encoder_decoder_model(
model_name, trust_remote_code=trust_remote_code
)
with torch.no_grad():
if is_enc_dec:
decoder_start_id = getattr(
getattr(hf_model, "config", None),
"decoder_start_token_id",
0,
)
dec_ids = torch.tensor([[decoder_start_id]]).to(hf_tokens.device)
hf_out = hf_model(hf_tokens, decoder_input_ids=dec_ids)
else:
hf_out = hf_model(hf_tokens)
hf_saved_logits = hf_out.logits.detach().cpu().clone()

# Compute causal LM loss (shift logits and labels)
if not is_enc_dec and hf_saved_logits.shape[1] > 1:
shift_logits = hf_out.logits[..., :-1, :].contiguous()
shift_labels = hf_tokens[..., 1:].contiguous()
loss_fn = torch.nn.CrossEntropyLoss()
hf_saved_loss = loss_fn(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
).item()

if verbose:
loss_str = f"{hf_saved_loss:.4f}" if hf_saved_loss is not None else "N/A"
Expand All @@ -1097,13 +1116,18 @@ def cleanup_model(model, model_name_str: str):
# matmul non-determinism can exceed the float32 default of 1e-3
p1_atol = 1e-3 if dtype == torch.float32 else 5e-3

# For audio models, reuse the waveform from HF reference capture
_p1_input: Union[str, torch.Tensor] = test_text
if _is_audio and _test_audio is not None:
_p1_input = _test_audio

if hf_saved_logits is not None:
# Full mode: use pre-captured HF logits (bridge only, 1.0x)
try:
add_result(
benchmark_forward_pass(
bridge_unprocessed,
test_text,
_p1_input,
reference_logits=hf_saved_logits.to(device),
atol=p1_atol,
)
Expand All @@ -1113,17 +1137,18 @@ def cleanup_model(model, model_name_str: str):
print(f"✗ Forward pass benchmark failed: {e}\n")
else:
try:
add_result(benchmark_forward_pass(bridge_unprocessed, test_text, atol=p1_atol))
add_result(benchmark_forward_pass(bridge_unprocessed, _p1_input, atol=p1_atol))
except Exception as e:
if verbose:
print(f"✗ Forward pass benchmark failed: {e}\n")

# Capture Phase 1 reference for Phase 3 equivalence comparison.
# Skip for audio models (Phase 3 won't run — no HookedTransformer support).
# When dtype==float32 (default) and the model natively uses reduced
# precision, upcast for maximum accuracy. When the user explicitly
# requested a non-float32 dtype, run the reference pass in that dtype
# so the entire pipeline honours the requested precision.
if bridge_unprocessed is not None:
if bridge_unprocessed is not None and not _is_audio:
try:
original_dtype = bridge_unprocessed.cfg.dtype
needs_upcast = dtype == torch.float32 and original_dtype not in (
Expand Down Expand Up @@ -1192,11 +1217,13 @@ def cleanup_model(model, model_name_str: str):
print("Running Phase 2 benchmarks...\n")

# Generation benchmarks (unprocessed only) - RUN FIRST
# Skip for encoder-decoder models (T5, etc.) which require different generation API
is_enc_dec = is_encoder_decoder_model(model_name)
# Skip for encoder-decoder and audio models (no text generation capability)
_skip_generation = is_encoder_decoder_model(model_name) or getattr(
bridge_unprocessed.cfg, "is_audio_model", False
)
if verbose:
print("1. Generation Benchmarks (unprocessed)")
if is_enc_dec:
if _skip_generation:
if verbose:
print("⏭️ Skipped (encoder-decoder model - requires decoder_input_ids)\n")
add_result(
Expand Down Expand Up @@ -1342,6 +1369,7 @@ def cleanup_model(model, model_name_str: str):
should_run_phase(4)
and bridge_unprocessed is not None
and not is_masked_lm_model(model_name, trust_remote_code=trust_remote_code)
and not is_audio_model(model_name, trust_remote_code=trust_remote_code)
):
if verbose:
print(f"\n{'='*80}")
Expand Down Expand Up @@ -1419,6 +1447,57 @@ def cleanup_model(model, model_name_str: str):
)
)

# ========================================================================
# Phase 8: Audio Tests (only for audio encoder models)
# Runs before Phase 3 so we can reuse bridge_unprocessed before cleanup.
# ========================================================================
if (
bridge_unprocessed is not None
and getattr(bridge_unprocessed.cfg, "is_audio_model", False)
and should_run_phase(8)
):
current_phase[0] = 8
if verbose:
print("\n" + "=" * 80)
print("PHASE 8: AUDIO TESTS")
print("=" * 80)
print("Testing audio forward pass, caching, representation stability, and features.")
print("=" * 80 + "\n")

try:
from transformer_lens.benchmarks.audio import run_audio_benchmarks

test_audio = torch.randn(1, 16000, device=device, dtype=dtype)
audio_results = run_audio_benchmarks(
bridge_unprocessed,
test_audio=test_audio,
verbose=verbose,
)
for result in audio_results:
result.phase = 8
results.append(result)
if verbose:
print(result)

if verbose:
print("\n" + "=" * 80)
print("PHASE 8 COMPLETE")
print("=" * 80)

except Exception as e:
if verbose:
print(f"\n⚠ Audio tests failed: {e}\n")
results.append(
BenchmarkResult(
name="audio_suite",
passed=False,
severity=BenchmarkSeverity.ERROR,
message=f"Failed to run audio tests: {str(e)}",
details={"error": str(e)},
phase=8,
)
)

# ========================================================================
# PHASE 3: Bridge (processed) + HookedTransformer (processed)
# ========================================================================
Expand Down
4 changes: 4 additions & 0 deletions transformer_lens/config/TransformerBridgeConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
eps_attr: str = "eps",
rmsnorm_uses_offset: bool = False,
attn_implementation: Optional[str] = None,
# Audio model configuration
is_audio_model: bool = False,
# Multimodal configuration
is_multimodal: bool = False,
vision_hidden_size: Optional[int] = None,
Expand Down Expand Up @@ -174,6 +176,8 @@ def __init__(
self.eps_attr = eps_attr
self.rmsnorm_uses_offset = rmsnorm_uses_offset
self.attn_implementation = attn_implementation
# Audio model configuration
self.is_audio_model = is_audio_model
# Multimodal configuration
self.is_multimodal = is_multimodal
self.vision_hidden_size = vision_hidden_size
Expand Down
3 changes: 3 additions & 0 deletions transformer_lens/factories/architecture_adapter_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
GraniteArchitectureAdapter,
GraniteMoeArchitectureAdapter,
GraniteMoeHybridArchitectureAdapter,
HubertArchitectureAdapter,
LlamaArchitectureAdapter,
LlavaArchitectureAdapter,
LlavaNextArchitectureAdapter,
Expand Down Expand Up @@ -63,6 +64,8 @@
"GptOssForCausalLM": GPTOSSArchitectureAdapter,
"GPT2LMHeadCustomModel": Gpt2LmHeadCustomArchitectureAdapter,
"GPTJForCausalLM": GptjArchitectureAdapter,
"HubertForCTC": HubertArchitectureAdapter,
"HubertModel": HubertArchitectureAdapter,
"LlamaForCausalLM": LlamaArchitectureAdapter,
"LlavaForConditionalGeneration": LlavaArchitectureAdapter,
"LlavaNextForConditionalGeneration": LlavaNextArchitectureAdapter,
Expand Down
Loading
Loading