diff --git a/tests/test_shadow_peft.py b/tests/test_shadow_peft.py index ffae772..7ae5ece 100644 --- a/tests/test_shadow_peft.py +++ b/tests/test_shadow_peft.py @@ -87,6 +87,7 @@ def test_save_load_roundtrip_matches_outputs(tmp_path: Path): torch.manual_seed(0) base1 = _tiny_llama(num_layers=4) base1.eval() + base1_state = {k: v.detach().clone() for k, v in base1.state_dict().items()} cfg = ShadowConfig( num_shadow_layers=1, @@ -107,7 +108,7 @@ def test_save_load_roundtrip_matches_outputs(tmp_path: Path): # Create a new base with identical weights. base2 = _tiny_llama(num_layers=4) - base2.load_state_dict(base1.state_dict()) + base2.load_state_dict(base1_state) base2.eval() m2 = ShadowPeftModel.from_pretrained(base2, save_dir, is_trainable=False) @@ -217,5 +218,3 @@ def test_implicit_shadow_model_uses_shadow_intermediate_size(): ) m = get_shadow_model(base, cfg) assert m.shadow_model.config.intermediate_size == 12 - -