Skip to content

Commit 6fd5162

Browse files
committed
Call PostMapEqualNodesReuser after dw deduplication
1 parent 644eacb commit 6fd5162

2 files changed

Lines changed: 52 additions & 2 deletions

File tree

pytato/transform/__init__.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
.. autoclass:: TopoSortMapper
7575
.. autoclass:: CachedMapAndCopyMapper
7676
.. autoclass:: EdgeCachedMapper
77+
.. autoclass:: PostMapEqualNodeReuser
7778
.. autofunction:: copy_dict_of_named_arrays
7879
.. autofunction:: get_dependencies
7980
.. autofunction:: map_and_copy
@@ -210,6 +211,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
210211
This does not copy the data of a :class:`pytato.array.DataWrapper`.
211212
"""
212213

214+
# type-ignore reason: incompatible type with Mapper.rec
215+
def __call__(self, expr: T) -> T: # type: ignore[override]
216+
return self.rec(expr) # type: ignore[no-any-return]
217+
213218
def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
214219
) -> Tuple[IndexOrShapeExpr, ...]:
215220
return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp)
@@ -1706,6 +1711,48 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \
17061711
# }}}
17071712

17081713

1714+
# {{{ PostMapEqualNodeReuser
1715+
1716+
class PostMapEqualNodeReuser(CopyMapper):
1717+
"""
1718+
A mapper that reuses the same object instances for equal segments of
1719+
graphs.
1720+
1721+
.. note::
1722+
1723+
The operation performed here is equivalent to that of a
1724+
:class:`CopyMapper`, in that both return a single instance for equal
1725+
:class:`pytato.Array` nodes. However, they differ at the point where
1726+
two array expressions are compared. :class:`CopyMapper` compares array
1727+
expressions before the expressions are mapped i.e. repeatedly comparing
1728+
equal array expressions but unequal instances, and because of this it
1729+
spends super-linear time in comparing array expressions. On the other
1730+
hand, :class:`PostMapEqualNodeReuser` has linear complexity in the
1731+
number of nodes in the number of array expressions as the larger mapped
1732+
expressions already contain same instances for the predecessors,
1733+
resulting in a cheaper equality comparison overall.
1734+
"""
1735+
def __init__(self) -> None:
1736+
super().__init__()
1737+
self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {}
1738+
1739+
def cache_key(self, expr: CachedMapperT) -> Any:
1740+
return (id(expr), expr)
1741+
1742+
# type-ignore reason: incompatible with Mapper.rec
1743+
def rec(self, expr: T) -> T: # type: ignore[override]
1744+
rec_expr = super().rec(expr)
1745+
try:
1746+
# type-ignored because 'result_cache' maps to ArrayOrNames
1747+
return self.result_cache[rec_expr] # type: ignore[return-value]
1748+
except KeyError:
1749+
self.result_cache[rec_expr] = rec_expr
1750+
# type-ignored because of super-class' relaxed types
1751+
return rec_expr # type: ignore[no-any-return]
1752+
1753+
# }}}
1754+
1755+
17091756
# {{{ deduplicate_data_wrappers
17101757

17111758
def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable:
@@ -1782,8 +1829,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:
17821829
len(data_wrapper_cache),
17831830
data_wrappers_encountered - len(data_wrapper_cache))
17841831

1785-
return array_or_names
1832+
# many paths in the DAG might be semantically equivalent after DWs are
1833+
# deduplicated => morph them
1834+
return PostMapEqualNodeReuser()(array_or_names)
17861835

17871836
# }}}
17881837

1838+
17891839
# vim: foldmethod=marker

test/test_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,7 @@ def test_zero_size_cl_array_dedup(ctx_factory):
16071607
x4 = pt.make_data_wrapper(x_cl2)
16081608

16091609
out = pt.make_dict_of_named_arrays({"out1": 2*x1,
1610-
"out2": 2*x2,
1610+
"out2": 3*x2,
16111611
"out3": x3 + x4
16121612
})
16131613

0 commit comments

Comments
 (0)