Skip to content

Commit 45b9df6

Browse files
authored
Merge branch 'main' into scanops
2 parents 2769c85 + b419422 commit 45b9df6

10 files changed

Lines changed: 592 additions & 107 deletions

File tree

.github/workflows/CI.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches:
6+
- main
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }}
12+
runs-on: ${{ matrix.runner }}
13+
timeout-minutes: 15
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
version:
18+
- '1.11'
19+
- '1.12'
20+
os:
21+
- Linux
22+
- Windows
23+
- macOS
24+
arch:
25+
- x64
26+
- aarch64
27+
exclude:
28+
- os: Windows
29+
arch: aarch64
30+
- os: macOS
31+
arch: x64
32+
include:
33+
- os: Linux
34+
arch: x64
35+
runner: ubuntu-latest
36+
- os: Linux
37+
arch: aarch64
38+
runner: ubuntu-24.04-arm
39+
- os: Windows
40+
arch: x64
41+
runner: windows-latest
42+
- os: macOS
43+
arch: aarch64
44+
runner: macos-latest
45+
steps:
46+
- uses: actions/checkout@v4
47+
- uses: julia-actions/setup-julia@v2
48+
with:
49+
version: ${{ matrix.version }}
50+
arch: ${{ matrix.arch }}
51+
- uses: julia-actions/cache@v2
52+
- uses: julia-actions/julia-buildpkg@v1
53+
- uses: julia-actions/julia-runtest@v1
54+
with:
55+
test_args: '--quickfail'

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ IRStructurizer = {path = "IRStructurizer"}
1818
CUDAExt = "CUDA"
1919

2020
[compat]
21+
julia = "1.11"
2122
CUDA_Compiler_jll = "0.4"
2223
CUDA_Tile_jll = "13.1"
2324

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ Benchmarks comparing cuTile.jl against cuTile Python on an RTX 5080:
101101
| Matrix Multiplication | 48.3 TFLOPS | 48.6 TFLOPS | OK (=) |
102102
| Layer Normalization | 254 GB/s | 683 GB/s | https://github.com/JuliaGPU/cuTile.jl/issues/1 (-63%) |
103103
| Batch Matrix Multiply | 31.7 TFLOPS | 31.6 TFLOPS | OK (=) |
104+
| FFT (3-stage Cooley-Tukey) | 508 μs | 230 μs | (-55%) |
104105

105106
Compute-intensive kernels (matmul, batch matmul) perform identically to Python. Memory-bound
106107
kernels (vadd, transpose) are within ~3% of Python. The layernorm kernel is slower due to

examples/benchmarks.jl

Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using CUDA
88
using LinearAlgebra
99
using CUDA: GPUArrays
10+
using FFTW
1011
import cuTile as ct
1112

1213
#=============================================================================
@@ -375,6 +376,14 @@ const BATCHMATMUL_TM = 128
375376
const BATCHMATMUL_TN = 256
376377
const BATCHMATMUL_TK = 64
377378

379+
# FFT sizes
380+
# Tile size is (D, BS, N2D), limited by tileiras compiler.
381+
# Current kernel loads all batches per block, limiting scalability.
382+
const FFT_BATCH = 64
383+
const FFT_SIZE = 512
384+
const FFT_FACTORS = (8, 8, 8)
385+
const FFT_ATOM_PACKING_DIM = 2
386+
378387
# SIMT naive kernel (2-pass: compute mean/var, then normalize)
379388
function layernorm_simt_kernel!(X, W, B, Y, Mean, Rstd, N, eps)
380389
m = blockIdx().x
@@ -603,6 +612,228 @@ function benchmark_batchmatmul()
603612
return results
604613
end
605614

615+
#=============================================================================
616+
FFT (3-stage Cooley-Tukey) - Column-Major Version
617+
=============================================================================#
618+
619+
# FFT kernel - 3-stage Cooley-Tukey decomposition (column-major)
620+
# Uses swapped dimensions and right-multiply for column-major compatibility.
621+
# Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
622+
function fft_kernel(
623+
x_packed_in::ct.TileArray{Float32, 3},
624+
y_packed_out::ct.TileArray{Float32, 3},
625+
W0::ct.TileArray{Float32, 3},
626+
W1::ct.TileArray{Float32, 3},
627+
W2::ct.TileArray{Float32, 3},
628+
T0::ct.TileArray{Float32, 3},
629+
T1::ct.TileArray{Float32, 3},
630+
n_const::ct.Constant{Int},
631+
f0_const::ct.Constant{Int},
632+
f1_const::ct.Constant{Int},
633+
f2_const::ct.Constant{Int},
634+
f0f1_const::ct.Constant{Int},
635+
f1f2_const::ct.Constant{Int},
636+
f0f2_const::ct.Constant{Int},
637+
bs_const::ct.Constant{Int},
638+
d_const::ct.Constant{Int},
639+
n2d_const::ct.Constant{Int}
640+
)
641+
N = n_const[]
642+
F0 = f0_const[]
643+
F1 = f1_const[]
644+
F2 = f2_const[]
645+
F0F1 = f0f1_const[]
646+
F1F2 = f1f2_const[]
647+
F0F2 = f0f2_const[]
648+
BS = bs_const[]
649+
D = d_const[]
650+
N2D = n2d_const[]
651+
652+
bid = ct.bid(1)
653+
654+
# Load input (D, BS, N2D) and reshape to (2, BS, N)
655+
X_ri = ct.reshape(ct.load(x_packed_in, (1, bid, 1), (D, BS, N2D)), (2, BS, N))
656+
X_r = ct.reshape(ct.extract(X_ri, (1, 1, 1), (1, BS, N)), (BS, F1F2, F0))
657+
X_i = ct.reshape(ct.extract(X_ri, (2, 1, 1), (1, BS, N)), (BS, F1F2, F0))
658+
659+
# Load DFT matrices
660+
W0_ri = ct.reshape(ct.load(W0, (1, 1, 1), (F0, F0, 2)), (F0, F0, 2))
661+
W0_r = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 1), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
662+
W0_i = ct.broadcast_to(ct.reshape(ct.extract(W0_ri, (1, 1, 2), (F0, F0, 1)), (1, F0, F0)), (BS, F0, F0))
663+
664+
W1_ri = ct.reshape(ct.load(W1, (1, 1, 1), (F1, F1, 2)), (F1, F1, 2))
665+
W1_r = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 1), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
666+
W1_i = ct.broadcast_to(ct.reshape(ct.extract(W1_ri, (1, 1, 2), (F1, F1, 1)), (1, F1, F1)), (BS, F1, F1))
667+
668+
W2_ri = ct.reshape(ct.load(W2, (1, 1, 1), (F2, F2, 2)), (F2, F2, 2))
669+
W2_r = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 1), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
670+
W2_i = ct.broadcast_to(ct.reshape(ct.extract(W2_ri, (1, 1, 2), (F2, F2, 1)), (1, F2, F2)), (BS, F2, F2))
671+
672+
# Load twiddle factors (column-major layout)
673+
T0_ri = ct.reshape(ct.load(T0, (1, 1, 1), (F1F2, F0, 2)), (F1F2, F0, 2))
674+
T0_r = ct.reshape(ct.extract(T0_ri, (1, 1, 1), (F1F2, F0, 1)), (1, N))
675+
T0_i = ct.reshape(ct.extract(T0_ri, (1, 1, 2), (F1F2, F0, 1)), (1, N))
676+
677+
T1_ri = ct.reshape(ct.load(T1, (1, 1, 1), (F0F2, F1, 2)), (F0F2, F1, 2))
678+
T1_r = ct.reshape(ct.extract(T1_ri, (1, 1, 1), (F0F2, F1, 1)), (1, F0F2 * F1))
679+
T1_i = ct.reshape(ct.extract(T1_ri, (1, 1, 2), (F0F2, F1, 1)), (1, F0F2 * F1))
680+
681+
# Stage 0: F0-point DFT via right-multiply
682+
X_r_ = X_r * W0_r - X_i * W0_i
683+
X_i_ = X_r * W0_i + X_i * W0_r
684+
685+
# Twiddle & Permute 0
686+
X_r_flat = ct.reshape(X_r_, (BS, N))
687+
X_i_flat = ct.reshape(X_i_, (BS, N))
688+
X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
689+
X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat
690+
691+
X_r3 = ct.reshape(X_r2, (BS, F2, F1, F0))
692+
X_i3 = ct.reshape(X_i2, (BS, F2, F1, F0))
693+
X_r4 = ct.permute(X_r3, (1, 2, 4, 3))
694+
X_i4 = ct.permute(X_i3, (1, 2, 4, 3))
695+
X_r5 = ct.reshape(X_r4, (BS, F0F2, F1))
696+
X_i5 = ct.reshape(X_i4, (BS, F0F2, F1))
697+
698+
# Stage 1: F1-point DFT
699+
X_r6 = X_r5 * W1_r - X_i5 * W1_i
700+
X_i6 = X_r5 * W1_i + X_i5 * W1_r
701+
702+
# Twiddle & Permute 1
703+
X_r_flat2 = ct.reshape(X_r6, (BS, N))
704+
X_i_flat2 = ct.reshape(X_i6, (BS, N))
705+
X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
706+
X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2
707+
708+
X_r8 = ct.reshape(X_r7, (BS, F2, F0, F1))
709+
X_i8 = ct.reshape(X_i7, (BS, F2, F0, F1))
710+
X_r9 = ct.permute(X_r8, (1, 3, 4, 2))
711+
X_i9 = ct.permute(X_i8, (1, 3, 4, 2))
712+
X_r10 = ct.reshape(X_r9, (BS, F0F1, F2))
713+
X_i10 = ct.reshape(X_i9, (BS, F0F1, F2))
714+
715+
# Stage 2: F2-point DFT
716+
X_r11 = X_r10 * W2_r - X_i10 * W2_i
717+
X_i11 = X_r10 * W2_i + X_i10 * W2_r
718+
719+
# Final output
720+
X_r_final = ct.reshape(X_r11, (1, BS, N))
721+
X_i_final = ct.reshape(X_i11, (1, BS, N))
722+
723+
# Concatenate and Store
724+
Y_ri = ct.reshape(ct.cat((X_r_final, X_i_final), 1), (D, BS, N2D))
725+
ct.store(y_packed_out, (1, bid, 1), Y_ri)
726+
727+
return
728+
end
729+
730+
# Helper: Generate DFT matrix
731+
function fft_dft_matrix(size::Int)
732+
W = zeros(ComplexF32, size, size)
733+
for i in 0:size-1, j in 0:size-1
734+
W[i+1, j+1] = exp(-2π * im * i * j / size)
735+
end
736+
result = zeros(Float32, size, size, 2)
737+
result[:, :, 1] = Float32.(real.(W))
738+
result[:, :, 2] = Float32.(imag.(W))
739+
return result
740+
end
741+
742+
# Twiddle factors T0 for column-major layout (F1F2, F0)
743+
function fft_make_twiddles_T0(F0::Int, F1F2::Int, N::Int)
744+
T0 = zeros(Float32, F1F2, F0, 2)
745+
for j in 0:F1F2-1, i in 0:F0-1
746+
val = exp(-2π * im * i * j / N)
747+
T0[j+1, i+1, 1] = Float32(real(val))
748+
T0[j+1, i+1, 2] = Float32(imag(val))
749+
end
750+
return T0
751+
end
752+
753+
# Twiddle factors T1 for column-major layout (F0F2, F1)
754+
function fft_make_twiddles_T1(F0::Int, F1::Int, F2::Int)
755+
F0F2 = F0 * F2
756+
F1F2 = F1 * F2
757+
T1 = zeros(Float32, F0F2, F1, 2)
758+
for k in 0:F0F2-1, j in 0:F1-1
759+
f2 = k % F2
760+
val = exp(-2π * im * j * f2 / F1F2)
761+
T1[k+1, j+1, 1] = Float32(real(val))
762+
T1[k+1, j+1, 2] = Float32(imag(val))
763+
end
764+
return T1
765+
end
766+
767+
function fft_make_twiddles(factors::NTuple{3, Int})
768+
F0, F1, F2 = factors
769+
N = F0 * F1 * F2
770+
F1F2 = F1 * F2
771+
W0 = fft_dft_matrix(F0)
772+
W1 = fft_dft_matrix(F1)
773+
W2 = fft_dft_matrix(F2)
774+
T0 = fft_make_twiddles_T0(F0, F1F2, N)
775+
T1 = fft_make_twiddles_T1(F0, F1, F2)
776+
return (W0, W1, W2, T0, T1)
777+
end
778+
779+
function benchmark_fft()
780+
println("\nBenchmarking FFT...")
781+
BS, N = FFT_BATCH, FFT_SIZE
782+
F0, F1, F2 = FFT_FACTORS
783+
D = FFT_ATOM_PACKING_DIM
784+
println(" Size: $BS batches × $N FFT ($(BS * N * 8 / 1e6) MB)")
785+
786+
# Create complex input
787+
CUDA.seed!(42)
788+
input = CUDA.randn(ComplexF32, BS, N)
789+
790+
# Reference result (FFTW)
791+
reference = FFTW.fft(Array(input), 2)
792+
793+
results = BenchmarkResult[]
794+
795+
# Pre-compute twiddles (one-time CPU cost)
796+
W0, W1, W2, T0, T1 = fft_make_twiddles(FFT_FACTORS)
797+
W0_gpu, W1_gpu, W2_gpu = CuArray(W0), CuArray(W1), CuArray(W2)
798+
T0_gpu, T1_gpu = CuArray(T0), CuArray(T1)
799+
800+
# Pre-pack input (zero-copy)
801+
N2D = N * 2 ÷ D
802+
x_packed = reinterpret(reshape, Float32, input)
803+
y_packed = CUDA.zeros(Float32, D, BS, N2D)
804+
805+
# Kernel launch parameters
806+
F0F1, F1F2, F0F2 = F0 * F1, F1 * F2, F0 * F2
807+
grid = (BS, 1, 1)
808+
809+
# Kernel-only timing function
810+
cutile_kernel_f = () -> ct.launch(fft_kernel, grid,
811+
x_packed, y_packed,
812+
W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
813+
ct.Constant(N), ct.Constant(F0), ct.Constant(F1), ct.Constant(F2),
814+
ct.Constant(F0F1), ct.Constant(F1F2), ct.Constant(F0F2),
815+
ct.Constant(BS), ct.Constant(D), ct.Constant(N2D))
816+
817+
# Verify correctness
818+
cutile_kernel_f()
819+
CUDA.synchronize()
820+
y_complex = reinterpret(reshape, ComplexF32, y_packed)
821+
output = copy(y_complex)
822+
@assert isapprox(Array(output), reference, rtol=1e-3) "cuTile FFT incorrect!"
823+
824+
# Benchmark kernel only
825+
min_t, mean_t = benchmark_kernel(cutile_kernel_f)
826+
push!(results, BenchmarkResult("cuTile.jl", min_t, mean_t))
827+
828+
# Performance metric: GFLOPS (5 * N * log2(N) per complex FFT)
829+
flops_per_fft = 5.0 * N * log2(N)
830+
total_flops = BS * flops_per_fft
831+
gflops = [string(round(total_flops / (r.min_ms * 1e-3) / 1e9, digits=1), " GFLOPS") for r in results]
832+
833+
print_table("FFT (ComplexF32)", results; extra_col=("Performance", gflops))
834+
return results
835+
end
836+
606837
#=============================================================================
607838
Main
608839
=============================================================================#
@@ -622,6 +853,7 @@ function main()
622853
matmul_results = benchmark_matmul()
623854
layernorm_results = benchmark_layernorm()
624855
batchmatmul_results = benchmark_batchmatmul()
856+
fft_results = benchmark_fft()
625857

626858
println()
627859
println("=" ^ 60)

0 commit comments

Comments
 (0)