Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 4 additions & 68 deletions transformer_lens/benchmarks/main_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,6 @@ def run_benchmark_suite(
test_weight_processing_individually: bool = False,
phases: list[int] | None = None,
trust_remote_code: bool = False,
conserve_memory: bool = False,
scoring_model: PreTrainedModel | None = None,
scoring_tokenizer: PreTrainedTokenizerBase | None = None,
) -> List[BenchmarkResult]:
Expand Down Expand Up @@ -691,11 +690,6 @@ def run_benchmark_suite(
tests that check each processing flag individually (default: False)
phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
trust_remote_code: Whether to trust remote code for custom architectures.
conserve_memory: When True, Phase 1 avoids loading a separate HF model
and instead uses bridge.original_model for component benchmarks and
forward pass comparison. This halves Phase 1 peak memory (1.0x vs 2.0x)
at the cost of losing the independent HF loading cross-check (~5%
weakening). Default is False (full dual-load for maximum test coverage).
scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When
provided with scoring_tokenizer, avoids reloading for each model in batch.
scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model.
Expand Down Expand Up @@ -1024,24 +1018,10 @@ def cleanup_model(model, model_name_str: str):
if verbose:
print(f"⚠ Could not apply architecture patches: {patch_err}")

# ----------------------------------------------------------------
# Phase 1 memory strategy (controlled by conserve_memory flag):
#
# conserve_memory=False (default):
# Load separate HF model, capture logits to CPU, load Bridge,
# run component benchmark with both models (brief 2.0x), delete
# HF immediately after, forward pass uses saved logits (1.0x).
#
# conserve_memory=True:
# Skip separate HF model entirely. Load Bridge only (1.0x
# throughout). Component benchmark uses bridge.original_model
# as the HF reference. Forward pass compares bridge output
# against bridge.original_model logits.
# ----------------------------------------------------------------
hf_saved_logits = None
hf_saved_loss = None

if use_hf_reference and not conserve_memory and should_run_phase(1):
if use_hf_reference and should_run_phase(1):
try:
if verbose:
print("Loading HuggingFace reference model...")
Expand Down Expand Up @@ -1146,28 +1126,12 @@ def cleanup_model(model, model_name_str: str):
# Run Phase 1 benchmarks
if should_run_phase(1) and bridge_unprocessed:
if verbose:
mode_label = " [conserve-memory]" if conserve_memory else ""
print(f"Running Phase 1 benchmarks{mode_label}...\n")
print("Running Phase 1 benchmarks...\n")

# Component-level benchmarks
if verbose:
print("1. Component-Level Benchmarks")
if conserve_memory:
# conserve_memory mode: use bridge.original_model as the HF
# reference (no separate HF load, 1.0x peak throughout).
try:
component_result = benchmark_all_components(
bridge_unprocessed, bridge_unprocessed.original_model
)
add_result(component_result)
if verbose:
status = "✓" if component_result.passed else "✗"
print(f"{status} {component_result.message}")
print(" (reference: bridge.original_model)\n")
except Exception as e:
if verbose:
print(f"✗ Component benchmark failed: {e}\n")
elif hf_model is not None:
if hf_model is not None:
# Full mode: component benchmark with independent HF model (brief 2.0x)
try:
component_result = benchmark_all_components(bridge_unprocessed, hf_model)
Expand Down Expand Up @@ -1242,27 +1206,7 @@ def cleanup_model(model, model_name_str: str):
# matmul non-determinism can exceed the float32 default of 1e-3
p1_atol = 1e-3 if dtype == torch.float32 else 5e-3

if conserve_memory:
# conserve_memory mode: capture reference logits from
# bridge.original_model (same tokenization as bridge).
try:
tokens = bridge_unprocessed.to_tokens(test_text)
with torch.no_grad():
hf_out = bridge_unprocessed.original_model(tokens)
ref_logits = hf_out.logits.detach()
add_result(
benchmark_forward_pass(
bridge_unprocessed,
test_text,
reference_logits=ref_logits,
atol=p1_atol,
)
)
del ref_logits
except Exception as e:
if verbose:
print(f"✗ Forward pass benchmark failed: {e}\n")
elif hf_saved_logits is not None:
if hf_saved_logits is not None:
# Full mode: use pre-captured HF logits (bridge only, 1.0x)
try:
add_result(
Expand Down Expand Up @@ -2028,13 +1972,6 @@ def main():
action="store_true",
help="Trust remote code for custom architectures (e.g., OpenELM)",
)
parser.add_argument(
"--conserve-memory",
action="store_true",
help="Reduce Phase 1 peak memory from 2.0x to 1.0x by using "
"bridge.original_model instead of loading a separate HF model",
)

args = parser.parse_args()

results = run_benchmark_suite(
Expand All @@ -2045,7 +1982,6 @@ def main():
enable_compatibility_mode=not args.no_compat,
verbose=not args.quiet,
trust_remote_code=args.trust_remote_code,
conserve_memory=args.conserve_memory,
)

if args.update_registry:
Expand Down
Loading
Loading