A toy specializing compiler for NumPy expressions that uses MLIR as a target and can use equality saturation (e-graphs) to do term rewriting on the intermediate representation, enabling extremely precise and composable optimizations of mathematical expressions before lowering to MLIR.
We use the embedded Datalog DSL egglog to express and compose rewrite rules in pure Python and the egg library to extract optimized syntax trees from the e-graph.
The whole project is just under 1300 lines of code, and is designed to be a simple and easy to understand example of how to integrate e-graphs into a compiler pipeline.
Think of an e-graph as this magical data structure that's like a super-powered hash table of program expressions. Instead of just storing one way to write a program, it efficiently stores ALL equivalent ways to write it.
Equality saturation is the process of filling this e-graph with all possible equivalent programs by applying rewrite rules until we can't find any more rewrites (that's the "saturation" part). We can explore tons of different optimizations simultaneously, rather than having to pick a specific sequence of transformations. The you can apply a cost function over the entire e-graph to extract the best solution that minimizes some user-defined objective function.
Traditionally you'd have to muddle through with a fixed-point iteration system and tons of top-down/bottom-up rewrite rule contingent on application orders, but e-graphs make term rewriting much more efficient, declarative and compositional.
On MacOS, install LLVM 20 which includes MLIR:
brew install llvm@20
export PATH="/opt/homebrew/opt/llvm@20/bin:$PATH"On Linux, install the dependencies (setup instructions here):
sudo apt-get install -y llvm-20 llvm-20-dev llvm-20-tools mlir-20-toolsThen to use the library built it with uv:
git clone https://github.com/sdiehl/mlir-egglog.git
cd mlir-egglog
uv sync
uv run python example_basic.pyimport numpy as np
from mlir_egglog import kernel
@kernel("float32(float32)")
def fn(x: float) -> float:
# sinh(2x) = 2 * sinh(x) * cosh(x)
return np.sinh(x) * np.cosh(x) + np.cosh(x) * np.sinh(x)
out = fn(np.array([1, 2, 3], dtype=np.float32))
print(out)You can create your own optimization rules using the ruleset decorator. Here's a complete example that optimizes away addition with zero:
import numpy as np
from egglog import rewrite, ruleset
from mlir_egglog import kernel
from mlir_egglog.term_ir import Term
from mlir_egglog.optimization_rules import basic_math
@ruleset
def float_rules(x: Term):
yield rewrite(x + Term.lit_f32(0.0)).to(x)
yield rewrite(Term.lit_f32(0.0) + x).to(x)
@kernel("float32(float32)", rewrites=(basic_math, float_rules))
def custom_fn(x):
return x + 0.0 # This addition will be optimized away!
test_input = np.array([1.0, 2.0, 3.0], dtype=np.float32)
result = custom_fn(test_input)
print(result)The rewrite rules are applied during compilation, so there's no runtime overhead. The generated MLIR code will be as if you just wrote return x. You can combine multiple rulesets to build up more complex program optimizations.
For a full example see example_rewrite.py.
Here's the recommended order to understand the codebase:
Foundation Layer - Expression representation and manipulation
memory_descriptors.py- Memory management utilities for NumPy arrays and MLIR memref descriptorsterm_ir.py- Intermediate representation for the e-graph system with cost models for each operation
Transformation Layer - Code transformation and lowering
python_to_ir.py- Converts Python functions to the internal IR representationoptimization_rules.py- Built-in algebraic and trigonometric rewrite rules
Optimization and Code Generation Layer
egglog_optimizer.py- Runs the e-graph saturation and extracts the lowest-cost termmlir_gen.py- Lowers the extracted term tree to MLIR source textmlir_backend.py- Shells out tomlir-optandmlir-translateto produce LLVM IR
Execution Layer - Runtime execution
llvm_runtime.py- Initialises llvmlite to query the host target triple and data layoutjit_engine.py- Compiles LLVM IR to a shared library viallcandcc, then loads it withctypesdispatcher.py-@kerneldecorator: drives compilation and dispatches NumPy array calls
Educational
tutorial.py- Walks through each stage of the compilation pipeline (AST -> IR -> e-graph -> MLIR -> LLVM IR -> machine code), used byexample_tutorial.py
This project is licensed under the MIT License. See the LICENSE file for details.
