diff --git a/effectful/handlers/jax/_handlers.py b/effectful/handlers/jax/_handlers.py index 308cdb76e..c5d104233 100644 --- a/effectful/handlers/jax/_handlers.py +++ b/effectful/handlers/jax/_handlers.py @@ -19,6 +19,7 @@ deffn, defop, syntactic_eq, + syntactic_hash, ) from effectful.ops.types import Expr, NotHandled, Operation, Term @@ -277,3 +278,9 @@ def _(x: jax.Array, other) -> bool: and x.shape == other.shape and bool((jnp.asarray(x) == jnp.asarray(other)).all()) ) + + +@syntactic_hash.register(jax.Array) +def _(x: jax.Array) -> int: + # Concrete arrays aren't hashable; hash by shape, dtype, and bytes. + return hash(("jax.Array", x.shape, str(x.dtype), bytes(jax.numpy.asarray(x)))) diff --git a/effectful/handlers/jax/monoid.py b/effectful/handlers/jax/monoid.py new file mode 100644 index 000000000..a406cda5b --- /dev/null +++ b/effectful/handlers/jax/monoid.py @@ -0,0 +1,162 @@ +import functools + +import jax + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + NormalizeIntp, + Product, + Sum, + outer_stream, +) +from effectful.ops.semantics import evaluate, fvsof, fwd, handler, typeof +from effectful.ops.syntax import ObjectInterpretation, deffn, implements +from effectful.ops.types import Operation + + +def cartesian_prod(x, y): + if x.ndim == 1: + x = x[:, None] + if y.ndim == 1: + y = y[:, None] + nx, dx = x.shape + ny, dy = y.shape + # Broadcast into (nx, ny, dx+dy), then flatten the first two axes + x_b = jnp.broadcast_to(x[:, None, :], (nx, ny, dx)) + y_b = jnp.broadcast_to(y[None, :, :], (nx, ny, dy)) + return jnp.concatenate([x_b, y_b], axis=-1).reshape(nx * ny, dx + dy) + + +LogSumExp = Monoid(name="LogSumExp", identity=jnp.asarray(float("-inf"))) + + +def _jax_args(args): + """True iff ``args`` is non-empty and every arg is a concrete + :class:`jax.Array` (no Terms). + """ + typs = (typeof(a) for a in args) + return ( + bool(args) + and any(issubclass(t, jax.Array) for t in typs) + and all(issubclass(t, jax.typing.ArrayLike) for t in typs) + ) + + +class SumPlusJax(ObjectInterpretation): + @implements(Sum.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.add, args) + + +class ProductPlusJax(ObjectInterpretation): + @implements(Product.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.multiply, args) + + +class MinPlusJax(ObjectInterpretation): + @implements(Min.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.minimum, args) + + +class MaxPlusJax(ObjectInterpretation): + @implements(Max.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.maximum, args) + + +class LogSumExpPlusJax(ObjectInterpretation): + @implements(LogSumExp.plus) + def plus(self, *args): + if not _jax_args(args): + return fwd() + return functools.reduce(jnp.logaddexp, args) + + +class CartesianProductPlusJax(ObjectInterpretation): + @implements(CartesianProduct.plus) + def plus(self, *args): + # Skip identity ``[()]`` args; short-circuit on zero ``[]``. Both + # sentinels arrive as Python lists alongside jax-array factors, so + # check for them explicitly before composing. + if not any(isinstance(a, jax.Array) for a in args): + return fwd() + result = None + for a in args: + if a is CartesianProduct.zero: + return CartesianProduct.zero + if a is CartesianProduct.identity: + continue + if not isinstance(a, jax.Array): + return fwd() + result = a if result is None else cartesian_prod(result, a) + return result if result is not None else CartesianProduct.identity + + +ARRAY_REDUCTORS = { + Sum: jnp.sum, + Product: jnp.prod, + Min: jnp.min, + Max: jnp.max, + LogSumExp: logsumexp, +} + + +class ArrayReduce(ObjectInterpretation): + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if monoid not in ARRAY_REDUCTORS or typeof(body) is not jax.Array: + return fwd() + if not streams: + return monoid.identity + + reductor = ARRAY_REDUCTORS[monoid] + index = Operation.define(jax.Array) + for stream_key, stream_body, streams_tail in outer_stream(streams): + if not issubclass(typeof(stream_body), jax.Array): + continue + + if stream_key in fvsof(body): + with handler({stream_key: deffn(unbind_dims(stream_body, index))}): + eval_body = evaluate(body) + eval_streams_tail = evaluate(streams_tail) + assert isinstance(eval_streams_tail, dict) + reduce_tail = ( + monoid.reduce(eval_body, eval_streams_tail) + if len(eval_streams_tail) > 0 + else eval_body + ) + return reductor(bind_dims(reduce_tail, index), axis=0) + else: + # TODO: In this case, the stream is unused in the body. The body + # should be multiplied by the length of the stream. The current + # behavior is not efficient. + return fwd() + + return fwd() + + +NormalizeIntp.extend( + ArrayReduce(), + SumPlusJax(), + ProductPlusJax(), + MinPlusJax(), + MaxPlusJax(), + LogSumExpPlusJax(), + CartesianProductPlusJax(), +) diff --git a/effectful/internals/disjoint_set.py b/effectful/internals/disjoint_set.py new file mode 100644 index 000000000..73b5c5c52 --- /dev/null +++ b/effectful/internals/disjoint_set.py @@ -0,0 +1,99 @@ +class DisjointSet: + """Disjoint Set Union (Union-Find) data structure. + + Maintains a collection of disjoint sets over the integers 0..n-1, + supporting near-constant-time union and find operations via + path compression and union by rank. + + The amortized time complexity per operation is O(α(n)), where α + is the inverse Ackermann function (effectively constant for any + practical n). + + Example: + >>> dsu = DisjointSet(5) + >>> dsu.union(0, 1) + True + >>> dsu.union(1, 2) + True + >>> dsu.find(0) == dsu.find(2) + True + >>> dsu.find(0) == dsu.find(3) + False + """ + + def __init__(self, n): + """Initialize n singleton sets: {0}, {1}, ..., {n-1}. + + Args: + n: The number of elements. Elements are labeled 0..n-1. + """ + self.parent = list(range(n)) + self.rank = [0] * n + + def _validate(self, x): + if x < 0 or x >= len(self.parent): + raise IndexError(f"Element {x} out of bounds") + + def find(self, x): + """Return the representative (root) of the set containing x. + + Two elements belong to the same set if and only if they have + the same representative. Applies path compression: every node + traversed is re-parented directly to its grandparent, flattening + the tree to speed up future queries. + + Args: + x: The element to look up. + + Returns: + The root element of x's set. + """ + self._validate(x) + while self.parent[x] != x: + self.parent[x] = self.parent[self.parent[x]] # path compression + x = self.parent[x] + return x + + def union(self, *elements): + """Merge the sets containing all given elements into one. + + Accepts any number of elements and unions them all together. + Uses union by rank: shallower trees are attached under the root + of the deeper one, keeping the combined tree shallow. + + Args: + *elements: Two or more elements to merge into a single set. + Calling with 0 or 1 elements is a no-op and returns False. + + Returns: + True if any merging occurred (i.e., at least two of the + elements were in different sets); False if all elements + were already in the same set or fewer than 2 were given. + """ + if len(elements) < 2: + return False + + merged = False + first = elements[0] + + for y in elements[1:]: + if self._union_pair(first, y): + merged = True + + return merged + + def _union_pair(self, x, y): + rx = self.find(x) + ry = self.find(y) + + if rx == ry: + return False + + if self.rank[rx] < self.rank[ry]: + rx, ry = ry, rx + + self.parent[ry] = rx + if self.rank[rx] == self.rank[ry]: + self.rank[rx] += 1 + + return True diff --git a/effectful/internals/product_n.py b/effectful/internals/product_n.py index 4b8bd2a81..87a9c6a42 100644 --- a/effectful/internals/product_n.py +++ b/effectful/internals/product_n.py @@ -69,7 +69,7 @@ def map_structure(func, expr): else: return type(expr)(map_structure(func, tuple(expr.items()))) elif isinstance(expr, collections.abc.Sequence): - if isinstance(expr, str | bytes): + if isinstance(expr, str | bytes | range): return expr elif ( isinstance(expr, tuple) diff --git a/effectful/ops/monoid.py b/effectful/ops/monoid.py new file mode 100644 index 000000000..70bb50022 --- /dev/null +++ b/effectful/ops/monoid.py @@ -0,0 +1,771 @@ +import collections.abc +import functools +import itertools +import operator +import typing +from collections import Counter, UserDict, defaultdict +from collections.abc import Callable, Generator, Iterable, Mapping +from dataclasses import dataclass +from graphlib import TopologicalSorter +from typing import Annotated, Any + +from effectful.internals.disjoint_set import DisjointSet +from effectful.ops.semantics import coproduct, evaluate, fvsof, fwd, handler +from effectful.ops.syntax import ( + ObjectInterpretation, + Scoped, + deffn, + implements, + iter_, + syntactic_eq, + syntactic_hash, +) +from effectful.ops.types import Expr, Interpretation, NotHandled, Operation, Term + +# Note: The streams value type should be something like Iterable[T], but some of +# our target stream types (e.g. jax.Array) are not subtypes of Iterable +type Streams[T] = Mapping[Operation[[], T], Any] + +type Body[T] = ( + Iterable[T] + | Callable[..., Body[T]] + | T + | Mapping[Any, Body[T]] + | Interpretation[T, Body[T]] +) + + +def outer_stream( + streams: Streams, +) -> Iterable[tuple[Operation, Expr, dict[Operation, Expr]]]: + """Returns the streams that can be ordered outermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + pred = {k: fvsof(v) & stream_vars for k, v in streams.items()} + topo = TopologicalSorter(pred) + topo.prepare() + return ( + (op, streams[op], {k: v for (k, v) in streams.items() if k != op}) + for op in topo.get_ready() + ) + + +class Monoid[T]: + """A monoid with ``plus`` and ``reduce`` :class:`Operation` s.""" + + _name: str + identity: T + + def __init__(self, identity: T, name: str): + self._name = name + self.identity = identity + + def __repr__(self): + return f"Monoid({self._name!r})" + + def __eq__(self, other): + return id(self) == id(other) + + def __hash__(self): + return hash(id(self)) + + # the weak typing allows us to write monoid.plus(monoid.identity, ) + # and monoid.plus(monoid.identity, ) + @Operation.define + def plus(self, *args: Any) -> Any: + """Monoid addition. Handlers supply per-monoid and broadcasting + behavior; the default rule only handles empty / Term cases. + """ + if not args: + return self.identity + raise NotHandled + + @Operation.define + def reduce[A, B, U: Body]( + self, + body: Annotated[U, Scoped[A | B]], + streams: Annotated[Streams, Scoped[A]], + ) -> Annotated[U, Scoped[B]]: + """Reduce ``body`` over ``streams``. Handlers supply per-monoid and + broadcasting behavior; the default rule only handles the empty-stream + case. + """ + for stream_key, stream_body, streams_tail in outer_stream(streams): + if isinstance(stream_body, Term): + continue + stream_values_iter = iter(stream_body) + if isinstance(stream_values_iter, Term) and stream_values_iter.op is iter_: + continue + new_reduces = [] + for stream_val in stream_values_iter: + with handler({stream_key: deffn(stream_val)}): + eval_args = evaluate((body, streams_tail)) + assert isinstance(eval_args, tuple) + new_reduces.append( + self.reduce(*eval_args) if streams_tail else eval_args[0] + ) + return self.plus(*new_reduces) + raise NotHandled + + +class MonoidWithZero[T](Monoid[T]): + zero: T + + def __init__(self, name: str, identity: T, zero: T): + super().__init__(name=name, identity=identity) + self.zero = zero + + +Min = Monoid(name="Min", identity=float("inf")) +Max = Monoid(name="Max", identity=-float("inf")) +ArgMin = Monoid(name="ArgMin", identity=(Min.identity, None)) +ArgMax = Monoid(name="ArgMax", identity=(Max.identity, None)) +Sum = Monoid(name="Sum", identity=0) +Product = MonoidWithZero(name="Product", identity=1, zero=0) +# CartesianProduct values are "two-level indexable" (rows × positions). The +# identity ``[()]`` is one row of zero positions (composing with it preserves +# shape); the zero ``[]`` is no rows (absorbs under product). +CartesianProduct = MonoidWithZero(name="CartesianProduct", identity=[()], zero=[]) + + +@dataclass +class _ExtensiblePredicate[T]: + elems: set[T] + + def register(self, t: T) -> None: + self.elems.add(t) + + def __call__(self, t: T) -> bool: + return t in self.elems + + +is_commutative = _ExtensiblePredicate({Max, Min, Sum, Product}) +is_idempotent = _ExtensiblePredicate({Max, Min}) + + +@dataclass +class _ExtensibleBinaryRelation[S, T]: + tuples: set[tuple[S, T]] + + def register(self, s: S, t: T) -> None: + self.tuples.add((s, t)) + + def __call__(self, s: S, t: T) -> bool: + return (s, t) in self.tuples + + +distributes_over = _ExtensibleBinaryRelation( + {(Max, Min), (Min, Max), (Sum, Min), (Sum, Max), (Product, Sum)} +) + + +def _is_monoid_plus(op: Operation) -> bool: + """True if ``op`` is the ``plus`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.plus + + +def _is_monoid_reduce(op: Operation) -> bool: + """True if ``op`` is the ``reduce`` operation of some :class:`Monoid`.""" + owner = getattr(op, "__self__", None) + return isinstance(owner, Monoid) and op is owner.reduce + + +class PlusEmpty(ObjectInterpretation): + """plus() = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args: + return monoid.identity + return fwd() + + +class PlusSingle(ObjectInterpretation): + """plus(x) = x""" + + @implements(Monoid.plus) + def plus(self, _, *args): + if len(args) == 1: + return args[0] + return fwd() + + +class PlusIdentity(ObjectInterpretation): + """x₁ + ... + 0 + ... + xₙ = x₁ + ... + xₙ""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if any(x is monoid.identity for x in args): + return monoid.plus(*(x for x in args if x is not monoid.identity)) + return fwd() + + +class PlusAssoc(ObjectInterpretation): + """x + (y + z) = (x + y) + z = x + y + z""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + def is_nested_plus(x): + return isinstance(x, Term) and x.op is monoid.plus + + if any(is_nested_plus(x) for x in args): + flat_args = itertools.chain.from_iterable( + t.args if is_nested_plus(t) else (t,) for t in args + ) + assert len(args) > 0 + return monoid.plus(*flat_args) + return fwd() + + +class PlusDistr(ObjectInterpretation): + """x + (y * z) = x * y + x * z""" + + @implements(Monoid.plus) + def plus(self, monoid: Monoid, *args): + if any( + isinstance(x, Term) + and _is_monoid_plus(x.op) + and distributes_over(monoid, x.op.__self__) + for x in args + ): + non_terms = [] + + # group terms by their monoid + by_monoid: dict[Monoid, list[Term]] = defaultdict(list) + for t in args: + if isinstance(t, Term) and _is_monoid_plus(t.op): + by_monoid[t.op.__self__].append(t) + else: + non_terms.append(t) + + # distribute over each group + progress = False + final_sum = [] + for m, terms in by_monoid.items(): + if ( + len(terms) > 1 + and distributes_over(monoid, m) + and not distributes_over(m, monoid) + ): + progress = True + term_args = (t.args for t in terms) + dist_terms = ( + monoid.plus(*args) for args in itertools.product(*term_args) + ) + final_sum.append(m.plus(*dist_terms)) + else: + final_sum += terms + if progress: + return monoid.plus(*non_terms, *final_sum) + return fwd() + + +class PlusZero(ObjectInterpretation): + """x₁ * ... * 0 * ... * xₙ = 0""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not (isinstance(monoid, MonoidWithZero)): + return fwd() + if any(x is monoid.zero for x in args): + return monoid.zero + return fwd() + + +class PlusConsecutiveDups(ObjectInterpretation): + """x ⊕ x ⊕ y = x ⊕ y""" + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not is_idempotent(monoid): + return fwd() + + dedup_args = ( + args[i] + for i in range(len(args)) + if i == 0 or not syntactic_eq(args[i - 1], args[i]) + ) + return fwd(monoid, *dedup_args) + + +class PlusDups(ObjectInterpretation): + """x ⊕ y ⊕ x = x ⊕ y""" + + @dataclass + class _HashableTerm: + term: Term + + def __eq__(self, other): + return syntactic_eq(self, other) + + def __hash__(self): + return syntactic_hash(self) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not (is_idempotent(monoid) and is_commutative(monoid)): + return fwd() + + # elim dups + args_count = Counter(self._HashableTerm(t) for t in args) + if len(args_count) < len(args): + dedup_args = [] + for t in args: + ht = self._HashableTerm(t) + if ht in args_count: + dedup_args.append(t) + del args_count[ht] + return fwd(monoid, *dedup_args) + return fwd() + + +class ReduceNoStreams(ObjectInterpretation): + """Implements the identity + reduce(R, ∅, body) = 0 + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, _, streams): + if len(streams) == 0: + return monoid.identity + return fwd() + + +class ReduceFusion(ObjectInterpretation): + """Implements the identity + reduce(R, S1, reduce(R, S2, body)) = reduce(R, S1 ∪ S2, body) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) and body.op is monoid.reduce: + return monoid.reduce(body.args[0], streams | body.args[1]) + return fwd() + + +class ReduceSplit(ObjectInterpretation): + """Implements the identity + reduce(R, S, b1 + ... + bn) = reduce(R, S, b1) + ... + reduce(R, S, bn) + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_commutative(monoid): + return fwd() + if isinstance(body, Term) and body.op is monoid.plus: + return monoid.plus(*(monoid.reduce(x, streams) for x in body.args)) + return fwd() + + +class ReduceFactorization(ObjectInterpretation): + """ + Implements factorization of independent terms. + For example, when having two independent distributions, + we can rewrite their marginalization as: + ∫p(x)⋅q(y)dxdy => ∫p(x)dx ⋅ ∫q(y)dy + + More specifically, in terms of reduces we are performing: + reduce(R, (S₁ × ... × Sₖ) , A₁ * ... * Aₖ) + => reduce(R, S₁, A₁) * ... * reduce(R, Sₖ, Aₖ) + where free(Aᵢ) ∩ free(Aⱼ) ∩ S = ∅ + and free(Aᵢ) ∩ S ⊆ Sᵢ + """ + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_commutative(monoid): + return fwd() + if ( + isinstance(body, Term) + and _is_monoid_plus(body.op) + and distributes_over(body.op.__self__, monoid) + ): + inner_monoid: Monoid = body.op.__self__ + stream_vars = set(streams.keys()) + factors = [(arg, fvsof(arg)) for arg in body.args] + stream_ids = {v: i for (i, v) in enumerate(stream_vars)} + ds = DisjointSet(len(streams)) + + # streams are in the same partition as their dependencies + for stream_var, stream_id in stream_ids.items(): + stream_body = streams[stream_var] + deps = sorted([stream_ids[v] for v in fvsof(stream_body) & stream_vars]) + ds.union(stream_id, *deps) + + # factors are in the same partition as their dependencies + for _, factor_fvs in factors: + factor_streams = sorted( + [stream_ids[v] for v in (factor_fvs & stream_vars)] + ) + ds.union(*factor_streams) + + placed_streams = set() + new_reduces = [] + for stream_key in streams: + if stream_key in placed_streams: + continue + + partition = ds.find(stream_ids[stream_key]) + partition_streams = { + k: v + for (k, v) in streams.items() + if ds.find(stream_ids[k]) == partition + } + partition_stream_keys = set(partition_streams.keys()) + + partition_factors = [ + t for t in factors if (t[1] & partition_stream_keys) + ] + + assert all( + (t[1] & stream_vars) <= partition_stream_keys + for t in partition_factors + ), "partition contains all streams required by factor" + + partition_term = inner_monoid.plus(*(t[0] for t in partition_factors)) + new_reduces.append((partition_term, partition_streams)) + placed_streams |= partition_stream_keys + + constant_factors = [t for (t, fvs) in factors if not (fvs & stream_vars)] + + if len(new_reduces) > 1: + result = inner_monoid.plus( + *constant_factors, *(monoid.reduce(*args) for args in new_reduces) + ) + return result + + return fwd() + + +def inner_stream( + streams: dict[Operation, Expr], +) -> Iterable[tuple[dict[Operation, Expr], Operation, Expr]]: + """Returns the streams that can be ordered innermost in the loop nest as + well as the remaining streams in the nest. + + """ + stream_vars = set(streams.keys()) + + no_dependents = set() + succ = defaultdict(set) + for k, v in streams.items(): + preds = fvsof(v) & stream_vars + if preds: + for pred in preds: + succ[pred].add(k) + else: + no_dependents.add(k) + + topo = TopologicalSorter(succ) + topo.prepare() + return ( + ({k: v for (k, v) in streams.items() if k != op}, op, streams[op]) + for op in set(topo.get_ready()) | no_dependents + ) + + +class ReduceDistributeCartesianProduct(ObjectInterpretation): + """Eliminates a reduce over a cartesian product. + ∑_x₁ ∑_x₂ ... ∑_xₙ ∏_i f(xᵢ) = ∏_i ∑_xᵢ f(xᵢ) + This transform is also called inversion in the lifting + literature (e.g. [1]). + + More specifically, this transform implements the identity + reduce(⨁, reduce(⨂, body2, {vv: v()}), {v: reduce(×, body1, S1)} ∪ S2) + = reduce(⨁, reduce(⨂, reduce(⨁, body2, {vv: body1}), S1), S2) + where × is the cartesian product and ⨂ distributes over ⨁. + + Note: This could be generalized to grouped inversion [2]. + + [1] Braz, Rd, Eyal Amir, and Dan Roth. "Lifted first-order + probabilistic inference." IJCAI. 2005. + [2] Taghipour, Nima, et al. "Completeness results for lifted + variable elimination." AISTATS. 2013. + """ + + @implements(Monoid.reduce) + def reduce(self, sum_monoid: Monoid, sum_body, sum_streams): + if not (is_commutative(sum_monoid) and isinstance(sum_body, Term)): + return fwd() + + # body is a product or multiplication of products + if _is_monoid_plus(sum_body.op) and distributes_over( + sum_body.op.__self__, sum_monoid + ): + prod_reduces = sum_body.args + else: + prod_reduces = [sum_body] + + products: list[tuple[Monoid, Callable, Operation, Term]] = [] + for prod_reduce in prod_reduces: + if not ( + isinstance(prod_reduce, Term) and _is_monoid_reduce(prod_reduce.op) + ): + return fwd() + prod_monoid: Monoid = prod_reduce.op.__self__ + prod_body = prod_reduce.args[0] + prod_streams = typing.cast(Mapping, prod_reduce.args[1]) + if not ( + distributes_over(prod_monoid, sum_monoid) + and (len(products) == 0 or products[-1][0] == prod_monoid) + ): + return fwd() + + if len(prod_streams) > 1 or len(prod_streams) == 0: + return fwd() + (prod_op, prod_stream) = next(iter(prod_streams.items())) + products.append( + (prod_monoid, deffn(prod_body, prod_op), prod_op, prod_stream) + ) + + assert len(products) > 0 + + for outer_sum_streams, cprod_op, cprod_term in inner_stream(sum_streams): + if not ( + isinstance(cprod_term, Term) + and cprod_term.op is CartesianProduct.reduce + ): + continue + (cprod_body, cprod_streams) = cprod_term.args + + if not all( + prod_stream.op == cprod_op for (_, _, _, prod_stream) in products + ): + continue + + prod_op = Operation.define(products[0][2]) + prod_monoid = products[0][0] + inner_sum = sum_monoid.reduce( + prod_monoid.plus( + *(prod_body(prod_op()) for (_, prod_body, _, _) in products) + ), + {prod_op: cprod_body}, + ) + prod = prod_monoid.reduce(inner_sum, cprod_streams) + outer_sum = ( + sum_monoid.reduce(prod, outer_sum_streams) + if outer_sum_streams + else prod + ) + return outer_sum + + return fwd() + + +class MonoidOverCallable(ObjectInterpretation): + """``monoid.reduce(f, streams) = lambda *a: monoid.reduce(f(*a), streams)``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Callable): + return fwd() + return lambda *a, **k: monoid.reduce(body(*a, **k), streams) + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or any( + isinstance(arg, Term) or not isinstance(arg, Callable) for arg in args + ): + return fwd() + return lambda *a, **k: monoid.plus(*(arg(*a, **k) for arg in args)) + + +class MonoidOverMapping(ObjectInterpretation): + """``monoid.reduce({k: v_k}, streams) = {k: monoid.reduce(v_k, streams)}``.""" + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if isinstance(body, Term) or not isinstance(body, Mapping): + return fwd() + return {k: monoid.reduce(v, streams) for (k, v) in body.items()} + + @implements(Monoid.plus) + def plus(self, monoid, *args): + if not args or not isinstance(args[0], Mapping): + return fwd() + + if isinstance(args[0], Interpretation): + keys = args[0].keys() + for b in args[1:]: + if not isinstance(b, Interpretation): + raise TypeError(f"Expected interpretation but got {b}") + if not keys == b.keys(): + raise ValueError( + f"Expected interpretation of {keys} but got {b.keys()}" + ) + return {k: monoid.plus(*(handler(b)(b[k]) for b in args)) for k in keys} + + for b in args[1:]: + if not isinstance(b, Mapping): + raise TypeError(f"Expected mapping but got {b}") + all_values = collections.defaultdict(list) + for d in args: + for k, v in d.items(): + all_values[k].append(v) + return {k: monoid.plus(*vs) for (k, vs) in all_values.items()} + + +def _scalar_args(args): + """True iff ``args`` is non-empty and every arg is a concrete int/float.""" + return ( + bool(args) + and not any(isinstance(x, Term) for x in args) + and all(isinstance(x, int | float) for x in args) + ) + + +class SumPlus(ObjectInterpretation): + """Scalar implementation of :data:`Sum`.""" + + @implements(Sum.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return sum(args) + + +class MinPlus(ObjectInterpretation): + """Scalar implementation of :data:`Min`.""" + + @implements(Min.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return min(args) + + +class MaxPlus(ObjectInterpretation): + """Scalar implementation of :data:`Max`.""" + + @implements(Max.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return max(args) + + +class ProductPlus(ObjectInterpretation): + """Scalar implementation of :data:`Product`.""" + + @implements(Product.plus) + def plus(self, *args): + if not _scalar_args(args): + return fwd() + return functools.reduce(operator.mul, args) + + +class ArgMinPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMin`.""" + + @implements(ArgMin.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return min(args, key=lambda a: a[0]) + + +class ArgMaxPlus(ObjectInterpretation): + """Scalar score implementation of :data:`ArgMax`.""" + + @implements(ArgMax.plus) + def plus(self, *args): + if not args or not all(isinstance(a, tuple) for a in args): + return fwd() + if any(isinstance(a[0], Term) for a in args): + return fwd() + if not all(isinstance(a[0], int | float) for a in args): + return fwd() + return max(args, key=lambda a: a[0]) + + +class CartesianProductPlus(ObjectInterpretation): + """Pure-Python implementation of :data:`CartesianProduct`.""" + + @implements(CartesianProduct.plus) + def plus(self, *args): + if not args: + return fwd() + if any(isinstance(x, Term) for x in args): + return fwd() + if not all(isinstance(x, Iterable) for x in args): + return fwd() + + def to_tuple(x): + return x if isinstance(x, tuple) else (x,) + + return [ + sum((to_tuple(v) for v in vals), ()) for vals in itertools.product(*args) + ] + + +is_scalar = _ExtensiblePredicate({Min, Max, Sum, Product}) + + +class MonoidOverSequence(ObjectInterpretation): + @implements(Monoid.plus) + def plus(self, monoid, *args): + if ( + not is_scalar(monoid) + or not args + or not isinstance(args[0], tuple | list | Generator) + ): + return fwd() + zipped = zip(*args, strict=True) + result = (monoid.plus(*vs) for vs in zipped) + if isinstance(args[0], tuple | list): + return type(args[0])(result) + return result + + @implements(Monoid.reduce) + def reduce(self, monoid, body, streams): + if not is_scalar(monoid) or not isinstance(body, tuple | list | Generator): + return fwd() + result = (monoid.reduce(x, streams) for x in body) + if isinstance(body, tuple | list): + return type(body)(result) + return result + + +class _ExtensibleInterpretation(UserDict, Interpretation): + def extend(self, *intps: Interpretation) -> typing.Self: + for intp in intps: + self.data = coproduct(self.data, intp) # type: ignore[assignment] + return self + + +NormalizeIntp = _ExtensibleInterpretation().extend( + MonoidOverSequence(), + MonoidOverMapping(), + MonoidOverCallable(), + ReduceNoStreams(), + ReduceFusion(), + ReduceSplit(), + ReduceFactorization(), + ReduceDistributeCartesianProduct(), + PlusEmpty(), + PlusSingle(), + PlusIdentity(), + PlusAssoc(), + PlusDistr(), + PlusZero(), + PlusConsecutiveDups(), + PlusDups(), + SumPlus(), + MinPlus(), + MaxPlus(), + ProductPlus(), + ArgMinPlus(), + ArgMaxPlus(), + CartesianProductPlus(), +) +"""``NormalizeIntp``applies pure-Term rewrites (associativity, distributivity, +identity elimination, fusion, factorization, etc.). + +""" diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f7678fd24..8fd62bcd5 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -209,6 +209,7 @@ def evaluate[T]( @evaluate.register(object) @evaluate.register(str) @evaluate.register(bytes) +@evaluate.register(range) def _evaluate_object[T](expr: T, **kwargs) -> T: if dataclasses.is_dataclass(expr) and not isinstance(expr, type): return typing.cast( diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 764016752..5ea04fcb8 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -849,9 +849,89 @@ def _(x: collections.abc.Sequence, other) -> bool: @syntactic_eq.register(object) @syntactic_eq.register(str | bytes) def _(x: object, other) -> bool: + if isinstance(other, Term): # Terms often override __eq__ + return False return x == other +@_CustomSingleDispatchCallable +def syntactic_hash(__dispatch: Callable[[type], Callable[[Any], int]], x) -> int: + """Structural hash compatible with :func:`syntactic_eq`. + + Guarantees that ``syntactic_eq(x, y)`` implies + ``syntactic_hash(x) == syntactic_hash(y)``. + + :param x: A term. + :returns: An integer hash. + """ + if dataclasses.is_dataclass(x) and not isinstance(x, type): + return hash( + ( + "dataclass", + type(x), + syntactic_hash( + { + field.name: getattr(x, field.name) + for field in dataclasses.fields(x) + } + ), + ) + ) + else: + return __dispatch(type(x))(x) + + +@syntactic_hash.register +def _(x: Term) -> int: + return hash( + ( + "term", + x.op, + len(x.args), + tuple(syntactic_hash(a) for a in x.args), + # sort kwargs so order doesn't affect the hash + tuple((k, syntactic_hash(x.kwargs[k])) for k in sorted(x.kwargs)), + ) + ) + + +@syntactic_hash.register +def _(x: collections.abc.Mapping) -> int: + # XOR over (key_hash, value_hash) pairs — order-independent, + # matching the set-based comparison in syntactic_eq's Mapping branch. + acc = 0 + for k in x: + acc ^= hash((hash(k), syntactic_hash(x[k]))) + return hash(("mapping", acc)) + + +@syntactic_hash.register +def _(x: collections.abc.Sequence) -> int: + if ( + isinstance(x, tuple) + and hasattr(x, "_fields") + and all(hasattr(x, f) for f in x._fields) + ): + return hash( + ( + "namedtuple", + type(x), + tuple(syntactic_hash(getattr(x, f)) for f in x._fields), + ) + ) + else: + # Use the abstract Sequence tag (not type(x)) because syntactic_eq + # treats any two Sequences of equal length and elementwise-equal + # contents as equal — e.g. [1,2] and (1,2) compare equal. + return hash(("sequence", len(x), tuple(syntactic_hash(a) for a in x))) + + +@syntactic_hash.register(object) +@syntactic_hash.register(str | bytes) +def _(x: object) -> int: + return hash(x) + + class ObjectInterpretation[T, V](collections.abc.Mapping): """A helper superclass for defining an ``Interpretation`` of many :class:`~effectful.ops.types.Operation` instances with shared state or behavior. diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 40c1f4af5..c68e0d46c 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -42,6 +42,59 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: return self.func(self.dispatch, *args, **kwargs) +class _CustomSingleDispatchMethod[**P, **Q, S, T]: + """Method analog of :class:`_CustomSingleDispatchCallable`. + + The wrapped function has signature ``(self, dispatch, *args, **kwargs)``, + where ``dispatch`` is :meth:`functools.singledispatch.dispatch`. As a + descriptor, it binds ``self`` on attribute access, so callers invoke it + as ``instance.method(*args, **kwargs)``. + """ + + def __init__( + self, + func: Callable[Concatenate[Any, Callable[[type], Callable[Q, S]], P], T], + ): + self.func = func + self._registry = functools.singledispatch(func) + self.__signature__ = inspect.signature( + functools.partial(func, None, None) # type: ignore[arg-type] + ) + functools.update_wrapper(self, func) # type: ignore[arg-type] + + @property + def dispatch(self): + return self._registry.dispatch + + @property + def register(self): + return self._registry.register + + def __get__(self, instance, owner=None): + if instance is None: + return self + return _BoundCustomSingleDispatchMethod(self, instance) + + +class _BoundCustomSingleDispatchMethod: + __slots__ = ("_method", "_instance") + + def __init__(self, method: _CustomSingleDispatchMethod, instance: Any): + self._method = method + self._instance = instance + + @property + def dispatch(self): + return self._method.dispatch + + @property + def register(self): + return self._method.register + + def __call__(self, *args, **kwargs): + return self._method.func(self._instance, self._method.dispatch, *args, **kwargs) + + class _ClassMethodOpDescriptor(classmethod): def __init__(self, define, *args, **kwargs): super().__init__(*args, **kwargs) @@ -311,6 +364,15 @@ def func(*args, **kwargs): return typing.cast(Operation[P, T], cls.define(func, **kwargs)) + @define.register(types.MethodType) + @classmethod + def _define_methodtype[**P, T]( + cls, t: Callable[P, T], *, name: str | None = None + ) -> "Operation[P, T]": + op = cls._define_callable(t, name=name) + op.__self__ = t.__self__ # type: ignore[attr-defined] + return typing.cast("Operation[P, T]", op) + @define.register(staticmethod) @classmethod def _define_staticmethod[**P, T](cls, t: "staticmethod[P, T]", **kwargs): @@ -350,6 +412,20 @@ def func(*args, **kwargs): op.register = default._registry.register # type: ignore[attr-defined] return op + @define.register(_CustomSingleDispatchMethod) + @classmethod + def _define_customsingledispatchmethod( + cls, default: _CustomSingleDispatchMethod, **kwargs + ): + @functools.wraps(default.func) + def _wrapper(obj, *args, **kwargs): + return default.__get__(obj)(*args, **kwargs) + + op = cls.define(_wrapper, **kwargs) + op.register = default.register # type: ignore[attr-defined] + op.dispatch = default.dispatch # type: ignore[attr-defined] + return op + @typing.final def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": """The default rule is used when the operation is not handled. @@ -488,7 +564,10 @@ def _instance_op(instance, *args, **kwargs): else: return default_result - instance_op = self.define(types.MethodType(_instance_op, instance)) + name = ("" if owner is None else f"{owner.__name__}_") + self.__name__ + instance_op = self.define( + types.MethodType(_instance_op, instance), name=name + ) instance.__dict__[self._name_on_instance] = instance_op return instance_op elif instance is not None: diff --git a/pyproject.toml b/pyproject.toml index d565403f2..685aaf55f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ test = [ "pytest-cov", "pytest-xdist", "pytest-benchmark", + "hypothesis", "mypy", "ruff", "nbval", diff --git a/tests/_monoid_helpers.py b/tests/_monoid_helpers.py new file mode 100644 index 000000000..f15103e30 --- /dev/null +++ b/tests/_monoid_helpers.py @@ -0,0 +1,355 @@ +import itertools +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from typing import Any, get_args, get_origin + +import jax +from hypothesis import given, settings +from hypothesis import strategies as st + +import effectful.handlers.jax.numpy as _jnp +from effectful.internals.runtime import interpreter +from effectful.ops.monoid import NormalizeIntp +from effectful.ops.semantics import apply, evaluate, handler +from effectful.ops.syntax import _BaseTerm, defdata, deffn, syntactic_eq +from effectful.ops.types import NotHandled, Operation, Term + +_JAX_ARRAY_SHAPE = (2,) + + +def _jax_array_value_strategy() -> st.SearchStrategy[jax.Array]: + return st.lists( + st.integers(min_value=-5, max_value=5), + min_size=_JAX_ARRAY_SHAPE[0], + max_size=_JAX_ARRAY_SHAPE[0], + ).map(lambda xs: jax.numpy.asarray(xs, dtype=jax.numpy.float32)) + + +# Unary jax fns map a scalar to a 1-D array (analogous to ``_UNARY_LIST_FNS`` +# for ints). Uses the effectful-wrapped jnp so named-dim broadcasting works. +_UNARY_JAX_FNS: list[Callable[[jax.Array], jax.Array]] = [ + lambda a: _jnp.stack([a, a + 1]), + lambda a: _jnp.stack([a, -a]), + lambda a: _jnp.stack([a, a + 1, 2 * a]), +] + +_BINARY_JAX_FNS: list[Callable[[jax.Array, jax.Array], jax.Array]] = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, +] + + +def _value_strategy_for(annotation: Any) -> st.SearchStrategy[Any]: + """Strategy for the value an *0-arg* Operation should return.""" + if annotation is int: + return st.integers(min_value=-100, max_value=100) + if annotation is float: + return st.floats(allow_nan=False) + if get_origin(annotation) is list and get_args(annotation) == (int,): + return st.lists(st.integers(min_value=-100, max_value=100), max_size=2) + if annotation is jax.Array: + return _jax_array_value_strategy() + if get_origin(annotation) is list and get_args(annotation) == (jax.Array,): + return st.lists(_jax_array_value_strategy(), max_size=2) + raise NotImplementedError( + f"No value strategy for return annotation {annotation!r}; " + "supported: int, list[int], jax.Array, list[jax.Array]" + ) + + +_UNARY_NUM_FNS: list[Callable[[int], int]] = [ + lambda x: x, + lambda x: x + 1, + lambda x: x - 1, + lambda x: -x, + lambda x: 2 * x, + lambda x: 3 * x + 1, +] + +_BINARY_NUM_FNS: list[Callable[[int, int], int]] = [ + lambda x, y: x + y, + lambda x, y: x - y, + lambda x, y: x * y, + lambda x, y: x + 2 * y, + lambda x, y: 2 * x - y, +] + +_UNARY_LIST_FNS: list[Callable[[int], list[int]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], + lambda x: [0, x, x + 1], +] + +_UNARY_JAX_LIST_FNS: list[Callable[[jax.Array], list[jax.Array]]] = [ + lambda _x: [], + lambda x: [x], + lambda x: [x, x + 1], + lambda x: [x, -x], +] + + +def _strategy_for_op(op: Operation) -> st.SearchStrategy[Callable[..., Any]]: + """Pick a strategy producing a callable suitable for binding `op` in an + interpretation. Inspects the operation's signature. + """ + sig = op.__signature__ + params = list(sig.parameters.values()) + ret = sig.return_annotation + param_types = tuple(p.annotation for p in params) + + if not params: + return _value_strategy_for(ret).map(deffn) + if ret in (int, float) and param_types == (int,): + return st.sampled_from(_UNARY_NUM_FNS) + if ret in (int, float) and param_types == (int, int): + return st.sampled_from(_BINARY_NUM_FNS) + if get_origin(ret) is list and get_args(ret) == (int,) and param_types == (int,): + return st.sampled_from(_UNARY_LIST_FNS) + if ret is jax.Array and param_types == (jax.Array,): + return st.sampled_from(_UNARY_JAX_FNS) + if ret is jax.Array and param_types == (jax.Array, jax.Array): + return st.sampled_from(_BINARY_JAX_FNS) + if ( + get_origin(ret) is list + and get_args(ret) == (jax.Array,) + and param_types == (jax.Array,) + ): + return st.sampled_from(_UNARY_JAX_LIST_FNS) + raise NotImplementedError( + f"No callable strategy for free var with return {ret!r}, params {param_types!r}" + ) + + +@st.composite +def random_interpretation( + draw: st.DrawFn, free_vars: Sequence[Operation] +) -> Mapping[Operation, Callable[..., Any]]: + """Draw an Interpretation binding every Operation in `case.free_vars` to + a randomly chosen value/callable. Keys are Operation identities. + """ + intp: dict[Operation, Callable[..., Any]] = {} + for op in free_vars: + intp[op] = draw(_strategy_for_op(op)) + return intp + + +def define_vars(*names, typ=int): + if len(names) == 1: + return Operation.define(typ, name=names[0]) + return tuple(Operation.define(typ, name=n) for n in names) + + +def syntactic_eq_alpha(x, y) -> bool: + """Alpha-equivalence-respecting variant of ``syntactic_eq``. + + Walks each expression bottom-up with :func:`evaluate` and renames + every bound variable to a deterministic canonical Operation. The + canonical names are assigned by a counter that increments in + ``evaluate``'s natural traversal order, so two alpha-equivalent + expressions canonicalize to syntactically identical results. + """ + + _op_cache: dict[int, Operation] = {} + + def _canonical_op(idx: int, op: Operation) -> Operation: + """Cached canonical Operation, keyed by encounter index. + + Cached so that two independent canonicalize runs return the same + Operation object for the same index — letting ``syntactic_eq`` + compare canonical forms by Operation identity. + """ + if idx in _op_cache: + return _op_cache[idx] + + op = Operation.define(op, name=f"__cv_{idx}") + _op_cache[idx] = op + return op + + cx = _canonicalize(x, _canonical_op) + cy = _canonicalize(y, _canonical_op) + return syntactic_eq(cx, cy) + + +def _canonicalize(expr, _canonical_op): + counter = itertools.count() + + def _substitute(arg, renaming): + """Apply a bound-variable renaming using ``evaluate`` for traversal.""" + if not renaming: + return arg + with interpreter({apply: _BaseTerm, **renaming}): + return evaluate(arg) + + def _bound_var_order(args, kwargs, bound_set: set[Operation]) -> list[Operation]: + """Return bound variables in deterministic encounter order.""" + seen: list[Operation] = [] + seen_set: set[Operation] = set() + + def _capture(op, *a, **kw): + if op in bound_set and op not in seen_set: + seen.append(op) + seen_set.add(op) + return defdata(op, *a, **kw) + + # ``evaluate`` walks Terms, lists, tuples, mappings, dataclasses, + # etc. for free; the apply handler captures bound vars used as + # ``x()`` anywhere in the body. + with interpreter({apply: _capture}): + evaluate((args, kwargs)) + + # Binders bypass the apply handler. Pick them up with a small structural + # walk that visits dict keys too. + def _walk_bare(obj): + if isinstance(obj, Operation): + if obj in bound_set and obj not in seen_set: + seen.append(obj) + seen_set.add(obj) + elif isinstance(obj, dict): + for k, v in obj.items(): + _walk_bare(k) + _walk_bare(v) + elif isinstance(obj, list | set | frozenset | tuple): + for v in obj: + _walk_bare(v) + + _walk_bare((args, kwargs)) + return seen + + def _apply_canonical(op, *args, **kwargs) -> Term: + bindings = op.__fvs_rule__(*args, **kwargs) + all_bound: set[Operation] = set().union( + *bindings.args, *bindings.kwargs.values() + ) + if not all_bound: + return _BaseTerm(op, *args, **kwargs) + + order = _bound_var_order(args, kwargs, all_bound) + canonical = {var: _canonical_op(next(counter), var) for var in order} + assert all_bound <= set(order) + + new_args = tuple( + _substitute( + arg, {v: canonical[v] for v in bindings.args[i] if v in canonical} + ) + for i, arg in enumerate(args) + ) + new_kwargs = { + k: _substitute( + v, + {var: canonical[var] for var in bindings.kwargs[k] if var in canonical}, + ) + for k, v in kwargs.items() + } + + # avoid the renaming from defdata + return _BaseTerm(op, *new_args, **new_kwargs) + + with interpreter({apply: _apply_canonical}): + return evaluate(expr) + + +@dataclass(frozen=True) +class Backend: + """A value-domain spec used to share monoid tests across int and jax.Array + backends. Provides the concrete value type, the hypothesis strategy for + drawing scalars in property tests, and an equality predicate that works + for that domain. + """ + + name: str + scalar_typ: Any + stream_typ: Any + scalar_strategy: st.SearchStrategy[Any] + eq: Callable[[Any, Any], bool] + + def fresh_op(self, name: str, n_args: int = 1, ret: str = "scalar") -> Operation: + """Build a fresh, unhandled Operation whose parameter and return + annotations are derived from this backend. + + ``ret`` is ``"scalar"`` for a scalar return or ``"stream"`` for a + stream-of-scalar return. The operation has ``n_args`` parameters, + each of type ``scalar_typ``. + """ + scalar = self.scalar_typ + out = self.stream_typ if ret == "stream" else scalar + params = ", ".join(f"_a{i}" for i in range(n_args)) + ns: dict[str, Any] = {"NotHandled": NotHandled} + exec(f"def _fn({params}):\n raise NotHandled\n", ns) + fn = ns["_fn"] + fn.__annotations__ = { + **{f"_a{i}": scalar for i in range(n_args)}, + "return": out, + } + return Operation.define(fn, name=name) + + +def _int_eq(a: Any, b: Any) -> bool: + return not isinstance(a, Term) and not isinstance(b, Term) and a == b + + +def _jax_eq(a: Any, b: Any) -> bool: + def _leaf_eq(x: Any, y: Any) -> bool: + return bool(jax.numpy.all(jax.numpy.isclose(x, y, equal_nan=True))) + + try: + leaves = jax.tree.leaves(jax.tree.map(_leaf_eq, a, b)) + except (ValueError, TypeError): + return False + return all(leaves) + + +def check_rewrite( + lhs, + rhs, + rule, + *, + backend: Backend, + free_vars=[], + max_examples: int = 25, + deadline=None, +) -> None: + with handler(rule): + norm = evaluate(lhs) + assert syntactic_eq_alpha(norm, rhs) + + @given(intp=random_interpretation(free_vars)) + @settings(max_examples=max_examples, deadline=deadline) + def _check_semantics(intp): + with handler(NormalizeIntp), handler(intp): + lhs_val = evaluate(lhs) + rhs_val = evaluate(rhs) + assert backend.eq(lhs_val, rhs_val) + + _check_semantics() + + +INT_BACKEND = Backend( + name="int", + scalar_typ=int, + stream_typ=list[int], + scalar_strategy=st.integers(min_value=-100, max_value=100), + eq=_int_eq, +) + + +JAX_BACKEND = Backend( + name="jax", + scalar_typ=jax.Array, + stream_typ=jax.Array, + scalar_strategy=_jax_array_value_strategy(), + eq=_jax_eq, +) + + +__all__ = [ + "Backend", + "INT_BACKEND", + "JAX_BACKEND", + "random_interpretation", + "define_vars", + "syntactic_eq_alpha", + "check_rewrite", +] diff --git a/tests/test_handlers_jax_monoid.py b/tests/test_handlers_jax_monoid.py new file mode 100644 index 000000000..35d041fe2 --- /dev/null +++ b/tests/test_handlers_jax_monoid.py @@ -0,0 +1,96 @@ +import jax +import pytest + +import effectful.handlers.jax.numpy as jnp +from effectful.handlers.jax import bind_dims, unbind_dims +from effectful.handlers.jax.monoid import ArrayReduce, LogSumExp +from effectful.handlers.jax.scipy.special import logsumexp +from effectful.ops.monoid import Max, Min, Product, Sum +from tests._monoid_helpers import JAX_BACKEND, Backend, check_rewrite, define_vars + +MONOIDS = [ + pytest.param(Sum, jnp.sum, id="Sum"), + pytest.param(Product, jnp.prod, id="Product"), + pytest.param(Min, jnp.min, id="Min"), + pytest.param(Max, jnp.max, id="Max"), + pytest.param(LogSumExp, logsumexp, id="LogSumExp"), +] + + +@pytest.fixture +def backend() -> Backend: + return JAX_BACKEND + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_1(monoid, reductor, backend: Backend): + (x, k) = define_vars("x", "k", typ=jax.Array) + X = define_vars("X", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {x: X()}) + rhs = reductor(bind_dims(unbind_dims(X(), k), k), axis=0) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ArrayReduce(), backend=backend, free_vars=[x, X, k] + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_2(monoid, reductor, backend: Backend): + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + (X, Y) = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: Y()}) + rhs = reductor( + bind_dims( + reductor( + bind_dims(f(unbind_dims(X(), k1), unbind_dims(Y(), k2)), k2), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, Y, f], + ) + + +@pytest.mark.parametrize("monoid,reductor", MONOIDS) +def test_reduce_array_3(monoid, reductor, backend: Backend): + """Stream `y` is `g(x())` — depends on the bound element of X. The reducer + must inline ``g`` along the same named dim used to unbind `x`.""" + (x, y, k1, k2) = define_vars("x", "y", "k1", "k2", typ=backend.scalar_typ) + X = define_vars("X", typ=backend.stream_typ) + + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") + + lhs = monoid.reduce(f(x(), y()), {x: X(), y: g(x())}) + rhs = reductor( + bind_dims( + reductor( + bind_dims( + f(unbind_dims(X(), k1), unbind_dims(g(unbind_dims(X(), k1)), k2)), + k2, + ), + axis=0, + ), + k1, + ), + axis=0, + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ArrayReduce(), + backend=backend, + free_vars=[x, y, k1, k2, X, f, g], + ) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index b56fd7bbd..9a2983901 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -240,7 +240,7 @@ def test_agent_tool_names_are_valid_integration(): agent = _ToolNameAgent() template = agent.ask tools = template.tools - expected_helper_tool_name = f"self__{agent.helper.__name__}" + expected_helper_tool_name = "self__helper" assert tools assert expected_helper_tool_name in tools assert all(re.fullmatch(r"[a-zA-Z0-9_-]+", name) for name in tools) diff --git a/tests/test_internals_disjoint_set.py b/tests/test_internals_disjoint_set.py new file mode 100644 index 000000000..808b8d25d --- /dev/null +++ b/tests/test_internals_disjoint_set.py @@ -0,0 +1,124 @@ +import random + +import pytest + +from effectful.internals.disjoint_set import DisjointSet + + +@pytest.fixture +def dsu(): + return DisjointSet(10) + + +def test_initial_state(dsu): + for i in range(10): + assert dsu.find(i) == i + + +def test_simple_union(dsu): + assert dsu.union(1, 2) is True + assert dsu.find(1) == dsu.find(2) + + +def test_union_idempotent(dsu): + dsu.union(1, 2) + assert dsu.union(1, 2) is False + + +def test_union_chain(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + assert dsu.find(1) == dsu.find(3) + + +def test_union_multiple_elements_all_connected(dsu): + dsu.union(1, 2, 3, 4, 5) + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_partial_overlap(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + dsu.union(2, 3, 5) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5]} + assert len(roots) == 1 + + +def test_union_multiple_elements_with_existing_connections(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4, 5, 6) + + roots = {dsu.find(i) for i in [1, 2, 3, 4, 5, 6]} + assert len(roots) == 1 + + +def test_union_single_element(dsu): + assert dsu.union(1) is False + + +def test_union_no_elements(dsu): + assert dsu.union() is False + + +def test_union_self(dsu): + assert dsu.union(3, 3) is False + assert dsu.find(3) == 3 + + +def test_transitivity(dsu): + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + assert dsu.find(1) == dsu.find(4) + + +def test_disjoint_sets_remain_separate(dsu): + dsu.union(1, 2) + dsu.union(3, 4) + assert dsu.find(1) != dsu.find(3) + + +def test_randomized_unions(): + n = 50 + dsu = DisjointSet(n) + + groups = [{i} for i in range(n)] + + def find_group(x): + for g in groups: + if x in g: + return g + + for _ in range(100): + elems = random.sample(range(n), random.randint(2, 5)) + dsu.union(*elems) + + # merge ground-truth groups + merged = set() + for e in elems: + merged |= find_group(e) + + groups = [g for g in groups if g.isdisjoint(merged)] + groups.append(merged) + + # verify structure matches ground truth + for g in groups: + roots = {dsu.find(x) for x in g} + assert len(roots) == 1 + + +def test_path_compression_effect(): + dsu = DisjointSet(6) + dsu.union(0, 1) + dsu.union(1, 2) + dsu.union(2, 3) + dsu.union(3, 4) + + # Trigger compression + root_before = dsu.find(4) + root_after = dsu.find(4) + + assert root_before == root_after diff --git a/tests/test_ops_monoid.py b/tests/test_ops_monoid.py new file mode 100644 index 000000000..c7ee7567c --- /dev/null +++ b/tests/test_ops_monoid.py @@ -0,0 +1,668 @@ +import typing + +import pytest +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st + +import effectful.handlers.jax.monoid # noqa: F401 +from effectful.ops.monoid import ( + CartesianProduct, + Max, + Min, + Monoid, + MonoidOverMapping, + MonoidOverSequence, + NormalizeIntp, + PlusAssoc, + PlusConsecutiveDups, + PlusDistr, + PlusDups, + PlusEmpty, + PlusIdentity, + PlusSingle, + PlusZero, + Product, + ReduceDistributeCartesianProduct, + ReduceFactorization, + ReduceFusion, + ReduceNoStreams, + ReduceSplit, + Sum, + distributes_over, +) +from effectful.ops.semantics import fvsof, handler +from effectful.ops.types import Operation +from tests._monoid_helpers import ( + INT_BACKEND, + JAX_BACKEND, + Backend, + check_rewrite, + define_vars, + syntactic_eq_alpha, +) + + +@pytest.fixture(params=[INT_BACKEND, JAX_BACKEND], ids=["int", "jax"]) +def backend(request) -> Backend: + return request.param + + +ALL_MONOIDS = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +COMMUTATIVE = [ + pytest.param(Sum, id="Sum"), + pytest.param(Product, id="Product"), + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +IDEMPOTENT = [ + pytest.param(Min, id="Min"), + pytest.param(Max, id="Max"), +] + +WITH_ZERO = [ + pytest.param(Product, id="Product"), +] + +# Pairs (outer, inner) such that inner distributes over outer — i.e. the lifting +# identity ``outer(inner(body, A), CartesianProduct...) == inner(outer(body, D), ...)`` +# is valid for that semiring pair. +MONOID_PAIRS = [ + pytest.param(o.values[0], i.values[0], id=f"{o.id}-{i.id}") + for o in ALL_MONOIDS + for i in ALL_MONOIDS + if distributes_over( + typing.cast(Monoid, i.values[0]), typing.cast(Monoid, o.values[0]) + ) +] + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_associativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + c = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + left = monoid.plus(monoid.plus(a, b), c) + right = monoid.plus(a, monoid.plus(b, c)) + assert backend.eq(left, right) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_identity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.identity, a), a) + assert backend.eq(monoid.plus(a, monoid.identity), a) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_commutativity(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + b = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, b), monoid.plus(b, a)) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_idempotence(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(a, a), a) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +@given(data=st.data()) +@settings( + max_examples=50, + deadline=None, + suppress_health_check=[HealthCheck.function_scoped_fixture], +) +def test_zero_absorbs(monoid, backend, data): + a = data.draw(backend.scalar_strategy) + with handler(NormalizeIntp): + assert backend.eq(monoid.plus(monoid.zero, a), monoid.zero) + assert backend.eq(monoid.plus(a, monoid.zero), monoid.zero) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_empty(monoid, backend): + check_rewrite( + lhs=monoid.plus(), rhs=monoid.identity, rule=PlusEmpty(), backend=backend + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_single(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(x()), rhs=x(), rule=PlusSingle(), backend=backend, free_vars=[x] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_right(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + + lhs = monoid.plus(x(), monoid.identity) + rhs = monoid.plus(x()) + + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_identity_left(monoid, backend): + x = define_vars("x", typ=backend.scalar_typ) + + lhs = monoid.plus(monoid.identity, x()) + rhs = monoid.plus(x()) + + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusIdentity(), backend=backend, free_vars=[x]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_right(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(x(), monoid.plus(y(), z())), + rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_assoc_left(monoid, backend): + x, y, z = define_vars("x", "y", "z", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus(monoid.plus(x(), y()), z()), + rhs=monoid.plus(x(), y(), z()), + rule=PlusAssoc(), + backend=backend, + free_vars=[x, y, z], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_sequence(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + check_rewrite( + lhs=monoid.plus((a(), b()), (c(), d())), + rhs=(monoid.plus(a(), c()), monoid.plus(b(), d())), + rule=MonoidOverSequence(), + backend=backend, + free_vars=[a, b, c, d], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_plus_mapping(monoid, backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + + lhs = monoid.plus({0: a(), 1: b()}, {0: c(), 2: d()}) + rhs = {0: monoid.plus(a(), c()), 1: monoid.plus(b()), 2: monoid.plus(d())} + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[a, b, c, d], + ) + + +def test_plus_distributes(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d())) + rhs = Product.plus( + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) + + +def test_plus_distributes_constant(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + lhs = Product.plus(Sum.plus(a(), b()), Sum.plus(c(), d()), 5) + rhs = Product.plus( + 5, + Sum.plus( + Product.plus(a(), c()), + Product.plus(a(), d()), + Product.plus(b(), c()), + Product.plus(b(), d()), + ), + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) + + +def test_plus_distributes_multiple(backend): + a, b, c, d = define_vars("a", "b", "c", "d", typ=backend.scalar_typ) + lhs = Sum.plus( + Min.plus(a(), b()), + Min.plus(c(), d()), + Max.plus(a(), b()), + Max.plus(c(), d()), + ) + rhs = Sum.plus( + Min.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + Max.plus( + Sum.plus(a(), c()), + Sum.plus(a(), d()), + Sum.plus(b(), c()), + Sum.plus(b(), d()), + ), + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=PlusDistr(), backend=backend, free_vars=[a, b, c, d] + ) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_consecutive(monoid, backend): + """``a, a, b → a, b`` — only consecutive duplicates collapse.""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), a(), b()) + return check_rewrite( + lhs=lhs, + rhs=monoid.plus(a(), b()), + rule=PlusConsecutiveDups(), + backend=backend, + free_vars=[a, b], + ) + + +@pytest.mark.parametrize("monoid", IDEMPOTENT) +def test_plus_idempotent_non_consecutive(monoid, backend): + """``a, b, a`` — Semilattice (Min/Max) collapses via commutative + PlusDups.""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), b(), a()) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", [Min, Max]) +def test_plus_commutative_idempotent_long(monoid, backend): + """Long alternation collapses via commutative dedup (Min/Max only).""" + a, b = define_vars("a", "b", typ=backend.scalar_typ) + lhs = monoid.plus(a(), b(), a(), b(), b(), a(), a()) + rhs = monoid.plus(a(), b()) + check_rewrite(lhs=lhs, rhs=rhs, rule=PlusDups(), backend=backend, free_vars=[a, b]) + + +@pytest.mark.parametrize("monoid", WITH_ZERO) +def test_plus_zero(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) + lhs_right = monoid.plus(a(), monoid.zero) + lhs_left = monoid.plus(monoid.zero, a()) + rhs = monoid.zero + check_rewrite( + lhs=lhs_right, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + check_rewrite( + lhs=lhs_left, rhs=rhs, rule=PlusZero(), backend=backend, free_vars=[a] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_1(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + lhs = monoid.reduce(x(), {x: []}) + rhs = monoid.identity + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: []}) + rhs = monoid.identity + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_3(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + Y = define_vars("Y", typ=backend.stream_typ) + + lhs = monoid.reduce(x(), {y: Y(), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: Y()}), monoid.reduce(b(), {y: Y()})) + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, Y]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_partial_4(monoid, backend): + x, y, a, b = define_vars("x", "y", "a", "b", typ=backend.scalar_typ) + f = backend.fresh_op("f", n_args=1, ret="stream") + + lhs = monoid.reduce(x(), {y: f(x()), x: [a(), b()]}) + rhs = monoid.plus(monoid.reduce(a(), {y: f(a())}), monoid.reduce(b(), {y: f(b())})) + + check_rewrite(lhs=lhs, rhs=rhs, rule={}, backend=backend, free_vars=[x, y, a, b, f]) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") + g = Operation.define(f, name="g") + + lhs = monoid.reduce((f(x()), g(x())), {x: X()}) + rhs = (monoid.reduce(f(x()), {x: X()}), monoid.reduce(g(x()), {x: X()})) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, f, g], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_sequence_2(monoid, backend): + x, y = define_vars("x", "y", typ=backend.scalar_typ) + X, Y = define_vars("X", "Y", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + g = Operation.define(f, name="g") + + lhs = monoid.reduce((f(x()), g(y())), {x: X(), y: Y()}) + rhs = ( + monoid.reduce(f(x()), {x: X(), y: Y()}), + monoid.reduce(g(y()), {x: X(), y: Y()}), + ) + + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverSequence(), + backend=backend, + free_vars=[X, Y, f, g], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_body_mapping(monoid, backend): + x = Operation.define(backend.scalar_typ, name="x") + X = Operation.define(backend.stream_typ, name="X") + f = backend.fresh_op("f", n_args=1, ret="scalar") + g = Operation.define(f, name="g") + + lhs = monoid.reduce({0: f(x()), 1: g(x())}, {x: X()}) + rhs = { + 0: monoid.reduce(f(x()), {x: X()}), + 1: monoid.reduce(g(x()), {x: X()}), + } + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=MonoidOverMapping(), + backend=backend, + free_vars=[X, f, g], + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_no_streams(monoid, backend): + a = define_vars("a", typ=backend.scalar_typ) + lhs = monoid.reduce(a(), {}) + rhs = monoid.identity + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceNoStreams(), backend=backend, free_vars=[a] + ) + + +@pytest.mark.parametrize("monoid", ALL_MONOIDS) +def test_reduce_reduce(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = monoid.reduce(monoid.reduce(f(a(), b()), {a: A()}), {b: B()}) + rhs = monoid.reduce(f(a(), b()), {a: A(), b: B()}) + + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFusion(), backend=backend, free_vars=[A, B, f] + ) + + +@pytest.mark.parametrize("monoid", COMMUTATIVE) +def test_reduce_plus(monoid, backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + lhs = monoid.reduce(monoid.plus(a(), b()), {a: A(), b: B()}) + rhs = monoid.plus( + monoid.reduce(a(), {a: A(), b: B()}), + monoid.reduce(b(), {a: A(), b: B()}), + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceSplit(), backend=backend, free_vars=[A, B] + ) + + +def test_reduce_independent_1(backend): + a, b = define_vars("a", "b", typ=backend.scalar_typ) + A, B = define_vars("A", "B", typ=backend.stream_typ) + lhs = Sum.reduce(Product.plus(a(), b()), {a: A(), b: B()}) + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), Sum.reduce(Product.plus(b()), {b: B()}) + ) + check_rewrite( + lhs=lhs, rhs=rhs, rule=ReduceFactorization(), backend=backend, free_vars=[A, B] + ) + + +def test_reduce_independent_2(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c())), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + Sum.reduce(Product.plus(a()), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) + + +def test_reduce_independent_3_negative(backend): + """Stream `b` depends on `a` (b: g(a())), so the proposed factorization + is unsound — the normalizer must NOT apply it.""" + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, C = define_vars("A", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + g = backend.fresh_op("g", n_args=1, ret="stream") + + with handler(ReduceFactorization()): # ty:ignore[invalid-argument-type] + lhs = Sum.reduce( + Product.plus(a(), b(), f(b(), c())), {a: A(), b: g(a()), c: C()} + ) + bogus_rhs = Product.plus( + Sum.reduce(a(), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: g(a()), c: C()}), + ) + assert fvsof(bogus_rhs) != fvsof(lhs) + assert not syntactic_eq_alpha(lhs, bogus_rhs) + + +def test_reduce_independent_4(backend): + a, b, c = define_vars("a", "b", "c", typ=backend.scalar_typ) + A, B, C = define_vars("A", "B", "C", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=2, ret="scalar") + + lhs = Sum.reduce(Product.plus(a(), b(), f(b(), c()), 7), {a: A(), b: B(), c: C()}) + rhs = Product.plus( + 7, + Sum.reduce(Product.plus(a()), {a: A()}), + Sum.reduce(Product.plus(b(), f(b(), c())), {b: B(), c: C()}), + ) + check_rewrite( + lhs=lhs, + rhs=rhs, + rule=ReduceFactorization(), + backend=backend, + free_vars=[A, B, C, f], + ) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_1(outer, inner, backend): + a, i = define_vars("a", "i", typ=backend.scalar_typ) + A, N, A_domain = define_vars("A", "N", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N()})}, + ) + term2 = inner.reduce(outer.reduce(inner.plus(f(a())), {a: A_domain()}), {i: N()}) + + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, A_domain, f], + ) + + +def test_reduce_cartesian_1(): + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) + + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: []}), + {A: CartesianProduct.reduce([], {i: []})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: []}), {i: []}) + assert term1 == term2 + + +def test_reduce_cartesian_2(): + a, i = define_vars("a", "i", typ=int) + A = define_vars("A", typ=tuple[int]) + + with handler(NormalizeIntp): + term1 = Sum.reduce( + Product.reduce(a(), {a: A()}), + {A: CartesianProduct.reduce([(0,)], {i: [0]})}, + ) + term2 = Product.reduce(Sum.reduce(a(), {a: [0]}), {i: [0]}) + assert term1 == term2 + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_multi_index(outer, inner, backend): + a, i, j = define_vars("a", "i", "j", typ=backend.scalar_typ) + A, N, M, A_domain = define_vars("A", "N", "M", "A_domain", typ=backend.stream_typ) + f = backend.fresh_op("f", n_args=1, ret="scalar") + + term1 = outer.reduce( + inner.reduce(f(a()), {a: A()}), + {A: CartesianProduct.reduce(A_domain(), {i: N(), j: M()})}, + ) + term2 = inner.reduce( + outer.reduce(inner.plus(f(a())), {a: A_domain()}), + {i: N(), j: M()}, + ) + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[N, M, A_domain, f], + ) + + +@pytest.mark.parametrize("outer,inner", MONOID_PAIRS) +def test_reduce_lifted_2(outer, inner, backend): + """The worked example on page 396 of 'Lifted Variable Elimination: + Decoupling the Operators from the Constraint Language'. + + """ + a, i, s, t = define_vars("a", "i", "s", "t", typ=backend.scalar_typ) + A, N, T = define_vars("A", "N", "T", typ=backend.stream_typ) + A_domain = backend.fresh_op("A_domain", n_args=1, ret="stream") + f1 = backend.fresh_op("f1", n_args=2, ret="scalar") + f2 = backend.fresh_op("f2", n_args=2, ret="scalar") + + term1 = outer.reduce( + inner.reduce(inner.plus(f1(a(), s()), f2(t(), a())), {a: A()}), + {A: CartesianProduct.reduce(A_domain(i()), {i: N()}), t: T()}, + ) + + term2 = outer.reduce( + inner.reduce( + outer.reduce( + inner.plus(inner.plus(f1(a(), s()), f2(t(), a()))), {a: A_domain(i())} + ), + {i: N()}, + ), + {t: T()}, + ) + + check_rewrite( + lhs=term1, + rhs=term2, + rule=ReduceDistributeCartesianProduct(), + backend=backend, + free_vars=[a, i, s, t, A, N, T, A_domain, f1, f2], + ) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 185b6132e..1f5c47763 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -489,7 +489,6 @@ def _(self, x: bool) -> bool: ) assert isinstance(term_float, Term) - assert term_float.op.__name__ == "my_singledispatch" assert term_float.args == (1.5,) assert term_float.kwargs == {}