From 7da9ac94c85450e2f543b1751740fa20b3777920 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 18 Jun 2026 16:22:23 +0200 Subject: [PATCH 1/8] Add dedup_nodes option to FunctionGraph.freeze, defaulting to no-dedup FrozenApply interns structurally-identical nodes globally, which is destroy-unaware: it folds two distinct-but-equal inplace buffers (that MergeOptimizer deliberately kept apart) into one, aliasing them. freeze() therefore defaults to dedup_nodes=False: FrozenFunctionGraph folds each node's toposort position into the FrozenApply intern key, so structurally-identical nodes within a graph stay distinct (buffers don't alias) while two structurally-identical graphs still compare equal (positions line up), preserving cross-graph equality / op-merge / funcify-cache reuse. dedup_nodes=True is opt-in for inner graphs known to be inplace-free (Composite, ScalarLoop) and for from_structural_inputs, whose structural matching needs the global interning. --- pytensor/graph/basic.py | 33 ++++++++++++++++++++++++----- pytensor/graph/fg.py | 41 +++++++++++++++++++++++++++++-------- pytensor/graph/traversal.py | 3 +-- pytensor/scalar/basic.py | 4 +++- pytensor/scalar/loop.py | 4 +++- 5 files changed, 68 insertions(+), 17 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 92d7881234..cc05fb4f82 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -837,24 +837,38 @@ class FrozenApply(Apply): """ _cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary() + _unique_idx: int | None def __new__( cls, op: "Op", inputs: tuple[Variable, ...], output_types: tuple["Type", ...], + unique_idx: int | None = None, ): - cache_key = ( + cache_key: tuple | None = ( op, tuple(i.signature() if isinstance(i, Constant) else id(i) for i in inputs), output_types, + unique_idx, ) - cached = cls._cache.get(cache_key) + try: + cached = cls._cache.get(cache_key) + except (AttributeError, TypeError): + # Hashing ``cache_key`` hashes ``op``. During a cyclic unpickle (e.g. + # numba's on-disk function cache reloading a closure that holds both + # ``op`` and this node) ``op`` can still be a half-built stub whose + # ``__props__`` aren't restored yet, so it isn't hashable. Skip + # interning for this node -- the unpickled graph is self-contained and + # internally deduped by pickle's memo. + cache_key = None + cached = None if cached is not None: return cached instance = object.__new__(cls) instance.op = op + instance._unique_idx = unique_idx instance.inputs = inputs # type: ignore[assignment] instance.outputs = tuple( # type: ignore[assignment] t.variable_type(type=t, owner=instance, index=i) @@ -865,10 +879,11 @@ def __new__( for out in instance.outputs: out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign] instance.tag = Scratchpad() - cls._cache[cache_key] = instance + if cache_key is not None: + cls._cache[cache_key] = instance return instance - def __init__(self, op, inputs, output_types): + def __init__(self, op, inputs, output_types, unique_idx=None): # All initialization is done in __new__ pass @@ -877,7 +892,15 @@ def clone(self, clone_inner_graph: bool = False) -> Self: return self def __reduce__(self): - return (type(self), (self.op, self.inputs, tuple(o.type for o in self.outputs))) + return ( + type(self), + ( + self.op, + self.inputs, + tuple(o.type for o in self.outputs), + self._unique_idx, + ), + ) def clone( diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index b43c3cdd8b..8314bb5b5e 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -936,9 +936,19 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) - def freeze(self) -> "FrozenFunctionGraph": - """Return a frozen, hashable version of this FunctionGraph.""" - return FrozenFunctionGraph(self.inputs, self.outputs) + def freeze(self, dedup_nodes: bool = False) -> "FrozenFunctionGraph": + """Return a frozen, hashable version of this FunctionGraph. + + By default (``dedup_nodes=False``) structurally-identical nodes are kept + as distinct interned nodes (keyed by toposort position) instead of being + folded into one. This preserves distinct buffers that downstream + inplace/``destroy_map`` logic relies on (which a blind structural fold + would alias), while still keeping two structurally-identical *graphs* + equal because positions line up. Pass ``dedup_nodes=True`` only for inner + graphs known to be free of inplace ops (e.g. ``Composite``/``ScalarLoop``) + to additionally fold structurally-identical nodes within the graph. + """ + return FrozenFunctionGraph(self.inputs, self.outputs, dedup_nodes=dedup_nodes) class FrozenFunctionGraph(AbstractFunctionGraph): @@ -969,7 +979,9 @@ def __init__( self, inputs: Sequence[Variable], outputs: Sequence[Variable], + dedup_nodes: bool = False, ): + self._dedup_nodes = dedup_nodes nominal_inputs = tuple( NominalVariable(i, inp.type) for i, inp in enumerate(inputs) ) @@ -990,10 +1002,15 @@ def _resolve_input(inp, memo=memo): "or produced by Apply nodes reachable from the inputs." ) - for node in toposort(outputs, blockers=inputs): + for node_idx, node in enumerate(toposort(outputs, blockers=inputs)): new_inputs = tuple(_resolve_input(inp) for inp in node.inputs) output_types = tuple(out.type for out in node.outputs) - new_node = FrozenApply(node.op, new_inputs, output_types) + new_node = FrozenApply( + node.op, + new_inputs, + output_types, + unique_idx=None if dedup_nodes else node_idx, + ) sorted_apply_nodes.append(new_node) memo.update(zip(node.outputs, new_node.outputs, strict=True)) @@ -1058,12 +1075,20 @@ def from_structural_inputs( roots = [ v for v in graph_inputs([*inputs, *outputs]) if not isinstance(v, Constant) ] - interned = cls(roots, [*inputs, *outputs]) + # Structural matching requires deduplication: each intermediate input + # expression must intern onto the same node as its occurrences in the + # outputs so they can be rewired (position-keyed interning would keep + # them distinct). + interned = cls(roots, [*inputs, *outputs], dedup_nodes=True) n_inputs = len(inputs) - return cls(interned.outputs[:n_inputs], interned.outputs[n_inputs:]) + return cls( + interned.outputs[:n_inputs], + interned.outputs[n_inputs:], + dedup_nodes=True, + ) def __reduce__(self): - return FrozenFunctionGraph, (self.inputs, self.outputs) + return FrozenFunctionGraph, (self.inputs, self.outputs, self._dedup_nodes) def __hash__(self): return hash(self._output_nodes) diff --git a/pytensor/graph/traversal.py b/pytensor/graph/traversal.py index 50b3359ff8..2b92fdd6ca 100644 --- a/pytensor/graph/traversal.py +++ b/pytensor/graph/traversal.py @@ -705,8 +705,7 @@ def compute_deps(obj, orderings=orderings): # type: ignore[misc] yield from ( apply for apply in walk_toposort(graphs, deps=compute_deps) - # mypy doesn't understand that our generator will return both Apply and Variables - if isinstance(apply, Apply) # type: ignore + if isinstance(apply, Apply) ) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 196124d72e..33581346f2 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4238,7 +4238,9 @@ def __init__( for i in inputs: assert i not in outputs # This isn't supported, use identity - self.fgraph = FrozenFunctionGraph(inputs, outputs) + # Composite inner graphs have no inplace ops, so structurally-identical + # nodes can be safely deduplicated. + self.fgraph = FrozenFunctionGraph(inputs, outputs, dedup_nodes=True) self._validate_inner_graph(self.fgraph) self.inputs = self.fgraph.inputs diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 1a4b30a008..86e0958ef4 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -61,7 +61,9 @@ def __init__( self.is_while = until is not None - self.fgraph = FrozenFunctionGraph(inputs, outputs) + # ScalarLoop inner graphs have no inplace ops, so structurally-identical + # nodes can be safely deduplicated. + self.fgraph = FrozenFunctionGraph(inputs, outputs, dedup_nodes=True) self._validate_inner_graph(self.fgraph) self.inputs = self.fgraph.inputs self.outputs = self.fgraph.outputs From 2912186a96ba42093ad6d832ab14bd52ab9ec2b6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 18 Jun 2026 16:22:05 +0200 Subject: [PATCH 2/8] Stop FusedElemwise passing deprecated accept_inplace compile_kwarg OpFromGraph now deprecates and ignores compile_kwargs (the inner graph inherits the outer compilation). FusedElemwise still passed accept_inplace=True, so it warned at itself with a FutureWarning whenever elemwise fusion ran under filterwarnings=error. The kwarg was already a no-op, so drop it. --- pytensor/tensor/rewriting/indexed_elemwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/indexed_elemwise.py b/pytensor/tensor/rewriting/indexed_elemwise.py index b3ff3828b3..38d816d5cf 100644 --- a/pytensor/tensor/rewriting/indexed_elemwise.py +++ b/pytensor/tensor/rewriting/indexed_elemwise.py @@ -261,7 +261,7 @@ def __init__(self, *args, indexed_inputs=(), indexed_outputs=(), **kwargs): # safe because reads don't destroy. Write targets always get their own # fresh inner input (see FuseIndexedElemwise) so a destroyed buffer is # never deduped onto a read source. - super().__init__(*args, on_unused_input="ignore", accept_inplace=True, **kwargs) + super().__init__(*args, on_unused_input="ignore", **kwargs) def __str__(self): for node in self.fgraph.apply_nodes: From 42f8def437ead25b9ae610211a5e6782a5cb2bc0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 16 Jun 2026 19:33:13 +0200 Subject: [PATCH 3/8] Optimize Scan/OpFromGraph inner graphs via a rewrite Move inner-graph optimization of Scan/OpFromGraph out of the linker dispatch and make_thunk (where it mutated the canonical, shared inner FunctionGraph in place) into a single `optimize_inner_graphs` rewrite that runs near the end of optimization and produces NEW immutable ops. This fixes the bug where compiling a graph (e.g. a scan-based pymc CustomDist under numba) corrupted the canonical inner graph and broke later logp derivation. Key pieces: - `GraphRewriter.rewrite` advertises the active rewriter on the fgraph and `FunctionMaker` attaches the compile mode; `get_active_rewriter`/ `get_active_mode` let the rewrite reuse the exact outer compilation (mode, fast_run/fast_compile, linker required/incompatible rewrites). `config.mode` is unreliable across nested compilations. - `optimize_inner_graphs` (per-op, recursive): unfreezes each inner graph, optimizes it with the transferred outer query, and rebuilds via `clone_with_inner_graph` (OpFromGraph/Scan). Skips `*WithCoreShape` and ScipyWrapperOp (MinimizeOp/RootOp still keep a mutable inner graph; they are made frozen and folded in a later commit). - The inner optimizer query excludes `symbolic_op_recognition`: those rewrites fold a pattern into an inner-graph op (e.g. `exp(x) / sum(exp(x))` -> `Softmax`, itself an OpFromGraph whose inner graph is that very pattern), so running them on the inner graph would re-create the op and recurse without end (RecursionError compiling any softmax under FAST_RUN or NUMBA). The old link-time inner-graph optimization excluded it too. - Registered at optdb position 49.6 (fast_run/fast_compile) and required for the numba/jax/pytorch/mlx linkers. - All linker dispatches + Scan.fn now operate on a throwaway `unfreeze()` copy, never the canonical op. - Deprecations: Scan `mode` (still honored/reconciled with the active linker) and OpFromGraph `compile_kwargs` (now ignored; `on_unused_input` subsumed as the default, governed by `strict`). Inner-graph inplace is baked into the frozen op, so dispatch is rewrite-free: - optimize_inner_graphs groups apply NODES (OpFromGraph by op; Scan by destroyable-tap signature, since tap-inplace depends on outer buffer shapes), attaches a Supervisor (OFG protects all inputs; Scan protects all but destroyable taps), runs the optimizer including inplace, and bakes the result into each new frozen op (frozen dedup_nodes=False so distinct inplace buffers survive; see gh #2194). - Scan.inner_destroyable_inputs factors out the destroyable-tap logic shared by the rewrite and numba_funcify_Scan. - OFG/Scan inner graphs freeze with dedup_nodes=False; OpFromGraph.fn compiles with accept_inplace=True (inner graphs may carry inplace that only destroys internal buffers; inputs stay protected). - numba_funcify_Scan/OpFromGraph drop the redundant optimizer pass and just funcify the baked graph (supervisor/insert_deepcopy stay as link-time aliasing setup). Regression test: compiling must not mutate the canonical inner graph. --- pytensor/compile/builders.py | 59 +++++- pytensor/compile/inner_graph_rewriting.py | 184 ++++++++++++++++++ pytensor/compile/maker.py | 9 +- pytensor/compile/mode.py | 16 ++ pytensor/graph/rewriting/basic.py | 48 ++++- pytensor/link/jax/dispatch/basic.py | 3 - pytensor/link/jax/dispatch/scan.py | 9 - pytensor/link/jax/dispatch/tensor_basic.py | 29 ++- pytensor/link/jax/linker.py | 1 + pytensor/link/mlx/dispatch/basic.py | 2 - pytensor/link/mlx/linker.py | 4 + pytensor/link/numba/dispatch/compile_ops.py | 13 +- pytensor/link/numba/dispatch/scan.py | 61 +----- pytensor/link/numba/linker.py | 1 + pytensor/link/pytorch/dispatch/basic.py | 12 -- pytensor/link/pytorch/linker.py | 4 + pytensor/scan/basic.py | 9 + pytensor/scan/op.py | 203 ++++++++++++++++++-- pytensor/scan/scan_perform.pyx | 14 +- tests/compile/test_builders.py | 21 +- tests/link/numba/test_compile_ops.py | 34 +++- 21 files changed, 609 insertions(+), 127 deletions(-) create mode 100644 pytensor/compile/inner_graph_rewriting.py diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 123f64bf0d..fa6b8e8c83 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -341,7 +341,10 @@ def __init__( self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( inputs, outputs ) - self._frozen_fgraph = self.fgraph.freeze() + # ``dedup_nodes=False``: inner graphs may carry inplace ops whose + # destroyed buffers must stay distinct; structural folding would alias + # them. See ``FunctionGraph.freeze``. + self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) if strict and self.shared_inputs: raise ValueError( @@ -349,6 +352,20 @@ def __init__( f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}" ) + # `compile_kwargs` used to control how the inner graph was compiled. + # That is now the job of the `optimize_inner_graphs` rewrite (which + # inherits the outer compilation), so they are deprecated AND ignored: + # the inner function is compiled with default settings (see `fn`). + # `on_unused_input` is exempt: tolerating unused inputs is now the + # default behavior, so passing it is a harmless no-op (not warned). + deprecated_kwargs = {k for k in kwargs if k != "on_unused_input"} + if deprecated_kwargs: + warnings.warn( + "Passing `compile_kwargs` to `OpFromGraph` is deprecated and " + "now ignored: the inner graph inherits the outer compilation. " + f"Ignored: {sorted(deprecated_kwargs)}.", + FutureWarning, + ) self.kwargs = kwargs self.input_types = [inp.type for inp in inputs] self.output_types = [out.type for out in outputs] @@ -919,9 +936,20 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - kwargs = self.kwargs.copy() - mode = get_mode(kwargs.pop("mode", None)).excluding("symbolic_op_recognition") - self._fn = function(self.inner_inputs, self.inner_outputs, mode=mode, **kwargs) + # `compile_kwargs` are deprecated and ignored; compile the (already + # optimized) inner graph with default settings. Inner graphs commonly + # have unused inputs (e.g. rng, size), which are tolerated by default. + # They may also carry inplace ops baked in by ``optimize_inner_graphs`` + # (or built inplace, e.g. ``FusedElemwise``); those only ever destroy + # internal buffers (inputs stay protected), so we accept them. + mode = get_mode(None).excluding("symbolic_op_recognition") + self._fn = function( + self.inner_inputs, + self.inner_outputs, + mode=mode, + on_unused_input="ignore", + accept_inplace=True, + ) self._fn.trust_input = True return self._fn @@ -939,6 +967,29 @@ def clone(self): res.fgraph = res.fgraph.clone(clone_inner_graphs=True) return res + def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph: + """Return a copy of this op whose inner graph is ``inner_fgraph``. + + Used by the ``optimize_inner_graphs`` rewrite to build a new op + carrying an already-optimized inner graph without ever mutating + ``self``. The subclass and all properties/overrides are preserved + (via ``copy``); only the inner graph and the state derived from it are + rebuilt. + """ + new = copy(self) + new.fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph( + list(inner_fgraph.inputs), list(inner_fgraph.outputs) + ) + new._frozen_fgraph = new.fgraph.freeze(dedup_nodes=False) + new.input_types = [inp.type for inp in new.fgraph.inputs] + new.output_types = [out.type for out in new.fgraph.outputs] + # Drop caches tied to the previous inner graph. + new._lop_op_cache = {} + new._rop_op_cache = None + new._frozen_lop = None + new._frozen_rop = None + return new + def perform(self, node, inputs, outputs): variables = self.fn(*inputs) # zip strict not specified because we are in a hot loop diff --git a/pytensor/compile/inner_graph_rewriting.py b/pytensor/compile/inner_graph_rewriting.py new file mode 100644 index 0000000000..7d7f65bcc7 --- /dev/null +++ b/pytensor/compile/inner_graph_rewriting.py @@ -0,0 +1,184 @@ +"""Rewrite that optimizes the inner graphs of inner-graph ops once. + +Historically each linker/``make_thunk`` optimized the inner ``FunctionGraph`` +of ``Scan``/``OpFromGraph``/``ScipyWrapperOp`` lazily and *in place* on the +canonical (shared) op. That mutated graphs that are supposed to be immutable +and leaked backend-specific rewrites onto ops reused elsewhere. + +Instead, :class:`OptimizeInnerGraphs` runs once near the end of outer +optimization. For each unique inner-graph op it unfreezes the canonical inner +graph (a throwaway copy), optimizes that copy with the *same* outer query +(transferred via :func:`get_active_rewriter`), and builds a NEW immutable op +via ``clone_with_inner_graph`` -- the canonical op is never touched. Linkers +then simply funcify the already-optimized inner graph. +""" + +# Most imports are deferred into ``apply`` to avoid an import cycle: this module +# is imported by ``pytensor.compile.mode`` at startup, while the ops below pull +# in ``mode`` transitively. +from collections import defaultdict + +from pytensor.graph.rewriting.basic import GraphRewriter, get_active_mode + + +def _scan_destroyable_signature(op, node): + """Hashable signature of which taps are destroyable for this outer node. + + Two ``Scan`` nodes sharing an op can share a baked inner graph iff they agree + on this (see ``Scan.inner_destroyable_inputs``). Untraced sit_sot are always + destroyable and mit_mot never, so only the sit_sot / mit_sot masks vary. + """ + sitsot_mask = tuple( + outer.type.shape[0] == 1 for outer in op.outer_sitsot(node.inputs) + ) + mitsot_mask = tuple( + outer.type.shape[0] == abs(min(taps)) + for outer, taps in zip( + op.outer_mitsot(node.inputs), op.info.mit_sot_in_slices, strict=True + ) + ) + return (sitsot_mask, mitsot_mask) + + +class OptimizeInnerGraphs(GraphRewriter): + """Optimize inner graphs of inner-graph ops, producing new immutable ops.""" + + def add_requirements(self, fgraph): + pass + + def apply(self, fgraph): + from pytensor.compile.aliasing import add_supervisor_to_fgraph + from pytensor.compile.builders import OpFromGraph + from pytensor.compile.io import In + from pytensor.compile.mode import get_mode + from pytensor.configdefaults import config + from pytensor.graph.features import NoOutputFromInplace + from pytensor.link.basic import JITLinker + from pytensor.scan.op import Scan + from pytensor.tensor.random.op import OpWithCoreShape + + # NOTE: ``ScipyWrapperOp`` (MinimizeOp/RootOp) still keeps a *mutable* + # inner ``fgraph`` (it has no frozen graph yet), so it cannot be unfrozen + # here. Its inner graph is optimized lazily at link time by ``build_fn``. + # Add it back once it is made frozen. + inner_graph_types: tuple = (OpFromGraph, Scan) + + # Group apply NODES by an optimization key, so each distinct inner graph + # is optimized once and shared across its nodes (an op reused in many + # nodes -> one optimization; cache reuse holds). The key is the op, + # except for ``Scan``: its tap-inplace depends on the *outer* node's + # buffer shapes, so two nodes sharing a ``Scan`` op but with different + # destroyable taps need different baked inner graphs -- the key then also + # carries the destroyable signature. + groups: dict = defaultdict(list) + for node in fgraph.apply_nodes: + op = node.op + if not isinstance(op, inner_graph_types) or isinstance(op, OpWithCoreShape): + # ``*WithCoreShape`` are leaf backend ops with dedicated dispatch; + # re-optimizing/re-wrapping them would loop. + continue + if isinstance(op, Scan): + groups[(op, _scan_destroyable_signature(op, node))].append(node) + else: + groups[op].append(node) + if not groups: + return + + # Optimize each inner graph with the *active outer compilation* (mode + + # linker), recovered reliably from the fgraph (``config.mode`` gets + # coerced across nested compilations). The query includes + # ``optimize_inner_graphs``, so nested inner-graph ops recurse -- we + # propagate the mode onto each inner graph so the recursion recovers it. + mode = get_active_mode(fgraph) + # Inplace is baked into the frozen op (frozen ``dedup_nodes=False`` so + # distinct inplace buffers survive interning). ``OpFromGraph`` must not + # mutate its inputs, so all are protected; ``Scan`` protects all but the + # destroyable taps. ``ScipyWrapperOp`` keeps inplace excluded (separate + # work). + # Exclude ``symbolic_op_recognition``: those rewrites fold a pattern into + # an inner-graph op (e.g. ``exp(x) / sum(exp(x))`` -> ``Softmax``, itself + # an ``OpFromGraph``). Running them on the inner graph -- which *is* that + # pattern -- would re-create the op, and optimizing its inner graph would + # recurse without end. The old link-time inner optimization excluded it + # too. + inplace_optimizer = mode.excluding("symbolic_op_recognition").optimizer + noinplace_optimizer = mode.excluding( + "inplace", "symbolic_op_recognition" + ).optimizer + + node_to_new_op: dict = {} + for nodes in groups.values(): + rep_node = nodes[0] + op = rep_node.op + inner = op._frozen_fgraph.unfreeze() + inner._compile_mode = mode + + custom_mode = getattr(op, "mode", None) + if isinstance(op, OpFromGraph): + input_specs = [In(x, borrow=True, mutable=False) for x in inner.inputs] + add_supervisor_to_fgraph( + fgraph=inner, input_specs=input_specs, accept_inplace=True + ) + optimizer = inplace_optimizer + elif isinstance(op, Scan): + destroyable = op.inner_destroyable_inputs(rep_node.inputs, inner.inputs) + input_specs = [ + In(x, borrow=True, mutable=x in destroyable) for x in inner.inputs + ] + add_supervisor_to_fgraph( + fgraph=inner, input_specs=input_specs, accept_inplace=True + ) + # Protect the same tap outputs ``Scan.prepare_fgraph`` does, so the + # inplace baked here never makes a protected output the result of a + # destroy-map node. Otherwise the CVM/Python ``Scan.fn`` path, which + # re-attaches ``NoOutputFromInplace`` and re-optimizes, would refuse + # the already-baked graph (``InconsistencyError`` under + # ``on_opt_error=raise``). This protection is specific to the + # non-jit linkers: jit backends (numba/jax) rely on inner tap + # inplace and copy outputs themselves (``insert_deepcopy``), and + # ``prepare_fgraph`` only installs it when output preallocation is + # enabled -- match both conditions so the baked graph agrees with + # whatever ``Scan.fn`` would produce. + if ( + not isinstance(mode.linker, JITLinker) + and config.scan__allow_output_prealloc + ): + inner.attach_feature( + NoOutputFromInplace(op.protected_inner_out_idxs()) + ) + if custom_mode is not None: + # A custom Scan ``mode`` is still honored (deprecated; warned + # at op creation), combined with the active linker's + # required/incompatible rewrites so backend must-have ops + # still apply. + linker = mode.linker + optimizer = ( + get_mode(custom_mode) + .including(*linker.required_rewrites) + .excluding( + *linker.incompatible_rewrites, "symbolic_op_recognition" + ) + .optimizer + ) + else: + optimizer = inplace_optimizer + else: + optimizer = noinplace_optimizer + + optimizer.rewrite(inner) + new_op = op.clone_with_inner_graph(inner) + if new_op != op: + for node in nodes: + node_to_new_op[node] = new_op + + if not node_to_new_op: + return + + for node in fgraph.toposort(): + new_op = node_to_new_op.get(node) + if new_op is not None: + new_node = new_op.make_node(*node.inputs) + fgraph.replace_all( + list(zip(node.outputs, new_node.outputs, strict=True)), + reason="optimize_inner_graphs", + ) diff --git a/pytensor/compile/maker.py b/pytensor/compile/maker.py index 0073294012..fbe79e7585 100644 --- a/pytensor/compile/maker.py +++ b/pytensor/compile/maker.py @@ -469,7 +469,14 @@ def prepare_fgraph( mode=mode, traceback__limit=config.traceback__compile_limit, ): - rewriter_profile = rewriter(fgraph) + # Expose the compile mode so inner-graph rewrites can recover + # the active linker's required/incompatible rewrites reliably + # (``config.mode`` is unreliable across nested compilations). + fgraph._compile_mode = mode + try: + rewriter_profile = rewriter(fgraph) + finally: + fgraph._compile_mode = None end_rewriter = time.perf_counter() rewrite_time = end_rewriter - start_rewriter diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 69e0b1cb85..e4eeca0f8e 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -264,6 +264,22 @@ def apply(self, fgraph): "add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5 ) +# Optimize the inner graphs of inner-graph ops (Scan/OpFromGraph/...) once, +# producing new immutable ops. Runs for every backend (so the C/VM and Python +# linkers get inner-graph optimization too) and is required by the JIT linkers +# (see their ``required_rewrites``). The registration name doubles as its tag. +# Imported here (after `optdb` is defined) to avoid an import cycle. +from pytensor.compile.inner_graph_rewriting import OptimizeInnerGraphs # noqa: E402 + + +optdb.register( + "optimize_inner_graphs", + OptimizeInnerGraphs(), + "fast_run", + "fast_compile", + position=49.6, +) + # final pass just to make sure optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100) diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index e39465f416..5f3111c117 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -107,7 +107,19 @@ def rewrite(self, fgraph, *args, **kwargs): """ self.add_requirements(fgraph) - return self.apply(fgraph, *args, **kwargs) + # Advertise the active (outermost) rewriter on the `fgraph` so that + # rewrites which optimize inner graphs (`optimize_inner_graphs`) can + # reuse the *same* outer query on those inner graphs. Only the + # outermost `rewrite` owns the slot; nested calls (e.g. applying this + # rewriter to an inner graph) re-establish it on that inner graph. + owns_active = getattr(fgraph, "_active_rewriter", None) is None + if owns_active: + fgraph._active_rewriter = self + try: + return self.apply(fgraph, *args, **kwargs) + finally: + if owns_active: + fgraph._active_rewriter = None def __call__(self, fgraph): """Rewrite a `FunctionGraph`.""" @@ -131,6 +143,40 @@ def print_profile(cls, stream, prof, level=0): ) +def get_active_rewriter(fgraph) -> "GraphRewriter": + """Return the outermost `GraphRewriter` currently rewriting `fgraph`. + + Set by `GraphRewriter.rewrite`. Used by rewrites that need to optimize an + inner graph with the *same* outer query (transferring the outer + compilation -- mode, ``fast_run``/``fast_compile``, user + ``including``/``excluding`` -- to the inner graph). Falls back to the + default mode's optimizer when called outside an active rewrite. + """ + active = getattr(fgraph, "_active_rewriter", None) + if active is not None: + return active + from pytensor.compile.mode import get_mode + from pytensor.configdefaults import config + + return get_mode(config.mode).optimizer + + +def get_active_mode(fgraph): + """Return the compile `Mode` currently being used to rewrite `fgraph`. + + Set by `FunctionMaker` around the optimization pass. Used by inner-graph + rewrites to recover the active linker's required/incompatible rewrites. + Falls back to the default mode when called outside a compilation. + """ + from pytensor.compile.mode import get_mode + from pytensor.configdefaults import config + + active = getattr(fgraph, "_compile_mode", None) + if active is not None: + return get_mode(active) + return get_mode(config.mode) + + class NodeRewriter(Rewriter): """A `Rewriter` that is applied to an `Apply` node.""" diff --git a/pytensor/link/jax/dispatch/basic.py b/pytensor/link/jax/dispatch/basic.py index c1240bba31..51037e499b 100644 --- a/pytensor/link/jax/dispatch/basic.py +++ b/pytensor/link/jax/dispatch/basic.py @@ -7,7 +7,6 @@ import numpy as np from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import JAX from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.configdefaults import config from pytensor.graph import Constant @@ -128,8 +127,6 @@ def type_cast(x): def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: _ = kwargs.pop("storage_map", None) - # Apply inner rewrites - JAX.optimizer(ofg.fgraph) fgraph_fn = jax_funcify(ofg.fgraph, **kwargs) if len(ofg.fgraph.outputs) == 1: diff --git a/pytensor/link/jax/dispatch/scan.py b/pytensor/link/jax/dispatch/scan.py index c4c24f0000..defce26a96 100644 --- a/pytensor/link/jax/dispatch/scan.py +++ b/pytensor/link/jax/dispatch/scan.py @@ -4,7 +4,6 @@ import numpy as np from jax._src.lax.control_flow import scan as jax_scan -from pytensor.compile.mode import JAX, get_mode from pytensor.link.jax.dispatch.basic import jax_funcify from pytensor.scan.op import Scan @@ -27,14 +26,6 @@ def jax_funcify_Scan(op: Scan, node, **kwargs): if info.as_while: raise NotImplementedError("While Scan cannot yet be converted to JAX") - # Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode) - rewriter = ( - get_mode(op.mode) - .including("jax") - .excluding("numba", *JAX._optimizer.exclude) - .optimizer - ) - rewriter(op.fgraph) scan_inner_func = jax_funcify(op.fgraph, **kwargs) def scan(*outer_inputs): diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index e70fd67b72..b0112ea979 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -20,7 +20,7 @@ get_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape, Shape_i ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange` to be constants. @@ -31,6 +31,24 @@ """ +def _is_shape_derived(var): + """Whether ``var``'s value is fixed by input *shapes* rather than input *data*. + + JAX makes array shapes concrete at trace time, so such a value is usable + where ``jax.numpy.arange`` needs a concrete argument -- even when it isn't a + graph ``Constant`` (e.g. ``x.shape[-1]``, which may appear as + ``Subtensor(Shape(x), -1)`` rather than the canonical ``Shape_i``). + """ + if isinstance(var, Constant): + return True + op, *inputs = var.owner_op_and_inputs + if isinstance(op, Shape | Shape_i): + return True + if op is None: + return False + return all(_is_shape_derived(inp) for inp in inputs) + + @jax_funcify.register(AllocEmpty) def jax_funcify_AllocEmpty(op, **kwargs): def allocempty(*shape): @@ -62,12 +80,13 @@ def jax_funcify_ARange(op, node, **kwargs): arange_args = node.inputs constant_args = [] for arg in arange_args: - if arg.owner and isinstance(arg.owner.op, Shape_i): - constant_args.append(None) - elif isinstance(arg, Constant): + if isinstance(arg, Constant): constant_args.append(arg.value) + elif _is_shape_derived(arg): + # Fixed by input shapes (e.g. ``shape(x)[-1]``, ``shape(x)[-1] + 1``), + # which JAX resolves to a concrete value at trace time. + constant_args.append(None) else: - # TODO: This might be failing without need (e.g., if arg = shape(x)[-1] + 1)! raise NotImplementedError(ARANGE_CONCRETE_VALUE_ERROR) constant_start, constant_stop, constant_step = constant_args diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 4ab0224881..3c3bdb5265 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -12,6 +12,7 @@ class JAXLinker(JITLinker): required_rewrites = ( "minimum_compile", "jax", + "optimize_inner_graphs", ) # TODO: Distinguish between optional "jax" and "minimum_compile_jax" incompatible_rewrites = ( "cxx_only", diff --git a/pytensor/link/mlx/dispatch/basic.py b/pytensor/link/mlx/dispatch/basic.py index e205fd519d..983ef65918 100644 --- a/pytensor/link/mlx/dispatch/basic.py +++ b/pytensor/link/mlx/dispatch/basic.py @@ -7,7 +7,6 @@ import numpy as np from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import MLX from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph import Constant from pytensor.graph.fg import AbstractFunctionGraph @@ -204,7 +203,6 @@ def assert_fn(x, *inputs): def mlx_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs): _ = kwargs.pop("storage_map", None) - MLX.optimizer(ofg.fgraph) fgraph_fn = mlx_funcify(ofg.fgraph, squeeze_output=True, **kwargs) return fgraph_fn diff --git a/pytensor/link/mlx/linker.py b/pytensor/link/mlx/linker.py index 6c9db9afd6..6384ffce26 100644 --- a/pytensor/link/mlx/linker.py +++ b/pytensor/link/mlx/linker.py @@ -4,6 +4,10 @@ class MLXLinker(JITLinker): """A `Linker` that JIT-compiles NumPy-based operations using Apple's MLX.""" + required_rewrites = ( + "minimum_compile", + "optimize_inner_graphs", + ) incompatible_rewrites = ( "cxx_only", "BlasOpt", diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 74f8dda91a..724d9de6c1 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -8,7 +8,6 @@ from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy from pytensor.compile.builders import OpFromGraph from pytensor.compile.io import In, Out -from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse from pytensor.link.numba.cache import compile_numba_function_src @@ -53,7 +52,6 @@ def string_deepcopy(x): def numba_funcify_OpFromGraph( op, node=None, - mode=NUMBA.excluding("symbolic_op_recognition"), ofg_memo=None, **kwargs, ): @@ -62,18 +60,17 @@ def numba_funcify_OpFromGraph( if ofg_memo is not None and op in ofg_memo: return ofg_memo[op] - # Apply inner rewrites - # TODO: Not sure this is the right place to do this, should we have a rewrite that - # explicitly triggers the optimization of the inner graphs of OpFromGraph? - # The C-code defers it to the make_thunk phase - fgraph = op.fgraph + # The inner graph is already optimized and inplace-resolved by the + # ``optimize_inner_graphs`` rewrite. We unfreeze a throwaway copy and only + # set up the link-time aliasing machinery (supervisor/DestroyHandler for + # ``insert_deepcopy``). + fgraph = op._frozen_fgraph.unfreeze() input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, input_specs=input_specs, accept_inplace=True, ) - mode.optimizer(fgraph) output_specs = [Out(o, borrow=False) for o in fgraph.outputs] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index cf177f6183..051bad0a70 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -7,7 +7,6 @@ from pytensor.compile.aliasing import add_supervisor_to_fgraph, insert_deepcopy from pytensor.compile.io import In, Out -from pytensor.compile.mode import NUMBA, get_mode from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -56,54 +55,19 @@ def range_arr(x): @register_funcify_and_cache_key(Scan) def numba_funcify_Scan(op: Scan, node, **kwargs): - # Apply inner rewrites - # TODO: Not sure this is the right place to do this, should we have a rewrite that - # explicitly triggers the optimization of the inner graphs of Scan? - # The C-code defers it to the make_thunk phase - rewriter = ( - get_mode(op.mode) - .including("numba") - .excluding(*NUMBA._optimizer.exclude) - .optimizer - ) - fgraph = op.fgraph - # When the buffer can only hold one SITSOT or as as many MITSOT as there are taps, - # We must always discard the oldest tap, so it's safe to destroy it in the inner function. - # TODO: Allow inplace for MITMOT - destroyable_sitsot = [ - inner_sitsot - for outer_sitsot, inner_sitsot in zip( - op.outer_sitsot(node.inputs), op.inner_sitsot(fgraph.inputs), strict=True - ) - if outer_sitsot.type.shape[0] == 1 - ] - destroyable_mitsot = [ - oldest_inner_mitmot - for outer_mitsot, oldest_inner_mitmot, taps in zip( - op.outer_mitsot(node.inputs), - op.oldest_inner_mitsot(fgraph.inputs), - op.info.mit_sot_in_slices, - strict=True, - ) - if outer_mitsot.type.shape[0] == abs(min(taps)) - ] - # Always allow the inner function to destroy untraced_sit_sot inputs. - # After the first iteration, these come from the previous output so - # destroying is always safe. For the first iteration, the codegen - # copies the outer input if the Scan's destroy_map doesn't allow it. - destroyable_untraced_sit_sot = list(op.inner_untraced_sit_sot(fgraph.inputs)) - destroyable = { - *destroyable_sitsot, - *destroyable_mitsot, - *destroyable_untraced_sit_sot, - } + # The inner graph is already optimized and inplace-resolved by the + # ``optimize_inner_graphs`` rewrite. We unfreeze a throwaway copy and only + # set up the link-time aliasing machinery: the supervisor/DestroyHandler + # (so ``destroyers``/``insert_deepcopy`` can reason about the baked inplace) + # and the output deepcopies. + fgraph = op._frozen_fgraph.unfreeze() + destroyable = op.inner_destroyable_inputs(node.inputs, fgraph.inputs) input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, input_specs=input_specs, accept_inplace=True, ) - rewriter(fgraph) untraced_sit_sot_inner_outputs = set(op.inner_untraced_sit_sot_outs(fgraph.outputs)) output_specs = [ Out(x, borrow=x in untraced_sit_sot_inner_outputs) for x in fgraph.outputs @@ -112,17 +76,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): # Track which untraced_sit_sot outputs have their inner input destroyed # by the optimized inner function (transitively, via DestroyHandler). - untraced_start = ( - op.info.n_mit_mot + op.info.n_mit_sot + op.info.n_sit_sot + op.info.n_nit_sot - ) - inner_destroyed_untraced_out_idxs = set() - if hasattr(fgraph, "destroyers"): - for j, inner_inp in enumerate(op.inner_untraced_sit_sot(fgraph.inputs)): - if fgraph.destroyers(inner_inp): - inner_destroyed_untraced_out_idxs.add(untraced_start + j) + inner_destroyed_untraced_out_idxs = op.inner_destroyed_untraced_out_idxs(fgraph) scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( - op.fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") + fgraph, fgraph_name="numba_scan", ofg_memo=kwargs.get("ofg_memo") ) outer_in_names_to_vars = { diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index c8c676dd38..af74015a2b 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -5,6 +5,7 @@ class NumbaLinker(JITLinker): required_rewrites = ( "minimum_compile", "numba", + "optimize_inner_graphs", ) # TODO: Distinguish between optional "numba" and "minimum_compile_numba" incompatible_rewrites = ( "cxx_only", diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c66f3210b6..5b24859754 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -4,10 +4,7 @@ import numpy as np import torch -from pytensor import In -from pytensor.compile.aliasing import add_supervisor_to_fgraph from pytensor.compile.builders import OpFromGraph -from pytensor.compile.mode import PYTORCH from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Constant from pytensor.graph.fg import AbstractFunctionGraph @@ -201,15 +198,6 @@ def ifelse(cond, *true_and_false, n_outs=n_outs): @pytorch_funcify.register(OpFromGraph) def pytorch_funcify_OpFromGraph(op, node, **kwargs): kwargs.pop("storage_map", None) - # Apply inner rewrites - PYTORCH.optimizer(op.fgraph) - fgraph = op.fgraph - add_supervisor_to_fgraph( - fgraph=fgraph, - input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], - accept_inplace=True, - ) - PYTORCH.optimizer(fgraph) fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) return fgraph_fn diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 88fb8d7407..8bfe6ffc7f 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -5,6 +5,10 @@ class PytorchLinker(JITLinker): """A `Linker` that compiles NumPy-based operations using torch.compile.""" + required_rewrites = ( + "minimum_compile", + "optimize_inner_graphs", + ) incompatible_rewrites = ( "cxx_only", "BlasOpt", diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index b4ba5d6608..11c584a724 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -425,6 +425,15 @@ def f(x): Pass this to `pytensor.function` when compiling your function. """ + if mode is not None: + warnings.warn( + "The `mode` argument of `scan` is deprecated: the inner graph now " + "inherits the outer compilation (see the `optimize_inner_graphs` " + "rewrite). It is still honored for now.", + FutureWarning, + stacklevel=2, + ) + # General observation : this code is executed only once, at creation # of the computational graph, so we don't yet need to be smart about # anything (to speed things up) diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index a0cb073f54..99a817a7ec 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -48,7 +48,7 @@ import time import warnings from collections.abc import Callable, Iterable -from copy import copy +from copy import copy, deepcopy from itertools import chain, product import numpy as np @@ -80,7 +80,7 @@ from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs from pytensor.graph.type import HasShape -from pytensor.graph.utils import InconsistencyError, MissingInputError +from pytensor.graph.utils import MissingInputError from pytensor.link.vm import VMLinker from pytensor.printing import op_debug_information from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new @@ -502,6 +502,114 @@ def outer_untraced_sit_sot_outs(self, list_outputs, with_idx=False): else: return res + def inner_destroyable_inputs(self, outer_inputs, inner_inputs): + """Inner inputs the step function may safely destroy in place. + + Destroyability depends on the *outer* node's buffer shapes, so this is a + per-node property (two nodes sharing a `Scan` op but with different outer + buffers can differ): + + - A sit_sot tap whose outer buffer holds a single state (``shape[0] == 1``): + the buffer always discards the oldest state, so destroying it is safe. + - The oldest mit_sot tap when the outer buffer holds exactly the taps + (``shape[0] == abs(min(taps))``): same reasoning. + - Every untraced sit_sot: after the first iteration it comes from the + previous output (always safe to destroy); the first iteration is copied + by the codegen (`perform`) when the Scan's ``destroy_map`` doesn't grant + ownership of the outer input, so destroying the inner buffer can't reach + back to it. + + ``mit_mot`` taps are never destroyed (TODO: allow inplace for MITMOT). + """ + destroyable_sitsot = [ + inner_sitsot + for outer_sitsot, inner_sitsot in zip( + self.outer_sitsot(outer_inputs), + self.inner_sitsot(inner_inputs), + strict=True, + ) + if outer_sitsot.type.shape[0] == 1 + ] + destroyable_mitsot = [ + oldest_inner_mitsot + for outer_mitsot, oldest_inner_mitsot, taps in zip( + self.outer_mitsot(outer_inputs), + self.oldest_inner_mitsot(inner_inputs), + self.info.mit_sot_in_slices, + strict=True, + ) + if outer_mitsot.type.shape[0] == abs(min(taps)) + ] + destroyable_untraced_sit_sot = self.inner_untraced_sit_sot(inner_inputs) + return { + *destroyable_sitsot, + *destroyable_mitsot, + *destroyable_untraced_sit_sot, + } + + def _preallocated_mitmot_out_idxs(self): + """Inner-output indices of mit_mot taps that are also inputs. + + With output preallocation these are wrapped as updates that write back + (possibly in place) into the corresponding input buffer, so -- unlike the + other tap outputs -- they are *allowed* to be the result of an in-place + operation. Mirrors the loop in `prepare_fgraph`. + """ + info = self.info + preallocated = [] + for mitmot_idx in range(info.n_mit_mot): + for inp_tap in info.mit_mot_in_slices[mitmot_idx]: + if inp_tap in info.mit_mot_out_slices[mitmot_idx]: + output_idx = sum( + len(m) for m in info.mit_mot_out_slices[:mitmot_idx] + ) + output_idx += info.mit_mot_out_slices[mitmot_idx].index(inp_tap) + preallocated.append(output_idx) + return preallocated + + def protected_inner_out_idxs(self, preallocated_mitmot_outs=None): + """Inner-output indices that must not be the result of an in-place op. + + These are the tap outputs (mit_mot / mit_sot / sit_sot / nit_sot) whose + buffers the VM reuses across iterations; a protected output computed by a + destroy-map node would alias a value still needed elsewhere. Preallocated + mit_mot updates are excluded -- they are *meant* to write back into their + input buffer. This is the protection installed as `NoOutputFromInplace` + both at link time (`prepare_fgraph`) and when baking inplace into the + frozen inner graph (`optimize_inner_graphs`), so the two agree. + """ + if preallocated_mitmot_outs is None: + preallocated_mitmot_outs = ( + self._preallocated_mitmot_out_idxs() + if config.scan__allow_output_prealloc + else [] + ) + info = self.info + n_taps = info.n_mit_mot_outs + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot + prealloc = set(preallocated_mitmot_outs) + return tuple(i for i in range(n_taps) if i not in prealloc) + + def inner_destroyed_untraced_out_idxs(self, fgraph): + """Output indices of untraced sit_sot whose inner input the step fn destroys. + + ``fgraph`` is the (DestroyHandler-carrying) inner graph a backend is about + to link -- numba's prepared copy or the C/py ``Scan.fn`` graph. When such + an untraced output is *not* owned (its index isn't in ``destroy_map``) the + codegen must hand the step function a copy of the outer input on the first + iteration, else the in-place destruction reaches back to the caller's + input. Shared by the numba codegen and the C/py ``perform`` so both decide + identically; when the inner doesn't destroy the input the ``view_map`` + alias is preserved (no copy). + """ + if not hasattr(fgraph, "destroyers"): + return set() + untraced_start = self.n_tap_outs + self.info.n_nit_sot + return { + untraced_start + j + for j, inner_inp in enumerate(self.inner_untraced_sit_sot(fgraph.inputs)) + if fgraph.destroyers(inner_inp) + } + def inner_non_seqs(self, list_inputs): n_taps_upto_sit_sot = sum( len(x) @@ -945,17 +1053,21 @@ def tensorConstructor(shape, dtype): self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs - if any(node.op.destroy_map for node in self.fgraph.apply_nodes): - raise InconsistencyError( - "Inner-graphs must not contain in-place operations." - ) + # NOTE: Inner graphs are allowed to contain in-place operations. They + # are introduced by the `optimize_inner_graphs` rewrite (which has the + # outer-node context needed to decide what is safe to destroy) and + # baked into the frozen inner graph. The destroy-handler / inplace + # machinery validates correctness when the inner graph is built. - self._frozen_fgraph = self.fgraph.freeze() + # ``dedup_nodes=False``: inner graphs may carry inplace ops (baked by + # ``optimize_inner_graphs``) whose destroyed buffers must stay distinct; + # structural folding would alias them. See ``FunctionGraph.freeze``. + self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) def __setstate__(self, d): self.__dict__.update(d) if not hasattr(self, "_frozen_fgraph"): - self._frozen_fgraph = self.fgraph.freeze() + self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) # Ensure that the graph associated with the inner function is valid. self.validate_inner_graph() @@ -1439,16 +1551,7 @@ def prepare_fgraph(self, fgraph): for x in fgraph.outputs[slices:] ] - protected_outs = tuple( - i - for i in range( - info.n_mit_mot_outs - + info.n_mit_sot - + info.n_sit_sot - + info.n_nit_sot - ) - if i not in preallocated_mitmot_outs - ) + protected_outs = self.protected_inner_out_idxs(preallocated_mitmot_outs) fgraph.attach_feature(NoOutputFromInplace(protected_outs)) else: @@ -1476,7 +1579,11 @@ def fn(self): if getattr(self, "_fn", None) is not None: return self._fn - wrapped_inputs, wrapped_outputs = self.prepare_fgraph(self.fgraph) + # Compile a throwaway copy of the (already math-optimized) inner graph. + # The canonical inner graph is immutable; linking setup (MIT-MOT update + # wrapping, supervisor) and any inplace happen on this transient. + inner_fgraph = self._frozen_fgraph.unfreeze() + wrapped_inputs, wrapped_outputs = self.prepare_fgraph(inner_fgraph) profile = None if config.profile or ( @@ -1518,7 +1625,7 @@ def fn(self): accept_inplace=False, profile=profile, on_unused_input="ignore", - fgraph=self.fgraph, + fgraph=inner_fgraph, ).create() return self._fn @@ -1536,6 +1643,27 @@ def clone(self) -> "Scan": res.fgraph = res.fgraph.clone(clone_inner_graphs=True) # type: ignore[attr-defined] return res + def clone_with_inner_graph(self, inner_fgraph) -> "Scan": + """Return a new `Scan` whose inner graph is ``inner_fgraph``. + + Used by the ``optimize_inner_graphs`` rewrite to build a new op + carrying an already-optimized inner graph without mutating ``self``. + The constructor recomputes all inner-graph-derived state + (``output_types``/``mintaps``/``view_map``/``mitmots_preallocated``) + from ``info`` + the new inner output types. + """ + return type(self)( + inputs=list(inner_fgraph.inputs), + outputs=list(inner_fgraph.outputs), + info=self.info, + mode=self.mode, + truncate_gradient=self.truncate_gradient, + name=self.name, + profile=self.profile, + allow_gc=self.allow_gc, + strict=getattr(self, "strict", True), + ) + def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): """ @@ -1618,6 +1746,24 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): cython_destroy_map = np.asarray(cython_destroy_map, dtype=bool) + # Untraced sit_sot inputs the step fn destroys but the Scan doesn't own + # must be handed a copy on the first iteration, else the in-place + # destruction reaches the caller's input (see + # ``inner_destroyed_untraced_out_idxs``). When the inner doesn't destroy + # the input no copy is made, preserving the output-views-input alias. + inner_destroyed = self.inner_destroyed_untraced_out_idxs( + self.fn.maker.fgraph + ) + untraced_out_start = self.n_tap_outs + self.info.n_nit_sot + cython_untraced_copy = np.asarray( + [ + (untraced_out_start + j) in inner_destroyed + and (untraced_out_start + j) not in self.destroy_map + for j in range(self.info.n_untraced_sit_sot) + ], + dtype=bool, + ) + inner_input_storage = [s.storage for s in self.fn.input_storage] inner_output_storage = [s.storage for s in self.fn.output_storage] @@ -1676,6 +1822,7 @@ def p(node, inputs, outputs): inner_input_storage, inner_output_storage, cython_destroy_map, + cython_untraced_copy, inputs, outputs, outer_output_dtypes, @@ -1839,6 +1986,9 @@ def perform(self, node, inputs, output_storage): offset = self.nit_sot_arg_offset + info.n_nit_sot other_args = inputs[offset:] inner_input_storage = self.fn.input_storage + inner_destroyed_untraced = self.inner_destroyed_untraced_out_idxs( + self.fn.maker.fgraph + ) nb_mitmot_in = sum(map(len, info.mit_mot_in_slices)) old_mitmot_input_storage = [None] * nb_mitmot_in old_mitmot_input_data = [None] * nb_mitmot_in @@ -1903,7 +2053,18 @@ def perform(self, node, inputs, output_storage): o_offset = self.n_tap_outs + info.n_nit_sot if i == 0: for j in range(info.n_untraced_sit_sot): - inner_input_storage[offset].storage[0] = inputs[a_offset + j] + # Copy when the step fn destroys this input in place but the + # Scan doesn't own it, so the destruction can't reach back to + # the caller's input. ``deepcopy`` so this works for RNG inputs + # too (a numpy ``Generator``/``RandomState`` has no ``.copy``); + # for arrays it is as cheap as ``.copy()``. Otherwise pass the + # input directly, preserving the output-views-input alias. + inp = inputs[a_offset + j] + if (o_offset + j) in inner_destroyed_untraced and ( + o_offset + j + ) not in self.destroy_map: + inp = deepcopy(inp) + inner_input_storage[offset].storage[0] = inp offset += 1 else: for j in range(info.n_untraced_sit_sot): diff --git a/pytensor/scan/scan_perform.pyx b/pytensor/scan/scan_perform.pyx index 20ca6858a6..844d3d2fc4 100644 --- a/pytensor/scan/scan_perform.pyx +++ b/pytensor/scan/scan_perform.pyx @@ -47,6 +47,7 @@ by describing the arguments of this function) """ import sys +from copy import deepcopy from libc.time cimport time, time_t @@ -96,6 +97,7 @@ def perform( list inner_input_storage not None, list inner_output_storage not None, const numpy.npy_bool[:] destroy_map not None, + const numpy.npy_bool[:] untraced_copy not None, list outer_inputs not None, list outer_outputs not None, tuple outer_output_dtypes not None, @@ -342,7 +344,17 @@ def perform( o_offset = n_outs + n_nit_sot if i == 0: for j in range(n_untraced_sit_sot): - inner_input_storage[offset][0] = outer_inputs[(a_offset+j)] + # ``untraced_copy[j]`` is set when the step fn destroys this input + # in place but the Scan doesn't own it: hand over a copy so the + # destruction can't reach back to the caller's input. ``deepcopy`` + # so this works for RNG inputs too (a numpy ``Generator``/ + # ``RandomState`` has no ``.copy``); for arrays it is as cheap as + # ``.copy()``. Otherwise pass the input directly, preserving the + # output-views-input alias. + if untraced_copy[j] != 0: + inner_input_storage[offset][0] = deepcopy(outer_inputs[(a_offset+j)]) + else: + inner_input_storage[offset][0] = outer_inputs[(a_offset+j)] offset += 1 else: for j in range(n_untraced_sit_sot): diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index c6f1f1e74c..1a63f926b8 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -534,17 +534,24 @@ def test_shared_to_nonshared_input(self): assert np.array_equal(res_2, 1.0) def test_outputs_consistency(self): - """Make sure that `OpFromGraph.fn` doesn't change the value of `OpFromGraph.inner_outputs`.""" + """Compiling the inner function must not mutate `OpFromGraph.inner_outputs`.""" x = scalar("x") - op = OpFromGraph([x], [x**2 / x], mode="FAST_RUN") + op = OpFromGraph([x], [x**2 / x]) # Confirm that the inner-graph is as expected assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) - # These outputs of the compiled `op.fgraph` should differ from the - # original, uncompiled `op.fgraph` outputs - fn = op.fn + # Optimizing a copy of the inner graph (here FAST_RUN, which rewrites + # ``x**2 / x`` to ``x``) must not leak back into the canonical, shared + # inner graph -- the compiled `FunctionGraph` is a separate clone. + fn = function( + op.inner_inputs, + op.inner_outputs, + mode="FAST_RUN", + on_unused_input="ignore", + accept_inplace=True, + ) new_inputs = fn.maker.fgraph.inputs new_outputs = fn.maker.fgraph.outputs assert not equal_computations(new_outputs, [x**2 / x], new_inputs, [x]) @@ -552,6 +559,10 @@ def test_outputs_consistency(self): # The original `op.fgraph` outputs should stay the same, though assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + # `op.fn` (compiled under the active mode) must likewise leave it intact. + op.fn + assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) + def test_explicit_input_from_constant(self): x = pt.dscalar("x") y = constant(1.0, dtype=x.type.dtype, name="y") diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index b8b8f1b56b..c776915e99 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -235,13 +235,14 @@ def test_check_and_raise(): def test_ofg_with_inner_scan_rewrite(): - # Regression test where inner scan would be mutated when compiling outer OFG - ys = pt.tensor("ys", shape=(5, 3, 3)) + # Regression test where inner scan would be mutated when compiling outer OFG. + # The inner cholesky is *batched* (over the size-4 axis) so the Blockwise + # survives optimization and is wrapped as BlockwiseWithCoreShape for numba. + ys = pt.tensor("ys", shape=(5, 4, 3, 3)) xs = scan( lambda y: cholesky(y), sequences=[ys], return_updates=False, - mode=Mode(optimizer=None), ) xs_ofg = OpFromGraph([ys], [xs])(ys) fn = function([ys], xs_ofg, mode="NUMBA") @@ -265,6 +266,33 @@ def test_ofg_with_inner_scan_rewrite(): assert isinstance(cholesky_op.core_op, Cholesky) +def test_compiling_does_not_mutate_canonical_inner_graph(): + # Regression test: compiling an op with an inner graph must NOT mutate the + # canonical (shared) inner FunctionGraph. Backend specialization (e.g. the + # numba ``BlockwiseWithCoreShape`` wrapping) must happen on per-compilation + # copies, never on the op the user holds -- otherwise a second use of the + # same op (here: deriving something from it after an ``.eval()``) sees a + # corrupted inner graph. This is what tripped scan-based pymc CustomDists. + ys = pt.tensor("ys", shape=(5, 4, 3, 3)) + xs = scan(lambda y: cholesky(y), sequences=[ys], return_updates=False) + ofg_out = OpFromGraph([ys], [xs])(ys) + + scan_op = ofg_out.owner.op.fgraph.outputs[0].owner.op + inner_before = [type(o.owner.op).__name__ for o in scan_op.fgraph.outputs] + + # Compile (and run) under numba: this used to mutate the inner graph above. + fn = function([ys], ofg_out, mode="NUMBA") + fn(np.eye(3)[None, None].repeat(5, 0).repeat(4, 1)) + + inner_after = [type(o.owner.op).__name__ for o in scan_op.fgraph.outputs] + assert inner_before == inner_after, ( + f"canonical inner graph was mutated by compilation: " + f"{inner_before} -> {inner_after}" + ) + # And specifically: no backend wrapper leaked onto the canonical op. + assert not any("CoreShape" in name for name in inner_after) + + @pytest.mark.parametrize("as_view", [True, False]) def test_ifelse_single_output(as_view, single_out=True): x = pt.vector("x") From 42bf4b5fc1a92b2a1bdc040dd561fa735a1dd595 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 18 Jun 2026 19:01:00 +0200 Subject: [PATCH 4/8] Rebuild immutable Apply nodes instead of mutating in clone_with_new_inputs When the input types are unchanged, clone_with_new_inputs cloned the node and reassigned its inputs. For immutable nodes whose clone() returns self (FrozenApply), this mutated the shared, supposedly-immutable node in place, silently corrupting any frozen inner graph that was clone_replace'd. Build a fresh, mutable Apply in that case instead -- the per-node form of bind. --- pytensor/graph/basic.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index cc05fb4f82..1c6a33990a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -319,7 +319,16 @@ def clone_with_new_inputs( new_node.tag = copy(self.tag).__update__(new_node.tag) else: new_node = self.clone(clone_inner_graph=clone_inner_graph) - new_node.inputs = new_inputs + if new_node is self: + # Immutable nodes (e.g. ``FrozenApply``) return ``self`` from + # ``clone()``; mutating ``inputs`` would corrupt the shared node. + # Build a fresh, mutable ``Apply`` instead. + new_node = Apply( + self.op, new_inputs, [out.type() for out in self.outputs] + ) + new_node.tag = copy(self.tag).__update__(new_node.tag) + else: + new_node.inputs = new_inputs return new_node def get_parents(self): From 9652d40f5341555ed6a30421889143d981b88a88 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 18 Jun 2026 19:01:27 +0200 Subject: [PATCH 5/8] Collapse Scan and OpFromGraph to a single frozen inner graph Scan and OpFromGraph kept both a mutable self.fgraph and a separate frozen _frozen_fgraph. Collapse to one immutable graph exposed as op.fgraph (a FrozenFunctionGraph): the mutable copy is gone, so the canonical inner graph can never be mutated in place. clone() returns self (like Composite), and eq/hash/dispatch/inner_inputs all read the frozen graph. Rewrites that recombine inner variables to build a new Scan (push_out, trace, io/utils) read the frozen inner_inputs/inner_outputs directly and pass them to the Scan constructor; construct_nominal_fgraph rebuilds the body as fresh mutable Apply nodes (relying on clone_with_new_inputs no longer mutating frozen nodes). Rewrites that mutate the inner graph in place (the linalg scan-solve split) unfreeze() a transient first. scan_make_inplace shallow-copies the op (clone() is now self). inner_inputs/inner_outputs return lists so callers keep list semantics. Updates the OFG/Scan clone tests to the immutable contract. --- pytensor/compile/builders.py | 50 +++++++----- pytensor/compile/inner_graph_rewriting.py | 2 +- pytensor/link/numba/dispatch/compile_ops.py | 2 +- pytensor/link/numba/dispatch/scan.py | 2 +- pytensor/printing.py | 38 ++++------ pytensor/scan/op.py | 84 +++++++++++---------- pytensor/scan/rewriting/inplace.py | 6 +- pytensor/scan/rewriting/utils.py | 14 ++-- pytensor/tensor/rewriting/linalg/solvers.py | 22 ++---- pytensor/tensor/rewriting/ofg.py | 2 +- tests/compile/test_builders.py | 8 +- tests/scan/test_basic.py | 10 +-- tests/test_printing.py | 23 +++--- 13 files changed, 132 insertions(+), 131 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index fa6b8e8c83..43a2ef6551 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -338,13 +338,15 @@ def __init__( self.is_inline = inline - self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( + inner_fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph( inputs, outputs ) - # ``dedup_nodes=False``: inner graphs may carry inplace ops whose - # destroyed buffers must stay distinct; structural folding would alias - # them. See ``FunctionGraph.freeze``. - self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) + # Keep only the immutable (frozen) inner graph as ``op.fgraph``; the + # mutable copy is transient, so the canonical inner graph can never be + # mutated in place. ``dedup_nodes=False``: inner graphs may carry inplace + # ops whose destroyed buffers must stay distinct; structural folding + # would alias them. See ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze(dedup_nodes=False) if strict and self.shared_inputs: raise ValueError( @@ -432,7 +434,7 @@ def _freeze_override(self, override, make_dummy_args): if override is None: return None if isinstance(override, OpFromGraph): - return override._frozen_fgraph + return override.fgraph all_inputs, callable_args = make_dummy_args() @@ -494,7 +496,7 @@ def __eq__(self, other): if type(self) is not type(other): return False if ( - self._frozen_fgraph != other._frozen_fgraph + self.fgraph != other.fgraph or self.is_inline != other.is_inline or self.destroy_map != other.destroy_map or len(self.shared_inputs) != len(other.shared_inputs) @@ -518,7 +520,7 @@ def __eq__(self, other): ) def __hash__(self): - return hash((type(self), self._frozen_fgraph, self.is_inline)) + return hash((type(self), self.fgraph, self.is_inline)) def __str__(self): name = self.__class__.__name__ if self.name is None else self.name @@ -577,8 +579,13 @@ def _build_and_cache_lop_op( except KeyError: pass - inner_inputs = self.inner_inputs - inner_outputs = self.inner_outputs + # Differentiate a thawed copy of the inner graph so ``grad`` walks + # mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes + # of ``self.fgraph`` (whose tuple inputs/outputs break Ops that + # concatenate them, e.g. ``Blockwise.pullback``). + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) nin = len(inner_inputs) nout = len(inner_outputs) pullback_overrides = self.pullback_overrides @@ -701,8 +708,10 @@ def _build_and_cache_rop_op(self): if self._rop_op_cache is not None: return self._rop_op_cache - inner_inputs = self.inner_inputs - inner_outputs = self.inner_outputs + # Thaw the inner graph before differentiating (see ``_build_and_cache_lop_op``). + unfrozen_fgraph = self.fgraph.unfreeze() + inner_inputs = list(unfrozen_fgraph.inputs) + inner_outputs = list(unfrozen_fgraph.outputs) nout = len(inner_outputs) pushforward_overrides = self.pushforward_overrides @@ -956,16 +965,19 @@ def fn(self): @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so callers that concatenate inner + # inputs/outputs keep list semantics. Read-only views of the immutable + # graph; manipulating them requires a fresh/unfrozen graph. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) def clone(self): - res = copy(self) - res.fgraph = res.fgraph.clone(clone_inner_graphs=True) - return res + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite.clone``. + return self def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph: """Return a copy of this op whose inner graph is ``inner_fgraph``. @@ -977,10 +989,10 @@ def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph: rebuilt. """ new = copy(self) - new.fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph( + new_fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph( list(inner_fgraph.inputs), list(inner_fgraph.outputs) ) - new._frozen_fgraph = new.fgraph.freeze(dedup_nodes=False) + new.fgraph = new_fgraph.freeze(dedup_nodes=False) new.input_types = [inp.type for inp in new.fgraph.inputs] new.output_types = [out.type for out in new.fgraph.outputs] # Drop caches tied to the previous inner graph. diff --git a/pytensor/compile/inner_graph_rewriting.py b/pytensor/compile/inner_graph_rewriting.py index 7d7f65bcc7..fd774f2d8c 100644 --- a/pytensor/compile/inner_graph_rewriting.py +++ b/pytensor/compile/inner_graph_rewriting.py @@ -110,7 +110,7 @@ def apply(self, fgraph): for nodes in groups.values(): rep_node = nodes[0] op = rep_node.op - inner = op._frozen_fgraph.unfreeze() + inner = op.fgraph.unfreeze() inner._compile_mode = mode custom_mode = getattr(op, "mode", None) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 724d9de6c1..5036ecc189 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -64,7 +64,7 @@ def numba_funcify_OpFromGraph( # ``optimize_inner_graphs`` rewrite. We unfreeze a throwaway copy and only # set up the link-time aliasing machinery (supervisor/DestroyHandler for # ``insert_deepcopy``). - fgraph = op._frozen_fgraph.unfreeze() + fgraph = op.fgraph.unfreeze() input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs] add_supervisor_to_fgraph( fgraph=fgraph, diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 051bad0a70..f83bfdea22 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -60,7 +60,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): # set up the link-time aliasing machinery: the supervisor/DestroyHandler # (so ``destroyers``/``insert_deepcopy`` can reason about the baked inplace) # and the output deepcopies. - fgraph = op._frozen_fgraph.unfreeze() + fgraph = op.fgraph.unfreeze() destroyable = op.inner_destroyable_inputs(node.inputs, fgraph.inputs) input_specs = [In(x, borrow=True, mutable=x in destroyable) for x in fgraph.inputs] add_supervisor_to_fgraph( diff --git a/pytensor/printing.py b/pytensor/printing.py index cd9b9bab1f..3c27fc7bde 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -853,22 +853,15 @@ def _show_inner_graph(op): continue else: printed_inner_graph_ops.add(ig_var.owner.op) - # This is a work-around to maintain backward compatibility - # (e.g. to only print inner graphs that have been compiled through - # a call to `Op.prepare_node`) - inner_fn = getattr(ig_var.owner.op, "_fn", None) - - if inner_fn: - # If the op was compiled, print the optimized version. - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs + # The canonical op already carries its optimized inner graph + # (``optimize_inner_graphs`` bakes it in), rooted on NominalVariables; + # print that directly. + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs else: - if hasattr(ig_var.owner.op, "scalar_op"): - inner_inputs = ig_var.owner.op.scalar_op.inner_inputs - inner_outputs = ig_var.owner.op.scalar_op.inner_outputs - else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs @@ -1319,17 +1312,12 @@ def _assign_color(node_key) -> str: continue printed.add(ig_var.owner) - inner_fn = getattr(ig_var.owner.op, "_fn", None) - if inner_fn: - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs + if hasattr(ig_var.owner.op, "scalar_op"): + inner_inputs = ig_var.owner.op.scalar_op.inner_inputs + inner_outputs = ig_var.owner.op.scalar_op.inner_outputs else: - if hasattr(ig_var.owner.op, "scalar_op"): - inner_inputs = ig_var.owner.op.scalar_op.inner_inputs - inner_outputs = ig_var.owner.op.scalar_op.inner_outputs - else: - inner_inputs = ig_var.owner.op.inner_inputs - inner_outputs = ig_var.owner.op.inner_outputs + inner_inputs = ig_var.owner.op.inner_inputs + inner_outputs = ig_var.owner.op.inner_outputs outer_inputs = ig_var.owner.inputs inner_to_outer: dict[Variable, Variable] | None diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 99a817a7ec..1a82282612 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -48,7 +48,7 @@ import time import warnings from collections.abc import Callable, Iterable -from copy import copy, deepcopy +from copy import deepcopy from itertools import chain, product import numpy as np @@ -75,7 +75,7 @@ Variable, ) from pytensor.graph.features import NoOutputFromInplace -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern from pytensor.graph.replace import clone_replace from pytensor.graph.traversal import graph_inputs @@ -950,13 +950,21 @@ def __init__( If ``True``, all the shared variables used in the inner-graph must be provided. """ - self.fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs) + inner_fgraph, shared_inputs, _, _ = construct_nominal_fgraph(inputs, outputs) # The shared variables should have been removed, so, if there are # any, it's because the user didn't specify an input. if shared_inputs: raise MissingInputError(f"Scan is missing inputs: {shared_inputs}") + # Keep only the immutable (frozen) inner graph as ``op.fgraph``; the + # mutable copy is transient, so the canonical inner graph can never be + # mutated in place. Inner graphs may carry inplace ops (baked by + # ``optimize_inner_graphs``) whose destroyed buffers must stay distinct; + # ``dedup_nodes=False`` keeps structural folding from aliasing them. + # See ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze(dedup_nodes=False) + self.info = info self.truncate_gradient = truncate_gradient self.name = name @@ -1053,21 +1061,13 @@ def tensorConstructor(shape, dtype): self.n_outer_inputs = info.n_outer_inputs self.n_outer_outputs = info.n_outer_outputs - # NOTE: Inner graphs are allowed to contain in-place operations. They - # are introduced by the `optimize_inner_graphs` rewrite (which has the - # outer-node context needed to decide what is safe to destroy) and - # baked into the frozen inner graph. The destroy-handler / inplace - # machinery validates correctness when the inner graph is built. - - # ``dedup_nodes=False``: inner graphs may carry inplace ops (baked by - # ``optimize_inner_graphs``) whose destroyed buffers must stay distinct; - # structural folding would alias them. See ``FunctionGraph.freeze``. - self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) - def __setstate__(self, d): self.__dict__.update(d) - if not hasattr(self, "_frozen_fgraph"): - self._frozen_fgraph = self.fgraph.freeze(dedup_nodes=False) + # Back-compat: older pickles stored a mutable inner ``fgraph`` (plus a + # separate ``_frozen_fgraph``). Collapse to the single frozen graph. + if not isinstance(self.fgraph, FrozenFunctionGraph): + self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.__dict__.pop("_frozen_fgraph", None) # Ensure that the graph associated with the inner function is valid. self.validate_inner_graph() @@ -1446,7 +1446,7 @@ def __eq__(self, other): if self.allow_gc != other.allow_gc: return False - return self._frozen_fgraph == other._frozen_fgraph + return self.fgraph == other.fgraph def __str__(self): inplace = "none" @@ -1466,7 +1466,7 @@ def __hash__(self): return hash( ( type(self), - self._frozen_fgraph, + self.fgraph, self.info, self.profile, self.truncate_gradient, @@ -1582,7 +1582,7 @@ def fn(self): # Compile a throwaway copy of the (already math-optimized) inner graph. # The canonical inner graph is immutable; linking setup (MIT-MOT update # wrapping, supervisor) and any inplace happen on this transient. - inner_fgraph = self._frozen_fgraph.unfreeze() + inner_fgraph = self.fgraph.unfreeze() wrapped_inputs, wrapped_outputs = self.prepare_fgraph(inner_fgraph) profile = None @@ -1622,7 +1622,10 @@ def fn(self): wrapped_inputs, wrapped_outputs, mode=mode_instance, - accept_inplace=False, + # The (already-optimized) inner graph may carry inplace ops baked in + # by optimize_inner_graphs; prepare_fgraph has already attached the + # DestroyHandler + Supervisor, so accept them here. + accept_inplace=True, profile=profile, on_unused_input="ignore", fgraph=inner_fgraph, @@ -1632,16 +1635,19 @@ def fn(self): @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so the many ``inner_*`` slicing helpers + # and their callers keep list semantics. These are read-only views of the + # immutable graph; rewrites that rebuild a Scan must ``unfreeze`` first. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) def clone(self) -> "Scan": - res = copy(self) - res.fgraph = res.fgraph.clone(clone_inner_graphs=True) # type: ignore[attr-defined] - return res + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite.clone``. + return self def clone_with_inner_graph(self, inner_fgraph) -> "Scan": """Return a new `Scan` whose inner graph is ``inner_fgraph``. @@ -2645,8 +2651,13 @@ def pullback(self, inputs, outs, dC_douts): if self.truncate_gradient != -1: grad_steps = minimum(grad_steps, self.truncate_gradient) - self_inputs = self.inner_inputs - self_outputs = self.inner_outputs + # Differentiate a thawed copy of the inner graph so ``grad`` walks + # mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes + # of ``self.fgraph`` (whose tuple inputs/outputs break Ops that + # concatenate them). + unfrozen_fgraph = self.fgraph.unfreeze() + self_inputs = list(unfrozen_fgraph.inputs) + self_outputs = list(unfrozen_fgraph.outputs) # differentiable inputs diff_inputs = ( self.inner_seqs(self_inputs) @@ -3402,12 +3413,14 @@ def compute_all_gradients(known_grads): def pushforward(self, inputs, outputs, eval_points): # Step 0. Prepare some shortcut variable info = self.info - self_inputs = self.inner_inputs + # Thaw the inner graph before differentiating (see ``L_op``). + unfrozen_fgraph = self.fgraph.unfreeze() + self_inputs = list(unfrozen_fgraph.inputs) + self_outputs = list(unfrozen_fgraph.outputs) rop_of_inputs = ( self_inputs[: info.n_seqs + self.n_tap_outs] + self_inputs[info.n_seqs + self.n_tap_outs + info.n_untraced_sit_sot :] ) - self_outputs = self.inner_outputs # Step 1. Compute the R_op of the inner function inner_eval_points = [safe_new(x, "_evalpoint") for x in rop_of_inputs] @@ -3682,20 +3695,11 @@ def _op_debug_information_Scan(op: Scan, node: Apply): extra_information = {} - inner_fn = getattr(op, "_fn", None) - - if inner_fn: - inner_inputs = inner_fn.maker.fgraph.inputs - inner_outputs = inner_fn.maker.fgraph.outputs - else: - inner_inputs = op.inner_inputs - inner_outputs = op.inner_outputs - scan_args = ScanArgs( node.inputs, node.outputs, - inner_inputs, - inner_outputs, + op.inner_inputs, + op.inner_outputs, node.op.info, clone=False, ) diff --git a/pytensor/scan/rewriting/inplace.py b/pytensor/scan/rewriting/inplace.py index 4839d4348f..ee44eb2ba4 100644 --- a/pytensor/scan/rewriting/inplace.py +++ b/pytensor/scan/rewriting/inplace.py @@ -6,6 +6,7 @@ have if it understood Scan's input categories. """ +from copy import copy from itertools import chain from pytensor.compile.ops import deep_copy_op @@ -85,7 +86,10 @@ def attempt_scan_inplace( inputs = ls_begin + ls + ls_end - new_op = op.clone() + # Shallow-copy the op so we can give it its own ``destroy_map`` without + # mutating the canonical op; the frozen inner graph is immutable and + # safely shared. (``op.clone()`` returns ``self`` for immutable ops.) + new_op = copy(op) destroy_map = op.destroy_map.copy() for out_idx in output_indices: diff --git a/pytensor/scan/rewriting/utils.py b/pytensor/scan/rewriting/utils.py index 288f55e607..f96dda5554 100644 --- a/pytensor/scan/rewriting/utils.py +++ b/pytensor/scan/rewriting/utils.py @@ -57,12 +57,14 @@ def _rebuild_scan_with_new_signature( n_non_seqs=len(keep_non_seqs), ) - inner_seqs = op.inner_seqs(op.inner_inputs) - inner_mm_groups = op.inner_mitmot_grouped(op.inner_inputs) - inner_ms_groups = op.inner_mitsot_grouped(op.inner_inputs) - inner_ss = op.inner_sitsot(op.inner_inputs) - inner_us = op.inner_untraced_sit_sot(op.inner_inputs) - inner_non_seqs = op.inner_non_seqs(op.inner_inputs) + inner_inputs = op.inner_inputs + + inner_seqs = op.inner_seqs(inner_inputs) + inner_mm_groups = op.inner_mitmot_grouped(inner_inputs) + inner_ms_groups = op.inner_mitsot_grouped(inner_inputs) + inner_ss = op.inner_sitsot(inner_inputs) + inner_us = op.inner_untraced_sit_sot(inner_inputs) + inner_non_seqs = op.inner_non_seqs(inner_inputs) new_inner_inputs = ( [inner_seqs[k] for k in keep_seqs] diff --git a/pytensor/tensor/rewriting/linalg/solvers.py b/pytensor/tensor/rewriting/linalg/solvers.py index 33a907b1da..4b2447f111 100644 --- a/pytensor/tensor/rewriting/linalg/solvers.py +++ b/pytensor/tensor/rewriting/linalg/solvers.py @@ -1,11 +1,11 @@ from collections.abc import Container -from copy import copy from pytensor import tensor as pt from pytensor.assumptions import DIAGONAL, ORTHOGONAL, check_assumption from pytensor.assumptions.positive_definite import POSITIVE_DEFINITE from pytensor.compile import optdb from pytensor.graph import Constant, graph_inputs +from pytensor.graph.fg import FrozenFunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, dfs_rewriter, @@ -564,8 +564,11 @@ def _scan_split_non_sequence_decomposition_and_solve( The LU decomposition step can then be pushed out of the inner loop by the `scan_pushout_non_sequences` rewrite. """ scan_op: Scan = node.op - non_sequences = set(scan_op.inner_non_seqs(scan_op.inner_inputs)) - new_scan_fgraph = scan_op.fgraph + # ``scan_op.fgraph`` is immutable; work on a mutable transient and recompute + # non-sequences against its inputs so the toposort match below lines up. + assert isinstance(scan_op.fgraph, FrozenFunctionGraph) + new_scan_fgraph = scan_op.fgraph.unfreeze() + non_sequences = set(scan_op.inner_non_seqs(new_scan_fgraph.inputs)) changed = False while True: @@ -578,14 +581,6 @@ def _scan_split_non_sequence_decomposition_and_solve( (isinstance(root_inp, Constant) or (root_inp in non_sequences)) for root_inp in graph_inputs([A]) ): - if new_scan_fgraph is scan_op.fgraph: - # Clone the first time to avoid mutating the original fgraph - new_scan_fgraph, equiv = new_scan_fgraph.clone_get_equiv() # type: ignore[attr-defined] - non_sequences = { - equiv[non_seq] for non_seq in non_sequences - } - inner_node = equiv[inner_node] - replace_dict = _split_decomp_and_solve_steps( new_scan_fgraph, inner_node, @@ -595,7 +590,7 @@ def _scan_split_non_sequence_decomposition_and_solve( assert ( isinstance(replace_dict, dict) and len(replace_dict) > 0 ), "Rewrite failed" - new_scan_fgraph.replace_all(replace_dict.items()) # type: ignore[attr-defined] + new_scan_fgraph.replace_all(replace_dict.items()) changed = True break # Break to start over with a fresh toposort else: # no_break @@ -605,8 +600,7 @@ def _scan_split_non_sequence_decomposition_and_solve( return # Return a new scan to indicate that a rewrite was done - new_scan_op = copy(scan_op) - new_scan_op.fgraph = new_scan_fgraph + new_scan_op = scan_op.clone_with_inner_graph(new_scan_fgraph) new_outs = new_scan_op.make_node(*node.inputs).outputs copy_stack_trace(node.outputs, new_outs) return new_outs diff --git a/pytensor/tensor/rewriting/ofg.py b/pytensor/tensor/rewriting/ofg.py index 1cd914d343..b806aefcd0 100644 --- a/pytensor/tensor/rewriting/ofg.py +++ b/pytensor/tensor/rewriting/ofg.py @@ -9,7 +9,7 @@ def inline_ofg_node(node: Apply) -> list[Variable]: - frozen_fg: FrozenFunctionGraph = node.op._frozen_fgraph + frozen_fg: FrozenFunctionGraph = node.op.fgraph replacements = dict(zip(frozen_fg.inputs, node.inputs)) inlined_outs = frozen_fg.bind(replacements) copy_stack_trace(frozen_fg.outputs, inlined_outs) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 1a63f926b8..8cc467ba62 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -65,11 +65,9 @@ def test_clone(self): ofg = OpFromGraph([x], [2 * x]) - ofg_clone = ofg.clone() - - assert ofg_clone.fgraph is not ofg.fgraph - assert ofg_clone.fgraph.outputs != ofg.fgraph.outputs - assert equal_computations(ofg_clone.fgraph.outputs, ofg.fgraph.outputs) + # OpFromGraph is immutable (single frozen inner graph), so cloning + # returns self -- mirroring Composite. + assert ofg.clone() is ofg @pytest.mark.parametrize( "cls_ofg", [OpFromGraph, partial(OpFromGraph, inline=True)] diff --git a/tests/scan/test_basic.py b/tests/scan/test_basic.py index 1bbd183953..1ca2cbebff 100644 --- a/tests/scan/test_basic.py +++ b/tests/scan/test_basic.py @@ -27,7 +27,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import NullTypeGradError, disconnected_grad, grad, pushforward -from pytensor.graph.basic import Apply, Variable, equal_computations +from pytensor.graph.basic import Apply, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.replace import vectorize_graph @@ -299,11 +299,9 @@ def test_clone(self): scan_op = output.owner.op assert isinstance(scan_op, Scan) - scan_op_clone = scan_op.clone() - assert scan_op_clone is not scan_op - assert scan_op_clone.fgraph is not scan_op.fgraph - assert scan_op_clone.fgraph.outputs != scan_op.fgraph.outputs - assert equal_computations(scan_op_clone.fgraph.outputs, scan_op.fgraph.outputs) + # Scan ops are immutable (single frozen inner graph), so cloning returns + # self -- mirroring Composite. + assert scan_op.clone() is scan_op @pytest.mark.skipif( isinstance(get_default_mode(), DebugMode), diff --git a/tests/test_printing.py b/tests/test_printing.py index 973c7c5888..55476ecb34 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -567,9 +567,10 @@ def add_one_composite(): for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): assert exp_line.strip() == res_line.strip() - # An Op that only appears nested inside other inner graphs: its nodes are - # only discovered while the parent bodies are printed, and the shared - # header must still list every node id + # An Op that only appears nested inside other inner graphs is still + # discovered and printed in the "Inner graphs" section. Here both A and B + # apply the same Relu to their (nominal) input, so global FrozenApply + # interning collapses it to a single shared node, printed once. i1 = dvector("i") a_op = OpFromGraph([i1], [relu_ofg()(i1) + 1], inline=False, name="A") i2 = dvector("i") @@ -594,16 +595,16 @@ def add_one_composite(): B{inline=False} [id D] ← Mul [id J] - ├─ Relu{inline=False} [id K] - │ └─ i0 [id G] - └─ ExpandDims{axis=0} [id L] - └─ 2 [id M] + ├─ Relu{inline=False} [id F] + │ └─ ··· + └─ ExpandDims{axis=0} [id K] + └─ 2 [id L] -Relu{inline=False} [id F, K] - ← Maximum [id N] +Relu{inline=False} [id F] + ← Maximum [id M] ├─ i0 [id G] - └─ ExpandDims{axis=0} [id O] - └─ 0 [id P] + └─ ExpandDims{axis=0} [id N] + └─ 0 [id O] """ for exp_line, res_line in zip(exp_res.split("\n"), lines, strict=True): From 94162459f346db83ffcafd2eda8ce6f512421f48 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 19 Jun 2026 12:53:01 +0200 Subject: [PATCH 6/8] Freeze ScipyWrapperOp inner graphs and optimize them via the rewrite MinimizeScalarOp/MinimizeOp/RootScalarOp/RootOp built a mutable inner FunctionGraph and kept it as self.fgraph; that left them as the last inner-graph ops not collapsed to a single frozen graph, so they were excluded from optimize_inner_graphs. Build the inner graph (plus the jac/hess outputs) mutably in __init__ as before, then freeze(dedup_nodes=False) it -- the canonical op.fgraph is now an immutable FrozenFunctionGraph that can never be mutated in place. clone() returns self, inner_inputs/outputs return lists, and eq/hash key on the frozen graph plus the per-subclass _scipy_props (so e.g. RootOps of different equations stay distinct without a hand-maintained __props__). clone_with_inner_graph bakes an optimized inner graph into a new frozen op for the rewrite. Graph manipulation (build_fn, compute_implicit_gradients, the connection pattern) runs on a fresh unfreeze()d copy: graph_replace on frozen nodes would otherwise leak them into the outer graph. Re-add ScipyWrapperOp to the optimize_inner_graphs matcher (reverting the interim exclusion). It uses the no-inplace optimizer because its inner graph is not backend-funcified but recompiled by build_fn on the scipy-wrapped graph. --- pytensor/compile/inner_graph_rewriting.py | 13 +-- pytensor/tensor/optimize.py | 127 +++++++++++++++++----- 2 files changed, 106 insertions(+), 34 deletions(-) diff --git a/pytensor/compile/inner_graph_rewriting.py b/pytensor/compile/inner_graph_rewriting.py index fd774f2d8c..f956f462ce 100644 --- a/pytensor/compile/inner_graph_rewriting.py +++ b/pytensor/compile/inner_graph_rewriting.py @@ -55,13 +55,10 @@ def apply(self, fgraph): from pytensor.graph.features import NoOutputFromInplace from pytensor.link.basic import JITLinker from pytensor.scan.op import Scan + from pytensor.tensor.optimize import ScipyWrapperOp from pytensor.tensor.random.op import OpWithCoreShape - # NOTE: ``ScipyWrapperOp`` (MinimizeOp/RootOp) still keeps a *mutable* - # inner ``fgraph`` (it has no frozen graph yet), so it cannot be unfrozen - # here. Its inner graph is optimized lazily at link time by ``build_fn``. - # Add it back once it is made frozen. - inner_graph_types: tuple = (OpFromGraph, Scan) + inner_graph_types: tuple = (OpFromGraph, Scan, ScipyWrapperOp) # Group apply NODES by an optimization key, so each distinct inner graph # is optimized once and shared across its nodes (an op reused in many @@ -93,8 +90,10 @@ def apply(self, fgraph): # Inplace is baked into the frozen op (frozen ``dedup_nodes=False`` so # distinct inplace buffers survive interning). ``OpFromGraph`` must not # mutate its inputs, so all are protected; ``Scan`` protects all but the - # destroyable taps. ``ScipyWrapperOp`` keeps inplace excluded (separate - # work). + # destroyable taps. ``ScipyWrapperOp`` uses the no-inplace optimizer: + # its inner graph is not backend-funcified but recompiled by + # ``build_fn`` (which applies inplace itself on the scipy-wrapped graph), + # so baking inplace into the frozen op here would be pointless. # Exclude ``symbolic_op_recognition``: those rewrites fold a pattern into # an inner-graph op (e.g. ``exp(x) / sum(exp(x))`` -> ``Softmax``, itself # an ``OpFromGraph``). Running them on the inner graph -- which *is* that diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 130cbf9ba2..639f0950fc 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -8,7 +8,7 @@ from pytensor.compile.maker import function from pytensor.gradient import DisconnectedType, grad, jacobian from pytensor.graph.basic import Apply, Constant -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import ( ComputeMapType, @@ -163,15 +163,26 @@ def _depends_only_on_constants(var: Variable) -> bool: class ScipyWrapperOp(Op, HasInnerGraph): - """Shared logic for scipy optimization ops""" + """Shared logic for scipy optimization ops. + + The inner graph is held frozen (immutable) as ``self.fgraph``, so the + canonical op can never be mutated in place. Graph manipulation (``grad`` / + ``graph_replace`` / ``function``) runs on a fresh mutable copy obtained via + ``self.fgraph.unfreeze()`` -- ``graph_replace`` on frozen nodes would + otherwise leak them into the outer graph. + """ + + # Attribute names (besides the frozen inner graph) that distinguish two ops + # of the same type for eq/hash. Subclasses override. + _scipy_props: tuple[str, ...] = () def build_fn(self): """ This is overloaded because scipy converts scalar inputs to lists, changing the return type. The wrapper function logic is there to handle this. """ - outputs = self.inner_outputs - self._fn = fn = function(self.inner_inputs, outputs, trust_input=True) + fgraph = self.fgraph.unfreeze() + self._fn = fn = function(fgraph.inputs, fgraph.outputs, trust_input=True) # Do this reassignment to see the compiled graph in the dprint # self.fgraph = fn.maker.fgraph @@ -192,22 +203,60 @@ def fn_wrapped(self): @property def inner_inputs(self): - return self.fgraph.inputs + # A list (not the frozen tuple) so callers that concatenate inner + # inputs/outputs keep list semantics. + return list(self.fgraph.inputs) @property def inner_outputs(self): - return self.fgraph.outputs + return list(self.fgraph.outputs) + + def _prop_values(self): + values = [] + for name in self._scipy_props: + value = getattr(self, name) + if isinstance(value, dict): + value = tuple(sorted(value.items())) + values.append(value) + return tuple(values) + + def __eq__(self, other): + if self is other: + return True + if type(self) is not type(other): + return False + return ( + self.fgraph == other.fgraph and self._prop_values() == other._prop_values() + ) + + def __hash__(self): + return hash((type(self), self.fgraph, self._prop_values())) + + def clone_with_inner_graph(self, inner_fgraph): + """Return a copy of this op whose inner graph is ``inner_fgraph``. - def clone_with_new_fgraph(self, fgraph): + Used by the ``optimize_inner_graphs`` rewrite to bake an + already-optimized inner graph into a NEW immutable op without touching + ``self``. ``inner_fgraph`` may be a mutable ``FunctionGraph`` (it is + frozen here) or an already-frozen graph. + """ clone_op = copy(self) clone_op._fn = None clone_op._fn_wrapped = None - clone_op.fgraph = fgraph + clone_op.fgraph = ( + inner_fgraph + if isinstance(inner_fgraph, FrozenFunctionGraph) + else inner_fgraph.freeze(dedup_nodes=False) + ) return clone_op + # Name used by the canonicalization rewrite that rebuilds the inner graph. + clone_with_new_fgraph = clone_with_inner_graph + def clone(self): - clone_fgraph = self.fgraph.clone(clone_inner_graphs=True) - return self.clone_with_new_fgraph(clone_fgraph) + # The inner graph is immutable (a frozen ``FunctionGraph``), so there is + # nothing to deep-clone -- mirror ``Composite``/``OpFromGraph``. + return self def prepare_node( self, @@ -233,11 +282,12 @@ class ScipyScalarWrapperOp(ScipyWrapperOp): def build_fn(self): # We need to adjust the graph to work with what scipy will be passing into the inner function -- # always scalar array of float64 type - x, *args = self.inner_inputs + fgraph = self.fgraph.unfreeze() + x, *args = fgraph.inputs new_root_x = ps.float64(name="x_scalar") new_x = tensor_from_scalar(new_root_x.astype(x.type.dtype)) - new_outputs = graph_replace(self.inner_outputs, {x: new_x}) + new_outputs = graph_replace(fgraph.outputs, {x: new_x}) self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True) @@ -270,9 +320,9 @@ def compute_implicit_gradients( Whether the optimization problem is a minimization problem. If False, it is assumed to be a root-finding problem. """ - fgraph = self.fgraph - inner_x, *inner_args = self.inner_inputs - inner_fx = self.inner_outputs[0] + fgraph = self.fgraph.unfreeze() + inner_x, *inner_args = fgraph.inputs + inner_fx = fgraph.outputs[0] if is_minimization: # The implicit function in minimization is grad(x, theta) == 0 @@ -323,13 +373,14 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): def build_fn(self): # We need to adjust the graph to work with what scipy will be passing into the inner function -- # always a vector array with size of at least 1 - x, *args = self.inner_inputs - if x.type.shape != (): + if self.inner_inputs[0].type.shape != (): return super().build_fn() + fgraph = self.fgraph.unfreeze() + x, *args = fgraph.inputs new_root_x = x[None].type() new_x = new_root_x.squeeze() - new_outputs = graph_replace(self.inner_outputs, {x: new_x}) + new_outputs = graph_replace(fgraph.outputs, {x: new_x}) self._fn = fn = function([new_root_x, *args], new_outputs, trust_input=True) # Do this reassignment to see the compiled graph in the dprint @@ -383,9 +434,9 @@ def compute_implicit_gradients( problem, where `f` is the objective function. In this case, we instead take `f` to be the gradient of the objective function, which *is* indeed zero at the minimum. """ - fgraph = self.fgraph - inner_x, *inner_args = self.inner_inputs - implicit_f = self.inner_outputs[0] + fgraph = self.fgraph.unfreeze() + inner_x, *inner_args = fgraph.inputs + implicit_f = fgraph.outputs[0] if is_minimization: # The implicit function in minimization is grad(x, theta) == 0 implicit_f = grad(implicit_f, inner_x) @@ -484,6 +535,7 @@ def _optimizer_connection_pattern(fgraph, is_minimization): An input may be connected to the objective but disconnected from its gradient (e.g. an additive constant), so the connection pattern must reflect the actual implicit function. """ + fgraph = fgraph.unfreeze() inner_x = fgraph.inputs[0] fx = fgraph.outputs[0] if is_minimization: @@ -494,6 +546,8 @@ def _optimizer_connection_pattern(fgraph, is_minimization): class MinimizeScalarOp(ScipyScalarWrapperOp): + _scipy_props = ("method", "optimizer_kwargs") + def __init__( self, x: TensorVariable, @@ -514,7 +568,7 @@ def __init__( raise ValueError( "The variable `x` must be an input to the computational graph of the objective function." ) - self.fgraph = FunctionGraph([x, *args], [objective]) + self.fgraph = FunctionGraph([x, *args], [objective]).freeze(dedup_nodes=False) self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} @@ -611,6 +665,15 @@ def minimize_scalar( class MinimizeOp(ScipyVectorWrapperOp): + _scipy_props = ( + "method", + "jac", + "hess", + "hessp", + "use_vectorized_jac", + "optimizer_kwargs", + ) + def __init__( self, x: TensorVariable, @@ -651,6 +714,8 @@ def __init__( ) self.fgraph.add_output(hess_wrt_x) + self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.jac = jac self.hess = hess self.hessp = hessp @@ -813,6 +878,8 @@ def minimize( class RootScalarOp(ScipyScalarWrapperOp): + _scipy_props = ("method", "jac", "hess", "optimizer_kwargs") + def __init__( self, variables: TensorVariable, @@ -851,6 +918,8 @@ def __init__( f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0]) self.fgraph.add_output(f_double_prime) + self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} self.jac = jac @@ -965,9 +1034,10 @@ def root_scalar( class RootOp(ScipyVectorWrapperOp): - # These __props__ were wrong: they ignore the inner graph, - # making RootOps of different equations compare equal (and get merged) - # __props__ = ("method", "jac") + # eq/hash key on the (frozen) inner graph plus these props -- keying on the + # inner graph is what keeps RootOps of different equations distinct (an + # earlier ``__props__ = ("method", "jac")`` ignored it and merged them). + _scipy_props = ("method", "jac", "use_vectorized_jac", "optimizer_kwargs") def __init__( self, @@ -1003,6 +1073,8 @@ def __init__( ) self.fgraph.add_output(atleast_2d(jac_wrt_x)) + self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.jac = jac self.method = method @@ -1020,8 +1092,9 @@ def __str__(self): return f"{self.__class__.__name__}({str_args})" def build_fn(self): - outputs = self.inner_outputs - variables, *args = self.inner_inputs + fgraph = self.fgraph.unfreeze() + variables, *args = fgraph.inputs + outputs = fgraph.outputs if variables.ndim > 0: new_root_variables = variables From 639853333c2ad65969ec839e0127b294648df87b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 18 Jun 2026 19:53:56 +0200 Subject: [PATCH 7/8] Remove the clone_inner_graph(s) machinery Cloning an outer graph used to optionally deep-clone the inner graphs of HasInnerGraph ops (clone_inner_graph(s)=True). Now that every inner-graph op (Scan/OpFromGraph/Composite/ScipyWrapperOp) is immutable -- clone() returns self -- that deep-clone is always a no-op, so drop the kwarg and the branch throughout the clone machinery: Apply.clone, Apply.clone_with_new_inputs, FrozenApply.clone, clone, clone_node_and_cache, clone_get_equiv, FunctionGraph.clone, and rebuild_collect_shared. clone_node_and_cache's op-clone caching (which only existed to reuse cloned inner-graph ops) is removed too. Updates test_clone_inner_graph to the new shared-Op contract. --- pytensor/compile/rebuild.py | 10 ------ pytensor/graph/basic.py | 71 ++++++------------------------------- pytensor/graph/fg.py | 8 ++--- tests/graph/test_basic.py | 9 +++-- tests/scalar/test_loop.py | 19 +++------- 5 files changed, 20 insertions(+), 97 deletions(-) diff --git a/pytensor/compile/rebuild.py b/pytensor/compile/rebuild.py index 29933a51e0..6de79468bc 100644 --- a/pytensor/compile/rebuild.py +++ b/pytensor/compile/rebuild.py @@ -29,7 +29,6 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, ) -> tuple[ list[Variable], Variable, @@ -51,7 +50,6 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, ) -> tuple[ list[Variable], list[Variable], @@ -73,7 +71,6 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, ) -> tuple[ list[Variable], Out, @@ -95,7 +92,6 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, ) -> tuple[ list[Variable], list[Out], @@ -116,7 +112,6 @@ def rebuild_collect_shared( rebuild_strict=True, copy_inputs_over=True, no_default_updates=False, - clone_inner_graphs=False, ) -> tuple[ list[Variable], list[Variable] | Variable | Out | list[Out], @@ -156,9 +151,6 @@ def rebuild_collect_shared( If False (default), perform them all. Else, perform automatic updates on all Variables that are neither in "updates" nor in "no_default_updates". - clone_inner_graphs : bool - If ``True``, clone `Op`\s that are subclasses of `HasInnerGraph` and their - inner-graphs. """ @@ -201,7 +193,6 @@ def clone_v_get_shared_updates(v, copy_inputs_over): owner, clone_d, strict=rebuild_strict, - clone_inner_graphs=clone_inner_graphs, ) clone_d.setdefault(var, var) continue @@ -481,7 +472,6 @@ def param_to_in(param, allow_downcast=None): rebuild_strict=rebuild_strict, copy_inputs_over=True, no_default_updates=no_default_updates, - clone_inner_graphs=True, ) input_variables, cloned_extended_outputs, other_stuff = output_vars clone_d, update_d, _update_expr, shared_inputs = other_stuff diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 1c6a33990a..cb61859d7a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -226,38 +226,27 @@ def __str__(self): def __repr__(self): return str(self) - def clone(self, clone_inner_graph: bool = False) -> "Apply[OpType]": + def clone(self) -> "Apply[OpType]": r"""Clone this `Apply` instance. - Parameters - ---------- - clone_inner_graph - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. - Returns ------- A new `Apply` instance with new outputs. Notes ----- - Tags are copied from `self` to the returned instance. + Tags are copied from `self` to the returned instance. Inner-graph `Op`\s + are immutable, so the `Op` is shared rather than deep-cloned. """ - from pytensor.graph.op import HasInnerGraph - - new_op = self.op - - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore - cp = self.__class__( - new_op, self.inputs, [output.clone() for output in self.outputs] + self.op, self.inputs, [output.clone() for output in self.outputs] ) cp.tag = copy(self.tag) return cp def clone_with_new_inputs( - self, inputs: Sequence["Variable"], strict=True, clone_inner_graph=False + self, inputs: Sequence["Variable"], strict=True ) -> "Apply[OpType]": r"""Duplicate this `Apply` instance in a new graph. @@ -274,8 +263,6 @@ def clone_with_new_inputs( ``self.outputs``. If ``False``, then there's no guarantee that the clone's outputs will have the same types as ``self.outputs``, and cloning may not even be possible (it depends on the `Op`). - clone_inner_graph : bool - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. Returns ------- @@ -283,8 +270,6 @@ def clone_with_new_inputs( An `Apply` instance with the same `Op` but different outputs. """ - from pytensor.graph.op import HasInnerGraph - assert isinstance(inputs, list | tuple) remake_node = False new_inputs: list[Variable] = list(inputs) @@ -310,15 +295,10 @@ def clone_with_new_inputs( remake_node = True if remake_node: - new_op = self.op - - if isinstance(new_op, HasInnerGraph) and clone_inner_graph: # type: ignore - new_op = new_op.clone() # type: ignore - - new_node = new_op.make_node(*new_inputs) + new_node = self.op.make_node(*new_inputs) new_node.tag = copy(self.tag).__update__(new_node.tag) else: - new_node = self.clone(clone_inner_graph=clone_inner_graph) + new_node = self.clone() if new_node is self: # Immutable nodes (e.g. ``FrozenApply``) return ``self`` from # ``clone()``; mutating ``inputs`` would corrupt the shared node. @@ -896,7 +876,7 @@ def __init__(self, op, inputs, output_types, unique_idx=None): # All initialization is done in __new__ pass - def clone(self, clone_inner_graph: bool = False) -> Self: + def clone(self) -> Self: """Frozen nodes are immutable — cloning returns self.""" return self @@ -917,7 +897,6 @@ def clone( outputs: Sequence[Variable], copy_inputs: bool = True, copy_orphans: bool | None = None, - clone_inner_graphs: bool = False, ) -> tuple[list[Variable], list[Variable]]: r"""Copies the sub-graph contained between inputs and outputs. @@ -933,9 +912,6 @@ def clone( When ``None``, use the `copy_inputs` value. When ``True``, new orphans nodes are created. When ``False``, original orphans nodes are reused in the new graph. - clone_inner_graphs : bool - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. - Returns ------- The inputs and outputs of that copy. @@ -955,7 +931,6 @@ def clone( outputs, copy_inputs=copy_inputs, copy_orphans=copy_orphans, - clone_inner_graphs=clone_inner_graphs, ) return [cast(Variable, equiv[input]) for input in inputs], [ cast(Variable, equiv[output]) for output in outputs @@ -965,14 +940,10 @@ def clone( def clone_node_and_cache( node: Apply, clone_d: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]], - clone_inner_graphs=False, **kwargs, ) -> Apply | None: """Clone an `Apply` node and cache the results in `clone_d`. - This function handles `Op` clones that are generated by inner-graph - cloning. - Returns ------- ``None`` if all of `node`'s outputs are already in `clone_d`; otherwise, @@ -984,29 +955,12 @@ def clone_node_and_cache( # `clone_d`, then there's likely no need to clone it return None - # Use a cached `Op` clone when available - new_op: Op | None = cast(Optional["Op"], clone_d.get(node.op)) - cloned_inputs: list[Variable] = [cast(Variable, clone_d[i]) for i in node.inputs] - new_node = node.clone_with_new_inputs( - cloned_inputs, - # Only clone inner-graph `Op`s when there isn't a cached clone (and - # when `clone_inner_graphs` is enabled) - clone_inner_graph=clone_inner_graphs if new_op is None else False, - **kwargs, - ) - - if new_op: - # If we didn't clone the inner-graph `Op` above, because - # there was a cached version, set the cloned `Apply` to use - # the cached clone `Op` - new_node.op = new_op + new_node = node.clone_with_new_inputs(cloned_inputs, **kwargs) clone_d[node] = new_node - clone_d.setdefault(node.op, new_node.op) - for old_o, new_o in zip(node.outputs, new_node.outputs, strict=True): clone_d.setdefault(old_o, new_o) @@ -1020,7 +974,6 @@ def clone_get_equiv( copy_orphans: bool = True, memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None = None, - clone_inner_graphs: bool = False, **kwargs, ) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]: r"""Clone the graph between `inputs` and `outputs` and return a map of the cloned objects. @@ -1050,8 +1003,6 @@ def clone_get_equiv( Optionally start with a partly-filled dictionary for the return value. If a dictionary is passed, this function will work in-place on that dictionary and return it. - clone_inner_graphs - If ``True``, clone `HasInnerGraph` `Op`\s and their inner-graphs. kwargs Keywords passed to `Apply.clone_with_new_inputs`. @@ -1081,9 +1032,7 @@ def clone_get_equiv( else: memo[input] = input - clone_node_and_cache( - apply, memo, clone_inner_graphs=clone_inner_graphs, **kwargs - ) + clone_node_and_cache(apply, memo, **kwargs) # finish up by cloning any remaining outputs (it can happen) for output in outputs: diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index 8314bb5b5e..b40d7231fc 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -848,13 +848,9 @@ def check_integrity(self) -> None: def __repr__(self): return f"FunctionGraph({', '.join(graph_as_string(self.inputs, self.outputs))})" - def clone( - self, check_integrity=True, clone_inner_graphs: bool = False - ) -> "FunctionGraph": + def clone(self, check_integrity=True) -> "FunctionGraph": """Clone the graph.""" - return self.clone_get_equiv( - check_integrity, clone_inner_graphs=clone_inner_graphs - )[0] + return self.clone_get_equiv(check_integrity)[0] def clone_get_equiv( self, check_integrity: bool = True, attach_feature: bool = True, **kwargs diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index ce2a5d5429..95c97dd2ec 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -186,13 +186,12 @@ def test_clone_inner_graph(self): o2.name = "o1" o2_node = o2.owner - o2_node_clone = o2_node.clone(clone_inner_graph=True) + o2_node_clone = o2_node.clone() + # Inner-graph Ops are immutable, so cloning a node shares the Op (and its + # inner graph) rather than deep-cloning it. assert o2_node_clone is not o2_node - assert o2_node_clone.op.fgraph is not o2_node.op.fgraph - assert equal_computations( - o2_node_clone.op.fgraph.outputs, o2_node.op.fgraph.outputs - ) + assert o2_node_clone.op is o2_node.op class TestEval: diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index 9fabfda6b8..a7e68bc513 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -313,18 +313,7 @@ def test_identical_loops_share_inner_graph(): assert hash(op1) == hash(op2) assert op1.fgraph == op2.fgraph - # Two loops with the same structure but different outer inputs. - # MergeOptimizer can't collapse the Apply nodes (different inputs), - # but both should reference the same inner Op after merging. - n = int64("n") - a, b, c_val, d = float64("a"), float64("b"), float64("c_val"), float64("d") - y1 = op1(n, a, b) - y2 = op2(n, c_val, d) - - fn = function( - [n, a, b, c_val, d], [y1, y2], mode=Mode(optimizer="merge", linker="py") - ) - nodes = fn.maker.fgraph.toposort() - loop_nodes = [nd for nd in nodes if isinstance(nd.op, ScalarLoop)] - assert len(loop_nodes) == 2 - assert loop_nodes[0].op is loop_nodes[1].op + # Structurally identical inner graphs are globally interned via FrozenApply, + # so the two distinct op wrappers share the very same inner-graph nodes (the + # heavy state) at construction -- no compilation or canonicalization needed. + assert op1.fgraph.outputs[0].owner is op2.fgraph.outputs[0].owner From de06c057a01a25e08374be73914c6b0fab388f52 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 25 Jun 2026 16:19:54 +0200 Subject: [PATCH 8/8] Decouple FrozenApply from Apply; rebuild frozen inner graphs explicitly FrozenApply subclassed Apply and relied on a clone_with_new_inputs band-aid that silently rebuilt frozen nodes as mutable Applys -- error-prone, and it let frozen nodes masquerade as mutable ones. Introduce AbstractApply, a never-instantiated ancestor holding the read-only structural API shared by the mutable Apply and the immutable, interned FrozenApply. FrozenApply no longer subclasses Apply and has no clone/clone_with_new_inputs, so it can never be cloned by accident. The isinstance(Apply) sites that walk graphs generically (traversal, printing, FunctionGraph membership, the Variable.owner check) accept AbstractApply; outer- and op-specific sites stay Apply. Inner-graph manipulation that needs a mutable copy now does so explicitly, dropping the band-aid in Apply.clone_with_new_inputs: - rebuild_mutable reconstructs (rather than clones) a graph's nodes, treating the replace keys as the boundary, so it works whether the graph is frozen or mutable. - construct_nominal_fgraph uses it, so op construction rebuilds a frozen inner graph in a single pass; inner-graph rewrites pass frozen slices straight to Scan/OpFromGraph instead of unfreezing just to have construction re-freeze. - The scan rewrites (reconstruct_graph, push_out, trace, io _rebuild_scan), ScipyWrapperOp's de-dup rewrite, and OpFromGraph's shared-input rebuild apply their substitutions via rebuild_mutable; Scan.infer_shape's validator and Composite/ScalarLoop.make_node rebuild via op.make_node / unfreeze. - OpFromGraph.fn compiles an unfreeze()d copy (the canonical inner graph is immutable). --- pytensor/compile/builders.py | 53 ++++--- pytensor/compile/inner_graph_rewriting.py | 4 +- pytensor/d3viz/formatting.py | 4 +- pytensor/graph/basic.py | 160 ++++++++++++---------- pytensor/graph/fg.py | 11 +- pytensor/graph/replace.py | 35 +++++ pytensor/graph/traversal.py | 11 +- pytensor/printing.py | 6 +- pytensor/scalar/basic.py | 6 +- pytensor/scalar/loop.py | 6 +- pytensor/scan/op.py | 10 +- pytensor/scan/rewriting/push_out.py | 10 +- pytensor/scan/rewriting/trace.py | 4 +- pytensor/scan/rewriting/utils.py | 13 +- pytensor/scan/utils.py | 14 +- pytensor/tensor/optimize.py | 10 +- pytensor/tensor/rewriting/optimize.py | 4 +- tests/compile/test_builders.py | 10 +- 18 files changed, 216 insertions(+), 155 deletions(-) diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index 43a2ef6551..17943418ee 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -7,11 +7,9 @@ from collections.abc import Callable, Sequence from copy import copy from functools import partial -from typing import cast from pytensor.compile.maker import function from pytensor.compile.mode import get_mode -from pytensor.compile.rebuild import rebuild_collect_shared from pytensor.compile.sharedvalue import SharedVariable from pytensor.gradient import DisconnectedType, disconnected_type, grad, pushforward from pytensor.graph.basic import ( @@ -23,7 +21,7 @@ from pytensor.graph.fg import FrozenFunctionGraph, FunctionGraph from pytensor.graph.null_type import NullType from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern -from pytensor.graph.replace import clone_replace, graph_replace +from pytensor.graph.replace import graph_replace, rebuild_mutable from pytensor.graph.traversal import graph_inputs from pytensor.graph.utils import MissingInputError from pytensor.tensor.shape import Shape_i @@ -101,23 +99,17 @@ def construct_nominal_fgraph( ) ) - new = rebuild_collect_shared( - cast(Sequence[Variable], outputs), - inputs=inputs + implicit_shared_inputs, - replace=replacements, - copy_inputs_over=False, - ) - ( - local_inputs, - local_outputs, - (_clone_d, update_d, update_expr, new_shared_inputs), - ) = new + # Rebuild ``outputs`` rooted at the dummy (nominal) inputs. ``rebuild_mutable`` + # reconstructs every node, so this works whether the inner graph is mutable or + # still references immutable ``FrozenApply`` nodes (when an inner-graph rewrite + # assembled this op from a frozen inner graph) -- the rewrite need not unfreeze + # just to have construction re-freeze. Shared inputs were gathered above and + # are part of ``replacements``. + local_inputs = dummy_inputs + dummy_implicit_shared_inputs + local_outputs = rebuild_mutable(outputs, replacements) assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs) assert len(local_outputs) == len(outputs) - assert not update_d - assert not update_expr - assert not new_shared_inputs fgraph = FunctionGraph(local_inputs, local_outputs, clone=False) @@ -135,7 +127,9 @@ def construct_nominal_fgraph( fgraph.clients.pop(inp, None) fgraph.add_input(nom_inp) - return fgraph, implicit_shared_inputs, update_d, update_expr + # Inner graphs never carry shared-variable updates (asserted previously via + # rebuild_collect_shared); the update maps are always empty. + return fgraph, implicit_shared_inputs, {}, {} class OpFromGraph(Op, HasInnerGraph): @@ -343,10 +337,10 @@ def __init__( ) # Keep only the immutable (frozen) inner graph as ``op.fgraph``; the # mutable copy is transient, so the canonical inner graph can never be - # mutated in place. ``dedup_nodes=False``: inner graphs may carry inplace - # ops whose destroyed buffers must stay distinct; structural folding - # would alias them. See ``FunctionGraph.freeze``. - self.fgraph = inner_fgraph.freeze(dedup_nodes=False) + # mutated in place. Freeze with the default (no dedup): inner graphs may + # carry inplace ops whose destroyed buffers must stay distinct, and + # structural folding would alias them. See ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze() if strict and self.shared_inputs: raise ValueError( @@ -842,9 +836,7 @@ def make_node(self, *inputs): # If the new shared variables are inconsistent with the inner-graph, # such errors should arise in this step - new_inner_outputs = clone_replace( - self.inner_outputs, replace=replace, copy_inputs_over=True - ) + new_inner_outputs = rebuild_mutable(self.inner_outputs, replace) # It's possible that the new shared variable inputs aren't actually # shared variables. When they aren't we need to add them as new @@ -952,9 +944,14 @@ def fn(self): # (or built inplace, e.g. ``FusedElemwise``); those only ever destroy # internal buffers (inputs stay protected), so we accept them. mode = get_mode(None).excluding("symbolic_op_recognition") + # The canonical inner graph is frozen (immutable); compile an + # ``unfreeze()``d mutable copy. ``function`` re-clones it while wrapping + # the inputs/outputs as In/Out specs; that extra clone is one-time since + # ``_fn`` is cached. + unfrozen_fgraph = self.fgraph.unfreeze() self._fn = function( - self.inner_inputs, - self.inner_outputs, + unfrozen_fgraph.inputs, + unfrozen_fgraph.outputs, mode=mode, on_unused_input="ignore", accept_inplace=True, @@ -992,7 +989,7 @@ def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph: new_fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph( list(inner_fgraph.inputs), list(inner_fgraph.outputs) ) - new.fgraph = new_fgraph.freeze(dedup_nodes=False) + new.fgraph = new_fgraph.freeze() new.input_types = [inp.type for inp in new.fgraph.inputs] new.output_types = [out.type for out in new.fgraph.outputs] # Drop caches tied to the previous inner graph. diff --git a/pytensor/compile/inner_graph_rewriting.py b/pytensor/compile/inner_graph_rewriting.py index f956f462ce..688ea9dc86 100644 --- a/pytensor/compile/inner_graph_rewriting.py +++ b/pytensor/compile/inner_graph_rewriting.py @@ -87,8 +87,8 @@ def apply(self, fgraph): # ``optimize_inner_graphs``, so nested inner-graph ops recurse -- we # propagate the mode onto each inner graph so the recursion recovers it. mode = get_active_mode(fgraph) - # Inplace is baked into the frozen op (frozen ``dedup_nodes=False`` so - # distinct inplace buffers survive interning). ``OpFromGraph`` must not + # Inplace is baked into the frozen op (the default freeze does not dedup, + # so distinct inplace buffers survive interning). ``OpFromGraph`` must not # mutate its inputs, so all are protected; ``Scan`` protects all but the # destroyable taps. ``ScipyWrapperOp`` uses the no-inplace optimizer: # its inner graph is not backend-funcified but recompiled by diff --git a/pytensor/d3viz/formatting.py b/pytensor/d3viz/formatting.py index 5c879258cf..8dced65a5d 100644 --- a/pytensor/d3viz/formatting.py +++ b/pytensor/d3viz/formatting.py @@ -11,7 +11,7 @@ import pytensor from pytensor.compile import builders from pytensor.compile.executor import Function -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import AbstractApply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.traversal import graph_inputs from pytensor.printing import _try_pydot_import @@ -127,7 +127,7 @@ def __call__(self, fct, graph=None): else: if isinstance(fct, Variable): fct = [fct] - elif isinstance(fct, Apply): + elif isinstance(fct, AbstractApply): fct = fct.outputs assert isinstance(fct, list | tuple) assert all(isinstance(v, Variable) for v in fct) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index cb61859d7a..bf38d18c6d 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -15,7 +15,6 @@ Any, Generic, Optional, - Self, TypeVar, Union, cast, @@ -108,7 +107,78 @@ def dprint(self, **kwargs): return debugprint(self, **kwargs) -class Apply(Node, Generic[OpType]): # noqa: UP046 +class AbstractApply(Node): + r"""Common, immutability-agnostic base for `Apply` and `FrozenApply`. + + Never instantiated directly. It holds the read-only structural API shared by + the mutable `Apply` and the immutable, interned `FrozenApply`: the `op`, the + `inputs`/`outputs` sequences, and the queries derived from them. Mutation and + cloning live on `Apply` alone, so code that must reject frozen nodes can test + ``isinstance(x, Apply)`` while code that only reads structure can accept + `AbstractApply`. + """ + + op: "Op" + inputs: Sequence["Variable"] + outputs: Sequence["Variable"] + tag: Scratchpad + + def default_output(self): + """ + Returns the default output for this node. + + Returns + ------- + Variable instance + An element of self.outputs, typically self.outputs[0]. + + Notes + ----- + May raise AttributeError self.op.default_output is out of range, or if + there are multiple outputs and self.op.default_output does not exist. + + """ + do = getattr(self.op, "default_output", None) + if do is None: + if len(self.outputs) == 1: + return self.outputs[0] + else: + raise ValueError( + f"Multi-output Op {self.op} default_output not specified" + ) + return self.outputs[do] + + def __str__(self): + # FIXME: The called function is too complicated for this simple use case. + return op_as_string(self.inputs, self) + + def __repr__(self): + return str(self) + + def get_parents(self): + return list(self.inputs) + + @property + def out(self): + """An alias for `self.default_output`""" + return self.default_output() + + @property + def nin(self): + """The number of inputs.""" + return len(self.inputs) + + @property + def nout(self): + """The number of outputs.""" + return len(self.outputs) + + @property + def params_type(self): + return self.op.params_type + + +class Apply(AbstractApply, Generic[OpType]): # noqa: UP046 """A `Node` representing the application of an operation to inputs. Basically, an `Apply` instance is an object that represents the @@ -143,6 +213,8 @@ class Apply(Node, Generic[OpType]): # noqa: UP046 """ + op: OpType + def __init__( self, op: OpType, @@ -194,38 +266,6 @@ def __getstate__(self): d["tag"] = t return d - def default_output(self): - """ - Returns the default output for this node. - - Returns - ------- - Variable instance - An element of self.outputs, typically self.outputs[0]. - - Notes - ----- - May raise AttributeError self.op.default_output is out of range, or if - there are multiple outputs and self.op.default_output does not exist. - - """ - do = getattr(self.op, "default_output", None) - if do is None: - if len(self.outputs) == 1: - return self.outputs[0] - else: - raise ValueError( - f"Multi-output Op {self.op} default_output not specified" - ) - return self.outputs[do] - - def __str__(self): - # FIXME: The called function is too complicated for this simple use case. - return op_as_string(self.inputs, self) - - def __repr__(self): - return str(self) - def clone(self) -> "Apply[OpType]": r"""Clone this `Apply` instance. @@ -299,40 +339,9 @@ def clone_with_new_inputs( new_node.tag = copy(self.tag).__update__(new_node.tag) else: new_node = self.clone() - if new_node is self: - # Immutable nodes (e.g. ``FrozenApply``) return ``self`` from - # ``clone()``; mutating ``inputs`` would corrupt the shared node. - # Build a fresh, mutable ``Apply`` instead. - new_node = Apply( - self.op, new_inputs, [out.type() for out in self.outputs] - ) - new_node.tag = copy(self.tag).__update__(new_node.tag) - else: - new_node.inputs = new_inputs + new_node.inputs = new_inputs return new_node - def get_parents(self): - return list(self.inputs) - - @property - def out(self): - """An alias for `self.default_output`""" - return self.default_output() - - @property - def nin(self): - """The number of inputs.""" - return len(self.inputs) - - @property - def nout(self): - """The number of outputs.""" - return len(self.outputs) - - @property - def params_type(self): - return self.op.params_type - class Variable(Node, Generic[_TypeType, OptionalApplyType]): # noqa: UP046 r""" @@ -453,7 +462,7 @@ def __init__( self.owner = owner - if owner is not None and not isinstance(owner, Apply): + if owner is not None and not isinstance(owner, AbstractApply): raise TypeError("owner must be an Apply instance") if index is not None and not isinstance(index, int): @@ -806,10 +815,13 @@ def __reduce_ex__(protocol): return __reduce_ex__ -class FrozenApply(Apply): - """An immutable, globally-interned Apply node for frozen graphs. +class FrozenApply(AbstractApply): + """An immutable, globally-interned application node for frozen graphs. - ``inputs`` and ``outputs`` are tuples, so mutating them raises ``TypeError``. + It deliberately does *not* subclass `Apply`: it has no `clone` / + `clone_with_new_inputs` and its `inputs` / `outputs` are tuples, so it cannot + be mutated or rebuilt in place. Code wanting a mutable copy must unfreeze the + owning `FrozenFunctionGraph` (``op.fgraph.unfreeze()`` / ``bind``). Instances are interned on ``(op, inputs, output_types)``: constructing one with a matching key returns the cached instance. Constant inputs are keyed @@ -858,8 +870,8 @@ def __new__( instance = object.__new__(cls) instance.op = op instance._unique_idx = unique_idx - instance.inputs = inputs # type: ignore[assignment] - instance.outputs = tuple( # type: ignore[assignment] + instance.inputs = inputs + instance.outputs = tuple( t.variable_type(type=t, owner=instance, index=i) for i, t in enumerate(output_types) ) @@ -876,10 +888,6 @@ def __init__(self, op, inputs, output_types, unique_idx=None): # All initialization is done in __new__ pass - def clone(self) -> Self: - """Frozen nodes are immutable — cloning returns self.""" - return self - def __reduce__(self): return ( type(self), diff --git a/pytensor/graph/fg.py b/pytensor/graph/fg.py index b40d7231fc..811b6274da 100644 --- a/pytensor/graph/fg.py +++ b/pytensor/graph/fg.py @@ -9,6 +9,7 @@ from pytensor.configdefaults import config from pytensor.graph.basic import ( + AbstractApply, Apply, AtomicVariable, Constant, @@ -912,10 +913,10 @@ def __getstate__(self): d.pop("_execute_callbacks_times_dict", None) return d - def __contains__(self, item: Variable | Apply) -> bool: + def __contains__(self, item: Variable | AbstractApply) -> bool: if isinstance(item, Variable): return item in self.variables - elif isinstance(item, Apply): + elif isinstance(item, AbstractApply): return item in self.apply_nodes else: raise TypeError() @@ -1007,7 +1008,7 @@ def _resolve_input(inp, memo=memo): output_types, unique_idx=None if dedup_nodes else node_idx, ) - sorted_apply_nodes.append(new_node) + sorted_apply_nodes.append(cast(Apply, new_node)) memo.update(zip(node.outputs, new_node.outputs, strict=True)) @@ -1040,7 +1041,9 @@ def _resolve_input(inp, memo=memo): ) self.apply_nodes: frozenset[Apply] = frozenset(sorted_apply_nodes) self._toposort: tuple[Apply, ...] = tuple(sorted_apply_nodes) - self._output_nodes: tuple[Apply, ...] = tuple(output_nodes) + self._output_nodes: tuple[Apply, ...] = cast( + "tuple[Apply, ...]", tuple(output_nodes) + ) self._variables: frozenset[Variable] | None = None self._clients: dict[Variable, list[ClientType]] | None = None diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index ad309161aa..e4987d8ba8 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,5 +1,6 @@ import warnings from collections.abc import Iterable, Mapping, Sequence +from copy import copy from functools import singledispatch from typing import cast, overload @@ -11,6 +12,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.traversal import ( + io_toposort, toposort, truncated_graph_inputs, ) @@ -37,6 +39,39 @@ def _format_replace(replace: ReplaceTypes | None = None) -> dict[Variable, Varia return items +def rebuild_mutable( + outputs: Sequence[Variable], + replace: ReplaceTypes | None = None, +) -> list[Variable]: + """Rebuild ``outputs`` as fresh, mutable `Apply` nodes, applying ``replace``. + + Like `clone_replace`, but every reachable node is *reconstructed* rather than + cloned, so it works on graphs containing immutable `FrozenApply` nodes (which + cannot be cloned). Inner-graph rewrites do their analysis on the cheap frozen + graph and call this only once they commit to firing, to obtain a mutable copy + of the inner graph (optionally substituting some variables) for the new op. + Unreplaced root inputs are kept as-is; the caller re-roots them. + """ + memo = _format_replace(replace) + # The replaced variables are the graph boundary: stop the walk there so we + # rebuild only the graph between them and ``outputs`` (not past them into a + # surrounding graph), and keep any unreplaced roots as-is. + for node in io_toposort(list(memo), outputs): + if all(out in memo for out in node.outputs): + continue + new_node = Apply( + node.op, + [memo.get(i, i) for i in node.inputs], + [out.type() for out in node.outputs], + ) + new_node.tag = copy(node.tag) + for old_out, new_out in zip(node.outputs, new_node.outputs, strict=True): + new_out.name = old_out.name + new_out.tag = copy(old_out.tag) + memo.update(zip(node.outputs, new_node.outputs, strict=True)) + return [memo.get(out, out) for out in outputs] + + @overload def clone_replace( output: Sequence[Variable], diff --git a/pytensor/graph/traversal.py b/pytensor/graph/traversal.py index 2b92fdd6ca..eacf58a0f3 100644 --- a/pytensor/graph/traversal.py +++ b/pytensor/graph/traversal.py @@ -12,7 +12,7 @@ overload, ) -from pytensor.graph.basic import Apply, Constant, Node, Variable +from pytensor.graph.basic import AbstractApply, Apply, Constant, Node, Variable T = TypeVar("T", bound=Node) @@ -340,7 +340,7 @@ def apply_depends_on(apply: Apply, depends_on: Apply | Iterable[Apply]) -> bool: bool """ - if isinstance(depends_on, Apply): + if isinstance(depends_on, AbstractApply): depends_on = frozenset((depends_on,)) else: depends_on = frozenset(depends_on) @@ -683,7 +683,7 @@ def toposort_with_orderings( def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings): if obj in blocker_set: return None - if isinstance(obj, Apply): + if isinstance(obj, AbstractApply): return [*obj.inputs, *orderings.get(obj, [])] else: if (apply := obj.owner) is not None: @@ -694,7 +694,7 @@ def compute_deps(obj, blocker_set=frozenset(blockers), orderings=orderings): # mypy doesn't like conditional functions with different signatures, # but passing the globals as optional is faster def compute_deps(obj, orderings=orderings): # type: ignore[misc] - if isinstance(obj, Apply): + if isinstance(obj, AbstractApply): return [*obj.inputs, *orderings.get(obj, [])] else: if (apply := obj.owner) is not None: @@ -705,7 +705,8 @@ def compute_deps(obj, orderings=orderings): # type: ignore[misc] yield from ( apply for apply in walk_toposort(graphs, deps=compute_deps) - if isinstance(apply, Apply) + # mypy doesn't understand that our generator will return both Apply and Variables + if isinstance(apply, AbstractApply) # type: ignore ) diff --git a/pytensor/printing.py b/pytensor/printing.py index 3c27fc7bde..8593d7b0ed 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -20,7 +20,7 @@ from pytensor.compile.executor import Function from pytensor.compile.io import In, Out from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, Constant, Variable +from pytensor.graph.basic import AbstractApply, Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import HasInnerGraph, Op, StorageMapType from pytensor.graph.traversal import graph_inputs, toposort @@ -681,7 +681,7 @@ def _show_inner_graph(op): profile_list.append(None) storage_maps.append(None) topo_orders.append(None) - elif isinstance(obj, Apply): + elif isinstance(obj, AbstractApply): outputs_to_print.extend(obj.outputs) profile_list.extend(None for item in obj.outputs) storage_maps.extend(None for item in obj.outputs) @@ -2064,7 +2064,7 @@ def pydotprint( else: if isinstance(fct, Variable): fct = [fct] - elif isinstance(fct, Apply): + elif isinstance(fct, AbstractApply): fct = fct.outputs assert isinstance(fct, list | tuple) assert all(isinstance(v, Variable) for v in fct) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 33581346f2..36f2e4e307 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4277,9 +4277,11 @@ def make_node(self, *inputs): if self.inputs_type == tuple(i.type for i in inputs): return super().make_node(*inputs) else: - # Make a new op with the right input types. + # Make a new op with the right input types. Work on a mutable copy: + # this re-infers output types (rebuild_strict=False), which + # rebuild_collect_shared can't do on the immutable frozen inner graph. assert len(inputs) == self.nin - fg = self.fgraph + fg = self.fgraph.unfreeze() res = pytensor.compile.rebuild_collect_shared( fg.outputs, replace=dict(zip(fg.inputs, inputs, strict=True)), diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 86e0958ef4..432478e770 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -125,8 +125,10 @@ def make_node(self, n_steps, *inputs): if self.inputs_type == tuple(i.type for i in inputs): return super().make_node(n_steps, *inputs) else: - # Make a new op with the right input types. - fg = self.fgraph + # Make a new op with the right input types. Work on a mutable copy: + # this re-infers output types (rebuild_strict=False), which + # rebuild_collect_shared can't do on the immutable frozen inner graph. + fg = self.fgraph.unfreeze() res = rebuild_collect_shared( fg.outputs, replace=dict(zip(fg.inputs, inputs, strict=True)), diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index 1a82282612..65f04d13b3 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -960,10 +960,10 @@ def __init__( # Keep only the immutable (frozen) inner graph as ``op.fgraph``; the # mutable copy is transient, so the canonical inner graph can never be # mutated in place. Inner graphs may carry inplace ops (baked by - # ``optimize_inner_graphs``) whose destroyed buffers must stay distinct; - # ``dedup_nodes=False`` keeps structural folding from aliasing them. - # See ``FunctionGraph.freeze``. - self.fgraph = inner_fgraph.freeze(dedup_nodes=False) + # ``optimize_inner_graphs``) whose destroyed buffers must stay distinct, + # so freeze with the default (no dedup) to avoid aliasing them. See + # ``FunctionGraph.freeze``. + self.fgraph = inner_fgraph.freeze() self.info = info self.truncate_gradient = truncate_gradient @@ -1066,7 +1066,7 @@ def __setstate__(self, d): # Back-compat: older pickles stored a mutable inner ``fgraph`` (plus a # separate ``_frozen_fgraph``). Collapse to the single frozen graph. if not isinstance(self.fgraph, FrozenFunctionGraph): - self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.fgraph = self.fgraph.freeze() self.__dict__.pop("_frozen_fgraph", None) # Ensure that the graph associated with the inner function is valid. self.validate_inner_graph() diff --git a/pytensor/scan/rewriting/push_out.py b/pytensor/scan/rewriting/push_out.py index 949dfd11a2..78cb93bbd4 100644 --- a/pytensor/scan/rewriting/push_out.py +++ b/pytensor/scan/rewriting/push_out.py @@ -21,7 +21,7 @@ from pytensor.compile.ops import DeepCopyOp, ViewOp from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph, Output -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import rebuild_mutable from pytensor.graph.rewriting.basic import node_rewriter from pytensor.graph.type import HasShape from pytensor.scan.op import Scan @@ -176,7 +176,7 @@ def add_to_replace(y): nw_outer.append(repl_out) givens[to_repl] = repl_in - op_outs = clone_replace(node_outputs, replace=givens) + op_outs = rebuild_mutable(node_outputs, givens) op_ins = node_inputs + nw_inner new_info = dataclasses.replace( @@ -399,7 +399,7 @@ def add_to_replace(y): givens[to_repl] = repl_in - op_outs = clone_replace(node_outputs, replace=givens) + op_outs = rebuild_mutable(node_outputs, givens) op_ins = nw_inner + node_inputs # Reconstruct node @@ -604,7 +604,9 @@ def add_nitsot_outputs( assert isinstance(old_scan_node.op, Scan) - # Create the `Scan` `Op` from the `ScanArgs` + # Create the `Scan` `Op` from the `ScanArgs`. ``args`` was parsed with + # ``clone=False``, so the inner graph is still frozen; ``Scan`` construction + # rebuilds (and re-freezes) it. new_scan_op = Scan( new_scan_args.inner_inputs, new_scan_args.inner_outputs, diff --git a/pytensor/scan/rewriting/trace.py b/pytensor/scan/rewriting/trace.py index 2c681340a8..04abdceb5d 100644 --- a/pytensor/scan/rewriting/trace.py +++ b/pytensor/scan/rewriting/trace.py @@ -922,7 +922,9 @@ def scan_sit_sot_to_untraced(fgraph, node): convertible_set = set(convertible) - # Gather current inner inputs/outputs by category + # Gather current inner inputs/outputs by category. These are read-only views + # of the frozen inner graph; the reassembled inputs/outputs are handed to a + # new ``Scan``, whose construction rebuilds (and re-freezes) them. inner_inputs = list(op.inner_inputs) inner_outputs = list(op.inner_outputs) diff --git a/pytensor/scan/rewriting/utils.py b/pytensor/scan/rewriting/utils.py index f96dda5554..e6efefc618 100644 --- a/pytensor/scan/rewriting/utils.py +++ b/pytensor/scan/rewriting/utils.py @@ -3,7 +3,7 @@ from typing import cast from pytensor.graph.basic import Apply, Variable -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import rebuild_mutable from pytensor.scan.op import Scan @@ -24,9 +24,9 @@ def _rebuild_scan_with_new_signature( Each ``drop_*`` argument is a set of indices into its category; the rebuilt op retains only the entries whose index is not listed. - ``inner_substitutions``, when provided, is applied via ``clone_replace`` - on the inner outputs before the rebuild -- use it to inline constants - or rewire duplicate inner inputs. + ``inner_substitutions``, when provided, is applied while rebuilding the + inner outputs into fresh mutable nodes -- use it to inline constants or + rewire duplicate inner inputs. Returns a ``replacements`` dict: kept outer outputs map to their counterparts on the new op, dropped outputs carry no mapping (they @@ -77,7 +77,10 @@ def _rebuild_scan_with_new_signature( inner_outputs = op.inner_outputs if inner_substitutions: - inner_outputs = clone_replace(inner_outputs, replace=inner_substitutions) + # Apply the substitutions, rebuilding the frozen inner outputs into + # mutable nodes. Without substitutions they stay frozen and ``Scan`` + # construction rebuilds them, avoiding an unfreeze-just-to-freeze pass. + inner_outputs = rebuild_mutable(inner_outputs, inner_substitutions) inner_mm_out_groups = op.inner_mitmot_outs_grouped(inner_outputs) inner_ms_outs = op.inner_mitsot_outs(inner_outputs) inner_ss_outs = op.inner_sitsot_outs(inner_outputs) diff --git a/pytensor/scan/utils.py b/pytensor/scan/utils.py index a999979f12..6cd6ba3709 100644 --- a/pytensor/scan/utils.py +++ b/pytensor/scan/utils.py @@ -14,7 +14,7 @@ from pytensor import tensor as pt from pytensor.compile.debug.profiling import ProfileStats from pytensor.graph.basic import Constant, Variable, equal_computations -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import clone_replace, rebuild_mutable from pytensor.graph.traversal import graph_inputs from pytensor.graph.type import HasDataType from pytensor.tensor.basic import AllocEmpty, cast @@ -310,7 +310,11 @@ def get_value(out): all_inputs = [inp for (inp, is_valid) in inputs] equiv_inputs = [inp for (inp, is_valid) in inputs if not is_valid] if equiv_inputs: - cloned_node = out.owner.clone_with_new_inputs(all_inputs) + # Rebuild the node with the equivalent inputs. ``make_node`` works + # whether ``out.owner`` is a mutable ``Apply`` or an immutable + # ``FrozenApply`` (the op is shared); the latter has no + # ``clone_with_new_inputs``. + cloned_node = out.owner.op.make_node(*all_inputs) cloned_out = cloned_node.outputs[out.index] self.invalid.add(out) self.valid.add(cloned_out) @@ -334,9 +338,9 @@ def reconstruct_graph(inputs, outputs, tag=None): if tag is None: tag = "" nw_inputs = [safe_new(x, tag) for x in inputs] - - givens = {x: nw_x for nw_x, x in zip(nw_inputs, inputs, strict=True)} - nw_outputs = clone_replace(outputs, replace=givens) + # ``outputs`` may come from a frozen inner graph, whose ``FrozenApply`` nodes + # are immutable and cannot be cloned; ``rebuild_mutable`` reconstructs them. + nw_outputs = rebuild_mutable(outputs, dict(zip(inputs, nw_inputs, strict=True))) return (nw_inputs, nw_outputs) diff --git a/pytensor/tensor/optimize.py b/pytensor/tensor/optimize.py index 639f0950fc..d8ec4e99b5 100644 --- a/pytensor/tensor/optimize.py +++ b/pytensor/tensor/optimize.py @@ -246,7 +246,7 @@ def clone_with_inner_graph(self, inner_fgraph): clone_op.fgraph = ( inner_fgraph if isinstance(inner_fgraph, FrozenFunctionGraph) - else inner_fgraph.freeze(dedup_nodes=False) + else inner_fgraph.freeze() ) return clone_op @@ -568,7 +568,7 @@ def __init__( raise ValueError( "The variable `x` must be an input to the computational graph of the objective function." ) - self.fgraph = FunctionGraph([x, *args], [objective]).freeze(dedup_nodes=False) + self.fgraph = FunctionGraph([x, *args], [objective]).freeze() self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} @@ -714,7 +714,7 @@ def __init__( ) self.fgraph.add_output(hess_wrt_x) - self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.fgraph = self.fgraph.freeze() self.jac = jac self.hess = hess @@ -918,7 +918,7 @@ def __init__( f_double_prime = grad(self.fgraph.outputs[-1], self.fgraph.inputs[0]) self.fgraph.add_output(f_double_prime) - self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.fgraph = self.fgraph.freeze() self.method = method self.optimizer_kwargs = optimizer_kwargs if optimizer_kwargs is not None else {} @@ -1073,7 +1073,7 @@ def __init__( ) self.fgraph.add_output(atleast_2d(jac_wrt_x)) - self.fgraph = self.fgraph.freeze(dedup_nodes=False) + self.fgraph = self.fgraph.freeze() self.jac = jac diff --git a/pytensor/tensor/rewriting/optimize.py b/pytensor/tensor/rewriting/optimize.py index 41de8666f8..e6b3f59819 100644 --- a/pytensor/tensor/rewriting/optimize.py +++ b/pytensor/tensor/rewriting/optimize.py @@ -1,6 +1,6 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.graph.replace import clone_replace +from pytensor.graph.replace import rebuild_mutable from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.optimize import ScipyWrapperOp from pytensor.tensor.rewriting.basic import register_canonicalize @@ -40,7 +40,7 @@ def remove_constants_and_duplicate_inputs_scipy(fgraph, node): if not givens: return None - new_inner_outputs = clone_replace(op.inner_outputs, replace=givens) + new_inner_outputs = rebuild_mutable(op.inner_outputs, givens) new_inner_inputs = (inner_x, *new_inner_args) new_fgraph = FunctionGraph(new_inner_inputs, new_inner_outputs, clone=False) new_op = op.clone_with_new_fgraph(new_fgraph) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index 8cc467ba62..84617fad0c 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -541,11 +541,13 @@ def test_outputs_consistency(self): assert equal_computations(op.inner_outputs, [x**2 / x], op.inner_inputs, [x]) # Optimizing a copy of the inner graph (here FAST_RUN, which rewrites - # ``x**2 / x`` to ``x``) must not leak back into the canonical, shared - # inner graph -- the compiled `FunctionGraph` is a separate clone. + # ``x**2 / x`` to ``x``) must not leak back into the canonical, frozen + # inner graph. The canonical graph is immutable and is never handed to + # ``function`` directly; compile an ``unfreeze()``d mutable copy instead. + unfrozen = op.fgraph.unfreeze() fn = function( - op.inner_inputs, - op.inner_outputs, + unfrozen.inputs, + unfrozen.outputs, mode="FAST_RUN", on_unused_input="ignore", accept_inplace=True,