Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 195 additions & 131 deletions pytensor/compile/builders.py

Large diffs are not rendered by default.

95 changes: 92 additions & 3 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import warnings
import weakref
from collections.abc import (
Hashable,
Iterable,
Expand All @@ -14,6 +15,7 @@
Any,
Generic,
Optional,
Self,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -838,6 +840,93 @@ def value(self):
return self.data


def _get_frozen_output(apply_node: "FrozenApply", index: int) -> Variable:
"""Resolve a FrozenApply output by index. Used by pickle."""
return apply_node.outputs[index]


def _make_frozen_output_reduce(out: Variable):
"""Create a __reduce_ex__ override for a FrozenApply output Variable."""
owner = out.owner
index = out.index

def __reduce_ex__(protocol):
return (_get_frozen_output, (owner, index))

return __reduce_ex__


class FrozenApply(Apply):
"""An immutable, globally-interned Apply node for frozen graphs.

Uses tuples for ``inputs`` and ``outputs`` so mutation raises ``TypeError``
at the language level. Interned by ``(op, cache_key(inputs))`` —
constructing a ``FrozenApply`` with the same op and input variables returns
the cached instance.

Constants are keyed by ``(type, data_bytes)`` so that two independently
created Constants with the same value resolve to the same cached node.
"""

_cache: weakref.WeakValueDictionary = weakref.WeakValueDictionary()

@staticmethod
def _input_to_key(inp: Variable):
"""Convert an input Variable to a hashable, value-based cache key element.

Non-Constants (NominalVariables, FrozenApply outputs) are already
globally interned, so identity works. Constants use their byte
representation so that independently-created equal constants
(including NaN) produce the same key. Object-dtype constants
(e.g. slices) fall back to ``signature()`` since their byte
representation stores pointers, not values.
"""
if isinstance(inp, Constant):
a = np.asarray(inp.data)
if a.dtype.kind != "O":
return (inp.type, a.tobytes(), a.dtype.str, a.shape)
return inp.signature()
return inp

def __new__(
cls,
op: "Op",
inputs: tuple[Variable, ...],
output_types: tuple["Type", ...],
):
cache_key = (op, tuple(cls._input_to_key(i) for i in inputs))
cached = cls._cache.get(cache_key)
if cached is not None:
return cached

instance = object.__new__(cls)
instance.op = op
instance.inputs = inputs # type: ignore[assignment]
instance.outputs = tuple( # type: ignore[assignment]
t.variable_type(type=t, owner=instance, index=i)
for i, t in enumerate(output_types)
)
# Give each output Variable a __reduce__ that resolves to the
# canonical output on unpickle, avoiding fresh Variable objects.
for out in instance.outputs:
out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign]
instance.tag = Scratchpad()
cls._cache[cache_key] = instance
return instance
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the frozenapply to be hash-consed? Isn't it enough if the input/output variables are? Wondering if we can remove some extra code that way. The Apply doesn't do much anyway


def __init__(self, op, inputs, output_types):
# All initialization is done in __new__
pass

def clone(self, clone_inner_graph: bool = False) -> Self:
"""Frozen nodes are immutable — cloning returns self."""
return self

def __reduce__(self):
output_types = tuple(o.type for o in self.outputs)
return (type(self), (self.op, self.inputs, output_types))


