7474from pytensor .graph .basic import (
7575 Apply ,
7676 Variable ,
77- equal_computations ,
7877)
7978from pytensor .graph .features import NoOutputFromInplace
8079from pytensor .graph .op import HasInnerGraph , Op , io_connection_pattern
8180from pytensor .graph .replace import clone_replace
8281from pytensor .graph .traversal import graph_inputs
8382from pytensor .graph .type import HasShape
8483from pytensor .graph .utils import InconsistencyError , MissingInputError
85- from pytensor .link .c .basic import CLinker
8684from pytensor .link .vm import VMLinker
8785from pytensor .printing import op_debug_information
8886from pytensor .scan .utils import ScanProfileStats , Validator , forced_replace , safe_new
@@ -939,13 +937,12 @@ def tensorConstructor(shape, dtype):
939937 "Inner-graphs must not contain in-place operations."
940938 )
941939
942- self ._cmodule_key = CLinker ().cmodule_key_variables (
943- self .inner_inputs , self .inner_outputs , []
944- )
945- self ._hash_inner_graph = hash (self ._cmodule_key )
940+ self ._frozen_fgraph = self .fgraph .freeze ()
946941
947942 def __setstate__ (self , d ):
948943 self .__dict__ .update (d )
944+ if not hasattr (self , "_frozen_fgraph" ):
945+ self ._frozen_fgraph = self .fgraph .freeze ()
949946 # Ensure that the graph associated with the inner function is valid.
950947 self .validate_inner_graph ()
951948
@@ -1324,27 +1321,7 @@ def __eq__(self, other):
13241321 if self .allow_gc != other .allow_gc :
13251322 return False
13261323
1327- # Compare inner graphs
1328- # TODO: Use `self.inner_fgraph == other.inner_fgraph`
1329- if len (self .inner_inputs ) != len (other .inner_inputs ):
1330- return False
1331-
1332- if len (self .inner_outputs ) != len (other .inner_outputs ):
1333- return False
1334-
1335- # strict=False because length already compared above
1336- for self_in , other_in in zip (
1337- self .inner_inputs , other .inner_inputs , strict = False
1338- ):
1339- if self_in .type != other_in .type :
1340- return False
1341-
1342- return equal_computations (
1343- self .inner_outputs ,
1344- other .inner_outputs ,
1345- self .inner_inputs ,
1346- other .inner_inputs ,
1347- )
1324+ return self ._frozen_fgraph == other ._frozen_fgraph
13481325
13491326 def __str__ (self ):
13501327 inplace = "none"
@@ -1364,7 +1341,7 @@ def __hash__(self):
13641341 return hash (
13651342 (
13661343 type (self ),
1367- self ._hash_inner_graph ,
1344+ self ._frozen_fgraph ,
13681345 self .info ,
13691346 self .profile ,
13701347 self .truncate_gradient ,
0 commit comments