11from __future__ import annotations
22
33import sys
4+ import warnings
45from pathlib import Path
5- from typing import TYPE_CHECKING , Literal , overload
6+ from typing import TYPE_CHECKING , Any , overload
67
78import numpy as np
89import witty
910from Cheetah .Template import Template
10- from typing_extensions import Self
1111
1212from spatial_graph .dtypes import DType
1313
1414if TYPE_CHECKING :
15- from collections .abc import Mapping
15+ from collections .abc import Iterable , Mapping
1616
17- from .cgraph import CGraph , DirectedCGraph , UnDirectedCGraph
17+ from typing_extensions import Literal , Self
18+
19+ from .cgraph import CDiGraph , CGraph , CGraphBase
1820
1921
2022# Set platform-specific compile arguments
@@ -70,20 +72,20 @@ def _compile_graph(
7072 node_attr_dtypes : Mapping [str , str ] | None = None ,
7173 edge_attr_dtypes : Mapping [str , str ] | None = None ,
7274 directed : Literal [True ] = ...,
73- ) -> type [DirectedCGraph ]: ...
75+ ) -> type [CDiGraph ]: ...
7476@overload
7577def _compile_graph (
7678 node_dtype : str ,
7779 node_attr_dtypes : Mapping [str , str ] | None = None ,
7880 edge_attr_dtypes : Mapping [str , str ] | None = None ,
7981 directed : bool = ...,
80- ) -> type [UnDirectedCGraph ]: ...
82+ ) -> type [CGraph ]: ...
8183def _compile_graph (
8284 node_dtype : str ,
8385 node_attr_dtypes : Mapping [str , str ] | None = None ,
8486 edge_attr_dtypes : Mapping [str , str ] | None = None ,
8587 directed : bool = False ,
86- ) -> type [CGraph ]:
88+ ) -> type [CGraphBase ]:
8789 wrapper_template = _build_wrapper (
8890 node_dtype = node_dtype ,
8991 node_attr_dtypes = node_attr_dtypes ,
@@ -103,13 +105,15 @@ def _compile_graph(
103105
104106if TYPE_CHECKING :
105107
106- class Graph (CGraph ):
107- node_dtype : str
108- node_attr_dtypes : Mapping [str , str ] | None
109- edge_attr_dtypes : Mapping [str , str ] | None
110- directed : bool
108+ class GraphBase (CGraph ):
109+ """Base class for undirected graph instances."""
110+
111111 node_attrs : NodeAttrs
112112 edge_attrs : EdgeAttrs
113+ node_dtype : str
114+ node_attr_dtypes : Mapping [str , str ]
115+ edge_attr_dtypes : Mapping [str , str ]
116+ directed : bool
113117
114118 def __init__ (
115119 self ,
@@ -118,18 +122,40 @@ def __init__(
118122 edge_attr_dtypes : Mapping [str , str ] | None = None ,
119123 directed : bool = False ,
120124 ): ...
125+
126+ class Graph (GraphBase , CGraph ):
127+ """Base class for undirected graph instances."""
128+
129+ class DiGraph (GraphBase , CDiGraph ):
130+ """Base class for directed graph instances."""
131+
132+
121133else :
122134
123- class Graph :
135+ class GraphBase :
136+ """Base class for compiled graph instances."""
137+
124138 def __new__ (
125139 cls ,
126140 node_dtype : str ,
127141 node_attr_dtypes : Mapping [str , str ] | None = None ,
128142 edge_attr_dtypes : Mapping [str , str ] | None = None ,
129- directed : bool = False ,
130- * args ,
131- ** kwargs ,
143+ directed : bool | None = None ,
144+ * args : Any ,
145+ ** kwargs : Any ,
132146 ) -> Self :
147+ if directed is not None :
148+ warnings .warn (
149+ "The 'directed' argument is deprecated and will be removed in "
150+ "future versions. Use the 'DiGraph' class for directed graphs." ,
151+ DeprecationWarning ,
152+ stacklevel = 2 ,
153+ )
154+ cls = DiGraph if directed else Graph
155+
156+ directed = issubclass (cls , DiGraph )
157+
158+ print ("Compiling graph with directed =" , directed )
133159 # dynamically compile a specialized C++ implementation of the graph
134160 # tailored to the user's specific type requirements.
135161 CGraph = _compile_graph (
@@ -138,13 +164,12 @@ def __new__(
138164 edge_attr_dtypes = edge_attr_dtypes ,
139165 directed = directed ,
140166 )
141- # create a new class that inherits from both this class
142- # and the compiled c++ implementation `wrapper.Graph`
167+
168+ # create a new class that inherits from both the base class
169+ # and the compiled c++ implementation
143170 GraphType = type (cls .__name__ , (cls , CGraph ), {})
144- # call the __new__ method of the native C++ class, but pass the dynamically
145- # created class as the type. This ensures the object will be an instance
146- # of the dynamically created class, but using the C++ allocation logic and
147- # initialization code.
171+
172+ # create and initialize the instance
148173 return CGraph .__new__ (GraphType )
149174
150175 def __init__ (
@@ -154,19 +179,21 @@ def __init__(
154179 edge_attr_dtypes : Mapping [str , str ] | None = None ,
155180 directed : bool = False ,
156181 ):
157- if node_attr_dtypes is None :
158- node_attr_dtypes = {}
159- if edge_attr_dtypes is None :
160- edge_attr_dtypes = {}
161182 super ().__init__ ()
162183 self .node_dtype = node_dtype
163- self .node_attr_dtypes = node_attr_dtypes
164- self .edge_attr_dtypes = edge_attr_dtypes
184+ self .node_attr_dtypes = node_attr_dtypes or {}
185+ self .edge_attr_dtypes = edge_attr_dtypes or {}
165186 self .directed = directed
166187
167188 self .node_attrs = NodeAttrs (self )
168189 self .edge_attrs = EdgeAttrs (self )
169190
191+ class Graph (GraphBase ):
192+ """Base class for undirected graph instances."""
193+
194+ class DiGraph (GraphBase ):
195+ """Base class for directed graph instances."""
196+
170197
171198class NodeAttrsView :
172199 def __init__ (self , graph , nodes ):
@@ -202,7 +229,7 @@ def __init__(self, graph, nodes):
202229 # 3. None
203230 super ().__setattr__ ("nodes" , nodes )
204231
205- def __getattr__ (self , name ) :
232+ def __getattr__ (self , name : str ) -> np . ndarray :
206233 if name in self .graph .node_attr_dtypes :
207234 return getattr (self , f"get_attr_{ name } " )(self .nodes )
208235 else :
@@ -220,9 +247,12 @@ def __iter__(self):
220247
221248
222249class EdgeAttrsView :
223- def __init__ (self , graph , edges ):
250+ graph : GraphBase
251+ edges : np .ndarray | tuple [float , float ] | None
252+
253+ def __init__ (self , graph : GraphBase , edges : np .ndarray | Iterable | None ) -> None :
224254 super ().__setattr__ ("graph" , graph )
225- for name in graph .edge_attr_dtypes . keys () :
255+ for name in graph .edge_attr_dtypes :
226256 super ().__setattr__ (
227257 f"get_attr_{ name } " , getattr (graph , f"get_edges_data_{ name } " )
228258 )
@@ -254,7 +284,7 @@ def __init__(self, graph, edges):
254284 # edges should be an iteratable
255285 try :
256286 # does it have a length?
257- len (edges )
287+ len (edges ) # type: ignore
258288 # case 2 and 3
259289 edges = np .array (edges , dtype = graph .node_dtype )
260290 except Exception as e : # pragma: no cover
@@ -265,8 +295,8 @@ def __init__(self, graph, edges):
265295 if isinstance (edges , np .ndarray ):
266296 if len (edges ) == 0 :
267297 edges = edges .reshape ((0 , 2 ))
268- assert edges .shape [1 ] == 2 , "Edge arrays should have shape (n, 2)"
269- edges = np .ascontiguousarray (edges .T )
298+ assert edges .shape [1 ] == 2 , "Edge arrays should have shape (n, 2)" # type: ignore
299+ edges = np .ascontiguousarray (edges .T ) # type: ignore
270300 elif isinstance (edges , tuple ):
271301 # a single edge
272302 for name in graph .edge_attr_dtypes .keys ():
@@ -283,7 +313,7 @@ def __init__(self, graph, edges):
283313 # 3. None
284314 super ().__setattr__ ("edges" , edges )
285315
286- def __getattr__ (self , name ) :
316+ def __getattr__ (self , name : str ) -> np . ndarray :
287317 if name in self .graph .edge_attr_dtypes :
288318 if self .edges is not None :
289319 return getattr (self , f"get_attr_{ name } " )(self .edges [0 ], self .edges [1 ])
@@ -306,16 +336,16 @@ def __iter__(self):
306336
307337
308338class NodeAttrs (NodeAttrsView ):
309- def __init__ (self , graph ) :
339+ def __init__ (self , graph : GraphBase ) -> None :
310340 super ().__init__ (graph , nodes = None )
311341
312- def __getitem__ (self , nodes ):
342+ def __getitem__ (self , nodes ) -> NodeAttrsView :
313343 return NodeAttrsView (self .graph , nodes )
314344
315345
316346class EdgeAttrs (EdgeAttrsView ):
317- def __init__ (self , graph ) :
347+ def __init__ (self , graph : GraphBase ) -> None :
318348 super ().__init__ (graph , edges = None )
319349
320- def __getitem__ (self , edges ):
350+ def __getitem__ (self , edges ) -> EdgeAttrsView :
321351 return EdgeAttrsView (self .graph , edges )
0 commit comments