From 5da7826795d5bf0afa2a67777a41fe84d6926af0 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:50:53 +0300 Subject: [PATCH 1/3] Fix MLX IncSubtensor with slice index (gradient of basic slicing) `inc_subtensor`/`set_subtensor` with a slice index failed on the MLX backend with `ValueError: Slice indices must be integers or None.`. This broke the gradient of any basic slicing (`x[a:b]`, `x[::2]`, ...), since the gradient of "read a slice" is an `IncSubtensor` that increments that slice of a zero tensor, and the slice bounds arrived as `mx.array` scalars which MLX slices reject. The forward `mlx_funcify_Subtensor` already coerces index inputs with `[int(element) for element in ilists]`; the `IncSubtensor` path did not. Coerce integer index inputs to Python ints in `mlx_funcify_IncSubtensor`, mirroring the forward op (and matching C/Numba/JAX/PyTorch semantics). `AdvancedIncSubtensor1` (a vector index that must not be int-coerced) is moved off the shared `IncSubtensor` dispatcher onto `mlx_funcify_AdvancedIncSubtensor`, mirroring PyTorch's canonical basic-vs-advanced grouping. Disclosure: implemented with AI assistance. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/subtensor.py | 10 +++++---- tests/link/mlx/test_subtensor.py | 28 ++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 42a7bfdd80..2e6e2a3af1 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): - indices = indices_from_subtensor(ilist, idx_list) + # Coerce integer index inputs to Python ints (e.g. slice bounds), as + # MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor. + indices = indices_from_subtensor([int(element) for element in ilist], idx_list) if len(indices) == 1: indices = indices[0] @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index e24607b08e..89e67216f6 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): """Test subtensor operations with PyTensor variables as inputs.""" # Test with variable arrays (not constants) @@ -224,3 +223,30 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_IncSubtensor_slice_grad(): + """Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds + passed as (array) inputs; these must be coerced to Python ints for MLX.""" + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + + # Contiguous and strided (RoPE-style) slices both exercise the slice path. + for sl in (x_pt[0:3], x_pt[0::2]): + g = pt.grad((sl**2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.xfail( + reason="Upstream mx.compile bug: a negative-strided in-place scatter whose " + "update derives from a negative-strided view returns wrong values " + "(correct when eager / use_compile=False).", + strict=True, +) +def test_mlx_IncSubtensor_negative_step_slice_grad(): + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) From 7cf890439b30c512d5e356df620f9a5bcb368d90 Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 18 Jun 2026 14:57:49 +0300 Subject: [PATCH 2/3] Reference upstream MLX issue in negative-step xfail Link the documented negative-strided slice gradient limitation to the upstream report ml-explore/mlx#3716. Co-authored-by: Cursor --- tests/link/mlx/test_subtensor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 89e67216f6..2ca0374961 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -239,9 +239,9 @@ def test_mlx_IncSubtensor_slice_grad(): @pytest.mark.xfail( - reason="Upstream mx.compile bug: a negative-strided in-place scatter whose " - "update derives from a negative-strided view returns wrong values " - "(correct when eager / use_compile=False).", + reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " + "elementwise expression to a negative-strided slice returns wrong values " + "under mx.compile (correct when eager / use_compile=False).", strict=True, ) def test_mlx_IncSubtensor_negative_step_slice_grad(): From 91d034abc6f323f28c52802c7229b2881cb6923f Mon Sep 17 00:00:00 2001 From: Carlos Trujillo <59846724+cetagostini@users.noreply.github.com> Date: Thu, 18 Jun 2026 16:15:52 +0300 Subject: [PATCH 3/3] Accumulate duplicate indices in MLX AdvancedIncSubtensor MLX increment with duplicate integer indices dropped repeated contributions: `x[indices] += y` desugars to gather-add-scatter and writes each destination once. Gradients of advanced indexing (e.g. embedding lookups with repeated token ids) are inc with duplicate indices, so MLX silently computed wrong values. Use MLX's functional scatter-add `x.at[indices].add(y)` for the inc path, mirroring JAX and matching the reference `np.add.at`-based `perform`. The `ignore_duplicates=True` mode keeps numpy write-once `x[idx] += y` semantics (matching PyTorch/Numba/reference), and the set path is unchanged. Co-authored-by: Cursor --- pytensor/link/mlx/dispatch/subtensor.py | 16 +++++- tests/link/mlx/test_subtensor.py | 66 +++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 2e6e2a3af1..785917f77e 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -80,14 +80,26 @@ def mlx_fn(x, indices, y): x[indices] = y return x - else: - + elif getattr(op, "ignore_duplicates", False): + # `ignore_duplicates` requests numpy's write-once `x[idx] += y` + # semantics (duplicate indices are not accumulated), matching the + # reference `perform` and the PyTorch/Numba backends. def mlx_fn(x, indices, y): if not op.inplace: x = deepcopy(x) x[indices] += y return x + else: + # Accumulate duplicate indices (`np.add.at` semantics) via MLX's + # functional scatter-add, mirroring JAX's `x.at[indices].add(y)`. + # Plain `x[indices] += y` writes each destination once, dropping + # repeated-index contributions (e.g. gradients of embedding lookups). + # `AdvancedIncSubtensor1` has no `ignore_duplicates` flag and always + # accumulates, like its `np.add.at`-based `perform`. + def mlx_fn(x, indices, y): + return x.at[indices].add(y) + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): return mlx_fn(x, ilist, y) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2ca0374961..c23c8baea8 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -250,3 +250,69 @@ def test_mlx_IncSubtensor_negative_step_slice_grad(): g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.parametrize( + "func", + (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1), + ids=("inc", "set"), +) +def test_mlx_AdvancedIncSubtensor1_duplicate_indices(func): + """Duplicate indices must accumulate for inc (``np.add.at`` semantics). + + Gradients of advanced indexing (e.g. embedding lookups with repeated token + ids) produce inc with duplicate indices; MLX must sum all contributions + rather than writing each destination once. + """ + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + idxs = np.array([0, 0, 0, 1], dtype=np.int64) + out = func(x, y, idxs) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1) + + x_np = np.zeros(3, dtype=np.float32) + y_np = np.ones(4, dtype=np.float32) + compare_mlx_and_py([x, y], [out], [x_np, y_np]) + + +def test_mlx_AdvancedIncSubtensor1_duplicate_indices_edge_cases(): + """Duplicate accumulation with negative indices and a scalar (broadcast) ``y``.""" + x = pt.vector("x", dtype="int32") + y = pt.scalar("y", dtype="int32") + idxs = np.array([-1, -1, 0, -1], dtype=np.int64) + out = pt_subtensor.advanced_inc_subtensor1(x, y, idxs) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1) + + compare_mlx_and_py([x, y], [out], [np.zeros(3, dtype=np.int32), np.int32(2)]) + + +def test_mlx_AdvancedIncSubtensor_duplicate_indices(): + """``AdvancedIncSubtensor`` with duplicate indices accumulates like ``np.add.at``.""" + x = pt.matrix("x", dtype="float32") + y = pt.vector("y", dtype="float32") + rows = np.array([0, 0, 1], dtype=np.int64) + cols = np.array([1, 1, 2], dtype=np.int64) + out = pt_subtensor.inc_subtensor(x[rows, cols], y) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert not out.owner.op.set_instead_of_inc + assert not out.owner.op.ignore_duplicates + + x_np = np.zeros((3, 3), dtype=np.float32) + y_np = np.ones(3, dtype=np.float32) + compare_mlx_and_py([x, y], [out], [x_np, y_np]) + + +def test_mlx_AdvancedIncSubtensor_ignore_duplicates(): + """``ignore_duplicates=True`` requests write-once (numpy ``x[idx] += y``). + + Duplicate indices must NOT be accumulated in this mode, matching the + reference ``perform`` and the PyTorch/Numba backends. + """ + x = pt.vector("x", dtype="float32") + out = pt_subtensor.inc_subtensor( + x[[0, 1, 0]], np.float32(5.0), ignore_duplicates=True + ) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert out.owner.op.ignore_duplicates + + compare_mlx_and_py([x], [out], [np.zeros(3, dtype=np.float32)])