diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 42a7bfdd80..785917f77e 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): - indices = indices_from_subtensor(ilist, idx_list) + # Coerce integer index inputs to Python ints (e.g. slice bounds), as + # MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor. + indices = indices_from_subtensor([int(element) for element in ilist], idx_list) if len(indices) == 1: indices = indices[0] @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -78,14 +80,26 @@ def mlx_fn(x, indices, y): x[indices] = y return x - else: - + elif getattr(op, "ignore_duplicates", False): + # `ignore_duplicates` requests numpy's write-once `x[idx] += y` + # semantics (duplicate indices are not accumulated), matching the + # reference `perform` and the PyTorch/Numba backends. def mlx_fn(x, indices, y): if not op.inplace: x = deepcopy(x) x[indices] += y return x + else: + # Accumulate duplicate indices (`np.add.at` semantics) via MLX's + # functional scatter-add, mirroring JAX's `x.at[indices].add(y)`. + # Plain `x[indices] += y` writes each destination once, dropping + # repeated-index contributions (e.g. gradients of embedding lookups). + # `AdvancedIncSubtensor1` has no `ignore_duplicates` flag and always + # accumulates, like its `np.add.at`-based `perform`. + def mlx_fn(x, indices, y): + return x.at[indices].add(y) + def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn): return mlx_fn(x, ilist, y) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index e24607b08e..c23c8baea8 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): """Test subtensor operations with PyTensor variables as inputs.""" # Test with variable arrays (not constants) @@ -224,3 +223,96 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_IncSubtensor_slice_grad(): + """Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds + passed as (array) inputs; these must be coerced to Python ints for MLX.""" + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + + # Contiguous and strided (RoPE-style) slices both exercise the slice path. + for sl in (x_pt[0:3], x_pt[0::2]): + g = pt.grad((sl**2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.xfail( + reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " + "elementwise expression to a negative-strided slice returns wrong values " + "under mx.compile (correct when eager / use_compile=False).", + strict=True, +) +def test_mlx_IncSubtensor_negative_step_slice_grad(): + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.parametrize( + "func", + (pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1), + ids=("inc", "set"), +) +def test_mlx_AdvancedIncSubtensor1_duplicate_indices(func): + """Duplicate indices must accumulate for inc (``np.add.at`` semantics). + + Gradients of advanced indexing (e.g. embedding lookups with repeated token + ids) produce inc with duplicate indices; MLX must sum all contributions + rather than writing each destination once. + """ + x = pt.vector("x", dtype="float32") + y = pt.vector("y", dtype="float32") + idxs = np.array([0, 0, 0, 1], dtype=np.int64) + out = func(x, y, idxs) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1) + + x_np = np.zeros(3, dtype=np.float32) + y_np = np.ones(4, dtype=np.float32) + compare_mlx_and_py([x, y], [out], [x_np, y_np]) + + +def test_mlx_AdvancedIncSubtensor1_duplicate_indices_edge_cases(): + """Duplicate accumulation with negative indices and a scalar (broadcast) ``y``.""" + x = pt.vector("x", dtype="int32") + y = pt.scalar("y", dtype="int32") + idxs = np.array([-1, -1, 0, -1], dtype=np.int64) + out = pt_subtensor.advanced_inc_subtensor1(x, y, idxs) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1) + + compare_mlx_and_py([x, y], [out], [np.zeros(3, dtype=np.int32), np.int32(2)]) + + +def test_mlx_AdvancedIncSubtensor_duplicate_indices(): + """``AdvancedIncSubtensor`` with duplicate indices accumulates like ``np.add.at``.""" + x = pt.matrix("x", dtype="float32") + y = pt.vector("y", dtype="float32") + rows = np.array([0, 0, 1], dtype=np.int64) + cols = np.array([1, 1, 2], dtype=np.int64) + out = pt_subtensor.inc_subtensor(x[rows, cols], y) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert not out.owner.op.set_instead_of_inc + assert not out.owner.op.ignore_duplicates + + x_np = np.zeros((3, 3), dtype=np.float32) + y_np = np.ones(3, dtype=np.float32) + compare_mlx_and_py([x, y], [out], [x_np, y_np]) + + +def test_mlx_AdvancedIncSubtensor_ignore_duplicates(): + """``ignore_duplicates=True`` requests write-once (numpy ``x[idx] += y``). + + Duplicate indices must NOT be accumulated in this mode, matching the + reference ``perform`` and the PyTorch/Numba backends. + """ + x = pt.vector("x", dtype="float32") + out = pt_subtensor.inc_subtensor( + x[[0, 1, 0]], np.float32(5.0), ignore_duplicates=True + ) + assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor) + assert out.owner.op.ignore_duplicates + + compare_mlx_and_py([x], [out], [np.zeros(3, dtype=np.float32)])