Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Removed
* Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243)

### Fixed
* Fix `TypeError` exception raised with empty axes [gh-288](https://github.com/IntelPython/mkl_fft/pull/288)

## [2.1.2] - 2025-12-02

### Added
Expand Down
21 changes: 17 additions & 4 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,10 @@ def _iter_complementary(x, axes, func, kwargs, result):
m_ind = _flat_to_multi(ind, sub_shape)
for k1, k2 in zip(dual_ind, m_ind):
sl[k1] = k2
tsl = tuple(sl)

if np.issubdtype(x.dtype, np.complexfloating):
func(x[tuple(sl)], **kwargs, out=result[tuple(sl)])
func(x[tsl], **kwargs, out=result[tsl])
else:
# For c2c FFT, if the input is real, half of the output is the
# complex conjugate of the other half. Instead of upcasting the
Expand All @@ -247,7 +249,7 @@ def _iter_complementary(x, axes, func, kwargs, result):
# array appeared in the second half of the NumPy output array,
# while the equivalent element in the NumPy array was the conjugate
# of the mkl_fft output array.
np.copyto(result[tuple(sl)], func(x[tuple(sl)], **kwargs))
np.copyto(result[tsl], func(x[tsl], **kwargs))

return result

Expand All @@ -260,7 +262,6 @@ def _iter_fftnd(
direction=+1,
scale_function=lambda ind: 1.0,
):
a = np.asarray(a)
Comment thread
antonwolfy marked this conversation as resolved.
s, axes = _init_nd_shape_and_axes(a, s, axes)

# Combine the two, but in reverse, to end with the first axis given.
Expand Down Expand Up @@ -412,8 +413,20 @@ def _c2c_fftnd_impl(
out=out,
)
else:
x = np.asarray(x)

# Fast path: FFT over no axes is complete identity (preserve dtype)
_, xa = _cook_nd_args(x, s, axes)
if len(xa) == 0:
if out is None:
out = x.copy()
else:
_validate_out_array(out, x, x.dtype)
np.copyto(out, x)
# No scaling applied - identity transform has no normalization
return out

if _complementary and x.dtype in valid_dtypes:
x = np.asarray(x)
if out is None:
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
else:
Expand Down
45 changes: 45 additions & 0 deletions mkl_fft/tests/test_fftnd.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,48 @@ def test_out_strided(axes, func):
expected = getattr(np.fft, func)(x, axes=axes, out=out)

assert_allclose(result, expected, strict=True)


@pytest.mark.parametrize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("shape", [(3, 4), (5,), (2, 3, 4), (10, 20)])
@pytest.mark.parametrize("norm", [None, "ortho", "forward", "backward"])
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
def test_empty_axes(dtype, shape, norm, func):
if np.issubdtype(dtype, np.complexfloating):
x = rnd.random(shape).astype(dtype) + 1j * rnd.random(shape).astype(
dtype
)
else:
x = rnd.random(shape).astype(dtype)

# Test fftn with axes=()
result = getattr(mkl_fft, func)(x, axes=(), norm=norm)
expected = getattr(np.fft, func)(x, axes=(), norm=norm)

rtol, atol = _get_rtol_atol(result)
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)


@pytest.mark.parametrize(
"dtype", [np.float32, np.float64, np.complex64, np.complex128]
)
@pytest.mark.parametrize("func", ["fftn", "ifftn", "fft2", "ifft2"])
def test_empty_axes_with_out(dtype, func):
if np.issubdtype(dtype, np.complexfloating):
x = rnd.random((3, 4)).astype(dtype) + 1j * rnd.random((3, 4)).astype(
dtype
)
else:
x = rnd.random((3, 4)).astype(dtype)

# For axes=(), output dtype should match input dtype (identity transform)
out = np.empty_like(x, dtype=dtype)
result = getattr(mkl_fft, func)(x, axes=(), out=out)
expected = getattr(np.fft, func)(x, axes=())

# Result should be written to out
assert result is out
rtol, atol = _get_rtol_atol(result)
assert_allclose(result, expected, rtol=rtol, atol=atol, strict=True)
Loading