From 9c93b505906f4335071df3c3716e8955594e3c54 Mon Sep 17 00:00:00 2001 From: Tong Lu Date: Tue, 24 Mar 2026 14:05:03 +0800 Subject: [PATCH] Fix action_loss_per_sample shape for has_real_action --- .../model/dreamzero/action_head/wan_flow_matching_action_tf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py index b7cd806f..1f46ea1b 100644 --- a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py +++ b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py @@ -792,7 +792,7 @@ def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> action_loss_per_sample = torch.nn.functional.mse_loss( action_noise_pred.float(), training_target_action.float(), reduction='none' ) * action_mask # shape: [B, ...] - action_loss_per_sample = has_real_action[:, None].float() * action_loss_per_sample # apply has_real_action + action_loss_per_sample = has_real_action[:, None, None].float() * action_loss_per_sample # apply has_real_action weight_action = action_loss_per_sample.mean(dim=2) * self.scheduler.training_weight( timestep_action.flatten(0, 1), ).unflatten(0, (noise_action.shape[0], noise_action.shape[1])).to(self._device)