Skip to content

Commit c101f28

Browse files
committed
Overhaul dft2 backend
* Support taking the dft over arbitrary axes of a 3d array * Rework _dftcore to compute the matrix triple product more efficiently for both the numpy and jax implementations
1 parent 1e09ae2 commit c101f28

5 files changed

Lines changed: 169 additions & 83 deletions

File tree

prtools/backend/_jax.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
class Numpy(BackendLibrary):
77
def __init__(self):
88
super().__init__(importlib.import_module('jax.numpy'))
9+
self.jax = importlib.import_module('jax')
910

1011
def broadcast_to(self, array, shape):
1112
# jax numpy.broadcast_to expects an array input
@@ -43,6 +44,43 @@ def sum(self, a, *args, **kwargs):
4344
a = self.module.asarray(a)
4445
return self.module.sum(a, *args, **kwargs)
4546

47+
def take(self, a, indices, *args, **kwargs):
48+
# jax numpy.take expects an array input for a and indices
49+
a = self.module.asarray(a)
50+
indices = self.module.asarray(indices)
51+
return self.module.take(a, indices, *args, **kwargs)
52+
53+
def _multi_dot_three(self, a, b, c, axes, out):
54+
# compute the matrix triple product
55+
#
56+
# a few notes:
57+
# * while numpy-based implementation of this method is based on
58+
# np.matmul, jax.numpy.matmul doesn't implement the axes argument so
59+
# we have to use jax.numpy.linalg.multi_dot instead
60+
# * the implementation used here supports b with ndim in (2, 3)
61+
# iterating over any of the 3 axes when b.ndim == 3
62+
# * jax.vmap handles the case when b.ndim == 3 compared with the numpy
63+
# equivalent of this function which does everything within the
64+
# confines of matmul using the axes argument
65+
if b.ndim == 2:
66+
return self.module.linalg.multi_dot((a, b, c))
67+
else:
68+
iter_axis = _iter_axis(axes)
69+
return self.jax.vmap(self._multi_dot, in_axes=[None, iter_axis, None], out_axes=iter_axis)(a, b, c)
70+
71+
def _multi_dot(self, a, b, c):
72+
# wrapper function to support vmap call signature
73+
return self.module.linalg.multi_dot((a, b, c))
74+
75+
76+
def _iter_axis(axes):
77+
# NOTE: this function is purposely written in pure Python to avoid
78+
# dealing with mutability issues when __backend__ is JAX
79+
mask = [0, 1, 2]
80+
for ax in axes:
81+
mask[ax] = None
82+
return [ax for ax in mask if ax is not None][0]
83+
4684

4785
class Scipy(BackendLibrary):
4886
def __init__(self):

prtools/backend/_numpy.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@ def __init__(self):
66
import numpy
77
super().__init__(numpy)
88

9+
def _multi_dot_three(self, a, b, c, axes, out):
10+
# compute the matrix triple product
11+
#
12+
# a few notes:
13+
# * this method is similar to np.linalg.multi_dot although it is less
14+
# general - here we only consider the matrix triple product used as
15+
# a part of prtools.dft2
16+
# * because we use np.matmul instead of np.linalg.multi_dot, we can
17+
# take advantage of broadcasting a and c when b.ndim = 3. This
18+
# eliminates a for loop in the code
19+
# * the implementation used here supports b with ndim in (2, 3)
20+
# iterating over any of the 3 axes when b.ndim == 3
21+
# * the implementation used here is actually slightly faster than
22+
# an equivalent call to np.linalg.multi_dot when b.ndim == 2
23+
# * np.linalg.multi_dot chooses the fastest multiplication order from
24+
# [(ab)c, a(bc)] depending on the shapes of a, b, and c. There is no
25+
# difference when computing the dft because both a and c are square
26+
# matrices
27+
ab = self.module.matmul(a, b, axes=[(0, 1), axes, axes])
28+
out = self.module.matmul(ab, c, axes=[axes, (0, 1), axes], out=out)
29+
return out
30+
931

1032
class Scipy(BackendLibrary):
1133
def __init__(self):

prtools/fourier.py

Lines changed: 82 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from prtools.backend import numpy as np
33

44

