Skip to content

Commit 207493e

Browse files
committed
Add CUDA backend for FFT and IFFT
1 parent e2e2c99 commit 207493e

11 files changed

Lines changed: 271 additions & 22 deletions

File tree

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ if(BUILD_METAL)
2222
add_compile_options(-DMODMESH_METAL)
2323
endif()
2424

25+
option(BUILD_CUDA "build with CUDA" OFF)
26+
message(STATUS "BUILD_CUDA: ${BUILD_CUDA}")
27+
2528
option(USE_CLANG_TIDY "use clang-tidy" OFF)
2629
option(LINT_AS_ERRORS "clang-tidy warnings as errors" OFF)
2730

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ DEBUG_SYMBOL ?= ON
2424
MODMESH_PROFILE ?= OFF
2525
BUILD_METAL ?= OFF
2626
BUILD_QT ?= ON
27+
BUILD_CUDA ?= OFF
2728
USE_CLANG_TIDY ?= OFF
2829
CMAKE_BUILD_TYPE ?= Release
2930
MAKE_PARALLEL ?= -j
@@ -89,6 +90,7 @@ CMAKE_CMD = cmake $(MODMESH_ROOT) \
8990
-DDEBUG_SYMBOL=$(DEBUG_SYMBOL) \
9091
-DBUILD_METAL=$(BUILD_METAL) \
9192
-DBUILD_QT=$(BUILD_QT) \
93+
-DBUILD_CUDA=$(BUILD_CUDA) \
9294
-DUSE_CLANG_TIDY=$(USE_CLANG_TIDY) \
9395
-DLINT_AS_ERRORS=ON \
9496
-DMODMESH_PROFILE=$(MODMESH_PROFILE) \

cpp/modmesh/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,17 @@ if (BUILD_QT)
3030
find_package(Qt6 REQUIRED COMPONENTS 3DExtras)
3131
endif () # BUILD_QT
3232

33+
if (BUILD_CUDA)
34+
find_package(CUDA REQUIRED)
35+
find_package(CUDAToolkit REQUIRED)
36+
enable_language(CUDA)
37+
if (TARGET CUDA::cufft)
38+
message(STATUS "CUDA cuFFT available")
39+
else ()
40+
message(FATAL_ERROR "CUDA cuFFT not found")
41+
endif ()
42+
endif () # BUILD_CUDA
43+
3344
add_subdirectory(buffer)
3445
add_subdirectory(mesh)
3546
add_subdirectory(toggle)
@@ -135,6 +146,12 @@ else () # BUILD_QT
135146
)
136147
endif () # BUILD_QT
137148

149+
if (BUILD_CUDA)
150+
set_target_properties(modmesh_primary PROPERTIES LINKER_LANGUAGE CUDA)
151+
target_compile_definitions(modmesh_primary PRIVATE BUILD_CUDA)
152+
target_link_libraries(modmesh_primary PRIVATE CUDA::cudart CUDA::cufft)
153+
endif () # BUILD_CUDA
154+
138155
set_target_properties(modmesh_primary PROPERTIES POSITION_INDEPENDENT_CODE ON)
139156

