Skip to content

Make inner graph Ops immutable#2243

Open
ricardoV94 wants to merge 10 commits into
pymc-devs:mainfrom
ricardoV94:real_frozen_graphs
Open

Make inner graph Ops immutable#2243
ricardoV94 wants to merge 10 commits into
pymc-devs:mainfrom
ricardoV94:real_frozen_graphs

Conversation

@ricardoV94

@ricardoV94 ricardoV94 commented Jun 18, 2026

Copy link
Copy Markdown
Member

Closes #2232
Closes #2028
Closes #2033
Closes #2194 by adding a method (should it be the default) where theres's no node deduplication, as the toposort index becomes part of the hash

No more inner graph "cloning" of inner graphs

@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch 3 times, most recently from 2c3b605 to 8070b5d Compare June 20, 2026 15:39
…inplace regression)

Move inner-graph optimization of Scan/OpFromGraph out of the linker dispatch
and make_thunk (where it mutated the canonical, shared inner FunctionGraph in
place) into a single `optimize_inner_graphs` rewrite that runs near the end of
optimization and produces NEW immutable ops. This fixes the bug where compiling
a graph (e.g. a scan-based pymc CustomDist under numba) corrupted the canonical
inner graph and broke later logp derivation.

Key pieces:
- `GraphRewriter.rewrite` advertises the active rewriter on the fgraph and
  `FunctionMaker` attaches the compile mode; `get_active_rewriter`/
  `get_active_mode` let the rewrite reuse the exact outer compilation (mode,
  fast_run/fast_compile, linker required/incompatible rewrites). `config.mode`
  is unreliable across nested compilations.
- `optimize_inner_graphs` (per-op, recursive): unfreezes each inner graph,
  optimizes it with the transferred outer query, and rebuilds via
  `clone_with_inner_graph` (OpFromGraph/Scan). Skips `*WithCoreShape`.
- Registered at optdb position 49.6 (fast_run/fast_compile) and required for the
  numba/jax/pytorch/mlx linkers.
- All linker dispatches + Scan.fn now operate on a throwaway `unfreeze()` copy,
  never the canonical op.
- Deprecations: Scan `mode` (still honored/reconciled with the active linker)
  and OpFromGraph `compile_kwargs` (now ignored; `on_unused_input` subsumed as
  the default, governed by `strict`).
- Regression test: compiling must not mutate the canonical inner graph.

