Skip to content

Commit ad45705

Browse files
author
jngaravitoc
committed
adding r_bins parameter to Grid3D
1 parent 1a2dd21 commit ad45705

File tree

4 files changed

+434
-72
lines changed

4 files changed

+434
-72
lines changed

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
55

6+
7+
## [0.2.3] – 2026-01-06
8+
9+
### Added
10+
- Added option to pass `r_bins` to `Grid3D`
11+
12+
### Fixed
13+
- Grid3D consistent with *physics* convention
14+
- Grid3D tests and tutorials are updated
15+
16+
617
## [0.2.2] – 2026-01-05
718

819
### Added

EXPtools/visuals/grid.py

Lines changed: 201 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,131 @@
11
import numpy as np
22

3+
34
class Grid3D:
45
"""
56
Generate and transform 3D coordinate grids in Cartesian, Spherical, or Cylindrical systems.
67
7-
Conventions
8-
-----------
9-
- Spherical coordinates: (r, theta, phi)
10-
r >= 0
11-
theta ∈ [0, 2pi) (azimuthal angle in the xy-plane)
12-
phi ∈ [0, pi] (polar angle from +z axis)
8+
This class constructs structured 3D grids in one of three coordinate systems and
9+
provides utilities to transform between them. Grids are stored internally as a
10+
flattened array of shape ``(N, 3)``, where ``N`` is the total number of grid points.
1311
14-
- Cylindrical coordinates: (rho, theta, z)
15-
rho >= 0
16-
theta ∈ [0, 2pi)
17-
z is unbounded
12+
Coordinate Conventions (Physics convention)
13+
--------------------------------------------
14+
- Cartesian: (x, y, z)
15+
- Spherical: (r, theta, phi)
16+
* r >= 0
17+
* theta ∈ [0, π) (polar angle in the xy-plane)
18+
* phi ∈ [0, 2π] (azimuthal angle from +z axis)
19+
- Cylindrical: (rho, phi, z)
20+
* rho >= 0
21+
* phi ∈ [0, 2π)
22+
* z ∈ ℝ
1823
19-
- Cartesian coordinates: (x, y, z)
24+
Notes
25+
-----
26+
- Spherical grids use volume-uniform radial sampling by default
27+
(uniform in :math:`r^3`).
28+
- All grids are generated using ``numpy.meshgrid`` with ``indexing='ij'``.
29+
- Angular coordinates always follow standard *physics* conventions.
2030
2131
Parameters
2232
----------
2333
system : str
24-
One of {'cartesian', 'spherical', 'cylindrical'}.
25-
ranges : list of tuple
34+
Coordinate system for the grid. Must be one of
35+
``{'cartesian', 'spherical', 'cylindrical'}``.
36+
ranges : list of tuple or None
2637
Axis limits. Interpretation depends on `system`:
27-
- cartesian: [(x_min, x_max), (y_min, y_max), (z_min, z_max)]
28-
e.g. [(-1, 1), (-1, 1), (-1, 1)]
29-
- cylindrical: [(rho_min, rho_max), None, (z_min, z_max)]
30-
e.g. [(0, 5), None, (-2, 2)]
31-
(theta is always [0, 2π))
32-
- spherical: [(r_min, r_max), None, None]
33-
e.g. [(0, 10), None, None]
34-
(theta ∈ [0, 2π), phi ∈ [0, π] by default)
38+
39+
* cartesian:
40+
``[(x_min, x_max), (y_min, y_max), (z_min, z_max)]``
41+
42+
* cylindrical:
43+
``[(rho_min, rho_max), None, (z_min, z_max)]``
44+
(phi is always [0, 2π))
45+
46+
* spherical:
47+
``[(r_min, r_max), None, None]``
48+
(theta ∈ [0, π), phi ∈ [0, 2π])
3549
num_points : int or list of int, optional
36-
Number of grid points per dimension. If int, applied to all.
37-
For example:
38-
- cartesian: [nx, ny, nz]
39-
- cylindrical: [nrho, ntheta, nz]
40-
- spherical: [nr, ntheta, nphi]
50+
Number of grid points per dimension. If an integer is provided,
51+
the same value is used for all dimensions. Default is 10.
52+
53+
Attributes
54+
----------
55+
system : str
56+
Coordinate system of the grid.
57+
ranges : list
58+
Axis limits used to construct the grid.
59+
num_points : list of int
60+
Number of grid points per axis.
61+
grid : ndarray, shape (N, 3)
62+
Flattened grid of coordinates in the native system.
4163
"""
4264

43-
def __init__(self, system, ranges, num_points=10):
65+
def __init__(self, system, ranges, num_points=10, r_bins=None):
66+
"""
67+
Initialize a 3D coordinate grid.
68+
69+
Parameters
70+
----------
71+
system : str
72+
Coordinate system of the grid. Must be one of
73+
``{'cartesian', 'spherical', 'cylindrical'}``.
74+
75+
ranges : list of tuple or None
76+
Axis limits defining the extent of the grid. Interpretation
77+
depends on the coordinate system:
78+
79+
* cartesian:
80+
``[(x_min, x_max), (y_min, y_max), (z_min, z_max)]``
81+
82+
* cylindrical:
83+
``[(rho_min, rho_max), None, (z_min, z_max)]``
84+
(phi is always [0, 2π))
85+
86+
* spherical:
87+
``[(r_min, r_max), None, None]``
88+
(theta ∈ [0, π), phi ∈ [0, 2π])
89+
90+
When ``r_bins`` is provided for spherical grids, ``ranges[0]`` is
91+
ignored.
92+
93+
num_points : int or list of int, optional
94+
Number of grid points per axis. If an integer is provided, the
95+
same value is used for all dimensions. Default is 10.
96+
97+
For spherical grids, ``num_points[0]`` is ignored when ``r_bins``
98+
is supplied.
99+
100+
r_bins : array_like, optional
101+
Explicit radial bin centers or edges for spherical grids.
102+
If provided, the radial coordinate ``r`` is taken directly from
103+
this array instead of being generated internally.
104+
105+
Notes:
106+
* Only applicable when ``system='spherical'``.
107+
* Must be one-dimensional and strictly non-negative.
108+
* Enables non-uniform, user-controlled radial sampling
109+
(e.g. logarithmic bins, adaptive bins, simulation outputs).
110+
* Useful when matching observational or simulation-derived
111+
radial grids.
112+
113+
Raises
114+
------
115+
ValueError
116+
If ``r_bins`` is provided for a non-spherical coordinate system.
117+
"""
118+
119+
44120
self.system = system.lower()
45121
self.supported_systems = ['cartesian', 'spherical', 'cylindrical']
46122

47123
if self.system not in self.supported_systems:
48-
raise ValueError(f"Unsupported system: {self.system}. "
49-
f"Supported systems are {self.supported_systems}")
124+
raise ValueError(
125+
f"Unsupported system: {self.system}. "
126+
f"Supported systems are {self.supported_systems}"
127+
)
128+
50129

51130
# Normalize num_points
52131
if isinstance(num_points, int):
@@ -57,56 +136,85 @@ def __init__(self, system, ranges, num_points=10):
57136
# Validate ranges
58137
if self.system == 'cartesian':
59138
if len(ranges) != 3:
60-
raise ValueError("Cartesian requires 3 ranges [(x_min, x_max), (y_min, y_max), (z_min, z_max)]")
139+
raise ValueError("Cartesian requires 3 ranges")
61140
elif self.system == 'cylindrical':
62141
if len(ranges) != 3:
63-
raise ValueError("Cylindrical requires 3 entries [(rho_min, rho_max), None, (z_min, z_max)]")
142+
raise ValueError("Cylindrical requires 3 entries")
64143
elif self.system == 'spherical':
65144
if len(ranges) != 3 or not isinstance(ranges[0], tuple):
66145
raise ValueError("Spherical requires [(r_min, r_max), None, None]")
67146

147+
if r_bins is not None:
148+
if self.system != 'spherical':
149+
raise ValueError("r_bins can only be used with spherical grids")
150+
151+
r_bins = np.asarray(r_bins, dtype=float)
152+
if r_bins.ndim != 1 or np.any(r_bins <= 0):
153+
raise ValueError("r_bins must be a 1D array of positive values")
154+
if not np.all(np.diff(r_bins) > 0):
155+
raise ValueError("r_bins must be strictly increasing")
156+
157+
self.r_bins = r_bins
68158
self.ranges = ranges
69159
self.num_points = num_points
70160
self.grid = self._generate_grid()
71161

72162
def _generate_grid(self):
163+
"""
164+
Generate the grid in the native coordinate system.
165+
166+
Returns
167+
-------
168+
ndarray, shape (N, 3)
169+
Flattened coordinate grid.
170+
"""
73171
if self.system == 'spherical':
74-
r_min, r_max = self.ranges[0]
75-
r_points = np.linspace(r_min**3, r_max**3, self.num_points[0]) # uniform in volume
76-
r = r_points**(1/3)
77172

78-
theta = np.linspace(0, 2 * np.pi, self.num_points[1], endpoint=False)
79-
u = np.linspace(-1, 1, self.num_points[2])
80-
phi = np.arccos(u)
173+
if self.r_bins is not None:
174+
r = self.r_bins
175+
else:
176+
# Evenly distributed by volume
177+
r_min, r_max = self.ranges[0]
178+
r_points = np.linspace(r_min**3, r_max**3, self.num_points[0])
179+
r = r_points ** (1/3)
81180

181+
u = np.linspace(-1, 1, self.num_points[1])
182+
theta = np.arccos(u)
183+
phi = np.linspace(0, 2 * np.pi, self.num_points[2], endpoint=False)
184+
82185
R, Theta, Phi = np.meshgrid(r, theta, phi, indexing='ij')
83186
return np.stack([R, Theta, Phi], axis=-1).reshape(-1, 3)
84187

