11import numpy as np
22
3+
34class Grid3D :
45 """
56 Generate and transform 3D coordinate grids in Cartesian, Spherical, or Cylindrical systems.
67
7- Conventions
8- -----------
9- - Spherical coordinates: (r, theta, phi)
10- r >= 0
11- theta ∈ [0, 2pi) (azimuthal angle in the xy-plane)
12- phi ∈ [0, pi] (polar angle from +z axis)
8+ This class constructs structured 3D grids in one of three coordinate systems and
9+ provides utilities to transform between them. Grids are stored internally as a
10+ flattened array of shape ``(N, 3)``, where ``N`` is the total number of grid points.
1311
14- - Cylindrical coordinates: (rho, theta, z)
15- rho >= 0
16- theta ∈ [0, 2pi)
17- z is unbounded
12+ Coordinate Conventions (Physics convention)
13+ --------------------------------------------
14+ - Cartesian: (x, y, z)
15+ - Spherical: (r, theta, phi)
16+ * r >= 0
17+ * theta ∈ [0, π) (polar angle in the xy-plane)
18+ * phi ∈ [0, 2π] (azimuthal angle from +z axis)
19+ - Cylindrical: (rho, phi, z)
20+ * rho >= 0
21+ * phi ∈ [0, 2π)
22+ * z ∈ ℝ
1823
19- - Cartesian coordinates: (x, y, z)
24+ Notes
25+ -----
26+ - Spherical grids use volume-uniform radial sampling by default
27+ (uniform in :math:`r^3`).
28+ - All grids are generated using ``numpy.meshgrid`` with ``indexing='ij'``.
29+ - Angular coordinates always follow standard *physics* conventions.
2030
2131 Parameters
2232 ----------
2333 system : str
24- One of {'cartesian', 'spherical', 'cylindrical'}.
25- ranges : list of tuple
34+ Coordinate system for the grid. Must be one of
35+ ``{'cartesian', 'spherical', 'cylindrical'}``.
36+ ranges : list of tuple or None
2637 Axis limits. Interpretation depends on `system`:
27- - cartesian: [(x_min, x_max), (y_min, y_max), (z_min, z_max)]
28- e.g. [(-1, 1), (-1, 1), (-1, 1)]
29- - cylindrical: [(rho_min, rho_max), None, (z_min, z_max)]
30- e.g. [(0, 5), None, (-2, 2)]
31- (theta is always [0, 2π))
32- - spherical: [(r_min, r_max), None, None]
33- e.g. [(0, 10), None, None]
34- (theta ∈ [0, 2π), phi ∈ [0, π] by default)
38+
39+ * cartesian:
40+ ``[(x_min, x_max), (y_min, y_max), (z_min, z_max)]``
41+
42+ * cylindrical:
43+ ``[(rho_min, rho_max), None, (z_min, z_max)]``
44+ (phi is always [0, 2π))
45+
46+ * spherical:
47+ ``[(r_min, r_max), None, None]``
48+ (theta ∈ [0, π), phi ∈ [0, 2π])
3549 num_points : int or list of int, optional
36- Number of grid points per dimension. If int, applied to all.
37- For example:
38- - cartesian: [nx, ny, nz]
39- - cylindrical: [nrho, ntheta, nz]
40- - spherical: [nr, ntheta, nphi]
50+ Number of grid points per dimension. If an integer is provided,
51+ the same value is used for all dimensions. Default is 10.
52+
53+ Attributes
54+ ----------
55+ system : str
56+ Coordinate system of the grid.
57+ ranges : list
58+ Axis limits used to construct the grid.
59+ num_points : list of int
60+ Number of grid points per axis.
61+ grid : ndarray, shape (N, 3)
62+ Flattened grid of coordinates in the native system.
4163 """
4264
43- def __init__ (self , system , ranges , num_points = 10 ):
65+ def __init__ (self , system , ranges , num_points = 10 , r_bins = None ):
66+ """
67+ Initialize a 3D coordinate grid.
68+
69+ Parameters
70+ ----------
71+ system : str
72+ Coordinate system of the grid. Must be one of
73+ ``{'cartesian', 'spherical', 'cylindrical'}``.
74+
75+ ranges : list of tuple or None
76+ Axis limits defining the extent of the grid. Interpretation
77+ depends on the coordinate system:
78+
79+ * cartesian:
80+ ``[(x_min, x_max), (y_min, y_max), (z_min, z_max)]``
81+
82+ * cylindrical:
83+ ``[(rho_min, rho_max), None, (z_min, z_max)]``
84+ (phi is always [0, 2π))
85+
86+ * spherical:
87+ ``[(r_min, r_max), None, None]``
88+ (theta ∈ [0, π), phi ∈ [0, 2π])
89+
90+ When ``r_bins`` is provided for spherical grids, ``ranges[0]`` is
91+ ignored.
92+
93+ num_points : int or list of int, optional
94+ Number of grid points per axis. If an integer is provided, the
95+ same value is used for all dimensions. Default is 10.
96+
97+ For spherical grids, ``num_points[0]`` is ignored when ``r_bins``
98+ is supplied.
99+
100+ r_bins : array_like, optional
101+ Explicit radial bin centers or edges for spherical grids.
102+ If provided, the radial coordinate ``r`` is taken directly from
103+ this array instead of being generated internally.
104+
105+ Notes:
106+ * Only applicable when ``system='spherical'``.
107+ * Must be one-dimensional and strictly non-negative.
108+ * Enables non-uniform, user-controlled radial sampling
109+ (e.g. logarithmic bins, adaptive bins, simulation outputs).
110+ * Useful when matching observational or simulation-derived
111+ radial grids.
112+
113+ Raises
114+ ------
115+ ValueError
116+ If ``r_bins`` is provided for a non-spherical coordinate system.
117+ """
118+
119+
44120 self .system = system .lower ()
45121 self .supported_systems = ['cartesian' , 'spherical' , 'cylindrical' ]
46122
47123 if self .system not in self .supported_systems :
48- raise ValueError (f"Unsupported system: { self .system } . "
49- f"Supported systems are { self .supported_systems } " )
124+ raise ValueError (
125+ f"Unsupported system: { self .system } . "
126+ f"Supported systems are { self .supported_systems } "
127+ )
128+
50129
51130 # Normalize num_points
52131 if isinstance (num_points , int ):
@@ -57,56 +136,85 @@ def __init__(self, system, ranges, num_points=10):
57136 # Validate ranges
58137 if self .system == 'cartesian' :
59138 if len (ranges ) != 3 :
60- raise ValueError ("Cartesian requires 3 ranges [(x_min, x_max), (y_min, y_max), (z_min, z_max)] " )
139+ raise ValueError ("Cartesian requires 3 ranges" )
61140 elif self .system == 'cylindrical' :
62141 if len (ranges ) != 3 :
63- raise ValueError ("Cylindrical requires 3 entries [(rho_min, rho_max), None, (z_min, z_max)] " )
142+ raise ValueError ("Cylindrical requires 3 entries" )
64143 elif self .system == 'spherical' :
65144 if len (ranges ) != 3 or not isinstance (ranges [0 ], tuple ):
66145 raise ValueError ("Spherical requires [(r_min, r_max), None, None]" )
67146
147+ if r_bins is not None :
148+ if self .system != 'spherical' :
149+ raise ValueError ("r_bins can only be used with spherical grids" )
150+
151+ r_bins = np .asarray (r_bins , dtype = float )
152+ if r_bins .ndim != 1 or np .any (r_bins <= 0 ):
153+ raise ValueError ("r_bins must be a 1D array of positive values" )
154+ if not np .all (np .diff (r_bins ) > 0 ):
155+ raise ValueError ("r_bins must be strictly increasing" )
156+
157+ self .r_bins = r_bins
68158 self .ranges = ranges
69159 self .num_points = num_points
70160 self .grid = self ._generate_grid ()
71161
72162 def _generate_grid (self ):
163+ """
164+ Generate the grid in the native coordinate system.
165+
166+ Returns
167+ -------
168+ ndarray, shape (N, 3)
169+ Flattened coordinate grid.
170+ """
73171 if self .system == 'spherical' :
74- r_min , r_max = self .ranges [0 ]
75- r_points = np .linspace (r_min ** 3 , r_max ** 3 , self .num_points [0 ]) # uniform in volume
76- r = r_points ** (1 / 3 )
77172
78- theta = np .linspace (0 , 2 * np .pi , self .num_points [1 ], endpoint = False )
79- u = np .linspace (- 1 , 1 , self .num_points [2 ])
80- phi = np .arccos (u )
173+ if self .r_bins is not None :
174+ r = self .r_bins
175+ else :
176+ # Evenly distributed by volume
177+ r_min , r_max = self .ranges [0 ]
178+ r_points = np .linspace (r_min ** 3 , r_max ** 3 , self .num_points [0 ])
179+ r = r_points ** (1 / 3 )
81180
181+ u = np .linspace (- 1 , 1 , self .num_points [1 ])
182+ theta = np .arccos (u )
183+ phi = np .linspace (0 , 2 * np .pi , self .num_points [2 ], endpoint = False )
184+
82185 R , Theta , Phi = np .meshgrid (r , theta , phi , indexing = 'ij' )
83186 return np .stack ([R , Theta , Phi ], axis = - 1 ).reshape (- 1 , 3 )
84187
85188 elif self .system == 'cylindrical' :
86189 rho = np .linspace (* self .ranges [0 ], self .num_points [0 ])
87- theta = np .linspace (0 , 2 * np .pi , self .num_points [1 ], endpoint = False )
190+ phi = np .linspace (0 , 2 * np .pi , self .num_points [1 ], endpoint = False )
88191 z = np .linspace (* self .ranges [2 ], self .num_points [2 ])
89- Rho , Theta , Z = np .meshgrid (rho , theta , z , indexing = 'ij' )
90- return np .stack ([Rho , Theta , Z ], axis = - 1 ).reshape (- 1 , 3 )
192+ Rho , Phi , Z = np .meshgrid (rho , phi , z , indexing = 'ij' )
193+ return np .stack ([Rho , Phi , Z ], axis = - 1 ).reshape (- 1 , 3 )
91194
92195 elif self .system == 'cartesian' :
93- axes = [np .linspace (start , stop , n ) for (start , stop ), n in zip (self .ranges , self .num_points )]
196+ axes = [
197+ np .linspace (start , stop , n )
198+ for (start , stop ), n in zip (self .ranges , self .num_points )
199+ ]
94200 mesh = np .meshgrid (* axes , indexing = 'ij' )
95201 return np .stack (mesh , axis = - 1 ).reshape (- 1 , 3 )
96202
203+
97204 def to (self , target_system ):
98205 """
99- Transform the current grid to a new coordinate system.
206+ Transform the grid to a different coordinate system.
100207
101208 Parameters
102209 ----------
103210 target_system : str
104- One of {'cartesian', 'spherical', 'cylindrical'}
211+ Target coordinate system. Must be one of
212+ ``{'cartesian', 'spherical', 'cylindrical'}``.
105213
106214 Returns
107215 -------
108- np. ndarray
109- Transformed (N, 3) grid .
216+ ndarray, shape (N, 3)
217+ Grid transformed to the target coordinate system .
110218 """
111219 target_system = target_system .lower ()
112220 if target_system == self .system :
@@ -123,40 +231,77 @@ def to(self, target_system):
123231 return self ._from_cartesian (cartesian , target_system )
124232
125233 def _to_cartesian (self , grid , system ):
234+ """
235+ Convert a grid from spherical or cylindrical to Cartesian coordinates.
236+
237+ Parameters
238+ ----------
239+ grid : ndarray, shape (N, 3)
240+ Input grid.
241+ system : {'spherical', 'cylindrical'}
242+ Original coordinate system.
243+
244+ Returns
245+ -------
246+ ndarray, shape (N, 3)
247+ Cartesian coordinates.
248+ """
126249 if system == 'spherical' :
127250 r , theta , phi = grid [:, 0 ], grid [:, 1 ], grid [:, 2 ]
128- x = r * np .sin (phi ) * np .cos (theta )
129- y = r * np .sin (phi ) * np .sin (theta )
130- z = r * np .cos (phi )
251+ x = r * np .sin (theta ) * np .cos (phi )
252+ y = r * np .sin (theta ) * np .sin (phi )
253+ z = r * np .cos (theta )
131254 return np .column_stack ((x , y , z ))
132255 elif system == 'cylindrical' :
133- rho , theta , z = grid [:, 0 ], grid [:, 1 ], grid [:, 2 ]
134- x = rho * np .cos (theta )
135- y = rho * np .sin (theta )
256+ rho , phi , z = grid [:, 0 ], grid [:, 1 ], grid [:, 2 ]
257+ x = rho * np .cos (phi )
258+ y = rho * np .sin (phi )
136259 return np .column_stack ((x , y , z ))
137260 else :
138261 raise ValueError ("Invalid system for Cartesian conversion" )
139262
140263 def _from_cartesian (self , cartesian , target_system ):
264+ """
265+ Convert Cartesian coordinates to another coordinate system.
266+
267+ Parameters
268+ ----------
269+ cartesian : ndarray, shape (N, 3)
270+ Cartesian coordinates.
271+ target_system : {'spherical', 'cylindrical'}
272+ Desired coordinate system.
273+
274+ Returns
275+ -------
276+ ndarray, shape (N, 3)
277+ Transformed coordinates.
278+ """
141279 x , y , z = cartesian [:, 0 ], cartesian [:, 1 ], cartesian [:, 2 ]
142280
143281 if target_system == 'spherical' :
144282 r = np .sqrt (x ** 2 + y ** 2 + z ** 2 )
145283 # Avoid division by zero
146284 with np .errstate (invalid = 'ignore' , divide = 'ignore' ):
147- theta = np .arctan2 (y , x )
148- phi = np .arccos (np .clip (z / np .where (r == 0 , 1 , r ), - 1.0 , 1.0 ))
285+ phi = np .arctan2 (y , x )
286+ theta = np .arccos (np .clip (z / np .where (r == 0 , 1 , r ), - 1.0 , 1.0 ))
149287 return np .column_stack ((r , theta , phi ))
150288
151289 elif target_system == 'cylindrical' :
152290 rho = np .sqrt (x ** 2 + y ** 2 )
153- theta = np .arctan2 (y , x )
154- return np .column_stack ((rho , theta , z ))
291+ phi = np .arctan2 (y , x )
292+ return np .column_stack ((rho , phi , z ))
155293
156294 else :
157295 raise ValueError ("Invalid target system" )
158296
159297 def get (self ):
160- """Return the raw grid in its original coordinate system."""
298+ """
299+ Return the raw grid in its native coordinate system.
300+
301+ Returns
302+ -------
303+ ndarray, shape (N, 3)
304+ Grid coordinates.
305+ """
161306 return self .grid
162307
0 commit comments