diff --git a/groot/vla/model/dreamzero/base_vla.py b/groot/vla/model/dreamzero/base_vla.py index 274b8640..7784f9b8 100644 --- a/groot/vla/model/dreamzero/base_vla.py +++ b/groot/vla/model/dreamzero/base_vla.py @@ -324,6 +324,10 @@ def from_pretrained_for_tuning( ) if lora_weights_path is not None: + # Inject LoRA adapters first (base weights are already loaded with correct key paths) + if hasattr(model, 'action_head') and hasattr(model.action_head, 'inject_lora_after_loading'): + print("Injecting LoRA adapters before loading LoRA weights") + model.action_head.inject_lora_after_loading() print(f"Loading LoRA weights from: {lora_weights_path}") model.load_lora_weight(lora_weights_path) else: