Skip to content
Draft
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
TYPE_CHECKING)
Type, TYPE_CHECKING)
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
Expand Down Expand Up @@ -413,6 +413,49 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
# }}}


# {{{ NodeTypeCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NodeTypeCountMapper(CachedWalkMapper):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could maybe just modify NodeCountMapper to do this and then have get_num_nodes return sum(ncm.counts.values())? Saves some duplication.

"""
Counts the number of nodes of a given type in a DAG.

.. attribute:: counts

Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.counts = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if type(expr) not in self.counts:
self.counts[type(expr)] = 0
Comment thread
MTCam marked this conversation as resolved.
Outdated
self.counts[type(expr)] += 1


def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_num_node_types is a little confusing (sounds like it would return the number of distinct types in the DAG). I used get_node_counts in mine, which isn't much better... get_node_type_counts? 🤷‍♂️

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. get_node_type_counts seems like the clear choice.

"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeTypeCountMapper()
ncm(outputs)

return ncm.counts

# }}}


# {{{ CallSiteCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
Expand Down