Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions pytensor/link/mlx/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ def advanced_subtensor(x, *ilists):


@mlx_funcify.register(IncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_IncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
if op.set_instead_of_inc:

def mlx_fn(x, indices, y):
if not op.inplace:
Expand All @@ -59,7 +58,9 @@ def mlx_fn(x, indices, y):
return x

def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):
indices = indices_from_subtensor(ilist, idx_list)
# Coerce integer index inputs to Python ints (e.g. slice bounds), as
# MLX slices reject array-typed bounds. Mirrors mlx_funcify_Subtensor.
indices = indices_from_subtensor([int(element) for element in ilist], idx_list)
if len(indices) == 1:
indices = indices[0]

Expand All @@ -69,23 +70,36 @@ def incsubtensor(x, y, *ilist, mlx_fn=mlx_fn, idx_list=op.idx_list):


@mlx_funcify.register(AdvancedIncSubtensor)
@mlx_funcify.register(AdvancedIncSubtensor1)
def mlx_funcify_AdvancedIncSubtensor(op, node, **kwargs):
if getattr(op, "set_instead_of_inc", False):
if op.set_instead_of_inc:

def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] = y
return x

else:

elif getattr(op, "ignore_duplicates", False):
# `ignore_duplicates` requests numpy's write-once `x[idx] += y`
# semantics (duplicate indices are not accumulated), matching the
# reference `perform` and the PyTorch/Numba backends.
def mlx_fn(x, indices, y):
if not op.inplace:
x = deepcopy(x)
x[indices] += y
return x

else:
# Accumulate duplicate indices (`np.add.at` semantics) via MLX's
# functional scatter-add, mirroring JAX's `x.at[indices].add(y)`.
# Plain `x[indices] += y` writes each destination once, dropping
# repeated-index contributions (e.g. gradients of embedding lookups).
# `AdvancedIncSubtensor1` has no `ignore_duplicates` flag and always
# accumulates, like its `np.add.at`-based `perform`.
def mlx_fn(x, indices, y):
return x.at[indices].add(y)

def advancedincsubtensor(x, y, *ilist, mlx_fn=mlx_fn):
return mlx_fn(x, ilist, y)

Expand Down
94 changes: 93 additions & 1 deletion tests/link/mlx/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_mlx_subtensor_edge_cases():
compare_mlx_and_py([], [out_pt], [])


@pytest.mark.xfail(reason="MLX indexing with tuples not yet supported")
def test_mlx_subtensor_with_variables():
"""Test subtensor operations with PyTensor variables as inputs."""
# Test with variable arrays (not constants)
Expand All @@ -224,3 +223,96 @@ def test_mlx_subtensor_with_variables():
# Set operation with variables
out_pt = pt_subtensor.set_subtensor(x_pt[0, :2], y_pt)
compare_mlx_and_py([x_pt, y_pt], [out_pt], [x_np, y_np])


def test_mlx_IncSubtensor_slice_grad():
"""Gradient of a basic slice lowers to an ``IncSubtensor`` with slice bounds
passed as (array) inputs; these must be coerced to Python ints for MLX."""
x_pt = pt.vector("x", dtype="float32")
x_np = np.arange(6, dtype=np.float32)

# Contiguous and strided (RoPE-style) slices both exercise the slice path.
for sl in (x_pt[0:3], x_pt[0::2]):
g = pt.grad((sl**2).sum(), x_pt)
assert isinstance(g.owner.op, pt_subtensor.IncSubtensor)
compare_mlx_and_py([x_pt], [g], [x_np])


@pytest.mark.xfail(
reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an "
"elementwise expression to a negative-strided slice returns wrong values "
"under mx.compile (correct when eager / use_compile=False).",
strict=True,
)
def test_mlx_IncSubtensor_negative_step_slice_grad():
x_pt = pt.vector("x", dtype="float32")
x_np = np.arange(6, dtype=np.float32)
g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt)
assert isinstance(g.owner.op, pt_subtensor.IncSubtensor)
compare_mlx_and_py([x_pt], [g], [x_np])


@pytest.mark.parametrize(
"func",
(pt_subtensor.advanced_inc_subtensor1, pt_subtensor.advanced_set_subtensor1),
ids=("inc", "set"),
)
def test_mlx_AdvancedIncSubtensor1_duplicate_indices(func):
"""Duplicate indices must accumulate for inc (``np.add.at`` semantics).

Gradients of advanced indexing (e.g. embedding lookups with repeated token
ids) produce inc with duplicate indices; MLX must sum all contributions
rather than writing each destination once.
"""
x = pt.vector("x", dtype="float32")
y = pt.vector("y", dtype="float32")
idxs = np.array([0, 0, 0, 1], dtype=np.int64)
out = func(x, y, idxs)
assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1)

x_np = np.zeros(3, dtype=np.float32)
y_np = np.ones(4, dtype=np.float32)
compare_mlx_and_py([x, y], [out], [x_np, y_np])


def test_mlx_AdvancedIncSubtensor1_duplicate_indices_edge_cases():
"""Duplicate accumulation with negative indices and a scalar (broadcast) ``y``."""
x = pt.vector("x", dtype="int32")
y = pt.scalar("y", dtype="int32")
idxs = np.array([-1, -1, 0, -1], dtype=np.int64)
out = pt_subtensor.advanced_inc_subtensor1(x, y, idxs)
assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor1)

compare_mlx_and_py([x, y], [out], [np.zeros(3, dtype=np.int32), np.int32(2)])


def test_mlx_AdvancedIncSubtensor_duplicate_indices():
"""``AdvancedIncSubtensor`` with duplicate indices accumulates like ``np.add.at``."""
x = pt.matrix("x", dtype="float32")
y = pt.vector("y", dtype="float32")
rows = np.array([0, 0, 1], dtype=np.int64)
cols = np.array([1, 1, 2], dtype=np.int64)
out = pt_subtensor.inc_subtensor(x[rows, cols], y)
assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert not out.owner.op.set_instead_of_inc
assert not out.owner.op.ignore_duplicates

x_np = np.zeros((3, 3), dtype=np.float32)
y_np = np.ones(3, dtype=np.float32)
compare_mlx_and_py([x, y], [out], [x_np, y_np])


def test_mlx_AdvancedIncSubtensor_ignore_duplicates():
"""``ignore_duplicates=True`` requests write-once (numpy ``x[idx] += y``).

Duplicate indices must NOT be accumulated in this mode, matching the
reference ``perform`` and the PyTorch/Numba backends.
"""
x = pt.vector("x", dtype="float32")
out = pt_subtensor.inc_subtensor(
x[[0, 1, 0]], np.float32(5.0), ignore_duplicates=True
)
assert isinstance(out.owner.op, pt_subtensor.AdvancedIncSubtensor)
assert out.owner.op.ignore_duplicates

compare_mlx_and_py([x], [out], [np.zeros(3, dtype=np.float32)])
Loading