diff --git a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py index b7cd806f..1f46ea1b 100644 --- a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py +++ b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py @@ -792,7 +792,7 @@ def forward(self, backbone_output: BatchFeature, action_input: BatchFeature) -> action_loss_per_sample = torch.nn.functional.mse_loss( action_noise_pred.float(), training_target_action.float(), reduction='none' ) * action_mask # shape: [B, ...] - action_loss_per_sample = has_real_action[:, None].float() * action_loss_per_sample # apply has_real_action + action_loss_per_sample = has_real_action[:, None, None].float() * action_loss_per_sample # apply has_real_action weight_action = action_loss_per_sample.mean(dim=2) * self.scheduler.training_weight( timestep_action.flatten(0, 1), ).unflatten(0, (noise_action.shape[0], noise_action.shape[1])).to(self._device)