From 02897e457679f66b78461e5bf36b16c7692abac2 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 15 Jun 2026 21:41:25 -0500 Subject: [PATCH 1/2] Add full mlx blockwise support --- pytensor/link/mlx/dispatch/blockwise.py | 91 +++++++++---- tests/link/mlx/test_blockwise.py | 169 ++++++++++++++++-------- 2 files changed, 173 insertions(+), 87 deletions(-) diff --git a/pytensor/link/mlx/dispatch/blockwise.py b/pytensor/link/mlx/dispatch/blockwise.py index b864cea2a3..79721958a9 100644 --- a/pytensor/link/mlx/dispatch/blockwise.py +++ b/pytensor/link/mlx/dispatch/blockwise.py @@ -1,42 +1,75 @@ import mlx.core as mx from pytensor.link.mlx.dispatch import mlx_funcify -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, _check_runtime_broadcast_core + + +def _reshape_stream(dtype): + # The default (GPU) stream does not support float64, so pin the squeeze and + # expand_dims of a float64 array to the CPU stream where it survives. Other + # dtypes stay on the default stream. + return mx.cpu if dtype == "float64" else None @mlx_funcify.register(Blockwise) def funcify_Blockwise(op: Blockwise, node, **kwargs): - # Get the core python function for this Blockwise operation core_node = op._create_dummy_core_node(node.inputs) core_f = mlx_funcify(op.core_op, core_node) - # Determine how many batch dimensions are present in the output - n_batch = op.batch_ndim(node) - - # If there are no batch dimensions, just return the core function - if n_batch == 0: + batch_ndim = op.batch_ndim(node) + if batch_ndim == 0: return core_f - # Build in_axes specification for mx.vmap - # Each input can be vectorized (axis=0) or static (axis=None) - in_axes: list[int | None] = [] - for inp, sig in zip(node.inputs, op.inputs_sig): - batch_ndim = inp.type.ndim - len(sig) - if batch_ndim == 0: - # Input has no batch dimensions - treat as static - in_axes.append(None) - continue - - batch_bcast = inp.type.broadcastable[:batch_ndim] - # If all batch dims are broadcastable (size 1), treat input as static - # Otherwise, vectorize over the first dimension (axis=0) - in_axes.append(0 if not all(batch_bcast) else None) - - # If all inputs are static (no actual vectorization needed), return core function - # This prevents calling mx.vmap with all-None in_axes, which would raise: - # "ValueError: At least one of in_axes must be non-None" - if not any(axis == 0 for axis in in_axes): - return core_f + multi_output = len(op.outputs_sig) > 1 + in_core = [len(sig) for sig in op.inputs_sig] + + # Decide batching purely from static shapes so a graph batches identically + # here and in every other backend: a batch axis broadcasts (is never mapped) + # only when its static size is exactly 1, or the input lacks it entirely. + batch_bcast = [inp.type.broadcastable[:batch_ndim] for inp in node.inputs] + squeeze_axes, padded_batch, squeeze_stream = [], [], [] + for inp, n_core in zip(node.inputs, in_core): + batch_shape = inp.type.shape[: inp.type.ndim - n_core] + squeeze_axes.append(tuple(i for i, s in enumerate(batch_shape) if s == 1)) + padded_batch.append((1,) * (batch_ndim - len(batch_shape)) + tuple(batch_shape)) + squeeze_stream.append(_reshape_stream(inp.type.dtype)) + + # Nest one mx.vmap per mapped batch axis (innermost first, so array axis 0 + # tracks the outermost batch dim). All-broadcast axes are squeezed out of + # every input above and re-inserted as size-1 dims after the mapped call. + fn, expand_axes = core_f, [] + for axis in reversed(range(batch_ndim)): + in_axes = tuple(None if batch[axis] == 1 else 0 for batch in padded_batch) + if all(ax is None for ax in in_axes): + expand_axes.append(axis) + else: + fn = mx.vmap(fn, in_axes=in_axes) + + expand_axes.sort() + expand_stream = [_reshape_stream(out.type.dtype) for out in node.outputs] + + def blockwise(*args): + # Verify the static broadcast pattern holds: a runtime size-1 batch dim + # that is not statically broadcastable must not silently broadcast here + # when every other backend would reject it. + _check_runtime_broadcast_core(args, batch_bcast, batch_ndim) + + squeezed = [ + mx.squeeze(arg, axes, stream=stream) if axes else arg + for arg, axes, stream in zip(args, squeeze_axes, squeeze_stream) + ] + out = fn(*squeezed) + if not expand_axes: + return out + + # Re-insert the never-mapped all-broadcast axes as size-1 dims, in + # ascending order so each insertion's index stays valid for the next. + outs = out if multi_output else (out,) + for ax in expand_axes: + outs = [ + mx.expand_dims(o, ax, stream=stream) + for o, stream in zip(outs, expand_stream) + ] + return tuple(outs) if multi_output else outs[0] - # Apply mx.vmap to vectorize the core function over batch dimensions - return mx.vmap(core_f, in_axes=tuple(in_axes)) + return blockwise diff --git a/tests/link/mlx/test_blockwise.py b/tests/link/mlx/test_blockwise.py index cfe0908411..15a3547ea1 100644 --- a/tests/link/mlx/test_blockwise.py +++ b/tests/link/mlx/test_blockwise.py @@ -1,69 +1,122 @@ import numpy as np +import pytest +import pytensor import pytensor.tensor as pt from pytensor.tensor import tensor from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import Dot -from tests.link.mlx.test_basic import compare_mlx_and_py +from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode, py_mode -# Equivalent blockwise to matmul but with dumb signature +matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)") odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") -def test_blockwise_conv1d(): - rng = np.random.default_rng(14) - a = tensor("a", shape=(2, 100)) - b = tensor("b", shape=(2, 8)) - - a_test = rng.normal(size=(2, 100)) - b_test = rng.normal(size=(2, 8)) - - test_values = [a_test, b_test] - - out = pt.signal.convolve1d(a, b, mode="valid") - - # assert isinstance(out.owner.op, Blockwise) - compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True) - - -def test_blockwise_no_batch_dimensions(): - """Test that Blockwise returns the core function when there are no batch dimensions. - - This verifies the fix for the vmap dispatcher issue where mx.vmap should not - be called when there are no batch dimensions to vectorize over. - """ - rng = np.random.default_rng(42) - - # Create a blockwise matmul with no batch dimensions (core operation only) - x = pt.matrix("x") - y = pt.matrix("y") - - blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)") - z = blockwise_matmul(x, y) - - x_test = rng.normal(size=(2, 3)) - y_test = rng.normal(size=(3, 4)) - - compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True) - - -def test_blockwise_all_broadcastable_batch_dims(): - """Test that Blockwise returns the core function when all batch dims are broadcastable. - - When all batch dimensions are size-1 (broadcastable), vmap should not be called - since there's no actual vectorization needed. - """ - rng = np.random.default_rng(43) - - # Create inputs with size-1 batch dimensions - x = tensor("x", shape=(1, 2, 3)) - y = tensor("y", shape=(1, 3, 4)) - - blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)") - z = blockwise_matmul(x, y) - - x_test = rng.normal(size=(1, 2, 3)) - y_test = rng.normal(size=(1, 3, 4)) - - compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True) +def _spd_batch(rng, batch): + """A batch of symmetric positive-definite matrices for Cholesky.""" + a = rng.standard_normal((*batch, 3, 3)) + return a @ np.swapaxes(a, -1, -2) + 3 * np.eye(3) + + +# Core ops with distinct gufunc signatures, each built for a leading batch shape: +# two rank-2 inputs, one rank-2 input, two rank-1 inputs. +def _matmul_graph(rng, batch): + a = tensor("a", shape=(*batch, 2, 3)) + b = tensor("b", shape=(*batch, 3, 4)) + values = [rng.standard_normal((*batch, 2, 3)), rng.standard_normal((*batch, 3, 4))] + return [a, b], matmul(a, b), values + + +def _cholesky_graph(rng, batch): + m = tensor("m", shape=(*batch, 3, 3)) + return [m], pt.linalg.cholesky(m), [_spd_batch(rng, batch)] + + +def _convolve_graph(rng, batch): + v = tensor("v", shape=(*batch, 16)) + k = tensor("k", shape=(*batch, 5)) + values = [rng.standard_normal((*batch, 16)), rng.standard_normal((*batch, 5))] + return [v, k], pt.signal.convolve1d(v, k, mode="valid"), values + + +@pytest.mark.parametrize( + "build", + [_matmul_graph, _cholesky_graph, _convolve_graph], + ids=["matmul", "cholesky", "convolve1d"], +) +@pytest.mark.parametrize("batch", [(5,), (2, 3)], ids=["single_batch", "nested_batch"]) +def test_blockwise_signatures(build, batch): + rng = np.random.default_rng(7) + inputs, out, values = build(rng, batch) + + assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py(inputs, [out], values) + + +@pytest.mark.parametrize( + "a_shape, b_shape", + [ + ((2, 3), (3, 4)), # no batch dims -> core function, no vmap + ((5, 2, 3), (3, 4)), # one input unbatched -> broadcast over batch + ((2, 1, 2, 3), (1, 3, 3, 4)), # size-1 batch dims on different axes + ((1, 2, 3), (1, 3, 4)), # all batch dims size-1 -> squeeze + expand + ], + ids=["no_batch", "broadcast_unbatched", "cross_broadcast", "all_broadcast"], +) +def test_blockwise_batch_broadcasting(a_shape, b_shape): + rng = np.random.default_rng(7) + a = tensor("a", shape=a_shape) + b = tensor("b", shape=b_shape) + out = matmul(a, b) + + assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py( + [a, b], [out], [rng.standard_normal(a_shape), rng.standard_normal(b_shape)] + ) + + +def test_blockwise_no_runtime_broadcast(): + rng = np.random.default_rng(7) + a = tensor("a", shape=(None, 2, 3)) + b = tensor("b", shape=(5, 3, 4)) + out = matmul(a, b) + + assert isinstance(out.owner.op, Blockwise) + values = [rng.standard_normal((1, 2, 3)), rng.standard_normal((5, 3, 4))] + + py_fn = pytensor.function([a, b], out, mode=py_mode) + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + py_fn(*values) + + mlx_fn = pytensor.function([a, b], out, mode=mlx_mode) + with pytest.raises(ValueError, match="Runtime broadcasting not allowed"): + mlx_fn(*values) + + +@pytest.mark.parametrize("batch", [(), (5,)], ids=["no_batch", "single_batch"]) +def test_blockwise_fallback_signature(batch): + rng = np.random.default_rng(7) + a = tensor("a", shape=(*batch, 2, 3)) + b = tensor("b", shape=(*batch, 3, 4)) + out = odd_matmul(a, b) + + assert isinstance(out.owner.op, Blockwise) + compare_mlx_and_py( + [a, b], + [out], + [rng.standard_normal((*batch, 2, 3)), rng.standard_normal((*batch, 3, 4))], + ) + + +def test_blockwise_multi_output(): + rng = np.random.default_rng(7) + x = tensor("x", shape=(1, 4, 4)) + q, r = pt.linalg.qr(x, mode="economic") + + assert isinstance(q.owner.op, Blockwise) + compare_mlx_and_py( + graph_inputs=[x], + graph_outputs=[q, r], + test_inputs=[rng.standard_normal((1, 4, 4))], + ) From 3f11bb2b45d78dbdb8e40db3924ee5671ffe52d3 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 21 Jun 2026 17:00:20 -0400 Subject: [PATCH 2/2] Run all MLX tests on cpu --- .github/workflows/test.yml | 2 +- .../link/mlx/dispatch/linalg/decomposition.py | 17 ++++++----- pytensor/link/mlx/dispatch/linalg/summary.py | 30 +++++++------------ tests/link/mlx/conftest.py | 13 ++++++++ tests/link/mlx/test_elemwise.py | 10 +++---- tests/link/mlx/test_subtensor.py | 23 ++++++++------ 6 files changed, 53 insertions(+), 42 deletions(-) create mode 100644 tests/link/mlx/conftest.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b7e27eb6a..246e8b82c6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -169,7 +169,7 @@ jobs: fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" && pip install "jax>=0.8,<0.9.1" jaxlib numpyro equinox tfp-nightly; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi - if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi + if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx>=0.30,<0.32"; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi pip install -e ./ diff --git a/pytensor/link/mlx/dispatch/linalg/decomposition.py b/pytensor/link/mlx/dispatch/linalg/decomposition.py index f82a65588b..fbd770b841 100644 --- a/pytensor/link/mlx/dispatch/linalg/decomposition.py +++ b/pytensor/link/mlx/dispatch/linalg/decomposition.py @@ -139,14 +139,15 @@ def mlx_funcify_PivotToPermutations(op, **kwargs): inverse = op.inverse def pivot_to_permutations(pivots): - pivots = mx.array(pivots) - n = pivots.shape[0] - p_inv = mx.arange(n, dtype=mx.int32) - for i in range(n): - p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] - if inverse: - return p_inv - return mx.argsort(p_inv) + with mx.stream(mx.cpu): + pivots = mx.array(pivots) + n = pivots.shape[0] + p_inv = mx.arange(n, dtype=mx.int32) + for i in range(n): + p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i] + if inverse: + return p_inv + return mx.argsort(p_inv) return pivot_to_permutations diff --git a/pytensor/link/mlx/dispatch/linalg/summary.py b/pytensor/link/mlx/dispatch/linalg/summary.py index bb051ca040..3d5c41b430 100644 --- a/pytensor/link/mlx/dispatch/linalg/summary.py +++ b/pytensor/link/mlx/dispatch/linalg/summary.py @@ -5,23 +5,13 @@ def _lu_det_parts(x): - """Shared helper: compute sign and logdet via LU factorization.""" - lu, pivots = mx.linalg.lu_factor(x, stream=mx.cpu) - diag_u = mx.diagonal(lu, stream=mx.cpu) - n_swaps = mx.sum( - pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype, stream=mx.cpu), - stream=mx.cpu, - ) + """Compute sign and logdet via LU factorization.""" + lu, pivots = mx.linalg.lu_factor(x) + diag_u = mx.diagonal(lu) + n_swaps = mx.sum(pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype)) pivot_sign = 1 - 2 * (n_swaps % 2) - sign = mx.multiply( - pivot_sign, - mx.prod(mx.sign(diag_u, stream=mx.cpu), stream=mx.cpu), - stream=mx.cpu, - ) - logabsdet = mx.sum( - mx.log(mx.abs(diag_u, stream=mx.cpu), stream=mx.cpu), - stream=mx.cpu, - ) + sign = pivot_sign * mx.prod(mx.sign(diag_u)) + logabsdet = mx.sum(mx.log(mx.abs(diag_u))) return sign, logabsdet @@ -30,8 +20,9 @@ def mlx_funcify_Det(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) def det(x): - sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) - return mx.multiply(sign, mx.exp(logabsdet, stream=mx.cpu), stream=mx.cpu) + with mx.stream(mx.cpu): + sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype)) + return sign * mx.exp(logabsdet) return det @@ -41,6 +32,7 @@ def mlx_funcify_SLogDet(op, node, **kwargs): X_dtype = getattr(mx, node.inputs[0].dtype) def slogdet(x): - return _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu)) + with mx.stream(mx.cpu): + return _lu_det_parts(x.astype(dtype=X_dtype)) return slogdet diff --git a/tests/link/mlx/conftest.py b/tests/link/mlx/conftest.py new file mode 100644 index 0000000000..46f3e396a1 --- /dev/null +++ b/tests/link/mlx/conftest.py @@ -0,0 +1,13 @@ +import pytest + + +@pytest.fixture(scope="session", autouse=True) +def mlx_cpu_default(): + # GitHub's macOS runners ship an older Metal stack that aborts when a + # CPU-produced array (mlx.linalg is CPU-only) feeds an op on the default GPU + # stream. + mx = pytest.importorskip("mlx.core") + previous = mx.default_device() + mx.set_default_device(mx.cpu) + yield + mx.set_default_device(previous) diff --git a/tests/link/mlx/test_elemwise.py b/tests/link/mlx/test_elemwise.py index b7c690cff0..3d40b49aef 100644 --- a/tests/link/mlx/test_elemwise.py +++ b/tests/link/mlx/test_elemwise.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np import pytest @@ -217,8 +219,6 @@ def test_erfc() -> None: def test_erfc_extreme_values() -> None: """Test erfc with extreme values""" - from functools import partial - x = vector("x") out = erfc(x) @@ -239,7 +239,9 @@ def test_erfcx() -> None: # Test with positive values where erfcx is most numerically stable x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5]) - compare_mlx_and_py([x], out, [x_test]) + relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3) + + compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert) def test_erfcx_small_values() -> None: @@ -266,8 +268,6 @@ def test_softplus() -> None: def test_softplus_extreme_values() -> None: """Test softplus with extreme values to verify numerical stability""" - from functools import partial - x = vector("x") out = softplus(x) diff --git a/tests/link/mlx/test_subtensor.py b/tests/link/mlx/test_subtensor.py index 2ca0374961..3e8f3acaee 100644 --- a/tests/link/mlx/test_subtensor.py +++ b/tests/link/mlx/test_subtensor.py @@ -1,10 +1,11 @@ import numpy as np import pytest +import pytensor import pytensor.tensor as pt from pytensor.tensor import subtensor as pt_subtensor from pytensor.tensor import tensor -from tests.link.mlx.test_basic import compare_mlx_and_py +from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode, py_mode mx = pytest.importorskip("mlx.core") @@ -238,15 +239,19 @@ def test_mlx_IncSubtensor_slice_grad(): compare_mlx_and_py([x_pt], [g], [x_np]) -@pytest.mark.xfail( - reason="Upstream mx.compile bug (ml-explore/mlx#3716): assigning an " - "elementwise expression to a negative-strided slice returns wrong values " - "under mx.compile (correct when eager / use_compile=False).", - strict=True, -) -def test_mlx_IncSubtensor_negative_step_slice_grad(): +@pytest.mark.parametrize("stream", [mx.cpu, mx.gpu], ids=["cpu", "gpu"]) +def test_mlx_IncSubtensor_negative_step_slice_grad(stream): x_pt = pt.vector("x", dtype="float32") x_np = np.arange(6, dtype=np.float32) g = pt.grad((x_pt[::-1] ** 2).sum(), x_pt) assert isinstance(g.owner.op, pt_subtensor.IncSubtensor) - compare_mlx_and_py([x_pt], [g], [x_np]) + + expected = pytensor.function([x_pt], g, mode=py_mode)(x_np) + mlx_fn = pytensor.function([x_pt], g, mode=mlx_mode) + with mx.stream(stream): + result = np.asarray(mlx_fn(x_np)) + + if stream == mx.cpu: + np.testing.assert_allclose(result, expected, rtol=1e-4) + else: + assert not np.allclose(result, expected, rtol=1e-4)