77using CUDA
88using LinearAlgebra
99using CUDA: GPUArrays
10+ using FFTW
1011import cuTile as ct
1112
1213#= ============================================================================
@@ -375,6 +376,14 @@ const BATCHMATMUL_TM = 128
375376const BATCHMATMUL_TN = 256
376377const BATCHMATMUL_TK = 64
377378
379+ # FFT sizes
380+ # Tile size is (D, BS, N2D), limited by tileiras compiler.
381+ # Current kernel loads all batches per block, limiting scalability.
382+ const FFT_BATCH = 64
383+ const FFT_SIZE = 512
384+ const FFT_FACTORS = (8 , 8 , 8 )
385+ const FFT_ATOM_PACKING_DIM = 2
386+
378387# SIMT naive kernel (2-pass: compute mean/var, then normalize)
379388function layernorm_simt_kernel! (X, W, B, Y, Mean, Rstd, N, eps)
380389 m = blockIdx (). x
@@ -603,6 +612,228 @@ function benchmark_batchmatmul()
603612 return results
604613end
605614
615+ #= ============================================================================
616+ FFT (3-stage Cooley-Tukey) - Column-Major Version
617+ =============================================================================#
618+
619+ # FFT kernel - 3-stage Cooley-Tukey decomposition (column-major)
620+ # Uses swapped dimensions and right-multiply for column-major compatibility.
621+ # Input/output layout: (D, BS, N2D) where D=2 for real/imag interleaving.
622+ function fft_kernel (
623+ x_packed_in:: ct.TileArray{Float32, 3} ,
624+ y_packed_out:: ct.TileArray{Float32, 3} ,
625+ W0:: ct.TileArray{Float32, 3} ,
626+ W1:: ct.TileArray{Float32, 3} ,
627+ W2:: ct.TileArray{Float32, 3} ,
628+ T0:: ct.TileArray{Float32, 3} ,
629+ T1:: ct.TileArray{Float32, 3} ,
630+ n_const:: ct.Constant{Int} ,
631+ f0_const:: ct.Constant{Int} ,
632+ f1_const:: ct.Constant{Int} ,
633+ f2_const:: ct.Constant{Int} ,
634+ f0f1_const:: ct.Constant{Int} ,
635+ f1f2_const:: ct.Constant{Int} ,
636+ f0f2_const:: ct.Constant{Int} ,
637+ bs_const:: ct.Constant{Int} ,
638+ d_const:: ct.Constant{Int} ,
639+ n2d_const:: ct.Constant{Int}
640+ )
641+ N = n_const[]
642+ F0 = f0_const[]
643+ F1 = f1_const[]
644+ F2 = f2_const[]
645+ F0F1 = f0f1_const[]
646+ F1F2 = f1f2_const[]
647+ F0F2 = f0f2_const[]
648+ BS = bs_const[]
649+ D = d_const[]
650+ N2D = n2d_const[]
651+
652+ bid = ct. bid (1 )
653+
654+ # Load input (D, BS, N2D) and reshape to (2, BS, N)
655+ X_ri = ct. reshape (ct. load (x_packed_in, (1 , bid, 1 ), (D, BS, N2D)), (2 , BS, N))
656+ X_r = ct. reshape (ct. extract (X_ri, (1 , 1 , 1 ), (1 , BS, N)), (BS, F1F2, F0))
657+ X_i = ct. reshape (ct. extract (X_ri, (2 , 1 , 1 ), (1 , BS, N)), (BS, F1F2, F0))
658+
659+ # Load DFT matrices
660+ W0_ri = ct. reshape (ct. load (W0, (1 , 1 , 1 ), (F0, F0, 2 )), (F0, F0, 2 ))
661+ W0_r = ct. broadcast_to (ct. reshape (ct. extract (W0_ri, (1 , 1 , 1 ), (F0, F0, 1 )), (1 , F0, F0)), (BS, F0, F0))
662+ W0_i = ct. broadcast_to (ct. reshape (ct. extract (W0_ri, (1 , 1 , 2 ), (F0, F0, 1 )), (1 , F0, F0)), (BS, F0, F0))
663+
664+ W1_ri = ct. reshape (ct. load (W1, (1 , 1 , 1 ), (F1, F1, 2 )), (F1, F1, 2 ))
665+ W1_r = ct. broadcast_to (ct. reshape (ct. extract (W1_ri, (1 , 1 , 1 ), (F1, F1, 1 )), (1 , F1, F1)), (BS, F1, F1))
666+ W1_i = ct. broadcast_to (ct. reshape (ct. extract (W1_ri, (1 , 1 , 2 ), (F1, F1, 1 )), (1 , F1, F1)), (BS, F1, F1))
667+
668+ W2_ri = ct. reshape (ct. load (W2, (1 , 1 , 1 ), (F2, F2, 2 )), (F2, F2, 2 ))
669+ W2_r = ct. broadcast_to (ct. reshape (ct. extract (W2_ri, (1 , 1 , 1 ), (F2, F2, 1 )), (1 , F2, F2)), (BS, F2, F2))
670+ W2_i = ct. broadcast_to (ct. reshape (ct. extract (W2_ri, (1 , 1 , 2 ), (F2, F2, 1 )), (1 , F2, F2)), (BS, F2, F2))
671+
672+ # Load twiddle factors (column-major layout)
673+ T0_ri = ct. reshape (ct. load (T0, (1 , 1 , 1 ), (F1F2, F0, 2 )), (F1F2, F0, 2 ))
674+ T0_r = ct. reshape (ct. extract (T0_ri, (1 , 1 , 1 ), (F1F2, F0, 1 )), (1 , N))
675+ T0_i = ct. reshape (ct. extract (T0_ri, (1 , 1 , 2 ), (F1F2, F0, 1 )), (1 , N))
676+
677+ T1_ri = ct. reshape (ct. load (T1, (1 , 1 , 1 ), (F0F2, F1, 2 )), (F0F2, F1, 2 ))
678+ T1_r = ct. reshape (ct. extract (T1_ri, (1 , 1 , 1 ), (F0F2, F1, 1 )), (1 , F0F2 * F1))
679+ T1_i = ct. reshape (ct. extract (T1_ri, (1 , 1 , 2 ), (F0F2, F1, 1 )), (1 , F0F2 * F1))
680+
681+ # Stage 0: F0-point DFT via right-multiply
682+ X_r_ = X_r * W0_r - X_i * W0_i
683+ X_i_ = X_r * W0_i + X_i * W0_r
684+
685+ # Twiddle & Permute 0
686+ X_r_flat = ct. reshape (X_r_, (BS, N))
687+ X_i_flat = ct. reshape (X_i_, (BS, N))
688+ X_r2 = T0_r .* X_r_flat .- T0_i .* X_i_flat
689+ X_i2 = T0_i .* X_r_flat .+ T0_r .* X_i_flat
690+
691+ X_r3 = ct. reshape (X_r2, (BS, F2, F1, F0))
692+ X_i3 = ct. reshape (X_i2, (BS, F2, F1, F0))
693+ X_r4 = ct. permute (X_r3, (1 , 2 , 4 , 3 ))
694+ X_i4 = ct. permute (X_i3, (1 , 2 , 4 , 3 ))
695+ X_r5 = ct. reshape (X_r4, (BS, F0F2, F1))
696+ X_i5 = ct. reshape (X_i4, (BS, F0F2, F1))
697+
698+ # Stage 1: F1-point DFT
699+ X_r6 = X_r5 * W1_r - X_i5 * W1_i
700+ X_i6 = X_r5 * W1_i + X_i5 * W1_r
701+
702+ # Twiddle & Permute 1
703+ X_r_flat2 = ct. reshape (X_r6, (BS, N))
704+ X_i_flat2 = ct. reshape (X_i6, (BS, N))
705+ X_r7 = T1_r .* X_r_flat2 .- T1_i .* X_i_flat2
706+ X_i7 = T1_i .* X_r_flat2 .+ T1_r .* X_i_flat2
707+
708+ X_r8 = ct. reshape (X_r7, (BS, F2, F0, F1))
709+ X_i8 = ct. reshape (X_i7, (BS, F2, F0, F1))
710+ X_r9 = ct. permute (X_r8, (1 , 3 , 4 , 2 ))
711+ X_i9 = ct. permute (X_i8, (1 , 3 , 4 , 2 ))
712+ X_r10 = ct. reshape (X_r9, (BS, F0F1, F2))
713+ X_i10 = ct. reshape (X_i9, (BS, F0F1, F2))
714+
715+ # Stage 2: F2-point DFT
716+ X_r11 = X_r10 * W2_r - X_i10 * W2_i
717+ X_i11 = X_r10 * W2_i + X_i10 * W2_r
718+
719+ # Final output
720+ X_r_final = ct. reshape (X_r11, (1 , BS, N))
721+ X_i_final = ct. reshape (X_i11, (1 , BS, N))
722+
723+ # Concatenate and Store
724+ Y_ri = ct. reshape (ct. cat ((X_r_final, X_i_final), 1 ), (D, BS, N2D))
725+ ct. store (y_packed_out, (1 , bid, 1 ), Y_ri)
726+
727+ return
728+ end
729+
730+ # Helper: Generate DFT matrix
731+ function fft_dft_matrix (size:: Int )
732+ W = zeros (ComplexF32, size, size)
733+ for i in 0 : size- 1 , j in 0 : size- 1
734+ W[i+ 1 , j+ 1 ] = exp (- 2 π * im * i * j / size)
735+ end
736+ result = zeros (Float32, size, size, 2 )
737+ result[:, :, 1 ] = Float32 .(real .(W))
738+ result[:, :, 2 ] = Float32 .(imag .(W))
739+ return result
740+ end
741+
742+ # Twiddle factors T0 for column-major layout (F1F2, F0)
743+ function fft_make_twiddles_T0 (F0:: Int , F1F2:: Int , N:: Int )
744+ T0 = zeros (Float32, F1F2, F0, 2 )
745+ for j in 0 : F1F2- 1 , i in 0 : F0- 1
746+ val = exp (- 2 π * im * i * j / N)
747+ T0[j+ 1 , i+ 1 , 1 ] = Float32 (real (val))
748+ T0[j+ 1 , i+ 1 , 2 ] = Float32 (imag (val))
749+ end
750+ return T0
751+ end
752+
753+ # Twiddle factors T1 for column-major layout (F0F2, F1)
754+ function fft_make_twiddles_T1 (F0:: Int , F1:: Int , F2:: Int )
755+ F0F2 = F0 * F2
756+ F1F2 = F1 * F2
757+ T1 = zeros (Float32, F0F2, F1, 2 )
758+ for k in 0 : F0F2- 1 , j in 0 : F1- 1
759+ f2 = k % F2
760+ val = exp (- 2 π * im * j * f2 / F1F2)
761+ T1[k+ 1 , j+ 1 , 1 ] = Float32 (real (val))
762+ T1[k+ 1 , j+ 1 , 2 ] = Float32 (imag (val))
763+ end
764+ return T1
765+ end
766+
767+ function fft_make_twiddles (factors:: NTuple{3, Int} )
768+ F0, F1, F2 = factors
769+ N = F0 * F1 * F2
770+ F1F2 = F1 * F2
771+ W0 = fft_dft_matrix (F0)
772+ W1 = fft_dft_matrix (F1)
773+ W2 = fft_dft_matrix (F2)
774+ T0 = fft_make_twiddles_T0 (F0, F1F2, N)
775+ T1 = fft_make_twiddles_T1 (F0, F1, F2)
776+ return (W0, W1, W2, T0, T1)
777+ end
778+
779+ function benchmark_fft ()
780+ println (" \n Benchmarking FFT..." )
781+ BS, N = FFT_BATCH, FFT_SIZE
782+ F0, F1, F2 = FFT_FACTORS
783+ D = FFT_ATOM_PACKING_DIM
784+ println (" Size: $BS batches × $N FFT ($(BS * N * 8 / 1e6 ) MB)" )
785+
786+ # Create complex input
787+ CUDA. seed! (42 )
788+ input = CUDA. randn (ComplexF32, BS, N)
789+
790+ # Reference result (FFTW)
791+ reference = FFTW. fft (Array (input), 2 )
792+
793+ results = BenchmarkResult[]
794+
795+ # Pre-compute twiddles (one-time CPU cost)
796+ W0, W1, W2, T0, T1 = fft_make_twiddles (FFT_FACTORS)
797+ W0_gpu, W1_gpu, W2_gpu = CuArray (W0), CuArray (W1), CuArray (W2)
798+ T0_gpu, T1_gpu = CuArray (T0), CuArray (T1)
799+
800+ # Pre-pack input (zero-copy)
801+ N2D = N * 2 ÷ D
802+ x_packed = reinterpret (reshape, Float32, input)
803+ y_packed = CUDA. zeros (Float32, D, BS, N2D)
804+
805+ # Kernel launch parameters
806+ F0F1, F1F2, F0F2 = F0 * F1, F1 * F2, F0 * F2
807+ grid = (BS, 1 , 1 )
808+
809+ # Kernel-only timing function
810+ cutile_kernel_f = () -> ct. launch (fft_kernel, grid,
811+ x_packed, y_packed,
812+ W0_gpu, W1_gpu, W2_gpu, T0_gpu, T1_gpu,
813+ ct. Constant (N), ct. Constant (F0), ct. Constant (F1), ct. Constant (F2),
814+ ct. Constant (F0F1), ct. Constant (F1F2), ct. Constant (F0F2),
815+ ct. Constant (BS), ct. Constant (D), ct. Constant (N2D))
816+
817+ # Verify correctness
818+ cutile_kernel_f ()
819+ CUDA. synchronize ()
820+ y_complex = reinterpret (reshape, ComplexF32, y_packed)
821+ output = copy (y_complex)
822+ @assert isapprox (Array (output), reference, rtol= 1e-3 ) " cuTile FFT incorrect!"
823+
824+ # Benchmark kernel only
825+ min_t, mean_t = benchmark_kernel (cutile_kernel_f)
826+ push! (results, BenchmarkResult (" cuTile.jl" , min_t, mean_t))
827+
828+ # Performance metric: GFLOPS (5 * N * log2(N) per complex FFT)
829+ flops_per_fft = 5.0 * N * log2 (N)
830+ total_flops = BS * flops_per_fft
831+ gflops = [string (round (total_flops / (r. min_ms * 1e-3 ) / 1e9 , digits= 1 ), " GFLOPS" ) for r in results]
832+
833+ print_table (" FFT (ComplexF32)" , results; extra_col= (" Performance" , gflops))
834+ return results
835+ end
836+
606837#= ============================================================================
607838 Main
608839=============================================================================#
@@ -622,6 +853,7 @@ function main()
622853 matmul_results = benchmark_matmul ()
623854 layernorm_results = benchmark_layernorm ()
624855 batchmatmul_results = benchmark_batchmatmul ()
856+ fft_results = benchmark_fft ()
625857
626858 println ()
627859 println (" =" ^ 60 )
0 commit comments