140157
if (CLANG_TIDY_EXE AND USE_CLANG_TIDY)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
#pragma once
2+
3+
/*
4+
* Copyright (c) 2025, Alex Chiang <jyxemperor@gmail.com>
5+
*
6+
* Redistribution and use in source and binary forms, with or without
7+
* modification, are permitted provided that the following conditions are met:
8+
*
9+
* - Redistributions of source code must retain the above copyright notice,
10+
* this list of conditions and the following disclaimer.
11+
* - Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
* - Neither the name of the copyright holder nor the names of its contributors
15+
* may be used to endorse or promote products derived from this software
16+
* without specific prior written permission.
17+
*
18+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21+
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28+
* POSSIBILITY OF SUCH DAMAGE.
29+
*/
30+
31+
#include <cuda.h>
32+
#include <cuda_runtime.h>
33+
#include <cufft.h>
34+
#include <stdio.h>
35+
36+
#define CUDA_SAFE_CALL(err) __cudaSafeCall(err, __FILE__, __LINE__)
37+
#define CUFFT_SAFE_CALL(err) __cufftSafeCall(err, __FILE__, __LINE__)
38+
#define CUDA_GET_LAST_ERROR() __cudaCheckError(__FILE__, __LINE__)
39+
40+
void inline __cudaSafeCall(cudaError_t err, const char * file, const int line)
41+
{
42+
if (err != cudaSuccess)
43+
{
44+
printf("CUDA Error %d: %s.\n%s(%d)\n", (int)err, cudaGetErrorString(err), file, line);
45+
}
46+
}
47+
48+
void inline __cudaCheckError(const char * file, const int line)
49+
{
50+
cudaError_t err = cudaDeviceSynchronize();
51+
if (err != cudaSuccess)
52+
{
53+
printf("CUDA Error %d: %s.\n%s(%d)\n", (int)err, cudaGetErrorString(err), file, line);
54+
}
55+
56+
err = cudaGetLastError();
57+
if (err != cudaSuccess)
58+
{
59+
printf("CUDA Error %d: %s.\n%s(%d)\n", (int)err, cudaGetErrorString(err), file, line);
60+
}
61+
}
62+
63+
const inline char * __cufftResultToString(cufftResult err)
64+
{
65+
switch (err)
66+
{
67+
case CUFFT_SUCCESS: return "CUFFT_SUCCESS.";
68+
case CUFFT_INVALID_PLAN: return "CUFFT_INVALID_PLAN.";
69+
case CUFFT_ALLOC_FAILED: return "CUFFT_ALLOC_FAILED.";
70+
case CUFFT_INVALID_TYPE: return "CUFFT_INVALID_TYPE.";
71+
case CUFFT_INVALID_VALUE: return "CUFFT_INVALID_VALUE.";
72+
case CUFFT_INTERNAL_ERROR: return "CUFFT_INTERNAL_ERROR.";
73+
case CUFFT_EXEC_FAILED: return "CUFFT_EXEC_FAILED.";
74+
case CUFFT_SETUP_FAILED: return "CUFFT_SETUP_FAILED.";
75+
case CUFFT_INVALID_SIZE: return "CUFFT_INVALID_SIZE.";
76+
case CUFFT_UNALIGNED_DATA: return "CUFFT_UNALIGNED_DATA.";
77+
default: return "CUFFT Unknown error code.";
78+
}
79+
}
80+
81+
void inline __cufftSafeCall(cufftResult err, const char * file, const int line)
82+
{
83+
if (CUFFT_SUCCESS != err)
84+
{
85+
printf("CUFFT error %d: %s\n%s(%d)\n", (int)err, __cufftResultToString(err), file, line);
86+
}
87+
}
88+
89+
// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4:

cpp/modmesh/transform/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ set(MODMESH_TRANSFORM_HEADERS
55
${CMAKE_CURRENT_SOURCE_DIR}/fourier.hpp
66
CACHE FILEPATH "" FORCE)
77

8+
if (BUILD_CUDA)
9+
list(APPEND MODMESH_TRANSFORM_HEADERS
10+
${CMAKE_CURRENT_SOURCE_DIR}/fourier.cuh)
11+
endif()
12+
813
set(MODMESH_TRANSFORM_SOURCES
914
${CMAKE_CURRENT_SOURCE_DIR}/fourier.cpp
1015
CACHE FILEPATH "" FORCE)

