Skip to content

Commit 3305071

Browse files
committed
Call EqualBranchesReuser after dw deduplication
1 parent 644eacb commit 3305071

2 files changed

Lines changed: 37 additions & 2 deletions

File tree

pytato/transform/__init__.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
.. autoclass:: TopoSortMapper
7575
.. autoclass:: CachedMapAndCopyMapper
7676
.. autoclass:: EdgeCachedMapper
77+
.. autoclass:: EqualBranchesReuser
7778
.. autofunction:: copy_dict_of_named_arrays
7879
.. autofunction:: get_dependencies
7980
.. autofunction:: map_and_copy
@@ -210,6 +211,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
210211
This does not copy the data of a :class:`pytato.array.DataWrapper`.
211212
"""
212213

214+
# type-ignore reason: incompatible type with Mapper.rec
215+
def __call__(self, expr: T) -> T: # type: ignore[override]
216+
return self.rec(expr) # type: ignore[no-any-return]
217+
213218
def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
214219
) -> Tuple[IndexOrShapeExpr, ...]:
215220
return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp)
@@ -1706,6 +1711,33 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \
17061711
# }}}
17071712

17081713

1714+
# {{{ EqualBranchesReuser
1715+
1716+
class EqualBranchesReuser(CopyMapper):
1717+
"""
1718+
A mapper that replaces equal segments of graphs with identical objects.
1719+
"""
1720+
def __init__(self) -> None:
1721+
super().__init__()
1722+
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}
1723+
1724+
def cache_key(self, expr: CachedMapperT) -> Any:
1725+
return (id(expr), expr)
1726+
1727+
# type-ignore reason: incompatible with Mapper.rec
1728+
def rec(self, expr: T) -> T: # type: ignore[override]
1729+
rec_expr = super().rec(expr)
1730+
try:
1731+
# type-ignored because 'result_cache' maps to ArrayOrNames
1732+
return self.result_cache[rec_expr] # type: ignore[return-value]
1733+
except KeyError:
1734+
self.result_cache[rec_expr] = rec_expr
1735+
# type-ignored because of super-class' relaxed types
1736+
return rec_expr # type: ignore[no-any-return]
1737+
1738+
# }}}
1739+
1740+
17091741
# {{{ deduplicate_data_wrappers
17101742

17111743
def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
@@ -1782,8 +1814,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
17821814
len(data_wrapper_cache),
17831815
data_wrappers_encountered - len(data_wrapper_cache))
17841816

1785-
return array_or_names
1817+
# many paths in the DAG might be semantically equivalent after DWs are
1818+
# deduplicated => morph them
1819+
return EqualBranchesReuser()(array_or_names)
17861820

17871821
# }}}
17881822

1823+
17891824
# vim: foldmethod=marker

test/test_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
16071607
x4 = pt.make_data_wrapper(x_cl2)
16081608

16091609
out = pt.make_dict_of_named_arrays({"out1": 2*x1,
1610-
"out2": 2*x2,
1610+
"out2": 3*x2,
16111611
"out3": x3 + x4
16121612
})
16131613

0 commit comments

Comments
 (0)