Skip to content

Commit ef353f8

Browse files
authored
Merge pull request #40 from tlambert03/fix-type-id
fix: type identity
2 parents aad86ee + b4ed4ae commit ef353f8

2 files changed

Lines changed: 54 additions & 12 deletions

File tree

src/spatial_graph/_graph/graph.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import hashlib
4+
import json
35
import sys
46
import warnings
57
from pathlib import Path
@@ -103,6 +105,52 @@ def _compile_graph(
103105
return wrapper.Graph
104106

105107

108+
def _hash_args(containers: tuple[Any, ...]) -> str:
109+
"""Hash a bunch of mutable arg container objects in a reproducible way.
110+
111+
This is for stuff like extra_compile_args, extra_link_args, and extension_kwargs.
112+
"""
113+
hash_obj = hashlib.md5()
114+
for container in containers:
115+
# sort dict keys for reproducibility
116+
serialized = json.dumps(container, sort_keys=True)
117+
# Update the hash object with the serialized container
118+
hash_obj.update(serialized.encode())
119+
return hash_obj.hexdigest()
120+
121+
122+
# This caches the outputs of _get_cached_graph_subtype
123+
# to avoid recompiling the same graph type multiple times.
124+
# it allows `type(Graph('uint64')) is type(Graph('uint64'))`
125+
_CLS_CACHE: dict[str, tuple[type, type[GraphBase]]] = {}
126+
127+
128+
def _get_cached_graph_subtype(
129+
cls: type[GraphBase],
130+
node_dtype: str,
131+
node_attr_dtypes: Mapping[str, str] | None = None,
132+
edge_attr_dtypes: Mapping[str, str] | None = None,
133+
) -> tuple[type[CGraph], type[GraphBase]]:
134+
_hash = _hash_args((node_dtype, node_attr_dtypes or {}, edge_attr_dtypes or {}))
135+
_hash += str(hash(cls))
136+
if _hash not in _CLS_CACHE:
137+
directed = issubclass(cls, DiGraph)
138+
# dynamically compile a specialized C++ implementation of the graph
139+
# tailored to the user's specific type requirements.
140+
CGraph = _compile_graph(
141+
node_dtype=node_dtype,
142+
node_attr_dtypes=node_attr_dtypes,
143+
edge_attr_dtypes=edge_attr_dtypes,
144+
directed=directed,
145+
)
146+
SubClass = type(cls.__name__, (cls, CGraph), {})
147+
# create a new class that inherits from both the base class
148+
# and the compiled c++ implementation
149+
_CLS_CACHE[_hash] = (CGraph, SubClass)
150+
151+
return _CLS_CACHE[_hash]
152+
153+
106154
if TYPE_CHECKING:
107155

108156
class GraphBase(CGraph):
@@ -153,22 +201,14 @@ def __new__(
153201
)
154202
cls = DiGraph if directed else Graph
155203

156-
directed = issubclass(cls, DiGraph)
157-
158-
print("Compiling graph with directed =", directed)
159-
# dynamically compile a specialized C++ implementation of the graph
160-
# tailored to the user's specific type requirements.
161-
CGraph = _compile_graph(
204+
# determine the class to use based on the base class
205+
CGraph, GraphType = _get_cached_graph_subtype(
206+
cls=cls,
162207
node_dtype=node_dtype,
163208
node_attr_dtypes=node_attr_dtypes,
164209
edge_attr_dtypes=edge_attr_dtypes,
165-
directed=directed,
166210
)
167211

168-
# create a new class that inherits from both the base class
169-
# and the compiled c++ implementation
170-
GraphType = type(cls.__name__, (cls, CGraph), {})
171-
172212
# create and initialize the instance
173213
return CGraph.__new__(GraphType)
174214

tests/test_graph.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
@pytest.mark.parametrize("edge_attr_dtypes", edge_attr_dtypes)
1414
@pytest.mark.parametrize("cls", [sg.Graph, sg.DiGraph])
1515
def test_construction(node_dtype, node_attr_dtypes, edge_attr_dtypes, cls):
16-
cls(node_dtype, node_attr_dtypes, edge_attr_dtypes)
16+
obj1 = cls(node_dtype, node_attr_dtypes, edge_attr_dtypes)
17+
obj2 = cls(node_dtype, node_attr_dtypes, edge_attr_dtypes)
18+
assert type(obj1) is type(obj2)
1719

1820

1921
@pytest.mark.parametrize("cls", [sg.Graph, sg.DiGraph])

0 commit comments

Comments
 (0)