Skip to content

ValueError: Shapes cannot be broadcast during Qwen3-VL fine-tuning due to attention_mask/input_ids shape mismatch #652

@harrylawson37

Description

@harrylawson37

Environment:

  • MLX-VLM version: 0.3.9
  • Model: Qwen3-VL
  • Platform: macOS (Apple Silicon)
  • Python version: 3.12

Description:
When fine-tuning Qwen3-VL models using mlx_vlm.lora.py, training fails with a shape mismatch error between attention_mask and input_ids tensors. The error occurs because input_ids is sliced to remove the last token (standard for causal LM training) but attention_mask retains its original length, causing a broadcast error in the model's rope position calculation.

Error Message:

ValueError: [broadcast_shapes] Shapes (2196) and (2195) cannot be broadcast.

Full Stack Trace:

File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/lora.py", line 110, in main
  loss = trainer.train_step(
File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/trainer/trainer.py", line 269, in train_step
  loss, grads = loss_and_grad_fn(self.model, batch)
File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/trainer/trainer.py", line 234, in loss_fn
  outputs = model(input_ids, pixel_values, attention_mask, **kwargs)
File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/models/qwen3_vl/qwen3_vl.py", line 133, in __call__
  logits = self.language_model(
File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/models/qwen3_vl/language.py", line 571, in __call__
  position_ids, rope_deltas = self.get_rope_index(
File "/opt/homebrew/Caskroom/miniconda/base/envs/image-quality/lib/python3.12/site-packages/mlx_vlm/models/qwen3_vl/language.py", line 385, in get_rope_index
  input_ids = mx.where(

Root Cause:
In trainer.py line 224, input_ids is sliced: input_ids = input_ids[:, :-1] (removing last token for causal LM), but attention_mask is not correspondingly sliced, creating a shape mismatch when both tensors are used in the model forward pass.

Current Workaround:
Adding this line after input_ids slicing resolves the issue:

attention_mask = attention_mask[:, :-1]  # Slice attention_mask to match input_ids

Expected Behavior:
The trainer should automatically handle attention_mask alignment with input_ids, or the prepare_inputs function should ensure consistent tensor shapes.

Suggested Fix:
Either:

  1. Update trainer.py to automatically align attention_mask with input_ids after slicing
  2. Update the Qwen3-VL model's prepare_inputs to handle this alignment
  3. Add proper shape validation with informative error messages

Additional Context:
Unsure if this issue is specific to the Qwen3-VL model. Let me know if this is a user error or if you need any more info - first time creating a bug report.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions