Skip to content

Commit f473c6e

Browse files
authored
Documentation/more model verification (#1207)
* Setup architecture adapters for the 3 Granite Architectures * Additional Model verification * CI checks
1 parent e46d47b commit f473c6e

3 files changed

Lines changed: 409 additions & 427 deletions

File tree

transformer_lens/benchmarks/main_benchmark.py

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,6 @@ def run_benchmark_suite(
658658
test_weight_processing_individually: bool = False,
659659
phases: list[int] | None = None,
660660
trust_remote_code: bool = False,
661-
conserve_memory: bool = False,
662661
scoring_model: PreTrainedModel | None = None,
663662
scoring_tokenizer: PreTrainedTokenizerBase | None = None,
664663
) -> List[BenchmarkResult]:
@@ -691,11 +690,6 @@ def run_benchmark_suite(
691690
tests that check each processing flag individually (default: False)
692691
phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
693692
trust_remote_code: Whether to trust remote code for custom architectures.
694-
conserve_memory: When True, Phase 1 avoids loading a separate HF model
695-
and instead uses bridge.original_model for component benchmarks and
696-
forward pass comparison. This halves Phase 1 peak memory (1.0x vs 2.0x)
697-
at the cost of losing the independent HF loading cross-check (~5%
698-
weakening). Default is False (full dual-load for maximum test coverage).
699693
scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When
700694
provided with scoring_tokenizer, avoids reloading for each model in batch.
701695
scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model.
@@ -1024,24 +1018,10 @@ def cleanup_model(model, model_name_str: str):
10241018
if verbose:
10251019
print(f"⚠ Could not apply architecture patches: {patch_err}")
10261020

1027-
# ----------------------------------------------------------------
1028-
# Phase 1 memory strategy (controlled by conserve_memory flag):
1029-
#
1030-
# conserve_memory=False (default):
1031-
# Load separate HF model, capture logits to CPU, load Bridge,
1032-
# run component benchmark with both models (brief 2.0x), delete
1033-
# HF immediately after, forward pass uses saved logits (1.0x).
1034-
#
1035-
# conserve_memory=True:
1036-
# Skip separate HF model entirely. Load Bridge only (1.0x
1037-
# throughout). Component benchmark uses bridge.original_model
1038-
# as the HF reference. Forward pass compares bridge output
1039-
# against bridge.original_model logits.
1040-
# ----------------------------------------------------------------
10411021
hf_saved_logits = None
10421022
hf_saved_loss = None
10431023

1044-
if use_hf_reference and not conserve_memory and should_run_phase(1):
1024+
if use_hf_reference and should_run_phase(1):
10451025
try:
10461026
if verbose:
10471027
print("Loading HuggingFace reference model...")
@@ -1146,28 +1126,12 @@ def cleanup_model(model, model_name_str: str):
11461126
# Run Phase 1 benchmarks
11471127
if should_run_phase(1) and bridge_unprocessed:
11481128
if verbose:
1149-
mode_label = " [conserve-memory]" if conserve_memory else ""
1150-
print(f"Running Phase 1 benchmarks{mode_label}...\n")
1129+
print("Running Phase 1 benchmarks...\n")
11511130

11521131
# Component-level benchmarks
11531132
if verbose:
11541133
print("1. Component-Level Benchmarks")
1155-
if conserve_memory:
1156-
# conserve_memory mode: use bridge.original_model as the HF
1157-
# reference (no separate HF load, 1.0x peak throughout).
1158-
try:
1159-
component_result = benchmark_all_components(
1160-
bridge_unprocessed, bridge_unprocessed.original_model
1161-
)
1162-
add_result(component_result)
1163-
if verbose:
1164-
status = "✓" if component_result.passed else "✗"
1165-
print(f"{status} {component_result.message}")
1166-
print(" (reference: bridge.original_model)\n")
1167-
except Exception as e:
1168-
if verbose:
1169-
print(f"✗ Component benchmark failed: {e}\n")
1170-
elif hf_model is not None:
1134+
if hf_model is not None:
11711135
# Full mode: component benchmark with independent HF model (brief 2.0x)
11721136
try:
11731137
component_result = benchmark_all_components(bridge_unprocessed, hf_model)
@@ -1242,27 +1206,7 @@ def cleanup_model(model, model_name_str: str):
12421206
# matmul non-determinism can exceed the float32 default of 1e-3
12431207
p1_atol = 1e-3 if dtype == torch.float32 else 5e-3
12441208

1245-
if conserve_memory:
1246-
# conserve_memory mode: capture reference logits from
1247-
# bridge.original_model (same tokenization as bridge).
1248-
try:
1249-
tokens = bridge_unprocessed.to_tokens(test_text)
1250-
with torch.no_grad():
1251-
hf_out = bridge_unprocessed.original_model(tokens)
1252-
ref_logits = hf_out.logits.detach()
1253-
add_result(
1254-
benchmark_forward_pass(
1255-
bridge_unprocessed,
1256-
test_text,
1257-
reference_logits=ref_logits,
1258-
atol=p1_atol,
1259-
)
1260-
)
1261-
del ref_logits
1262-
except Exception as e:
1263-
if verbose:
1264-
print(f"✗ Forward pass benchmark failed: {e}\n")
1265-
elif hf_saved_logits is not None:
1209+
if hf_saved_logits is not None:
12661210
# Full mode: use pre-captured HF logits (bridge only, 1.0x)
12671211
try:
12681212
add_result(
@@ -2028,13 +1972,6 @@ def main():
20281972
action="store_true",
20291973
help="Trust remote code for custom architectures (e.g., OpenELM)",
20301974
)
2031-
parser.add_argument(
2032-
"--conserve-memory",
2033-
action="store_true",
2034-
help="Reduce Phase 1 peak memory from 2.0x to 1.0x by using "
2035-
"bridge.original_model instead of loading a separate HF model",
2036-
)
2037-
20381975
args = parser.parse_args()
20391976

20401977
results = run_benchmark_suite(
@@ -2045,7 +1982,6 @@ def main():
20451982
enable_compatibility_mode=not args.no_compat,
20461983
verbose=not args.quiet,
20471984
trust_remote_code=args.trust_remote_code,
2048-
conserve_memory=args.conserve_memory,
20491985
)
20501986

20511987
if args.update_registry:

0 commit comments

Comments
 (0)