Skip to content

Add full mlx blockwise support#2233

Open
jessegrabowski wants to merge 2 commits into
pymc-devs:mainfrom
jessegrabowski:mlx-blockwise-upgrade
Open

Add full mlx blockwise support#2233
jessegrabowski wants to merge 2 commits into
pymc-devs:mainfrom
jessegrabowski:mlx-blockwise-upgrade

Conversation

@jessegrabowski

Copy link
Copy Markdown
Member

MLX blockwise operations fail in non-trivial cases because our implementation isn't a full vectorize, it's just a single naive application of mlx.vmap. We need to align and broadcast all input shapes, then apply vmap once per non-core dimension. This is what jnp.vectorize does under the hood. I just went ahead and re-implemented jnp.vectorize in MLX, giving us full vectorize support.

Closes #2092 , and also address additional unreported cases (e.g. when data has batch dim).

Comment thread tests/link/mlx/test_blockwise.py
Comment thread tests/link/mlx/test_blockwise.py
Comment thread pytensor/link/mlx/dispatch/blockwise.py Outdated
@ricardoV94

Copy link
Copy Markdown
Member

failing test

@jessegrabowski jessegrabowski force-pushed the mlx-blockwise-upgrade branch 4 times, most recently from da138df to e95b8c7 Compare June 21, 2026 23:59
@jessegrabowski jessegrabowski force-pushed the mlx-blockwise-upgrade branch from e95b8c7 to d53a158 Compare June 22, 2026 00:31
@jessegrabowski jessegrabowski force-pushed the mlx-blockwise-upgrade branch from d53a158 to 3f11bb2 Compare June 22, 2026 01:50
@ricardoV94

Copy link
Copy Markdown
Member

that ci looks like fun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: MLX Convolve1d dispatch crashes when Blockwise broadcasts the kernel

2 participants