@@ -658,7 +658,6 @@ def run_benchmark_suite(
658658 test_weight_processing_individually : bool = False ,
659659 phases : list [int ] | None = None ,
660660 trust_remote_code : bool = False ,
661- conserve_memory : bool = False ,
662661 scoring_model : PreTrainedModel | None = None ,
663662 scoring_tokenizer : PreTrainedTokenizerBase | None = None ,
664663) -> List [BenchmarkResult ]:
@@ -691,11 +690,6 @@ def run_benchmark_suite(
691690 tests that check each processing flag individually (default: False)
692691 phases: Optional list of phase numbers to run (e.g., [1, 2, 3]). If None, runs all phases.
693692 trust_remote_code: Whether to trust remote code for custom architectures.
694- conserve_memory: When True, Phase 1 avoids loading a separate HF model
695- and instead uses bridge.original_model for component benchmarks and
696- forward pass comparison. This halves Phase 1 peak memory (1.0x vs 2.0x)
697- at the cost of losing the independent HF loading cross-check (~5%
698- weakening). Default is False (full dual-load for maximum test coverage).
699693 scoring_model: Optional pre-loaded GPT-2 scoring model for Phase 4. When
700694 provided with scoring_tokenizer, avoids reloading for each model in batch.
701695 scoring_tokenizer: Optional pre-loaded tokenizer for Phase 4 scoring model.
@@ -1024,24 +1018,10 @@ def cleanup_model(model, model_name_str: str):
10241018 if verbose :
10251019 print (f"⚠ Could not apply architecture patches: { patch_err } " )
10261020
1027- # ----------------------------------------------------------------
1028- # Phase 1 memory strategy (controlled by conserve_memory flag):
1029- #
1030- # conserve_memory=False (default):
1031- # Load separate HF model, capture logits to CPU, load Bridge,
1032- # run component benchmark with both models (brief 2.0x), delete
1033- # HF immediately after, forward pass uses saved logits (1.0x).
1034- #
1035- # conserve_memory=True:
1036- # Skip separate HF model entirely. Load Bridge only (1.0x
1037- # throughout). Component benchmark uses bridge.original_model
1038- # as the HF reference. Forward pass compares bridge output
1039- # against bridge.original_model logits.
1040- # ----------------------------------------------------------------
10411021 hf_saved_logits = None
10421022 hf_saved_loss = None
10431023
1044- if use_hf_reference and not conserve_memory and should_run_phase (1 ):
1024+ if use_hf_reference and should_run_phase (1 ):
10451025 try :
10461026 if verbose :
10471027 print ("Loading HuggingFace reference model..." )
@@ -1146,28 +1126,12 @@ def cleanup_model(model, model_name_str: str):
11461126 # Run Phase 1 benchmarks
11471127 if should_run_phase (1 ) and bridge_unprocessed :
11481128 if verbose :
1149- mode_label = " [conserve-memory]" if conserve_memory else ""
1150- print (f"Running Phase 1 benchmarks{ mode_label } ...\n " )
1129+ print ("Running Phase 1 benchmarks...\n " )
11511130
11521131 # Component-level benchmarks
11531132 if verbose :
11541133 print ("1. Component-Level Benchmarks" )
1155- if conserve_memory :
1156- # conserve_memory mode: use bridge.original_model as the HF
1157- # reference (no separate HF load, 1.0x peak throughout).
1158- try :
1159- component_result = benchmark_all_components (
1160- bridge_unprocessed , bridge_unprocessed .original_model
1161- )
1162- add_result (component_result )
1163- if verbose :
1164- status = "✓" if component_result .passed else "✗"
1165- print (f"{ status } { component_result .message } " )
1166- print (" (reference: bridge.original_model)\n " )
1167- except Exception as e :
1168- if verbose :
1169- print (f"✗ Component benchmark failed: { e } \n " )
1170- elif hf_model is not None :
1134+ if hf_model is not None :
11711135 # Full mode: component benchmark with independent HF model (brief 2.0x)
11721136 try :
11731137 component_result = benchmark_all_components (bridge_unprocessed , hf_model )
@@ -1242,27 +1206,7 @@ def cleanup_model(model, model_name_str: str):
12421206 # matmul non-determinism can exceed the float32 default of 1e-3
12431207 p1_atol = 1e-3 if dtype == torch .float32 else 5e-3
12441208
1245- if conserve_memory :
1246- # conserve_memory mode: capture reference logits from
1247- # bridge.original_model (same tokenization as bridge).
1248- try :
1249- tokens = bridge_unprocessed .to_tokens (test_text )
1250- with torch .no_grad ():
1251- hf_out = bridge_unprocessed .original_model (tokens )
1252- ref_logits = hf_out .logits .detach ()
1253- add_result (
1254- benchmark_forward_pass (
1255- bridge_unprocessed ,
1256- test_text ,
1257- reference_logits = ref_logits ,
1258- atol = p1_atol ,
1259- )
1260- )
1261- del ref_logits
1262- except Exception as e :
1263- if verbose :
1264- print (f"✗ Forward pass benchmark failed: { e } \n " )
1265- elif hf_saved_logits is not None :
1209+ if hf_saved_logits is not None :
12661210 # Full mode: use pre-captured HF logits (bridge only, 1.0x)
12671211 try :
12681212 add_result (
@@ -2028,13 +1972,6 @@ def main():
20281972 action = "store_true" ,
20291973 help = "Trust remote code for custom architectures (e.g., OpenELM)" ,
20301974 )
2031- parser .add_argument (
2032- "--conserve-memory" ,
2033- action = "store_true" ,
2034- help = "Reduce Phase 1 peak memory from 2.0x to 1.0x by using "
2035- "bridge.original_model instead of loading a separate HF model" ,
2036- )
2037-
20381975 args = parser .parse_args ()
20391976
20401977 results = run_benchmark_suite (
@@ -2045,7 +1982,6 @@ def main():
20451982 enable_compatibility_mode = not args .no_compat ,
20461983 verbose = not args .quiet ,
20471984 trust_remote_code = args .trust_remote_code ,
2048- conserve_memory = args .conserve_memory ,
20491985 )
20501986
20511987 if args .update_registry :
0 commit comments