5-
def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), unitary=True,
6-
out=None):
5+
def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), axes=(-2, -1),
6+
unitary=True, out=None):
77
r"""Compute the 2-dimensional discrete Fourier Transform.
88
99
The DFT is defined in one dimension as
@@ -31,19 +31,22 @@ def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), unitary=True,
3131
``F.shape = (shape[0], shape[1])``. Default is ``f.shape``.
3232
shift : array_like, optional
3333
Number of pixels in (r,c) to shift the DC pixel in the output plane
34-
with the origin centrally located in the plane. Default is ``(0,0)``.
34+
with the origin centrally located in the plane. Default is ``(0, 0)``.
3535
offset : array_like, optional
3636
Number of pixels in (r,c) that the input plane is shifted relative to
37-
the origin. Default is ``(0,0)``.
37+
the origin. Default is ``(0, 0)``.
38+
axes : (2,) array_like of ints, optional
39+
Axes over which to compute the DFT. If not given, the last two axes are
40+
used.
3841
unitary : bool, optional
3942
Normalization flag. If ``True``, a normalization is performed on the
4043
output such that the DFT operation is unitary and energy is conserved
4144
through the Fourier transform operation (Parseval's theorem). In this
4245
way, the energy in in a limited-area DFT is a fraction of the total
4346
energy corresponding to the limited area. Default is ``True``.
4447
out : ndarray or None
45-
A location into which the result is stored. If provided, out.shape ==
46-
shape and out.dtype == np.complex. If not provided or None, a
48+
A location into which the result is stored. If provided, ``out.shape ==
49+
shape`` and ``out.dtype == np.complex``. If not provided or None, a
4750
freshly-allocated array is returned.
4851
4952
Returns
@@ -76,36 +79,42 @@ def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), unitary=True,
7679
[1] Soummer, et. al. Fast computation of Lyot-style coronagraph
7780
propagation (2007)
7881
"""
79-
return _dftcore(f, alpha, shape, shift, offset, unitary, out, forward=True)
8082

83+
return _dftcore(f, alpha, shape, shift, offset, axes, unitary,
84+
forward=True, out=out)
8185

82-
def _dftcore(f, alpha, shape, shift, offset, unitary, out, forward):
8386

84-
#__backend__ = prtools.__backend__
87+
def _dftcore(f, alpha, shape, shift, offset, axes, unitary, forward, out):
8588

86-
if out is not None:
87-
if __backend__ == 'numpy':
89+
if __backend__ == 'numpy':
90+
if out is not None:
8891
if not np.can_cast(complex, out.dtype):
8992
raise TypeError(f"Cannot cast complex output to dtype('{out.dtype}')")
90-
elif __backend__ == 'jax':
91-
raise ValueError('JAX backend does not support the out parameter')
9293

93-
alpha_row, alpha_col = np.broadcast_to(alpha, (2,))
94+
elif __backend__ == 'jax':
95+
if out is not None:
96+
raise ValueError('JAX backend does not support the out parameter')
9497

9598
f = np.asarray(f)
96-
m, n = f.shape
9799

98-
if shape is None:
99-
shape = (m, n)
100-
M, N = shape
100+
out_shape, axes = _cook_nd_args(f, shape, axes)
101+
in_shape = np.take(np.asarray(f.shape), axes)
101102

103+
m, n = in_shape
104+
M, N = out_shape
105+
106+
alpha_row, alpha_col = np.broadcast_to(alpha, (2,))
102107
shift_row, shift_col = np.broadcast_to(shift, (2,))
103108
offset_row, offset_col = np.broadcast_to(offset, (2,))
104109

105-
E1, E2 = _dft2_matrices(m, n, M, N, alpha_row, alpha_col, shift_row, shift_col,
106-
offset_row, offset_col, forward)
107-
108-
F = np.dot(E1.dot(f), E2, out=out)
110+
E1, E2 = _dft2_matrices(m, n, M, N, alpha_row, alpha_col, shift_row,
111+
shift_col, offset_row, offset_col, forward)
112+
# note there's no function _multi_dot_three in the base numpy namespace
113+
# (although this function does exist in numpy.linalg). What's really being
114+
# called here is prtools.backend.numpy._multi_dot_three, which provides
115+
# different highly optimized implementations of the matrix triple product
116+
# depending on which backend is active.
117+
F = np._multi_dot_three(E1, f, E2, axes, out)
109118

110119
if unitary:
111120
F = np.multiply(F, np.sqrt(np.abs(alpha_row * alpha_col)), out=F)
@@ -117,28 +126,21 @@ def _dftcore(f, alpha, shape, shift, offset, unitary, out, forward):
117126
return F
118127

