From 7c2275653cc9f98630f9d429e6000b093df6088e Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 6 May 2026 16:21:46 -0400 Subject: [PATCH 01/34] Add monoid module (#653) * add monoid module * clean up * fix doctest * fix * wip * remove incorrect rule * add disjoint set tests and fix bug * lint * drop jax monoid defs * drop incorrect comment * add assert * reduce nondeterminism and add assertions * fix inconsistent stream numbering and missing constant factors --- effectful/internals/disjoint_set.py | 99 +++++ effectful/ops/monoid.py | 556 +++++++++++++++++++++++++++ effectful/ops/syntax.py | 78 ++++ pyproject.toml | 1 + tests/_monoid_helpers.py | 85 ++++ tests/test_internals_disjoint_set.py | 124 ++++++ tests/test_ops_monoid.py | 518 +++++++++++++++++++++++++ 7 files changed, 1461 insertions(+) create mode 100644 effectful/internals/disjoint_set.py create mode 100644 effectful/ops/monoid.py create mode 100644 tests/_monoid_helpers.py create mode 100644 tests/test_internals_disjoint_set.py create mode 100644 tests/test_ops_monoid.py diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py new file mode 100644 index 000000000..73b5c5c52 --- /dev/null +++ b/effectful/internals/disjoint_set.py @@ -0,0 +1,99 @@ +class DisjointSet: + """Disjoint Set Union (Union-Find) data structure. + + Maintains a collection of disjoint sets over the integers 0..n-1, + supporting near-constant-time union and find operations via + path compression and union by rank. + + The amortized time complexity per operation is O(α(n)), where α + is the inverse Ackermann function (effectively constant for any + practical n). + + Example: + >>> dsu = DisjointSet(5) + >>> dsu.union(0, 1) + True + >>> dsu.union(1, 2) + True + >>> dsu.find(0) == dsu.find(2) + True + >>> dsu.find(0) == dsu.find(3) + False + """ + + def __init__(self, n): + """Initialize n singleton sets: {0}, {1}, ..., {n-1}. + + Args: + n: The number of elements. Elements are labeled 0..n-1. + """ + self.parent = list(range(n)) + self.rank = [0] * n + + def _validate(self, x): + if x < 0 or x >= len(self.parent): + raise IndexError(f"Element {x} out of bounds") + + def find(self, x): + """Return the representative (root) of the set containing x. + + Two elements belong to the same set if and only if they have + the same representative. Applies path compression: every node + traversed is re-parented directly to its grandparent, flattening + the tree to speed up future queries. + + Args: + x: The element to look up. + + Returns: + The root element of x's set. + """ + self._validate(x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path compression + x = self.parent[x] + return x + + def union(self, *elements): + """Merge the sets containing all given elements into one. + + Accepts any number of elements and unions them all together. + Uses union by rank: shallower trees are attached under the root + of the deeper one, keeping the combined tree shallow. + + Args: + *elements: Two or more elements to merge into a single set. + Calling with 0 or 1 elements is a no-op and returns False. + + Returns: + True if any merging occurred (i.e., at least two of the + elements were in different sets); False if all elements + were already in the same set or fewer than 2 were given. + """ + if len(elements) < 2: + return False + + merged = False + first = elements[0] + + for y in elements[1:]: + if self._union_pair(first, y): + merged = True + + return merged + + def _union_pair(self, x, y): + rx = self.find(x) + ry = self.find(y) + + if rx == ry: + return False + + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + + return True diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 000000000..58a10ba3d --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,556 @@ +import collections.abc +import functools +import itertools +import numbers +import typing +from collections import Counter, defaultdict +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.internals.disjoint_set import DisjointSet +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + _NumberTerm, + defdata, + implements, + iter_, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Interpretation, NotHandled, Operation, Term + +# Note: The streams value type should be something like Iterable[T], but some of +# our target stream types (e.g. jax.Array) are not subtypes of Iterable +type Streams[T] = Mapping[Operation[[], T], Any] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: + """Determine an order to evaluate the streams based on their dependencies""" + stream_vars = set(streams.keys()) + dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(dependencies) + topo.prepare() + while topo.is_active(): + node_group = topo.get_ready() + for op in sorted(node_group): + yield (op, streams[op]) + topo.done(*node_group) + + +class Monoid[T]: + kernel: Operation[[T, T], T] + identity: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T): + self.identity = identity + self.kernel = ( + kernel if isinstance(kernel, Operation) else Operation.define(kernel) + ) + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + + """ + if not args: + return typing.cast(S, self.identity) + + if any(isinstance(x, Term) for x in args): + return typing.cast(S, defdata(self.plus, *args)) + + return self._plus(*args) + + @functools.singledispatchmethod + def _plus[S](self, *args: S) -> S: + return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) + + @_plus.register(Sequence) + def _(self, *args): + return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Mapping) + def _(self, *args): + if isinstance(args[0], Interpretation): + keys = args[0].keys() + + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + + b_keys = b.keys() + if not keys == b_keys: + raise ValueError( + f"Expected interpretation of {keys} but got {b_keys}" + ) + + result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + return result + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + result = {k: self.plus(*vs) for (k, vs) in all_values.items()} + return result + + @Operation.define + @functools.singledispatchmethod + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + if callable(body): + return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = { + stream_key: functools.partial(lambda v: v, val) + } + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + try: + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) + except NotHandled: + return typing.cast(U, defdata(self.reduce, body, streams)) + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Mapping, streams): + return {k: self.reduce(v, streams) for (k, v) in body.items()} + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Sequence, streams): + return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + + @reduce.register # type: ignore[attr-defined] + def _(self, body: Generator, streams): + return (self.reduce(x, streams) for x in body) + + +class IdempotentMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoid[T](Monoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): + zero: T + + def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): + super().__init__(kernel, identity) + self.zero = zero + + def __repr__(self): + return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" + + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): + @Operation.define + def plus[S: Body[T]](self, *args: S) -> S: + return super().plus(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + return super().reduce(body, streams) + + +@Operation.define +def _arg_min[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] < a[0] else a # type: ignore + + +@Operation.define +def _arg_max[T]( + a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] +) -> tuple[numbers.Number, T | None]: + if isinstance(a[0], Term) or isinstance(b[0], Term): + raise NotHandled + return b if b[0] > a[0] else a # type: ignore + + +Min = Semilattice(kernel=min, identity=float("inf")) +Max = Semilattice(kernel=max, identity=float("-inf")) +ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) +ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) +Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) +Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + { + (Max.plus, Min.plus), + (Min.plus, Max.plus), + (Sum.plus, Min.plus), + (Sum.plus, Max.plus), + (Product.plus, Sum.plus), + } +) + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusIdentity(ObjectInterpretation): + """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(x is monoid.identity for x in args): + return monoid.plus(*(x for x in args if x is not monoid.identity)) + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + flat_args = itertools.chain.from_iterable( + t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) + for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any( + isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + ): + non_terms = [] + + # group terms by head operation + by_head_op = defaultdict(list) + for t in args: + if isinstance(t, Term): + by_head_op[t.op].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for op, terms in by_head_op.items(): + if ( + len(terms) > 1 + and distributes_over(monoid.plus, op) + and not distributes_over(op, monoid.plus) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(op(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusZero(ObjectInterpretation): + """x₁ * ... * 0 * ... * xₙ = 0""" + + @implements(CommutativeMonoidWithZero.plus) + def plus(self, monoid, *args): + if any(x is monoid.zero for x in args): + return monoid.zero + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(IdempotentMonoid.plus) + def plus(self, monoid, *args): + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Semilattice.plus) + def plus(self, monoid, *args): + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +NormalizePlusIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + ], + ), +) + + +class ReduceNoStreams(ObjectInterpretation): + """Implements the identity + reduce(R, ∅, body) = 0 + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, _, streams): + if len(streams) == 0: + return monoid.identity + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op == monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +class ReduceFactorization(ObjectInterpretation): + """ + Implements factorization of independent terms. + For example, when having two independent distributions, + we can rewrite their marginalization as: + ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy + + More specifically, in terms of reduces we are performing: + reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) + => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) + where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ + and free(Aᵢ) ∩ S ⊆ Sᵢ + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + stream_vars = set(streams.keys()) + factors = [(arg, fvsof(arg)) for arg in body.args] + stream_ids = {v: i for (i, v) in enumerate(stream_vars)} + ds = DisjointSet(len(streams)) + + # streams are in the same partition as their dependencies + for stream_var, stream_id in stream_ids.items(): + stream_body = streams[stream_var] + deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) + ds.union(stream_id, *deps) + + # factors are in the same partition as their dependencies + for factor, factor_fvs in factors: + factor_streams = sorted( + [stream_ids[v] for v in (factor_fvs & stream_vars)] + ) + ds.union(*factor_streams) + + placed_streams = set() + new_reduces = [] + for stream_key in streams: + if stream_key in placed_streams: + continue + + partition = ds.find(stream_ids[stream_key]) + partition_streams = { + k: v + for (k, v) in streams.items() + if ds.find(stream_ids[k]) == partition + } + partition_stream_keys = set(partition_streams.keys()) + + partition_factors = [ + t for t in factors if (t[1] & partition_stream_keys) + ] + + assert all( + (t[1] & stream_vars) <= partition_stream_keys + for t in partition_factors + ), "partition contains all streams required by factor" + + partition_term = body.op(*(t[0] for t in partition_factors)) + new_reduces.append((partition_term, partition_streams)) + placed_streams |= partition_stream_keys + + constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] + + if len(new_reduces) > 1: + result = body.op( + *constant_factors, *(monoid.reduce(*args) for args in new_reduces) + ) + return result + + return fwd() + + +NormalizeReduceIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + ), +) + +NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..8fb12598f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -852,6 +852,84 @@ def _(x: object, other) -> bool: return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/pyproject.toml b/pyproject.toml index d565403f2..685aaf55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 000000000..4532ae72d --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,85 @@ +from collections.abc import Callable, Mapping, Sequence +from typing import Any, get_args, get_origin + +from hypothesis import strategies as st + +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: + """Strategy for the value an *0-arg* Operation should return.""" + if annotation is int: + return st.integers() + if annotation is float: + return st.floats(allow_nan=False) + if get_origin(annotation) is list and get_args(annotation) == (int,): + return st.lists(st.integers()) + raise NotImplementedError( + f"No value strategy for return annotation {annotation!r}; " + "supported: int, list[int]" + ) + + +_UNARY_INT_FNS: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, +] + +_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, +] + +_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], +] + + +def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Pick a strategy producing a callable suitable for binding `op` in an + interpretation. Inspects the operation's signature. + """ + sig = op.__signature__ + params = list(sig.parameters.values()) + ret = sig.return_annotation + param_types = tuple(p.annotation for p in params) + + if not params: + return _value_strategy_for(ret).map(deffn) + if ret is int and param_types == (int,): + return st.sampled_from(_UNARY_INT_FNS) + if ret is int and param_types == (int, int): + return st.sampled_from(_BINARY_INT_FNS) + if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): + return st.sampled_from(_UNARY_LIST_FNS) + raise NotImplementedError( + f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + ) + + +@st.composite +def random_interpretation( + draw: st.DrawFn, free_vars: Sequence[Operation] +) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `case.free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op in free_vars: + intp[op] = draw(_strategy_for_op(op)) + return intp + + +__all__ = ["random_interpretation"] diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py new file mode 100644 index 000000000..808b8d25d --- /dev/null +++ b/tests/test_internals_disjoint_set.py @@ -0,0 +1,124 @@ +import random + +import pytest + +from effectful.internals.disjoint_set import DisjointSet + + +@pytest.fixture +def dsu(): + return DisjointSet(10) + + +def test_initial_state(dsu): + for i in range(10): + assert dsu.find(i) == i + + +def test_simple_union(dsu): + assert dsu.union(1, 2) is True + assert dsu.find(1) == dsu.find(2) + + +def test_union_idempotent(dsu): + dsu.union(1, 2) + assert dsu.union(1, 2) is False + + +def test_union_chain(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + assert dsu.find(1) == dsu.find(3) + + +def test_union_multiple_elements_all_connected(dsu): + dsu.union(1, 2, 3, 4, 5) + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_partial_overlap(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + dsu.union(2, 3, 5) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_with_existing_connections(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4, 5, 6) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} + assert len(roots) == 1 + + +def test_union_single_element(dsu): + assert dsu.union(1) is False + + +def test_union_no_elements(dsu): + assert dsu.union() is False + + +def test_union_self(dsu): + assert dsu.union(3, 3) is False + assert dsu.find(3) == 3 + + +def test_transitivity(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + assert dsu.find(1) == dsu.find(4) + + +def test_disjoint_sets_remain_separate(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + assert dsu.find(1) != dsu.find(3) + + +def test_randomized_unions(): + n = 50 + dsu = DisjointSet(n) + + groups = [{i} for i in range(n)] + + def find_group(x): + for g in groups: + if x in g: + return g + + for _ in range(100): + elems = random.sample(range(n), random.randint(2, 5)) + dsu.union(*elems) + + # merge ground-truth groups + merged = set() + for e in elems: + merged |= find_group(e) + + groups = [g for g in groups if g.isdisjoint(merged)] + groups.append(merged) + + # verify structure matches ground truth + for g in groups: + roots = {dsu.find(x) for x in g} + assert len(roots) == 1 + + +def test_path_compression_effect(): + dsu = DisjointSet(6) + dsu.union(0, 1) + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + + # Trigger compression + root_before = dsu.find(4) + root_after = dsu.find(4) + + assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 000000000..a22928cca --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,518 @@ +import functools +import itertools + +import pytest +from hypothesis import given, settings +from hypothesis import strategies as st + +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.semantics import apply, evaluate, fvsof, handler +from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import random_interpretation + +_INT = st.integers(min_value=-100, max_value=100) + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + + +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +@functools.cache +def _canonical_op(idx: int) -> Operation: + """Globally cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + return Operation.define(int, name=f"__cv_{idx}") + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + return syntactic_eq(_canonicalize(x), _canonicalize(y)) + + +def _canonicalize(expr): + counter = itertools.count() + + def _passthrough(op, *args, **kwargs): + return defdata(op, *args, **kwargs) + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _passthrough, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return defdata(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter)) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT, b=_INT, c=_INT) +@settings(max_examples=50, deadline=None) +def test_associativity(monoid, a, b, c): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert left == right + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_identity(monoid, a): + assert monoid.plus(monoid.identity, a) == a + assert monoid.plus(a, monoid.identity) == a + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(a=_INT, b=_INT) +@settings(max_examples=50, deadline=None) +def test_commutativity(monoid, a, b): + assert monoid.plus(a, b) == monoid.plus(b, a) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_idempotence(monoid, a): + assert monoid.plus(a, a) == a + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(a=_INT) +@settings(max_examples=50, deadline=None) +def test_zero_absorbs(monoid, a): + assert monoid.plus(monoid.zero, a) == monoid.zero + assert monoid.plus(a, monoid.zero) == monoid.zero + + +def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: + """Run structural + semantic checks on a TermPair.""" + with handler(NormalizeIntp): + norm = evaluate(lhs) + + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=None) + def _check_semantics(intp): + with handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert lhs_val == rhs_val + + _check_semantics() + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid): + _check_pair(lhs=monoid.plus(), rhs=monoid.identity) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid): + x = define_vars("x", typ=type(monoid.identity)) + _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid): + x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus([a(), b()], [c(), d()]), + rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + free_vars=[a, b, c, d], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid): + a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) + _check_pair( + lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), + rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, + free_vars=[a, b, c, d], + ) + + +def test_plus_distributes(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_constant(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + rhs = Product.plus( + 5, + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +def test_plus_distributes_multiple(): + a, b, c, d = define_vars("a", "b", "c", "d") + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), a(), b()) + return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" + a, b = define_vars("a", "b") + lhs = monoid.plus(a(), b(), a()) + if isinstance(monoid, Semilattice): + rhs = monoid.plus(a(), b()) + else: + rhs = monoid.plus(a(), b(), a()) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + + +def test_plus_commutative_idempotent_long(): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = define_vars("a", "b") + lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) + _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid): + a = define_vars("a") + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) + _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(x())], {x: X()}) + rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid): + x, y = define_vars("x", "y") + X, Y = define_vars("X", "Y", typ=list[int]) + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) + rhs = [ + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ] + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid): + x = Operation.define(int, name="x") + X = Operation.define(list[int], name="X") + + @Operation.define + def f(_x: int) -> int: + raise NotHandled + + g = Operation.define(f, name="g") + + lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + rhs = { + "a": monoid.reduce(f(x()), {x: X()}), + "b": monoid.reduce(g(x()), {x: X()}), + } + _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid): + a = define_vars("a") + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_1(): + a, b = define_vars("a", "b") + A, B = define_vars("A", "B", typ=list[int]) + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + + +def test_reduce_independent_2(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_independent_3_negative(): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = define_vars("a", "b", "c") + A, C = define_vars("A", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + @Operation.define + def g(_x: int) -> list[int]: + raise NotHandled + + with handler(NormalizeIntp): + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + # Structural-only negative check: the normalizer correctly refused to apply + # the bogus factorization. + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(): + a, b, c = define_vars("a", "b", "c") + A, B, C = define_vars("A", "B", "C", typ=list[int]) + + @Operation.define + def f(_x: int, _y: int) -> int: + raise NotHandled + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + 7, + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) From 1d38f0dbbf322d4f262f8c95e366f4318f731beb Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 7 May 2026 15:27:57 -0400 Subject: [PATCH 02/34] wip --- effectful/internals/product_n.py | 2 +- effectful/ops/monoid.py | 174 +++++++++++++++++++++++++++- effectful/ops/semantics.py | 1 + effectful/ops/types.py | 5 +- tests/_monoid_helpers.py | 12 +- tests/test_handlers_llm_provider.py | 2 +- tests/test_ops_monoid.py | 95 +++++++++++++-- tests/test_ops_syntax.py | 2 +- 8 files changed, 269 insertions(+), 24 deletions(-) diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..87a9c6a42 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3d..748eb9cf3 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,18 +10,20 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet +from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -80,9 +82,13 @@ def plus[S: Body[T]](self, *args: S) -> S: def _plus[S](self, *args: S) -> S: return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) - @_plus.register(Sequence) + @_plus.register(tuple) def _(self, *args): - return type(args[0])(self.plus(*vs) for vs in zip(*args, strict=True)) + return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) + + @_plus.register(Generator) + def _(self, *args): + return (self.plus(*vs) for vs in zip(*args, strict=True)) @_plus.register(Mapping) def _(self, *args): @@ -161,8 +167,8 @@ def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): - return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _(self, body: tuple, streams): + return tuple(self.reduce(x, streams) for x in body) @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): @@ -252,12 +258,26 @@ def _arg_max[T]( return b if b[0] > a[0] else a # type: ignore +@Operation.define +def product[T]( + a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] +) -> Iterable[tuple[T, ...]]: + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + Min = Semilattice(kernel=min, identity=float("inf")) Max = Semilattice(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +CartesianProduct = Monoid(kernel=product, identity=[()]) @dataclass @@ -545,11 +565,153 @@ def reduce(self, monoid, body, streams): return fwd() +def outer_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) + topo.prepare() + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) + + +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + for pred in fvsof(v) & stream_vars: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +def match_reduce(term: Term) -> tuple | None: + reduce_args = None + + def set_reduce_args(*args, **kwargs): + nonlocal reduce_args + reduce_args = args + + with interpreter({Monoid.reduce: set_reduce_args}): + term.op(*term.args, **term.kwargs) + return reduce_args + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(CommutativeMonoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if distributes_over(sum_body.op, sum_monoid.plus): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + prod_args = match_reduce(prod_reduce) + if prod_args is None: + return fwd() + (prod_monoid, prod_body, prod_streams) = prod_args + if not ( + distributes_over(prod_monoid.plus, sum_monoid.plus) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op == CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + inner_sum = sum_monoid.reduce( + Product.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], - [ReduceNoStreams(), ReduceFusion(), ReduceSplit(), ReduceFactorization()], + [ + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + ], ), ) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd24..8fd62bcd5 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af5..d24be9745 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -488,7 +488,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72d..f6397053b 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -21,7 +21,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: ) -_UNARY_INT_FNS: list[Callable[[int], int]] = [ +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ lambda x: x, lambda x: x + 1, lambda x: x - 1, @@ -30,7 +30,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: 3 * x + 1, ] -_BINARY_INT_FNS: list[Callable[[int, int], int]] = [ +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ lambda x, y: x + y, lambda x, y: x - y, lambda x, y: x * y, @@ -58,10 +58,10 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: if not params: return _value_strategy_for(ret).map(deffn) - if ret is int and param_types == (int,): - return st.sampled_from(_UNARY_INT_FNS) - if ret is int and param_types == (int, int): - return st.sampled_from(_BINARY_INT_FNS) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) raise NotImplementedError( diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bbd..9a2983901 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cca..e1c1400ff 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -6,7 +6,15 @@ from hypothesis import strategies as st from effectful.internals.runtime import interpreter -from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + NormalizeIntp, + Product, + Semilattice, + Sum, +) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq from effectful.ops.types import NotHandled, Operation @@ -70,14 +78,11 @@ def syntactic_eq_alpha(x, y) -> bool: def _canonicalize(expr): counter = itertools.count() - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - def _substitute(arg, renaming): """Apply a bound-variable renaming using ``evaluate`` for traversal.""" if not renaming: return arg - with interpreter({apply: _passthrough, **renaming}): + with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) def _bound_var_order(args, kwargs, bound_set): @@ -121,7 +126,7 @@ def _apply_canonical(op, *args, **kwargs): *bindings.args, *bindings.kwargs.values() ) if not all_bound: - return defdata(op, *args, **kwargs) + return _BaseTerm(op, *args, **kwargs) order = _bound_var_order(args, kwargs, all_bound) canonical = {var: _canonical_op(next(counter)) for var in order} @@ -496,8 +501,6 @@ def g(_x: int) -> list[int]: Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), ) assert fvsof(bogus_rhs) != fvsof(lhs) - # Structural-only negative check: the normalizer correctly refused to apply - # the bogus factorization. assert not syntactic_eq_alpha(lhs, bogus_rhs) @@ -516,3 +519,79 @@ def f(_x: int, _y: int) -> int: Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + + +def test_reduce_lifted_1(): + a, i = define_vars("a", "i") + A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = Product.reduce(Sum.reduce(f(a()), {a: A_domain()}), {i: N()}) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i") + A = define_vars("A", typ=list[int]) + + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +def test_reduce_lifted_2(): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t") + A, N, T = define_vars("A", "N", "T", typ=list[int]) + + @Operation.define + def A_domain(_i: int) -> list[int]: + raise NotHandled + + @Operation.define + def f1(_a: int, _s: int) -> float: + raise NotHandled + + @Operation.define + def f2(_t: int, _a: int) -> float: + raise NotHandled + + term1 = Sum.reduce( + Product.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = Sum.reduce( + Product.reduce( + Sum.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + {i: N()}, + ), + {t: T()}, + ) + + _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132e..af8935eca 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,7 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" + assert term_float.op.__name__ == "MyClass_my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {} From 92586c43e3276cc3a883764012c0b65b9b7377eb Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 7 May 2026 15:41:39 -0400 Subject: [PATCH 03/34] cleanup --- effectful/ops/monoid.py | 19 +------------------ tests/_monoid_helpers.py | 2 +- tests/test_ops_monoid.py | 14 +++++++------- tests/test_ops_syntax.py | 1 - 4 files changed, 9 insertions(+), 27 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 748eb9cf3..575838c29 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,7 +4,7 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Iterator, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any @@ -565,23 +565,6 @@ def reduce(self, monoid, body, streams): return fwd() -def outer_stream( - streams: dict[Operation, Expr], -) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: - """Returns the streams that can be ordered outermost in the loop nest as - well as the remaining streams in the nest. - - """ - stream_vars = set(streams.keys()) - pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(pred) - topo.prepare() - return ( - (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) - for op in topo.get_ready() - ) - - def inner_stream( streams: dict[Operation, Expr], ) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index f6397053b..772f91ecf 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,7 +14,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers()) + return st.lists(st.integers(), max_size=3) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e1c1400ff..8dbd32b6e 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -257,8 +257,8 @@ def test_plus_assoc_left(monoid): def test_plus_sequence(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus([a(), b()], [c(), d()]), - rhs=[monoid.plus(a(), c()), monoid.plus(b(), d())], + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), free_vars=[a, b, c, d], ) @@ -373,8 +373,8 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(x())], {x: X()}) - rhs = [monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})] + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) @@ -390,11 +390,11 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce([f(x()), g(y())], {x: X(), y: Y()}) - rhs = [ + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( monoid.reduce(f(x()), {x: X(), y: Y()}), monoid.reduce(g(y()), {x: X(), y: Y()}), - ] + ) _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index af8935eca..1f5c47763 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "MyClass_my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {} From f7d43e564eecd243a13a43d9afb7735e005515d4 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 7 May 2026 15:42:34 -0400 Subject: [PATCH 04/34] wip --- effectful/handlers/jax/monoid.py | 79 ++++++++++++++ effectful/ops/monoid.py | 119 ++++++++++++--------- tests/_monoid_helpers.py | 159 +++++++++++++++++++++++++++- tests/test_handlers_jax_monoid.py | 82 +++++++++++++++ tests/test_ops_monoid.py | 165 +++++++++--------------------- 5 files changed, 433 insertions(+), 171 deletions(-) create mode 100644 effectful/handlers/jax/monoid.py create mode 100644 tests/test_handlers_jax_monoid.py diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..4b55674f9 --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,79 @@ +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CommutativeMonoid, + CommutativeMonoidWithZero, + Monoid, + Semilattice, + Streams, + distributes_over, + outer_stream, +) +from effectful.ops.semantics import evaluate, handler, typeof +from effectful.ops.syntax import deffn +from effectful.ops.types import Operation + + +@Operation.define +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + x, y = jnp.repeat(x, y.shape[0], axis=0), jnp.tile(y, (x.shape[0], 1)) + return jnp.hstack([x, y]) + + +Sum = CommutativeMonoid(kernel=jnp.add, identity=jnp.asarray(0)) +Product = CommutativeMonoidWithZero( + kernel=jnp.multiply, identity=jnp.asarray(1), zero=jnp.asarray(0) +) +Min = Semilattice(kernel=jnp.minimum, identity=jnp.asarray(float("-inf"))) +Max = Semilattice(kernel=jnp.maximum, identity=jnp.asarray(float("inf"))) +LogSumExp = CommutativeMonoid(kernel=jnp.logaddexp, identity=jnp.asarray(float("-inf"))) +CartesianProd = Monoid(kernel=cartesian_prod, identity=jnp.array([])) + +distributes_over.register(Max.plus, Min.plus) +distributes_over.register(Min.plus, Max.plus) +distributes_over.register(Sum.plus, Min.plus) +distributes_over.register(Sum.plus, Max.plus) +distributes_over.register(Product.plus, Sum.plus) +distributes_over.register(Sum.plus, LogSumExp.plus) + +ARRAY_REDUCE = { + Sum.plus: jnp.sum, + Product.plus: jnp.prod, + Min.plus: jnp.min, + Max.plus: jnp.max, + LogSumExp.plus: logsumexp, +} + + +@Monoid.reduce.register(jax.Array) +def _reduce_array(self, body: jax.Array, streams: Streams): + reductor = ARRAY_REDUCE[self.plus] + index = Operation.define(jax.Array) + + if not streams: + return self.identity + + # find and reduce an array stream + for stream_key, stream_body, streams_tail in outer_stream(streams): + if typeof(stream_body) != jax.Array: + continue + + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + (eval_body, eval_streams_tail) = evaluate(body), evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + + reduce_tail = ( + self.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + + return self._reduce_object(body, streams) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 58a10ba3d..5efebd22a 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -4,24 +4,32 @@ import numbers import typing from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping, Sequence +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ( ObjectInterpretation, Scoped, _NumberTerm, defdata, + deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchCallable, +) # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -36,17 +44,21 @@ ) -def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: - """Determine an order to evaluate the streams based on their dependencies""" +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ stream_vars = set(streams.keys()) - dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(dependencies) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) topo.prepare() - while topo.is_active(): - node_group = topo.get_ready() - for op in sorted(node_group): - yield (op, streams[op]) - topo.done(*node_group) + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) class Monoid[T]: @@ -114,58 +126,65 @@ def _(self, *args): return result @Operation.define - @functools.singledispatchmethod + @_CustomSingleDispatchCallable # type: ignore[arg-type] def reduce[A, B, U: Body]( + dispatch, self, + /, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) + return dispatch(typeof(body))(self, body, streams) # type: ignore[operator] - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return + @reduce.register(object) # type: ignore[attr-defined] + def _reduce_object(self, body: object, streams: Streams): + if not streams: + return self.identity - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] + # find and reduce a ground stream + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue - # If we try to iterate and get a term instead of a real - # iterator, give up + stream_values_iter = iter(stream_body) + + # if we iterate and get a term instead of a real iterator, skip if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - try: - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - except NotHandled: - return typing.cast(U, defdata(self.reduce, body, streams)) + continue - @reduce.register # type: ignore[attr-defined] - def _(self, body: Mapping, streams): + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append(self.reduce(*eval_args)) + + return self.plus(*new_reduces) + + return defdata(self.reduce, body, streams) + + @reduce.register(Callable) # type: ignore[attr-defined] + def _reduce_callable(self, body: Callable, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) + return lambda *a, **k: self.reduce(body(*a, **k), streams) + + @reduce.register(Mapping) # type: ignore[attr-defined] + def _reduce_mapping(self, body: Mapping, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return {k: self.reduce(v, streams) for (k, v) in body.items()} - @reduce.register # type: ignore[attr-defined] - def _(self, body: Sequence, streams): + @reduce.register(list | tuple) # type: ignore[attr-defined] + def _reduce_sequence(self, body: Sequence, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return type(body)(self.reduce(x, streams) for x in body) # type:ignore[call-arg] - @reduce.register # type: ignore[attr-defined] - def _(self, body: Generator, streams): + @reduce.register(Generator) # type: ignore[attr-defined] + def _reduce_generator(self, body: Generator, streams): + if isinstance(body, Term): + return defdata(self.reduce, body, streams) return (self.reduce(x, streams) for x in body) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 4532ae72d..634aacea6 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,11 +1,41 @@ +import itertools from collections.abc import Callable, Mapping, Sequence from typing import Any, get_args, get_origin +import jax from hypothesis import strategies as st -from effectful.ops.syntax import deffn +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.semantics import apply, evaluate +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import Operation +_JAX_ARRAY_SHAPE = (3,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.integers(min_value=0, max_value=2**31 - 1).map( + lambda seed: jax.random.uniform( + jax.random.PRNGKey(seed), _JAX_ARRAY_SHAPE, minval=0.5, maxval=1.5 + ) + ) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1.0]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1.0, 2.0 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] + def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" @@ -15,9 +45,11 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): return st.lists(st.integers()) + if annotation is jax.Array: + return _jax_array_value_strategy() raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int]" + "supported: int, list[int], jax.Array" ) @@ -64,8 +96,12 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_BINARY_INT_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) raise NotImplementedError( - f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -82,4 +118,119 @@ def random_interpretation( return intp -__all__ = ["random_interpretation"] +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set): + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs): + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +__all__ = ["random_interpretation", "define_vars", "syntactic_eq_alpha"] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..4efd0eb21 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,82 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import LogSumExp, Max, Min, Product, Sum +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.types import NotHandled, Operation +from tests._monoid_helpers import define_vars, syntactic_eq_alpha + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor): + (x, X, k) = define_vars("x", "X", "k", typ=jax.Array) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + assert syntactic_eq_alpha(lhs, rhs) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor): + (x, y, X, Y, k1, k2) = define_vars("x", "y", "X", "Y", "k1", "k2", typ=jax.Array) + + @Operation.define + def f(_a: jax.Array, _b: jax.Array) -> jax.Array: + raise NotHandled + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + assert syntactic_eq_alpha(lhs, rhs) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, X, k1, k2) = define_vars("x", "y", "X", "k1", "k2", typ=jax.Array) + + @Operation.define + def f(_a: jax.Array, _b: jax.Array) -> jax.Array: + raise NotHandled + + @Operation.define + def g(_a: jax.Array) -> jax.Array: + raise NotHandled + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + assert syntactic_eq_alpha(lhs, rhs) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index a22928cca..d073827c9 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,16 +1,11 @@ -import functools -import itertools - import pytest from hypothesis import given, settings from hypothesis import strategies as st -from effectful.internals.runtime import interpreter from effectful.ops.monoid import Max, Min, NormalizeIntp, Product, Semilattice, Sum -from effectful.ops.semantics import apply, evaluate, fvsof, handler -from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq +from effectful.ops.semantics import evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import random_interpretation +from tests._monoid_helpers import define_vars, random_interpretation, syntactic_eq_alpha _INT = st.integers(min_value=-100, max_value=100) @@ -38,116 +33,6 @@ ] -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - -@functools.cache -def _canonical_op(idx: int) -> Operation: - """Globally cached canonical Operation, keyed by encounter index. - - Cached so that two independent canonicalize runs return the same - Operation object for the same index — letting ``syntactic_eq`` - compare canonical forms by Operation identity. - """ - return Operation.define(int, name=f"__cv_{idx}") - - -def syntactic_eq_alpha(x, y) -> bool: - """Alpha-equivalence-respecting variant of ``syntactic_eq``. - - Walks each expression bottom-up with :func:`evaluate` and renames - every bound variable to a deterministic canonical Operation. The - canonical names are assigned by a counter that increments in - ``evaluate``'s natural traversal order, so two alpha-equivalent - expressions canonicalize to syntactically identical results. - """ - return syntactic_eq(_canonicalize(x), _canonicalize(y)) - - -def _canonicalize(expr): - counter = itertools.count() - - def _passthrough(op, *args, **kwargs): - return defdata(op, *args, **kwargs) - - def _substitute(arg, renaming): - """Apply a bound-variable renaming using ``evaluate`` for traversal.""" - if not renaming: - return arg - with interpreter({apply: _passthrough, **renaming}): - return evaluate(arg) - - def _bound_var_order(args, kwargs, bound_set): - """Return bound variables in deterministic encounter order.""" - seen: list[Operation] = [] - seen_set: set[Operation] = set() - - def _capture(op, *a, **kw): - if op in bound_set and op not in seen_set: - seen.append(op) - seen_set.add(op) - return defdata(op, *a, **kw) - - # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, - # etc. for free; the apply handler captures bound vars used as - # ``x()`` anywhere in the body. - with interpreter({apply: _capture}): - evaluate((args, kwargs)) - - # Binders bypass the apply handler. Pick them up with a small structural - # walk that visits dict keys too. - def _walk_bare(obj): - if isinstance(obj, Operation): - if obj in bound_set and obj not in seen_set: - seen.append(obj) - seen_set.add(obj) - elif isinstance(obj, dict): - for k, v in obj.items(): - _walk_bare(k) - _walk_bare(v) - elif isinstance(obj, list | set | frozenset | tuple): - for v in obj: - _walk_bare(v) - - _walk_bare((args, kwargs)) - return seen - - def _apply_canonical(op, *args, **kwargs): - bindings = op.__fvs_rule__(*args, **kwargs) - all_bound: set[Operation] = set().union( - *bindings.args, *bindings.kwargs.values() - ) - if not all_bound: - return defdata(op, *args, **kwargs) - - order = _bound_var_order(args, kwargs, all_bound) - canonical = {var: _canonical_op(next(counter)) for var in order} - assert all_bound <= set(order) - - new_args = tuple( - _substitute( - arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} - ) - for i, arg in enumerate(args) - ) - new_kwargs = { - k: _substitute( - v, - {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, - ) - for k, v in kwargs.items() - } - - # avoid the renaming from defdata - return _BaseTerm(op, *new_args, **new_kwargs) - - with interpreter({apply: _apply_canonical}): - return evaluate(expr) - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(a=_INT, b=_INT, c=_INT) @settings(max_examples=50, deadline=None) @@ -357,6 +242,52 @@ def test_plus_zero(monoid): _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid): + x, y = define_vars("x", "y") + + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid): + x, y = define_vars("x", "y") + Y = define_vars("Y", typ=list[int]) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_3(monoid): + x, y, a, b = define_vars("x", "y", "a", "b") + Y = define_vars("Y", typ=list[int]) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid): + x, y, a, b = define_vars("x", "y", "a", "b") + + @Operation.define + def f(_x: int) -> list[int]: + raise NotHandled + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) + + _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, f]) + + @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_reduce_body_sequence(monoid): x = Operation.define(int, name="x") From 8479ed5ec57e915937593f33b0d7699d577b6ead Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 09:37:11 -0400 Subject: [PATCH 05/34] fix rule --- effectful/ops/monoid.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 575838c29..e15c5c306 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -610,8 +610,8 @@ class ReduceDistributeCartesianProduct(ObjectInterpretation): More specifically, this transform implements the identity reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) - = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: v()}), S1), S2) - where × is the cartesian product and ⨂ distributes over ⨁. + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. Note: This could be generalized to grouped inversion [2]. From bd1025ea027f4b94c775d7762bac8da3047e368c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 10:06:09 -0400 Subject: [PATCH 06/34] wip --- tests/_monoid_helpers.py | 2 +- tests/test_ops_monoid.py | 53 ++++++++++++++++++++++++++++++++-------- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 772f91ecf..9b311b257 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -14,7 +14,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(), max_size=3) + return st.lists(st.integers(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " "supported: int, list[int]" diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 8dbd32b6e..506dbe980 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -14,6 +14,7 @@ Product, Semilattice, Sum, + distributes_over, ) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq @@ -45,6 +46,16 @@ pytest.param(Product, id="Product"), ] +# Pairs (outer, inner) such that inner distributes over outer — i.e. the lifting +# identity ``outer(inner(body, A), CartesianProduct...) == inner(outer(body, D), ...)`` +# is valid for that semiring pair. +MONOID_PAIRS = [ + pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") + for o in ALL_MONOIDS + for i in ALL_MONOIDS + if distributes_over(i.values[0].plus, o.values[0].plus) +] + def define_vars(*names, typ=int): if len(names) == 1: @@ -521,7 +532,8 @@ def f(_x: int, _y: int) -> int: _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) -def test_reduce_lifted_1(): +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_1(outer, inner): a, i = define_vars("a", "i") A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) @@ -529,11 +541,11 @@ def test_reduce_lifted_1(): def f(_: int) -> float: raise NotHandled - term1 = Sum.reduce( - Product.reduce(f(a()), {a: A()}), + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = Product.reduce(Sum.reduce(f(a()), {a: A_domain()}), {i: N()}) + term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) @@ -561,7 +573,28 @@ def test_reduce_cartesian_2(): assert term1 == term2 -def test_reduce_lifted_2(): +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_multi_index(outer, inner): + a, i, j = define_vars("a", "i", "j") + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) + + @Operation.define + def f(_: int) -> float: + raise NotHandled + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, + ) + term2 = inner.reduce( + outer.reduce(f(a()), {a: A_domain()}), + {i: N(), j: M()}, + ) + _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_2(outer, inner): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. @@ -581,14 +614,14 @@ def f1(_a: int, _s: int) -> float: def f2(_t: int, _a: int) -> float: raise NotHandled - term1 = Sum.reduce( - Product.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + term1 = outer.reduce( + inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, ) - term2 = Sum.reduce( - Product.reduce( - Sum.reduce(Product.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + term2 = outer.reduce( + inner.reduce( + outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), {i: N()}, ), {t: T()}, From 39d8bb026de43a66556faeeec2b3de38a30a35d9 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 10:54:16 -0400 Subject: [PATCH 07/34] fix bug --- effectful/ops/monoid.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index e15c5c306..1a0cb293a 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -667,8 +667,9 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): continue prod_op = Operation.define(products[0][2]) + prod_monoid = products[0][0] inner_sum = sum_monoid.reduce( - Product.plus( + prod_monoid.plus( *(prod_body(prod_op()) for (_, prod_body, _, _) in products) ), {prod_op: cprod_body}, From 5283db1b8b73f9eae7cc6c1254e924465790abb8 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 11:01:52 -0400 Subject: [PATCH 08/34] cleanup --- effectful/ops/monoid.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 1a0cb293a..ad83de47b 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -577,8 +577,10 @@ def inner_stream( no_dependents = set() succ = defaultdict(set) for k, v in streams.items(): - for pred in fvsof(v) & stream_vars: - succ[pred].add(k) + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) else: no_dependents.add(k) From 9fd88af4fb91c82fcb08757c9e193922c600e917 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 11:03:02 -0400 Subject: [PATCH 09/34] lin --- tests/test_ops_monoid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 506dbe980..e73a9a7b2 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,5 +1,6 @@ import functools import itertools +import typing import pytest from hypothesis import given, settings @@ -10,6 +11,7 @@ CartesianProduct, Max, Min, + Monoid, NormalizeIntp, Product, Semilattice, @@ -53,7 +55,9 @@ pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") for o in ALL_MONOIDS for i in ALL_MONOIDS - if distributes_over(i.values[0].plus, o.values[0].plus) + if distributes_over( + typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + ) ] From 43ca58dd3b5a22d268286509a30cf44f86dcce9f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 13:18:56 -0400 Subject: [PATCH 10/34] wip --- effectful/ops/monoid.py | 303 ++++++++++++++++++--------------------- effectful/ops/syntax.py | 1 + effectful/ops/types.py | 57 ++++++++ tests/test_ops_monoid.py | 4 +- 4 files changed, 201 insertions(+), 164 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ad83de47b..7571d6b65 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,6 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.internals.runtime import interpreter from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, @@ -23,7 +22,14 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term +from effectful.ops.types import ( + Expr, + Interpretation, + NotHandled, + Operation, + Term, + _CustomSingleDispatchMethod, +) # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -51,6 +57,56 @@ def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], An topo.done(*node_group) +@Operation.define +def reduce[A, B, U: Body]( + monoid: "Monoid", + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], +) -> Annotated[U, Scoped[B]]: + if callable(body): + return typing.cast(U, lambda *a, **k: monoid.reduce(body(*a, **k), streams)) + + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = {stream_key: functools.partial(lambda v: v, val)} + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + return monoid.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) + + +@Operation.define +def plus[S: Body](monoid: "Monoid", *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + + """ + if any(isinstance(x, Term) for x in args): + raise NotHandled + + return typing.cast(S, functools.reduce(monoid.kernel, args, monoid.identity)) + + class Monoid[T]: kernel: Operation[[T, T], T] identity: T @@ -64,33 +120,31 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity})" - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. + def __eq__(self, other): + return id(self) == id(other) - """ - if not args: - return typing.cast(S, self.identity) + def __hash__(self): + return hash(id(self)) - if any(isinstance(x, Term) for x in args): - return typing.cast(S, defdata(self.plus, *args)) - - return self._plus(*args) + @_CustomSingleDispatchMethod + def plus[S](self, dispatch, *args: S) -> S: + if not args: + return self.identity + return dispatch(type(args[0]))(self, *args) - @functools.singledispatchmethod - def _plus[S](self, *args: S) -> S: - return typing.cast(S, functools.reduce(self.kernel, args, self.identity)) + @plus.register(object) + def _(self, *args): + return plus(self, *args) - @_plus.register(tuple) + @plus.register(tuple) def _(self, *args): return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Generator) + @plus.register(Generator) def _(self, *args): return (self.plus(*vs) for vs in zip(*args, strict=True)) - @_plus.register(Mapping) + @plus.register(Mapping) def _(self, *args): if isinstance(args[0], Interpretation): keys = args[0].keys() @@ -119,7 +173,6 @@ def _(self, *args): result = {k: self.plus(*vs) for (k, vs) in all_values.items()} return result - @Operation.define @functools.singledispatchmethod def reduce[A, B, U: Body]( self, @@ -129,38 +182,7 @@ def reduce[A, B, U: Body]( if callable(body): return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return - - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] - - # If we try to iterate and get a term instead of a real - # iterator, give up - if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - try: - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - except NotHandled: - return typing.cast(U, defdata(self.reduce, body, streams)) + return reduce(self, body, streams) @reduce.register # type: ignore[attr-defined] def _(self, body: Mapping, streams): @@ -175,35 +197,7 @@ def _(self, body: Generator, streams): return (self.reduce(x, streams) for x in body) -class IdempotentMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - - -class CommutativeMonoid[T](Monoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - - -class CommutativeMonoidWithZero[T](CommutativeMonoid[T]): +class MonoidWithZero[T](Monoid[T]): zero: T def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): @@ -213,32 +207,6 @@ def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): def __repr__(self): return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - - -class Semilattice[T](IdempotentMonoid[T], CommutativeMonoid[T]): - @Operation.define - def plus[S: Body[T]](self, *args: S) -> S: - return super().plus(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - return super().reduce(body, streams) - @Operation.define def _arg_min[T]( @@ -271,15 +239,30 @@ def to_tuple(x): return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] -Min = Semilattice(kernel=min, identity=float("inf")) -Max = Semilattice(kernel=max, identity=float("-inf")) +Min = Monoid(kernel=min, identity=float("inf")) +Max = Monoid(kernel=max, identity=float("-inf")) ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = CommutativeMonoid(kernel=_NumberTerm.__add__, identity=0) -Product = CommutativeMonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) +Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) +Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) CartesianProduct = Monoid(kernel=product, identity=[()]) +@dataclass +class _ExtensiblePredicate[T]: + elems: set[T] + + def register(self, t: T) -> None: + self.elems.add(t) + + def __call__(self, t: T) -> bool: + return t in self.elems + + +is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) +is_idempotent = _ExtensiblePredicate({Max, Min}) + + @dataclass class _ExtensibleBinaryRelation[S, T]: tuples: set[tuple[S, T]] @@ -292,20 +275,14 @@ def __call__(self, s: S, t: T) -> bool: distributes_over = _ExtensibleBinaryRelation( - { - (Max.plus, Min.plus), - (Min.plus, Max.plus), - (Sum.plus, Min.plus), - (Sum.plus, Max.plus), - (Product.plus, Sum.plus), - } + {(Max, Min), (Min, Max), (Sum, Min), (Sum, Max), (Product, Sum)} ) class PlusEmpty(ObjectInterpretation): """plus() = 0""" - @implements(Monoid.plus) + @implements(plus) def plus(self, monoid, *args): if not args: return monoid.identity @@ -315,7 +292,7 @@ def plus(self, monoid, *args): class PlusSingle(ObjectInterpretation): """plus(x) = x""" - @implements(Monoid.plus) + @implements(plus) def plus(self, _, *args): if len(args) == 1: return args[0] @@ -325,7 +302,7 @@ def plus(self, _, *args): class PlusIdentity(ObjectInterpretation): """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" - @implements(Monoid.plus) + @implements(plus) def plus(self, monoid, *args): if any(x is monoid.identity for x in args): return monoid.plus(*(x for x in args if x is not monoid.identity)) @@ -335,12 +312,14 @@ def plus(self, monoid, *args): class PlusAssoc(ObjectInterpretation): """x + (y + z) = (x + y) + z = x + y + z""" - @implements(Monoid.plus) + @implements(plus) def plus(self, monoid, *args): - if any(isinstance(x, Term) and x.op is monoid.plus for x in args): + def is_nested_plus(x): + return isinstance(x, Term) and x.op == plus and x.args[0] is monoid + + if any(is_nested_plus(x) for x in args): flat_args = itertools.chain.from_iterable( - t.args if isinstance(t, Term) and t.op is monoid.plus else (t,) - for t in args + t.args[1:] if is_nested_plus(t) else (t,) for t in args ) assert len(args) > 0 return monoid.plus(*flat_args) @@ -350,36 +329,37 @@ def plus(self, monoid, *args): class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" - @implements(Monoid.plus) + @implements(plus) def plus(self, monoid, *args): if any( - isinstance(x, Term) and distributes_over(monoid.plus, x.op) for x in args + isinstance(x, Term) and x.op == plus and distributes_over(monoid, x.args[0]) + for x in args ): non_terms = [] - # group terms by head operation - by_head_op = defaultdict(list) + # group terms by monoid + by_monoid = defaultdict(list) for t in args: - if isinstance(t, Term): - by_head_op[t.op].append(t) + if isinstance(t, Term) and t.op == plus: + by_monoid[t.args[0]].append(t) else: non_terms.append(t) # distribute over each group progress = False final_sum = [] - for op, terms in by_head_op.items(): + for m, terms in by_monoid.items(): if ( len(terms) > 1 - and distributes_over(monoid.plus, op) - and not distributes_over(op, monoid.plus) + and distributes_over(monoid.plus, m) + and not distributes_over(m, monoid.plus) ): progress = True - term_args = (t.args for t in terms) + term_args = (t.args[1:] for t in terms) dist_terms = ( monoid.plus(*args) for args in itertools.product(*term_args) ) - final_sum.append(op(*dist_terms)) + final_sum.append(monoid.plus(*dist_terms)) else: final_sum += terms if progress: @@ -390,8 +370,10 @@ def plus(self, monoid, *args): class PlusZero(ObjectInterpretation): """x₁ * ... * 0 * ... * xₙ = 0""" - @implements(CommutativeMonoidWithZero.plus) + @implements(plus) def plus(self, monoid, *args): + if not (isinstance(monoid, MonoidWithZero)): + return fwd() if any(x is monoid.zero for x in args): return monoid.zero return fwd() @@ -400,8 +382,11 @@ def plus(self, monoid, *args): class PlusConsecutiveDups(ObjectInterpretation): """x ⊕ x ⊕ y = x ⊕ y""" - @implements(IdempotentMonoid.plus) + @implements(plus) def plus(self, monoid, *args): + if not is_idempotent(monoid): + return fwd() + dedup_args = ( args[i] for i in range(len(args)) @@ -423,8 +408,11 @@ def __eq__(self, other): def __hash__(self): return syntactic_hash(self) - @implements(Semilattice.plus) + @implements(plus) def plus(self, monoid, *args): + if not (is_idempotent(monoid) and is_commutative(monoid)): + return fwd() + # elim dups args_count = Counter(self._HashableTerm(t) for t in args) if len(args_count) < len(args): @@ -461,7 +449,7 @@ class ReduceNoStreams(ObjectInterpretation): reduce(R, ∅, body) = 0 """ - @implements(Monoid.reduce) + @implements(reduce) def reduce(self, monoid, _, streams): if len(streams) == 0: return monoid.identity @@ -473,7 +461,7 @@ class ReduceFusion(ObjectInterpretation): reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) """ - @implements(Monoid.reduce) + @implements(reduce) def reduce(self, monoid, body, streams): if isinstance(body, Term) and body.op == monoid.reduce: return monoid.reduce(body.args[0], streams | body.args[1]) @@ -485,8 +473,10 @@ class ReduceSplit(ObjectInterpretation): reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) """ - @implements(CommutativeMonoid.reduce) + @implements(reduce) def reduce(self, monoid, body, streams): + if not is_commutative(monoid): + return fwd() if isinstance(body, Term) and body.op == monoid.plus: return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) return fwd() @@ -506,8 +496,10 @@ class ReduceFactorization(ObjectInterpretation): and free(Aᵢ) ∩ S ⊆ Sᵢ """ - @implements(CommutativeMonoid.reduce) + @implements(reduce) def reduce(self, monoid, body, streams): + if not is_commutative(monoid): + return fwd() if isinstance(body, Term) and distributes_over(body.op, monoid.plus): stream_vars = set(streams.keys()) factors = [(arg, fvsof(arg)) for arg in body.args] @@ -521,7 +513,7 @@ def reduce(self, monoid, body, streams): ds.union(stream_id, *deps) # factors are in the same partition as their dependencies - for factor, factor_fvs in factors: + for _, factor_fvs in factors: factor_streams = sorted( [stream_ids[v] for v in (factor_fvs & stream_vars)] ) @@ -592,18 +584,6 @@ def inner_stream( ) -def match_reduce(term: Term) -> tuple | None: - reduce_args = None - - def set_reduce_args(*args, **kwargs): - nonlocal reduce_args - reduce_args = args - - with interpreter({Monoid.reduce: set_reduce_args}): - term.op(*term.args, **term.kwargs) - return reduce_args - - class ReduceDistributeCartesianProduct(ObjectInterpretation): """Eliminates a reduce over a cartesian product. ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) @@ -623,9 +603,9 @@ class ReduceDistributeCartesianProduct(ObjectInterpretation): variable elimination." AISTATS. 2013. """ - @implements(CommutativeMonoid.reduce) + @implements(reduce) def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): - if not (isinstance(sum_body, Term)): + if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): return fwd() # body is a product or multiplication of products @@ -636,12 +616,11 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): products: list[tuple[Monoid, Callable, Operation, Term]] = [] for prod_reduce in prod_reduces: - prod_args = match_reduce(prod_reduce) - if prod_args is None: + if not (isinstance(prod_reduce, Term) and prod_reduce.op == reduce): return fwd() - (prod_monoid, prod_body, prod_streams) = prod_args + (prod_monoid, prod_body, prod_streams) = prod_reduce.args if not ( - distributes_over(prod_monoid.plus, sum_monoid.plus) + distributes_over(prod_monoid, sum_monoid) and (len(products) == 0 or products[-1][0] == prod_monoid) ): return fwd() diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8fb12598f..501fa7dff 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -15,6 +15,7 @@ Operation, Term, _CustomSingleDispatchCallable, + _CustomSingleDispatchMethod, ) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index d24be9745..cc03f86f1 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -42,6 +42,63 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.func(self.dispatch, *args, **kwargs) +class _CustomSingleDispatchMethod[**P, **Q, S, T]: + """Method analog of :class:`_CustomSingleDispatchCallable`. + + The wrapped function has signature ``(self, dispatch, *args, **kwargs)``, + where ``dispatch`` is :meth:`functools.singledispatch.dispatch`. As a + descriptor, it binds ``self`` on attribute access, so callers invoke it + as ``instance.method(*args, **kwargs)``. + """ + + def __init__( + self, + func: Callable[ + Concatenate[Any, Callable[[type], Callable[Q, S]], P], T + ], + ): + self.func = func + self._registry = functools.singledispatch(func) + self.__signature__ = inspect.signature( + functools.partial(func, None, None) # type: ignore[arg-type] + ) + functools.update_wrapper(self, func) + + @property + def dispatch(self): + return self._registry.dispatch + + @property + def register(self): + return self._registry.register + + def __get__(self, instance, owner=None): + if instance is None: + return self + return _BoundCustomSingleDispatchMethod(self, instance) + + +class _BoundCustomSingleDispatchMethod: + __slots__ = ("_method", "_instance") + + def __init__(self, method: _CustomSingleDispatchMethod, instance: Any): + self._method = method + self._instance = instance + + @property + def dispatch(self): + return self._method.dispatch + + @property + def register(self): + return self._method.register + + def __call__(self, *args, **kwargs): + return self._method.func( + self._instance, self._method.dispatch, *args, **kwargs + ) + + class _ClassMethodOpDescriptor(classmethod): def __init__(self, define, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e73a9a7b2..5eacd9e20 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -14,9 +14,9 @@ Monoid, NormalizeIntp, Product, - Semilattice, Sum, distributes_over, + is_commutative, ) from effectful.ops.semantics import apply, evaluate, fvsof, handler from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq @@ -354,7 +354,7 @@ def test_plus_idempotent_non_consecutive(monoid): PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" a, b = define_vars("a", "b") lhs = monoid.plus(a(), b(), a()) - if isinstance(monoid, Semilattice): + if is_commutative(monoid): rhs = monoid.plus(a(), b()) else: rhs = monoid.plus(a(), b(), a()) From ed8cf13b915634872f040066644ffa1cc0d42d98 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 13:26:29 -0400 Subject: [PATCH 11/34] fix tests --- effectful/ops/monoid.py | 44 ++++++++++++++++++++++++++-------------- tests/test_ops_monoid.py | 2 +- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 7571d6b65..2526168c9 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -351,15 +351,15 @@ def plus(self, monoid, *args): for m, terms in by_monoid.items(): if ( len(terms) > 1 - and distributes_over(monoid.plus, m) - and not distributes_over(m, monoid.plus) + and distributes_over(monoid, m) + and not distributes_over(m, monoid) ): progress = True term_args = (t.args[1:] for t in terms) dist_terms = ( monoid.plus(*args) for args in itertools.product(*term_args) ) - final_sum.append(monoid.plus(*dist_terms)) + final_sum.append(m.plus(*dist_terms)) else: final_sum += terms if progress: @@ -463,8 +463,12 @@ class ReduceFusion(ObjectInterpretation): @implements(reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == monoid.reduce: - return monoid.reduce(body.args[0], streams | body.args[1]) + if ( + isinstance(body, Term) + and body.op == reduce + and body.args[0] is monoid + ): + return monoid.reduce(body.args[1], streams | body.args[2]) return fwd() @@ -477,8 +481,12 @@ class ReduceSplit(ObjectInterpretation): def reduce(self, monoid, body, streams): if not is_commutative(monoid): return fwd() - if isinstance(body, Term) and body.op == monoid.plus: - return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + if ( + isinstance(body, Term) + and body.op == plus + and body.args[0] is monoid + ): + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args[1:])) return fwd() @@ -500,9 +508,14 @@ class ReduceFactorization(ObjectInterpretation): def reduce(self, monoid, body, streams): if not is_commutative(monoid): return fwd() - if isinstance(body, Term) and distributes_over(body.op, monoid.plus): + if ( + isinstance(body, Term) + and body.op == plus + and distributes_over(body.args[0], monoid) + ): + inner_monoid = body.args[0] stream_vars = set(streams.keys()) - factors = [(arg, fvsof(arg)) for arg in body.args] + factors = [(arg, fvsof(arg)) for arg in body.args[1:]] stream_ids = {v: i for (i, v) in enumerate(stream_vars)} ds = DisjointSet(len(streams)) @@ -542,14 +555,14 @@ def reduce(self, monoid, body, streams): for t in partition_factors ), "partition contains all streams required by factor" - partition_term = body.op(*(t[0] for t in partition_factors)) + partition_term = inner_monoid.plus(*(t[0] for t in partition_factors)) new_reduces.append((partition_term, partition_streams)) placed_streams |= partition_stream_keys constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] if len(new_reduces) > 1: - result = body.op( + result = inner_monoid.plus( *constant_factors, *(monoid.reduce(*args) for args in new_reduces) ) return result @@ -609,8 +622,8 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() # body is a product or multiplication of products - if distributes_over(sum_body.op, sum_monoid.plus): - prod_reduces = sum_body.args + if sum_body.op == plus and distributes_over(sum_body.args[0], sum_monoid): + prod_reduces = sum_body.args[1:] else: prod_reduces = [sum_body] @@ -637,10 +650,11 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): if not ( isinstance(cprod_term, Term) - and cprod_term.op == CartesianProduct.reduce + and cprod_term.op == reduce + and cprod_term.args[0] is CartesianProduct ): continue - (cprod_body, cprod_streams) = cprod_term.args + (_, cprod_body, cprod_streams) = cprod_term.args if not all( prod_stream.op == cprod_op for (_, _, _, prod_stream) in products diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 5eacd9e20..d881869ac 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -56,7 +56,7 @@ for o in ALL_MONOIDS for i in ALL_MONOIDS if distributes_over( - typing.cast(Monoid, i.values[0]).plus, typing.cast(Monoid, o.values[0]).plus + typing.cast(Monoid, i.values[0]), typing.cast(Monoid, o.values[0]) ) ] From 908f580eaef44360fbc0364296c321080e40a317 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 13:27:08 -0400 Subject: [PATCH 12/34] format --- effectful/ops/monoid.py | 13 ++----------- effectful/ops/syntax.py | 1 - effectful/ops/types.py | 8 ++------ 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 2526168c9..e1f4e18c3 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -15,7 +15,6 @@ ObjectInterpretation, Scoped, _NumberTerm, - defdata, deffn, implements, iter_, @@ -463,11 +462,7 @@ class ReduceFusion(ObjectInterpretation): @implements(reduce) def reduce(self, monoid, body, streams): - if ( - isinstance(body, Term) - and body.op == reduce - and body.args[0] is monoid - ): + if isinstance(body, Term) and body.op == reduce and body.args[0] is monoid: return monoid.reduce(body.args[1], streams | body.args[2]) return fwd() @@ -481,11 +476,7 @@ class ReduceSplit(ObjectInterpretation): def reduce(self, monoid, body, streams): if not is_commutative(monoid): return fwd() - if ( - isinstance(body, Term) - and body.op == plus - and body.args[0] is monoid - ): + if isinstance(body, Term) and body.op == plus and body.args[0] is monoid: return monoid.plus(*(monoid.reduce(x, streams) for x in body.args[1:])) return fwd() diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 501fa7dff..8fb12598f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -15,7 +15,6 @@ Operation, Term, _CustomSingleDispatchCallable, - _CustomSingleDispatchMethod, ) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index cc03f86f1..febffd99a 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -53,9 +53,7 @@ class _CustomSingleDispatchMethod[**P, **Q, S, T]: def __init__( self, - func: Callable[ - Concatenate[Any, Callable[[type], Callable[Q, S]], P], T - ], + func: Callable[Concatenate[Any, Callable[[type], Callable[Q, S]], P], T], ): self.func = func self._registry = functools.singledispatch(func) @@ -94,9 +92,7 @@ def register(self): return self._method.register def __call__(self, *args, **kwargs): - return self._method.func( - self._instance, self._method.dispatch, *args, **kwargs - ) + return self._method.func(self._instance, self._method.dispatch, *args, **kwargs) class _ClassMethodOpDescriptor(classmethod): From 5c504a3d60c9bd56b6fa93ffc727885f40c0466b Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 13:30:36 -0400 Subject: [PATCH 13/34] lint --- effectful/ops/monoid.py | 12 +++++++----- effectful/ops/types.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index e1f4e18c3..bac02921a 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -128,7 +128,7 @@ def __hash__(self): @_CustomSingleDispatchMethod def plus[S](self, dispatch, *args: S) -> S: if not args: - return self.identity + return typing.cast(S, self.identity) return dispatch(type(args[0]))(self, *args) @plus.register(object) @@ -183,15 +183,15 @@ def reduce[A, B, U: Body]( return reduce(self, body, streams) - @reduce.register # type: ignore[attr-defined] + @reduce.register def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} - @reduce.register # type: ignore[attr-defined] + @reduce.register def _(self, body: tuple, streams): return tuple(self.reduce(x, streams) for x in body) - @reduce.register # type: ignore[attr-defined] + @reduce.register def _(self, body: Generator, streams): return (self.reduce(x, streams) for x in body) @@ -622,7 +622,9 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): for prod_reduce in prod_reduces: if not (isinstance(prod_reduce, Term) and prod_reduce.op == reduce): return fwd() - (prod_monoid, prod_body, prod_streams) = prod_reduce.args + prod_monoid = typing.cast(Monoid, prod_reduce.args[0]) + prod_body = prod_reduce.args[1] + prod_streams = typing.cast(Mapping, prod_reduce.args[2]) if not ( distributes_over(prod_monoid, sum_monoid) and (len(products) == 0 or products[-1][0] == prod_monoid) diff --git a/effectful/ops/types.py b/effectful/ops/types.py index febffd99a..7e427942f 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -60,7 +60,7 @@ def __init__( self.__signature__ = inspect.signature( functools.partial(func, None, None) # type: ignore[arg-type] ) - functools.update_wrapper(self, func) + functools.update_wrapper(self, func) # type: ignore[arg-type] @property def dispatch(self): From c0472a8e6bf9a04b195e551eca3e8906a3179333 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 16:16:05 -0400 Subject: [PATCH 14/34] wip --- effectful/ops/monoid.py | 211 ++++++++++++++++++++-------------------- effectful/ops/types.py | 23 +++++ 2 files changed, 126 insertions(+), 108 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index bac02921a..0d6e230c0 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -56,56 +56,6 @@ def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], An topo.done(*node_group) -@Operation.define -def reduce[A, B, U: Body]( - monoid: "Monoid", - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], -) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: monoid.reduce(body(*a, **k), streams)) - - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return - - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] - - # If we try to iterate and get a term instead of a real - # iterator, give up - if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = {stream_key: functools.partial(lambda v: v, val)} - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - return monoid.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - - -@Operation.define -def plus[S: Body](monoid: "Monoid", *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. - - """ - if any(isinstance(x, Term) for x in args): - raise NotHandled - - return typing.cast(S, functools.reduce(monoid.kernel, args, monoid.identity)) - - class Monoid[T]: kernel: Operation[[T, T], T] identity: T @@ -125,53 +75,53 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + @Operation.define @_CustomSingleDispatchMethod def plus[S](self, dispatch, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + """ if not args: return typing.cast(S, self.identity) return dispatch(type(args[0]))(self, *args) - @plus.register(object) + @plus.register(object) # type: ignore[attr-defined] def _(self, *args): - return plus(self, *args) + if any(isinstance(x, Term) for x in args): + raise NotHandled + return functools.reduce(self.kernel, args, self.identity) - @plus.register(tuple) + @plus.register(tuple) # type: ignore[attr-defined] def _(self, *args): return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - @plus.register(Generator) + @plus.register(Generator) # type: ignore[attr-defined] def _(self, *args): return (self.plus(*vs) for vs in zip(*args, strict=True)) - @plus.register(Mapping) + @plus.register(Mapping) # type: ignore[attr-defined] def _(self, *args): if isinstance(args[0], Interpretation): keys = args[0].keys() - for b in args[1:]: if not isinstance(b, Interpretation): raise TypeError(f"Expected interpretation but got {b}") - - b_keys = b.keys() - if not keys == b_keys: + if not keys == b.keys(): raise ValueError( - f"Expected interpretation of {keys} but got {b_keys}" + f"Expected interpretation of {keys} but got {b.keys()}" ) - - result = {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - return result + return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} for b in args[1:]: if not isinstance(b, Mapping): raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) for d in args: for k, v in d.items(): all_values[k].append(v) - result = {k: self.plus(*vs) for (k, vs) in all_values.items()} - return result + return {k: self.plus(*vs) for (k, vs) in all_values.items()} + @Operation.define @functools.singledispatchmethod def reduce[A, B, U: Body]( self, @@ -181,21 +131,61 @@ def reduce[A, B, U: Body]( if callable(body): return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) - return reduce(self, body, streams) + def generator(loop_order) -> Iterator[Interpretation]: + if len(loop_order) == 0: + return + + stream_key = loop_order[0][0] + stream_values = evaluate(streams[stream_key]) + stream_values_iter = iter(stream_values) # type: ignore[arg-type] + + # If we try to iterate and get a term instead of a real + # iterator, give up + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + raise NotHandled + + if len(loop_order) == 1: + for val in stream_values_iter: + yield {stream_key: functools.partial(lambda v: v, val)} + else: + for val in stream_values_iter: + intp: Interpretation = { + stream_key: functools.partial(lambda v: v, val) + } + with handler(intp): + for intp2 in generator(loop_order[1:]): + yield coproduct(intp, intp2) + + loop_order = list(order_streams(streams)) + return self.plus( + *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) + ) - @reduce.register + @reduce.register # type: ignore[attr-defined] def _(self, body: Mapping, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} - @reduce.register + @reduce.register # type: ignore[attr-defined] def _(self, body: tuple, streams): return tuple(self.reduce(x, streams) for x in body) - @reduce.register + @reduce.register # type: ignore[attr-defined] def _(self, body: Generator, streams): return (self.reduce(x, streams) for x in body) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + class MonoidWithZero[T](Monoid[T]): zero: T @@ -281,7 +271,7 @@ def __call__(self, s: S, t: T) -> bool: class PlusEmpty(ObjectInterpretation): """plus() = 0""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if not args: return monoid.identity @@ -291,7 +281,7 @@ def plus(self, monoid, *args): class PlusSingle(ObjectInterpretation): """plus(x) = x""" - @implements(plus) + @implements(Monoid.plus) def plus(self, _, *args): if len(args) == 1: return args[0] @@ -301,7 +291,7 @@ def plus(self, _, *args): class PlusIdentity(ObjectInterpretation): """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if any(x is monoid.identity for x in args): return monoid.plus(*(x for x in args if x is not monoid.identity)) @@ -311,14 +301,14 @@ def plus(self, monoid, *args): class PlusAssoc(ObjectInterpretation): """x + (y + z) = (x + y) + z = x + y + z""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): def is_nested_plus(x): - return isinstance(x, Term) and x.op == plus and x.args[0] is monoid + return isinstance(x, Term) and x.op is monoid.plus if any(is_nested_plus(x) for x in args): flat_args = itertools.chain.from_iterable( - t.args[1:] if is_nested_plus(t) else (t,) for t in args + t.args if is_nested_plus(t) else (t,) for t in args ) assert len(args) > 0 return monoid.plus(*flat_args) @@ -328,19 +318,21 @@ def is_nested_plus(x): class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if any( - isinstance(x, Term) and x.op == plus and distributes_over(monoid, x.args[0]) + isinstance(x, Term) + and _is_monoid_plus(x.op) + and distributes_over(monoid, x.op.__self__) for x in args ): non_terms = [] - # group terms by monoid - by_monoid = defaultdict(list) + # group terms by their monoid + by_monoid: dict[Monoid, list[Term]] = defaultdict(list) for t in args: - if isinstance(t, Term) and t.op == plus: - by_monoid[t.args[0]].append(t) + if isinstance(t, Term) and _is_monoid_plus(t.op): + by_monoid[t.op.__self__].append(t) else: non_terms.append(t) @@ -354,7 +346,7 @@ def plus(self, monoid, *args): and not distributes_over(m, monoid) ): progress = True - term_args = (t.args[1:] for t in terms) + term_args = (t.args for t in terms) dist_terms = ( monoid.plus(*args) for args in itertools.product(*term_args) ) @@ -369,7 +361,7 @@ def plus(self, monoid, *args): class PlusZero(ObjectInterpretation): """x₁ * ... * 0 * ... * xₙ = 0""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if not (isinstance(monoid, MonoidWithZero)): return fwd() @@ -381,7 +373,7 @@ def plus(self, monoid, *args): class PlusConsecutiveDups(ObjectInterpretation): """x ⊕ x ⊕ y = x ⊕ y""" - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if not is_idempotent(monoid): return fwd() @@ -407,7 +399,7 @@ def __eq__(self, other): def __hash__(self): return syntactic_hash(self) - @implements(plus) + @implements(Monoid.plus) def plus(self, monoid, *args): if not (is_idempotent(monoid) and is_commutative(monoid)): return fwd() @@ -448,7 +440,7 @@ class ReduceNoStreams(ObjectInterpretation): reduce(R, ∅, body) = 0 """ - @implements(reduce) + @implements(Monoid.reduce) def reduce(self, monoid, _, streams): if len(streams) == 0: return monoid.identity @@ -460,10 +452,10 @@ class ReduceFusion(ObjectInterpretation): reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) """ - @implements(reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - if isinstance(body, Term) and body.op == reduce and body.args[0] is monoid: - return monoid.reduce(body.args[1], streams | body.args[2]) + if isinstance(body, Term) and body.op is monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) return fwd() @@ -472,12 +464,12 @@ class ReduceSplit(ObjectInterpretation): reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) """ - @implements(reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): if not is_commutative(monoid): return fwd() - if isinstance(body, Term) and body.op == plus and body.args[0] is monoid: - return monoid.plus(*(monoid.reduce(x, streams) for x in body.args[1:])) + if isinstance(body, Term) and body.op is monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) return fwd() @@ -495,18 +487,18 @@ class ReduceFactorization(ObjectInterpretation): and free(Aᵢ) ∩ S ⊆ Sᵢ """ - @implements(reduce) + @implements(Monoid.reduce) def reduce(self, monoid, body, streams): if not is_commutative(monoid): return fwd() if ( isinstance(body, Term) - and body.op == plus - and distributes_over(body.args[0], monoid) + and _is_monoid_plus(body.op) + and distributes_over(body.op.__self__, monoid) ): - inner_monoid = body.args[0] + inner_monoid: Monoid = body.op.__self__ stream_vars = set(streams.keys()) - factors = [(arg, fvsof(arg)) for arg in body.args[1:]] + factors = [(arg, fvsof(arg)) for arg in body.args] stream_ids = {v: i for (i, v) in enumerate(stream_vars)} ds = DisjointSet(len(streams)) @@ -607,24 +599,28 @@ class ReduceDistributeCartesianProduct(ObjectInterpretation): variable elimination." AISTATS. 2013. """ - @implements(reduce) + @implements(Monoid.reduce) def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): return fwd() # body is a product or multiplication of products - if sum_body.op == plus and distributes_over(sum_body.args[0], sum_monoid): - prod_reduces = sum_body.args[1:] + if _is_monoid_plus(sum_body.op) and distributes_over( + sum_body.op.__self__, sum_monoid + ): + prod_reduces = sum_body.args else: prod_reduces = [sum_body] products: list[tuple[Monoid, Callable, Operation, Term]] = [] for prod_reduce in prod_reduces: - if not (isinstance(prod_reduce, Term) and prod_reduce.op == reduce): + if not ( + isinstance(prod_reduce, Term) and _is_monoid_reduce(prod_reduce.op) + ): return fwd() - prod_monoid = typing.cast(Monoid, prod_reduce.args[0]) - prod_body = prod_reduce.args[1] - prod_streams = typing.cast(Mapping, prod_reduce.args[2]) + prod_monoid: Monoid = prod_reduce.op.__self__ + prod_body = prod_reduce.args[0] + prod_streams = typing.cast(Mapping, prod_reduce.args[1]) if not ( distributes_over(prod_monoid, sum_monoid) and (len(products) == 0 or products[-1][0] == prod_monoid) @@ -643,11 +639,10 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): if not ( isinstance(cprod_term, Term) - and cprod_term.op == reduce - and cprod_term.args[0] is CartesianProduct + and cprod_term.op is CartesianProduct.reduce ): continue - (_, cprod_body, cprod_streams) = cprod_term.args + (cprod_body, cprod_streams) = cprod_term.args if not all( prod_stream.op == cprod_op for (_, _, _, prod_stream) in products diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 7e427942f..c68e0d46c 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -364,6 +364,15 @@ def func(*args, **kwargs): return typing.cast(Operation[P, T], cls.define(func, **kwargs)) + @define.register(types.MethodType) + @classmethod + def _define_methodtype[**P, T]( + cls, t: Callable[P, T], *, name: str | None = None + ) -> "Operation[P, T]": + op = cls._define_callable(t, name=name) + op.__self__ = t.__self__ # type: ignore[attr-defined] + return typing.cast("Operation[P, T]", op) + @define.register(staticmethod) @classmethod def _define_staticmethod[**P, T](cls, t: "staticmethod[P, T]", **kwargs): @@ -403,6 +412,20 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @define.register(_CustomSingleDispatchMethod) + @classmethod + def _define_customsingledispatchmethod( + cls, default: _CustomSingleDispatchMethod, **kwargs + ): + @functools.wraps(default.func) + def _wrapper(obj, *args, **kwargs): + return default.__get__(obj)(*args, **kwargs) + + op = cls.define(_wrapper, **kwargs) + op.register = default.register # type: ignore[attr-defined] + op.dispatch = default.dispatch # type: ignore[attr-defined] + return op + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. From e99d34ad93040da69c152e7125e012b3dbdb0446 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 12 May 2026 17:33:40 -0400 Subject: [PATCH 15/34] wip --- effectful/handlers/jax/monoid.py | 67 ++++++++----- effectful/ops/monoid.py | 162 ++++++++++++++++++++++--------- 2 files changed, 162 insertions(+), 67 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 4b55674f9..ff33a1360 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -4,20 +4,20 @@ from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( - CommutativeMonoid, - CommutativeMonoidWithZero, + CartesianProduct, + Max, + Min, Monoid, - Semilattice, + Product, Streams, - distributes_over, + Sum, outer_stream, ) -from effectful.ops.semantics import evaluate, handler, typeof -from effectful.ops.syntax import deffn +from effectful.ops.semantics import evaluate, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Operation -@Operation.define def cartesian_prod(x, y): if x.ndim == 1: x = x[:, None] @@ -27,21 +27,44 @@ def cartesian_prod(x, y): return jnp.hstack([x, y]) -Sum = CommutativeMonoid(kernel=jnp.add, identity=jnp.asarray(0)) -Product = CommutativeMonoidWithZero( - kernel=jnp.multiply, identity=jnp.asarray(1), zero=jnp.asarray(0) -) -Min = Semilattice(kernel=jnp.minimum, identity=jnp.asarray(float("-inf"))) -Max = Semilattice(kernel=jnp.maximum, identity=jnp.asarray(float("inf"))) -LogSumExp = CommutativeMonoid(kernel=jnp.logaddexp, identity=jnp.asarray(float("-inf"))) -CartesianProd = Monoid(kernel=cartesian_prod, identity=jnp.array([])) - -distributes_over.register(Max.plus, Min.plus) -distributes_over.register(Min.plus, Max.plus) -distributes_over.register(Sum.plus, Min.plus) -distributes_over.register(Sum.plus, Max.plus) -distributes_over.register(Product.plus, Sum.plus) -distributes_over.register(Sum.plus, LogSumExp.plus) +LogSumExp = Monoid("LogSumExp") + + +class MinKernelJax(ObjectInterpretation): + @implements(Min.kernel) + def kernel(self, x, y): + if isinstance(x, jax.Array) and isinstance(y, jax.Array): + return jnp.minimum(x, y) + return fwd() + + +class MaxKernelJax(ObjectInterpretation): + @implements(Max.kernel) + def kernel(self, x, y): + if isinstance(x, jax.Array) and isinstance(y, jax.Array): + return jnp.maximum(x, y) + return fwd() + + +class CartesianProductKernelJax(ObjectInterpretation): + @implements(CartesianProduct.kernel) + def kernel(self, x, y): + if isinstance(x, jax.Array) and isinstance(y, jax.Array): + return cartesian_prod(x, y) + return fwd() + + +class LogSumExpKernelJax(ObjectInterpretation): + @implements(LogSumExp.identity) + def identity(self): + return float("-inf") + + @implements(LogSumExp.kernel) + def kernel(self, x, y): + if isinstance(x, jax.Array) and isinstance(y, jax.Array): + return jnp.logaddexp(x, y) + return fwd() + ARRAY_REDUCE = { Sum.plus: jnp.sum, diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 9bdf81aad..e3f47c4b2 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -1,7 +1,6 @@ import collections.abc import functools import itertools -import numbers import typing from collections import Counter, defaultdict from collections.abc import Callable, Generator, Iterable, Mapping @@ -14,7 +13,6 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - _NumberTerm, deffn, implements, iter_, @@ -61,17 +59,13 @@ def outer_stream( class Monoid[T]: - kernel: Operation[[T, T], T] - identity: T + _name: str - def __init__(self, kernel: Callable[[T, T], T], identity: T): - self.identity = identity - self.kernel = ( - kernel if isinstance(kernel, Operation) else Operation.define(kernel) - ) + def __init__(self, name): + self._name = name def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity})" + return f"Monoid({self._name!r})" def __eq__(self, other): return id(self) == id(other) @@ -79,6 +73,14 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + @Operation.define + def kernel(self, _x: T, _y: T) -> T: + raise NotHandled + + @Operation.define + def identity(self) -> T: + raise NotHandled + @Operation.define @_CustomSingleDispatchMethod def plus[S](self, dispatch, *args: S) -> S: @@ -86,14 +88,14 @@ def plus[S](self, dispatch, *args: S) -> S: callables, and interpretations. """ if not args: - return typing.cast(S, self.identity) + return typing.cast(S, self.identity()) return dispatch(type(args[0]))(self, *args) @plus.register(object) # type: ignore[attr-defined] def _(self, *args): if any(isinstance(x, Term) for x in args): raise NotHandled - return functools.reduce(self.kernel, args, self.identity) + return functools.reduce(self.kernel, args, self.identity()) @plus.register(tuple) # type: ignore[attr-defined] def _(self, *args): @@ -131,7 +133,7 @@ def reduce[A, B, U: Body]( self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]] ) -> Annotated[U, Scoped[B]]: if not streams: - return self.identity + return self.identity() # find and reduce a ground stream for stream_key, stream_body, streams_tail in outer_stream(streams): @@ -185,32 +187,10 @@ def _is_monoid_reduce(op: Operation) -> bool: class MonoidWithZero[T](Monoid[T]): - zero: T - - def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): - super().__init__(kernel, identity) - self.zero = zero - - def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - - -@Operation.define -def _arg_min[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] < a[0] else a # type: ignore - - -@Operation.define -def _arg_max[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): + @Operation.define + @staticmethod + def zero() -> T: raise NotHandled - return b if b[0] > a[0] else a # type: ignore @Operation.define @@ -226,13 +206,13 @@ def to_tuple(x): return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] -Min = Monoid(kernel=min, identity=float("inf")) -Max = Monoid(kernel=max, identity=float("-inf")) -ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) -ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) -Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) -CartesianProduct = Monoid(kernel=product, identity=[()]) +Min = Monoid("Min") +Max = Monoid("Max") +ArgMin = Monoid("ArgMin") +ArgMax = Monoid("ArgMax") +Sum = Monoid("Sum") +Product = MonoidWithZero("Product") +CartesianProduct = Monoid("CartesianProduct") @dataclass @@ -266,6 +246,98 @@ def __call__(self, s: S, t: T) -> bool: ) +class SumKernel(ObjectInterpretation): + @implements(Sum.identity) + def identity(self): + return 0 + + @implements(Sum.kernel) + def kernel(self, x, y): + return x + y + + +class ProductKernel(ObjectInterpretation): + @implements(Product.identity) + def identity(self): + return 1 + + @implements(Product.zero) + def zero(self): + return 0 + + @implements(Product.kernel) + def kernel(self, x, y): + return x * y + + +class MinKernel(ObjectInterpretation): + @implements(Min.identity) + def identity(self): + return float("inf") + + @implements(Min.kernel) + def kernel(self, x, y): + if isinstance(x, int | float) and isinstance(y, int | float): + return min(x, y) + return fwd() + + +class MaxKernel(ObjectInterpretation): + @implements(Min.identity) + def identity(self): + return -float("inf") + + @implements(Min.kernel) + def kernel(self, x, y): + if isinstance(x, int | float) and isinstance(y, int | float): + return max(x, y) + return fwd() + + +class ArgMinKernel(ObjectInterpretation): + @implements(ArgMin.identity) + def identity(self): + return (float("inf"), None) + + @implements(ArgMin.kernel) + def kernel(self, a, b): + if isinstance(a[0], Term) or isinstance(b[0], Term): + return fwd() + if isinstance(a[0], int | float) and isinstance(b[0], int | float): + return b if b[0] < a[0] else a + return fwd() + + +class ArgMaxKernel(ObjectInterpretation): + @implements(ArgMax.identity) + def identity(self): + return (-float("inf"), None) + + @implements(ArgMax.kernel) + def kernel(self, a, b): + if isinstance(a[0], Term) or isinstance(b[0], Term): + return fwd() + if isinstance(a[0], int | float) and isinstance(b[0], int | float): + return b if b[0] < a[0] else a + return fwd() + + +class CartesianProductKernel(ObjectInterpretation): + @implements(CartesianProduct.kernel) + def kernel(self, a, b): + if isinstance(a, Term) or isinstance(b, Term): + raise NotHandled + + if isinstance(a, Iterable) and isinstance(b, Iterable): + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] + + return fwd() + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" From 45cac7d9850864c99e45053eacf55d5db62d9139 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 16:00:01 -0400 Subject: [PATCH 16/34] wip --- effectful/ops/monoid.py | 105 ++++++++++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 36 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 189d193c2..3fd6cf1d4 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -1,6 +1,7 @@ import collections.abc import functools import itertools +import operator import typing from collections import Counter, defaultdict from collections.abc import Callable, Generator, Iterable, Mapping @@ -61,9 +62,11 @@ def outer_stream( class Monoid[T]: _name: str + identity: T - def __init__(self, name): + def __init__(self, identity: T, name: str): self._name = name + self.identity = identity def __repr__(self): return f"Monoid({self._name!r})" @@ -74,20 +77,6 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - def __eq__(self, other): - return id(self) == id(other) - - def __hash__(self): - return hash(id(self)) - - @Operation.define - def kernel(self, _x: T, _y: T) -> T: - raise NotHandled - - @Operation.define - def identity(self) -> T: - raise NotHandled - @Operation.define @_CustomSingleDispatchMethod def plus[S](self, dispatch, *args: S) -> S: @@ -95,14 +84,14 @@ def plus[S](self, dispatch, *args: S) -> S: callables, and interpretations. """ if not args: - return typing.cast(S, self.identity()) + return typing.cast(S, self.identity) return dispatch(type(args[0]))(self, *args) @plus.register(object) # type: ignore[attr-defined] def _(self, *args): if any(isinstance(x, Term) for x in args): raise NotHandled - return functools.reduce(self.kernel, args, self.identity()) + raise TypeError("Unexpected arguments to plus") @plus.register(tuple) # type: ignore[attr-defined] def _(self, *args): @@ -140,7 +129,7 @@ def reduce[A, B, U: Body]( self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]] ) -> Annotated[U, Scoped[B]]: if not streams: - return self.identity() + return self.identity # find and reduce a ground stream for stream_key, stream_body, streams_tail in outer_stream(streams): @@ -194,32 +183,76 @@ def _is_monoid_reduce(op: Operation) -> bool: class MonoidWithZero[T](Monoid[T]): - @Operation.define - @staticmethod - def zero() -> T: + zero: T + + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) + self.zero = zero + + +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +CartesianProduct = Monoid( + name="CartesianProduct", + identity=Operation.define(object, name="CartesianProductId"), +) + + +@Min.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): raise NotHandled + return min(*args) -@Operation.define -def product[T]( - a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] -) -> Iterable[tuple[T, ...]]: - if isinstance(a, Term) or isinstance(b, Term): +@Max.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): raise NotHandled + return max(*args) - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] +@Min.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): + raise NotHandled + return min(*args) + + +@Max.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): + raise NotHandled + return max(*args) + + +@Sum.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): + raise NotHandled + return sum(*args) + + +@Product.plus.register(int | float) +def _(self, *args): + if any(isinstance(x, Term) for x in args): + raise NotHandled + return functools.reduce(operator.mul, args) + +@CartesianProduct.plus.register(Iterable) +def _(self, *args): + if any(isinstance(x, Term) for x in args): + raise NotHandled + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) -Min = Monoid("Min") -Max = Monoid("Max") -ArgMin = Monoid("ArgMin") -ArgMax = Monoid("ArgMax") -Sum = Monoid("Sum") -Product = MonoidWithZero("Product") -CartesianProduct = Monoid("CartesianProduct") + return [sum(to_tuple(x) for x in vals) for vals in itertools.product(*args)] @dataclass From 8d0a4e8e3e2c8413d1f314b657850e7a6b47cabe Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 16:29:50 -0400 Subject: [PATCH 17/34] wip --- effectful/handlers/jax/monoid.py | 107 ++++++------ effectful/ops/monoid.py | 271 ++++++++++++------------------- 2 files changed, 162 insertions(+), 216 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index ff33a1360..f36c16714 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,3 +1,5 @@ +import functools + import jax import effectful.handlers.jax.numpy as jnp @@ -13,9 +15,9 @@ Sum, outer_stream, ) -from effectful.ops.semantics import evaluate, fwd, handler, typeof -from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import Operation +from effectful.ops.semantics import evaluate, handler, typeof +from effectful.ops.syntax import deffn +from effectful.ops.types import NotHandled, Operation def cartesian_prod(x, y): @@ -27,43 +29,37 @@ def cartesian_prod(x, y): return jnp.hstack([x, y]) -LogSumExp = Monoid("LogSumExp") +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + + +@Sum.plus.register(jax.Array) +def _(*args): + return functools.reduce(jnp.add, args) + +@Product.plus.register(jax.Array) +def _(*args): + return functools.reduce(jnp.multiply, args) -class MinKernelJax(ObjectInterpretation): - @implements(Min.kernel) - def kernel(self, x, y): - if isinstance(x, jax.Array) and isinstance(y, jax.Array): - return jnp.minimum(x, y) - return fwd() +@Min.plus.register(jax.Array) +def _(*args): + return functools.reduce(jnp.minimum, args) -class MaxKernelJax(ObjectInterpretation): - @implements(Max.kernel) - def kernel(self, x, y): - if isinstance(x, jax.Array) and isinstance(y, jax.Array): - return jnp.maximum(x, y) - return fwd() +@Max.plus.register(jax.Array) +def _(*args): + return functools.reduce(jnp.maximum, args) -class CartesianProductKernelJax(ObjectInterpretation): - @implements(CartesianProduct.kernel) - def kernel(self, x, y): - if isinstance(x, jax.Array) and isinstance(y, jax.Array): - return cartesian_prod(x, y) - return fwd() +@LogSumExp.plus.register(jax.Array) +def _(*args): + return functools.reduce(jnp.logaddexp, args) -class LogSumExpKernelJax(ObjectInterpretation): - @implements(LogSumExp.identity) - def identity(self): - return float("-inf") - @implements(LogSumExp.kernel) - def kernel(self, x, y): - if isinstance(x, jax.Array) and isinstance(y, jax.Array): - return jnp.logaddexp(x, y) - return fwd() +@CartesianProduct.plus.register(jax.Array) +def _(*args): + return functools.reduce(cartesian_prod, args) ARRAY_REDUCE = { @@ -75,28 +71,37 @@ def kernel(self, x, y): } -@Monoid.reduce.register(jax.Array) -def _reduce_array(self, body: jax.Array, streams: Streams): - reductor = ARRAY_REDUCE[self.plus] - index = Operation.define(jax.Array) +def _reduce_array_for(monoid: Monoid): + def _reduce_array(body: jax.Array, streams: Streams): + reductor = ARRAY_REDUCE[monoid.plus] + index = Operation.define(jax.Array) + + if not streams: + return monoid.identity + + # find and reduce an array stream + for stream_key, stream_body, streams_tail in outer_stream(streams): + if typeof(stream_body) != jax.Array: + continue + + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + (eval_body, eval_streams_tail) = ( + evaluate(body), + evaluate(streams_tail), + ) + assert isinstance(eval_streams_tail, dict) - if not streams: - return self.identity + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) - # find and reduce an array stream - for stream_key, stream_body, streams_tail in outer_stream(streams): - if typeof(stream_body) != jax.Array: - continue + raise NotHandled - with handler({stream_key: deffn(unbind_dims(stream_body, index))}): - (eval_body, eval_streams_tail) = evaluate(body), evaluate(streams_tail) - assert isinstance(eval_streams_tail, dict) + return _reduce_array - reduce_tail = ( - self.reduce(eval_body, eval_streams_tail) - if len(eval_streams_tail) > 0 - else eval_body - ) - return reductor(bind_dims(reduce_tail, index), axis=0) - return self._reduce_object(body, streams) +for _m in (Sum, Product, Min, Max, LogSumExp): + _m.reduce.register(jax.Array)(_reduce_array_for(_m)) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 3fd6cf1d4..551729e06 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -14,7 +14,6 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - _NumberTerm, deffn, implements, iter_, @@ -27,7 +26,6 @@ NotHandled, Operation, Term, - _CustomSingleDispatchMethod, ) # Note: The streams value type should be something like Iterable[T], but some of @@ -61,6 +59,15 @@ def outer_stream( class Monoid[T]: + """A monoid with per-instance dispatch tables for ``plus`` and ``reduce``. + + Each instance owns its own :func:`functools.singledispatch` registries. + Backends and call sites extend behavior via ``instance.plus.register(type)`` + and ``instance.reduce.register(type)``. Rewrites that key on the class-level + ``Monoid.plus`` / ``Monoid.reduce`` operations still fire, because + per-instance operations delegate to them. + """ + _name: str identity: T @@ -68,6 +75,19 @@ def __init__(self, identity: T, name: str): self._name = name self.identity = identity + # per-instance dispatch tables + self._plus_dispatch = functools.singledispatch(self._plus_default) + self._reduce_dispatch = functools.singledispatch(self._reduce_default) + + # expose register/dispatch on the per-instance cached operations + plus_op = type(self).plus.__get__(self, type(self)) + plus_op.register = self._plus_dispatch.register + plus_op.dispatch = self._plus_dispatch.dispatch + + reduce_op = type(self).reduce.__get__(self, type(self)) + reduce_op.register = self._reduce_dispatch.register + reduce_op.dispatch = self._reduce_dispatch.dispatch + def __repr__(self): return f"Monoid({self._name!r})" @@ -77,32 +97,20 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - @Operation.define - @_CustomSingleDispatchMethod - def plus[S](self, dispatch, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. - """ - if not args: - return typing.cast(S, self.identity) - return dispatch(type(args[0]))(self, *args) + # --- plus default registrations ---------------------------------------- - @plus.register(object) # type: ignore[attr-defined] - def _(self, *args): + def _plus_default(self, *args): if any(isinstance(x, Term) for x in args): raise NotHandled - raise TypeError("Unexpected arguments to plus") + raise TypeError(f"Unexpected arguments to {self._name}.plus") - @plus.register(tuple) # type: ignore[attr-defined] - def _(self, *args): + def _plus_tuple(self, *args): return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - @plus.register(Generator) # type: ignore[attr-defined] - def _(self, *args): + def _plus_generator(self, *args): return (self.plus(*vs) for vs in zip(*args, strict=True)) - @plus.register(Mapping) # type: ignore[attr-defined] - def _(self, *args): + def _plus_mapping(self, *args): if isinstance(args[0], Interpretation): keys = args[0].keys() for b in args[1:]: @@ -123,11 +131,8 @@ def _(self, *args): all_values[k].append(v) return {k: self.plus(*vs) for (k, vs) in all_values.items()} - @Operation.define - @functools.singledispatchmethod - def reduce[A, B, U: Body]( - self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]] - ) -> Annotated[U, Scoped[B]]: + # --- reduce default registrations -------------------------------------- + def _reduce_default(self, body, streams): if not streams: return self.identity @@ -153,22 +158,38 @@ def reduce[A, B, U: Body]( raise NotHandled - @reduce.register(Callable) # type: ignore[attr-defined] - def _reduce_callable(self, body: Callable, streams): + def _reduce_callable(self, body, streams): return lambda *a, **k: self.reduce(body(*a, **k), streams) - @reduce.register(Mapping) # type: ignore[attr-defined] - def _reduce_mapping(self, body: Mapping, streams): + def _reduce_mapping(self, body, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} - @reduce.register(tuple) # type: ignore[attr-defined] - def _reduce_sequence(self, body: tuple, streams): - return tuple(self.reduce(x, streams) for x in body) # type:ignore[call-arg] + def _reduce_sequence(self, body, streams): + return type(body)(self.reduce(x, streams) for x in body) - @reduce.register(Generator) # type: ignore[attr-defined] - def _reduce_generator(self, body: Generator, streams): + def _reduce_generator(self, body, streams): return (self.reduce(x, streams) for x in body) + # --- public operations ------------------------------------------------- + + @Operation.define + def plus[S](self, *args: S) -> S: + """Monoid addition with broadcasting over common collection types, + callables, and interpretations. + """ + if not args: + return typing.cast(S, self.identity) + return self._plus_dispatch.dispatch(type(args[0]))(*args) + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + """Reduce ``body`` over ``streams``.""" + return self._reduce_dispatch.dispatch(type(body))(body, streams) + def _is_monoid_plus(op: Operation) -> bool: """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" @@ -190,84 +211,96 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero +def _register_broadcasting(monoid: "Monoid") -> None: + """Register elementwise broadcasting over common collection types. + + Standard monoids treat tuples/lists/generators/mappings as containers to + broadcast over. Monoids whose values *are* collections (e.g. + :data:`CartesianProduct`) opt out. + """ + monoid.plus.register(tuple)(monoid._plus_tuple) + monoid.plus.register(Generator)(monoid._plus_generator) + monoid.plus.register(Mapping)(monoid._plus_mapping) + monoid.reduce.register(Callable)(monoid._reduce_callable) + monoid.reduce.register(Mapping)(monoid._reduce_mapping) + monoid.reduce.register(tuple)(monoid._reduce_sequence) + monoid.reduce.register(list)(monoid._reduce_sequence) + monoid.reduce.register(Generator)(monoid._reduce_generator) + + +_cartesian_product_id_op = Operation.define(object, name="CartesianProductId") + Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) Sum = Monoid(name="Sum", identity=0) Product = MonoidWithZero(name="Product", identity=1, zero=0) -CartesianProduct = Monoid( - name="CartesianProduct", - identity=Operation.define(object, name="CartesianProductId"), -) +CartesianProduct = Monoid(name="CartesianProduct", identity=_cartesian_product_id_op()) + +for _m in (Min, Max, ArgMin, ArgMax, Sum, Product): + _register_broadcasting(_m) @Min.plus.register(int | float) -def _(self, *args): +def _(*args): if any(isinstance(x, Term) for x in args): raise NotHandled - return min(*args) + return min(args) @Max.plus.register(int | float) -def _(self, *args): +def _(*args): if any(isinstance(x, Term) for x in args): raise NotHandled - return max(*args) + return max(args) -@Min.plus.register(int | float) -def _(self, *args): +@Sum.plus.register(int | float) +def _(*args): if any(isinstance(x, Term) for x in args): raise NotHandled - return min(*args) + return sum(args) -@Max.plus.register(int | float) -def _(self, *args): +@Product.plus.register(int | float) +def _(*args): if any(isinstance(x, Term) for x in args): raise NotHandled - return max(*args) + return functools.reduce(operator.mul, args) -@Sum.plus.register(int | float) -def _(self, *args): - if any(isinstance(x, Term) for x in args): +# ArgMin / ArgMax: override the default tuple broadcasting with score comparison. +@ArgMin.plus.register(tuple) +def _(*args): + if any(isinstance(a[0], Term) for a in args): + raise NotHandled + if not all(isinstance(a[0], int | float) for a in args): raise NotHandled - return sum(*args) + return min(args, key=lambda a: a[0]) -@Product.plus.register(int | float) -def _(self, *args): - if any(isinstance(x, Term) for x in args): +@ArgMax.plus.register(tuple) +def _(*args): + if any(isinstance(a[0], Term) for a in args): raise NotHandled - return functools.reduce(operator.mul, args) + if not all(isinstance(a[0], int | float) for a in args): + raise NotHandled + return max(args, key=lambda a: a[0]) +# CartesianProduct skips ``_register_broadcasting`` because tuples and lists are +# CP values, not containers to broadcast over. Its plus is registered against +# Iterable; its reduce falls through to the default ground-stream rule. @CartesianProduct.plus.register(Iterable) -def _(self, *args): +def _(*args): if any(isinstance(x, Term) for x in args): raise NotHandled def to_tuple(x): return x if isinstance(x, tuple) else (x,) - return [sum(to_tuple(x) for x in vals) for vals in itertools.product(*args)] - - -@dataclass -class _ExtensiblePredicate[T]: - elems: set[T] - - def register(self, t: T) -> None: - self.elems.add(t) - - def __call__(self, t: T) -> bool: - return t in self.elems - - -is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) -is_idempotent = _ExtensiblePredicate({Max, Min}) + return [sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args)] @dataclass @@ -301,98 +334,6 @@ def __call__(self, s: S, t: T) -> bool: ) -class SumKernel(ObjectInterpretation): - @implements(Sum.identity) - def identity(self): - return 0 - - @implements(Sum.kernel) - def kernel(self, x, y): - return x + y - - -class ProductKernel(ObjectInterpretation): - @implements(Product.identity) - def identity(self): - return 1 - - @implements(Product.zero) - def zero(self): - return 0 - - @implements(Product.kernel) - def kernel(self, x, y): - return x * y - - -class MinKernel(ObjectInterpretation): - @implements(Min.identity) - def identity(self): - return float("inf") - - @implements(Min.kernel) - def kernel(self, x, y): - if isinstance(x, int | float) and isinstance(y, int | float): - return min(x, y) - return fwd() - - -class MaxKernel(ObjectInterpretation): - @implements(Min.identity) - def identity(self): - return -float("inf") - - @implements(Min.kernel) - def kernel(self, x, y): - if isinstance(x, int | float) and isinstance(y, int | float): - return max(x, y) - return fwd() - - -class ArgMinKernel(ObjectInterpretation): - @implements(ArgMin.identity) - def identity(self): - return (float("inf"), None) - - @implements(ArgMin.kernel) - def kernel(self, a, b): - if isinstance(a[0], Term) or isinstance(b[0], Term): - return fwd() - if isinstance(a[0], int | float) and isinstance(b[0], int | float): - return b if b[0] < a[0] else a - return fwd() - - -class ArgMaxKernel(ObjectInterpretation): - @implements(ArgMax.identity) - def identity(self): - return (-float("inf"), None) - - @implements(ArgMax.kernel) - def kernel(self, a, b): - if isinstance(a[0], Term) or isinstance(b[0], Term): - return fwd() - if isinstance(a[0], int | float) and isinstance(b[0], int | float): - return b if b[0] < a[0] else a - return fwd() - - -class CartesianProductKernel(ObjectInterpretation): - @implements(CartesianProduct.kernel) - def kernel(self, a, b): - if isinstance(a, Term) or isinstance(b, Term): - raise NotHandled - - if isinstance(a, Iterable) and isinstance(b, Iterable): - - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - - return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] - - return fwd() - - class PlusEmpty(ObjectInterpretation): """plus() = 0""" From dd836bdd84cec6171b8bc7ba57760b32b80952c6 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 16:38:48 -0400 Subject: [PATCH 18/34] wip --- effectful/ops/monoid.py | 104 +++++++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 551729e06..69be4228d 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,7 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -104,11 +104,15 @@ def _plus_default(self, *args): raise NotHandled raise TypeError(f"Unexpected arguments to {self._name}.plus") - def _plus_tuple(self, *args): - return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - - def _plus_generator(self, *args): - return (self.plus(*vs) for vs in zip(*args, strict=True)) + def _plus_iterable(self, *args): + """Broadcast plus elementwise over an iterable, preserving input type + for tuples and lists; generators stay generators. + """ + zipped = zip(*args, strict=True) + result = (self.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result def _plus_mapping(self, *args): if isinstance(args[0], Interpretation): @@ -164,11 +168,14 @@ def _reduce_callable(self, body, streams): def _reduce_mapping(self, body, streams): return {k: self.reduce(v, streams) for (k, v) in body.items()} - def _reduce_sequence(self, body, streams): - return type(body)(self.reduce(x, streams) for x in body) - - def _reduce_generator(self, body, streams): - return (self.reduce(x, streams) for x in body) + def _reduce_iterable(self, body, streams): + """Broadcast reduce elementwise over an iterable, preserving input type + for tuples and lists; generators stay generators. + """ + result = (self.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result # --- public operations ------------------------------------------------- @@ -188,19 +195,7 @@ def reduce[A, B, U: Body]( streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: """Reduce ``body`` over ``streams``.""" - return self._reduce_dispatch.dispatch(type(body))(body, streams) - - -def _is_monoid_plus(op: Operation) -> bool: - """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.plus - - -def _is_monoid_reduce(op: Operation) -> bool: - """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.reduce + return self._reduce_dispatch.dispatch(typeof(body))(body, streams) class MonoidWithZero[T](Monoid[T]): @@ -211,21 +206,36 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero -def _register_broadcasting(monoid: "Monoid") -> None: - """Register elementwise broadcasting over common collection types. +def _register_sequence_broadcasting(monoid: "Monoid") -> None: + """Register elementwise broadcasting over tuples, lists, and generators. + + Appropriate when the monoid's values are scalars and these collections are + containers to broadcast over. Skipped by monoids whose values *are* + sequences (e.g. :data:`CartesianProduct`) or whose tuples carry meaning + other than "container" (e.g. :data:`ArgMin` / :data:`ArgMax`, where a + tuple is a (score, value) pair). + """ + monoid.plus.register(tuple | list | Generator)(monoid._plus_iterable) + monoid.reduce.register(tuple | list | Generator)(monoid._reduce_iterable) + + +def _register_mapping_broadcasting(monoid: "Monoid") -> None: + """Register broadcasting over dict-like containers and interpretations. - Standard monoids treat tuples/lists/generators/mappings as containers to - broadcast over. Monoids whose values *are* collections (e.g. - :data:`CartesianProduct`) opt out. + Safe for any monoid: mappings carry one value per key, and broadcasting + merges per-key values via the monoid. """ - monoid.plus.register(tuple)(monoid._plus_tuple) - monoid.plus.register(Generator)(monoid._plus_generator) monoid.plus.register(Mapping)(monoid._plus_mapping) - monoid.reduce.register(Callable)(monoid._reduce_callable) monoid.reduce.register(Mapping)(monoid._reduce_mapping) - monoid.reduce.register(tuple)(monoid._reduce_sequence) - monoid.reduce.register(list)(monoid._reduce_sequence) - monoid.reduce.register(Generator)(monoid._reduce_generator) + + +def _register_callable_broadcasting(monoid: "Monoid") -> None: + """Register lifting of :meth:`reduce` under callables. + + ``monoid.reduce(f, streams)`` becomes ``lambda *a: monoid.reduce(f(*a), + streams)``. Safe for any monoid. + """ + monoid.reduce.register(Callable)(monoid._reduce_callable) _cartesian_product_id_op = Operation.define(object, name="CartesianProductId") @@ -238,8 +248,14 @@ def _register_broadcasting(monoid: "Monoid") -> None: Product = MonoidWithZero(name="Product", identity=1, zero=0) CartesianProduct = Monoid(name="CartesianProduct", identity=_cartesian_product_id_op()) -for _m in (Min, Max, ArgMin, ArgMax, Sum, Product): - _register_broadcasting(_m) +# Scalar-valued monoids: tuples/lists/generators are containers to broadcast over. +for _m in (Min, Max, Sum, Product): + _register_sequence_broadcasting(_m) + +# Mapping and callable broadcasting are safe for every monoid. +for _m in (Min, Max, ArgMin, ArgMax, Sum, Product, CartesianProduct): + _register_mapping_broadcasting(_m) + _register_callable_broadcasting(_m) @Min.plus.register(int | float) @@ -270,7 +286,7 @@ def _(*args): return functools.reduce(operator.mul, args) -# ArgMin / ArgMax: override the default tuple broadcasting with score comparison. +# ArgMin / ArgMax: tuples are (score, value) pairs; plus compares by score. @ArgMin.plus.register(tuple) def _(*args): if any(isinstance(a[0], Term) for a in args): @@ -334,6 +350,18 @@ def __call__(self, s: S, t: T) -> bool: ) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" From fb92486ccad721a46de38fe3587c5fc41e587830 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 17:10:23 -0400 Subject: [PATCH 19/34] wip --- effectful/handlers/jax/monoid.py | 57 +++++++++++++++++++++++++------- effectful/ops/monoid.py | 31 ++++------------- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index f36c16714..c9d088a15 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -6,6 +6,8 @@ from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( + ArgMax, + ArgMin, CartesianProduct, Max, Min, @@ -62,18 +64,46 @@ def _(*args): return functools.reduce(cartesian_prod, args) -ARRAY_REDUCE = { - Sum.plus: jnp.sum, - Product.plus: jnp.prod, - Min.plus: jnp.min, - Max.plus: jnp.max, - LogSumExp.plus: logsumexp, -} +# Chain a JAX-typed case onto the tuple plus handler: handle (jax.Array, +# jax.Array) pairs here, fall through to the prior handler otherwise. +_argmin_tuple_prior = ArgMin.plus.dispatch(tuple) +_argmax_tuple_prior = ArgMax.plus.dispatch(tuple) -def _reduce_array_for(monoid: Monoid): +@ArgMin.plus.register(tuple) +def _(*args): + if all(isinstance(a[0], jax.Array) and isinstance(a[1], jax.Array) for a in args): + best_score, best_value = args[0] + for score, value in args[1:]: + is_new = score < best_score + best_score = jnp.where(is_new, score, best_score) + best_value = jnp.where(is_new, value, best_value) + return (best_score, best_value) + return _argmin_tuple_prior(*args) + + +@ArgMax.plus.register(tuple) +def _(*args): + if all(isinstance(a[0], jax.Array) and isinstance(a[1], jax.Array) for a in args): + best_score, best_value = args[0] + for score, value in args[1:]: + is_new = score > best_score + best_score = jnp.where(is_new, score, best_score) + best_value = jnp.where(is_new, value, best_value) + return (best_score, best_value) + return _argmax_tuple_prior(*args) + + +def register_array_reduce(monoid: Monoid, reductor) -> None: + """Register ``reductor`` as the JAX array reduction for ``monoid``. + + A backend-specific reducer (e.g. :func:`jnp.sum`) is paired with a monoid so + that ``monoid.reduce`` over an array-valued body and array-valued streams + unbinds the indices, evaluates, and applies the reducer along those axes. + + """ + def _reduce_array(body: jax.Array, streams: Streams): - reductor = ARRAY_REDUCE[monoid.plus] index = Operation.define(jax.Array) if not streams: @@ -100,8 +130,11 @@ def _reduce_array(body: jax.Array, streams: Streams): raise NotHandled - return _reduce_array + monoid.reduce.register(jax.Array)(_reduce_array) -for _m in (Sum, Product, Min, Max, LogSumExp): - _m.reduce.register(jax.Array)(_reduce_array_for(_m)) +register_array_reduce(Sum, jnp.sum) +register_array_reduce(Product, jnp.prod) +register_array_reduce(Min, jnp.min) +register_array_reduce(Max, jnp.max) +register_array_reduce(LogSumExp, logsumexp) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 69be4228d..f43664b8b 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -97,11 +97,7 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - # --- plus default registrations ---------------------------------------- - def _plus_default(self, *args): - if any(isinstance(x, Term) for x in args): - raise NotHandled raise TypeError(f"Unexpected arguments to {self._name}.plus") def _plus_iterable(self, *args): @@ -135,7 +131,6 @@ def _plus_mapping(self, *args): all_values[k].append(v) return {k: self.plus(*vs) for (k, vs) in all_values.items()} - # --- reduce default registrations -------------------------------------- def _reduce_default(self, body, streams): if not streams: return self.identity @@ -183,9 +178,16 @@ def _reduce_iterable(self, body, streams): def plus[S](self, *args: S) -> S: """Monoid addition with broadcasting over common collection types, callables, and interpretations. + + Any :class:`Term` arg routes to symbolic evaluation. Registered + handlers can therefore assume their args are concrete values. + Composite handlers (tuple/list, Mapping) recurse through + :meth:`plus` so interior Terms are caught at the next call. """ if not args: return typing.cast(S, self.identity) + if any(isinstance(x, Term) for x in args): + raise NotHandled return self._plus_dispatch.dispatch(type(args[0]))(*args) @Operation.define @@ -260,37 +262,26 @@ def _register_callable_broadcasting(monoid: "Monoid") -> None: @Min.plus.register(int | float) def _(*args): - if any(isinstance(x, Term) for x in args): - raise NotHandled return min(args) @Max.plus.register(int | float) def _(*args): - if any(isinstance(x, Term) for x in args): - raise NotHandled return max(args) @Sum.plus.register(int | float) def _(*args): - if any(isinstance(x, Term) for x in args): - raise NotHandled return sum(args) @Product.plus.register(int | float) def _(*args): - if any(isinstance(x, Term) for x in args): - raise NotHandled return functools.reduce(operator.mul, args) -# ArgMin / ArgMax: tuples are (score, value) pairs; plus compares by score. @ArgMin.plus.register(tuple) def _(*args): - if any(isinstance(a[0], Term) for a in args): - raise NotHandled if not all(isinstance(a[0], int | float) for a in args): raise NotHandled return min(args, key=lambda a: a[0]) @@ -298,21 +289,13 @@ def _(*args): @ArgMax.plus.register(tuple) def _(*args): - if any(isinstance(a[0], Term) for a in args): - raise NotHandled if not all(isinstance(a[0], int | float) for a in args): raise NotHandled return max(args, key=lambda a: a[0]) -# CartesianProduct skips ``_register_broadcasting`` because tuples and lists are -# CP values, not containers to broadcast over. Its plus is registered against -# Iterable; its reduce falls through to the default ground-stream rule. @CartesianProduct.plus.register(Iterable) def _(*args): - if any(isinstance(x, Term) for x in args): - raise NotHandled - def to_tuple(x): return x if isinstance(x, tuple) else (x,) From 622d4ac7e652c0416b9cb9910e42128c58702bbc Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 17:30:10 -0400 Subject: [PATCH 20/34] wip --- effectful/handlers/jax/monoid.py | 12 +++++++++++- effectful/ops/monoid.py | 9 ++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index c9d088a15..ef4d1cecd 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -61,7 +61,17 @@ def _(*args): @CartesianProduct.plus.register(jax.Array) def _(*args): - return functools.reduce(cartesian_prod, args) + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both sentinels + # arrive as Python lists alongside jax-array factors, so check for them + # explicitly before composing under :func:`cartesian_prod`. + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity # Chain a JAX-typed case onto the tuple plus handler: handle (jax.Array, diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index f43664b8b..9d7463b4b 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -240,15 +240,18 @@ def _register_callable_broadcasting(monoid: "Monoid") -> None: monoid.reduce.register(Callable)(monoid._reduce_callable) -_cartesian_product_id_op = Operation.define(object, name="CartesianProductId") - Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) Sum = Monoid(name="Sum", identity=0) Product = MonoidWithZero(name="Product", identity=1, zero=0) -CartesianProduct = Monoid(name="CartesianProduct", identity=_cartesian_product_id_op()) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero( + name="CartesianProduct", identity=[()], zero=[] +) # Scalar-valued monoids: tuples/lists/generators are containers to broadcast over. for _m in (Min, Max, Sum, Product): From 9db88d7c84960a04f10aac905e9773d90a1cb4d2 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 17:54:47 -0400 Subject: [PATCH 21/34] wip --- effectful/handlers/jax/monoid.py | 240 +++++++++-------- effectful/ops/monoid.py | 435 +++++++++++++++++------------- tests/test_handlers_jax_monoid.py | 18 +- tests/test_ops_monoid.py | 15 +- 4 files changed, 400 insertions(+), 308 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index ef4d1cecd..54cd700cd 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,4 +1,5 @@ import functools +import typing import jax @@ -6,8 +7,6 @@ from effectful.handlers.jax import bind_dims, unbind_dims from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( - ArgMax, - ArgMin, CartesianProduct, Max, Min, @@ -17,9 +16,9 @@ Sum, outer_stream, ) -from effectful.ops.semantics import evaluate, handler, typeof -from effectful.ops.syntax import deffn -from effectful.ops.types import NotHandled, Operation +from effectful.ops.semantics import coproduct, evaluate, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Interpretation, NotHandled, Operation, Term def cartesian_prod(x, y): @@ -34,117 +33,128 @@ def cartesian_prod(x, y): LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) -@Sum.plus.register(jax.Array) -def _(*args): - return functools.reduce(jnp.add, args) - - -@Product.plus.register(jax.Array) -def _(*args): - return functools.reduce(jnp.multiply, args) - - -@Min.plus.register(jax.Array) -def _(*args): - return functools.reduce(jnp.minimum, args) - - -@Max.plus.register(jax.Array) -def _(*args): - return functools.reduce(jnp.maximum, args) - - -@LogSumExp.plus.register(jax.Array) -def _(*args): - return functools.reduce(jnp.logaddexp, args) - - -@CartesianProduct.plus.register(jax.Array) -def _(*args): - # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both sentinels - # arrive as Python lists alongside jax-array factors, so check for them - # explicitly before composing under :func:`cartesian_prod`. - result = None - for a in args: - if a is CartesianProduct.zero: - return CartesianProduct.zero - if a is CartesianProduct.identity: - continue - result = a if result is None else cartesian_prod(result, a) - return result if result is not None else CartesianProduct.identity - - -# Chain a JAX-typed case onto the tuple plus handler: handle (jax.Array, -# jax.Array) pairs here, fall through to the prior handler otherwise. -_argmin_tuple_prior = ArgMin.plus.dispatch(tuple) -_argmax_tuple_prior = ArgMax.plus.dispatch(tuple) - - -@ArgMin.plus.register(tuple) -def _(*args): - if all(isinstance(a[0], jax.Array) and isinstance(a[1], jax.Array) for a in args): - best_score, best_value = args[0] - for score, value in args[1:]: - is_new = score < best_score - best_score = jnp.where(is_new, score, best_score) - best_value = jnp.where(is_new, value, best_value) - return (best_score, best_value) - return _argmin_tuple_prior(*args) - - -@ArgMax.plus.register(tuple) -def _(*args): - if all(isinstance(a[0], jax.Array) and isinstance(a[1], jax.Array) for a in args): - best_score, best_value = args[0] - for score, value in args[1:]: - is_new = score > best_score - best_score = jnp.where(is_new, score, best_score) - best_value = jnp.where(is_new, value, best_value) - return (best_score, best_value) - return _argmax_tuple_prior(*args) - - -def register_array_reduce(monoid: Monoid, reductor) -> None: - """Register ``reductor`` as the JAX array reduction for ``monoid``. - - A backend-specific reducer (e.g. :func:`jnp.sum`) is paired with a monoid so - that ``monoid.reduce`` over an array-valued body and array-valued streams - unbinds the indices, evaluates, and applies the reducer along those axes. - - """ - - def _reduce_array(body: jax.Array, streams: Streams): - index = Operation.define(jax.Array) - - if not streams: - return monoid.identity - - # find and reduce an array stream - for stream_key, stream_body, streams_tail in outer_stream(streams): - if typeof(stream_body) != jax.Array: +def _jax_args(args): + """True iff every arg is a concrete :class:`jax.Array` (no Terms).""" + return all(isinstance(a, jax.Array) and not isinstance(a, Term) for a in args) + + +class SumKernelJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductKernelJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinKernelJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxKernelJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpKernelJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductKernelJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity - with handler({stream_key: deffn(unbind_dims(stream_body, index))}): - (eval_body, eval_streams_tail) = ( - evaluate(body), - evaluate(streams_tail), - ) - assert isinstance(eval_streams_tail, dict) - - reduce_tail = ( - monoid.reduce(eval_body, eval_streams_tail) - if len(eval_streams_tail) > 0 - else eval_body - ) - return reductor(bind_dims(reduce_tail, index), axis=0) - - raise NotHandled - - monoid.reduce.register(jax.Array)(_reduce_array) +def _make_array_reduce_class(monoid: Monoid, reductor): + """Build an :class:`ObjectInterpretation` that implements + ``monoid.reduce`` for ``jax.Array`` bodies using ``reductor``. + """ -register_array_reduce(Sum, jnp.sum) -register_array_reduce(Product, jnp.prod) -register_array_reduce(Min, jnp.min) -register_array_reduce(Max, jnp.max) -register_array_reduce(LogSumExp, logsumexp) + class _ArrayReduce(ObjectInterpretation): + @implements(monoid.reduce) + def reduce(self, body, streams): + if typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity + + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if typeof(stream_body) is not jax.Array: + continue + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + return fwd() + + _ArrayReduce.__name__ = f"{monoid._name}ArrayReduceJax" + return _ArrayReduce + + +_ARRAY_REDUCE_CLASSES = [ + _make_array_reduce_class(Sum, jnp.sum), + _make_array_reduce_class(Product, jnp.prod), + _make_array_reduce_class(Min, jnp.min), + _make_array_reduce_class(Max, jnp.max), + _make_array_reduce_class(LogSumExp, logsumexp), +] + + +JaxEvaluateIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + SumKernelJax(), + ProductKernelJax(), + MinKernelJax(), + MaxKernelJax(), + LogSumExpKernelJax(), + CartesianProductKernelJax(), + *[cls() for cls in _ARRAY_REDUCE_CLASSES], + ], + ), +) +"""JAX kernels for plus and reduce. Composes with +:data:`effectful.ops.monoid.EvaluateIntp` to extend evaluation to JAX arrays. +""" diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 9d7463b4b..ff678dbcc 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -59,13 +59,12 @@ def outer_stream( class Monoid[T]: - """A monoid with per-instance dispatch tables for ``plus`` and ``reduce``. + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s. - Each instance owns its own :func:`functools.singledispatch` registries. - Backends and call sites extend behavior via ``instance.plus.register(type)`` - and ``instance.reduce.register(type)``. Rewrites that key on the class-level - ``Monoid.plus`` / ``Monoid.reduce`` operations still fire, because - per-instance operations delegate to them. + Behavior is supplied by handlers installed via an interpretation + (see :data:`NormalizeIntp`). The default rules handle only the trivial + cases (empty plus, empty streams, Term args) — everything else stays + symbolic until a handler interprets it. """ _name: str @@ -75,19 +74,6 @@ def __init__(self, identity: T, name: str): self._name = name self.identity = identity - # per-instance dispatch tables - self._plus_dispatch = functools.singledispatch(self._plus_default) - self._reduce_dispatch = functools.singledispatch(self._reduce_default) - - # expose register/dispatch on the per-instance cached operations - plus_op = type(self).plus.__get__(self, type(self)) - plus_op.register = self._plus_dispatch.register - plus_op.dispatch = self._plus_dispatch.dispatch - - reduce_op = type(self).reduce.__get__(self, type(self)) - reduce_op.register = self._reduce_dispatch.register - reduce_op.dispatch = self._reduce_dispatch.dispatch - def __repr__(self): return f"Monoid({self._name!r})" @@ -97,108 +83,42 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) - def _plus_default(self, *args): - raise TypeError(f"Unexpected arguments to {self._name}.plus") - - def _plus_iterable(self, *args): - """Broadcast plus elementwise over an iterable, preserving input type - for tuples and lists; generators stay generators. + @Operation.define + def plus[S](self, *args: S) -> S: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles empty / Term cases. """ - zipped = zip(*args, strict=True) - result = (self.plus(*vs) for vs in zipped) - if isinstance(args[0], tuple | list): - return type(args[0])(result) - return result - - def _plus_mapping(self, *args): - if isinstance(args[0], Interpretation): - keys = args[0].keys() - for b in args[1:]: - if not isinstance(b, Interpretation): - raise TypeError(f"Expected interpretation but got {b}") - if not keys == b.keys(): - raise ValueError( - f"Expected interpretation of {keys} but got {b.keys()}" - ) - return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - - for b in args[1:]: - if not isinstance(b, Mapping): - raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) - for d in args: - for k, v in d.items(): - all_values[k].append(v) - return {k: self.plus(*vs) for (k, vs) in all_values.items()} - - def _reduce_default(self, body, streams): - if not streams: - return self.identity + if not args: + return typing.cast(S, self.identity) + if any(isinstance(x, Term) for x in args): + raise NotHandled + raise NotHandled - # find and reduce a ground stream + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ for stream_key, stream_body, streams_tail in outer_stream(streams): if isinstance(stream_body, Term): continue - stream_values_iter = iter(stream_body) - - # if we iterate and get a term instead of a real iterator, skip if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: continue - new_reduces = [] for stream_val in stream_values_iter: with handler({stream_key: deffn(stream_val)}): eval_args = evaluate((body, streams_tail)) assert isinstance(eval_args, tuple) new_reduces.append(self.reduce(*eval_args)) - return self.plus(*new_reduces) - raise NotHandled - def _reduce_callable(self, body, streams): - return lambda *a, **k: self.reduce(body(*a, **k), streams) - - def _reduce_mapping(self, body, streams): - return {k: self.reduce(v, streams) for (k, v) in body.items()} - - def _reduce_iterable(self, body, streams): - """Broadcast reduce elementwise over an iterable, preserving input type - for tuples and lists; generators stay generators. - """ - result = (self.reduce(x, streams) for x in body) - if isinstance(body, tuple | list): - return type(body)(result) - return result - - # --- public operations ------------------------------------------------- - - @Operation.define - def plus[S](self, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. - - Any :class:`Term` arg routes to symbolic evaluation. Registered - handlers can therefore assume their args are concrete values. - Composite handlers (tuple/list, Mapping) recurse through - :meth:`plus` so interior Terms are caught at the next call. - """ - if not args: - return typing.cast(S, self.identity) - if any(isinstance(x, Term) for x in args): - raise NotHandled - return self._plus_dispatch.dispatch(type(args[0]))(*args) - - @Operation.define - def reduce[A, B, U: Body]( - self, - body: Annotated[U, Scoped[A | B]], - streams: Annotated[Streams, Scoped[A]], - ) -> Annotated[U, Scoped[B]]: - """Reduce ``body`` over ``streams``.""" - return self._reduce_dispatch.dispatch(typeof(body))(body, streams) - class MonoidWithZero[T](Monoid[T]): zero: T @@ -208,38 +128,6 @@ def __init__(self, name: str, identity: T, zero: T): self.zero = zero -def _register_sequence_broadcasting(monoid: "Monoid") -> None: - """Register elementwise broadcasting over tuples, lists, and generators. - - Appropriate when the monoid's values are scalars and these collections are - containers to broadcast over. Skipped by monoids whose values *are* - sequences (e.g. :data:`CartesianProduct`) or whose tuples carry meaning - other than "container" (e.g. :data:`ArgMin` / :data:`ArgMax`, where a - tuple is a (score, value) pair). - """ - monoid.plus.register(tuple | list | Generator)(monoid._plus_iterable) - monoid.reduce.register(tuple | list | Generator)(monoid._reduce_iterable) - - -def _register_mapping_broadcasting(monoid: "Monoid") -> None: - """Register broadcasting over dict-like containers and interpretations. - - Safe for any monoid: mappings carry one value per key, and broadcasting - merges per-key values via the monoid. - """ - monoid.plus.register(Mapping)(monoid._plus_mapping) - monoid.reduce.register(Mapping)(monoid._reduce_mapping) - - -def _register_callable_broadcasting(monoid: "Monoid") -> None: - """Register lifting of :meth:`reduce` under callables. - - ``monoid.reduce(f, streams)`` becomes ``lambda *a: monoid.reduce(f(*a), - streams)``. Safe for any monoid. - """ - monoid.reduce.register(Callable)(monoid._reduce_callable) - - Min = Monoid(name="Min", identity=float("inf")) Max = Monoid(name="Max", identity=-float("inf")) ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) @@ -249,60 +137,7 @@ def _register_callable_broadcasting(monoid: "Monoid") -> None: # CartesianProduct values are "two-level indexable" (rows × positions). The # identity ``[()]`` is one row of zero positions (composing with it preserves # shape); the zero ``[]`` is no rows (absorbs under product). -CartesianProduct = MonoidWithZero( - name="CartesianProduct", identity=[()], zero=[] -) - -# Scalar-valued monoids: tuples/lists/generators are containers to broadcast over. -for _m in (Min, Max, Sum, Product): - _register_sequence_broadcasting(_m) - -# Mapping and callable broadcasting are safe for every monoid. -for _m in (Min, Max, ArgMin, ArgMax, Sum, Product, CartesianProduct): - _register_mapping_broadcasting(_m) - _register_callable_broadcasting(_m) - - -@Min.plus.register(int | float) -def _(*args): - return min(args) - - -@Max.plus.register(int | float) -def _(*args): - return max(args) - - -@Sum.plus.register(int | float) -def _(*args): - return sum(args) - - -@Product.plus.register(int | float) -def _(*args): - return functools.reduce(operator.mul, args) - - -@ArgMin.plus.register(tuple) -def _(*args): - if not all(isinstance(a[0], int | float) for a in args): - raise NotHandled - return min(args, key=lambda a: a[0]) - - -@ArgMax.plus.register(tuple) -def _(*args): - if not all(isinstance(a[0], int | float) for a in args): - raise NotHandled - return max(args, key=lambda a: a[0]) - - -@CartesianProduct.plus.register(Iterable) -def _(*args): - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - - return [sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args)] +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) @dataclass @@ -569,6 +404,8 @@ class ReduceFactorization(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): + import sys + print(f"RF CALLED monoid={monoid} body={body} streams={streams}", file=sys.stderr) if not is_commutative(monoid): return fwd() if ( @@ -748,6 +585,186 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() +class ReduceOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + +class ReduceOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + +class PlusOverMapping(ObjectInterpretation): + """Broadcast ``plus`` over dict-like containers and interpretations.""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumKernel(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinKernel(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxKernel(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductKernel(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinKernel(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxKernel(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductKernel(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +def sequence_broadcasting(monoid: "Monoid") -> ObjectInterpretation: + """Return an :class:`ObjectInterpretation` that broadcasts ``monoid.plus`` + and ``monoid.reduce`` elementwise over tuples, lists, and generators. + + Tuples and lists are reconstructed as their input type; generators stay + generators. Appropriate for monoids whose values are scalars; *not* for + monoids whose values *are* sequences (:data:`CartesianProduct`) or whose + tuples carry meaning (:data:`ArgMin` / :data:`ArgMax`). + """ + + class _SequenceBroadcasting(ObjectInterpretation): + @implements(monoid.plus) + def plus(self, *args): + if not args or not isinstance(args[0], tuple | list | Generator): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(monoid.reduce) + def reduce(self, body, streams): + if not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + _SequenceBroadcasting.__name__ = f"{monoid._name}SequenceBroadcasting" + return _SequenceBroadcasting() + + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( @@ -762,4 +779,40 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): ), ) -NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) + +EvaluateIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + # universal broadcasting + PlusOverMapping(), + ReduceOverCallable(), + ReduceOverMapping(), + # per-monoid concrete kernels + SumKernel(), + MinKernel(), + MaxKernel(), + ProductKernel(), + ArgMinKernel(), + ArgMaxKernel(), + CartesianProductKernel(), + # tuple/list/Generator broadcasting, only for monoids whose values + # are scalars (not CartesianProduct, ArgMin, or ArgMax). + *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), + ], + ), +) +"""Concrete-value kernels and universal broadcasting. Install to evaluate +plus/reduce on concrete (non-Term) values without applying any rewrites. +""" + + +NormalizeIntp = coproduct( + coproduct(NormalizePlusIntp, NormalizeReduceIntp), EvaluateIntp +) +"""Rewrites *plus* evaluation. ``NormalizeIntp`` is a superset of +:data:`EvaluateIntp`: it applies pure-Term rewrites (associativity, +distributivity, identity elimination, fusion, factorization, etc.) and also +evaluates concrete arguments through the kernels. +""" diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 4efd0eb21..fc43caed1 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -3,11 +3,27 @@ import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import LogSumExp, Max, Min, Product, Sum +from effectful.handlers.jax.monoid import ( + JaxEvaluateIntp, + LogSumExp, + Max, + Min, + Product, + Sum, +) from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import EvaluateIntp +from effectful.ops.semantics import coproduct, handler from effectful.ops.types import NotHandled, Operation from tests._monoid_helpers import define_vars, syntactic_eq_alpha + +@pytest.fixture(autouse=True) +def _install_evaluate(): + """Install scalar + JAX evaluation kernels for every test in this module.""" + with handler(coproduct(EvaluateIntp, JaxEvaluateIntp)): + yield + MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), pytest.param(Product, jnp.prod, id="Product"), diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 7ba739afa..bd68a2647 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -6,6 +6,7 @@ from effectful.ops.monoid import ( CartesianProduct, + EvaluateIntp, Max, Min, Monoid, @@ -15,10 +16,22 @@ distributes_over, is_commutative, ) -from effectful.ops.semantics import evaluate, fvsof, handler +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation from tests._monoid_helpers import define_vars, random_interpretation, syntactic_eq_alpha + +@pytest.fixture(autouse=True) +def _install_normalize_intp(): + """Install :data:`NormalizeIntp` for every test in this module. + + :data:`NormalizeIntp` is a superset of :data:`EvaluateIntp` — direct + monoid calls evaluate, rewrites also fire. Will be replaced by a global + interpretation once that lands. + """ + with handler(NormalizeIntp): + yield + _INT = st.integers(min_value=-100, max_value=100) ALL_MONOIDS = [ From 12c91dafc801ed5e465c284515c5cc308302e2c4 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 18:47:38 -0400 Subject: [PATCH 22/34] drop runtime typed dict lifting --- effectful/internals/unification.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e425bba6c..efb5a27f8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -1015,10 +1015,6 @@ def _(value: collections.abc.Mapping): ktyp = functools.reduce( operator.or_, [nested_type(x).value for x in value.keys()] ) - if ktyp is str: - # str-keyed multi-entry dicts → always TypedDict - fields = {key: nested_type(vl).value for key, vl in value.items()} - return Box(typing.TypedDict("RuntimeTypeDict", fields)) # type: ignore vtyp = functools.reduce( operator.or_, [nested_type(x).value for x in value.values()] ) From 11fa13a2173cc9dcc4d89b88055defed882c5051 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 19:12:18 -0400 Subject: [PATCH 23/34] wip --- effectful/ops/monoid.py | 18 +++++++----------- tests/test_ops_monoid.py | 1 + 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ff678dbcc..652bad977 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -88,10 +88,6 @@ def plus[S](self, *args: S) -> S: """Monoid addition. Handlers supply per-monoid and broadcasting behavior; the default rule only handles empty / Term cases. """ - if not args: - return typing.cast(S, self.identity) - if any(isinstance(x, Term) for x in args): - raise NotHandled raise NotHandled @Operation.define @@ -115,7 +111,9 @@ def reduce[A, B, U: Body]( with handler({stream_key: deffn(stream_val)}): eval_args = evaluate((body, streams_tail)) assert isinstance(eval_args, tuple) - new_reduces.append(self.reduce(*eval_args)) + new_reduces.append( + self.reduce(*eval_args) if streams_tail else eval_args[0] + ) return self.plus(*new_reduces) raise NotHandled @@ -404,8 +402,6 @@ class ReduceFactorization(ObjectInterpretation): @implements(Monoid.reduce) def reduce(self, monoid, body, streams): - import sys - print(f"RF CALLED monoid={monoid} body={body} streams={streams}", file=sys.stderr) if not is_commutative(monoid): return fwd() if ( @@ -785,6 +781,9 @@ def reduce(self, body, streams): typing.cast( list[Interpretation], [ + # tuple/list/Generator broadcasting, only for monoids whose values + # are scalars (not CartesianProduct, ArgMin, or ArgMax). + *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), # universal broadcasting PlusOverMapping(), ReduceOverCallable(), @@ -797,9 +796,6 @@ def reduce(self, body, streams): ArgMinKernel(), ArgMaxKernel(), CartesianProductKernel(), - # tuple/list/Generator broadcasting, only for monoids whose values - # are scalars (not CartesianProduct, ArgMin, or ArgMax). - *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), ], ), ) @@ -809,7 +805,7 @@ def reduce(self, body, streams): NormalizeIntp = coproduct( - coproduct(NormalizePlusIntp, NormalizeReduceIntp), EvaluateIntp + coproduct(EvaluateIntp, NormalizeReduceIntp), NormalizePlusIntp ) """Rewrites *plus* evaluation. ``NormalizeIntp`` is a superset of :data:`EvaluateIntp`: it applies pure-Term rewrites (associativity, diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index bd68a2647..1a03ff606 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -32,6 +32,7 @@ def _install_normalize_intp(): with handler(NormalizeIntp): yield + _INT = st.integers(min_value=-100, max_value=100) ALL_MONOIDS = [ From 8711a765466af5b92ccc1fbd53489557bf982304 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 19:12:37 -0400 Subject: [PATCH 24/34] format --- effectful/handlers/jax/monoid.py | 3 +-- effectful/ops/monoid.py | 10 ++-------- tests/test_handlers_jax_monoid.py | 1 + tests/test_ops_monoid.py | 3 +-- 4 files changed, 5 insertions(+), 12 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 54cd700cd..295ae2caf 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -12,13 +12,12 @@ Min, Monoid, Product, - Streams, Sum, outer_stream, ) from effectful.ops.semantics import coproduct, evaluate, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import Interpretation, NotHandled, Operation, Term +from effectful.ops.types import Interpretation, Operation, Term def cartesian_prod(x, y): diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 652bad977..072408ac7 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -10,7 +10,7 @@ from typing import Annotated, Any from effectful.internals.disjoint_set import DisjointSet -from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler, typeof +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler from effectful.ops.syntax import ( ObjectInterpretation, Scoped, @@ -20,13 +20,7 @@ syntactic_eq, syntactic_hash, ) -from effectful.ops.types import ( - Expr, - Interpretation, - NotHandled, - Operation, - Term, -) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index fc43caed1..9120acd07 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -24,6 +24,7 @@ def _install_evaluate(): with handler(coproduct(EvaluateIntp, JaxEvaluateIntp)): yield + MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), pytest.param(Product, jnp.prod, id="Product"), diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 1a03ff606..95e0bb2f9 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -6,7 +6,6 @@ from effectful.ops.monoid import ( CartesianProduct, - EvaluateIntp, Max, Min, Monoid, @@ -16,7 +15,7 @@ distributes_over, is_commutative, ) -from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.semantics import evaluate, fvsof, handler from effectful.ops.types import NotHandled, Operation from tests._monoid_helpers import define_vars, random_interpretation, syntactic_eq_alpha From 0ccbefb24cce4d125dac3171a090e4fa67fe15ba Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 13 May 2026 19:19:31 -0400 Subject: [PATCH 25/34] reorganize --- effectful/handlers/jax/monoid.py | 93 +++++++++++++++----------------- effectful/ops/monoid.py | 88 +++++++++++++++--------------- 2 files changed, 86 insertions(+), 95 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 295ae2caf..151bec530 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -37,7 +37,7 @@ def _jax_args(args): return all(isinstance(a, jax.Array) and not isinstance(a, Term) for a in args) -class SumKernelJax(ObjectInterpretation): +class SumPlusJax(ObjectInterpretation): @implements(Sum.plus) def plus(self, *args): if not _jax_args(args): @@ -45,7 +45,7 @@ def plus(self, *args): return functools.reduce(jnp.add, args) -class ProductKernelJax(ObjectInterpretation): +class ProductPlusJax(ObjectInterpretation): @implements(Product.plus) def plus(self, *args): if not _jax_args(args): @@ -53,7 +53,7 @@ def plus(self, *args): return functools.reduce(jnp.multiply, args) -class MinKernelJax(ObjectInterpretation): +class MinPlusJax(ObjectInterpretation): @implements(Min.plus) def plus(self, *args): if not _jax_args(args): @@ -61,7 +61,7 @@ def plus(self, *args): return functools.reduce(jnp.minimum, args) -class MaxKernelJax(ObjectInterpretation): +class MaxPlusJax(ObjectInterpretation): @implements(Max.plus) def plus(self, *args): if not _jax_args(args): @@ -69,7 +69,7 @@ def plus(self, *args): return functools.reduce(jnp.maximum, args) -class LogSumExpKernelJax(ObjectInterpretation): +class LogSumExpPlusJax(ObjectInterpretation): @implements(LogSumExp.plus) def plus(self, *args): if not _jax_args(args): @@ -77,7 +77,7 @@ def plus(self, *args): return functools.reduce(jnp.logaddexp, args) -class CartesianProductKernelJax(ObjectInterpretation): +class CartesianProductPlusJax(ObjectInterpretation): @implements(CartesianProduct.plus) def plus(self, *args): # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both @@ -97,46 +97,39 @@ def plus(self, *args): return result if result is not None else CartesianProduct.identity -def _make_array_reduce_class(monoid: Monoid, reductor): - """Build an :class:`ObjectInterpretation` that implements - ``monoid.reduce`` for ``jax.Array`` bodies using ``reductor``. - """ +ARRAY_REDUCTORS = { + Sum: jnp.sum, + Product: jnp.prod, + Min: jnp.min, + Max: jnp.max, + LogSumExp: logsumexp, +} - class _ArrayReduce(ObjectInterpretation): - @implements(monoid.reduce) - def reduce(self, body, streams): - if typeof(body) is not jax.Array: - return fwd() - if not streams: - return monoid.identity - - index = Operation.define(jax.Array) - for stream_key, stream_body, streams_tail in outer_stream(streams): - if typeof(stream_body) is not jax.Array: - continue - with handler({stream_key: deffn(unbind_dims(stream_body, index))}): - eval_body = evaluate(body) - eval_streams_tail = evaluate(streams_tail) - assert isinstance(eval_streams_tail, dict) - reduce_tail = ( - monoid.reduce(eval_body, eval_streams_tail) - if len(eval_streams_tail) > 0 - else eval_body - ) - return reductor(bind_dims(reduce_tail, index), axis=0) - return fwd() - - _ArrayReduce.__name__ = f"{monoid._name}ArrayReduceJax" - return _ArrayReduce +class ArrayReduce(ObjectInterpretation): + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity -_ARRAY_REDUCE_CLASSES = [ - _make_array_reduce_class(Sum, jnp.sum), - _make_array_reduce_class(Product, jnp.prod), - _make_array_reduce_class(Min, jnp.min), - _make_array_reduce_class(Max, jnp.max), - _make_array_reduce_class(LogSumExp, logsumexp), -] + reductor = ARRAY_REDUCTORS[monoid] + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if typeof(stream_body) is not jax.Array: + continue + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + return fwd() JaxEvaluateIntp = functools.reduce( @@ -144,13 +137,13 @@ def reduce(self, body, streams): typing.cast( list[Interpretation], [ - SumKernelJax(), - ProductKernelJax(), - MinKernelJax(), - MaxKernelJax(), - LogSumExpKernelJax(), - CartesianProductKernelJax(), - *[cls() for cls in _ARRAY_REDUCE_CLASSES], + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), + ArrayReduce(), ], ), ) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 072408ac7..ac037c3aa 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -324,24 +324,6 @@ def plus(self, monoid, *args): return fwd() -NormalizePlusIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - PlusEmpty(), - PlusSingle(), - PlusIdentity(), - PlusAssoc(), - PlusDistr(), - PlusZero(), - PlusConsecutiveDups(), - PlusDups(), - ], - ), -) - - class ReduceNoStreams(ObjectInterpretation): """Implements the identity reduce(R, ∅, body) = 0 @@ -633,7 +615,7 @@ def _scalar_args(args): ) -class SumKernel(ObjectInterpretation): +class SumPlus(ObjectInterpretation): """Scalar implementation of :data:`Sum`.""" @implements(Sum.plus) @@ -643,7 +625,7 @@ def plus(self, *args): return sum(args) -class MinKernel(ObjectInterpretation): +class MinPlus(ObjectInterpretation): """Scalar implementation of :data:`Min`.""" @implements(Min.plus) @@ -653,7 +635,7 @@ def plus(self, *args): return min(args) -class MaxKernel(ObjectInterpretation): +class MaxPlus(ObjectInterpretation): """Scalar implementation of :data:`Max`.""" @implements(Max.plus) @@ -663,7 +645,7 @@ def plus(self, *args): return max(args) -class ProductKernel(ObjectInterpretation): +class ProductPlus(ObjectInterpretation): """Scalar implementation of :data:`Product`.""" @implements(Product.plus) @@ -673,7 +655,7 @@ def plus(self, *args): return functools.reduce(operator.mul, args) -class ArgMinKernel(ObjectInterpretation): +class ArgMinPlus(ObjectInterpretation): """Scalar score implementation of :data:`ArgMin`.""" @implements(ArgMin.plus) @@ -687,7 +669,7 @@ def plus(self, *args): return min(args, key=lambda a: a[0]) -class ArgMaxKernel(ObjectInterpretation): +class ArgMaxPlus(ObjectInterpretation): """Scalar score implementation of :data:`ArgMax`.""" @implements(ArgMax.plus) @@ -701,7 +683,7 @@ def plus(self, *args): return max(args, key=lambda a: a[0]) -class CartesianProductKernel(ObjectInterpretation): +class CartesianProductPlus(ObjectInterpretation): """Pure-Python implementation of :data:`CartesianProduct`.""" @implements(CartesianProduct.plus) @@ -755,6 +737,33 @@ def reduce(self, body, streams): return _SequenceBroadcasting() +EvaluateIntp = functools.reduce( + coproduct, + typing.cast( + list[Interpretation], + [ + # tuple/list/Generator broadcasting, only for monoids whose values + # are scalars (not CartesianProduct, ArgMin, or ArgMax). + *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), + # universal broadcasting + PlusOverMapping(), + ReduceOverCallable(), + ReduceOverMapping(), + # per-monoid concrete kernels + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), + ], + ), +) +"""Concrete-value kernels and universal broadcasting. Install to evaluate +plus/reduce on concrete (non-Term) values without applying any rewrites. +""" + NormalizeReduceIntp = functools.reduce( coproduct, typing.cast( @@ -769,33 +778,22 @@ def reduce(self, body, streams): ), ) - -EvaluateIntp = functools.reduce( +NormalizePlusIntp = functools.reduce( coproduct, typing.cast( list[Interpretation], [ - # tuple/list/Generator broadcasting, only for monoids whose values - # are scalars (not CartesianProduct, ArgMin, or ArgMax). - *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), - # universal broadcasting - PlusOverMapping(), - ReduceOverCallable(), - ReduceOverMapping(), - # per-monoid concrete kernels - SumKernel(), - MinKernel(), - MaxKernel(), - ProductKernel(), - ArgMinKernel(), - ArgMaxKernel(), - CartesianProductKernel(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), ], ), ) -"""Concrete-value kernels and universal broadcasting. Install to evaluate -plus/reduce on concrete (non-Term) values without applying any rewrites. -""" NormalizeIntp = coproduct( From 76f7002fe61404fc5d9c9a98b3f0a309514f1cf1 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 09:51:58 -0400 Subject: [PATCH 26/34] stop using string dicts to avoid unification issue --- effectful/internals/unification.py | 4 ++++ tests/test_ops_monoid.py | 10 +++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index efb5a27f8..e425bba6c 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -1015,6 +1015,10 @@ def _(value: collections.abc.Mapping): ktyp = functools.reduce( operator.or_, [nested_type(x).value for x in value.keys()] ) + if ktyp is str: + # str-keyed multi-entry dicts → always TypedDict + fields = {key: nested_type(vl).value for key, vl in value.items()} + return Box(typing.TypedDict("RuntimeTypeDict", fields)) # type: ignore vtyp = functools.reduce( operator.or_, [nested_type(x).value for x in value.values()] ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 95e0bb2f9..16748a9cd 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -184,8 +184,8 @@ def test_plus_sequence(monoid): def test_plus_mapping(monoid): a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) _check_pair( - lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), - rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, + lhs=monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}), + rhs={0: monoid.plus(a(), c()), 1: b(), 2: d()}, free_vars=[a, b, c, d], ) @@ -373,10 +373,10 @@ def f(_x: int) -> int: g = Operation.define(f, name="g") - lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { - "a": monoid.reduce(f(x()), {x: X()}), - "b": monoid.reduce(g(x()), {x: X()}), + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), } _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) From 6fdd23d4291da59dbc9296edcd683d52013eba4c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 10:20:48 -0400 Subject: [PATCH 27/34] wip --- effectful/handlers/jax/monoid.py | 16 +- effectful/ops/syntax.py | 2 + tests/_monoid_helpers.py | 105 +++++++- tests/test_ops_monoid.py | 429 ++++++++++++++++--------------- 4 files changed, 336 insertions(+), 216 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 151bec530..14be90f36 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -16,10 +16,16 @@ outer_stream, ) from effectful.ops.semantics import coproduct, evaluate, fwd, handler, typeof -from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.syntax import ObjectInterpretation, deffn, implements, syntactic_hash from effectful.ops.types import Interpretation, Operation, Term +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) + + def cartesian_prod(x, y): if x.ndim == 1: x = x[:, None] @@ -33,8 +39,12 @@ def cartesian_prod(x, y): def _jax_args(args): - """True iff every arg is a concrete :class:`jax.Array` (no Terms).""" - return all(isinstance(a, jax.Array) and not isinstance(a, Term) for a in args) + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.Array` (no Terms). + """ + return bool(args) and all( + isinstance(a, jax.Array) and not isinstance(a, Term) for a in args + ) class SumPlusJax(ObjectInterpretation): diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8fb12598f..5ea04fcb8 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,6 +849,8 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False return x == other diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 7643bb202..fd785c0af 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,5 +1,6 @@ import itertools from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import Any, get_args, get_origin import jax @@ -9,7 +10,7 @@ from effectful.internals.runtime import interpreter from effectful.ops.semantics import apply, evaluate from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq -from effectful.ops.types import Operation +from effectful.ops.types import NotHandled, Operation _JAX_ARRAY_SHAPE = (3,) @@ -47,9 +48,11 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: return st.lists(st.integers(), max_size=2) if annotation is jax.Array: return _jax_array_value_strategy() + if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): + return st.lists(_jax_array_value_strategy(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int], jax.Array" + "supported: int, list[int], jax.Array, list[jax.Array]" ) @@ -78,6 +81,13 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: [0, x, x + 1], ] +_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1.0], + lambda x: [x, -x], +] + def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: """Pick a strategy producing a callable suitable for binding `op` in an @@ -100,6 +110,12 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_UNARY_JAX_FNS) if ret is jax.Array and param_types == (jax.Array, jax.Array): return st.sampled_from(_BINARY_JAX_FNS) + if ( + get_origin(ret) is list + and get_args(ret) == (jax.Array,) + and param_types == (jax.Array,) + ): + return st.sampled_from(_UNARY_JAX_LIST_FNS) raise NotImplementedError( f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -233,4 +249,87 @@ def _apply_canonical(op, *args, **kwargs): return evaluate(expr) -__all__ = ["random_interpretation", "define_vars", "syntactic_eq_alpha"] +@dataclass(frozen=True) +class Backend: + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + scalar_strategy: st.SearchStrategy[Any] + eq: Callable[[Any, Any], bool] + lift: Callable[[Any], Any] + + def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(n_args)) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": scalar for i in range(n_args)}, + "return": out, + } + return Operation.define(fn, name=name) + + +def _int_eq(a: Any, b: Any) -> bool: + return a == b + + +def _jax_eq(a: Any, b: Any) -> bool: + if isinstance(a, dict) and isinstance(b, dict): + if set(a.keys()) != set(b.keys()): + return False + return all(_jax_eq(a[k], b[k]) for k in a) + if isinstance(a, tuple | list) and isinstance(b, tuple | list): + if len(a) != len(b): + return False + return all(_jax_eq(x, y) for x, y in zip(a, b, strict=True)) + if isinstance(a, jax.Array) or isinstance(b, jax.Array): + aa, bb = jax.numpy.asarray(a), jax.numpy.asarray(b) + aa, bb = jax.numpy.broadcast_arrays(aa, bb) + return bool(jax.numpy.all(jax.numpy.isclose(aa, bb, equal_nan=True))) + return a == b + + +INT_BACKEND = Backend( + name="int", + scalar_typ=int, + stream_typ=list[int], + scalar_strategy=st.integers(min_value=-100, max_value=100), + eq=_int_eq, + lift=lambda x: x, +) + + +JAX_BACKEND = Backend( + name="jax", + scalar_typ=jax.Array, + stream_typ=list[jax.Array], + scalar_strategy=_jax_array_value_strategy(), + eq=_jax_eq, + lift=lambda x: jax.numpy.asarray(x, dtype=jax.numpy.float32), +) + + +__all__ = [ + "Backend", + "INT_BACKEND", + "JAX_BACKEND", + "random_interpretation", + "define_vars", + "syntactic_eq_alpha", +] diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 16748a9cd..1c1931fbb 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,9 +1,10 @@ import typing import pytest -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st +from effectful.handlers.jax.monoid import JaxEvaluateIntp from effectful.ops.monoid import ( CartesianProduct, Max, @@ -15,25 +16,35 @@ distributes_over, is_commutative, ) -from effectful.ops.semantics import evaluate, fvsof, handler -from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import define_vars, random_interpretation, syntactic_eq_alpha +from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + INT_BACKEND, + JAX_BACKEND, + Backend, + define_vars, + random_interpretation, + syntactic_eq_alpha, +) -@pytest.fixture(autouse=True) -def _install_normalize_intp(): - """Install :data:`NormalizeIntp` for every test in this module. +@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param + - :data:`NormalizeIntp` is a superset of :data:`EvaluateIntp` — direct - monoid calls evaluate, rewrites also fire. Will be replaced by a global - interpretation once that lands. +@pytest.fixture(autouse=True) +def _install_normalize_intp(backend): + """Install :data:`NormalizeIntp` (plus JAX kernels when the backend is + jax) for every test in this module. """ - with handler(NormalizeIntp): + intp = NormalizeIntp + if backend.scalar_typ is not int: + intp = coproduct(intp, JaxEvaluateIntp) + with handler(intp): yield -_INT = st.integers(min_value=-100, max_value=100) - ALL_MONOIDS = [ pytest.param(Sum, id="Sum"), pytest.param(Product, id="Product"), @@ -71,45 +82,57 @@ def _install_normalize_intp(): @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT, b=_INT, c=_INT) -@settings(max_examples=50, deadline=None) -def test_associativity(monoid, a, b, c): +@given(data=st.data()) +@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_associativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + c = data.draw(backend.scalar_strategy) left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) - assert left == right + assert backend.eq(left, right) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_identity(monoid, a): - assert monoid.plus(monoid.identity, a) == a - assert monoid.plus(a, monoid.identity) == a +@given(data=st.data()) +@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_identity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + ident = backend.lift(monoid.identity) + assert backend.eq(monoid.plus(ident, a), a) + assert backend.eq(monoid.plus(a, ident), a) @pytest.mark.parametrize("monoid", COMMUTATIVE) -@given(a=_INT, b=_INT) -@settings(max_examples=50, deadline=None) -def test_commutativity(monoid, a, b): - assert monoid.plus(a, b) == monoid.plus(b, a) +@given(data=st.data()) +@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_commutativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @pytest.mark.parametrize("monoid", IDEMPOTENT) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_idempotence(monoid, a): - assert monoid.plus(a, a) == a +@given(data=st.data()) +@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_idempotence(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + assert backend.eq(monoid.plus(a, a), a) @pytest.mark.parametrize("monoid", WITH_ZERO) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_zero_absorbs(monoid, a): - assert monoid.plus(monoid.zero, a) == monoid.zero - assert monoid.plus(a, monoid.zero) == monoid.zero - - -def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: +@given(data=st.data()) +@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +def test_zero_absorbs(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + zero = backend.lift(monoid.zero) + assert backend.eq(monoid.plus(zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, zero), monoid.zero) + + +def _check_pair( + lhs, rhs, *, backend: Backend, free_vars=[], max_examples: int = 25 +) -> None: """Run structural + semantic checks on a TermPair.""" with handler(NormalizeIntp): norm = evaluate(lhs) @@ -117,81 +140,99 @@ def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: assert syntactic_eq_alpha(norm, rhs) @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=None) + @settings( + max_examples=max_examples, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + ) def _check_semantics(intp): with handler(intp): lhs_val = evaluate(lhs) rhs_val = evaluate(rhs) - assert lhs_val == rhs_val + assert backend.eq(lhs_val, rhs_val) _check_semantics() @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid): - _check_pair(lhs=monoid.plus(), rhs=monoid.identity) +def test_plus_empty(monoid, backend): + _check_pair(lhs=monoid.plus(), rhs=monoid.identity, backend=backend) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) +def test_plus_single(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + _check_pair(lhs=monoid.plus(x()), rhs=x(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) +def test_plus_identity_right(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + _check_pair( + lhs=monoid.plus(x(), monoid.identity), + rhs=x(), + backend=backend, + free_vars=[x], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) +def test_plus_identity_left(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + _check_pair( + lhs=monoid.plus(monoid.identity, x()), + rhs=x(), + backend=backend, + free_vars=[x], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) +def test_plus_assoc_right(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) _check_pair( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) +def test_plus_assoc_left(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) _check_pair( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) +def test_plus_sequence(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) _check_pair( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + backend=backend, free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) +def test_plus_mapping(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) _check_pair( lhs=monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}), rhs={0: monoid.plus(a(), c()), 1: b(), 2: d()}, + backend=backend, free_vars=[a, b, c, d], ) -def test_plus_distributes(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) rhs = Sum.plus( Product.plus(a(), c()), @@ -199,11 +240,11 @@ def test_plus_distributes(): Product.plus(b(), c()), Product.plus(b(), d()), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) -def test_plus_distributes_constant(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_constant(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -214,11 +255,11 @@ def test_plus_distributes_constant(): Product.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) -def test_plus_distributes_multiple(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_multiple(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -239,118 +280,111 @@ def test_plus_distributes_multiple(): Sum.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid): +def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + return _check_pair( + lhs=lhs, rhs=monoid.plus(a(), b()), backend=backend, free_vars=[a, b] + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid): +def test_plus_idempotent_non_consecutive(monoid, backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) if is_commutative(monoid): rhs = monoid.plus(a(), b()) else: rhs = monoid.plus(a(), b(), a()) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b]) -def test_plus_commutative_idempotent_long(): +def test_plus_commutative_idempotent_long(backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + _check_pair( + lhs=lhs, rhs=Min.plus(a(), b()), backend=backend, free_vars=[a, b] + ) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid): - a = define_vars("a") +def test_plus_zero(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) - _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) - _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + _check_pair(lhs=lhs_right, rhs=monoid.zero, backend=backend, free_vars=[a]) + _check_pair(lhs=lhs_left, rhs=monoid.zero, backend=backend, free_vars=[a]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_1(monoid): - x, y = define_vars("x", "y") +def test_partial_1(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_2(monoid): - x, y = define_vars("x", "y") - Y = define_vars("Y", typ=list[int]) +def test_partial_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, Y]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_3(monoid): - x, y, a, b = define_vars("x", "y", "a", "b") - Y = define_vars("Y", typ=list[int]) +def test_partial_3(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, Y]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, a, b, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_partial_4(monoid): - x, y, a, b = define_vars("x", "y", "a", "b") - - @Operation.define - def f(_x: int) -> list[int]: - raise NotHandled +def test_partial_4(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=1, ret="stream") lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[x, y, a, b, f]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, a, b, f]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_sequence(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, f, g]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid): - x, y = define_vars("x", "y") - X, Y = define_vars("X", "Y", typ=list[int]) - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_sequence_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + X, Y = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) @@ -359,18 +393,14 @@ def f(_x: int) -> int: monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, Y, f, g]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_mapping(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) @@ -378,82 +408,70 @@ def f(_x: int) -> int: 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, f, g]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid): - a = define_vars("a") +def test_reduce_no_streams(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_reduce(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, f]) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_plus(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B]) -def test_reduce_independent_1(): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_independent_1(backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B]) -def test_reduce_independent_2(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_2(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( Sum.reduce(a(), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, C, f]) -def test_reduce_independent_3_negative(): +def test_reduce_independent_3_negative(backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c") - A, C = define_vars("A", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled - - @Operation.define - def g(_x: int) -> list[int]: - raise NotHandled + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, C = define_vars("A", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") with handler(NormalizeIntp): lhs = Sum.reduce( @@ -467,13 +485,10 @@ def g(_x: int) -> list[int]: assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_4(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( @@ -481,29 +496,26 @@ def f(_x: int, _y: int) -> int: Sum.reduce(a(), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, C, f]) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner): - a, i = define_vars("a", "i") - A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_1(outer, inner, backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + _check_pair(lhs=term1, rhs=term2, backend=backend, free_vars=[N, A_domain, f]) -def test_reduce_cartesian_1(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) +def test_reduce_cartesian_1(backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A = define_vars("A", typ=backend.stream_typ) term1 = Sum.reduce( Product.reduce(a(), {a: []}), @@ -513,9 +525,9 @@ def test_reduce_cartesian_1(): assert term1 == term2 -def test_reduce_cartesian_2(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) +def test_reduce_cartesian_2(backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A = define_vars("A", typ=backend.stream_typ) term1 = Sum.reduce( Product.reduce(a(), {a: A()}), @@ -526,13 +538,12 @@ def test_reduce_cartesian_2(): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner): - a, i, j = define_vars("a", "i", "j") - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_multi_index(outer, inner, backend): + a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) + A, N, M, A_domain = define_vars( + "A", "N", "M", "A_domain", typ=backend.stream_typ + ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), @@ -542,29 +553,22 @@ def f(_: int) -> float: outer.reduce(f(a()), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + _check_pair( + lhs=term1, rhs=term2, backend=backend, free_vars=[N, M, A_domain, f] + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner): +def test_reduce_lifted_2(outer, inner, backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t") - A, N, T = define_vars("A", "N", "T", typ=list[int]) - - @Operation.define - def A_domain(_i: int) -> list[int]: - raise NotHandled - - @Operation.define - def f1(_a: int, _s: int) -> float: - raise NotHandled - - @Operation.define - def f2(_t: int, _a: int) -> float: - raise NotHandled + a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) + A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) + A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") + f1 = backend.fresh_op("f1", n_args=2, ret="scalar") + f2 = backend.fresh_op("f2", n_args=2, ret="scalar") term1 = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), @@ -579,4 +583,9 @@ def f2(_t: int, _a: int) -> float: {t: T()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) + _check_pair( + lhs=term1, + rhs=term2, + backend=backend, + free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + ) From 6996a8cf2d547fac11934dea3d4fb8943c3ce927 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 12:04:10 -0400 Subject: [PATCH 28/34] wip --- effectful/handlers/jax/_handlers.py | 7 ++ effectful/handlers/jax/monoid.py | 39 +++----- effectful/ops/monoid.py | 132 ++++++++++++---------------- tests/_monoid_helpers.py | 27 ++---- tests/test_ops_monoid.py | 86 +++++++++--------- 5 files changed, 128 insertions(+), 163 deletions(-) diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76e..c5d104233 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -19,6 +19,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 14be90f36..378207e60 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -8,24 +8,20 @@ from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( CartesianProduct, + EvaluateIntp, Max, Min, Monoid, + NormalizeIntp, Product, Sum, outer_stream, ) from effectful.ops.semantics import coproduct, evaluate, fwd, handler, typeof -from effectful.ops.syntax import ObjectInterpretation, deffn, implements, syntactic_hash +from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Interpretation, Operation, Term -@syntactic_hash.register(jax.Array) -def _(x: jax.Array) -> int: - # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. - return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) - - def cartesian_prod(x, y): if x.ndim == 1: x = x[:, None] @@ -42,9 +38,7 @@ def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete :class:`jax.Array` (no Terms). """ - return bool(args) and all( - isinstance(a, jax.Array) and not isinstance(a, Term) for a in args - ) + return bool(args) and all(not isinstance(a, Term) for a in args) class SumPlusJax(ObjectInterpretation): @@ -142,21 +136,12 @@ def reduce(self, monoid, body, streams): return fwd() -JaxEvaluateIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - SumPlusJax(), - ProductPlusJax(), - MinPlusJax(), - MaxPlusJax(), - LogSumExpPlusJax(), - CartesianProductPlusJax(), - ArrayReduce(), - ], - ), +NormalizeIntp.extend(ArrayReduce()) +EvaluateIntp.extend( + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), ) -"""JAX kernels for plus and reduce. Composes with -:data:`effectful.ops.monoid.EvaluateIntp` to extend evaluation to JAX arrays. -""" diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index ac037c3aa..541d4b063 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -3,7 +3,7 @@ import itertools import operator import typing -from collections import Counter, defaultdict +from collections import Counter, UserDict, defaultdict from collections.abc import Callable, Generator, Iterable, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter @@ -53,13 +53,7 @@ def outer_stream( class Monoid[T]: - """A monoid with ``plus`` and ``reduce`` :class:`Operation` s. - - Behavior is supplied by handlers installed via an interpretation - (see :data:`NormalizeIntp`). The default rules handle only the trivial - cases (empty plus, empty streams, Term args) — everything else stays - symbolic until a handler interprets it. - """ + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" _name: str identity: T @@ -77,11 +71,15 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + # the weak typing allows us to write monoid.plus(monoid.identity, ) + # and monoid.plus(monoid.identity, ) @Operation.define - def plus[S](self, *args: S) -> S: + def plus(self, *args: Any) -> Any: """Monoid addition. Handlers supply per-monoid and broadcasting behavior; the default rule only handles empty / Term cases. """ + if not args: + return self.identity raise NotHandled @Operation.define @@ -557,7 +555,7 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() -class ReduceOverCallable(ObjectInterpretation): +class MonoidOverCallable(ObjectInterpretation): """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" @implements(Monoid.reduce) @@ -566,8 +564,16 @@ def reduce(self, monoid, body, streams): return fwd() return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + -class ReduceOverMapping(ObjectInterpretation): +class MonoidOverMapping(ObjectInterpretation): """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" @implements(Monoid.reduce) @@ -576,10 +582,6 @@ def reduce(self, monoid, body, streams): return fwd() return {k: monoid.reduce(v, streams) for (k, v) in body.items()} - -class PlusOverMapping(ObjectInterpretation): - """Broadcast ``plus`` over dict-like containers and interpretations.""" - @implements(Monoid.plus) def plus(self, monoid, *args): if not args or not isinstance(args[0], Mapping): @@ -733,74 +735,54 @@ def reduce(self, body, streams): return type(body)(result) return result - _SequenceBroadcasting.__name__ = f"{monoid._name}SequenceBroadcasting" + _SequenceBroadcasting.__name__ = f"{monoid._name}OverSequence" return _SequenceBroadcasting() -EvaluateIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - # tuple/list/Generator broadcasting, only for monoids whose values - # are scalars (not CartesianProduct, ArgMin, or ArgMax). - *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), - # universal broadcasting - PlusOverMapping(), - ReduceOverCallable(), - ReduceOverMapping(), - # per-monoid concrete kernels - SumPlus(), - MinPlus(), - MaxPlus(), - ProductPlus(), - ArgMinPlus(), - ArgMaxPlus(), - CartesianProductPlus(), - ], - ), +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) + return self + + +EvaluateIntp = _ExtensibleInterpretation().extend( + # tuple/list/Generator broadcasting, only for monoids whose values + # are scalars (not CartesianProduct, ArgMin, or ArgMax). + *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), + # universal broadcasting + MonoidOverMapping(), + MonoidOverCallable(), + # per-monoid concrete kernels + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), ) """Concrete-value kernels and universal broadcasting. Install to evaluate plus/reduce on concrete (non-Term) values without applying any rewrites. -""" -NormalizeReduceIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - ReduceNoStreams(), - ReduceFusion(), - ReduceSplit(), - ReduceFactorization(), - ReduceDistributeCartesianProduct(), - ], - ), -) +""" -NormalizePlusIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - PlusEmpty(), - PlusSingle(), - PlusIdentity(), - PlusAssoc(), - PlusDistr(), - PlusZero(), - PlusConsecutiveDups(), - PlusDups(), - ], - ), +NormalizeIntp = _ExtensibleInterpretation().extend( + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), ) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). - -NormalizeIntp = coproduct( - coproduct(EvaluateIntp, NormalizeReduceIntp), NormalizePlusIntp -) -"""Rewrites *plus* evaluation. ``NormalizeIntp`` is a superset of -:data:`EvaluateIntp`: it applies pure-Term rewrites (associativity, -distributivity, identity elimination, fusion, factorization, etc.) and also -evaluates concrete arguments through the kernels. """ diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index fd785c0af..b51a74dfe 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -10,7 +10,7 @@ from effectful.internals.runtime import interpreter from effectful.ops.semantics import apply, evaluate from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq -from effectful.ops.types import NotHandled, Operation +from effectful.ops.types import NotHandled, Operation, Term _JAX_ARRAY_SHAPE = (3,) @@ -18,7 +18,7 @@ def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: return st.integers(min_value=0, max_value=2**31 - 1).map( lambda seed: jax.random.uniform( - jax.random.PRNGKey(seed), _JAX_ARRAY_SHAPE, minval=0.5, maxval=1.5 + jax.random.PRNGKey(seed), _JAX_ARRAY_SHAPE, minval=-1.5, maxval=1.5 ) ) @@ -41,11 +41,11 @@ def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" if annotation is int: - return st.integers() + return st.integers(min_value=-100, max_value=100) if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(), max_size=2) + return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) if annotation is jax.Array: return _jax_array_value_strategy() if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): @@ -262,7 +262,6 @@ class Backend: stream_typ: Any scalar_strategy: st.SearchStrategy[Any] eq: Callable[[Any, Any], bool] - lift: Callable[[Any], Any] def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: """Build a fresh, unhandled Operation whose parameter and return @@ -286,23 +285,11 @@ def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation def _int_eq(a: Any, b: Any) -> bool: - return a == b + return not isinstance(a, Term) and not isinstance(b, Term) and a == b def _jax_eq(a: Any, b: Any) -> bool: - if isinstance(a, dict) and isinstance(b, dict): - if set(a.keys()) != set(b.keys()): - return False - return all(_jax_eq(a[k], b[k]) for k in a) - if isinstance(a, tuple | list) and isinstance(b, tuple | list): - if len(a) != len(b): - return False - return all(_jax_eq(x, y) for x, y in zip(a, b, strict=True)) - if isinstance(a, jax.Array) or isinstance(b, jax.Array): - aa, bb = jax.numpy.asarray(a), jax.numpy.asarray(b) - aa, bb = jax.numpy.broadcast_arrays(aa, bb) - return bool(jax.numpy.all(jax.numpy.isclose(aa, bb, equal_nan=True))) - return a == b + return bool(jax.numpy.allclose(a, b)) INT_BACKEND = Backend( @@ -311,7 +298,6 @@ def _jax_eq(a: Any, b: Any) -> bool: stream_typ=list[int], scalar_strategy=st.integers(min_value=-100, max_value=100), eq=_int_eq, - lift=lambda x: x, ) @@ -321,7 +307,6 @@ def _jax_eq(a: Any, b: Any) -> bool: stream_typ=list[jax.Array], scalar_strategy=_jax_array_value_strategy(), eq=_jax_eq, - lift=lambda x: jax.numpy.asarray(x, dtype=jax.numpy.float32), ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 1c1931fbb..1e37c0b2b 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -4,9 +4,10 @@ from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from effectful.handlers.jax.monoid import JaxEvaluateIntp +import effectful.handlers.jax.monoid # noqa: F401 from effectful.ops.monoid import ( CartesianProduct, + EvaluateIntp, Max, Min, Monoid, @@ -16,7 +17,7 @@ distributes_over, is_commutative, ) -from effectful.ops.semantics import coproduct, evaluate, fvsof, handler +from effectful.ops.semantics import evaluate, fvsof, handler from effectful.ops.types import Operation from tests._monoid_helpers import ( INT_BACKEND, @@ -33,18 +34,6 @@ def backend(request) -> Backend: return request.param -@pytest.fixture(autouse=True) -def _install_normalize_intp(backend): - """Install :data:`NormalizeIntp` (plus JAX kernels when the backend is - jax) for every test in this module. - """ - intp = NormalizeIntp - if backend.scalar_typ is not int: - intp = coproduct(intp, JaxEvaluateIntp) - with handler(intp): - yield - - ALL_MONOIDS = [ pytest.param(Sum, id="Sum"), pytest.param(Product, id="Product"), @@ -83,51 +72,74 @@ def _install_normalize_intp(backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(data=st.data()) -@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) def test_associativity(monoid, backend, data): a = data.draw(backend.scalar_strategy) b = data.draw(backend.scalar_strategy) c = data.draw(backend.scalar_strategy) - left = monoid.plus(monoid.plus(a, b), c) - right = monoid.plus(a, monoid.plus(b, c)) - assert backend.eq(left, right) + with handler(EvaluateIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(data=st.data()) -@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) def test_identity(monoid, backend, data): a = data.draw(backend.scalar_strategy) - ident = backend.lift(monoid.identity) - assert backend.eq(monoid.plus(ident, a), a) - assert backend.eq(monoid.plus(a, ident), a) + with handler(EvaluateIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) @pytest.mark.parametrize("monoid", COMMUTATIVE) @given(data=st.data()) -@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) def test_commutativity(monoid, backend, data): a = data.draw(backend.scalar_strategy) b = data.draw(backend.scalar_strategy) - assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) + with handler(EvaluateIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @pytest.mark.parametrize("monoid", IDEMPOTENT) @given(data=st.data()) -@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) def test_idempotence(monoid, backend, data): a = data.draw(backend.scalar_strategy) - assert backend.eq(monoid.plus(a, a), a) + with handler(EvaluateIntp): + assert backend.eq(monoid.plus(a, a), a) @pytest.mark.parametrize("monoid", WITH_ZERO) @given(data=st.data()) -@settings(max_examples=50, deadline=None, suppress_health_check=[HealthCheck.function_scoped_fixture]) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) def test_zero_absorbs(monoid, backend, data): a = data.draw(backend.scalar_strategy) - zero = backend.lift(monoid.zero) - assert backend.eq(monoid.plus(zero, a), monoid.zero) - assert backend.eq(monoid.plus(a, zero), monoid.zero) + with handler(EvaluateIntp): + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) def _check_pair( @@ -146,7 +158,7 @@ def _check_pair( suppress_health_check=[HealthCheck.function_scoped_fixture], ) def _check_semantics(intp): - with handler(intp): + with handler(EvaluateIntp), handler(intp): lhs_val = evaluate(lhs) rhs_val = evaluate(rhs) assert backend.eq(lhs_val, rhs_val) @@ -310,9 +322,7 @@ def test_plus_commutative_idempotent_long(backend): """Long alternation collapses via commutative dedup (Min/Max only).""" a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair( - lhs=lhs, rhs=Min.plus(a(), b()), backend=backend, free_vars=[a, b] - ) + _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) @@ -540,9 +550,7 @@ def test_reduce_cartesian_2(backend): @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) def test_reduce_lifted_multi_index(outer, inner, backend): a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) - A, N, M, A_domain = define_vars( - "A", "N", "M", "A_domain", typ=backend.stream_typ - ) + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( @@ -553,9 +561,7 @@ def test_reduce_lifted_multi_index(outer, inner, backend): outer.reduce(f(a()), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair( - lhs=term1, rhs=term2, backend=backend, free_vars=[N, M, A_domain, f] - ) + _check_pair(lhs=term1, rhs=term2, backend=backend, free_vars=[N, M, A_domain, f]) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) From afea06c6fb0d0906032807721c8dd9b857a6d358 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 12:37:00 -0400 Subject: [PATCH 29/34] wip --- effectful/handlers/jax/monoid.py | 16 +- effectful/ops/monoid.py | 83 ++++------ tests/test_ops_monoid.py | 266 +++++++++++++++++++++---------- 3 files changed, 222 insertions(+), 143 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 378207e60..622283eb4 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -1,5 +1,4 @@ import functools -import typing import jax @@ -8,7 +7,6 @@ from effectful.handlers.jax.scipy.special import logsumexp from effectful.ops.monoid import ( CartesianProduct, - EvaluateIntp, Max, Min, Monoid, @@ -17,9 +15,9 @@ Sum, outer_stream, ) -from effectful.ops.semantics import coproduct, evaluate, fwd, handler, typeof +from effectful.ops.semantics import evaluate, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements -from effectful.ops.types import Interpretation, Operation, Term +from effectful.ops.types import Operation def cartesian_prod(x, y): @@ -38,7 +36,11 @@ def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete :class:`jax.Array` (no Terms). """ - return bool(args) and all(not isinstance(a, Term) for a in args) + return ( + bool(args) + and any(isinstance(a, jax.Array) for a in args) + and all(isinstance(a, jax.typing.ArrayLike) for a in args) + ) class SumPlusJax(ObjectInterpretation): @@ -136,8 +138,8 @@ def reduce(self, monoid, body, streams): return fwd() -NormalizeIntp.extend(ArrayReduce()) -EvaluateIntp.extend( +NormalizeIntp.extend( + ArrayReduce(), SumPlusJax(), ProductPlusJax(), MinPlusJax(), diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 541d4b063..6aa7453fc 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -705,38 +705,32 @@ def to_tuple(x): ] -def sequence_broadcasting(monoid: "Monoid") -> ObjectInterpretation: - """Return an :class:`ObjectInterpretation` that broadcasts ``monoid.plus`` - and ``monoid.reduce`` elementwise over tuples, lists, and generators. - - Tuples and lists are reconstructed as their input type; generators stay - generators. Appropriate for monoids whose values are scalars; *not* for - monoids whose values *are* sequences (:data:`CartesianProduct`) or whose - tuples carry meaning (:data:`ArgMin` / :data:`ArgMax`). - """ +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) - class _SequenceBroadcasting(ObjectInterpretation): - @implements(monoid.plus) - def plus(self, *args): - if not args or not isinstance(args[0], tuple | list | Generator): - return fwd() - zipped = zip(*args, strict=True) - result = (monoid.plus(*vs) for vs in zipped) - if isinstance(args[0], tuple | list): - return type(args[0])(result) - return result - - @implements(monoid.reduce) - def reduce(self, body, streams): - if not isinstance(body, tuple | list | Generator): - return fwd() - result = (monoid.reduce(x, streams) for x in body) - if isinstance(body, tuple | list): - return type(body)(result) - return result - _SequenceBroadcasting.__name__ = f"{monoid._name}OverSequence" - return _SequenceBroadcasting() +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result class _ExtensibleInterpretation(UserDict, Interpretation): @@ -746,28 +740,10 @@ def extend(self, *intps: Interpretation) -> typing.Self: return self -EvaluateIntp = _ExtensibleInterpretation().extend( - # tuple/list/Generator broadcasting, only for monoids whose values - # are scalars (not CartesianProduct, ArgMin, or ArgMax). - *(sequence_broadcasting(m) for m in (Sum, Min, Max, Product)), - # universal broadcasting +NormalizeIntp = _ExtensibleInterpretation().extend( + MonoidOverSequence(), MonoidOverMapping(), MonoidOverCallable(), - # per-monoid concrete kernels - SumPlus(), - MinPlus(), - MaxPlus(), - ProductPlus(), - ArgMinPlus(), - ArgMaxPlus(), - CartesianProductPlus(), -) -"""Concrete-value kernels and universal broadcasting. Install to evaluate -plus/reduce on concrete (non-Term) values without applying any rewrites. - -""" - -NormalizeIntp = _ExtensibleInterpretation().extend( ReduceNoStreams(), ReduceFusion(), ReduceSplit(), @@ -781,6 +757,13 @@ def extend(self, *intps: Interpretation) -> typing.Self: PlusZero(), PlusConsecutiveDups(), PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), ) """``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, identity elimination, fusion, factorization, etc.). diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 1e37c0b2b..ce0a3b22d 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -7,12 +7,26 @@ import effectful.handlers.jax.monoid # noqa: F401 from effectful.ops.monoid import ( CartesianProduct, - EvaluateIntp, Max, Min, Monoid, + MonoidOverMapping, + MonoidOverSequence, NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusIdentity, + PlusSingle, + PlusZero, Product, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReduceNoStreams, + ReduceSplit, Sum, distributes_over, is_commutative, @@ -70,6 +84,28 @@ def backend(request) -> Backend: ] +def _check_rewrite( + lhs, rhs, rule, *, backend: Backend, free_vars=[], max_examples: int = 25 +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings( + max_examples=max_examples, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], + ) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(data=st.data()) @settings( @@ -81,7 +117,7 @@ def test_associativity(monoid, backend, data): a = data.draw(backend.scalar_strategy) b = data.draw(backend.scalar_strategy) c = data.draw(backend.scalar_strategy) - with handler(EvaluateIntp): + with handler(NormalizeIntp): left = monoid.plus(monoid.plus(a, b), c) right = monoid.plus(a, monoid.plus(b, c)) assert backend.eq(left, right) @@ -96,7 +132,7 @@ def test_associativity(monoid, backend, data): ) def test_identity(monoid, backend, data): a = data.draw(backend.scalar_strategy) - with handler(EvaluateIntp): + with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.identity, a), a) assert backend.eq(monoid.plus(a, monoid.identity), a) @@ -111,7 +147,7 @@ def test_identity(monoid, backend, data): def test_commutativity(monoid, backend, data): a = data.draw(backend.scalar_strategy) b = data.draw(backend.scalar_strategy) - with handler(EvaluateIntp): + with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @@ -124,7 +160,7 @@ def test_commutativity(monoid, backend, data): ) def test_idempotence(monoid, backend, data): a = data.draw(backend.scalar_strategy) - with handler(EvaluateIntp): + with handler(NormalizeIntp): assert backend.eq(monoid.plus(a, a), a) @@ -137,74 +173,57 @@ def test_idempotence(monoid, backend, data): ) def test_zero_absorbs(monoid, backend, data): a = data.draw(backend.scalar_strategy) - with handler(EvaluateIntp): + with handler(NormalizeIntp): assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) -def _check_pair( - lhs, rhs, *, backend: Backend, free_vars=[], max_examples: int = 25 -) -> None: - """Run structural + semantic checks on a TermPair.""" - with handler(NormalizeIntp): - norm = evaluate(lhs) - - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings( - max_examples=max_examples, - deadline=None, - suppress_health_check=[HealthCheck.function_scoped_fixture], - ) - def _check_semantics(intp): - with handler(EvaluateIntp), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_empty(monoid, backend): - _check_pair(lhs=monoid.plus(), rhs=monoid.identity, backend=backend) + _check_rewrite( + lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_single(monoid, backend): x = define_vars("x", typ=backend.scalar_typ) - _check_pair(lhs=monoid.plus(x()), rhs=x(), backend=backend, free_vars=[x]) + _check_rewrite( + lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_identity_right(monoid, backend): x = define_vars("x", typ=backend.scalar_typ) - _check_pair( - lhs=monoid.plus(x(), monoid.identity), - rhs=x(), - backend=backend, - free_vars=[x], + + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) + + _check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x] ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_identity_left(monoid, backend): x = define_vars("x", typ=backend.scalar_typ) - _check_pair( - lhs=monoid.plus(monoid.identity, x()), - rhs=x(), - backend=backend, - free_vars=[x], + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + _check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x] ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_assoc_right(monoid, backend): x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - _check_pair( + _check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), backend=backend, free_vars=[x, y, z], ) @@ -213,9 +232,10 @@ def test_plus_assoc_right(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_assoc_left(monoid, backend): x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - _check_pair( + _check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), backend=backend, free_vars=[x, y, z], ) @@ -224,9 +244,10 @@ def test_plus_assoc_left(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_sequence(monoid, backend): a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - _check_pair( + _check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), backend=backend, free_vars=[a, b, c, d], ) @@ -235,9 +256,10 @@ def test_plus_sequence(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_mapping(monoid, backend): a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - _check_pair( + _check_rewrite( lhs=monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}), rhs={0: monoid.plus(a(), c()), 1: b(), 2: d()}, + rule=MonoidOverMapping(), backend=backend, free_vars=[a, b, c, d], ) @@ -252,7 +274,9 @@ def test_plus_distributes(backend): Product.plus(b(), c()), Product.plus(b(), d()), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) def test_plus_distributes_constant(backend): @@ -267,7 +291,9 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) def test_plus_distributes_multiple(backend): @@ -292,7 +318,9 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b, c, d]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) @@ -300,8 +328,12 @@ def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_pair( - lhs=lhs, rhs=monoid.plus(a(), b()), backend=backend, free_vars=[a, b] + return _check_rewrite( + lhs=lhs, + rhs=monoid.plus(a(), b()), + rule=PlusConsecutiveDups(), + backend=backend, + free_vars=[a, b], ) @@ -315,14 +347,16 @@ def test_plus_idempotent_non_consecutive(monoid, backend): rhs = monoid.plus(a(), b()) else: rhs = monoid.plus(a(), b(), a()) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b]) + _check_rewrite(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b]) -def test_plus_commutative_idempotent_long(backend): +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend): """Long alternation collapses via commutative dedup (Min/Max only).""" a, b = define_vars("a", "b", typ=backend.scalar_typ) - lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), backend=backend, free_vars=[a, b]) + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + _check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) @@ -330,18 +364,21 @@ def test_plus_zero(monoid, backend): a = define_vars("a", typ=backend.scalar_typ) lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) - _check_pair(lhs=lhs_right, rhs=monoid.zero, backend=backend, free_vars=[a]) - _check_pair(lhs=lhs_left, rhs=monoid.zero, backend=backend, free_vars=[a]) + rhs = monoid.zero + _check_rewrite( + lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + _check_rewrite( + lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_partial_1(monoid, backend): x, y = define_vars("x", "y", typ=backend.scalar_typ) - lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y]) + _check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -352,7 +389,7 @@ def test_partial_2(monoid, backend): lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, Y]) + _check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -363,7 +400,9 @@ def test_partial_3(monoid, backend): lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, a, b, Y]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -374,7 +413,9 @@ def test_partial_4(monoid, backend): lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[x, y, a, b, f]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -387,7 +428,13 @@ def test_reduce_body_sequence(monoid, backend): lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, f, g]) + _check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -403,7 +450,13 @@ def test_reduce_body_sequence_2(monoid, backend): monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, Y, f, g]) + _check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, Y, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -418,7 +471,13 @@ def test_reduce_body_mapping(monoid, backend): 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[X, f, g]) + _check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -427,7 +486,9 @@ def test_reduce_no_streams(monoid, backend): lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -439,7 +500,9 @@ def test_reduce_reduce(monoid, backend): lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, f]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] + ) @pytest.mark.parametrize("monoid", COMMUTATIVE) @@ -451,7 +514,9 @@ def test_reduce_plus(monoid, backend): monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] + ) def test_reduce_independent_1(backend): @@ -459,7 +524,9 @@ def test_reduce_independent_1(backend): A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B]) + _check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] + ) def test_reduce_independent_2(backend): @@ -472,7 +539,13 @@ def test_reduce_independent_2(backend): Sum.reduce(a(), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, C, f]) + _check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) def test_reduce_independent_3_negative(backend): @@ -483,7 +556,7 @@ def test_reduce_independent_3_negative(backend): f = backend.fresh_op("f", n_args=2, ret="scalar") g = backend.fresh_op("g", n_args=1, ret="stream") - with handler(NormalizeIntp): + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} ) @@ -506,7 +579,13 @@ def test_reduce_independent_4(backend): Sum.reduce(a(), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, backend=backend, free_vars=[A, B, C, f]) + _check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) @@ -520,18 +599,25 @@ def test_reduce_lifted_1(outer, inner, backend): {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) - _check_pair(lhs=term1, rhs=term2, backend=backend, free_vars=[N, A_domain, f]) + _check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, A_domain, f], + ) def test_reduce_cartesian_1(backend): a, i = define_vars("a", "i", typ=backend.scalar_typ) A = define_vars("A", typ=backend.stream_typ) - term1 = Sum.reduce( - Product.reduce(a(), {a: []}), - {A: CartesianProduct.reduce([], {i: []})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) assert term1 == term2 @@ -539,11 +625,12 @@ def test_reduce_cartesian_2(backend): a, i = define_vars("a", "i", typ=backend.scalar_typ) A = define_vars("A", typ=backend.stream_typ) - term1 = Sum.reduce( - Product.reduce(a(), {a: A()}), - {A: CartesianProduct.reduce([(0,)], {i: [0]})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) assert term1 == term2 @@ -561,7 +648,13 @@ def test_reduce_lifted_multi_index(outer, inner, backend): outer.reduce(f(a()), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair(lhs=term1, rhs=term2, backend=backend, free_vars=[N, M, A_domain, f]) + _check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, M, A_domain, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) @@ -589,9 +682,10 @@ def test_reduce_lifted_2(outer, inner, backend): {t: T()}, ) - _check_pair( + _check_rewrite( lhs=term1, rhs=term2, + rule=ReduceDistributeCartesianProduct(), backend=backend, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], ) From 08d92ae3e33e515220818c9726e6016b2a0c1075 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 12:42:30 -0400 Subject: [PATCH 30/34] wip --- tests/_monoid_helpers.py | 9 ++++++++- tests/test_ops_monoid.py | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index b51a74dfe..c4d08b605 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -289,7 +289,14 @@ def _int_eq(a: Any, b: Any) -> bool: def _jax_eq(a: Any, b: Any) -> bool: - return bool(jax.numpy.allclose(a, b)) + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) INT_BACKEND = Backend( diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index ce0a3b22d..e483facf6 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -256,9 +256,13 @@ def test_plus_sequence(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_mapping(monoid, backend): a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + _check_rewrite( - lhs=monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}), - rhs={0: monoid.plus(a(), c()), 1: b(), 2: d()}, + lhs=lhs, + rhs=rhs, rule=MonoidOverMapping(), backend=backend, free_vars=[a, b, c, d], From b50c0594f0fbc63dcb7a87222d118f117944b52d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 12:57:05 -0400 Subject: [PATCH 31/34] wip --- tests/_monoid_helpers.py | 16 ++++++++-------- tests/test_ops_monoid.py | 38 +++++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 25 deletions(-) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index c4d08b605..8d6693c44 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -16,19 +16,19 @@ def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: - return st.integers(min_value=0, max_value=2**31 - 1).map( - lambda seed: jax.random.uniform( - jax.random.PRNGKey(seed), _JAX_ARRAY_SHAPE, minval=-1.5, maxval=1.5 - ) - ) + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=_JAX_ARRAY_SHAPE[0], + max_size=_JAX_ARRAY_SHAPE[0], + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) # Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` # for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. _UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ - lambda a: _jnp.stack([a, a + 1.0]), + lambda a: _jnp.stack([a, a + 1]), lambda a: _jnp.stack([a, -a]), - lambda a: _jnp.stack([a, a + 1.0, 2.0 * a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), ] _BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ @@ -84,7 +84,7 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: _UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ lambda _x: [], lambda x: [x], - lambda x: [x, x + 1.0], + lambda x: [x, x + 1], lambda x: [x, -x], ] diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index e483facf6..09bfd47de 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -272,11 +272,13 @@ def test_plus_mapping(monoid, backend): def test_plus_distributes(backend): a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) - rhs = Sum.plus( - Product.plus(a(), c()), - Product.plus(a(), d()), - Product.plus(b(), c()), - Product.plus(b(), d()), + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) ) _check_rewrite( lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] @@ -344,14 +346,11 @@ def test_plus_idempotent_consecutive(monoid, backend): @pytest.mark.parametrize("monoid", IDEMPOTENT) def test_plus_idempotent_non_consecutive(monoid, backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative - PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" + PlusDups.""" a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) - if is_commutative(monoid): - rhs = monoid.plus(a(), b()) - else: - rhs = monoid.plus(a(), b(), a()) - _check_rewrite(lhs=lhs, rhs=rhs, backend=backend, free_vars=[a, b]) + rhs = monoid.plus(a(), b()) + _check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", [Min, Max]) @@ -527,7 +526,9 @@ def test_reduce_independent_1(backend): a, b = define_vars("a", "b", typ=backend.scalar_typ) A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) - rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) _check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] ) @@ -540,7 +541,7 @@ def test_reduce_independent_2(backend): lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_rewrite( @@ -580,7 +581,7 @@ def test_reduce_independent_4(backend): lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( 7, - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) _check_rewrite( @@ -602,7 +603,8 @@ def test_reduce_lifted_1(outer, inner, backend): inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) + term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + _check_rewrite( lhs=term1, rhs=term2, @@ -649,7 +651,7 @@ def test_reduce_lifted_multi_index(outer, inner, backend): {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) term2 = inner.reduce( - outer.reduce(f(a()), {a: A_domain()}), + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()}, ) _check_rewrite( @@ -680,7 +682,9 @@ def test_reduce_lifted_2(outer, inner, backend): term2 = outer.reduce( inner.reduce( - outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), {i: N()}, ), {t: T()}, From 7e204d08d6d5f7f210ec62e90c3db09939d74232 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 13:47:31 -0400 Subject: [PATCH 32/34] use check_rewrite in jax tests --- effectful/handlers/jax/monoid.py | 10 +++- tests/_monoid_helpers.py | 34 ++++++++++- tests/test_handlers_jax_monoid.py | 77 ++++++++++++------------ tests/test_ops_monoid.py | 99 +++++++++++-------------------- 4 files changed, 109 insertions(+), 111 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index 622283eb4..e3d40e34f 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -25,8 +25,12 @@ def cartesian_prod(x, y): x = x[:, None] if y.ndim == 1: y = y[:, None] - x, y = jnp.repeat(x, y.shape[0], axis=0), jnp.tile(y, (x.shape[0], 1)) - return jnp.hstack([x, y]) + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) @@ -123,7 +127,7 @@ def reduce(self, monoid, body, streams): reductor = ARRAY_REDUCTORS[monoid] index = Operation.define(jax.Array) for stream_key, stream_body, streams_tail in outer_stream(streams): - if typeof(stream_body) is not jax.Array: + if not issubclass(typeof(stream_body), jax.Array): continue with handler({stream_key: deffn(unbind_dims(stream_body, index))}): eval_body = evaluate(body) diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 8d6693c44..2444e895f 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -4,15 +4,17 @@ from typing import Any, get_args, get_origin import jax +from hypothesis import given, settings from hypothesis import strategies as st import effectful.handlers.jax.numpy as _jnp from effectful.internals.runtime import interpreter -from effectful.ops.semantics import apply, evaluate +from effectful.ops.monoid import NormalizeIntp +from effectful.ops.semantics import apply, evaluate, handler from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq from effectful.ops.types import NotHandled, Operation, Term -_JAX_ARRAY_SHAPE = (3,) +_JAX_ARRAY_SHAPE = (2,) def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: @@ -299,6 +301,31 @@ def _leaf_eq(x: Any, y: Any) -> bool: return all(leaves) +def check_rewrite( + lhs, + rhs, + rule, + *, + backend: Backend, + free_vars=[], + max_examples: int = 25, + deadline=None, +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=deadline) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + INT_BACKEND = Backend( name="int", scalar_typ=int, @@ -311,7 +338,7 @@ def _leaf_eq(x: Any, y: Any) -> bool: JAX_BACKEND = Backend( name="jax", scalar_typ=jax.Array, - stream_typ=list[jax.Array], + stream_typ=jax.Array, scalar_strategy=_jax_array_value_strategy(), eq=_jax_eq, ) @@ -324,4 +351,5 @@ def _leaf_eq(x: Any, y: Any) -> bool: "random_interpretation", "define_vars", "syntactic_eq_alpha", + "check_rewrite", ] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py index 9120acd07..35d041fe2 100644 --- a/tests/test_handlers_jax_monoid.py +++ b/tests/test_handlers_jax_monoid.py @@ -3,27 +3,10 @@ import effectful.handlers.jax.numpy as jnp from effectful.handlers.jax import bind_dims, unbind_dims -from effectful.handlers.jax.monoid import ( - JaxEvaluateIntp, - LogSumExp, - Max, - Min, - Product, - Sum, -) +from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp from effectful.handlers.jax.scipy.special import logsumexp -from effectful.ops.monoid import EvaluateIntp -from effectful.ops.semantics import coproduct, handler -from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import define_vars, syntactic_eq_alpha - - -@pytest.fixture(autouse=True) -def _install_evaluate(): - """Install scalar + JAX evaluation kernels for every test in this module.""" - with handler(coproduct(EvaluateIntp, JaxEvaluateIntp)): - yield - +from effectful.ops.monoid import Max, Min, Product, Sum +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars MONOIDS = [ pytest.param(Sum, jnp.sum, id="Sum"), @@ -34,23 +17,29 @@ def _install_evaluate(): ] +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_1(monoid, reductor): - (x, X, k) = define_vars("x", "X", "k", typ=jax.Array) +def test_reduce_array_1(monoid, reductor, backend: Backend): + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) lhs = monoid.reduce(x(), {x: X()}) rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) - assert syntactic_eq_alpha(lhs, rhs) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] + ) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_2(monoid, reductor): - (x, y, X, Y, k1, k2) = define_vars("x", "y", "X", "Y", "k1", "k2", typ=jax.Array) - - @Operation.define - def f(_a: jax.Array, _b: jax.Array) -> jax.Array: - raise NotHandled +def test_reduce_array_2(monoid, reductor, backend: Backend): + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) rhs = reductor( @@ -64,22 +53,24 @@ def f(_a: jax.Array, _b: jax.Array) -> jax.Array: axis=0, ) - assert syntactic_eq_alpha(lhs, rhs) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, Y, f], + ) @pytest.mark.parametrize("monoid,reductor", MONOIDS) -def test_reduce_array_3(monoid, reductor): +def test_reduce_array_3(monoid, reductor, backend: Backend): """Stream `y` is `g(x())` — depends on the bound element of X. The reducer must inline ``g`` along the same named dim used to unbind `x`.""" - (x, y, X, k1, k2) = define_vars("x", "y", "X", "k1", "k2", typ=jax.Array) + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) - @Operation.define - def f(_a: jax.Array, _b: jax.Array) -> jax.Array: - raise NotHandled - - @Operation.define - def g(_a: jax.Array) -> jax.Array: - raise NotHandled + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) rhs = reductor( @@ -96,4 +87,10 @@ def g(_a: jax.Array) -> jax.Array: axis=0, ) - assert syntactic_eq_alpha(lhs, rhs) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, f, g], + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 09bfd47de..62ee82b07 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -29,16 +29,15 @@ ReduceSplit, Sum, distributes_over, - is_commutative, ) -from effectful.ops.semantics import evaluate, fvsof, handler +from effectful.ops.semantics import fvsof, handler from effectful.ops.types import Operation from tests._monoid_helpers import ( INT_BACKEND, JAX_BACKEND, Backend, + check_rewrite, define_vars, - random_interpretation, syntactic_eq_alpha, ) @@ -84,28 +83,6 @@ def backend(request) -> Backend: ] -def _check_rewrite( - lhs, rhs, rule, *, backend: Backend, free_vars=[], max_examples: int = 25 -) -> None: - with handler(rule): - norm = evaluate(lhs) - assert syntactic_eq_alpha(norm, rhs) - - @given(intp=random_interpretation(free_vars)) - @settings( - max_examples=max_examples, - deadline=None, - suppress_health_check=[HealthCheck.function_scoped_fixture], - ) - def _check_semantics(intp): - with handler(NormalizeIntp), handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert backend.eq(lhs_val, rhs_val) - - _check_semantics() - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) @given(data=st.data()) @settings( @@ -180,7 +157,7 @@ def test_zero_absorbs(monoid, backend, data): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_empty(monoid, backend): - _check_rewrite( + check_rewrite( lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend ) @@ -188,7 +165,7 @@ def test_plus_empty(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_single(monoid, backend): x = define_vars("x", typ=backend.scalar_typ) - _check_rewrite( + check_rewrite( lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] ) @@ -200,9 +177,7 @@ def test_plus_identity_right(monoid, backend): lhs = monoid.plus(x(), monoid.identity) rhs = monoid.plus(x()) - _check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x] - ) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -212,15 +187,13 @@ def test_plus_identity_left(monoid, backend): lhs = monoid.plus(monoid.identity, x()) rhs = monoid.plus(x()) - _check_rewrite( - lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x] - ) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_assoc_right(monoid, backend): x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - _check_rewrite( + check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), @@ -232,7 +205,7 @@ def test_plus_assoc_right(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_assoc_left(monoid, backend): x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) - _check_rewrite( + check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), rule=PlusAssoc(), @@ -244,7 +217,7 @@ def test_plus_assoc_left(monoid, backend): @pytest.mark.parametrize("monoid", ALL_MONOIDS) def test_plus_sequence(monoid, backend): a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) - _check_rewrite( + check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), rule=MonoidOverSequence(), @@ -260,7 +233,7 @@ def test_plus_mapping(monoid, backend): lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=MonoidOverMapping(), @@ -280,7 +253,7 @@ def test_plus_distributes(backend): Product.plus(b(), d()), ) ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) @@ -297,7 +270,7 @@ def test_plus_distributes_constant(backend): Product.plus(b(), d()), ), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) @@ -324,7 +297,7 @@ def test_plus_distributes_multiple(backend): Sum.plus(b(), d()), ), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) @@ -334,7 +307,7 @@ def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_rewrite( + return check_rewrite( lhs=lhs, rhs=monoid.plus(a(), b()), rule=PlusConsecutiveDups(), @@ -350,7 +323,7 @@ def test_plus_idempotent_non_consecutive(monoid, backend): a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) rhs = monoid.plus(a(), b()) - _check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", [Min, Max]) @@ -359,7 +332,7 @@ def test_plus_commutative_idempotent_long(monoid, backend): a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) rhs = monoid.plus(a(), b()) - _check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) @@ -368,10 +341,10 @@ def test_plus_zero(monoid, backend): lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) rhs = monoid.zero - _check_rewrite( + check_rewrite( lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] ) - _check_rewrite( + check_rewrite( lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] ) @@ -381,7 +354,7 @@ def test_partial_1(monoid, backend): x, y = define_vars("x", "y", typ=backend.scalar_typ) lhs = monoid.reduce(x(), {x: []}) rhs = monoid.identity - _check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -392,7 +365,7 @@ def test_partial_2(monoid, backend): lhs = monoid.reduce(x(), {y: Y(), x: []}) rhs = monoid.identity - _check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -403,9 +376,7 @@ def test_partial_3(monoid, backend): lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) - _check_rewrite( - lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y] - ) + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -416,9 +387,7 @@ def test_partial_4(monoid, backend): lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - _check_rewrite( - lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f] - ) + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) @@ -431,7 +400,7 @@ def test_reduce_body_sequence(monoid, backend): lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=MonoidOverSequence(), @@ -453,7 +422,7 @@ def test_reduce_body_sequence_2(monoid, backend): monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=MonoidOverSequence(), @@ -474,7 +443,7 @@ def test_reduce_body_mapping(monoid, backend): 0: monoid.reduce(f(x()), {x: X()}), 1: monoid.reduce(g(x()), {x: X()}), } - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=MonoidOverMapping(), @@ -489,7 +458,7 @@ def test_reduce_no_streams(monoid, backend): lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] ) @@ -503,7 +472,7 @@ def test_reduce_reduce(monoid, backend): lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] ) @@ -517,7 +486,7 @@ def test_reduce_plus(monoid, backend): monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] ) @@ -529,7 +498,7 @@ def test_reduce_independent_1(backend): rhs = Product.plus( Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] ) @@ -544,7 +513,7 @@ def test_reduce_independent_2(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceFactorization(), @@ -584,7 +553,7 @@ def test_reduce_independent_4(backend): Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_rewrite( + check_rewrite( lhs=lhs, rhs=rhs, rule=ReduceFactorization(), @@ -605,7 +574,7 @@ def test_reduce_lifted_1(outer, inner, backend): ) term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) - _check_rewrite( + check_rewrite( lhs=term1, rhs=term2, rule=ReduceDistributeCartesianProduct(), @@ -654,7 +623,7 @@ def test_reduce_lifted_multi_index(outer, inner, backend): outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()}, ) - _check_rewrite( + check_rewrite( lhs=term1, rhs=term2, rule=ReduceDistributeCartesianProduct(), @@ -690,7 +659,7 @@ def test_reduce_lifted_2(outer, inner, backend): {t: T()}, ) - _check_rewrite( + check_rewrite( lhs=term1, rhs=term2, rule=ReduceDistributeCartesianProduct(), From 77241f9d80f71b783a46a9ffd09f61a50519f3fe Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 14 May 2026 13:51:06 -0400 Subject: [PATCH 33/34] lint --- effectful/ops/monoid.py | 4 ++-- tests/_monoid_helpers.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 6aa7453fc..70bb50022 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -224,7 +224,7 @@ class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" @implements(Monoid.plus) - def plus(self, monoid, *args): + def plus(self, monoid: Monoid, *args): if any( isinstance(x, Term) and _is_monoid_plus(x.op) @@ -736,7 +736,7 @@ def reduce(self, monoid, body, streams): class _ExtensibleInterpretation(UserDict, Interpretation): def extend(self, *intps: Interpretation) -> typing.Self: for intp in intps: - self.data = coproduct(self.data, intp) + self.data = coproduct(self.data, intp) # type: ignore[assignment] return self diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 2444e895f..f15103e30 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -183,7 +183,7 @@ def _substitute(arg, renaming): with interpreter({apply: _BaseTerm, **renaming}): return evaluate(arg) - def _bound_var_order(args, kwargs, bound_set): + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: """Return bound variables in deterministic encounter order.""" seen: list[Operation] = [] seen_set: set[Operation] = set() @@ -218,7 +218,7 @@ def _walk_bare(obj): _walk_bare((args, kwargs)) return seen - def _apply_canonical(op, *args, **kwargs): + def _apply_canonical(op, *args, **kwargs) -> Term: bindings = op.__fvs_rule__(*args, **kwargs) all_bound: set[Operation] = set().union( *bindings.args, *bindings.kwargs.values() From 5e6d333fbc179dec6b2dab53413807342ce00467 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 21 May 2026 13:31:12 -0400 Subject: [PATCH 34/34] fix bugs --- effectful/handlers/jax/monoid.py | 35 ++++++++++++++++++++------------ tests/test_ops_monoid.py | 12 +++++------ 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py index e3d40e34f..a406cda5b 100644 --- a/effectful/handlers/jax/monoid.py +++ b/effectful/handlers/jax/monoid.py @@ -15,7 +15,7 @@ Sum, outer_stream, ) -from effectful.ops.semantics import evaluate, fwd, handler, typeof +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof from effectful.ops.syntax import ObjectInterpretation, deffn, implements from effectful.ops.types import Operation @@ -40,10 +40,11 @@ def _jax_args(args): """True iff ``args`` is non-empty and every arg is a concrete :class:`jax.Array` (no Terms). """ + typs = (typeof(a) for a in args) return ( bool(args) - and any(isinstance(a, jax.Array) for a in args) - and all(isinstance(a, jax.typing.ArrayLike) for a in args) + and any(issubclass(t, jax.Array) for t in typs) + and all(issubclass(t, jax.typing.ArrayLike) for t in typs) ) @@ -129,16 +130,24 @@ def reduce(self, monoid, body, streams): for stream_key, stream_body, streams_tail in outer_stream(streams): if not issubclass(typeof(stream_body), jax.Array): continue - with handler({stream_key: deffn(unbind_dims(stream_body, index))}): - eval_body = evaluate(body) - eval_streams_tail = evaluate(streams_tail) - assert isinstance(eval_streams_tail, dict) - reduce_tail = ( - monoid.reduce(eval_body, eval_streams_tail) - if len(eval_streams_tail) > 0 - else eval_body - ) - return reductor(bind_dims(reduce_tail, index), axis=0) + + if stream_key in fvsof(body): + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + else: + # TODO: In this case, the stream is unused in the body. The body + # should be multiplied by the length of the stream. The current + # behavior is not efficient. + return fwd() + return fwd() diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index 62ee82b07..c7ee7567c 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -583,9 +583,9 @@ def test_reduce_lifted_1(outer, inner, backend): ) -def test_reduce_cartesian_1(backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A = define_vars("A", typ=backend.stream_typ) +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) with handler(NormalizeIntp): term1 = Sum.reduce( @@ -596,9 +596,9 @@ def test_reduce_cartesian_1(backend): assert term1 == term2 -def test_reduce_cartesian_2(backend): - a, i = define_vars("a", "i", typ=backend.scalar_typ) - A = define_vars("A", typ=backend.stream_typ) +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) with handler(NormalizeIntp): term1 = Sum.reduce(