Allow freezing of FunctionGraph for hashing#1908
Allow freezing of FunctionGraph for hashing#1908jessegrabowski wants to merge 7 commits intopymc-devs:mainfrom
Conversation
ricardoV94
left a comment
There was a problem hiding this comment.
Why did you not go all out?
If you already deduplicate and do internal hash-cons you are one step away from getting hashing for free across different FunctionGraphs. Just do the hash-cons globally. Then FrozenFunctionGrahp([x, y], [foo(x, y)] is equal to another functiongraph if and only if fgraph.outputs == other_fgraph.outputs. No need for recursive hashing or expensive equal_computations.
As it stands you are not doing much better sneaking a default MergeOptimizer at __init__ and adding a FunctionGraph class that has no replace mode.
And cheap hashing/ equality is not just a nice to have, it's really valuable to not slow down compilation. In some of my benchmarks on previous work, some graphs could spend inordinate time on equality checks.
Comments regardless of whether we go:
- Don't create
FrozenFunctionGraphas a subclass ofFrozenGraph, let's push the general principle, shared abstract classes, no-subclass of actually realized objects. Then you don't needcheck_frozen, the methods just don't exist for the frozen subclass. - You could create a frozenApply that uses
tuplefor input/outputs instead oflist. That will help ensuring the immutability because all our current rewrite machinery works on the idea of overriding entries in those lists. Accidentally trying to mutate a graph would 99% fail there.
305e26e to
08609fe
Compare
There was a problem hiding this comment.
This is starting to look good, how are you feeling about it?
Notes:
- Add a
FrozenFunctionGraph.unfreeze(), that yields aFunctionGraph? - Really try to avoid the FrozenConstant stuff
- Ops with inner graph (at least the ones you touched now) should only have a FrozenFunctionGraph internally (not a mutable one as well). Maybe that's already the case.
We need some follow-up issues open:
- Optimizing OpFromGraph: There should be an explicit rewrite that creates a new OpFromGraph with its updated frozen graph, (so it is also reflected immediately in dprint). We should never do any further rewrites of the internal fgraph during compilation.
- Scan/Minimize/Root: Use the new FrozenFunctionGraph as well. This should immediately address #1601
- When compiling OpFromGraph in jitted contexts we should try to avoid recreating inner numba/jax functions when the same OFG is compiled multiple times in a function, this will likely speedup compilation. In the C-backend that already happens due to the caching of
_fn. That's how we can deliver on the promised compilations speedups and it's specially relevant for a library likepytensor-mlthat may want to chains hundreds of the same "LayerOp"s in sequence
| @@ -4140,38 +4116,17 @@ def prepare_node(self, node, storage_map, compute_map, impl): | |||
| def __eq__(self, other): | |||
| if self is other: | |||
There was a problem hiding this comment.
can't we have regular __props__ based equality/hashing now?
There was a problem hiding this comment.
@jessegrabowski this still stands. With proper fgraph equality, we could have these inner graph Ops behave like other Ops based on __props__ (simpler mental model for devs). __props__ = ("fgraph",) (and whatever else not in the fgraph that influences behavior) ?
78ee1a9 to
eda51d2
Compare
|
I left some comments as I checked the changes. I need to think/discuss a bit about the spec thing, and the desire to have a consistent hashing across runtimes. If you remove that the complexity of this PR drops quite a bit, but maybe this is also fine. Can you confirm this was only needed for the C-backend, and that it would also work if whatever relies on that called something like Besides that this PR look amazing, and it's a game changer to working with inner graph ops. We really need those to work well |
eda51d2 to
7202ca3
Compare
|
I removed the spec stuff and simplified the PR down somewhat. |
4a7bea8 to
445731f
Compare
| out.__reduce_ex__ = _make_frozen_output_reduce(out) # type: ignore[method-assign] | ||
| instance.tag = Scratchpad() | ||
| cls._cache[cache_key] = instance | ||
| return instance |
There was a problem hiding this comment.
Do we need the frozenapply to be hash-consed? Isn't it enough if the input/output variables are? Wondering if we can remove some extra code that way. The Apply doesn't do much anyway
| assert op1 != op_different | ||
|
|
||
| # inline flag participates in equality | ||
| op_inline = OpFromGraph([x, y], [e], inline=True) |
There was a problem hiding this comment.
this is probably correct, but it's something that we have to think about. Some Ops have properties that affect their behavior in pytensor but not the computational meaning.
Those in theory should not be part of __props__ and affect equality. This would allow MergeOptimizer to merge the nodes with the same inputs but different OFG, which I guess is what we would want for the final compiled graph (e.g., it doesn't matter if node1 has a different gradient than node2 when you compile it at the end).
But on the other hand we don't want to confuse the two Ops, because when we "freeze -> unfreeze" for example, we don't want to lose those attributes.
We need to work on this in the future. There are different degrees of "equality" we need for different things. And maybe custom __eq__ even in the presence of __props__ is what we need, but I don't thing MergeOptimizer looks at __props__ specifically.
| # OFG is hashable, and different OFGs have different hashes | ||
| assert hash(op1) != hash(op_inline) | ||
|
|
||
| def test_equality_shared_variables(self): |
There was a problem hiding this comment.
This special behavior of shared variables is something I want to get rid of already for v3, but fine to test here as it's still a thing
| ofg_nodes = [n for n in fg.toposort() if isinstance(n.op, OpFromGraph)] | ||
| assert len(ofg_nodes) == 1 | ||
|
|
||
| # Different inputs are different graphs, so both nodes survive |
There was a problem hiding this comment.
Wondering if in the rebuild_collect_shared that's used at the beginning of the function compilation, we will still merge the Op (if not the node ofc). Because that would be nice, only one compilation instead of 2. Just a curiosity, not something that needs to be tested here
|
I think with this PR we'll stop seeing what was the actual optimized inner graph in the compiled function? Something we should follow up with, optimizing inner graph should be an explicit rewrite, not something that happens at make_node / dispatch time |
445731f to
00776d0
Compare
00776d0 to
c7bce17
Compare
c7bce17 to
f404421
Compare
Closes #1606
LLM disclosure: this PR made heavy use of Claude in the planning and first cut stages, though I was heavily involved. Still, the code should be subject to extra scrutiny as a result.
The purpose of the PR is to refactor Ops with inner graphs to allow comparison. The linked issue has an exhaustive discussion of the factors at play. There was an attempt in the aesara days to attack this, but it was perhaps too aggressive: it cons-hashed all Apply nodes, which necessitated changes across the codebase. @ricardoV94 suggested a weakref dict approach for subgraphs. This is implemented at the Op level. The plan is for Ops that have inner graphs (
Composite,ScalarLoop,Scan,OpFromGraph, etc) to have a_cacheclass attribute, and implement the op-specific logic for caching, pickling, unpickling, etc. It didn't look super generalizable to me at first blush, but we can argue about it maybe.Changes to
FunctionGraph:FunctionGraphnow has a methodfreezethat returns aFrozenFunctionGraph.FrozenFunctionGraphdoes cons-hashing of Apply nodes within its scope onlyFrozenFunctionGraphswith the same inner graph with evaluate to equal, but theirApplynodes won't be references to the same objects (this is the "conservatism" of my approach)Specific implementation details:
structural_hashof aFrozenFunctionGraphis built from a list of 3-tuples:(name, type, inputs), plus the outputs. For constants,inputsis replaced with the hash of the input data.FrozenFunctionGraphsis done by comparing hashes, then falling back toequal_computationif the hash misses.A consequence of the cons-hashing in this approach is that the inner graph is de-duplicated when we call
fg.freeze(). So aMergeOptimizerpass is no longer required. Usage is demonstrated on theCompositeOp. If we like the approach I can move forward with refactoring other Ops, but I wanted to stop here and discuss the approach.Code example:
Result: