Skip to content

Commit d59945a

Browse files
authored
Merge pull request #36 from tlambert03/digraph
Add Graph DiGraph and SpatialGraph SpatialDiGraph variants
2 parents 76643e5 + 2342843 commit d59945a

14 files changed

Lines changed: 142 additions & 89 deletions

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ graph = sg.SpatialGraph(
5050
node_attr_dtypes={"position": "double[3]"},
5151
edge_attr_dtypes={"score": "float32"},
5252
position_attr="position",
53-
directed=False,
5453
)
5554
```
5655

docs/index.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ graph = sg.SpatialGraph(
3333
node_attr_dtypes={"position": "double[3]"},
3434
edge_attr_dtypes={"score": "float32"},
3535
position_attr="position",
36-
directed=False,
3736
)
3837
```
3938

examples/basic_usage.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ def main():
2323
node_attr_dtypes={"position": "double[3]"},
2424
edge_attr_dtypes={"score": "float32"},
2525
position_attr="position",
26-
directed=False,
2726
)
2827
print(f" Created graph with {graph.ndims} dimensions")
2928
print(f" Node dtype: {graph.node_dtype}")

examples/query_nearest_vispy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
node_attr_dtypes={"position": "double[3]"},
3131
edge_attr_dtypes={"score": "float32"},
3232
position_attr="position",
33-
directed=False,
3433
)
3534
nodes = np.arange(100_000, dtype="uint64")
3635
graph.add_nodes(nodes, position=np.random.random((100_000, 3)))

src/spatial_graph/__init__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,17 @@
66
__version__ = "unknown"
77

88

9-
from .graph import Graph
9+
from .graph import DiGraph, Graph, GraphBase
1010
from .rtree import LineRTree, PointRTree
11-
from .spatial_graph import SpatialGraph
11+
from .spatial_graph import SpatialDiGraph, SpatialGraph, SpatialGraphBase
1212

13-
__all__ = ["Graph", "LineRTree", "PointRTree", "SpatialGraph"]
13+
__all__ = [
14+
"DiGraph",
15+
"Graph",
16+
"GraphBase",
17+
"LineRTree",
18+
"PointRTree",
19+
"SpatialDiGraph",
20+
"SpatialGraph",
21+
"SpatialGraphBase",
22+
]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .graph import Graph
1+
from .graph import DiGraph, Graph, GraphBase
22

3-
__all__ = ["Graph"]
3+
__all__ = ["DiGraph", "Graph", "GraphBase"]

src/spatial_graph/graph/cgraph.pyi

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
"""Stubs for classes generated by wrapper_template.pyx.
2+
3+
These classes are dynamically generated based on the user's
4+
specific type requirements for nodes and edges. That generation happens inside
5+
of graph.py, and these stubs indicate what is generated.
6+
"""
7+
18
from collections.abc import Iterator
29
from typing import Any, Literal, overload
310

411
import numpy as np
512

6-
class CGraph:
13+
class CGraphBase:
714
def add_node(self, node: Any, *data: Any, **kwargs: Any) -> int:
815
"""Add a single node to the graph.
916
@@ -192,7 +199,7 @@ class CGraph:
192199
The number of nodes in the graph.
193200
"""
194201

195-
class UnDirectedCGraph(CGraph):
202+
class CGraph(CGraphBase):
196203
def num_neighbors(self, nodes: np.ndarray) -> np.ndarray:
197204
"""Return the number of neighbors for each node.
198205
@@ -256,7 +263,7 @@ class UnDirectedCGraph(CGraph):
256263
node1 <= node2.
257264
"""
258265

