diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76e..c5d104233 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -19,6 +19,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..a406cda5b --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,162 @@ +import functools + +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Sum, + outer_stream, +) +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Operation + + +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) + + +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + + +def _jax_args(args): + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.Array` (no Terms). + """ + typs = (typeof(a) for a in args) + return ( + bool(args) + and any(issubclass(t, jax.Array) for t in typs) + and all(issubclass(t, jax.typing.ArrayLike) for t in typs) + ) + + +class SumPlusJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductPlusJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinPlusJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxPlusJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpPlusJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductPlusJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity + + +ARRAY_REDUCTORS = { + Sum: jnp.sum, + Product: jnp.prod, + Min: jnp.min, + Max: jnp.max, + LogSumExp: logsumexp, +} + + +class ArrayReduce(ObjectInterpretation): + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity + + reductor = ARRAY_REDUCTORS[monoid] + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if not issubclass(typeof(stream_body), jax.Array): + continue + + if stream_key in fvsof(body): + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + else: + # TODO: In this case, the stream is unused in the body. The body + # should be multiplied by the length of the stream. The current + # behavior is not efficient. + return fwd() + + return fwd() + + +NormalizeIntp.extend( + ArrayReduce(), + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), +) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py index 0d6e230c0..70bb50022 100644 --- a/effectful/ops/monoid.py +++ b/effectful/ops/monoid.py @@ -1,10 +1,10 @@ import collections.abc import functools import itertools -import numbers +import operator import typing -from collections import Counter, defaultdict -from collections.abc import Callable, Generator, Iterable, Iterator, Mapping +from collections import Counter, UserDict, defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping from dataclasses import dataclass from graphlib import TopologicalSorter from typing import Annotated, Any @@ -14,21 +14,13 @@ from effectful.ops.syntax import ( ObjectInterpretation, Scoped, - _NumberTerm, deffn, implements, iter_, syntactic_eq, syntactic_hash, ) -from effectful.ops.types import ( - Expr, - Interpretation, - NotHandled, - Operation, - Term, - _CustomSingleDispatchMethod, -) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term # Note: The streams value type should be something like Iterable[T], but some of # our target stream types (e.g. jax.Array) are not subtypes of Iterable @@ -43,31 +35,35 @@ ) -def order_streams[T](streams: Streams[T]) -> Iterable[tuple[Operation[[], T], Any]]: - """Determine an order to evaluate the streams based on their dependencies""" +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ stream_vars = set(streams.keys()) - dependencies = {k: fvsof(v) & stream_vars for k, v in streams.items()} - topo = TopologicalSorter(dependencies) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) topo.prepare() - while topo.is_active(): - node_group = topo.get_ready() - for op in sorted(node_group): - yield (op, streams[op]) - topo.done(*node_group) + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) class Monoid[T]: - kernel: Operation[[T, T], T] + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" + + _name: str identity: T - def __init__(self, kernel: Callable[[T, T], T], identity: T): + def __init__(self, identity: T, name: str): + self._name = name self.identity = identity - self.kernel = ( - kernel if isinstance(kernel, Operation) else Operation.define(kernel) - ) def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity})" + return f"Monoid({self._name!r})" def __eq__(self, other): return id(self) == id(other) @@ -75,166 +71,63 @@ def __eq__(self, other): def __hash__(self): return hash(id(self)) + # the weak typing allows us to write monoid.plus(monoid.identity, ) + # and monoid.plus(monoid.identity, ) @Operation.define - @_CustomSingleDispatchMethod - def plus[S](self, dispatch, *args: S) -> S: - """Monoid addition with broadcasting over common collection types, - callables, and interpretations. + def plus(self, *args: Any) -> Any: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles empty / Term cases. """ if not args: - return typing.cast(S, self.identity) - return dispatch(type(args[0]))(self, *args) - - @plus.register(object) # type: ignore[attr-defined] - def _(self, *args): - if any(isinstance(x, Term) for x in args): - raise NotHandled - return functools.reduce(self.kernel, args, self.identity) - - @plus.register(tuple) # type: ignore[attr-defined] - def _(self, *args): - return tuple(self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Generator) # type: ignore[attr-defined] - def _(self, *args): - return (self.plus(*vs) for vs in zip(*args, strict=True)) - - @plus.register(Mapping) # type: ignore[attr-defined] - def _(self, *args): - if isinstance(args[0], Interpretation): - keys = args[0].keys() - for b in args[1:]: - if not isinstance(b, Interpretation): - raise TypeError(f"Expected interpretation but got {b}") - if not keys == b.keys(): - raise ValueError( - f"Expected interpretation of {keys} but got {b.keys()}" - ) - return {k: self.plus(*(handler(b)(b[k]) for b in args)) for k in keys} - - for b in args[1:]: - if not isinstance(b, Mapping): - raise TypeError(f"Expected mapping but got {b}") - all_values = collections.defaultdict(list) - for d in args: - for k, v in d.items(): - all_values[k].append(v) - return {k: self.plus(*vs) for (k, vs) in all_values.items()} + return self.identity + raise NotHandled @Operation.define - @functools.singledispatchmethod def reduce[A, B, U: Body]( self, body: Annotated[U, Scoped[A | B]], streams: Annotated[Streams, Scoped[A]], ) -> Annotated[U, Scoped[B]]: - if callable(body): - return typing.cast(U, lambda *a, **k: self.reduce(body(*a, **k), streams)) - - def generator(loop_order) -> Iterator[Interpretation]: - if len(loop_order) == 0: - return - - stream_key = loop_order[0][0] - stream_values = evaluate(streams[stream_key]) - stream_values_iter = iter(stream_values) # type: ignore[arg-type] - - # If we try to iterate and get a term instead of a real - # iterator, give up + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: - raise NotHandled - - if len(loop_order) == 1: - for val in stream_values_iter: - yield {stream_key: functools.partial(lambda v: v, val)} - else: - for val in stream_values_iter: - intp: Interpretation = { - stream_key: functools.partial(lambda v: v, val) - } - with handler(intp): - for intp2 in generator(loop_order[1:]): - yield coproduct(intp, intp2) - - loop_order = list(order_streams(streams)) - return self.plus( - *(handler(intp)(evaluate)(body) for intp in generator(loop_order)) - ) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Mapping, streams): - return {k: self.reduce(v, streams) for (k, v) in body.items()} - - @reduce.register # type: ignore[attr-defined] - def _(self, body: tuple, streams): - return tuple(self.reduce(x, streams) for x in body) - - @reduce.register # type: ignore[attr-defined] - def _(self, body: Generator, streams): - return (self.reduce(x, streams) for x in body) - - -def _is_monoid_plus(op: Operation) -> bool: - """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.plus - - -def _is_monoid_reduce(op: Operation) -> bool: - """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" - owner = getattr(op, "__self__", None) - return isinstance(owner, Monoid) and op is owner.reduce + continue + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + self.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return self.plus(*new_reduces) + raise NotHandled class MonoidWithZero[T](Monoid[T]): zero: T - def __init__(self, kernel: Callable[[T, T], T], identity: T, zero: T): - super().__init__(kernel, identity) + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) self.zero = zero - def __repr__(self): - return f"{type(self)}({self.kernel}, {self.identity}, {self.zero})" - - -@Operation.define -def _arg_min[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] < a[0] else a # type: ignore - - -@Operation.define -def _arg_max[T]( - a: tuple[numbers.Number, T | None], b: tuple[numbers.Number, T | None] -) -> tuple[numbers.Number, T | None]: - if isinstance(a[0], Term) or isinstance(b[0], Term): - raise NotHandled - return b if b[0] > a[0] else a # type: ignore - - -@Operation.define -def product[T]( - a: Iterable[tuple[T, ...] | T], b: Iterable[tuple[T, ...] | T] -) -> Iterable[tuple[T, ...]]: - if isinstance(a, Term) or isinstance(b, Term): - raise NotHandled - - def to_tuple(x): - return x if isinstance(x, tuple) else (x,) - return [to_tuple(x) + to_tuple(y) for (x, y) in itertools.product(a, b)] - - -Min = Monoid(kernel=min, identity=float("inf")) -Max = Monoid(kernel=max, identity=float("-inf")) -ArgMin = Monoid(kernel=_arg_min, identity=(float("inf"), None)) -ArgMax = Monoid(kernel=_arg_max, identity=(float("-inf"), None)) -Sum = Monoid(kernel=_NumberTerm.__add__, identity=0) -Product = MonoidWithZero(kernel=_NumberTerm.__mul__, identity=1, zero=0) -CartesianProduct = Monoid(kernel=product, identity=[()]) +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) @dataclass @@ -268,6 +161,18 @@ def __call__(self, s: S, t: T) -> bool: ) +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + class PlusEmpty(ObjectInterpretation): """plus() = 0""" @@ -319,7 +224,7 @@ class PlusDistr(ObjectInterpretation): """x + (y * z) = x * y + x * z""" @implements(Monoid.plus) - def plus(self, monoid, *args): + def plus(self, monoid: Monoid, *args): if any( isinstance(x, Term) and _is_monoid_plus(x.op) @@ -417,24 +322,6 @@ def plus(self, monoid, *args): return fwd() -NormalizePlusIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - PlusEmpty(), - PlusSingle(), - PlusIdentity(), - PlusAssoc(), - PlusDistr(), - PlusZero(), - PlusConsecutiveDups(), - PlusDups(), - ], - ), -) - - class ReduceNoStreams(ObjectInterpretation): """Implements the identity reduce(R, ∅, body) = 0 @@ -668,18 +555,217 @@ def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): return fwd() -NormalizeReduceIntp = functools.reduce( - coproduct, - typing.cast( - list[Interpretation], - [ - ReduceNoStreams(), - ReduceFusion(), - ReduceSplit(), - ReduceFactorization(), - ReduceDistributeCartesianProduct(), - ], - ), +class MonoidOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + + +class MonoidOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumPlus(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinPlus(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxPlus(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductPlus(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductPlus(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) + + +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) # type: ignore[assignment] + return self + + +NormalizeIntp = _ExtensibleInterpretation().extend( + MonoidOverSequence(), + MonoidOverMapping(), + MonoidOverCallable(), + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), ) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). -NormalizeIntp = coproduct(NormalizePlusIntp, NormalizeReduceIntp) +""" diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8fb12598f..5ea04fcb8 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,6 +849,8 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False return x == other diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py index 9b311b257..f15103e30 100644 --- a/tests/_monoid_helpers.py +++ b/tests/_monoid_helpers.py @@ -1,23 +1,60 @@ +import itertools from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass from typing import Any, get_args, get_origin +import jax +from hypothesis import given, settings from hypothesis import strategies as st -from effectful.ops.syntax import deffn -from effectful.ops.types import Operation +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import NormalizeIntp +from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq +from effectful.ops.types import NotHandled, Operation, Term + +_JAX_ARRAY_SHAPE = (2,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=_JAX_ARRAY_SHAPE[0], + max_size=_JAX_ARRAY_SHAPE[0], + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: """Strategy for the value an *0-arg* Operation should return.""" if annotation is int: - return st.integers() + return st.integers(min_value=-100, max_value=100) if annotation is float: return st.floats(allow_nan=False) if get_origin(annotation) is list and get_args(annotation) == (int,): - return st.lists(st.integers(), max_size=2) + return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) + if annotation is jax.Array: + return _jax_array_value_strategy() + if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): + return st.lists(_jax_array_value_strategy(), max_size=2) raise NotImplementedError( f"No value strategy for return annotation {annotation!r}; " - "supported: int, list[int]" + "supported: int, list[int], jax.Array, list[jax.Array]" ) @@ -46,6 +83,13 @@ def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: lambda x: [0, x, x + 1], ] +_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], +] + def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: """Pick a strategy producing a callable suitable for binding `op` in an @@ -64,8 +108,18 @@ def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: return st.sampled_from(_BINARY_NUM_FNS) if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) + if ( + get_origin(ret) is list + and get_args(ret) == (jax.Array,) + and param_types == (jax.Array,) + ): + return st.sampled_from(_UNARY_JAX_LIST_FNS) raise NotImplementedError( - f"Function-typed free var must return int or list[int]; got {ret!r} for {op}" + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" ) @@ -82,4 +136,220 @@ def random_interpretation( return intp -__all__ = ["random_interpretation"] +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs) -> Term: + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@dataclass(frozen=True) +class Backend: + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + scalar_strategy: st.SearchStrategy[Any] + eq: Callable[[Any, Any], bool] + + def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(n_args)) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": scalar for i in range(n_args)}, + "return": out, + } + return Operation.define(fn, name=name) + + +def _int_eq(a: Any, b: Any) -> bool: + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +def _jax_eq(a: Any, b: Any) -> bool: + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +def check_rewrite( + lhs, + rhs, + rule, + *, + backend: Backend, + free_vars=[], + max_examples: int = 25, + deadline=None, +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=deadline) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + +INT_BACKEND = Backend( + name="int", + scalar_typ=int, + stream_typ=list[int], + scalar_strategy=st.integers(min_value=-100, max_value=100), + eq=_int_eq, +) + + +JAX_BACKEND = Backend( + name="jax", + scalar_typ=jax.Array, + stream_typ=jax.Array, + scalar_strategy=_jax_array_value_strategy(), + eq=_jax_eq, +) + + +__all__ = [ + "Backend", + "INT_BACKEND", + "JAX_BACKEND", + "random_interpretation", + "define_vars", + "syntactic_eq_alpha", + "check_rewrite", +] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..35d041fe2 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,96 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import Max, Min, Product, Sum +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor, backend: Backend): + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor, backend: Backend): + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor, backend: Backend): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, f, g], + ) diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py index d881869ac..c7ee7567c 100644 --- a/tests/test_ops_monoid.py +++ b/tests/test_ops_monoid.py @@ -1,29 +1,51 @@ -import functools -import itertools import typing import pytest -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from effectful.internals.runtime import interpreter +import effectful.handlers.jax.monoid # noqa: F401 from effectful.ops.monoid import ( CartesianProduct, Max, Min, Monoid, + MonoidOverMapping, + MonoidOverSequence, NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusIdentity, + PlusSingle, + PlusZero, Product, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReduceNoStreams, + ReduceSplit, Sum, distributes_over, - is_commutative, ) -from effectful.ops.semantics import apply, evaluate, fvsof, handler -from effectful.ops.syntax import _BaseTerm, defdata, syntactic_eq -from effectful.ops.types import NotHandled, Operation -from tests._monoid_helpers import random_interpretation +from effectful.ops.semantics import fvsof, handler +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + INT_BACKEND, + JAX_BACKEND, + Backend, + check_rewrite, + define_vars, + syntactic_eq_alpha, +) + + +@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param -_INT = st.integers(min_value=-100, max_value=100) ALL_MONOIDS = [ pytest.param(Sum, id="Sum"), @@ -61,247 +83,183 @@ ] -def define_vars(*names, typ=int): - if len(names) == 1: - return Operation.define(typ, name=names[0]) - return tuple(Operation.define(typ, name=n) for n in names) - - -@functools.cache -def _canonical_op(idx: int) -> Operation: - """Globally cached canonical Operation, keyed by encounter index. - - Cached so that two independent canonicalize runs return the same - Operation object for the same index — letting ``syntactic_eq`` - compare canonical forms by Operation identity. - """ - return Operation.define(int, name=f"__cv_{idx}") - - -def syntactic_eq_alpha(x, y) -> bool: - """Alpha-equivalence-respecting variant of ``syntactic_eq``. - - Walks each expression bottom-up with :func:`evaluate` and renames - every bound variable to a deterministic canonical Operation. The - canonical names are assigned by a counter that increments in - ``evaluate``'s natural traversal order, so two alpha-equivalent - expressions canonicalize to syntactically identical results. - """ - return syntactic_eq(_canonicalize(x), _canonicalize(y)) - - -def _canonicalize(expr): - counter = itertools.count() - - def _substitute(arg, renaming): - """Apply a bound-variable renaming using ``evaluate`` for traversal.""" - if not renaming: - return arg - with interpreter({apply: _BaseTerm, **renaming}): - return evaluate(arg) - - def _bound_var_order(args, kwargs, bound_set): - """Return bound variables in deterministic encounter order.""" - seen: list[Operation] = [] - seen_set: set[Operation] = set() - - def _capture(op, *a, **kw): - if op in bound_set and op not in seen_set: - seen.append(op) - seen_set.add(op) - return defdata(op, *a, **kw) - - # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, - # etc. for free; the apply handler captures bound vars used as - # ``x()`` anywhere in the body. - with interpreter({apply: _capture}): - evaluate((args, kwargs)) - - # Binders bypass the apply handler. Pick them up with a small structural - # walk that visits dict keys too. - def _walk_bare(obj): - if isinstance(obj, Operation): - if obj in bound_set and obj not in seen_set: - seen.append(obj) - seen_set.add(obj) - elif isinstance(obj, dict): - for k, v in obj.items(): - _walk_bare(k) - _walk_bare(v) - elif isinstance(obj, list | set | frozenset | tuple): - for v in obj: - _walk_bare(v) - - _walk_bare((args, kwargs)) - return seen - - def _apply_canonical(op, *args, **kwargs): - bindings = op.__fvs_rule__(*args, **kwargs) - all_bound: set[Operation] = set().union( - *bindings.args, *bindings.kwargs.values() - ) - if not all_bound: - return _BaseTerm(op, *args, **kwargs) - - order = _bound_var_order(args, kwargs, all_bound) - canonical = {var: _canonical_op(next(counter)) for var in order} - assert all_bound <= set(order) - - new_args = tuple( - _substitute( - arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} - ) - for i, arg in enumerate(args) - ) - new_kwargs = { - k: _substitute( - v, - {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, - ) - for k, v in kwargs.items() - } - - # avoid the renaming from defdata - return _BaseTerm(op, *new_args, **new_kwargs) - - with interpreter({apply: _apply_canonical}): - return evaluate(expr) - - @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT, b=_INT, c=_INT) -@settings(max_examples=50, deadline=None) -def test_associativity(monoid, a, b, c): - left = monoid.plus(monoid.plus(a, b), c) - right = monoid.plus(a, monoid.plus(b, c)) - assert left == right +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_associativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + c = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_identity(monoid, a): - assert monoid.plus(monoid.identity, a) == a - assert monoid.plus(a, monoid.identity) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_identity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) @pytest.mark.parametrize("monoid", COMMUTATIVE) -@given(a=_INT, b=_INT) -@settings(max_examples=50, deadline=None) -def test_commutativity(monoid, a, b): - assert monoid.plus(a, b) == monoid.plus(b, a) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_commutativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) @pytest.mark.parametrize("monoid", IDEMPOTENT) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_idempotence(monoid, a): - assert monoid.plus(a, a) == a +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_idempotence(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, a), a) @pytest.mark.parametrize("monoid", WITH_ZERO) -@given(a=_INT) -@settings(max_examples=50, deadline=None) -def test_zero_absorbs(monoid, a): - assert monoid.plus(monoid.zero, a) == monoid.zero - assert monoid.plus(a, monoid.zero) == monoid.zero - - -def _check_pair(lhs, rhs, *, free_vars=[], max_examples: int = 25) -> None: - """Run structural + semantic checks on a TermPair.""" +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_zero_absorbs(monoid, backend, data): + a = data.draw(backend.scalar_strategy) with handler(NormalizeIntp): - norm = evaluate(lhs) - - assert syntactic_eq_alpha(norm, rhs) + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) - @given(intp=random_interpretation(free_vars)) - @settings(max_examples=max_examples, deadline=None) - def _check_semantics(intp): - with handler(intp): - lhs_val = evaluate(lhs) - rhs_val = evaluate(rhs) - assert lhs_val == rhs_val - _check_semantics() +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid, backend): + check_rewrite( + lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_empty(monoid): - _check_pair(lhs=monoid.plus(), rhs=monoid.identity) +def test_plus_single(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_single(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x()), rhs=x(), free_vars=[x]) +def test_plus_identity_right(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) -@pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_right(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(x(), monoid.identity), rhs=x(), free_vars=[x]) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_identity_left(monoid): - x = define_vars("x", typ=type(monoid.identity)) - _check_pair(lhs=monoid.plus(monoid.identity, x()), rhs=x(), free_vars=[x]) +def test_plus_identity_left(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_right(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_right(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(x(), monoid.plus(y(), z())), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_assoc_left(monoid): - x, y, z = define_vars("x", "y", "z", typ=type(monoid.identity)) - _check_pair( +def test_plus_assoc_left(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus(monoid.plus(x(), y()), z()), rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, free_vars=[x, y, z], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_sequence(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( +def test_plus_sequence(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + check_rewrite( lhs=monoid.plus((a(), b()), (c(), d())), rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), + backend=backend, free_vars=[a, b, c, d], ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_plus_mapping(monoid): - a, b, c, d = define_vars("a", "b", "c", "d", typ=type(monoid.identity)) - _check_pair( - lhs=monoid.plus({"x": a(), "y": b()}, {"x": c(), "z": d()}), - rhs={"x": monoid.plus(a(), c()), "y": b(), "z": d()}, +def test_plus_mapping(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, free_vars=[a, b, c, d], ) -def test_plus_distributes(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) - rhs = Sum.plus( - Product.plus(a(), c()), - Product.plus(a(), d()), - Product.plus(b(), c()), - Product.plus(b(), d()), + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) -def test_plus_distributes_constant(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_constant(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) rhs = Product.plus( 5, @@ -312,11 +270,13 @@ def test_plus_distributes_constant(): Product.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) -def test_plus_distributes_multiple(): - a, b, c, d = define_vars("a", "b", "c", "d") +def test_plus_distributes_multiple(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) lhs = Sum.plus( Min.plus(a(), b()), Min.plus(c(), d()), @@ -337,72 +297,123 @@ def test_plus_distributes_multiple(): Sum.plus(b(), d()), ), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b, c, d]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_consecutive(monoid): +def test_plus_idempotent_consecutive(monoid, backend): """``a, a, b → a, b`` — only consecutive duplicates collapse.""" - a, b = define_vars("a", "b") + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), a(), b()) - return _check_pair(lhs=lhs, rhs=monoid.plus(a(), b()), free_vars=[a, b]) + return check_rewrite( + lhs=lhs, + rhs=monoid.plus(a(), b()), + rule=PlusConsecutiveDups(), + backend=backend, + free_vars=[a, b], + ) @pytest.mark.parametrize("monoid", IDEMPOTENT) -def test_plus_idempotent_non_consecutive(monoid): +def test_plus_idempotent_non_consecutive(monoid, backend): """``a, b, a`` — Semilattice (Min/Max) collapses via commutative - PlusDups; plain IdempotentMonoid leaves it as-is (consecutive-only).""" - a, b = define_vars("a", "b") + PlusDups.""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) lhs = monoid.plus(a(), b(), a()) - if is_commutative(monoid): - rhs = monoid.plus(a(), b()) - else: - rhs = monoid.plus(a(), b(), a()) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a, b]) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) -def test_plus_commutative_idempotent_long(): +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend): """Long alternation collapses via commutative dedup (Min/Max only).""" - a, b = define_vars("a", "b") - lhs = Min.plus(a(), b(), a(), b(), b(), a(), a()) - _check_pair(lhs=lhs, rhs=Min.plus(a(), b()), free_vars=[a, b]) + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) @pytest.mark.parametrize("monoid", WITH_ZERO) -def test_plus_zero(monoid): - a = define_vars("a") +def test_plus_zero(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs_right = monoid.plus(a(), monoid.zero) lhs_left = monoid.plus(monoid.zero, a()) - _check_pair(lhs=lhs_right, rhs=monoid.zero, free_vars=[a]) - _check_pair(lhs=lhs_left, rhs=monoid.zero, free_vars=[a]) + rhs = monoid.zero + check_rewrite( + lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + check_rewrite( + lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") +def test_partial_3(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=1, ret="stream") + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) - @Operation.define - def f(_x: int) -> int: - raise NotHandled + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(x())), {x: X()}) rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_sequence_2(monoid): - x, y = define_vars("x", "y") - X, Y = define_vars("X", "Y", typ=list[int]) - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_sequence_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + X, Y = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) @@ -411,103 +422,115 @@ def f(_x: int) -> int: monoid.reduce(g(y()), {x: X(), y: Y()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, Y, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, Y, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_body_mapping(monoid): - x = Operation.define(int, name="x") - X = Operation.define(list[int], name="X") - - @Operation.define - def f(_x: int) -> int: - raise NotHandled - +def test_reduce_body_mapping(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") g = Operation.define(f, name="g") - lhs = monoid.reduce({"a": f(x()), "b": g(x())}, {x: X()}) + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) rhs = { - "a": monoid.reduce(f(x()), {x: X()}), - "b": monoid.reduce(g(x()), {x: X()}), + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), } - _check_pair(lhs=lhs, rhs=rhs, free_vars=[X, f, g]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[X, f, g], + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_no_streams(monoid): - a = define_vars("a") +def test_reduce_no_streams(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) lhs = monoid.reduce(a(), {}) rhs = monoid.identity - _check_pair(lhs=lhs, rhs=rhs, free_vars=[a]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] + ) @pytest.mark.parametrize("monoid", ALL_MONOIDS) -def test_reduce_reduce(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_reduce(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, f]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] + ) @pytest.mark.parametrize("monoid", COMMUTATIVE) -def test_reduce_plus(monoid): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_plus(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) rhs = monoid.plus( monoid.reduce(a(), {a: A(), b: B()}), monoid.reduce(b(), {a: A(), b: B()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_1(): - a, b = define_vars("a", "b") - A, B = define_vars("A", "B", typ=list[int]) +def test_reduce_independent_1(backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) - rhs = Product.plus(Sum.reduce(a(), {a: A()}), Sum.reduce(b(), {b: B()})) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B]) - + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] + ) -def test_reduce_independent_2(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_2(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) rhs = Product.plus( - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) -def test_reduce_independent_3_negative(): +def test_reduce_independent_3_negative(backend): """Stream `b` depends on `a` (b: g(a())), so the proposed factorization is unsound — the normalizer must NOT apply it.""" - a, b, c = define_vars("a", "b", "c") - A, C = define_vars("A", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, C = define_vars("A", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") - @Operation.define - def g(_x: int) -> list[int]: - raise NotHandled - - with handler(NormalizeIntp): + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] lhs = Sum.reduce( Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} ) @@ -519,104 +542,107 @@ def g(_x: int) -> list[int]: assert not syntactic_eq_alpha(lhs, bogus_rhs) -def test_reduce_independent_4(): - a, b, c = define_vars("a", "b", "c") - A, B, C = define_vars("A", "B", "C", typ=list[int]) - - @Operation.define - def f(_x: int, _y: int) -> int: - raise NotHandled +def test_reduce_independent_4(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) rhs = Product.plus( 7, - Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), ) - _check_pair(lhs=lhs, rhs=rhs, free_vars=[A, B, C, f]) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_1(outer, inner): - a, i = define_vars("a", "i") - A, N, A_domain = define_vars("A", "N", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_1(outer, inner, backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N()})}, ) - term2 = inner.reduce(outer.reduce(f(a()), {a: A_domain()}), {i: N()}) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, A_domain, f]) + term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, A_domain, f], + ) def test_reduce_cartesian_1(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: []}), - {A: CartesianProduct.reduce([], {i: []})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) assert term1 == term2 def test_reduce_cartesian_2(): - a, i = define_vars("a", "i") - A = define_vars("A", typ=list[int]) + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) - term1 = Sum.reduce( - Product.reduce(a(), {a: A()}), - {A: CartesianProduct.reduce([(0,)], {i: [0]})}, - ) - term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) assert term1 == term2 @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_multi_index(outer, inner): - a, i, j = define_vars("a", "i", "j") - A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=list[int]) - - @Operation.define - def f(_: int) -> float: - raise NotHandled +def test_reduce_lifted_multi_index(outer, inner, backend): + a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") term1 = outer.reduce( inner.reduce(f(a()), {a: A()}), {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, ) term2 = inner.reduce( - outer.reduce(f(a()), {a: A_domain()}), + outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N(), j: M()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[N, M, A_domain, f]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, M, A_domain, f], + ) @pytest.mark.parametrize("outer,inner", MONOID_PAIRS) -def test_reduce_lifted_2(outer, inner): +def test_reduce_lifted_2(outer, inner, backend): """The worked example on page 396 of 'Lifted Variable Elimination: Decoupling the Operators from the Constraint Language'. """ - a, i, s, t = define_vars("a", "i", "s", "t") - A, N, T = define_vars("A", "N", "T", typ=list[int]) - - @Operation.define - def A_domain(_i: int) -> list[int]: - raise NotHandled - - @Operation.define - def f1(_a: int, _s: int) -> float: - raise NotHandled - - @Operation.define - def f2(_t: int, _a: int) -> float: - raise NotHandled + a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) + A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) + A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") + f1 = backend.fresh_op("f1", n_args=2, ret="scalar") + f2 = backend.fresh_op("f2", n_args=2, ret="scalar") term1 = outer.reduce( inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), @@ -625,10 +651,18 @@ def f2(_t: int, _a: int) -> float: term2 = outer.reduce( inner.reduce( - outer.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A_domain(i())}), + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), {i: N()}, ), {t: T()}, ) - _check_pair(lhs=term1, rhs=term2, free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2]) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + )