Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ jobs:
fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" && pip install "jax>=0.8,<0.9.1" jaxlib numpyro equinox tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx<0.29.4"; fi
if [[ $INSTALL_MLX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "mlx>=0.30,<0.32"; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi

pip install -e ./
Expand Down
91 changes: 62 additions & 29 deletions pytensor/link/mlx/dispatch/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,75 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch import mlx_funcify
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.blockwise import Blockwise, _check_runtime_broadcast_core


def _reshape_stream(dtype):
# The default (GPU) stream does not support float64, so pin the squeeze and
# expand_dims of a float64 array to the CPU stream where it survives. Other
# dtypes stay on the default stream.
return mx.cpu if dtype == "float64" else None


@mlx_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, **kwargs):
# Get the core python function for this Blockwise operation
core_node = op._create_dummy_core_node(node.inputs)
core_f = mlx_funcify(op.core_op, core_node)

# Determine how many batch dimensions are present in the output
n_batch = op.batch_ndim(node)

# If there are no batch dimensions, just return the core function
if n_batch == 0:
batch_ndim = op.batch_ndim(node)
if batch_ndim == 0:
return core_f

# Build in_axes specification for mx.vmap
# Each input can be vectorized (axis=0) or static (axis=None)
in_axes: list[int | None] = []
for inp, sig in zip(node.inputs, op.inputs_sig):
batch_ndim = inp.type.ndim - len(sig)
if batch_ndim == 0:
# Input has no batch dimensions - treat as static
in_axes.append(None)
continue

batch_bcast = inp.type.broadcastable[:batch_ndim]
# If all batch dims are broadcastable (size 1), treat input as static
# Otherwise, vectorize over the first dimension (axis=0)
in_axes.append(0 if not all(batch_bcast) else None)

# If all inputs are static (no actual vectorization needed), return core function
# This prevents calling mx.vmap with all-None in_axes, which would raise:
# "ValueError: At least one of in_axes must be non-None"
if not any(axis == 0 for axis in in_axes):
return core_f
multi_output = len(op.outputs_sig) > 1
in_core = [len(sig) for sig in op.inputs_sig]

# Decide batching purely from static shapes so a graph batches identically
# here and in every other backend: a batch axis broadcasts (is never mapped)
# only when its static size is exactly 1, or the input lacks it entirely.
batch_bcast = [inp.type.broadcastable[:batch_ndim] for inp in node.inputs]
squeeze_axes, padded_batch, squeeze_stream = [], [], []
for inp, n_core in zip(node.inputs, in_core):
batch_shape = inp.type.shape[: inp.type.ndim - n_core]
squeeze_axes.append(tuple(i for i, s in enumerate(batch_shape) if s == 1))
padded_batch.append((1,) * (batch_ndim - len(batch_shape)) + tuple(batch_shape))
squeeze_stream.append(_reshape_stream(inp.type.dtype))

# Nest one mx.vmap per mapped batch axis (innermost first, so array axis 0
# tracks the outermost batch dim). All-broadcast axes are squeezed out of
# every input above and re-inserted as size-1 dims after the mapped call.
fn, expand_axes = core_f, []
for axis in reversed(range(batch_ndim)):
in_axes = tuple(None if batch[axis] == 1 else 0 for batch in padded_batch)
if all(ax is None for ax in in_axes):
expand_axes.append(axis)
else:
fn = mx.vmap(fn, in_axes=in_axes)

expand_axes.sort()
expand_stream = [_reshape_stream(out.type.dtype) for out in node.outputs]

def blockwise(*args):
# Verify the static broadcast pattern holds: a runtime size-1 batch dim
# that is not statically broadcastable must not silently broadcast here
# when every other backend would reject it.
_check_runtime_broadcast_core(args, batch_bcast, batch_ndim)

squeezed = [
mx.squeeze(arg, axes, stream=stream) if axes else arg
for arg, axes, stream in zip(args, squeeze_axes, squeeze_stream)
]
out = fn(*squeezed)
if not expand_axes:
return out

# Re-insert the never-mapped all-broadcast axes as size-1 dims, in
# ascending order so each insertion's index stays valid for the next.
outs = out if multi_output else (out,)
for ax in expand_axes:
outs = [
mx.expand_dims(o, ax, stream=stream)
for o, stream in zip(outs, expand_stream)
]
return tuple(outs) if multi_output else outs[0]

# Apply mx.vmap to vectorize the core function over batch dimensions
return mx.vmap(core_f, in_axes=tuple(in_axes))
return blockwise
17 changes: 9 additions & 8 deletions pytensor/link/mlx/dispatch/linalg/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,15 @@ def mlx_funcify_PivotToPermutations(op, **kwargs):
inverse = op.inverse

def pivot_to_permutations(pivots):
pivots = mx.array(pivots)
n = pivots.shape[0]
p_inv = mx.arange(n, dtype=mx.int32)
for i in range(n):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
if inverse:
return p_inv
return mx.argsort(p_inv)
with mx.stream(mx.cpu):
pivots = mx.array(pivots)
n = pivots.shape[0]
p_inv = mx.arange(n, dtype=mx.int32)
for i in range(n):
p_inv[i], p_inv[pivots[i]] = p_inv[pivots[i]], p_inv[i]
if inverse:
return p_inv
return mx.argsort(p_inv)

return pivot_to_permutations

Expand Down
30 changes: 11 additions & 19 deletions pytensor/link/mlx/dispatch/linalg/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,13 @@


def _lu_det_parts(x):
"""Shared helper: compute sign and logdet via LU factorization."""
lu, pivots = mx.linalg.lu_factor(x, stream=mx.cpu)
diag_u = mx.diagonal(lu, stream=mx.cpu)
n_swaps = mx.sum(
pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype, stream=mx.cpu),
stream=mx.cpu,
)
"""Compute sign and logdet via LU factorization."""
lu, pivots = mx.linalg.lu_factor(x)
diag_u = mx.diagonal(lu)
n_swaps = mx.sum(pivots != mx.arange(pivots.shape[0], dtype=pivots.dtype))
pivot_sign = 1 - 2 * (n_swaps % 2)
sign = mx.multiply(
pivot_sign,
mx.prod(mx.sign(diag_u, stream=mx.cpu), stream=mx.cpu),
stream=mx.cpu,
)
logabsdet = mx.sum(
mx.log(mx.abs(diag_u, stream=mx.cpu), stream=mx.cpu),
stream=mx.cpu,
)
sign = pivot_sign * mx.prod(mx.sign(diag_u))
logabsdet = mx.sum(mx.log(mx.abs(diag_u)))
return sign, logabsdet