TEMPORARY REGRESSION (to be removed before this is one complete unit):
The rewrite currently EXCLUDES `inplace`, so the canonical/frozen inner graph
stays inplace-free and is not exposed to FrozenApply/MergeOptimizer collapsing
distinct-but-equal inplace buffers (gh pymc-devs#2194). Inplace is therefore still
applied per-node at link time on the transient (so runtime is correct), which
means dispatch is not yet fully rewrite-free. Completing the unit requires the
pymc-devs#2194 dedup guard (uniqueness token for inplace-distinct buffers) so inplace can
be baked into the frozen op. Two numba inplace introspection tests are xfailed
in the meantime; MinimizeOp/ScipyWrapperOp and removal of the clone_inner_graph
machinery remain follow-ups.
OpFromGraph now deprecates and ignores compile_kwargs (the inner graph
inherits the outer compilation). FusedElemwise still passed
accept_inplace=True, so it warned at itself with a FutureWarning whenever
elemwise fusion ran under filterwarnings=error. The kwarg was already a
no-op, so drop it.
FrozenApply interns structurally-identical nodes globally, which is
destroy-unaware: it folds two distinct-but-equal inplace buffers (that
MergeOptimizer deliberately kept apart) into one, aliasing them.

With dedup_nodes=False, FrozenFunctionGraph folds each node's toposort
position into the FrozenApply intern key, so structurally-identical nodes
within a graph stay distinct (buffers don't alias) while two
structurally-identical graphs still compare equal (positions line up),
preserving cross-graph equality / op-merge / funcify-cache reuse. The
default (True) keeps the existing global interning behavior.
@ricardoV94 ricardoV94 force-pushed the real_frozen_graphs branch from 8070b5d to fe15202 Compare June 20, 2026 15:39
Previously inner-graph inplace was excluded from the rewrite and applied
to a link-time transient, so the canonical frozen op stayed inplace-free
(a temporary regression). Now the rewrite bakes inplace into the frozen
op, and dispatch is rewrite-free.

- optimize_inner_graphs groups apply NODES (OpFromGraph by op; Scan by
  destroyable-tap signature, since tap-inplace depends on outer buffer
  shapes), attaches a Supervisor (OFG protects all inputs; Scan protects
  all but destroyable taps), runs the optimizer including inplace, and
  bakes the result into each new frozen op (frozen dedup_nodes=False so
  distinct inplace buffers survive).
- Scan.inner_destroyable_inputs factors out the destroyable-tap logic
  shared by the rewrite and numba_funcify_Scan.
- OFG/Scan inner graphs freeze with dedup_nodes=False; OpFromGraph.fn
  compiles with accept_inplace=True (inner graphs may carry inplace that
  only destroys internal buffers; inputs stay protected).
- numba_funcify_Scan/OpFromGraph drop the redundant optimizer pass and
  just funcify the baked graph (supervisor/insert_deepcopy stay as
  link-time aliasing setup).
- Un-xfail test_ofg_inner_inplace and test_inplace_taps.
…nputs

When the input types are unchanged, clone_with_new_inputs cloned the node
and reassigned its inputs. For immutable nodes whose clone() returns self
(FrozenApply), this mutated the shared, supposedly-immutable node in place,
silently corrupting any frozen inner graph that was clone_replace'd. Build a
fresh, mutable Apply in that case instead -- the per-node form of bind.
Scan and OpFromGraph kept both a mutable self.fgraph and a separate frozen
_frozen_fgraph. Collapse to one immutable graph exposed as op.fgraph (a
FrozenFunctionGraph): the mutable copy is gone, so the canonical inner graph
can never be mutated in place. clone() returns self (like Composite), and
eq/hash/dispatch/inner_inputs all read the frozen graph.

Rewrites that recombine inner variables to build a new Scan (push_out, trace,
io/utils) read the frozen inner_inputs/inner_outputs directly and pass them to
the Scan constructor; construct_nominal_fgraph rebuilds the body as fresh
mutable Apply nodes (relying on clone_with_new_inputs no longer mutating frozen
nodes). Rewrites that mutate the inner graph in place (the linalg scan-solve
split) unfreeze() a transient first. scan_make_inplace shallow-copies the op
(clone() is now self). inner_inputs/inner_outputs return lists so callers keep
list semantics. Updates the OFG/Scan clone tests to the immutable contract.
ScipyWrapperOp (MinimizeOp/RootOp) still keeps a *mutable* inner fgraph (it has
no frozen graph yet), so the rewrite's op.fgraph.unfreeze() crashed on it -- it
was matched before being made frozen. Its inner graph is already optimized
lazily at link time by build_fn, so exclude it from the matcher until it is
frozen, which also un-breaks tests/tensor/test_optimize.py.
Cloning an outer graph used to optionally deep-clone the inner graphs of
HasInnerGraph ops (clone_inner_graph(s)=True). Now that every inner-graph op
(Scan/OpFromGraph/Composite) is immutable -- clone() returns self -- that
deep-clone is always a no-op, so drop the kwarg and the branch throughout the
clone machinery: Apply.clone, Apply.clone_with_new_inputs, FrozenApply.clone,
clone, clone_node_and_cache, clone_get_equiv, FunctionGraph.clone, and
rebuild_collect_shared. clone_node_and_cache's op-clone caching (which only
existed to reuse cloned inner-graph ops) is removed too. ScipyWrapperOp.clone
drops the now-removed kwarg. Updates test_clone_inner_graph to the new
shared-Op contract.
symbolic_op_recognition rewrites fold a pattern into an inner-graph op -- e.g.
local_softmax_stabilize turns exp(x)/sum(exp(x)) into Softmax, which is itself
an OpFromGraph whose inner graph is that very pattern. optimize_inner_graphs
ran the full optimizer (including those rewrites) on each inner graph, so it
re-created Softmax inside Softmax and recursed on the new op's inner graph
without end (RecursionError compiling any softmax under FAST_RUN or NUMBA).

The old link-time inner-graph optimization excluded symbolic_op_recognition;
restore that exclusion across all three optimizer paths.
MinimizeScalarOp/MinimizeOp/RootScalarOp/RootOp built a mutable inner
FunctionGraph and kept it as self.fgraph; that left them as the last inner-graph
ops not collapsed to a single frozen graph, so they were excluded from
optimize_inner_graphs.

Build the inner graph (plus the jac/hess outputs) mutably in __init__ as before,
then freeze(dedup_nodes=False) it -- the canonical op.fgraph is now an immutable
FrozenFunctionGraph that can never be mutated in place. clone() returns self,
inner_inputs/outputs return lists, and eq/hash key on the frozen graph plus the
per-subclass _scipy_props (so e.g. RootOps of different equations stay distinct
without a hand-maintained __props__). clone_with_inner_graph bakes an
optimized inner graph into a new frozen op for the rewrite.

Graph manipulation (build_fn, compute_implicit_gradients, the connection
pattern) runs on a fresh unfreeze()d copy: graph_replace on frozen nodes would
otherwise leak them into the outer graph.

Re-add ScipyWrapperOp to the optimize_inner_graphs matcher (reverting the
interim exclusion). It uses the no-inplace optimizer because its inner graph is
not backend-funcified but recompiled by build_fn on the scipy-wrapped graph.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

1 participant