As commented on LinkedIn, I'm working on an MLX port of this project - mainly to speed up the local models that I'm running.
My prior work regarding the MLX/ANE platform spans:
- Custom Bonsai inference server Qwen3-8B with just 1.1 GB memory footprint + specific KV quantization including support for custom speculative decoding with working TurboQuant integration, 40-50 t/s on Macbook Air M4 24GB when running in background alongside many other processes
- contributions to mlx-audio fixing Qwen3-TTS implementation issues
- several ANE FluidAudio ASR/TTS custom inference server implementations (to be released)
- my typical coding/working style (C++): https://github.com/kyr0/libsharedmemory
I'm opening this issue for:
a) reporting/tracking progress before opening a PR (keeping the noise level low)
b) asking the community for feature requests
c) checking if anyone would be interested to collaborate
d) asking the maintainer(s) if your have any concrete idea on how cross-backend support should look like
Broader Motivation/Context:
I'm engaging on this in my spare time. I'm currently on vacation -- my mid-term goal is to build a pretty good local voice agent that can be used as a generic, multilingual, intelligent realtime interface audio (TTS + ASR) / video on Mac using a multi-backend inference approach (clever management of resources/scheduling to limit memory bandwidth + smart timing). Because of this, I need to speed up MLX models that are suitable for the task -- so my interest is a bit narrow-focused, but as I'm targeting MLX kernels here, this should still yield a generalizable outcome; I'm, in parallel, also focusing on improving model quantization (separate project, research paper on optimal Gemma 4 quantization is pending rn..). To get the maximum out of any new model until the goal is reached, I'm also checking every corner where I could improve the inference pipeline when it comes to cache quantization and throughput maximization (see Bonsai Server).
Implementation plan:
- I froze my MLX fork including 1 bit inference support for PrinsmML/Bonsai-8B, commit ref
- I'm adding an
mlx backend to AutoKernel instead of trying to force the current Triton/CUDA flow onto MLX
- Making the optimization unit a target manifest, not a single editable kernel file
- I created calibration files with prompt/response pairs and use only three frozen replay/calibration packs:
calibration_gemma4_e4b.json
calibration_gemma4_e2b.json
calibration_bonsai_8b.json
- I'm running real inference for the matching target model against its fixed pack
- Treating the stored
response, tokens, time_s, and tok_per_sec fields as the frozen baseline for replay validation and throughput comparison (imo AutoKernel should not optimize for theoretical results, but real world results)
- As almost all models that users run in practice, are quantized, I started with explicit pipeline support for:
quantized.cpp
quantized.h / quantized.metal
quantized_nax.h / quantized_nax.metal
quantized_utils.h
scaled_dot_product_attention.cpp
sdpa_vector.h
scaled_dot_product_attention.metal
- I'm currently building
bench_mlx.py on top of public MLX ops only:
mlx.core.quantized_matmul
mlx.core.fast.scaled_dot_product_attention
mlx.core.fast.rms_norm
mlx.core.fast.rope
- I'll add
verify_mlx.py for fixed-pack replay; it must run prompt-by-prompt inference and record:
- decoded continuation
- generated token count
- wall time
- tok/s
- memory / TTFT if available
- I'll then add rebuild automation so every experiment can:
- patch
- rebuild
- microbench
- keep or revert
- Then record the actual dispatch path for every run
- Optimization imo. needs to go in this order:
- quantized decode dispatch (
qmv_fast / qmv / qmv_quad)
- quantized prefill dispatch (
qmm_t / qmm_n / NAX-gated paths)
qdot / qdot_safe
- QMM loader + GEMM loop pipeline
sdpa_vector / 2-pass SDPA
- RMSNorm
- single-token RoPE
qvm_split_k
- I'm treating standalone softmax, LayerNorm, GEMV-masked, and Hadamard as secondary until traces prove otherwise
- AutoKernel-MLX "Agentic DoD" will require every kept change to pass:
- compile
- op-level correctness
- determinism
- fixed-pack replay on the relevant model
- no unacceptable response drift
- throughput improvement or trace-backed hot-path win
- It will capture Metal traces only on kept/promoted candidates, using the same fixed prompt packs
- Future research (V2): After a single-kernel path is solid, it will move to the higher-ROI work: reducing Metal dispatch overhead and fusing decode-path ops where it actually pays off
WDYT?
Edits: Grammar, mostly, and some wordings
As commented on LinkedIn, I'm working on an MLX port of this project - mainly to speed up the local models that I'm running.
My prior work regarding the MLX/ANE platform spans:
I'm opening this issue for:
a) reporting/tracking progress before opening a PR (keeping the noise level low)
b) asking the community for feature requests
c) checking if anyone would be interested to collaborate
d) asking the maintainer(s) if your have any concrete idea on how cross-backend support should look like
Broader Motivation/Context:
I'm engaging on this in my spare time. I'm currently on vacation -- my mid-term goal is to build a pretty good local voice agent that can be used as a generic, multilingual, intelligent realtime interface audio (TTS + ASR) / video on Mac using a multi-backend inference approach (clever management of resources/scheduling to limit memory bandwidth + smart timing). Because of this, I need to speed up MLX models that are suitable for the task -- so my interest is a bit narrow-focused, but as I'm targeting MLX kernels here, this should still yield a generalizable outcome; I'm, in parallel, also focusing on improving model quantization (separate project, research paper on optimal Gemma 4 quantization is pending rn..). To get the maximum out of any new model until the goal is reached, I'm also checking every corner where I could improve the inference pipeline when it comes to cache quantization and throughput maximization (see Bonsai Server).
Implementation plan:
mlxbackend to AutoKernel instead of trying to force the current Triton/CUDA flow onto MLXcalibration_gemma4_e4b.jsoncalibration_gemma4_e2b.jsoncalibration_bonsai_8b.jsonresponse,tokens,time_s, andtok_per_secfields as the frozen baseline for replay validation and throughput comparison (imo AutoKernel should not optimize for theoretical results, but real world results)quantized.cppquantized.h/quantized.metalquantized_nax.h/quantized_nax.metalquantized_utils.hscaled_dot_product_attention.cppsdpa_vector.hscaled_dot_product_attention.metalbench_mlx.pyon top of public MLX ops only:mlx.core.quantized_matmulmlx.core.fast.scaled_dot_product_attentionmlx.core.fast.rms_normmlx.core.fast.ropeverify_mlx.pyfor fixed-pack replay; it must run prompt-by-prompt inference and record:qmv_fast/qmv/qmv_quad)qmm_t/qmm_n/ NAX-gated paths)qdot/qdot_safesdpa_vector/ 2-pass SDPAqvm_split_kWDYT?
Edits: Grammar, mostly, and some wordings