85188
elif self.system == 'cylindrical':
86189
rho = np.linspace(*self.ranges[0], self.num_points[0])
87-
theta = np.linspace(0, 2 * np.pi, self.num_points[1], endpoint=False)
190+
phi = np.linspace(0, 2 * np.pi, self.num_points[1], endpoint=False)
88191
z = np.linspace(*self.ranges[2], self.num_points[2])
89-
Rho, Theta, Z = np.meshgrid(rho, theta, z, indexing='ij')
90-
return np.stack([Rho, Theta, Z], axis=-1).reshape(-1, 3)
192+
Rho, Phi, Z = np.meshgrid(rho, phi, z, indexing='ij')
193+
return np.stack([Rho, Phi, Z], axis=-1).reshape(-1, 3)
91194

92195
elif self.system == 'cartesian':
93-
axes = [np.linspace(start, stop, n) for (start, stop), n in zip(self.ranges, self.num_points)]
196+
axes = [
197+
np.linspace(start, stop, n)
198+
for (start, stop), n in zip(self.ranges, self.num_points)
199+
]
94200
mesh = np.meshgrid(*axes, indexing='ij')
95201
return np.stack(mesh, axis=-1).reshape(-1, 3)
96202

203+
97204
def to(self, target_system):
98205
"""
99-
Transform the current grid to a new coordinate system.
206+
Transform the grid to a different coordinate system.
100207
101208
Parameters
102209
----------
103210
target_system : str
104-
One of {'cartesian', 'spherical', 'cylindrical'}
211+
Target coordinate system. Must be one of
212+
``{'cartesian', 'spherical', 'cylindrical'}``.
105213
106214
Returns
107215
-------
108-
np.ndarray
109-
Transformed (N, 3) grid.
216+
ndarray, shape (N, 3)
217+
Grid transformed to the target coordinate system.
110218
"""
111219
target_system = target_system.lower()
112220
if target_system == self.system:
@@ -123,40 +231,77 @@ def to(self, target_system):
123231
return self._from_cartesian(cartesian, target_system)
124232

125233
def _to_cartesian(self, grid, system):
234+
"""
235+
Convert a grid from spherical or cylindrical to Cartesian coordinates.
236+
237+
Parameters
238+
----------
239+
grid : ndarray, shape (N, 3)
240+
Input grid.
241+
system : {'spherical', 'cylindrical'}
242+
Original coordinate system.
243+
244+
Returns
245+
-------
246+
ndarray, shape (N, 3)
247+
Cartesian coordinates.
248+
"""
126249
if system == 'spherical':
127250
r, theta, phi = grid[:, 0], grid[:, 1], grid[:, 2]
128-
x = r * np.sin(phi) * np.cos(theta)
129-
y = r * np.sin(phi) * np.sin(theta)
130-
z = r * np.cos(phi)
251+
x = r * np.sin(theta) * np.cos(phi)
252+
y = r * np.sin(theta) * np.sin(phi)
253+
z = r * np.cos(theta)
131254
return np.column_stack((x, y, z))
132255
elif system == 'cylindrical':
133-
rho, theta, z = grid[:, 0], grid[:, 1], grid[:, 2]
134-
x = rho * np.cos(theta)
135-
y = rho * np.sin(theta)
256+
rho, phi, z = grid[:, 0], grid[:, 1], grid[:, 2]
257+
x = rho * np.cos(phi)
258+
y = rho * np.sin(phi)
136259
return np.column_stack((x, y, z))
137260
else:
138261
raise ValueError("Invalid system for Cartesian conversion")
139262

140263
def _from_cartesian(self, cartesian, target_system):
264+
"""
265+
Convert Cartesian coordinates to another coordinate system.
266+
267+
Parameters
268+
----------
269+
cartesian : ndarray, shape (N, 3)
270+
Cartesian coordinates.
271+
target_system : {'spherical', 'cylindrical'}
272+
Desired coordinate system.
273+
274+
Returns
275+
-------
276+
ndarray, shape (N, 3)
277+
Transformed coordinates.
278+
"""
141279
x, y, z = cartesian[:, 0], cartesian[:, 1], cartesian[:, 2]
142280

143281
if target_system == 'spherical':
144282
r = np.sqrt(x**2 + y**2 + z**2)
145283
# Avoid division by zero
146284
with np.errstate(invalid='ignore', divide='ignore'):
147-
theta = np.arctan2(y, x)
148-
phi = np.arccos(np.clip(z / np.where(r == 0, 1, r), -1.0, 1.0))
285+
phi = np.arctan2(y, x)
286+
theta = np.arccos(np.clip(z / np.where(r == 0, 1, r), -1.0, 1.0))
149287
return np.column_stack((r, theta, phi))
150288

151289
elif target_system == 'cylindrical':
152290
rho = np.sqrt(x**2 + y**2)
153-
theta = np.arctan2(y, x)
154-
return np.column_stack((rho, theta, z))
291+
phi = np.arctan2(y, x)
292+
return np.column_stack((rho, phi, z))
155293

156294
else:
157295
raise ValueError("Invalid target system")
158296

159297
def get(self):
160-
"""Return the raw grid in its original coordinate system."""
298+
"""
299+
Return the raw grid in its native coordinate system.
300+
301+
Returns
302+
-------
303+
ndarray, shape (N, 3)
304+
Grid coordinates.
305+
"""
161306
return self.grid
162307

0 commit comments

Comments
 (0)