Skip to content

Commit eda51d2

Browse files
Update Scan to use FrozenFunctionGraph
1 parent 3dabd86 commit eda51d2

1 file changed

Lines changed: 5 additions & 28 deletions

File tree

pytensor/scan/op.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,13 @@
7474
from pytensor.graph.basic import (
7575
Apply,
7676
Variable,
77-
equal_computations,
7877
)
7978
from pytensor.graph.features import NoOutputFromInplace
8079
from pytensor.graph.op import HasInnerGraph, Op, io_connection_pattern
8180
from pytensor.graph.replace import clone_replace
8281
from pytensor.graph.traversal import graph_inputs
8382
from pytensor.graph.type import HasShape
8483
from pytensor.graph.utils import InconsistencyError, MissingInputError
85-
from pytensor.link.c.basic import CLinker
8684
from pytensor.link.vm import VMLinker
8785
from pytensor.printing import op_debug_information
8886
from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
@@ -939,13 +937,12 @@ def tensorConstructor(shape, dtype):
939937
"Inner-graphs must not contain in-place operations."
940938
)
941939

942-
self._cmodule_key = CLinker().cmodule_key_variables(
943-
self.inner_inputs, self.inner_outputs, []
944-
)
945-
self._hash_inner_graph = hash(self._cmodule_key)
940+
self._frozen_fgraph = self.fgraph.freeze()
946941

947942
def __setstate__(self, d):
948943
self.__dict__.update(d)
944+
if not hasattr(self, "_frozen_fgraph"):
945+
self._frozen_fgraph = self.fgraph.freeze()
949946
# Ensure that the graph associated with the inner function is valid.
950947
self.validate_inner_graph()
951948

@@ -1324,27 +1321,7 @@ def __eq__(self, other):
13241321
if self.allow_gc != other.allow_gc:
13251322
return False
13261323

1327-
# Compare inner graphs
1328-
# TODO: Use `self.inner_fgraph == other.inner_fgraph`
1329-
if len(self.inner_inputs) != len(other.inner_inputs):
1330-
return False
1331-
1332-
if len(self.inner_outputs) != len(other.inner_outputs):
1333-
return False
1334-
1335-
# strict=False because length already compared above
1336-
for self_in, other_in in zip(
1337-
self.inner_inputs, other.inner_inputs, strict=False
1338-
):
1339-
if self_in.type != other_in.type:
1340-
return False
1341-
1342-
return equal_computations(
1343-
self.inner_outputs,
1344-
other.inner_outputs,
1345-
self.inner_inputs,
1346-
other.inner_inputs,
1347-
)
1324+
return self._frozen_fgraph == other._frozen_fgraph
13481325

13491326
def __str__(self):
13501327
inplace = "none"
@@ -1364,7 +1341,7 @@ def __hash__(self):
13641341
return hash(
13651342
(
13661343
type(self),
1367-
self._hash_inner_graph,
1344+
self._frozen_fgraph,
13681345
self.info,
13691346
self.profile,
13701347
self.truncate_gradient,

0 commit comments

Comments
 (0)