Skip to content

Commit 1de7c2c

Browse files
authored
Merge pull request #31 from tlambert03/typing2
Improve type hints for SpatialGraph
2 parents 1a465ed + 24d6f8c commit 1de7c2c

7 files changed

Lines changed: 565 additions & 89 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ repository = "https://github.com/funkelab/spatial_graph"
5555
[tool.ruff]
5656
target-version = "py39"
5757
line-length = 88
58+
fix = true
59+
unsafe-fixes = true
5860

5961
[tool.ruff.lint]
6062
select = [

src/spatial_graph/graph/cgraph.pyi

Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
from collections.abc import Iterator
2+
from typing import Any, Literal, overload
3+
4+
import numpy as np
5+
6+
class CGraph:
7+
def add_node(self, node: Any, *data: Any, **kwargs: Any) -> int:
8+
"""Add a single node to the graph.
9+
10+
The node attributes provided via *data and **kwargs must match the
11+
data types and names specified in `node_attr_dtypes` when the graph
12+
was created.
13+
14+
Parameters
15+
----------
16+
node : Any
17+
The node identifier to add to the graph.
18+
*data : Any
19+
Positional arguments for node attributes. Names/number of args
20+
must match the `node_attr_dtypes`.
21+
**kwargs : Any
22+
Keyword arguments for node attributes. Names/number of kwargs
23+
must match the `node_attr_dtypes`.
24+
25+
Returns
26+
-------
27+
int
28+
Number of nodes added (1 if successful, 0 if node already exists).
29+
"""
30+
def add_nodes(self, nodes: np.ndarray, *data: Any, **kwargs: Any) -> int:
31+
"""Add multiple nodes to the graph.
32+
33+
Node attributes provided via *data and **kwargs must match the
34+
data types and names specified in `node_attr_dtypes`. Each attribute
35+
array must have the same length as the `nodes` array.
36+
37+
Parameters
38+
----------
39+
nodes : np.ndarray
40+
Array of node identifiers to add to the graph.
41+
*data : Any
42+
Positional arguments for node attributes. Each argument should be
43+
an array with length matching `nodes`. Names/number of args must
44+
match the `node_attr_dtypes`.
45+
**kwargs : Any
46+
Keyword arguments for node attributes. Each argument should be
47+
an array with length matching `nodes`. Names/number of kwargs
48+
must match the `node_attr_dtypes`.
49+
50+
Returns
51+
-------
52+
int
53+
Number of nodes successfully added.
54+
"""
55+
def add_edge(self, edge: np.ndarray, *args: Any, **kwargs: Any) -> int:
56+
"""Add an edge to the graph.
57+
58+
The edge attributes provided via *args and **kwargs must match the
59+
data types and names specified in `edge_attr_dtypes` when the graph
60+
was created.
61+
62+
Parameters
63+
----------
64+
edge : np.ndarray
65+
Array of length 2 containing [source_node, target_node].
66+
*args : Any
67+
Positional arguments for edge attributes. Names/number of args
68+
must match the `edge_attr_dtypes`.
69+
**kwargs : Any
70+
Keyword arguments for edge attributes. Names/number of kwargs
71+
must match the `edge_attr_dtypes`.
72+
73+
Returns
74+
-------
75+
int
76+
Number of edges added (1 if successful, 0 if edge already exists).
77+
"""
78+
79+
def add_edges(
80+
self, edges: np.ndarray, *args: np.ndarray, **kwargs: np.ndarray
81+
) -> int:
82+
"""Add multiple edges to the graph.
83+
84+
Edge attributes provided via *args and **kwargs must match the
85+
data types and names specified in `edge_attr_dtypes`. Each attribute
86+
array must have the same length as the number of edges.
87+
88+
Parameters
89+
----------
90+
edges : np.ndarray
91+
2D array of shape (n_edges, 2) where each row contains
92+
[source_node, target_node].
93+
*args : np.ndarray
94+
Positional arguments for edge attributes. Each argument should be
95+
an array with length matching the number of edges. Names/number
96+
of args must match the `edge_attr_dtypes`.
97+
**kwargs : np.ndarray
98+
Keyword arguments for edge attributes. Each argument should be
99+
an array with length matching the number of edges. Names/number
100+
of kwargs must match the `edge_attr_dtypes`.
101+
102+
Returns
103+
-------
104+
int
105+
Number of edges successfully added.
106+
"""
107+
108+
def nodes(self) -> np.ndarray:
109+
"""Get all node IDs in the graph.
110+
111+
The returned array is a copy and modifications will not affect
112+
the graph structure.
113+
114+
Returns
115+
-------
116+
np.ndarray
117+
Array containing all node identifiers in the graph, ordered
118+
by insertion order (earliest added first).
119+
"""
120+
def remove_node(self, node: Any) -> None:
121+
"""Remove a single node from the graph.
122+
123+
Removing a node will also remove all edges incident to that node.
124+
125+
Parameters
126+
----------
127+
node : Any
128+
The node identifier to remove from the graph.
129+
"""
130+
def remove_nodes(self, nodes: np.ndarray) -> None:
131+
"""Remove multiple nodes from the graph.
132+
133+
Removing nodes will also remove all edges incident to those nodes.
134+
135+
Parameters
136+
----------
137+
nodes : np.ndarray
138+
Array of node identifiers to remove from the graph.
139+
"""
140+
def nodes_data(self, nodes: np.ndarray | None = None) -> Iterator[tuple[Any, Any]]:
141+
"""Iterate over nodes and their associated data.
142+
143+
The node_data object provides access to node attributes as defined
144+
by the `node_attr_dtypes` when the graph was created.
145+
146+
Parameters
147+
----------
148+
nodes : np.ndarray, optional
149+
Array of specific node identifiers to iterate over. If None,
150+
iterates over all nodes in the graph.
151+
152+
Yields
153+
------
154+
tuple[Any, Any]
155+
Tuples of (node_id, node_data) where node_data is a view object
156+
providing access to the node's attributes.
157+
"""
158+
def edges_data(self, us: np.ndarray, vs: np.ndarray) -> Iterator:
159+
"""Iterate over edge data for specified edges.
160+
161+
The arrays `us` and `vs` must have the same length. The edge data
162+
objects provide access to edge attributes as defined by the
163+
`edge_attr_dtypes` when the graph was created.
164+
165+
Parameters
166+
----------
167+
us : np.ndarray
168+
Array of source node identifiers.
169+
vs : np.ndarray
170+
Array of target node identifiers.
171+
172+
Yields
173+
------
174+
Any
175+
Edge data view objects providing access to edge attributes
176+
for each edge (us[i], vs[i]).
177+
"""
178+
def num_edges(self) -> int:
179+
"""Get the total number of edges in the graph.
180+
181+
Returns
182+
-------
183+
int
184+
The number of edges in the graph.
185+
"""
186+
def __len__(self) -> int:
187+
"""Return the number of nodes in the graph.
188+
189+
Returns
190+
-------
191+
int
192+
The number of nodes in the graph.
193+
"""
194+
195+
class UnDirectedCGraph(CGraph):
196+
def num_neighbors(self, nodes: np.ndarray) -> np.ndarray:
197+
"""Return the number of neighbors for each node.
198+
199+
For undirected graphs, this counts all adjacent nodes regardless
200+
of edge direction since edges are bidirectional.
201+
202+
Parameters
203+
----------
204+
nodes : np.ndarray
205+
Array of node identifiers to count neighbors for.
206+
207+
Returns
208+
-------
209+
np.ndarray
210+
Array of neighbor counts for each node in the input array.
211+
"""
212+
@overload
213+
def edges(
214+
self, node: Any = ..., data: Literal[True] = ...
215+
) -> Iterator[tuple[tuple, Any]]: ...
216+
@overload
217+
def edges(self, node: Any = ..., data: Literal[False] = ...) -> Iterator[tuple]:
218+
"""Iterate over edges in the graph.
219+
220+
For undirected graphs, each edge is yielded only once with nodes
221+
ordered such that node1 < node2 to avoid duplicates.
222+
223+
Parameters
224+
----------
225+
node : Any, optional
226+
If provided, only iterate over edges incident to this node.
227+
If None, iterate over all edges in the graph.
228+
data : bool, default False
229+
If True, yield (edge, edge_data) tuples. If False, yield
230+
only edge tuples.
231+
232+
Yields
233+
------
234+
tuple or tuple[tuple, Any]
235+
If `data=False`: tuples of (node1, node2) representing edges.
236+
If `data=True`: tuples of ((node1, node2), edge_data) where
237+
edge_data provides access to edge attributes.
238+
"""
239+
def edges_by_nodes(self, nodes: np.ndarray) -> np.ndarray:
240+
"""Get all edges incident to the specified nodes.
241+
242+
This method provides fast access to edges incident to an array
243+
of nodes. Note that edges between nodes in the input array will
244+
be reported multiple times (once for each incident node).
245+
246+
Parameters
247+
----------
248+
nodes : np.ndarray
249+
Array of node identifiers to find incident edges for.
250+
251+
Returns
252+
-------
253+
np.ndarray
254+
2D array of shape (n_edges, 2) where each row contains
255+
[node1, node2] representing an edge. For undirected graphs,
256+
node1 <= node2.
257+
"""
258+
259+
class DirectedCGraph(CGraph):
260+
def num_in_neighbors(self, nodes: np.ndarray) -> np.ndarray:
261+
"""Return the number of incoming neighbors for each node.
262+
263+
This counts only nodes that have edges pointing to the specified nodes
264+
(i.e., predecessors).
265+
266+
Parameters
267+
----------
268+
nodes : np.ndarray
269+
Array of node identifiers to count incoming neighbors for.
270+
271+
Returns
272+
-------
273+
np.ndarray
274+
Array of incoming neighbor counts for each node in the input array.
275+
"""
276+
def num_out_neighbors(self, nodes: np.ndarray) -> np.ndarray:
277+
"""Return the number of outgoing neighbors for each node.
278+
279+
This counts only nodes that the specified nodes
280+
have edges pointing to (i.e., successors).
281+
282+
Parameters
283+
----------
284+
nodes : np.ndarray
285+
Array of node identifiers to count outgoing neighbors for.
286+
287+
Returns
288+
-------
289+
np.ndarray
290+
Array of outgoing neighbor counts for each node in the input array.
291+
"""
292+
@overload
293+
def in_edges(
294+
self, node: Any = None, data: Literal[True] = ...
295+
) -> Iterator[tuple[tuple, Any]]: ...
296+
@overload
297+
def in_edges(self, node: Any, data: Literal[False] = ...) -> Iterator[tuple]:
298+
"""Iterate over incoming edges to a node.
299+
300+
Only edges directed toward the specified node are yielded.
301+
302+
Parameters
303+
----------
304+
node : Any
305+
The target node to find incoming edges for.
306+
data : bool
307+
If True, yield (edge, edge_data) tuples. If False, yield
308+
only edge tuples.
309+
310+
Yields
311+
------
312+
tuple or tuple[tuple, Any]
313+
If `data=False`: tuples of (source_node, target_node) representing
314+
incoming edges where target_node is the specified node.
315+
If `data=True`: tuples of ((source_node, target_node), edge_data)
316+
where edge_data provides access to edge attributes.
317+
"""
318+
def in_edges_by_nodes(self, nodes: np.ndarray) -> np.ndarray:
319+
"""Get all incoming edges to the specified nodes.
320+
321+
This method provides fast access to incoming edges for an array
322+
of nodes. Edges between nodes in the input array will be reported
323+
multiple times if both source and target are in the array.
324+
325+
Parameters
326+
----------
327+
nodes : np.ndarray
328+
Array of node identifiers to find incoming edges for.
329+
330+
Returns
331+
-------
332+
np.ndarray
333+
2D array of shape (n_edges, 2) where each row contains
334+
[source_node, target_node] representing an incoming edge
335+
to one of the specified nodes.
336+
"""
337+
@overload
338+
def out_edges(
339+
self, node: Any = None, data: Literal[True] = ...
340+
) -> Iterator[tuple[tuple, Any]]: ...
341+
@overload
342+
def out_edges(self, node: Any, data: Literal[False] = ...) -> Iterator[tuple]:
343+
"""Iterate over outgoing edges from a node.
344+
345+
Only edges directed away from the specified node are yielded.
346+
347+
Parameters
348+
----------
349+
node : Any
350+
The source node to find outgoing edges for.
351+
data : bool
352+
If True, yield (edge, edge_data) tuples. If False, yield
353+
only edge tuples.
354+
355+
Yields
356+
------
357+
tuple or tuple[tuple, Any]
358+
If `data=False`: tuples of (source_node, target_node) representing
359+
outgoing edges where source_node is the specified node.
360+
If `data=True`: tuples of ((source_node, target_node), edge_data)
361+
where edge_data provides access to edge attributes.
362+
"""
363+
def out_edges_by_nodes(self, nodes: np.ndarray) -> np.ndarray:
364+
"""Get all outgoing edges from the specified nodes.
365+
366+
This method provides fast access to outgoing edges for an array
367+
of nodes. Edges between nodes in the input array will be reported
368+
multiple times if both source and target are in the array.
369+
370+
Parameters
371+
----------
372+
nodes : np.ndarray
373+
Array of node identifiers to find outgoing edges for.
374+
375+
Returns
376+
-------
377+
np.ndarray
378+
2D array of shape (n_edges, 2) where each row contains
379+
[source_node, target_node] representing an outgoing edge
380+
from one of the specified nodes.
381+
"""

0 commit comments

Comments
 (0)