Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists):


@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
if op.set_instead_of_inc:

def mlx_fn(x, indices, y):
if not op.inplace:
Expand All @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y):
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
# Coerce integer index inputs to Python ints (e.g. slice bounds), as
# MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor.
indices = indices_from_subtensor([int(element) for element in ilist], idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -69,8 +70,9 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):


@mlx_funcify.register(AdvancedIncSubtensor)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of scope if you want to merge this PR already but the dispatch for the general AdvancedIncSubtensor is wrong, as it can have slices as well. Maybe also worth checking AdvancedSubtensor

@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
if op.set_instead_of_inc:

def mlx_fn(x, indices, y):
if not op.inplace:
Expand Down
28 changes: 27 additions & 1 deletion tests/link/mlx/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases():
compare_mlx_and_py([], [out_pt], [])


@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported")
def test_mlx_subtensor_with_variables():
"""Test subtensor operations with PyTensor variables as inputs."""
# Test with variable arrays (not constants)
Expand All @@ -224,3 +223,30 @@ def test_mlx_subtensor_with_variables():
# Set operation with variables
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt)
compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])


def test_mlx_IncSubtensor_slice_grad():
"""Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds
passed as (array) inputs; these must be coerced to Python ints for MLX."""
x_pt = pt.vector("x", dtype="float32")
x_np = np.arange(6, dtype=np.float32)

# Contiguous and strided (RoPE-style) slices both exercise the slice path.
for sl in (x_pt[0:3], x_pt[0::2]):
g = pt.grad((sl**2).sum(), x_pt)
assert isinstance(g.owner.op, pt_subtensor.IncSubtensor)
compare_mlx_and_py([x_pt], [g], [x_np])


@pytest.mark.xfail(
reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an "
"elementwise expression to a negative-strided slice returns wrong values "
"under mx.compile (correct when eager / use_compile=False).",
strict=True,
)
def test_mlx_IncSubtensor_negative_step_slice_grad():
x_pt = pt.vector("x", dtype="float32")
x_np = np.arange(6, dtype=np.float32)
g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt)
assert isinstance(g.owner.op, pt_subtensor.IncSubtensor)
compare_mlx_and_py([x_pt], [g], [x_np])
Loading