|
74 | 74 | .. autoclass:: TopoSortMapper |
75 | 75 | .. autoclass:: CachedMapAndCopyMapper |
76 | 76 | .. autoclass:: EdgeCachedMapper |
| 77 | +.. autoclass:: EqualBranchesReuser |
77 | 78 | .. autofunction:: copy_dict_of_named_arrays |
78 | 79 | .. autofunction:: get_dependencies |
79 | 80 | .. autofunction:: map_and_copy |
@@ -210,6 +211,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]): |
210 | 211 | This does not copy the data of a :class:`pytato.array.DataWrapper`. |
211 | 212 | """ |
212 | 213 |
|
| 214 | + # type-ignore reason: incompatible type with Mapper.rec |
| 215 | + def __call__(self, expr: T) -> T: # type: ignore[override] |
| 216 | + return self.rec(expr) # type: ignore[no-any-return] |
| 217 | + |
213 | 218 | def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] |
214 | 219 | ) -> Tuple[IndexOrShapeExpr, ...]: |
215 | 220 | return tuple(self.rec(s) if isinstance(s, Array) else s for s in situp) |
@@ -1706,6 +1711,33 @@ def map_distributed_recv(self, expr: DistributedRecv, *args: Any) \ |
1706 | 1711 | # }}} |
1707 | 1712 |
|
1708 | 1713 |
|
| 1714 | +# {{{ EqualBranchesReuser |
| 1715 | + |
| 1716 | +class EqualBranchesReuser(CopyMapper): |
| 1717 | + """ |
| 1718 | + A mapper that replaces equal segments of graphs with identical objects. |
| 1719 | + """ |
| 1720 | + def __init__(self) -> None: |
| 1721 | + super().__init__() |
| 1722 | + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} |
| 1723 | + |
| 1724 | + def cache_key(self, expr: CachedMapperT) -> Any: |
| 1725 | + return (id(expr), expr) |
| 1726 | + |
| 1727 | + # type-ignore reason: incompatible with Mapper.rec |
| 1728 | + def rec(self, expr: T) -> T: # type: ignore[override] |
| 1729 | + rec_expr = super().rec(expr) |
| 1730 | + try: |
| 1731 | + # type-ignored because 'result_cache' maps to ArrayOrNames |
| 1732 | + return self.result_cache[rec_expr] # type: ignore[return-value] |
| 1733 | + except KeyError: |
| 1734 | + self.result_cache[rec_expr] = rec_expr |
| 1735 | + # type-ignored because of super-class' relaxed types |
| 1736 | + return rec_expr # type: ignore[no-any-return] |
| 1737 | + |
| 1738 | +# }}} |
| 1739 | + |
| 1740 | + |
1709 | 1741 | # {{{ deduplicate_data_wrappers |
1710 | 1742 |
|
1711 | 1743 | def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: |
@@ -1782,8 +1814,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: |
1782 | 1814 | len(data_wrapper_cache), |
1783 | 1815 | data_wrappers_encountered - len(data_wrapper_cache)) |
1784 | 1816 |
|
1785 | | - return array_or_names |
| 1817 | + # many paths in the DAG might be semantically equivalent after DWs are |
| 1818 | + # deduplicated => morph them |
| 1819 | + return EqualBranchesReuser()(array_or_names) |
1786 | 1820 |
|
1787 | 1821 | # }}} |
1788 | 1822 |
|
| 1823 | + |
1789 | 1824 | # vim: foldmethod=marker |
0 commit comments