|
74 | 74 | .. autoclass:: TopoSortMapper |
75 | 75 | .. autoclass:: CachedMapAndCopyMapper |
76 | 76 | .. autoclass:: EdgeCachedMapper |
| 77 | +.. autoclass:: PostMapEqualNodeReuser |
77 | 78 | .. autofunction:: copy_dict_of_named_arrays |
78 | 79 | .. autofunction:: get_dependencies |
79 | 80 | .. autofunction:: map_and_copy |
@@ -210,6 +211,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]): |
210 | 211 | This does not copy the data of a :class:`pytato.array.DataWrapper`. |
211 | 212 | """ |
212 | 213 |
|
| 214 | + # type-ignore reason: incompatible type with Mapper.rec |
| 215 | + def __call__(self, expr: T) -> T: # type: ignore[override] |
| 216 | + return self.rec(expr) # type: ignore[no-any-return] |
| 217 | + |
213 | 218 | def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] |
214 | 219 | ) -> Tuple[IndexOrShapeExpr, ...]: |
215 | 220 | return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp) |
@@ -1706,6 +1711,48 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ |
1706 | 1711 | # }}} |
1707 | 1712 |
|
1708 | 1713 |
|
| 1714 | +# {{{ PostMapEqualNodeReuser |
| 1715 | + |
| 1716 | +class PostMapEqualNodeReuser(CopyMapper): |
| 1717 | + """ |
| 1718 | + A mapper that reuses the same object instances for equal segments of |
| 1719 | + graphs. |
| 1720 | +
|
| 1721 | + .. note:: |
| 1722 | +
|
| 1723 | + The operation performed here is equivalent to that of a |
| 1724 | + :class:`CopyMapper`, in that both return a single instance for equal |
| 1725 | + :class:`pytato.Array` nodes. However, they differ at the point where |
| 1726 | + two array expressions are compared. :class:`CopyMapper` compares array |
| 1727 | + expressions before the expressions are mapped i.e. repeatedly comparing |
| 1728 | + equal array expressions but unequal instances, and because of this it |
| 1729 | + spends super-linear time in comparing array expressions. On the other |
| 1730 | + hand, :class:`PostMapEqualNodeReuser` has linear complexity in the |
| 1731 | + number of nodes in the number of array expressions as the larger mapped |
| 1732 | + expressions already contain same instances for the predecessors, |
| 1733 | + resulting in a cheaper equality comparison overall. |
| 1734 | + """ |
| 1735 | + def __init__(self) -> None: |
| 1736 | + super().__init__() |
| 1737 | + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} |
| 1738 | + |
| 1739 | + def cache_key(self, expr: CachedMapperT) -> Any: |
| 1740 | + return (id(expr), expr) |
| 1741 | + |
| 1742 | + # type-ignore reason: incompatible with Mapper.rec |
| 1743 | + def rec(self, expr: T) -> T: # type: ignore[override] |
| 1744 | + rec_expr = super().rec(expr) |
| 1745 | + try: |
| 1746 | + # type-ignored because 'result_cache' maps to ArrayOrNames |
| 1747 | + return self.result_cache[rec_expr] # type: ignore[return-value] |
| 1748 | + except KeyError: |
| 1749 | + self.result_cache[rec_expr] = rec_expr |
| 1750 | + # type-ignored because of super-class' relaxed types |
| 1751 | + return rec_expr # type: ignore[no-any-return] |
| 1752 | + |
| 1753 | +# }}} |
| 1754 | + |
| 1755 | + |
1709 | 1756 | # {{{ deduplicate_data_wrappers |
1710 | 1757 |
|
1711 | 1758 | def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: |
@@ -1782,8 +1829,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: |
1782 | 1829 | len(data_wrapper_cache), |
1783 | 1830 | data_wrappers_encountered - len(data_wrapper_cache)) |
1784 | 1831 |
|
1785 | | - return array_or_names |
| 1832 | + # many paths in the DAG might be semantically equivalent after DWs are |
| 1833 | + # deduplicated => morph them |
| 1834 | + return PostMapEqualNodeReuser()(array_or_names) |
1786 | 1835 |
|
1787 | 1836 | # }}} |
1788 | 1837 |
|
| 1838 | + |
1789 | 1839 | # vim: foldmethod=marker |
0 commit comments