Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 80 additions & 17 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,36 @@ def __init__(

self.is_inline = inline

self.fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inner_fgraph, self.shared_inputs, _, _ = construct_nominal_fgraph(
inputs, outputs
)
self._frozen_fgraph = self.fgraph.freeze()
# Keep only the immutable (frozen) inner graph as ``op.fgraph``; the
# mutable copy is transient, so the canonical inner graph can never be
# mutated in place. ``dedup_nodes=False``: inner graphs may carry inplace
# ops whose destroyed buffers must stay distinct; structural folding
# would alias them. See ``FunctionGraph.freeze``.
self.fgraph = inner_fgraph.freeze(dedup_nodes=False)

if strict and self.shared_inputs:
raise ValueError(
"All variables needed to compute inner-graph must be provided as inputs under strict=True. "
f"The inner-graph implicitly depends on the following shared variables {self.shared_inputs}"
)

# `compile_kwargs` used to control how the inner graph was compiled.
# That is now the job of the `optimize_inner_graphs` rewrite (which
# inherits the outer compilation), so they are deprecated AND ignored:
# the inner function is compiled with default settings (see `fn`).
# `on_unused_input` is exempt: tolerating unused inputs is now the
# default behavior, so passing it is a harmless no-op (not warned).
deprecated_kwargs = {k for k in kwargs if k != "on_unused_input"}
if deprecated_kwargs:
warnings.warn(
"Passing `compile_kwargs` to `OpFromGraph` is deprecated and "
"now ignored: the inner graph inherits the outer compilation. "
f"Ignored: {sorted(deprecated_kwargs)}.",
FutureWarning,
)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
Expand Down Expand Up @@ -415,7 +434,7 @@ def _freeze_override(self, override, make_dummy_args):
if override is None:
return None
if isinstance(override, OpFromGraph):
return override._frozen_fgraph
return override.fgraph

all_inputs, callable_args = make_dummy_args()

Expand Down Expand Up @@ -477,7 +496,7 @@ def __eq__(self, other):
if type(self) is not type(other):
return False
if (
self._frozen_fgraph != other._frozen_fgraph
self.fgraph != other.fgraph
or self.is_inline != other.is_inline
or self.destroy_map != other.destroy_map
or len(self.shared_inputs) != len(other.shared_inputs)
Expand All @@ -501,7 +520,7 @@ def __eq__(self, other):
)

def __hash__(self):
return hash((type(self), self._frozen_fgraph, self.is_inline))
return hash((type(self), self.fgraph, self.is_inline))

def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
Expand Down Expand Up @@ -560,8 +579,13 @@ def _build_and_cache_lop_op(
except KeyError:
pass

inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
# Differentiate a thawed copy of the inner graph so ``grad`` walks
# mutable ``Apply`` nodes rather than the immutable ``FrozenApply`` nodes
# of ``self.fgraph`` (whose tuple inputs/outputs break Ops that
# concatenate them, e.g. ``Blockwise.pullback``).
unfrozen_fgraph = self.fgraph.unfreeze()
inner_inputs = list(unfrozen_fgraph.inputs)
inner_outputs = list(unfrozen_fgraph.outputs)
nin = len(inner_inputs)
nout = len(inner_outputs)
pullback_overrides = self.pullback_overrides
Expand Down Expand Up @@ -684,8 +708,10 @@ def _build_and_cache_rop_op(self):
if self._rop_op_cache is not None:
return self._rop_op_cache

inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
# Thaw the inner graph before differentiating (see ``_build_and_cache_lop_op``).
unfrozen_fgraph = self.fgraph.unfreeze()
inner_inputs = list(unfrozen_fgraph.inputs)
inner_outputs = list(unfrozen_fgraph.outputs)
nout = len(inner_outputs)
pushforward_overrides = self.pushforward_overrides

Expand Down Expand Up @@ -919,25 +945,62 @@ def fn(self):
if getattr(self, "_fn", None) is not None:
return self._fn

kwargs = self.kwargs.copy()
mode = get_mode(kwargs.pop("mode", None)).excluding("symbolic_op_recognition")
self._fn = function(self.inner_inputs, self.inner_outputs, mode=mode, **kwargs)
# `compile_kwargs` are deprecated and ignored; compile the (already
# optimized) inner graph with default settings. Inner graphs commonly
# have unused inputs (e.g. rng, size), which are tolerated by default.
# They may also carry inplace ops baked in by ``optimize_inner_graphs``
# (or built inplace, e.g. ``FusedElemwise``); those only ever destroy
# internal buffers (inputs stay protected), so we accept them.
mode = get_mode(None).excluding("symbolic_op_recognition")
self._fn = function(
self.inner_inputs,
self.inner_outputs,
mode=mode,
on_unused_input="ignore",
accept_inplace=True,
)
self._fn.trust_input = True

return self._fn

@property
def inner_inputs(self):
return self.fgraph.inputs
# A list (not the frozen tuple) so callers that concatenate inner
# inputs/outputs keep list semantics. Read-only views of the immutable
# graph; manipulating them requires a fresh/unfrozen graph.
return list(self.fgraph.inputs)

@property
def inner_outputs(self):
return self.fgraph.outputs
return list(self.fgraph.outputs)

def clone(self):
res = copy(self)
res.fgraph = res.fgraph.clone(clone_inner_graphs=True)
return res
# The inner graph is immutable (a frozen ``FunctionGraph``), so there is
# nothing to deep-clone -- mirror ``Composite.clone``.
return self

def clone_with_inner_graph(self, inner_fgraph: FunctionGraph) -> OpFromGraph:
"""Return a copy of this op whose inner graph is ``inner_fgraph``.

Used by the ``optimize_inner_graphs`` rewrite to build a new op
carrying an already-optimized inner graph without ever mutating
``self``. The subclass and all properties/overrides are preserved
(via ``copy``); only the inner graph and the state derived from it are
rebuilt.
"""
new = copy(self)
new_fgraph, new.shared_inputs, _, _ = construct_nominal_fgraph(
list(inner_fgraph.inputs), list(inner_fgraph.outputs)
)
new.fgraph = new_fgraph.freeze(dedup_nodes=False)
new.input_types = [inp.type for inp in new.fgraph.inputs]
new.output_types = [out.type for out in new.fgraph.outputs]
# Drop caches tied to the previous inner graph.
new._lop_op_cache = {}
new._rop_op_cache = None
new._frozen_lop = None
new._frozen_rop = None
return new

def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
Expand Down
183 changes: 183 additions & 0 deletions pytensor/compile/inner_graph_rewriting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Rewrite that optimizes the inner graphs of inner-graph ops once.

Historically each linker/``make_thunk`` optimized the inner ``FunctionGraph``
of ``Scan``/``OpFromGraph``/``ScipyWrapperOp`` lazily and *in place* on the
canonical (shared) op. That mutated graphs that are supposed to be immutable
and leaked backend-specific rewrites onto ops reused elsewhere.

Instead, :class:`OptimizeInnerGraphs` runs once near the end of outer
optimization. For each unique inner-graph op it unfreezes the canonical inner
graph (a throwaway copy), optimizes that copy with the *same* outer query
(transferred via :func:`get_active_rewriter`), and builds a NEW immutable op
via ``clone_with_inner_graph`` -- the canonical op is never touched. Linkers
then simply funcify the already-optimized inner graph.
"""

# Most imports are deferred into ``apply`` to avoid an import cycle: this module
# is imported by ``pytensor.compile.mode`` at startup, while the ops below pull
# in ``mode`` transitively.
from collections import defaultdict

from pytensor.graph.rewriting.basic import GraphRewriter, get_active_mode


def _scan_destroyable_signature(op, node):
"""Hashable signature of which taps are destroyable for this outer node.

Two ``Scan`` nodes sharing an op can share a baked inner graph iff they agree
on this (see ``Scan.inner_destroyable_inputs``). Untraced sit_sot are always
destroyable and mit_mot never, so only the sit_sot / mit_sot masks vary.
"""
sitsot_mask = tuple(
outer.type.shape[0] == 1 for outer in op.outer_sitsot(node.inputs)
)
mitsot_mask = tuple(
outer.type.shape[0] == abs(min(taps))
for outer, taps in zip(
op.outer_mitsot(node.inputs), op.info.mit_sot_in_slices, strict=True
)
)
return (sitsot_mask, mitsot_mask)


class OptimizeInnerGraphs(GraphRewriter):
"""Optimize inner graphs of inner-graph ops, producing new immutable ops."""

def add_requirements(self, fgraph):
pass

def apply(self, fgraph):
from pytensor.compile.aliasing import add_supervisor_to_fgraph
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.io import In
from pytensor.compile.mode import get_mode
from pytensor.configdefaults import config
from pytensor.graph.features import NoOutputFromInplace
from pytensor.link.basic import JITLinker
from pytensor.scan.op import Scan
from pytensor.tensor.optimize import ScipyWrapperOp
from pytensor.tensor.random.op import OpWithCoreShape

inner_graph_types: tuple = (OpFromGraph, Scan, ScipyWrapperOp)

# Group apply NODES by an optimization key, so each distinct inner graph
# is optimized once and shared across its nodes (an op reused in many
# nodes -> one optimization; cache reuse holds). The key is the op,
# except for ``Scan``: its tap-inplace depends on the *outer* node's
# buffer shapes, so two nodes sharing a ``Scan`` op but with different
# destroyable taps need different baked inner graphs -- the key then also
# carries the destroyable signature.
groups: dict = defaultdict(list)
for node in fgraph.apply_nodes:
op = node.op
if not isinstance(op, inner_graph_types) or isinstance(op, OpWithCoreShape):
# ``*WithCoreShape`` are leaf backend ops with dedicated dispatch;
# re-optimizing/re-wrapping them would loop.
continue
if isinstance(op, Scan):
groups[(op, _scan_destroyable_signature(op, node))].append(node)
else:
groups[op].append(node)
if not groups:
return

# Optimize each inner graph with the *active outer compilation* (mode +
# linker), recovered reliably from the fgraph (``config.mode`` gets
# coerced across nested compilations). The query includes
# ``optimize_inner_graphs``, so nested inner-graph ops recurse -- we
# propagate the mode onto each inner graph so the recursion recovers it.
mode = get_active_mode(fgraph)
# Inplace is baked into the frozen op (frozen ``dedup_nodes=False`` so
# distinct inplace buffers survive interning). ``OpFromGraph`` must not
# mutate its inputs, so all are protected; ``Scan`` protects all but the
# destroyable taps. ``ScipyWrapperOp`` uses the no-inplace optimizer:
# its inner graph is not backend-funcified but recompiled by
# ``build_fn`` (which applies inplace itself on the scipy-wrapped graph),
# so baking inplace into the frozen op here would be pointless.
# Exclude ``symbolic_op_recognition``: those rewrites fold a pattern into
# an inner-graph op (e.g. ``exp(x) / sum(exp(x))`` -> ``Softmax``, itself
# an ``OpFromGraph``). Running them on the inner graph -- which *is* that
# pattern -- would re-create the op, and optimizing its inner graph would
# recurse without end. The old link-time inner optimization excluded it
# too.
inplace_optimizer = mode.excluding("symbolic_op_recognition").optimizer
noinplace_optimizer = mode.excluding(
"inplace", "symbolic_op_recognition"
).optimizer

node_to_new_op: dict = {}
for nodes in groups.values():
rep_node = nodes[0]
op = rep_node.op
inner = op.fgraph.unfreeze()
inner._compile_mode = mode

custom_mode = getattr(op, "mode", None)
if isinstance(op, OpFromGraph):
input_specs = [In(x, borrow=True, mutable=False) for x in inner.inputs]
add_supervisor_to_fgraph(
fgraph=inner, input_specs=input_specs, accept_inplace=True
)
optimizer = inplace_optimizer
elif isinstance(op, Scan):
destroyable = op.inner_destroyable_inputs(rep_node.inputs, inner.inputs)
input_specs = [
In(x, borrow=True, mutable=x in destroyable) for x in inner.inputs
]
add_supervisor_to_fgraph(
fgraph=inner, input_specs=input_specs, accept_inplace=True
)
# Protect the same tap outputs ``Scan.prepare_fgraph`` does, so the
# inplace baked here never makes a protected output the result of a
# destroy-map node. Otherwise the CVM/Python ``Scan.fn`` path, which
# re-attaches ``NoOutputFromInplace`` and re-optimizes, would refuse
# the already-baked graph (``InconsistencyError`` under
# ``on_opt_error=raise``). This protection is specific to the
# non-jit linkers: jit backends (numba/jax) rely on inner tap
# inplace and copy outputs themselves (``insert_deepcopy``), and
# ``prepare_fgraph`` only installs it when output preallocation is
# enabled -- match both conditions so the baked graph agrees with
# whatever ``Scan.fn`` would produce.
if (
not isinstance(mode.linker, JITLinker)
and config.scan__allow_output_prealloc
):
inner.attach_feature(
NoOutputFromInplace(op.protected_inner_out_idxs())
)
if custom_mode is not None:
# A custom Scan ``mode`` is still honored (deprecated; warned
# at op creation), combined with the active linker's
# required/incompatible rewrites so backend must-have ops
# still apply.
linker = mode.linker
optimizer = (
get_mode(custom_mode)
.including(*linker.required_rewrites)
.excluding(
*linker.incompatible_rewrites, "symbolic_op_recognition"
)
.optimizer
)
else:
optimizer = inplace_optimizer
else:
optimizer = noinplace_optimizer

optimizer.rewrite(inner)
new_op = op.clone_with_inner_graph(inner)
if new_op != op:
for node in nodes:
node_to_new_op[node] = new_op

if not node_to_new_op:
return

for node in fgraph.toposort():
new_op = node_to_new_op.get(node)
if new_op is not None:
new_node = new_op.make_node(*node.inputs)
fgraph.replace_all(
list(zip(node.outputs, new_node.outputs, strict=True)),
reason="optimize_inner_graphs",
)
9 changes: 8 additions & 1 deletion pytensor/compile/maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,14 @@ def prepare_fgraph(
mode=mode,
traceback__limit=config.traceback__compile_limit,
):
rewriter_profile = rewriter(fgraph)
# Expose the compile mode so inner-graph rewrites can recover
# the active linker's required/incompatible rewrites reliably
# (``config.mode`` is unreliable across nested compilations).
fgraph._compile_mode = mode
try:
rewriter_profile = rewriter(fgraph)
finally:
fgraph._compile_mode = None

end_rewriter = time.perf_counter()
rewrite_time = end_rewriter - start_rewriter
Expand Down
16 changes: 16 additions & 0 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,22 @@ def apply(self, fgraph):
"add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5
)

# Optimize the inner graphs of inner-graph ops (Scan/OpFromGraph/...) once,
# producing new immutable ops. Runs for every backend (so the C/VM and Python
# linkers get inner-graph optimization too) and is required by the JIT linkers
# (see their ``required_rewrites``). The registration name doubles as its tag.
# Imported here (after `optdb` is defined) to avoid an import cycle.
from pytensor.compile.inner_graph_rewriting import OptimizeInnerGraphs # noqa: E402


optdb.register(
"optimize_inner_graphs",
OptimizeInnerGraphs(),
"fast_run",
"fast_compile",
position=49.6,
)

# final pass just to make sure
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)

Expand Down
Loading
Loading