119128

120-
def _dft2_matrices(m, n, M, N, alphar, alphac, shiftr, shiftc, offsetr, offsetc, forward):
129+
def _dft2_matrices(m, n, M, N, alphar, alphac, shiftr, shiftc, offsetr,
130+
offsetc, forward):
121131
if forward:
122-
sign = -1
132+
c = -1j
123133
else:
124-
sign = 1
134+
c = 1j
125135
R, S, U, V = _dft2_coords(m, n, M, N)
126-
E1 = np.exp(sign*2.0 * 1j * np.pi * alphar * np.outer(R+offsetr, U-shiftr)).T
127-
E2 = np.exp(sign*2.0 * 1j * np.pi * alphac * np.outer(S+offsetc, V-shiftc))
128-
return E1, E2
129-
130-
131-
def _idft2_matrices(m, n, M, N, alphar, alphac, shiftr, shiftc, offsetr, offsetc):
132-
R, S, U, V = _dft2_coords(m, n, M, N)
133-
E1 = np.exp(2.0 * 1j * np.pi * alphar * np.outer(R+offsetr, U-shiftr)).T
134-
E2 = np.exp(2.0 * 1j * np.pi * alphac * np.outer(S+offsetc, V-shiftc))
136+
E1 = np.exp(2.0 * c * np.pi * alphar * np.outer(R+offsetr, U-shiftr)).T
137+
E2 = np.exp(2.0 * c * np.pi * alphac * np.outer(S+offsetc, V-shiftc))
135138
return E1, E2
136139

137140

138141
def _dft2_coords(m, n, M, N):
139142
# R and S are (r,c) coordinates in the (m x n) input plane f
140143
# V and U are (r,c) coordinates in the (M x N) output plane F
141-
142144
R = np.arange(m) - np.floor(m/2.0)
143145
S = np.arange(n) - np.floor(n/2.0)
144146
U = np.arange(M) - np.floor(M/2.0)
@@ -147,7 +149,28 @@ def _dft2_coords(m, n, M, N):
147149
return R, S, U, V
148150

149151