cpp/modmesh/transform/fourier.cuh

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
#pragma once
2+
3+
/*
4+
* Copyright (c) 2025, Alex Chiang <jyxemperor@gmail.com>
5+
*
6+
* Redistribution and use in source and binary forms, with or without
7+
* modification, are permitted provided that the following conditions are met:
8+
*
9+
* - Redistributions of source code must retain the above copyright notice,
10+
* this list of conditions and the following disclaimer.
11+
* - Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
* - Neither the name of the copyright holder nor the names of its contributors
15+
* may be used to endorse or promote products derived from this software
16+
* without specific prior written permission.
17+
*
18+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21+
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
22+
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23+
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24+
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25+
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26+
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27+
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28+
* POSSIBILITY OF SUCH DAMAGE.
29+
*/
30+
31+
#include <modmesh/modmesh.hpp>
32+
#include <modmesh/buffer/buffer.hpp>
33+
#include <modmesh/device/cuda/cuda_error_handle.hpp>
34+
35+
#define FFT_CUDA_IMPL(CUFFT_DATA_TYPE, CUFFT_EXEC_TYPE) \
36+
{ \
37+
cufftHandle plan; \
38+
CUFFT_DATA_TYPE * host_in = nullptr; \
39+
CUFFT_DATA_TYPE * host_out = nullptr; \
40+
CUFFT_DATA_TYPE * device_in = nullptr; \
41+
CUFFT_DATA_TYPE * device_out = nullptr; \
42+
host_in = (CUFFT_DATA_TYPE*)malloc(sizeof(CUFFT_DATA_TYPE) * N); \
43+
host_out = (CUFFT_DATA_TYPE*)malloc(sizeof(CUFFT_DATA_TYPE) * N); \
44+
for (size_t i = 0; i < N; ++i) \
45+
{ \
46+
host_in[i].x = in[i].real(); \
47+
host_in[i].y = in[i].imag(); \
48+
} \
49+
CUDA_SAFE_CALL(cudaMalloc((void**)&device_in, sizeof(CUFFT_DATA_TYPE) * N)); \
50+
CUDA_SAFE_CALL(cudaMalloc((void**)&device_out, sizeof(CUFFT_DATA_TYPE) * N)); \
51+
CUDA_SAFE_CALL(cudaMemcpy(device_in, host_in, sizeof(CUFFT_DATA_TYPE) * N, cudaMemcpyHostToDevice)); \
52+
CUFFT_SAFE_CALL(cufftPlan1d(&plan, N, CUFFT_##CUFFT_EXEC_TYPE, 1)); \
53+
CUFFT_SAFE_CALL(cufftExec##CUFFT_EXEC_TYPE(plan, device_in, device_out, CUFFT_FORWARD)); \
54+
CUDA_SAFE_CALL(cudaMemcpy(host_out, device_out, sizeof(CUFFT_DATA_TYPE) * N, cudaMemcpyDeviceToHost)); \
55+
for (size_t i = 0; i < N; ++i) \
56+
{ \
57+
out[i] = T1<T2>{ host_out[i].x, host_out[i].y }; \
58+
} \
59+
CUFFT_SAFE_CALL(cufftDestroy(plan)); \
60+
CUDA_SAFE_CALL(cudaFree(device_in)); \
61+
CUDA_SAFE_CALL(cudaFree(device_out)); \
62+
free(host_in); \
63+
free(host_out); \
64+
}
65+
66+
namespace modmesh
67+
{
68+
69+
template <template <typename> class T1, typename T2>
70+
void fft_cuda(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out)
71+
{
72+
size_t N = in.size();
73+
if constexpr (std::is_same_v<T2, float>)
74+
{
75+
FFT_CUDA_IMPL(cufftComplex, C2C)
76+
}
77+
else if constexpr (std::is_same_v<T2, double>)
78+
{
79+
FFT_CUDA_IMPL(cufftDoubleComplex, Z2Z)
80+
}
81+
}
82+
83+
} /* namespace modmesh */
84+
85+
// vim: set ff=unix fenc=utf8 et sw=4 ts=4 sts=4:

cpp/modmesh/transform/fourier.hpp

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
#include <modmesh/math/math.hpp>
44
#include <modmesh/buffer/buffer.hpp>
55

6+
#if defined(BUILD_CUDA)
7+
#include <modmesh/transform/fourier.cuh>
8+
#endif
9+
610
namespace modmesh
711
{
812

@@ -63,22 +67,37 @@ class FourierTransform
6367
FourierTransform & operator=(FourierTransform && other) = delete;
6468

6569
template <template <typename> class T1, typename T2>
66-
static void fft(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out)
70+
static void fft(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out, std::string const && backend)
6771
{
6872
const size_t N = in.size();
6973

70-
if ((N & (N - 1)) == 0)
74+
if (backend == "cpu")
75+
{
76+
if ((N & (N - 1)) == 0)
77+
{
78+
detail::fft_radix_2<T1, T2>(in, out);
79+
}
80+
else
81+
{
82+
detail::fft_bluestein<T1, T2>(in, out);
83+
}
84+
}
85+
else if (backend == "cuda")
7186
{
72-
detail::fft_radix_2<T1, T2>(in, out);
87+
#if defined(BUILD_CUDA)
88+
modmesh::fft_cuda<T1, T2>(in, out);
89+
#else
90+
throw std::runtime_error("CUDA is not available.");
91+
#endif
7392
}
7493
else
7594
{
76-
detail::fft_bluestein<T1, T2>(in, out);
95+
throw std::runtime_error("unsupported backend.");
7796
}
7897
}
7998

8099
template <template <typename> class T1, typename T2>
81-
static void ifft(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out)
100+
static void ifft(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out, std::string const && backend)
82101
{
83102
size_t N = in.size();
84103
SimpleArray<T1<T2>> in_conj{modmesh::small_vector<size_t>{N}, T1<T2>{0.0, 0.0}};
@@ -88,7 +107,7 @@ class FourierTransform
88107
in_conj[i] = in[i].conj();
89108
}
90109

91-
fft<T1, T2>(in_conj, out);
110+
fft<T1, T2>(in_conj, out, std::move(backend));
92111

93112
for (size_t i = 0; i < N; ++i)
94113
{
@@ -152,7 +171,7 @@ void fft_bluestein(SimpleArray<T1<T2>> const & in, SimpleArray<T1<T2>> & out)
152171
A[i] *= B[i];
153172
}
154173

155-
FourierTransform::ifft<T1, T2>(A, a);
174+
FourierTransform::ifft<T1, T2>(A, a, "cpu");
156175

157176
for (size_t i = 0; i < N; ++i)
158177
{

cpp/modmesh/transform/pymod/wrap_fourier.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ class MODMESH_PYTHON_WRAPPER_VISIBILITY WrapFourierTransform
5353
namespace py = pybind11; // NOLINT(misc-unused-alias-decls)
5454

5555
(*this)
56-
.def_static("fft", &wrapped_type::fft<modmesh::Complex, double>, py::arg("input"), py::arg("output"))
57-
.def_static("fft", &wrapped_type::fft<modmesh::Complex, float>, py::arg("input"), py::arg("output"))
58-
.def_static("ifft", &wrapped_type::ifft<modmesh::Complex, double>, py::arg("input"), py::arg("output"))
59-
.def_static("ifft", &wrapped_type::ifft<modmesh::Complex, float>, py::arg("input"), py::arg("output"))
56+
.def_static("fft", &wrapped_type::fft<modmesh::Complex, double>, py::arg("input"), py::arg("output"), py::arg("backend"))
57+
.def_static("fft", &wrapped_type::fft<modmesh::Complex, float>, py::arg("input"), py::arg("output"), py::arg("backend"))
58+
.def_static("ifft", &wrapped_type::ifft<modmesh::Complex, double>, py::arg("input"), py::arg("output"), py::arg("backend"))
59+
.def_static("ifft", &wrapped_type::ifft<modmesh::Complex, float>, py::arg("input"), py::arg("output"), py::arg("backend"))
6060
.def_static("dft", &wrapped_type::dft<modmesh::Complex, double>, py::arg("input"), py::arg("output"))
6161
.def_static("dft", &wrapped_type::dft<modmesh::Complex, float>, py::arg("input"), py::arg("output"));
6262
}

gtests/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ target_link_libraries(
3939
GTest::gmock_main
4040
)
4141

42+
if (BUILD_CUDA)
43+
find_package(CUDA REQUIRED)
44+
find_package(CUDAToolkit REQUIRED)
45+
enable_language(CUDA)
46+
target_link_libraries(
47+
test_nopython
48+
CUDA::cudart
49+
CUDA::cufft
50+
)
51+
target_compile_definitions(test_nopython PRIVATE BUILD_CUDA)
52+
endif()
53+
4254
include(GoogleTest)
4355
gtest_discover_tests(test_nopython)
4456

0 commit comments

Comments
 (0)