diff --git a/.github/environment.yml b/.github/environment.yml index e53af9e88..a60569499 100644 --- a/.github/environment.yml +++ b/.github/environment.yml @@ -13,7 +13,6 @@ dependencies: - numba>=0.57 - xarray>=2022.03 - verde>=1.9.0 - - xrft>=1.0 - choclo>=0.1 - boule>=0.6.0 # Build diff --git a/README.md b/README.md index 278388f94..727dba81a 100644 --- a/README.md +++ b/README.md @@ -52,9 +52,8 @@ Things that will *not* be covered in Harmonica: - Multi-physics partial differential equation solvers. Use [SimPEG](http://www.simpeg.xyz/) or [PyGIMLi](https://www.pygimli.org/) instead. -- Generic grid processing methods (like FFT and standard interpolation). - We'll rely on [Verde](https://www.fatiando.org/verde), - [xrft](https://xrft.readthedocs.io/en/latest/) and +- Generic grid processing methods (like standard interpolation). + We'll rely on [Verde](https://www.fatiando.org/verde) and [xarray](https://xarray.dev) for those. - Data visualization. - GUI applications. diff --git a/doc/api/index.rst b/doc/api/index.rst index 2e99f1927..4477ff25b 100644 --- a/doc/api/index.rst +++ b/doc/api/index.rst @@ -51,8 +51,6 @@ Define filters in the frequency domain. filters.gaussian_highpass_kernel filters.reduction_to_pole_kernel -Use :func:`xrft.xrft.fft` and :func:`xrft.xrft.ifft` to apply Fast-Fourier -Transforms and its inverse on :class:`xarray.DataArray`. Equivalent Sources ------------------ diff --git a/doc/conf.py b/doc/conf.py index 68fc5486e..509ec96b4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -49,7 +49,6 @@ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "pandas": ("http://pandas.pydata.org/pandas-docs/stable/", None), "xarray": ("http://xarray.pydata.org/en/stable/", None), - "xrft": ("https://xrft.readthedocs.io/en/stable/", None), "pooch": ("https://www.fatiando.org/pooch/latest/", None), "ensaio": ("https://www.fatiando.org/ensaio/latest/", None), "verde": ("https://www.fatiando.org/verde/latest/", None), diff --git a/doc/install.rst b/doc/install.rst index c91c9ff8f..d286971f2 100644 --- a/doc/install.rst +++ b/doc/install.rst @@ -74,7 +74,6 @@ Required: * `scikit-learn `__ * `pooch `__ * `verde `__ -* `xrft `__ Optional: diff --git a/doc/overview.rst b/doc/overview.rst index fdd39857c..efd0e6eca 100644 --- a/doc/overview.rst +++ b/doc/overview.rst @@ -19,8 +19,7 @@ Harmonica *will not* provide: instead. - Generic processing methods like grid transformations (use `Verde `__ or `Xarray `__ - instead) or multidimensional FFT calculations (use `xrft - `__ instead). + instead). - Reference ellipsoid representations and computations like normal gravity. Use `Boule `__ instead. - Data visualization functions. Use `matplotlib `__ diff --git a/environment.yml b/environment.yml index 956cf3305..b7fea32fd 100644 --- a/environment.yml +++ b/environment.yml @@ -15,7 +15,6 @@ dependencies: - scikit-learn - verde>=1.8.1 - xarray - - xrft>=1.0 - choclo>=0.1 - boule>=0.6 # Optional requirements diff --git a/pyproject.toml b/pyproject.toml index cad85d5de..2d24b0e25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ dependencies = [ "numba >= 0.57", "xarray >= 2022.03", "verde >= 1.8.1", - "xrft >= 1.0", "choclo >= 0.1", "boule >= 0.6.0" ] diff --git a/src/harmonica/filters/_fft.py b/src/harmonica/filters/_fft.py index d65d2de32..6c6f8454e 100644 --- a/src/harmonica/filters/_fft.py +++ b/src/harmonica/filters/_fft.py @@ -5,14 +5,19 @@ # This code is part of the Fatiando a Terra project (https://www.fatiando.org) # """ -Wrap xrft functions to compute FFTs and inverse FFTs. +Custom FFT and inverse FFT functions that work with :class:`xarray.DataArray`. + +These functions are inspired in the ``fft`` and ``ifft`` functions provided by ``xrft``, +which are released under the MIT license. """ -import xrft +import numpy as np +import numpy.typing as npt +import xarray as xr -def fft(grid, true_phase=True, true_amplitude=True, **kwargs): - """ +def fft(grid, *, prefix="freq_"): + r""" Compute Fast Fourier Transform of a 2D regular grid. Parameters @@ -22,60 +27,305 @@ def fft(grid, true_phase=True, true_amplitude=True, **kwargs): evenly spaced (regular grid). Its dimensions should be in the following order: *northing*, *easting*. Its coordinates should be defined in the same units. - true_phase : bool (optional) - Take the coordinates into consideration, keeping the original phase of - the coordinates in the spatial domain (``direct_lag``) and multiplies - the FFT with an exponential function corresponding to this phase. - Defaults to True. - true_amplitude : bool (optional) - If True, the FFT is multiplied by the spacing of the transformed - variables to match theoretical FT amplitude. - Defaults to True. - **kwargs - Any extra keyword arguments will be passed the :func:`xrft.fft` function. + prefix : str, optional + Prefix used for the name of the frequency coordinates and dimensions. + Returns ------- fourier_transform : :class:`xarray.DataArray` Array with the Fourier transform of the original grid. + + Notes + ----- + This function implements the discrete Fourier Transform of 2D regular grids. It's + based on the following definition of the Fourier Transform of + a :math:`g:\mathbb{R}^2 \rightarrow \mathbb{R}` function in the spatial domain: + + .. math:: + + \mathcal{F}[g](f_x, f_y) = + \int\limits_{-\infty}^{\infty} + \int\limits_{-\infty}^{\infty} + g(x, y) e^{-2 \pi i f_x x} e^{-2 \pi i f_y y} + \text{d}x + \text{d}y + + + If we consider two discretized spaces for :math:`x` and :math:`y`, both evenly + spaced with steps equal to :math:`\Delta x` and :math:`\Delta y`, respectively, then + the :math:`(k, j)` element of the discrete Fourier Transform of :math:`g` can be + defined as follows: + + .. math:: + + \mathcal{F}[g]_{(j, k)} = + \sum\limits_{n=0}^{N-1} + \sum\limits_{m=0}^{M-1} + g(x_n, y_m) e^{-2 \pi i f_x_j x_n} e^{-2 \pi i f_y_k y_m} + \Delta x \Delta y + + This function differs from the plain :func:`numpy.fft.fftn` function since it + implements the *true amplitude* and the *true phase* corrections. + """ - return xrft.fft( - grid, true_phase=true_phase, true_amplitude=true_amplitude, **kwargs + if not isinstance(grid, xr.DataArray): + msg = ( + f"Invalid 'grid' of type '{type(grid).__name__}'. " + "It must be an xarray.DataArray." + ) + raise TypeError(msg) + if grid.ndim != 2: + msg = ( + f"Invalid grid array with '{grid.ndim}' dimensions. It must be a 2D array." + ) + raise ValueError(msg) + + # Get dimensional coordinates, spacings, and coordinates' shifts + dimensional_coords = tuple( + _get_dimensional_coordinate(grid, dim) for dim in grid.dims ) + spacings = tuple(_get_spacing(grid.coords[coord]) for coord in dimensional_coords) + shifts = tuple(grid.coords[coord].values.min() for coord in dimensional_coords) + # Generate new coordinates + freqs = tuple( + _fftfreq(grid.coords[coord], spacing) + for coord, spacing in zip(dimensional_coords, spacings, strict=True) + ) -def ifft(fourier_transform, true_phase=True, true_amplitude=True, **kwargs): - """ - Compute Inverse Fast Fourier Transform of a 2D regular grid. + # Compute FFT + fft = np.fft.fftshift(np.fft.fftn(grid.values)) + + # Account for true amplitude and true phase + freqs_2d = np.meshgrid(freqs[1], freqs[0]) # invert order to create meshgrid + for freq, shift, spacing in zip(freqs_2d, shifts, spacings, strict=True): + fft *= np.exp(-2 * 1j * np.pi * freq * shift) * spacing + + # Build the FFT xr.DataArray + dims = tuple(f"{prefix}{dim}" for dim in grid.dims) + coords = { + f"{prefix}{coord}": (dim, freq) + for coord, dim, freq in zip(dimensional_coords, dims, freqs, strict=True) + } + da_fft = xr.DataArray(fft, dims=dims, coords=coords) + + # Add shifts to frequency coordinates + for coord, shift in zip(coords, shifts, strict=True): + da_fft.coords[coord].attrs.update({"shift": shift}) + return da_fft + + +def ifft(fft_grid, *, prefix="freq_"): + r""" + Compute the inverse Fast Fourier Transform of a 2D regular grid. + + If the frequency coordinates have a *shift* attribute, it will be used to shift the + coordinates in the spatial domain to such value. + + .. important:: + + Assumes that the ``fft_grid`` is *shifted*: it was passed to + :func:`numpy.fft.fftshift`. The outputs of the ``fft`` function satisfy this + condition. Parameters ---------- - fourier_transform : :class:`xarray.DataArray` + fft_grid : :class:`xarray.DataArray` Array with a regular grid defined in the frequency domain. Its dimensions should be in the following order: *freq_northing*, *freq_easting*. - true_phase : bool (optional) - Take the coordinates into consideration, recovering the original - coordinates in the spatial domain returning to the the original phase - (``direct_lag``), and multiplies the iFFT with an exponential function - corresponding to this phase. - Defaults to True. - true_amplitude : bool (optional) - If True, output is divided by the spacing of the transformed variables - to match theoretical IFT amplitude. - Defaults to True. - **kwargs - Any extra keyword arguments will be passed the :func:`xrft.ifft` function. + prefix : str, optional + Prefix used for the name of the frequency coordinates and dimensions. Returns ------- grid : :class:`xarray.DataArray` Array with the inverse Fourier transform of the passed grid. + + Notes + ----- + This function implements the discrete inverse Fourier Transform of 2D regular grids. + It's based on the following definition of the inverse Fourier Transform of + a :math:`G:\mathbb{R}^2 \rightarrow \mathbb{R}` function in the frequency + domain: + + .. math:: + + \mathcal{F}^{-1}[G](f_x, f_y) = + \int\limits_{-\infty}^{\infty} + \int\limits_{-\infty}^{\infty} + G(f_x, f_y) e^{2 \pi i f_x x} e^{2 \pi i f_y y} + \text{d}f_x + \text{d}f_y + + + If we consider two discretized spaces for the :math:`f_x` and :math:`f_y` + frequencies, both evenly spaced with steps equal to :math:`\Delta f_x` and + :math:`\Delta f_y`, respectively, then the :math:`(n, m)` element of the discrete + inverse Fourier Transform of :math:`G` can be defined as follows: + + .. math:: + + \mathcal{F}^{-1}[G]_{(n, m)} = + \sum\limits_{j=0}^{N-1} + \sum\limits_{k=0}^{M-1} + G(f_x_j, f_y_k) e^{2 \pi i f_x_j x_n} e^{2 \pi i f_y_k y_m} + \Delta f_x \Delta f_y + + This function differs from the plain :func:`numpy.fft.ifftn` function since it + implements the *true amplitude* and the *true phase* corrections. + """ - return xrft.ifft( - fourier_transform, - true_phase=true_phase, - true_amplitude=true_amplitude, - lag=(None, None), # Mutes an annoying FutureWarning from xrft - **kwargs, + if not isinstance(fft_grid, xr.DataArray): + msg = ( + f"Invalid 'grid' of type '{type(fft_grid).__name__}'. " + "It must be an xarray.DataArray." + ) + raise TypeError(msg) + if fft_grid.ndim != 2: + msg = ( + f"Invalid grid array with '{fft_grid.ndim}' dimensions. " + "It must be a 2D array." + ) + raise ValueError(msg) + + for dim in fft_grid.dims: + if not dim.startswith(prefix): + msg = ( + f"Invalid frequency dimension '{dim}'. " + f"It doesn't start with prefix '{prefix}'." + ) + raise ValueError(msg) + + # Get dimensional frequency coordinates and spacings + dimensional_fft_coords = tuple( + _get_dimensional_coordinate(fft_grid, dim) for dim in fft_grid.dims + ) + for coord in dimensional_fft_coords: + if not coord.startswith(prefix): + msg = ( + f"Invalid dimensional coordinate '{coord}'. " + f"It doesn't start with prefix '{prefix}'." + ) + raise ValueError(msg) + + # Generate new coordinates + coords = tuple( + _ifftfreq(fft_grid.coords[coord]) for coord in dimensional_fft_coords ) + + # Account for true amplitude and true phase + freqs = tuple(fft_grid.coords[coord].values for coord in dimensional_fft_coords) + freqs_2d = np.meshgrid(freqs[1], freqs[0]) # invert order to create meshgrid + fft_grid = fft_grid.copy() + for coord, freq in zip(coords, freqs_2d, strict=True): + shift = coord[0] + spacing = coord[1] - coord[0] + fft_grid *= np.exp(2 * 1j * np.pi * freq * shift) / spacing + + # Compute iFFT + ifft = np.fft.ifftn(np.fft.ifftshift(fft_grid.values)) + + # Build new xr.DataArray + dims = tuple(dim.removeprefix(prefix) for dim in fft_grid.dims) + coords = { + coord.removeprefix(prefix): (dim, coordinate_array) + for coord, dim, coordinate_array in zip( + dimensional_fft_coords, dims, coords, strict=True + ) + } + da = xr.DataArray(ifft, dims=dims, coords=coords) + + return da + + +def _get_spacing(coordinate: xr.DataArray) -> float: + """ + Return spacing of a grid coordinate. + + Parameters + ---------- + coordinate : xarray.DataArray + DataArray containing the coordinate. + coordinate : str + Coordinate name. + + Returns + ------- + spacing : float + """ + spacing = coordinate.values[1] - coordinate.values[0] + if not np.allclose(spacing, coordinate.values[1:] - coordinate.values[:-1]): + msg = f"Invalid '{coordinate.name}' coordinates: they must be evenly spaced." + raise ValueError(msg) + if spacing <= 0: + msg = ( + f"Invalid coordinate '{coordinate.name}': it must be increasingly ordered." + ) + raise ValueError(msg) + return spacing + + +def _get_dimensional_coordinate(grid: xr.DataArray, dim: str) -> str: + """ + Get dimensional coordinate in the grid for a particular dimension. + + Parameters + ---------- + grid : xarray.DataArray + DataArray containing the coordinate. + dim : str + Dimension name. + + Returns + ------- + dimensional_coordinate : str + """ + potential_coords = [ + coord for coord in grid.coords if grid.coords[coord].dims == (dim,) + ] + if not potential_coords: + msg = f"Couldn't find dimensional coordinate for dimension '{dim}'." + raise ValueError(msg) + if len(potential_coords) > 1: + bad_coords = ", ".join(potential_coords) + msg = ( + f"Multiple dimensional coordinates ({bad_coords}) found " + f"for the '{dim}' dimension. " + "Leave only one dimensional coordinate per dimension." + ) + raise ValueError(msg) + (dimensional_coordinate,) = potential_coords + return dimensional_coordinate + + +def _fftfreq(coordinate: xr.DataArray, spacing: float | None = None) -> npt.NDArray: + """ + Get coordinate into the frequency domain. + """ + if coordinate.ndim != 1: + raise ValueError() + if spacing is None: + spacing = _get_spacing(coordinate) + return np.fft.fftshift(np.fft.fftfreq(coordinate.size, spacing)) + + +def _ifftfreq(freq: xr.DataArray, spacing: float | None = None) -> npt.NDArray: + """ + Recover coordinate in the space domain from the frequency domain. + + Shifts the coordinates in the spatial domain if the ``freq`` has a *shift* + attribute. + """ + if freq.ndim != 1: + raise ValueError() + if spacing is None: + spacing = _get_spacing(freq) + coordinate = np.fft.fftshift(np.fft.fftfreq(freq.size, spacing)) + + # Apply static shift if any + if "shift" in freq.attrs: + coordinate += freq.attrs["shift"] - coordinate.min() + + return coordinate diff --git a/test/test_filters.py b/test/test_filters.py index 07ff4b9ce..6d2345ea5 100644 --- a/test/test_filters.py +++ b/test/test_filters.py @@ -142,7 +142,7 @@ def fixture_invalid_grid_with_nans(sample_grid): def test_fft_round_trip(sample_grid): """ - Test if the wrapped fft and ifft functions satisfy a round trip. + Test if the fft and ifft functions satisfy a round trip. """ xrt.assert_allclose(sample_grid, ifft(fft(sample_grid)))