150-
def idft2(F, alpha, shape=None, shift=(0,0), offset=(0,0), unitary=True, out=None):
152+
def _cook_nd_args(a, s=None, axes=None):
153+
# slightly modified version of numpy's function of the same name
154+
if s is None:
155+
if axes is None:
156+
if a.ndim == 2:
157+
s = list(a.shape)
158+
elif a.ndim == 3:
159+
s = list(a.shape[1:3])
160+
else:
161+
raise ValueError("Array must have ndim == 2 or 3")
162+
else:
163+
s = np.take(a.shape, axes)
164+
s = list(s)
165+
if axes is None:
166+
axes = list(range(-len(s), 0))
167+
if len(s) != len(axes):
168+
raise ValueError("Shape and axes have different lengths.")
169+
return s, axes
170+
171+
172+
def idft2(F, alpha, shape=None, shift=(0, 0), offset=(0, 0), axes=(-2, -1),
173+
unitary=True, out=None):
151174
r"""Compute the 2-dimensional inverse discrete Fourier Transform.
152175
153176
The IDFT is defined in one dimension as
@@ -165,24 +188,34 @@ def idft2(F, alpha, shape=None, shift=(0,0), offset=(0,0), unitary=True, out=Non
165188
F : array_like
166189
2D array to Fourier Transform
167190
alpha : float or array_like
168-
Input plane sampling interval (frequency). If :attr:`alpha` is an array,
169-
``alpha[1]`` represents row-wise sampling and ``alpha[2]`` represents
170-
column-wise sampling. If :attr:`alpha` is a scalar,
191+
Input plane sampling interval (frequency). If :attr:`alpha` is an
192+
array, ``alpha[1]`` represents row-wise sampling and ``alpha[2]``
193+
represents column-wise sampling. If :attr:`alpha` is a scalar,
171194
``alpha[1] = alpha[2] = alpha`` represents uniform sampling across the
172195
rows and columns of the input plane.
173196
shape : int or array_like, optional
174197
Size of the output array :attr:`F`. If :attr:`npshapeix` is an array,
175198
``F.shape = (shape[0], shape[1])``. If :attr:`shape` is a scalar,
176199
``F.shape = (shape, shape)``. Default is ``F.shape``
177200
shift : array_like, optional
178-
Number of pixels in (x,y) to shift the DC pixel in the output plane with
179-
the origin centrally located in the plane. Default is `[0,0]`.
201+
Number of pixels in (x,y) to shift the DC pixel in the output plane
202+
with the origin centrally located in the plane. Default is `[0,0]`.
203+
offset : array_like, optional
204+
Number of pixels in (r,c) that the input plane is shifted relative to
205+
the origin. Default is ``(0, 0)``.
206+
axes : (2,) array_like of ints, optional
207+
Axes over which to compute the DFT. If not given, the last two axes are
208+
used.
180209
unitary : bool, optional
181210
Normalization flag. If ``True``, a normalization is performed on the
182211
output such that the DFT operation is unitary and energy is conserved
183212
through the Fourier transform operation (Parseval's theorem). In this
184213
way, the energy in in a limited-area DFT is a fraction of the total
185214
energy corresponding to the limited area. Default is ``True``.
215+
out : ndarray or None
216+
A location into which the result is stored. If provided, ``out.shape ==
217+
shape`` and ``out.dtype == np.complex``. If not provided or None, a
218+
freshly-allocated array is returned.
186219
187220
Returns
188221
-------
@@ -205,13 +238,15 @@ def idft2(F, alpha, shape=None, shift=(0,0), offset=(0,0), unitary=True, out=Non
205238
206239
* If the y-axis shift behavior is not what you are expecting, you most
207240
likely have your plotting axes flipped (matplotlib's default behavior is
208-
to place [0,0] in the upper left corner of the axes). This may be resolved
209-
by either flipping the sign of the y component of ``shift`` or by passing
210-
``origin = 'lower'`` to ``imshow()``.
241+
to place [0,0] in the upper left corner of the axes). This may be
242+
resolved by either flipping the sign of the y component of ``shift`` or
243+
by passing ``origin = 'lower'`` to ``imshow()``.
211244
212245
References
213246
----------
214-
[1] Soummer, et. al. Fast computation of Lyot-style coronagraph propagation (2007)
215-
247+
[1] Soummer, et. al. Fast computation of Lyot-style coronagraph
248+
propagation (2007)
249+
216250
"""
217-
return _dftcore(F, alpha, shape, shift, offset, unitary, out, forward=False)
251+
return _dftcore(F, alpha, shape, shift, offset, axes, unitary,
252+
forward=False, out=out)

prtools/jax.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class JaxOptimizeResult:
3131
state: Any #: Optimizer state
3232

3333

34-
def lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None,
34+
def lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None,
3535
fn_kwargs=None):
3636
"""Minimize a scalar function of one or more variables using the L-BFGS
3737
algorithm
@@ -44,7 +44,7 @@ def lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None,
4444
.. code:: python
4545
4646
fn(x, *fn_args, **fn_kwargs)
47-
47+
4848
where ``x`` is a 1-D array with shape (n,) and ``fn_args`` and
4949
``fn_kwargs`` are optional positional and keyword arguments.
5050
x0 : jax.Array
@@ -88,14 +88,14 @@ def lbfgs(fn, x0, gtol=None, maxiter=None, callback=None, fn_args=None,
8888

8989
opt = optax.lbfgs()
9090
value_and_grad_fn = optax.value_and_grad_from_state(fn)
91-
91+
9292
def step(carry):
9393
params, state = carry
9494
# NOTE: passing *args and **kwargs to value_and_grad_fun is very
9595
# poorly documented in optax (as of v0.2.6 - 10/2025) but this
9696
# seems to work for now
97-
value, grad = value_and_grad_fn(params, *fn_args, state=state,
98-
**fn_kwargs)
97+
value, grad = value_and_grad_fn(params, *fn_args, state=state,
98+
**fn_kwargs)
9999
updates, state = opt.update(
100100
grad, state, params, value=value, grad=grad, value_fn=fn)
101101
if callback:
@@ -126,8 +126,9 @@ def continuing_criterion(carry):
126126
grad=otu.tree_get(final_state, 'grad'),
127127
value=otu.tree_get(final_state, 'value'),
128128
state=final_state)
129-
129+
130130
if callback:
131131
jax.debug.callback(callback, res)
132132

133133
return res
134+

0 commit comments

Comments
 (0)