Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6457,15 +6457,30 @@ def _check_args(tensor_inputs, indexing) -> None:
assert isinstance(tensor_inputs, (list, tuple))
if len(tensor_inputs) < 2:
raise ValueError("Requires >= 2 tensor inputs.")
if any(tensor_input.rank > 1 for tensor_input in tensor_inputs):
raise ValueError("meshgrid received non-1d tensor.")

if indexing not in ("ij", "xy"):
raise ValueError(f"indexing mode {indexing} not supported")

def _flatten_inputs(tensor_inputs):
"""Flatten non-1D inputs to 1D.

PyTorch JIT tracing can produce tensors with shape (N, 1) instead of (N,)
for ops like torch.linspace. Since meshgrid expects 1D inputs, we reshape
them here.
"""
flattened = []
for i, t in enumerate(tensor_inputs):
if t.rank > 1:
t = mb.reshape(
x=t, shape=[-1], name=node.name + "_flatten_" + str(i)
)
flattened.append(t)
return flattened

tensor_inputs, indexing = _parse_positional_args(context, node)
indexing = _parse_keyword_args(context, node, indexing)
_check_args(tensor_inputs, indexing)
tensor_inputs = _flatten_inputs(tensor_inputs)

result_symbolic_shape = [tensor_input.shape[0] for tensor_input in tensor_inputs]
result_shape = _utils.maybe_replace_symbols_with_source_tensor_shape_variables(
Expand Down
48 changes: 48 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11312,6 +11312,54 @@ def forward(self, x, y, z):
compute_unit=compute_unit,
)

@pytest.mark.parametrize(
"compute_unit, backend, indexing",
itertools.product(
compute_units,
backends,
["ij", "xy"],
),
)
def test_meshgrid_non_1d_inputs(self, compute_unit, backend, indexing):
"""Test meshgrid where inputs become non-1D during conversion.

When a 1D tensor is divided by a 0D scalar (e.g. from a dynamic shape),
the MIL converter may produce a higher-rank result. The meshgrid converter
should flatten these to 1D before processing rather than raising an error.

This pattern occurs in deformable attention (DINO/Deformable-DETR/RF-DETR)
where coordinate grids are created via:
grid_y = torch.linspace(..., steps=h) / h
grid_x = torch.linspace(..., steps=w) / w
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing='ij')
"""

class TestModel(nn.Module):
def forward(self, feat):
h, w = feat.shape[2], feat.shape[3]
grid_y = torch.linspace(
0.5, h - 0.5, steps=h, dtype=feat.dtype, device=feat.device
) / h
grid_x = torch.linspace(
0.5, w - 0.5, steps=w, dtype=feat.dtype, device=feat.device
) / w
grid_y, grid_x = torch.meshgrid(grid_y, grid_x, indexing=indexing)
return torch.stack([grid_x, grid_y], dim=-1)

inputs = (torch.randn(1, 64, 4, 6),)
model = TestModel().eval()
expected_results = model(*inputs)

self.run_compare_torch(
inputs,
model,
expected_results,
input_as_shape=False,
frontend=TorchFrontend.TORCHSCRIPT,
backend=backend,
compute_unit=compute_unit,
)


class TestAddmm(TorchBaseTest):
@pytest.mark.parametrize(
Expand Down