259-
class DirectedCGraph(CGraph):
266+
class CDiGraph(CGraphBase):
260267
def num_in_neighbors(self, nodes: np.ndarray) -> np.ndarray:
261268
"""Return the number of incoming neighbors for each node.
262269

src/spatial_graph/graph/graph.py

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
from __future__ import annotations
22

33
import sys
4+
import warnings
45
from pathlib import Path
5-
from typing import TYPE_CHECKING, Literal, overload
6+
from typing import TYPE_CHECKING, Any, overload
67

78
import numpy as np
89
import witty
910
from Cheetah.Template import Template
10-
from typing_extensions import Self
1111

1212
from spatial_graph.dtypes import DType
1313

1414
if TYPE_CHECKING:
15-
from collections.abc import Mapping
15+
from collections.abc import Iterable, Mapping
1616

17-
from .cgraph import CGraph, DirectedCGraph, UnDirectedCGraph
17+
from typing_extensions import Literal, Self
18+
19+
from .cgraph import CDiGraph, CGraph, CGraphBase
1820

1921

2022
# Set platform-specific compile arguments
@@ -70,20 +72,20 @@ def _compile_graph(
7072
node_attr_dtypes: Mapping[str, str] | None = None,
7173
edge_attr_dtypes: Mapping[str, str] | None = None,
7274
directed: Literal[True] = ...,
73-
) -> type[DirectedCGraph]: ...
75+
) -> type[CDiGraph]: ...
7476
@overload
7577
def _compile_graph(
7678
node_dtype: str,
7779
node_attr_dtypes: Mapping[str, str] | None = None,
7880
edge_attr_dtypes: Mapping[str, str] | None = None,
7981
directed: bool = ...,
80-
) -> type[UnDirectedCGraph]: ...
82+
) -> type[CGraph]: ...
8183
def _compile_graph(
8284
node_dtype: str,
8385
node_attr_dtypes: Mapping[str, str] | None = None,
8486
edge_attr_dtypes: Mapping[str, str] | None = None,
8587
directed: bool = False,
86-
) -> type[CGraph]:
88+
) -> type[CGraphBase]:
8789
wrapper_template = _build_wrapper(
8890
node_dtype=node_dtype,
8991
node_attr_dtypes=node_attr_dtypes,
@@ -103,13 +105,15 @@ def _compile_graph(
103105

104106
if TYPE_CHECKING:
105107

106-
class Graph(CGraph):
107-
node_dtype: str
108-
node_attr_dtypes: Mapping[str, str] | None
109-
edge_attr_dtypes: Mapping[str, str] | None
110-
directed: bool
108+
class GraphBase(CGraph):
109+
"""Base class for undirected graph instances."""
110+
111111
node_attrs: NodeAttrs
112112
edge_attrs: EdgeAttrs
113+
node_dtype: str
114+
node_attr_dtypes: Mapping[str, str]
115+
edge_attr_dtypes: Mapping[str, str]
116+
directed: bool
113117

114118
def __init__(
115119
self,
@@ -118,18 +122,40 @@ def __init__(
118122
edge_attr_dtypes: Mapping[str, str] | None = None,
119123
directed: bool = False,
120124
): ...
125+
126+
class Graph(GraphBase, CGraph):
127+
"""Base class for undirected graph instances."""
128+
129+
class DiGraph(GraphBase, CDiGraph):
130+
"""Base class for directed graph instances."""
131+
132+
121133
else:
122134

123-
class Graph:
135+
class GraphBase:
136+
"""Base class for compiled graph instances."""
137+
124138
def __new__(
125139
cls,
126140
node_dtype: str,
127141
node_attr_dtypes: Mapping[str, str] | None = None,
128142
edge_attr_dtypes: Mapping[str, str] | None = None,
129-
directed: bool = False,
130-
*args,
131-
**kwargs,
143+
directed: bool | None = None,
144+
*args: Any,
145+
**kwargs: Any,
132146
) -> Self:
147+
if directed is not None:
148+
warnings.warn(
149+
"The 'directed' argument is deprecated and will be removed in "
150+
"future versions. Use the 'DiGraph' class for directed graphs.",
151+
DeprecationWarning,
152+
stacklevel=2,
153+
)
154+
cls = DiGraph if directed else Graph
155+
156+
directed = issubclass(cls, DiGraph)
157+
158+
print("Compiling graph with directed =", directed)
133159
# dynamically compile a specialized C++ implementation of the graph
134160
# tailored to the user's specific type requirements.
135161
CGraph = _compile_graph(
@@ -138,13 +164,12 @@ def __new__(
138164
edge_attr_dtypes=edge_attr_dtypes,
139165
directed=directed,
140166
)
141-
# create a new class that inherits from both this class
142-
# and the compiled c++ implementation `wrapper.Graph`
167+
168+
# create a new class that inherits from both the base class
169+
# and the compiled c++ implementation
143170
GraphType = type(cls.__name__, (cls, CGraph), {})
144-
# call the __new__ method of the native C++ class, but pass the dynamically
145-
# created class as the type. This ensures the object will be an instance
146-
# of the dynamically created class, but using the C++ allocation logic and
147-
# initialization code.
171+
172+
# create and initialize the instance
148173
return CGraph.__new__(GraphType)
149174

150175
def __init__(
@@ -154,19 +179,21 @@ def __init__(
154179
edge_attr_dtypes: Mapping[str, str] | None = None,
155180
directed: bool = False,
156181
):
157-
if node_attr_dtypes is None:
158-
node_attr_dtypes = {}
159-
if edge_attr_dtypes is None:
160-
edge_attr_dtypes = {}
161182
super().__init__()
162183
self.node_dtype = node_dtype
163-
self.node_attr_dtypes = node_attr_dtypes
164-
self.edge_attr_dtypes = edge_attr_dtypes
184+
self.node_attr_dtypes = node_attr_dtypes or {}
185+
self.edge_attr_dtypes = edge_attr_dtypes or {}
165186
self.directed = directed
166187

167188
self.node_attrs = NodeAttrs(self)
168189
self.edge_attrs = EdgeAttrs(self)
169190

191+
class Graph(GraphBase):
192+
"""Base class for undirected graph instances."""
193+
194+
class DiGraph(GraphBase):
195+
"""Base class for directed graph instances."""
196+
170197

171198
class NodeAttrsView:
172199
def __init__(self, graph, nodes):
@@ -202,7 +229,7 @@ def __init__(self, graph, nodes):
202229
# 3. None
203230
super().__setattr__("nodes", nodes)
204231

205-
def __getattr__(self, name):
232+
def __getattr__(self, name: str) -> np.ndarray:
206233
if name in self.graph.node_attr_dtypes:
207234
return getattr(self, f"get_attr_{name}")(self.nodes)
208235
else:
@@ -220,9 +247,12 @@ def __iter__(self):
220247

221248

222249
class EdgeAttrsView:
223-
def __init__(self, graph, edges):
250+
graph: GraphBase
251+
edges: np.ndarray | tuple[float, float] | None
252+
253+
def __init__(self, graph: GraphBase, edges: np.ndarray | Iterable | None) -> None:
224254
super().__setattr__("graph", graph)
225-
for name in graph.edge_attr_dtypes.keys():
255+
for name in graph.edge_attr_dtypes:
226256
super().__setattr__(
227257
f"get_attr_{name}", getattr(graph, f"get_edges_data_{name}")
228258
)
@@ -254,7 +284,7 @@ def __init__(self, graph, edges):
254284
# edges should be an iteratable
255285
try:
256286
# does it have a length?
257-
len(edges)
287+
len(edges) # type: ignore
258288
# case 2 and 3
259289
edges = np.array(edges, dtype=graph.node_dtype)
260290
except Exception as e: # pragma: no cover
@@ -265,8 +295,8 @@ def __init__(self, graph, edges):
265295
if isinstance(edges, np.ndarray):
266296
if len(edges) == 0:
267297
edges = edges.reshape((0, 2))
268-
assert edges.shape[1] == 2, "Edge arrays should have shape (n, 2)"
269-
edges = np.ascontiguousarray(edges.T)
298+
assert edges.shape[1] == 2, "Edge arrays should have shape (n, 2)" # type: ignore
299+
edges = np.ascontiguousarray(edges.T) # type: ignore
270300
elif isinstance(edges, tuple):
271301
# a single edge
272302
for name in graph.edge_attr_dtypes.keys():
@@ -283,7 +313,7 @@ def __init__(self, graph, edges):
283313
# 3. None
284314
super().__setattr__("edges", edges)
285315

286-
def __getattr__(self, name):
316+
def __getattr__(self, name: str) -> np.ndarray:
287317
if name in self.graph.edge_attr_dtypes:
288318
if self.edges is not None:
289319
return getattr(self, f"get_attr_{name}")(self.edges[0], self.edges[1])
@@ -306,16 +336,16 @@ def __iter__(self):
306336

307337

308338
class NodeAttrs(NodeAttrsView):
309-
def __init__(self, graph):
339+
def __init__(self, graph: GraphBase) -> None:
310340
super().__init__(graph, nodes=None)
311341

312-
def __getitem__(self, nodes):
342+
def __getitem__(self, nodes) -> NodeAttrsView:
313343
return NodeAttrsView(self.graph, nodes)
314344

315345

316346
class EdgeAttrs(EdgeAttrsView):
317-
def __init__(self, graph):
347+
def __init__(self, graph: GraphBase) -> None:
318348
super().__init__(graph, edges=None)
319349

320-
def __getitem__(self, edges):
350+
def __getitem__(self, edges) -> EdgeAttrsView:
321351
return EdgeAttrsView(self.graph, edges)

src/spatial_graph/spatial_graph.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, ClassVar, cast
3+
from typing import TYPE_CHECKING, Any, ClassVar
44

55
import numpy as np
66

77
from .dtypes import DType
8-
from .graph.graph import Graph
8+
from .graph.graph import DiGraph, Graph, GraphBase
99
from .rtree import LineRTree, PointRTree
1010

1111
if TYPE_CHECKING:
1212
from collections.abc import Mapping
1313

14-
from .graph.cgraph import DirectedCGraph, UnDirectedCGraph
1514

16-
17-
class SpatialGraph(Graph):
15+
class SpatialGraphBase(GraphBase):
1816
edge_inclusion_values: ClassVar[list[str]] = ["incident", "leaving", "entering"]
1917

2018
def __init__(
@@ -92,13 +90,12 @@ def edges(self):
9290
def remove_nodes(self, nodes: np.ndarray) -> None:
9391
positions = getattr(self.node_attrs[nodes], self.position_attr)
9492
self._node_rtree.delete_items(nodes, positions)
95-
if self.directed:
96-
obj = cast("DirectedCGraph", self)
93+
if isinstance(self, DiGraph):
9794
edges = np.concatenate(
98-
obj.in_edges_by_nodes(nodes), obj.out_edges_by_nodes(nodes)
95+
self.in_edges_by_nodes(nodes), self.out_edges_by_nodes(nodes)
9996
)
100-
else:
101-
edges = cast("UnDirectedCGraph", self).edges_by_nodes(nodes)
97+
elif isinstance(self, Graph):
98+
edges = self.edges_by_nodes(nodes)
10299
positions_u = getattr(self.node_attrs[edges[:, 0]], self.position_attr)
103100
positions_v = getattr(self.node_attrs[edges[:, 1]], self.position_attr)
104101
self._edge_rtree.delete_items(edges, positions_u, positions_v)
@@ -108,3 +105,11 @@ def _get_position(self, kwargs):
108105
if self.position_attr in kwargs:
109106
return kwargs[self.position_attr]
110107
raise RuntimeError(f"position attribute '{self.position_attr}' not given")
108+
109+
110+
class SpatialGraph(SpatialGraphBase, Graph):
111+
"""Base class for undirected spatial graph instances."""
112+
113+
114+
class SpatialDiGraph(SpatialGraphBase, DiGraph):
115+
"""Base class for directed spatial graph instances."""

0 commit comments

Comments
 (0)