Expand All @@ -30,8 +20,9 @@ def mlx_funcify_Det(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)

def det(x):
sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu))
return mx.multiply(sign, mx.exp(logabsdet, stream=mx.cpu), stream=mx.cpu)
with mx.stream(mx.cpu):
sign, logabsdet = _lu_det_parts(x.astype(dtype=X_dtype))
return sign * mx.exp(logabsdet)

return det

Expand All @@ -41,6 +32,7 @@ def mlx_funcify_SLogDet(op, node, **kwargs):
X_dtype = getattr(mx, node.inputs[0].dtype)

def slogdet(x):
return _lu_det_parts(x.astype(dtype=X_dtype, stream=mx.cpu))
with mx.stream(mx.cpu):
return _lu_det_parts(x.astype(dtype=X_dtype))

return slogdet
13 changes: 13 additions & 0 deletions tests/link/mlx/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import pytest


@pytest.fixture(scope="session", autouse=True)
def mlx_cpu_default():
# GitHub's macOS runners ship an older Metal stack that aborts when a
# CPU-produced array (mlx.linalg is CPU-only) feeds an op on the default GPU
# stream.
mx = pytest.importorskip("mlx.core")
previous = mx.default_device()
mx.set_default_device(mx.cpu)
yield
mx.set_default_device(previous)
169 changes: 111 additions & 58 deletions tests/link/mlx/test_blockwise.py
Original file line number Diff line number Diff line change
@@ -1,69 +1,122 @@
import numpy as np
import pytest

import pytensor
import pytensor.tensor as pt
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
from tests.link.mlx.test_basic import compare_mlx_and_py
from tests.link.mlx.test_basic import compare_mlx_and_py, mlx_mode, py_mode


# Equivalent blockwise to matmul but with dumb signature
matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
Comment thread
jessegrabowski marked this conversation as resolved.


def test_blockwise_conv1d():
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 100))
b = tensor("b", shape=(2, 8))

a_test = rng.normal(size=(2, 100))
b_test = rng.normal(size=(2, 8))

test_values = [a_test, b_test]

out = pt.signal.convolve1d(a, b, mode="valid")

# assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py([a, b], [out], test_values, must_be_device_array=True)


def test_blockwise_no_batch_dimensions():
"""Test that Blockwise returns the core function when there are no batch dimensions.

This verifies the fix for the vmap dispatcher issue where mx.vmap should not
be called when there are no batch dimensions to vectorize over.
"""
rng = np.random.default_rng(42)

# Create a blockwise matmul with no batch dimensions (core operation only)
x = pt.matrix("x")
Comment thread
ricardoV94 marked this conversation as resolved.
y = pt.matrix("y")

blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)

x_test = rng.normal(size=(2, 3))
y_test = rng.normal(size=(3, 4))

compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)


