Make inner graph Ops immutable#2243
Open
ricardoV94 wants to merge 10 commits into
Open
Conversation
2c3b605 to
8070b5d
Compare
…inplace regression) Move inner-graph optimization of Scan/OpFromGraph out of the linker dispatch and make_thunk (where it mutated the canonical, shared inner FunctionGraph in place) into a single `optimize_inner_graphs` rewrite that runs near the end of optimization and produces NEW immutable ops. This fixes the bug where compiling a graph (e.g. a scan-based pymc CustomDist under numba) corrupted the canonical inner graph and broke later logp derivation. Key pieces: - `GraphRewriter.rewrite` advertises the active rewriter on the fgraph and `FunctionMaker` attaches the compile mode; `get_active_rewriter`/ `get_active_mode` let the rewrite reuse the exact outer compilation (mode, fast_run/fast_compile, linker required/incompatible rewrites). `config.mode` is unreliable across nested compilations. - `optimize_inner_graphs` (per-op, recursive): unfreezes each inner graph, optimizes it with the transferred outer query, and rebuilds via `clone_with_inner_graph` (OpFromGraph/Scan). Skips `*WithCoreShape`. - Registered at optdb position 49.6 (fast_run/fast_compile) and required for the numba/jax/pytorch/mlx linkers. - All linker dispatches + Scan.fn now operate on a throwaway `unfreeze()` copy, never the canonical op. - Deprecations: Scan `mode` (still honored/reconciled with the active linker) and OpFromGraph `compile_kwargs` (now ignored; `on_unused_input` subsumed as the default, governed by `strict`). - Regression test: compiling must not mutate the canonical inner graph. TEMPORARY REGRESSION (to be removed before this is one complete unit): The rewrite currently EXCLUDES `inplace`, so the canonical/frozen inner graph stays inplace-free and is not exposed to FrozenApply/MergeOptimizer collapsing distinct-but-equal inplace buffers (gh pymc-devs#2194). Inplace is therefore still applied per-node at link time on the transient (so runtime is correct), which means dispatch is not yet fully rewrite-free. Completing the unit requires the pymc-devs#2194 dedup guard (uniqueness token for inplace-distinct buffers) so inplace can be baked into the frozen op. Two numba inplace introspection tests are xfailed in the meantime; MinimizeOp/ScipyWrapperOp and removal of the clone_inner_graph machinery remain follow-ups.
OpFromGraph now deprecates and ignores compile_kwargs (the inner graph inherits the outer compilation). FusedElemwise still passed accept_inplace=True, so it warned at itself with a FutureWarning whenever elemwise fusion ran under filterwarnings=error. The kwarg was already a no-op, so drop it.
FrozenApply interns structurally-identical nodes globally, which is destroy-unaware: it folds two distinct-but-equal inplace buffers (that MergeOptimizer deliberately kept apart) into one, aliasing them. With dedup_nodes=False, FrozenFunctionGraph folds each node's toposort position into the FrozenApply intern key, so structurally-identical nodes within a graph stay distinct (buffers don't alias) while two structurally-identical graphs still compare equal (positions line up), preserving cross-graph equality / op-merge / funcify-cache reuse. The default (True) keeps the existing global interning behavior.
8070b5d to
fe15202
Compare
Previously inner-graph inplace was excluded from the rewrite and applied to a link-time transient, so the canonical frozen op stayed inplace-free (a temporary regression). Now the rewrite bakes inplace into the frozen op, and dispatch is rewrite-free. - optimize_inner_graphs groups apply NODES (OpFromGraph by op; Scan by destroyable-tap signature, since tap-inplace depends on outer buffer shapes), attaches a Supervisor (OFG protects all inputs; Scan protects all but destroyable taps), runs the optimizer including inplace, and bakes the result into each new frozen op (frozen dedup_nodes=False so distinct inplace buffers survive). - Scan.inner_destroyable_inputs factors out the destroyable-tap logic shared by the rewrite and numba_funcify_Scan. - OFG/Scan inner graphs freeze with dedup_nodes=False; OpFromGraph.fn compiles with accept_inplace=True (inner graphs may carry inplace that only destroys internal buffers; inputs stay protected). - numba_funcify_Scan/OpFromGraph drop the redundant optimizer pass and just funcify the baked graph (supervisor/insert_deepcopy stay as link-time aliasing setup). - Un-xfail test_ofg_inner_inplace and test_inplace_taps.
…nputs When the input types are unchanged, clone_with_new_inputs cloned the node and reassigned its inputs. For immutable nodes whose clone() returns self (FrozenApply), this mutated the shared, supposedly-immutable node in place, silently corrupting any frozen inner graph that was clone_replace'd. Build a fresh, mutable Apply in that case instead -- the per-node form of bind.
Scan and OpFromGraph kept both a mutable self.fgraph and a separate frozen _frozen_fgraph. Collapse to one immutable graph exposed as op.fgraph (a FrozenFunctionGraph): the mutable copy is gone, so the canonical inner graph can never be mutated in place. clone() returns self (like Composite), and eq/hash/dispatch/inner_inputs all read the frozen graph. Rewrites that recombine inner variables to build a new Scan (push_out, trace, io/utils) read the frozen inner_inputs/inner_outputs directly and pass them to the Scan constructor; construct_nominal_fgraph rebuilds the body as fresh mutable Apply nodes (relying on clone_with_new_inputs no longer mutating frozen nodes). Rewrites that mutate the inner graph in place (the linalg scan-solve split) unfreeze() a transient first. scan_make_inplace shallow-copies the op (clone() is now self). inner_inputs/inner_outputs return lists so callers keep list semantics. Updates the OFG/Scan clone tests to the immutable contract.
ScipyWrapperOp (MinimizeOp/RootOp) still keeps a *mutable* inner fgraph (it has no frozen graph yet), so the rewrite's op.fgraph.unfreeze() crashed on it -- it was matched before being made frozen. Its inner graph is already optimized lazily at link time by build_fn, so exclude it from the matcher until it is frozen, which also un-breaks tests/tensor/test_optimize.py.
Cloning an outer graph used to optionally deep-clone the inner graphs of HasInnerGraph ops (clone_inner_graph(s)=True). Now that every inner-graph op (Scan/OpFromGraph/Composite) is immutable -- clone() returns self -- that deep-clone is always a no-op, so drop the kwarg and the branch throughout the clone machinery: Apply.clone, Apply.clone_with_new_inputs, FrozenApply.clone, clone, clone_node_and_cache, clone_get_equiv, FunctionGraph.clone, and rebuild_collect_shared. clone_node_and_cache's op-clone caching (which only existed to reuse cloned inner-graph ops) is removed too. ScipyWrapperOp.clone drops the now-removed kwarg. Updates test_clone_inner_graph to the new shared-Op contract.
symbolic_op_recognition rewrites fold a pattern into an inner-graph op -- e.g. local_softmax_stabilize turns exp(x)/sum(exp(x)) into Softmax, which is itself an OpFromGraph whose inner graph is that very pattern. optimize_inner_graphs ran the full optimizer (including those rewrites) on each inner graph, so it re-created Softmax inside Softmax and recursed on the new op's inner graph without end (RecursionError compiling any softmax under FAST_RUN or NUMBA). The old link-time inner-graph optimization excluded symbolic_op_recognition; restore that exclusion across all three optimizer paths.
MinimizeScalarOp/MinimizeOp/RootScalarOp/RootOp built a mutable inner FunctionGraph and kept it as self.fgraph; that left them as the last inner-graph ops not collapsed to a single frozen graph, so they were excluded from optimize_inner_graphs. Build the inner graph (plus the jac/hess outputs) mutably in __init__ as before, then freeze(dedup_nodes=False) it -- the canonical op.fgraph is now an immutable FrozenFunctionGraph that can never be mutated in place. clone() returns self, inner_inputs/outputs return lists, and eq/hash key on the frozen graph plus the per-subclass _scipy_props (so e.g. RootOps of different equations stay distinct without a hand-maintained __props__). clone_with_inner_graph bakes an optimized inner graph into a new frozen op for the rewrite. Graph manipulation (build_fn, compute_implicit_gradients, the connection pattern) runs on a fresh unfreeze()d copy: graph_replace on frozen nodes would otherwise leak them into the outer graph. Re-add ScipyWrapperOp to the optimize_inner_graphs matcher (reverting the interim exclusion). It uses the no-inplace optimizer because its inner graph is not backend-funcified but recompiled by build_fn on the scipy-wrapped graph.
fe15202 to
726ee00
Compare
11 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #2232
Closes #2028
Closes #2033
Closes #2194 by adding a method (should it be the default) where theres's no node deduplication, as the toposort index becomes part of the hash
No more inner graph "cloning" of inner graphs