Convert StableHLO models into Apple Core ML format.
StableHLO is the portability layer used by ML frameworks like JAX and PyTorch. This library converts StableHLO programs into Apple's Core ML format via coremltools, enabling deployment on Apple hardware (iOS, macOS, etc.).
pip install stablehlo-coremlRequires Python 3.9+ and targets iOS/macOS 18+.
Models can be exported from any framework that produces StableHLO:
- JAX / Flax / Equinox — via
jax.export - PyTorch — via torchax to trace the model into JAX, then
jax.exportto StableHLO
The test suite validates against a broad set of models, including full HuggingFace Transformers such as TinyLlama, T5, DistilBERT, GPT-2, BERT, and Whisper, as well as vision models like ResNet, EfficientNet, ViT, ConvNeXt, and more.
For a real-world example, see gemma-coreml-chat, which exports Google's Gemma 4 model to Core ML using this library.
To convert a StableHLO module:
import coremltools as ct
from stablehlo_coreml.converter import convert
from stablehlo_coreml import DEFAULT_HLO_PIPELINE
mil_program = convert(hlo_module, minimum_deployment_target=ct.target.iOS18)
cml_model = ct.convert(
mil_program,
source="milinternal",
minimum_deployment_target=ct.target.iOS18,
pass_pipeline=DEFAULT_HLO_PIPELINE,
)import jax
from jax._src.lib.mlir import ir
from jax._src.interpreters import mlir as jax_mlir
from jax.export import export
import jax.numpy as jnp
def jax_function(a, b):
return jnp.einsum("ij,jk -> ik", a, b)
context = jax_mlir.make_ir_context()
input_shapes = (jnp.zeros((2, 4)), jnp.zeros((4, 3)))
jax_exported = export(jax.jit(jax_function))(*input_shapes)
hlo_module = ir.Module.parse(jax_exported.mlir_module(), context=context)For the JAX example to work, you will additionally need to install absl-py and flatbuffers as dependencies.
JAX models exported with symbolic dimensions are supported. Symbolic dims flow
through GetDimensionSizeOp, DynamicBroadcastInDimOp, DynamicIotaOp, and
shape-assertion CustomCallOps automatically, producing CoreML models with
flexible inputs.
import jax
import jax.numpy as jnp
from jax.export import export, symbolic_shape
jax_exported = export(jax.jit(jax_function))(
jax.ShapeDtypeStruct(symbolic_shape("batch, 4"), jnp.float32),
jax.ShapeDtypeStruct((4, 3), jnp.float32),
)When converting to a CoreML model, specify RangeDim for each symbolic
dimension so the model accepts a range of sizes at inference time:
cml_model = ct.convert(
mil_program,
source="milinternal",
minimum_deployment_target=ct.target.iOS18,
pass_pipeline=DEFAULT_HLO_PIPELINE,
inputs=[
ct.TensorType(name="_arg0", shape=(ct.RangeDim(1, 2048, 1), 4)),
ct.TensorType(name="_arg1", shape=(4, 3)),
],
)See tests/test_symbolic_shapes.py for
symbolic matmul, batched einsum, and multi-axis patterns (for example
transformer-style projections).
The tests/ directory has end-to-end export and conversion examples:
- PyTorch (torchax) —
tests/pytorch/test_pytorch.py:export_to_stablehlo_module, HuggingFace Transformers, and torchvision models. - JAX —
tests/test_jax.py - Flax / Equinox —
tests/test_flax.py,tests/test_equinox.py
coremltoolssupports up to Python 3.13. Do not run hatch with a newer version. Can be controlled using e.g.export HATCH_PYTHON=python3.13- Run tests using
hatch run test:pytest tests