def clone(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
Expand Down Expand Up @@ -1154,14 +1243,14 @@ def equal_computations(

for x, y in zip(xs, ys, strict=True):
if not isinstance(x, Variable) and not isinstance(y, Variable):
return np.array_equal(x, y)
return np.array_equal(x, y, equal_nan=True)
if not isinstance(x, Variable):
if isinstance(y, Constant):
return np.array_equal(y.data, x)
return np.array_equal(y.data, x, equal_nan=True)
return False
if not isinstance(y, Variable):
if isinstance(x, Constant):
return np.array_equal(x.data, y)
return np.array_equal(x.data, y, equal_nan=True)
return False
x_is_owned, y_is_owned = (x.owner is not None, y.owner is not None)
if x_is_owned != y_is_owned:
Expand Down
184 changes: 182 additions & 2 deletions pytensor/graph/fg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""A container for specifying and manipulating a graph with distinct inputs and outputs."""

import time
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Iterable, Sequence
from typing import Any, Union, cast
Expand All @@ -10,6 +11,8 @@
from pytensor.graph.basic import (
Apply,
AtomicVariable,
Constant,
NominalVariable,
Variable,
clone_get_equiv,
)
Expand All @@ -23,12 +26,25 @@
toposort_with_orderings,
vars_between,
)
from pytensor.graph.utils import MetaObject, MissingInputError, TestValueError
from pytensor.graph.utils import MissingInputError, TestValueError


ClientType = tuple[Apply, int]


class AbstractFunctionGraph(ABC):
"""Read-only interface shared by FunctionGraph and FrozenFunctionGraph."""

inputs: Sequence[Variable]
outputs: Sequence[Variable]
apply_nodes: set[Apply]
variables: set[Variable]
clients: dict[Variable, list[ClientType]]

@abstractmethod
def toposort(self) -> list[Apply]: ...


class Output(Op):
"""A dummy `Op` that represents an output variable in a `FunctionGraph`."""

Expand All @@ -47,7 +63,7 @@ def __str__(self):
return f"output[{self.idx}]"


class FunctionGraph(MetaObject):
class FunctionGraph(AbstractFunctionGraph):
r"""
A `FunctionGraph` represents a subgraph bound by a set of input variables and
a set of output variables, ie a subgraph that specifies an PyTensor function.
Expand Down Expand Up @@ -928,3 +944,167 @@ def dprint(self, **kwargs):
from pytensor.printing import debugprint

return debugprint(self, **kwargs)

def freeze(self) -> "FrozenFunctionGraph":
"""Return a frozen, hashable version of this FunctionGraph."""
return FrozenFunctionGraph(self.inputs, self.outputs)


class FrozenFunctionGraph(AbstractFunctionGraph):
"""Immutable, hashable function graph for inner graphs of Ops.

All internal nodes are globally interned via ``FrozenApply``. Two
``FrozenFunctionGraph`` instances built from structurally identical source
graphs share the same interned output objects, so equality reduces to
identity comparison on the outputs tuple.

Use ``FunctionGraph.freeze()`` or ``FrozenFunctionGraph(inputs, outputs)``
to create instances.

.. code-block:: python

from pytensor.scalar.basic import float64, add
from pytensor.graph.fg import FunctionGraph

x, y = float64("x"), float64("y")
frozen = FunctionGraph([x, y], [add(x, y)]).freeze()
frozen2 = FunctionGraph([x, y], [add(x, y)]).freeze()

assert frozen == frozen2
assert {frozen: "value"}[frozen2] == "value"
"""

def __init__(
self,
inputs: Sequence[Variable],
outputs: Sequence[Variable],
):
from pytensor.graph.basic import FrozenApply

nominal_inputs = tuple(
NominalVariable(i, inp.type, name=inp.name) for i, inp in enumerate(inputs)
)

memo: dict[Variable, Variable] = dict(zip(inputs, nominal_inputs, strict=True))

for node in toposort(outputs, blockers=inputs):
for inp in node.inputs:
if inp not in memo:
if isinstance(inp, Constant):
memo[inp] = inp
elif isinstance(inp, AtomicVariable):
memo[inp] = inp
else:
raise ValueError(
f"Non-Constant, non-AtomicVariable orphan {inp} found "
"in the graph. All variables must be graph inputs, "
"Constants, AtomicVariables, or produced by Apply "
"nodes reachable from the inputs."
)

new_inputs = tuple(memo[i] for i in node.inputs)
output_types = tuple(out.type for out in node.outputs)
new_node = FrozenApply(node.op, new_inputs, output_types)

memo.update(zip(node.outputs, new_node.outputs, strict=True))

# Handle outputs that are Constants or AtomicVariables not
# encountered during toposort (e.g. a graph with no Apply nodes)
for o in outputs:
if o not in memo:
if isinstance(o, Constant):
memo[o] = o
elif isinstance(o, AtomicVariable):
memo[o] = o

try:
frozen_outputs = tuple(memo[o] for o in outputs)
except KeyError:
unmapped = [o for o in outputs if o not in memo]
raise ValueError(
f"Output variable {unmapped[0]} could not be mapped to a frozen "
"graph variable. All outputs must be graph inputs, "
"constants, or produced by Apply nodes reachable from "
"the inputs."
)

self.inputs: tuple[Variable, ...] = nominal_inputs
self.outputs: tuple[Variable, ...] = frozen_outputs
self._variables: set[Variable] | None = None
self._apply_nodes: set[Apply] | None = None
self._clients: dict[Variable, list[ClientType]] | None = None
self._toposort: list[Apply] | None = None

def __reduce__(self):
return FrozenFunctionGraph, (self.inputs, self.outputs)

def __hash__(self):
return hash(self.outputs)

def __eq__(self, other):
if self is other:
return True
if not isinstance(other, FrozenFunctionGraph):
return False
return self.inputs == other.inputs and self.outputs == other.outputs

def __repr__(self):
return f"FrozenFunctionGraph(inputs={list(self.inputs)}, outputs={list(self.outputs)})"

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self

@property
def apply_nodes(self) -> set[Apply]: # type: ignore[override]
if self._apply_nodes is None:
self._apply_nodes = set(applys_between(self.inputs, self.outputs))
return self._apply_nodes

def toposort(self) -> list[Apply]:
if self._toposort is None:
self._toposort = list(toposort(self.outputs, blockers=self.inputs))
return self._toposort

@property
def variables(self) -> set[Variable]: # type: ignore[override]
if self._variables is None:
self._variables = set(vars_between(self.inputs, self.outputs))
return self._variables

@property
def clients(self) -> dict[Variable, list[ClientType]]: # type: ignore[override]
if self._clients is None:
clients: dict[Variable, list[ClientType]] = {v: [] for v in self.inputs}
for node in self.toposort():
for i, inp in enumerate(node.inputs):
clients.setdefault(inp, []).append((node, i))
for out in node.outputs:
clients.setdefault(out, [])
self._clients = clients
return self._clients

def unfreeze(self) -> "FunctionGraph":
"""Return a mutable FunctionGraph with fresh mutable Apply nodes."""
memo: dict[Variable, Variable] = {inp: inp.type() for inp in self.inputs}

for node in self.toposort():
for i in node.inputs:
if i not in memo:
if isinstance(i, AtomicVariable):
memo[i] = i
else:
memo[i] = i.clone()
new_inputs = [memo[i] for i in node.inputs]
new_node = Apply(
node.op,
new_inputs,
[o.type() for o in node.outputs],
)
memo.update(zip(node.outputs, new_node.outputs))

new_inputs = [memo[i] for i in self.inputs]
new_outputs = [memo[o] for o in self.outputs]
return FunctionGraph(new_inputs, new_outputs, clone=False)
14 changes: 8 additions & 6 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

if TYPE_CHECKING:
from pytensor.compile.function.types import Function
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import AbstractFunctionGraph, FunctionGraph
from pytensor.graph.type import Type

StorageCellType = list[Any | None]
Expand Down Expand Up @@ -459,16 +459,18 @@ def perform(
"""

def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
"""Determine whether or not constant folding should be performed for the given node.
"""Determine whether constant folding should be performed for the given node.

This allows each `Op` to determine if it wants to be constant
folded when all its inputs are constant. This allows it to choose where
it puts its memory/speed trade-off. Also, it could make things faster
as constants can't be used for in-place operations (see
``*IncSubtensor``).
as constants can't be used for in-place operations (see ``*IncSubtensor``).

Parameters
----------
fgraph : FunctionGraph
Function graph to which `node` belongs. This is passed in case the `Op` needs to inspect the graph to make
its decision.
node : Apply
The node for which the constant folding determination is made.

Expand Down Expand Up @@ -633,8 +635,8 @@ def perform(self, node, inputs, output_storage):
class HasInnerGraph(ABC):
r"""A mixin for an `Op` that contain an inner graph."""

fgraph: "FunctionGraph"
"""A `FunctionGraph` of the inner function."""
fgraph: "AbstractFunctionGraph"
"""The inner function graph (FunctionGraph or FrozenFunctionGraph)."""

@property
@abstractmethod
Expand Down
Loading
Loading