|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import hashlib |
| 4 | +import json |
3 | 5 | import sys |
4 | 6 | import warnings |
5 | 7 | from pathlib import Path |
@@ -103,6 +105,52 @@ def _compile_graph( |
103 | 105 | return wrapper.Graph |
104 | 106 |
|
105 | 107 |
|
| 108 | +def _hash_args(containers: tuple[Any, ...]) -> str: |
| 109 | + """Hash a bunch of mutable arg container objects in a reproducible way. |
| 110 | +
|
| 111 | + This is for stuff like extra_compile_args, extra_link_args, and extension_kwargs. |
| 112 | + """ |
| 113 | + hash_obj = hashlib.md5() |
| 114 | + for container in containers: |
| 115 | + # sort dict keys for reproducibility |
| 116 | + serialized = json.dumps(container, sort_keys=True) |
| 117 | + # Update the hash object with the serialized container |
| 118 | + hash_obj.update(serialized.encode()) |
| 119 | + return hash_obj.hexdigest() |
| 120 | + |
| 121 | + |
| 122 | +# This caches the outputs of _get_cached_graph_subtype |
| 123 | +# to avoid recompiling the same graph type multiple times. |
| 124 | +# it allows `type(Graph('uint64')) is type(Graph('uint64'))` |
| 125 | +_CLS_CACHE: dict[str, tuple[type, type[GraphBase]]] = {} |
| 126 | + |
| 127 | + |
| 128 | +def _get_cached_graph_subtype( |
| 129 | + cls: type[GraphBase], |
| 130 | + node_dtype: str, |
| 131 | + node_attr_dtypes: Mapping[str, str] | None = None, |
| 132 | + edge_attr_dtypes: Mapping[str, str] | None = None, |
| 133 | +) -> tuple[type[CGraph], type[GraphBase]]: |
| 134 | + _hash = _hash_args((node_dtype, node_attr_dtypes or {}, edge_attr_dtypes or {})) |
| 135 | + _hash += str(hash(cls)) |
| 136 | + if _hash not in _CLS_CACHE: |
| 137 | + directed = issubclass(cls, DiGraph) |
| 138 | + # dynamically compile a specialized C++ implementation of the graph |
| 139 | + # tailored to the user's specific type requirements. |
| 140 | + CGraph = _compile_graph( |
| 141 | + node_dtype=node_dtype, |
| 142 | + node_attr_dtypes=node_attr_dtypes, |
| 143 | + edge_attr_dtypes=edge_attr_dtypes, |
| 144 | + directed=directed, |
| 145 | + ) |
| 146 | + SubClass = type(cls.__name__, (cls, CGraph), {}) |
| 147 | + # create a new class that inherits from both the base class |
| 148 | + # and the compiled c++ implementation |
| 149 | + _CLS_CACHE[_hash] = (CGraph, SubClass) |
| 150 | + |
| 151 | + return _CLS_CACHE[_hash] |
| 152 | + |
| 153 | + |
106 | 154 | if TYPE_CHECKING: |
107 | 155 |
|
108 | 156 | class GraphBase(CGraph): |
@@ -153,22 +201,14 @@ def __new__( |
153 | 201 | ) |
154 | 202 | cls = DiGraph if directed else Graph |
155 | 203 |
|
156 | | - directed = issubclass(cls, DiGraph) |
157 | | - |
158 | | - print("Compiling graph with directed =", directed) |
159 | | - # dynamically compile a specialized C++ implementation of the graph |
160 | | - # tailored to the user's specific type requirements. |
161 | | - CGraph = _compile_graph( |
| 204 | + # determine the class to use based on the base class |
| 205 | + CGraph, GraphType = _get_cached_graph_subtype( |
| 206 | + cls=cls, |
162 | 207 | node_dtype=node_dtype, |
163 | 208 | node_attr_dtypes=node_attr_dtypes, |
164 | 209 | edge_attr_dtypes=edge_attr_dtypes, |
165 | | - directed=directed, |
166 | 210 | ) |
167 | 211 |
|
168 | | - # create a new class that inherits from both the base class |
169 | | - # and the compiled c++ implementation |
170 | | - GraphType = type(cls.__name__, (cls, CGraph), {}) |
171 | | - |
172 | 212 | # create and initialize the instance |
173 | 213 | return CGraph.__new__(GraphType) |
174 | 214 |
|
|
0 commit comments