Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from functools import singledispatch
from typing import cast, overload

from pytensor.graph.basic import (
Expand Down Expand Up @@ -169,6 +169,9 @@ def graph_replace(
fg_replace = {equiv[c]: c for c in conditions}
# add the replacements on top of input mappings
fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv})
# Filter out replacements whose keys are not in the FunctionGraph
# This can happen when a replacement makes an ancestor replacement redundant
fg_replace = {k: v for k, v in fg_replace.items() if k in fg.variables}
# replacements have to be done in reverse topological order so that nested
# expressions get recursively replaced correctly

Expand All @@ -183,11 +186,13 @@ def graph_replace(
toposort = fg.toposort()

def toposort_key(
fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable]
pair: tuple[Variable, Variable],
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to re-order the inputs?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got rid of the functools.partial

toposort=toposort,
fg=fg,
) -> int:
key, _ = pair
if key.owner is not None:
return ts.index(key.owner)
if (node := key.owner) is not None:
return toposort.index(node) # type: ignore[no-any-return]
else:
if key in fg.variables:
return -1
Expand All @@ -197,7 +202,7 @@ def toposort_key(
sorted_replacements = sorted(
fg_replace.items(),
# sort based on the fg toposort, if a variable has no owner, it goes first
key=partial(toposort_key, fg, toposort),
key=toposort_key,
reverse=True,
)
fg.replace_all(sorted_replacements, import_missing=True)
Expand Down
127 changes: 74 additions & 53 deletions tests/graph/test_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pytensor.tensor import dvector, fvector, vector
from tests import unittest_tools as utt
from tests.graph.utils import MyOp, MyVariable, op_multiple_outputs
from tests.unittest_tools import assert_equal_computations


class TestCloneReplace:
Expand Down Expand Up @@ -144,92 +145,112 @@ def test(x, y, mention_y):

class TestGraphReplace:
def test_graph_replace(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"

yc = graph_replace([x2], {x: y2})[0]
assert yc.owner.inputs[0] is y2
z = MyVariable("w")
out = op(x, z)

new_x = op(y)
new_out = graph_replace([out], {x: new_x})[0]
assert new_out.owner.inputs[0] is new_x
# the old reference is kept
assert yc.owner.inputs[1] is w
assert new_out.owner.inputs[1] is z

# test replace itself
yc = graph_replace([x2], {x2: y2})[0]
assert yc is y2
assert yc.owner.inputs[0] is y
assert len(yc.owner.inputs) == 1
new_out = graph_replace([out], {out: new_x})[0]
assert new_out is new_x
assert new_out.owner.inputs[0] is y
assert len(new_out.owner.inputs) == 1

# the case where inputs have to be replaced in reverse topological order
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
out2 = op(out, new_x)

oc = graph_replace([o], {x: new_x, y2: new_y2})[0]
assert oc.owner.inputs[1] is new_y2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
new_x2 = x.clone(name="new_x")
new_x22 = new_x.clone(name="new_x2")
new_out2 = graph_replace([out2], {x: new_x2, new_x: new_x22})[0]
assert new_out2.owner.inputs[1] is new_x22
assert new_out2.owner.inputs[0].owner.inputs[0] is new_x2
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out2.owner.inputs[0].owner.inputs[1] is z

def test_non_list_input(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
o = MyOp("xyop")(x, y)
new_x = x.clone(name="x_new")
new_y = y.clone(name="y2_new")
out = op(x, y)

new_x = x.clone(name="new_x")
new_y = y.clone(name="new_y")
# test non list inputs as well
oc = graph_replace(o, {x: new_x, y: new_y})
oc = graph_replace(out, {x: new_x, y: new_y})
assert oc.owner.inputs[1] is new_y
assert oc.owner.inputs[0] is new_x

def test_graph_replace_advanced(self):
op = MyOp("op")
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
w = MyVariable("w")
z2 = MyOp("zop")(z)
x2 = MyOp("xop")(x, w)
x2.name = "x2"
y2 = MyOp("yop")(y)
y2.name = "y2"
o = MyOp("xyop")(x2, y2)
new_x = x.clone(name="x_new")
new_y2 = y2.clone(name="y2_new")
new_y21 = MyOp("ny2op")(new_y2)

z_op = op(z)
xw_op = op(x, w)
y_op = op(y)
out = op(xw_op, y_op)

new_x = x.clone(name="new_x")
new_yop = y_op.clone(name="new_yop")

# now yet another replacement that could only appear after new_y2: z
# show we can do that after the prev clone
# the case where new variable is referenced during the replacements
new_y21 = MyOp("ny2op")(new_y2)
# the reference new_y2: z2 is not a part of the original graph so the replacement is unsafe
oc = graph_replace([o], {x: new_x, y2: new_y21})
oc = graph_replace(oc, {new_y2: z2})[0]
assert oc.owner.inputs[1].owner.inputs[0] is z2
assert oc.owner.inputs[0].owner.inputs[0] is new_x
new_yop_op = op(new_yop)
# the reference new_yop: z_op is not a part of the original graph so the replacement is unsafe
new_out = graph_replace([out], {x: new_x, y_op: new_yop_op})
new_out = graph_replace(new_out, {new_yop: z_op})[0]
assert new_out.owner.inputs[1].owner.inputs[0] is z_op
assert new_out.owner.inputs[0].owner.inputs[0] is new_x
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out.owner.inputs[0].owner.inputs[1] is w

new_z = z.clone(name="z_new")
oc = graph_replace([oc], {z: new_z})[0]
new_z = z.clone(name="new_z")
new_out = graph_replace([new_out], {z: new_z})[0]
# new reference appear
assert oc.owner.inputs[1].owner.inputs[0] is not z2
assert oc.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
assert new_out.owner.inputs[1].owner.inputs[0] is not z_op
assert new_out.owner.inputs[1].owner.inputs[0].owner.inputs[0] is new_z
# the old reference is still kept
assert oc.owner.inputs[0].owner.inputs[0] is new_x
assert oc.owner.inputs[0].owner.inputs[1] is w
assert new_out.owner.inputs[0].owner.inputs[0] is new_x
assert new_out.owner.inputs[0].owner.inputs[1] is w

def test_graph_replace_disconnected(self):
op = MyOp("op")
fake_op = MyOp("fake_op")
x = MyVariable("x")
fake = MyOp("fake")(x)
o = MyOp("o")(x)
oc = graph_replace([o], {fake: x.clone()}, strict=False)
assert oc[0] is o
fake = fake_op(x)
out = op(x)
[new_out] = graph_replace([out], {fake: x.clone()}, strict=False)
assert new_out is out
with pytest.raises(ValueError, match="Some replacements were not used"):
oc = graph_replace([o], {fake: x.clone()}, strict=True)
graph_replace([out], {fake: x.clone()}, strict=True)

def test_replace_var_and_ancestor(self):
"""Replacing both a variable and its ancestor should not crash.

When x depends on a and y only depends on a through x,
replacing both x and a should work: x->xx makes a->aa a no-op.
"""
op = MyOp("op")
a = MyVariable("a")
x = op(a) # x depends on a
y = op(x) # y depends on x (and transitively on a)

new_a = MyVariable("new_a")
new_x = MyVariable("new_x")

[new_y] = graph_replace([y], {a: new_a, x: new_x})
assert new_y.owner.inputs[0] is new_x
assert_equal_computations([new_y], [op(new_x)])


class TestVectorizeGraph:
Expand Down
Loading