22from prtools .backend import numpy as np
33
44
5- def dft2 (f , alpha , shape = None , shift = (0 , 0 ), offset = (0 , 0 ), unitary = True ,
6- out = None ):
5+ def dft2 (f , alpha , shape = None , shift = (0 , 0 ), offset = (0 , 0 ), axes = ( - 2 , - 1 ) ,
6+ unitary = True , out = None ):
77 r"""Compute the 2-dimensional discrete Fourier Transform.
88
99 The DFT is defined in one dimension as
@@ -31,19 +31,22 @@ def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), unitary=True,
3131 ``F.shape = (shape[0], shape[1])``. Default is ``f.shape``.
3232 shift : array_like, optional
3333 Number of pixels in (r,c) to shift the DC pixel in the output plane
34- with the origin centrally located in the plane. Default is ``(0,0)``.
34+ with the origin centrally located in the plane. Default is ``(0, 0)``.
3535 offset : array_like, optional
3636 Number of pixels in (r,c) that the input plane is shifted relative to
37- the origin. Default is ``(0,0)``.
37+ the origin. Default is ``(0, 0)``.
38+ axes : (2,) array_like of ints, optional
39+ Axes over which to compute the DFT. If not given, the last two axes are
40+ used.
3841 unitary : bool, optional
3942 Normalization flag. If ``True``, a normalization is performed on the
4043 output such that the DFT operation is unitary and energy is conserved
4144 through the Fourier transform operation (Parseval's theorem). In this
4245 way, the energy in in a limited-area DFT is a fraction of the total
4346 energy corresponding to the limited area. Default is ``True``.
4447 out : ndarray or None
45- A location into which the result is stored. If provided, out.shape ==
46- shape and out.dtype == np.complex. If not provided or None, a
48+ A location into which the result is stored. If provided, `` out.shape ==
49+ shape`` and `` out.dtype == np.complex`` . If not provided or None, a
4750 freshly-allocated array is returned.
4851
4952 Returns
@@ -76,36 +79,42 @@ def dft2(f, alpha, shape=None, shift=(0, 0), offset=(0, 0), unitary=True,
7679 [1] Soummer, et. al. Fast computation of Lyot-style coronagraph
7780 propagation (2007)
7881 """
79- return _dftcore (f , alpha , shape , shift , offset , unitary , out , forward = True )
8082
83+ return _dftcore (f , alpha , shape , shift , offset , axes , unitary ,
84+ forward = True , out = out )
8185
82- def _dftcore (f , alpha , shape , shift , offset , unitary , out , forward ):
8386
84- #__backend__ = prtools.__backend__
87+ def _dftcore ( f , alpha , shape , shift , offset , axes , unitary , forward , out ):
8588
86- if out is not None :
87- if __backend__ == 'numpy' :
89+ if __backend__ == 'numpy' :
90+ if out is not None :
8891 if not np .can_cast (complex , out .dtype ):
8992 raise TypeError (f"Cannot cast complex output to dtype('{ out .dtype } ')" )
90- elif __backend__ == 'jax' :
91- raise ValueError ('JAX backend does not support the out parameter' )
9293
93- alpha_row , alpha_col = np .broadcast_to (alpha , (2 ,))
94+ elif __backend__ == 'jax' :
95+ if out is not None :
96+ raise ValueError ('JAX backend does not support the out parameter' )
9497
9598 f = np .asarray (f )
96- m , n = f .shape
9799
98- if shape is None :
99- shape = (m , n )
100- M , N = shape
100+ out_shape , axes = _cook_nd_args (f , shape , axes )
101+ in_shape = np .take (np .asarray (f .shape ), axes )
101102
103+ m , n = in_shape
104+ M , N = out_shape
105+
106+ alpha_row , alpha_col = np .broadcast_to (alpha , (2 ,))
102107 shift_row , shift_col = np .broadcast_to (shift , (2 ,))
103108 offset_row , offset_col = np .broadcast_to (offset , (2 ,))
104109
105- E1 , E2 = _dft2_matrices (m , n , M , N , alpha_row , alpha_col , shift_row , shift_col ,
106- offset_row , offset_col , forward )
107-
108- F = np .dot (E1 .dot (f ), E2 , out = out )
110+ E1 , E2 = _dft2_matrices (m , n , M , N , alpha_row , alpha_col , shift_row ,
111+ shift_col , offset_row , offset_col , forward )
112+ # note there's no function _multi_dot_three in the base numpy namespace
113+ # (although this function does exist in numpy.linalg). What's really being
114+ # called here is prtools.backend.numpy._multi_dot_three, which provides
115+ # different highly optimized implementations of the matrix triple product
116+ # depending on which backend is active.
117+ F = np ._multi_dot_three (E1 , f , E2 , axes , out )
109118
110119 if unitary :
111120 F = np .multiply (F , np .sqrt (np .abs (alpha_row * alpha_col )), out = F )
@@ -117,28 +126,21 @@ def _dftcore(f, alpha, shape, shift, offset, unitary, out, forward):
117126 return F
118127
119128
120- def _dft2_matrices (m , n , M , N , alphar , alphac , shiftr , shiftc , offsetr , offsetc , forward ):
129+ def _dft2_matrices (m , n , M , N , alphar , alphac , shiftr , shiftc , offsetr ,
130+ offsetc , forward ):
121131 if forward :
122- sign = - 1
132+ c = - 1j
123133 else :
124- sign = 1
134+ c = 1j
125135 R , S , U , V = _dft2_coords (m , n , M , N )
126- E1 = np .exp (sign * 2.0 * 1j * np .pi * alphar * np .outer (R + offsetr , U - shiftr )).T
127- E2 = np .exp (sign * 2.0 * 1j * np .pi * alphac * np .outer (S + offsetc , V - shiftc ))
128- return E1 , E2
129-
130-
131- def _idft2_matrices (m , n , M , N , alphar , alphac , shiftr , shiftc , offsetr , offsetc ):
132- R , S , U , V = _dft2_coords (m , n , M , N )
133- E1 = np .exp (2.0 * 1j * np .pi * alphar * np .outer (R + offsetr , U - shiftr )).T
134- E2 = np .exp (2.0 * 1j * np .pi * alphac * np .outer (S + offsetc , V - shiftc ))
136+ E1 = np .exp (2.0 * c * np .pi * alphar * np .outer (R + offsetr , U - shiftr )).T
137+ E2 = np .exp (2.0 * c * np .pi * alphac * np .outer (S + offsetc , V - shiftc ))
135138 return E1 , E2
136139
137140
138141def _dft2_coords (m , n , M , N ):
139142 # R and S are (r,c) coordinates in the (m x n) input plane f
140143 # V and U are (r,c) coordinates in the (M x N) output plane F
141-
142144 R = np .arange (m ) - np .floor (m / 2.0 )
143145 S = np .arange (n ) - np .floor (n / 2.0 )
144146 U = np .arange (M ) - np .floor (M / 2.0 )
@@ -147,7 +149,28 @@ def _dft2_coords(m, n, M, N):
147149 return R , S , U , V
148150
149151
150- def idft2 (F , alpha , shape = None , shift = (0 ,0 ), offset = (0 ,0 ), unitary = True , out = None ):
152+ def _cook_nd_args (a , s = None , axes = None ):
153+ # slightly modified version of numpy's function of the same name
154+ if s is None :
155+ if axes is None :
156+ if a .ndim == 2 :
157+ s = list (a .shape )
158+ elif a .ndim == 3 :
159+ s = list (a .shape [1 :3 ])
160+ else :
161+ raise ValueError ("Array must have ndim == 2 or 3" )
162+ else :
163+ s = np .take (a .shape , axes )
164+ s = list (s )
165+ if axes is None :
166+ axes = list (range (- len (s ), 0 ))
167+ if len (s ) != len (axes ):
168+ raise ValueError ("Shape and axes have different lengths." )
169+ return s , axes
170+
171+
172+ def idft2 (F , alpha , shape = None , shift = (0 , 0 ), offset = (0 , 0 ), axes = (- 2 , - 1 ),
173+ unitary = True , out = None ):
151174 r"""Compute the 2-dimensional inverse discrete Fourier Transform.
152175
153176 The IDFT is defined in one dimension as
@@ -165,24 +188,34 @@ def idft2(F, alpha, shape=None, shift=(0,0), offset=(0,0), unitary=True, out=Non
165188 F : array_like
166189 2D array to Fourier Transform
167190 alpha : float or array_like
168- Input plane sampling interval (frequency). If :attr:`alpha` is an array,
169- ``alpha[1]`` represents row-wise sampling and ``alpha[2]`` represents
170- column-wise sampling. If :attr:`alpha` is a scalar,
191+ Input plane sampling interval (frequency). If :attr:`alpha` is an
192+ array, ``alpha[1]`` represents row-wise sampling and ``alpha[2]``
193+ represents column-wise sampling. If :attr:`alpha` is a scalar,
171194 ``alpha[1] = alpha[2] = alpha`` represents uniform sampling across the
172195 rows and columns of the input plane.
173196 shape : int or array_like, optional
174197 Size of the output array :attr:`F`. If :attr:`npshapeix` is an array,
175198 ``F.shape = (shape[0], shape[1])``. If :attr:`shape` is a scalar,
176199 ``F.shape = (shape, shape)``. Default is ``F.shape``
177200 shift : array_like, optional
178- Number of pixels in (x,y) to shift the DC pixel in the output plane with
179- the origin centrally located in the plane. Default is `[0,0]`.
201+ Number of pixels in (x,y) to shift the DC pixel in the output plane
202+ with the origin centrally located in the plane. Default is `[0,0]`.
203+ offset : array_like, optional
204+ Number of pixels in (r,c) that the input plane is shifted relative to
205+ the origin. Default is ``(0, 0)``.
206+ axes : (2,) array_like of ints, optional
207+ Axes over which to compute the DFT. If not given, the last two axes are
208+ used.
180209 unitary : bool, optional
181210 Normalization flag. If ``True``, a normalization is performed on the
182211 output such that the DFT operation is unitary and energy is conserved
183212 through the Fourier transform operation (Parseval's theorem). In this
184213 way, the energy in in a limited-area DFT is a fraction of the total
185214 energy corresponding to the limited area. Default is ``True``.
215+ out : ndarray or None
216+ A location into which the result is stored. If provided, ``out.shape ==
217+ shape`` and ``out.dtype == np.complex``. If not provided or None, a
218+ freshly-allocated array is returned.
186219
187220 Returns
188221 -------
@@ -205,13 +238,15 @@ def idft2(F, alpha, shape=None, shift=(0,0), offset=(0,0), unitary=True, out=Non
205238
206239 * If the y-axis shift behavior is not what you are expecting, you most
207240 likely have your plotting axes flipped (matplotlib's default behavior is
208- to place [0,0] in the upper left corner of the axes). This may be resolved
209- by either flipping the sign of the y component of ``shift`` or by passing
210- ``origin = 'lower'`` to ``imshow()``.
241+ to place [0,0] in the upper left corner of the axes). This may be
242+ resolved by either flipping the sign of the y component of ``shift`` or
243+ by passing ``origin = 'lower'`` to ``imshow()``.
211244
212245 References
213246 ----------
214- [1] Soummer, et. al. Fast computation of Lyot-style coronagraph propagation (2007)
215-
247+ [1] Soummer, et. al. Fast computation of Lyot-style coronagraph
248+ propagation (2007)
249+
216250 """
217- return _dftcore (F , alpha , shape , shift , offset , unitary , out , forward = False )
251+ return _dftcore (F , alpha , shape , shift , offset , axes , unitary ,
252+ forward = False , out = out )
0 commit comments