def test_blockwise_all_broadcastable_batch_dims():
"""Test that Blockwise returns the core function when all batch dims are broadcastable.

When all batch dimensions are size-1 (broadcastable), vmap should not be called
since there's no actual vectorization needed.
"""
rng = np.random.default_rng(43)

# Create inputs with size-1 batch dimensions
x = tensor("x", shape=(1, 2, 3))
y = tensor("y", shape=(1, 3, 4))

blockwise_matmul = Blockwise(Dot(), signature="(i,j),(j,k)->(i,k)")
z = blockwise_matmul(x, y)

x_test = rng.normal(size=(1, 2, 3))
y_test = rng.normal(size=(1, 3, 4))

compare_mlx_and_py([x, y], [z], [x_test, y_test], must_be_device_array=True)
def _spd_batch(rng, batch):
"""A batch of symmetric positive-definite matrices for Cholesky."""
a = rng.standard_normal((*batch, 3, 3))
return a @ np.swapaxes(a, -1, -2) + 3 * np.eye(3)


# Core ops with distinct gufunc signatures, each built for a leading batch shape:
# two rank-2 inputs, one rank-2 input, two rank-1 inputs.
def _matmul_graph(rng, batch):
a = tensor("a", shape=(*batch, 2, 3))
b = tensor("b", shape=(*batch, 3, 4))
values = [rng.standard_normal((*batch, 2, 3)), rng.standard_normal((*batch, 3, 4))]
return [a, b], matmul(a, b), values


def _cholesky_graph(rng, batch):
m = tensor("m", shape=(*batch, 3, 3))
return [m], pt.linalg.cholesky(m), [_spd_batch(rng, batch)]


def _convolve_graph(rng, batch):
v = tensor("v", shape=(*batch, 16))
k = tensor("k", shape=(*batch, 5))
values = [rng.standard_normal((*batch, 16)), rng.standard_normal((*batch, 5))]
return [v, k], pt.signal.convolve1d(v, k, mode="valid"), values


@pytest.mark.parametrize(
"build",
[_matmul_graph, _cholesky_graph, _convolve_graph],
ids=["matmul", "cholesky", "convolve1d"],
)
@pytest.mark.parametrize("batch", [(5,), (2, 3)], ids=["single_batch", "nested_batch"])
def test_blockwise_signatures(build, batch):
rng = np.random.default_rng(7)
inputs, out, values = build(rng, batch)

assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py(inputs, [out], values)


@pytest.mark.parametrize(
"a_shape, b_shape",
[
((2, 3), (3, 4)), # no batch dims -> core function, no vmap
((5, 2, 3), (3, 4)), # one input unbatched -> broadcast over batch
((2, 1, 2, 3), (1, 3, 3, 4)), # size-1 batch dims on different axes
((1, 2, 3), (1, 3, 4)), # all batch dims size-1 -> squeeze + expand
],
ids=["no_batch", "broadcast_unbatched", "cross_broadcast", "all_broadcast"],
)
def test_blockwise_batch_broadcasting(a_shape, b_shape):
rng = np.random.default_rng(7)
a = tensor("a", shape=a_shape)
b = tensor("b", shape=b_shape)
out = matmul(a, b)

assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py(
[a, b], [out], [rng.standard_normal(a_shape), rng.standard_normal(b_shape)]
)


def test_blockwise_no_runtime_broadcast():
rng = np.random.default_rng(7)
a = tensor("a", shape=(None, 2, 3))
b = tensor("b", shape=(5, 3, 4))
out = matmul(a, b)

assert isinstance(out.owner.op, Blockwise)
values = [rng.standard_normal((1, 2, 3)), rng.standard_normal((5, 3, 4))]

py_fn = pytensor.function([a, b], out, mode=py_mode)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
py_fn(*values)

mlx_fn = pytensor.function([a, b], out, mode=mlx_mode)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
mlx_fn(*values)


@pytest.mark.parametrize("batch", [(), (5,)], ids=["no_batch", "single_batch"])
def test_blockwise_fallback_signature(batch):
rng = np.random.default_rng(7)
a = tensor("a", shape=(*batch, 2, 3))
b = tensor("b", shape=(*batch, 3, 4))
out = odd_matmul(a, b)

assert isinstance(out.owner.op, Blockwise)
compare_mlx_and_py(
[a, b],
[out],
[rng.standard_normal((*batch, 2, 3)), rng.standard_normal((*batch, 3, 4))],
)


def test_blockwise_multi_output():
rng = np.random.default_rng(7)
x = tensor("x", shape=(1, 4, 4))
q, r = pt.linalg.qr(x, mode="economic")

assert isinstance(q.owner.op, Blockwise)
compare_mlx_and_py(
graph_inputs=[x],
graph_outputs=[q, r],
test_inputs=[rng.standard_normal((1, 4, 4))],
)
Loading
Loading