diff --git a/pytensor/link/mlx/dispatch/subtensor.py b/pytensor/link/mlx/dispatch/subtensor.py index 42a7bfdd80..2e6e2a3af1 100644 --- a/pytensor/link/mlx/dispatch/subtensor.py +++ b/pytensor/link/mlx/dispatch/subtensor.py @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists): @mlx_funcify.register(IncSubtensor) -@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_IncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y): return x def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): - indices = indices_from_subtensor(ilist, idx_list) + # Coerce integer index inputs to Python ints (e.g. slice bounds), as + # MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor. + indices = indices_from_subtensor([int(element) for element in ilist], idx_list) if len(indices) == 1: indices = indices[0] @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list): @mlx_funcify.register(AdvancedIncSubtensor) +@mlx_funcify.register(AdvancedIncSubtensor1) def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs): - if getattr(op, "set_instead_of_inc", False): + if op.set_instead_of_inc: def mlx_fn(x, indices, y): if not op.inplace: diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index e24607b08e..2ca0374961 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases(): compare_mlx_and_py([], [out_pt], []) -@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported") def test_mlx_subtensor_with_variables(): """Test subtensor operations with PyTensor variables as inputs.""" # Test with variable arrays (not constants) @@ -224,3 +223,30 @@ def test_mlx_subtensor_with_variables(): # Set operation with variables out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt) compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np]) + + +def test_mlx_IncSubtensor_slice_grad(): + """Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds + passed as (array) inputs; these must be coerced to Python ints for MLX.""" + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + + # Contiguous and strided (RoPE-style) slices both exercise the slice path. + for sl in (x_pt[0:3], x_pt[0::2]): + g = pt.grad((sl**2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np]) + + +@pytest.mark.xfail( + reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " + "elementwise expression to a negative-strided slice returns wrong values " + "under mx.compile (correct when eager / use_compile=False).", + strict=True, +) +def test_mlx_IncSubtensor_negative_step_slice_grad(): + x_pt = pt.vector("x", dtype="float32") + x_np = np.arange(6, dtype=np.float32) + g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) + assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) + compare_mlx_and_py([x_pt], [g], [x_np])