From 5da7826795d5bf0afa2a67777a41fe84d6926af0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:50:53 +0300 Subject: [PATCH 1/2] Fix MLX IncSubtensor with slice index (gradient of basic slicing) `inc_subtensor`/`set_subtensor` with a slice index failed on the MLX backend with `ValueError: Slice indices must be integers or None.`. This broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since the gradient of "read a slice" is an `IncSubtensor` that increments that slice of a zero tensor, and the slice bounds arrived as `mx.array` scalars which MLX slices reject. The forward `mlx_funcify_Subtensor` already coerces index inputs with `[int(element) for element in ilists]`; the `IncSubtensor` path did not. Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`, mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics). `AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is moved off the shared `IncSubtensor` dispatcher onto `mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical basic-vs-advanced grouping. Disclosure: implemented with AI assistance. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/subtensor.py | 10 +++++---- tests/link/mlx/test_subtensor.py | 28 ++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 42a7bfdd80..2e6e2a3af1 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): - indices = indices_from_subtensor(ilist, idx_list) + # Coerce integer index inputs to Python ints (e.g. slice bounds), as + # MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor. + indices = indices_from_subtensor([int(element) for element in ilist], idx_list) if len(indices) == 1: indices = indices[0] @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index e24607b08e..89e67216f6 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): """Test subtensor operations with PyTensor variables as inputs.""" # Test with variable arrays (not constants) @@ -224,3 +223,30 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_IncSubtensor_slice_grad(): + """Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds + passed as (array) inputs; these must be coerced to Python ints for MLX.""" + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + + # Contiguous and strided (RoPE-style) slices both exercise the slice path. + for sl in (x_pt[0:3], x_pt[0::2]): + g = pt.grad((sl**2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.xfail( + reason="Upstream mx.compile bug: a negative-strided in-place scatter whose " + "update derives from a negative-strided view returns wrong values " + "(correct when eager / use_compile=False).", + strict=True, +) +def test_mlx_IncSubtensor_negative_step_slice_grad(): + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) From 7cf890439b30c512d5e356df620f9a5bcb368d90 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:57:49 +0300 Subject: [PATCH 2/2] Reference upstream MLX issue in negative-step xfail Link the documented negative-strided slice gradient limitation to the upstream report ml-explore/mlx#3716. Co-authored-by: Cursor --- tests/link/mlx/test_subtensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 89e67216f6..2ca0374961 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -239,9 +239,9 @@ def test_mlx_IncSubtensor_slice_grad(): @pytest.mark.xfail( - reason="Upstream mx.compile bug: a negative-strided in-place scatter whose " - "update derives from a negative-strided view returns wrong values " - "(correct when eager / use_compile=False).", + reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " + "elementwise expression to a negative-strided slice returns wrong values " + "under mx.compile (correct when eager / use_compile=False).", strict=True, ) def test_mlx_IncSubtensor_negative_step_slice_grad():