Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
7c22756
Add monoid module (#653)
jfeser May 6, 2026
1d38f0d
wip
jfeser May 7, 2026
4ece287
Merge branch 'staging-weighted' into jf-weighted-lifting
jfeser May 7, 2026
92586c4
cleanup
jfeser May 7, 2026
f7d43e5
wip
jfeser May 7, 2026
bd404eb
Merge branch 'staging-weighted' into jf-weighted-jax
jfeser May 7, 2026
8479ed5
fix rule
jfeser May 12, 2026
bd1025e
wip
jfeser May 12, 2026
39d8bb0
fix bug
jfeser May 12, 2026
5283db1
cleanup
jfeser May 12, 2026
9fd88af
lin
jfeser May 12, 2026
43ca58d
wip
jfeser May 12, 2026
ed8cf13
fix tests
jfeser May 12, 2026
908f580
format
jfeser May 12, 2026
5c504a3
lint
jfeser May 12, 2026
c0472a8
wip
jfeser May 12, 2026
3b5ef0c
Merge branch 'staging-weighted' into jf-weighted-class-refactoring
jfeser May 12, 2026
40a1a7c
Merge branch 'jf-weighted-class-refactoring' into jf-weighted-jax
jfeser May 12, 2026
e99d34a
wip
jfeser May 12, 2026
2fd6bab
Merge branch 'staging-weighted' into jf-weighted-jax
jfeser May 12, 2026
45cac7d
wip
jfeser May 13, 2026
8d0a4e8
wip
jfeser May 13, 2026
dd836bd
wip
jfeser May 13, 2026
fb92486
wip
jfeser May 13, 2026
622d4ac
wip
jfeser May 13, 2026
9db88d7
wip
jfeser May 13, 2026
12c91da
drop runtime typed dict lifting
jfeser May 13, 2026
11fa13a
wip
jfeser May 13, 2026
8711a76
format
jfeser May 13, 2026
0ccbefb
reorganize
jfeser May 13, 2026
76f7002
stop using string dicts to avoid unification issue
jfeser May 14, 2026
6fdd23d
wip
jfeser May 14, 2026
6996a8c
wip
jfeser May 14, 2026
afea06c
wip
jfeser May 14, 2026
08d92ae
wip
jfeser May 14, 2026
b50c059
wip
jfeser May 14, 2026
7e204d0
use check_rewrite in jax tests
jfeser May 14, 2026
77241f9
lint
jfeser May 14, 2026
5e6d333
fix bugs
jfeser May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions effectful/handlers/jax/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
deffn,
defop,
syntactic_eq,
syntactic_hash,
)
from effectful.ops.types import Expr, NotHandled, Operation, Term

Expand Down Expand Up @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool:
and x.shape == other.shape
and bool((jnp.asarray(x) == jnp.asarray(other)).all())
)


@syntactic_hash.register(jax.Array)
def _(x: jax.Array) -> int:
# Concrete arrays aren't hashable; hash by shape, dtype, and bytes.
return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x))))
162 changes: 162 additions & 0 deletions effectful/handlers/jax/monoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import functools

import jax

import effectful.handlers.jax.numpy as jnp
from effectful.handlers.jax import bind_dims, unbind_dims
from effectful.handlers.jax.scipy.special import logsumexp
from effectful.ops.monoid import (
CartesianProduct,
Max,
Min,
Monoid,
NormalizeIntp,
Product,
Sum,
outer_stream,
)
from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof
from effectful.ops.syntax import ObjectInterpretation, deffn, implements
from effectful.ops.types import Operation


def cartesian_prod(x, y):
if x.ndim == 1:
x = x[:, None]
if y.ndim == 1:
y = y[:, None]
nx, dx = x.shape
ny, dy = y.shape
# Broadcast into (nx, ny, dx+dy), then flatten the first two axes
x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx))
y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy))
return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy)


LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf")))


def _jax_args(args):
"""True iff ``args`` is non-empty and every arg is a concrete
:class:`jax.Array` (no Terms).
"""
typs = (typeof(a) for a in args)
return (
bool(args)
and any(issubclass(t, jax.Array) for t in typs)
and all(issubclass(t, jax.typing.ArrayLike) for t in typs)
)


class SumPlusJax(ObjectInterpretation):
@implements(Sum.plus)
def plus(self, *args):
if not _jax_args(args):
return fwd()
return functools.reduce(jnp.add, args)


class ProductPlusJax(ObjectInterpretation):
@implements(Product.plus)
def plus(self, *args):
if not _jax_args(args):
return fwd()
return functools.reduce(jnp.multiply, args)


class MinPlusJax(ObjectInterpretation):
@implements(Min.plus)
def plus(self, *args):
if not _jax_args(args):
return fwd()
return functools.reduce(jnp.minimum, args)


class MaxPlusJax(ObjectInterpretation):
@implements(Max.plus)
def plus(self, *args):
if not _jax_args(args):
return fwd()
return functools.reduce(jnp.maximum, args)


class LogSumExpPlusJax(ObjectInterpretation):
@implements(LogSumExp.plus)
def plus(self, *args):
if not _jax_args(args):
return fwd()
return functools.reduce(jnp.logaddexp, args)


class CartesianProductPlusJax(ObjectInterpretation):
@implements(CartesianProduct.plus)
def plus(self, *args):
# Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both
# sentinels arrive as Python lists alongside jax-array factors, so
# check for them explicitly before composing.
if not any(isinstance(a, jax.Array) for a in args):
return fwd()
result = None
for a in args:
if a is CartesianProduct.zero:
return CartesianProduct.zero
if a is CartesianProduct.identity:
continue
if not isinstance(a, jax.Array):
return fwd()
result = a if result is None else cartesian_prod(result, a)
return result if result is not None else CartesianProduct.identity


ARRAY_REDUCTORS = {
Sum: jnp.sum,
Product: jnp.prod,
Min: jnp.min,
Max: jnp.max,
LogSumExp: logsumexp,
}


class ArrayReduce(ObjectInterpretation):
@implements(Monoid.reduce)
def reduce(self, monoid, body, streams):
if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array:
return fwd()
if not streams:
return monoid.identity

reductor = ARRAY_REDUCTORS[monoid]
index = Operation.define(jax.Array)
for stream_key, stream_body, streams_tail in outer_stream(streams):
if not issubclass(typeof(stream_body), jax.Array):
continue

if stream_key in fvsof(body):
with handler({stream_key: deffn(unbind_dims(stream_body, index))}):
eval_body = evaluate(body)
eval_streams_tail = evaluate(streams_tail)
assert isinstance(eval_streams_tail, dict)
reduce_tail = (
monoid.reduce(eval_body, eval_streams_tail)
if len(eval_streams_tail) > 0
else eval_body
)
return reductor(bind_dims(reduce_tail, index), axis=0)
else:
# TODO: In this case, the stream is unused in the body. The body
# should be multiplied by the length of the stream. The current
# behavior is not efficient.
return fwd()

return fwd()


NormalizeIntp.extend(
ArrayReduce(),
SumPlusJax(),
ProductPlusJax(),
MinPlusJax(),
MaxPlusJax(),
LogSumExpPlusJax(),
CartesianProductPlusJax(),
)
Loading
Loading