Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
843 changes: 843 additions & 0 deletions crypto/math/src/fft/bowers_fft_batch.rs

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions crypto/math/src/fft/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
pub mod bit_reversing;
#[cfg(feature = "alloc")]
pub mod bowers_fft;
#[cfg(feature = "alloc")]
pub mod bowers_fft_batch;
pub mod errors;
#[cfg(feature = "alloc")]
pub mod roots_of_unity;
#[cfg(feature = "alloc")]
pub mod two_half_fft;

#[cfg(all(test, feature = "alloc"))]
pub(crate) mod test_helpers;
246 changes: 246 additions & 0 deletions crypto/math/src/fft/two_half_fft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
//! Cache-blocked, transpose-free batched FFT (port of Plonky3's two-half
//! `Radix2DitParallel::dft_batch`).
//!
//! The flat Bowers DIF streams the whole `n·m` buffer with large strides at the
//! early layers, thrashing cache for large `n`. This kernel keeps every layer
//! cache-resident by interleaving bit-reversals: bit-reverse → first `mid` DIT
//! layers within `2^mid`-row chunks → bit-reverse → remaining layers within
//! `2^(log_n−mid)`-row chunks → bit-reverse. The bit-reversals turn the
//! large-stride butterflies into chunk-local ones — the cache win the flat
//! Bowers misses. Output is natural order, identical to
//! `bowers_fft_batch_row_major` followed by `in_place_bit_reverse_permute_row_major`.
//!
//! Twiddles are precomputed once per size in [`TwoHalfTwiddles`] and reused
//! across calls (the trace LDE invokes this once per direction per domain, and
//! the same domain recurs across tables and rounds).

#[cfg(feature = "alloc")]
use crate::fft::bit_reversing::reverse_index;
#[cfg(feature = "alloc")]
use crate::fft::bowers_fft_batch::in_place_bit_reverse_permute_row_major;
#[cfg(feature = "alloc")]
use crate::fft::errors::FFTError;
#[cfg(feature = "alloc")]
use crate::field::{
element::FieldElement,
traits::{IsFFTField, IsField, IsSubFieldOf},
};
#[cfg(all(feature = "alloc", feature = "parallel"))]
use rayon::prelude::*;

/// In-place bit-reversal permutation of a flat slice (length a power of two).
#[cfg(feature = "alloc")]
fn bit_reverse_vec<F: IsField>(v: &mut [FieldElement<F>]) {
let n = v.len();
for i in 0..n {
let j = reverse_index(i, n as u64);
if j > i {
v.swap(i, j);
}
}
}

/// Precomputed twiddles for a size-`2^log_n` two-half FFT in one direction.
///
/// `tw` is the flat geometric array `[ω⁰, ω¹, …, ω^(n/2−1)]` (`ω` the forward
/// root for the forward transform, its inverse for the inverse transform);
/// `bitrev_tw` is its bit-reversal permutation, used by the second-half layers.
/// Build once and share across calls of the same size and direction.
#[cfg(feature = "alloc")]
pub struct TwoHalfTwiddles<F: IsField> {
log_n: usize,
tw: Vec<FieldElement<F>>,
bitrev_tw: Vec<FieldElement<F>>,
}

#[cfg(feature = "alloc")]
impl<F: IsFFTField> TwoHalfTwiddles<F> {
/// Precompute twiddles for a size-`2^log_n` transform. `inverse = true`
/// selects the (unscaled) inverse transform (uses `ω⁻¹`); the `1/n`
/// normalization is the caller's responsibility.
pub fn new(log_n: usize, inverse: bool) -> Result<Self, FFTError> {
let n = 1usize << log_n;
let half = n / 2;
// `omega` is unused when half == 0 (log_n == 0), so skip the lookup.
let omega = if half == 0 {
FieldElement::<F>::one()
} else {
let fwd = F::get_primitive_root_of_unity(log_n as u64)
.map_err(|_| FFTError::InputError(n))?;
if inverse {
fwd.inv().map_err(|_| FFTError::InputError(n))?
} else {
fwd
}
};

let mut tw: Vec<FieldElement<F>> = Vec::with_capacity(half);
let mut cur = FieldElement::<F>::one();
for _ in 0..half {
tw.push(cur.clone());
cur = &cur * &omega;
}
let mut bitrev_tw = tw.clone();
bit_reverse_vec(&mut bitrev_tw);

Ok(Self {
log_n,
tw,
bitrev_tw,
})
}
}

/// DIT butterfly over two equal-length row-slices, one twiddle for all pairs:
/// `a' = a + tw·b`, `b' = a − tw·b` (element-wise; `tw·b` is the F×E multiply).
#[cfg(feature = "alloc")]
#[inline]
fn dit_butterfly_rows<F, E>(
lo: &mut [FieldElement<E>],
hi: &mut [FieldElement<E>],
tw: &FieldElement<F>,
) where
F: IsSubFieldOf<E>,
E: IsField,
{
for (a, b) in lo.iter_mut().zip(hi.iter_mut()) {
let t = tw * &*b; // F × E → E
let new_a = &*a + &t;
*b = &*a - &t;
*a = new_a;
}
}

/// First-half DIT layer (per-pair twiddle), applied within one cache-resident
/// row-chunk. `tw` is the flat `[ω^0..ω^(n/2−1)]` array; pair `j` of layer
/// `layer` uses `tw[j · 2^(log_n−1−layer)]`.
#[cfg(feature = "alloc")]
fn dit_first_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half = 1usize << layer;
let block_rows = half * 2;
let step = 1usize << (log_n - 1 - layer);
for block in chunk.chunks_mut(block_rows * m) {
let (lows, highs) = block.split_at_mut(half * m);
for j in 0..half {
let twj = &tw[j * step];
dit_butterfly_rows(
&mut lows[j * m..j * m + m],
&mut highs[j * m..j * m + m],
twj,
);
}
}
}

/// Second-half DIT layer (one twiddle per block, bit-reversed twiddle order),
/// applied within one cache-resident row-chunk owned by `thread`.
#[cfg(feature = "alloc")]
fn dit_second_half_layer<F, E>(
chunk: &mut [FieldElement<E>],
m: usize,
layer: usize,
log_n: usize,
mid: usize,
thread: usize,
bitrev_tw: &[FieldElement<F>],
) where
F: IsSubFieldOf<E>,
E: IsField,
{
let half_block = 1usize << (log_n - 1 - layer);
let block_rows = half_block * 2;
let first_block = thread << (layer - mid);
for (b, block) in chunk.chunks_mut(block_rows * m).enumerate() {
let twb = &bitrev_tw[first_block + b];
let (lows, highs) = block.split_at_mut(half_block * m);
dit_butterfly_rows(lows, highs, twb);
}
}

/// Cache-blocked, transpose-free batched FFT. `buf` is `n * num_cols` row-major
/// (`n` rows of `num_cols` consecutive elements); `tw` are the precomputed
/// twiddles for size `n` in the desired direction (forward or inverse).
/// Output is the natural-order DFT (matches `bowers_fft_batch_row_major`
/// followed by `in_place_bit_reverse_permute_row_major`). Inverse transforms
/// are NOT scaled by `1/n` — that is the caller's responsibility (e.g. folded
/// into the coset-weight pass of the LDE).
#[cfg(feature = "alloc")]
pub fn fft_batch_two_half<F, E>(
buf: &mut [FieldElement<E>],
num_cols: usize,
tw: &TwoHalfTwiddles<F>,
) -> Result<(), FFTError>
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
FieldElement<F>: Sync,
FieldElement<E>: Send + Sync,
{
let m = num_cols;
if m == 0 || buf.is_empty() {
return Ok(());
}
let total = buf.len();
if !total.is_multiple_of(m) {
return Err(FFTError::InputError(total));
}
let n = total / m;
if !n.is_power_of_two() {
return Err(FFTError::InputError(n));
}
let log_n = n.trailing_zeros() as usize;
if log_n != tw.log_n {
return Err(FFTError::InputError(n));
}
if log_n == 0 {
return Ok(());
}

let flat_tw = &tw.tw;
let bitrev_tw = &tw.bitrev_tw;
let mid = log_n.div_ceil(2);

// Step 1: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 2: first half — layers 0..mid within 2^mid-row chunks (all identical).
let first_chunk = (1usize << mid) * m;
#[cfg(feature = "parallel")]
let it = buf.par_chunks_mut(first_chunk);
#[cfg(not(feature = "parallel"))]
let it = buf.chunks_mut(first_chunk);
it.for_each(|chunk| {
for layer in 0..mid {
dit_first_half_layer::<F, E>(chunk, m, layer, log_n, flat_tw);
}
});

// Step 3: bit-reverse rows.
in_place_bit_reverse_permute_row_major(buf, m);

// Step 4: second half — layers mid..log_n within 2^(log_n-mid)-row chunks.
let second_chunk = (1usize << (log_n - mid)) * m;
#[cfg(feature = "parallel")]
let it2 = buf.par_chunks_mut(second_chunk).enumerate();
#[cfg(not(feature = "parallel"))]
let it2 = buf.chunks_mut(second_chunk).enumerate();
it2.for_each(|(thread, chunk)| {
for layer in mid..log_n {
dit_second_half_layer::<F, E>(chunk, m, layer, log_n, mid, thread, bitrev_tw);
}
});

// Step 5: final bit-reverse to natural order.
in_place_bit_reverse_permute_row_major(buf, m);

Ok(())
}
93 changes: 93 additions & 0 deletions crypto/math/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::fft::bowers_fft::{LayerTwiddles, bowers_fft_opt_fused, bowers_ifft_op
#[cfg(feature = "parallel")]
use crate::fft::bowers_fft::{bowers_fft_opt_fused_parallel, bowers_ifft_opt_parallel};
use crate::fft::errors::FFTError;
use crate::fft::two_half_fft::{TwoHalfTwiddles, fft_batch_two_half};
use crate::field::traits::{IsFFTField, IsField, IsSubFieldOf};
use alloc::{borrow::ToOwned, vec, vec::Vec};

Expand Down Expand Up @@ -502,6 +503,98 @@ impl<E: IsField> Polynomial<FieldElement<E>> {

Ok(())
}

/// Batched row-major coset LDE expansion.
///
/// `buffer` is the row-major flat layout of `n * num_cols` elements
/// (input trace evaluations on the natural-order domain, all M columns
/// interleaved per row). It is expanded in place to length
/// `n * blowup_factor * num_cols`, also row-major, holding the LDE
/// evaluations on the coset.
///
/// Pipeline mirrors [`coset_lde_full_expand`] cell-for-cell, just with
/// the row-major batched FFT primitives so the M columns share twiddle
/// loads inside each butterfly:
/// 1. batched iFFT (DIT) over rows[..n]
/// 2. scale rows[..n] by coset weights (one weight per row, applied to
/// all M elements of that row)
/// 3. zero-pad rows to `n * blowup_factor`
/// 4. batched forward FFT (DIF)
///
/// `weights` must be `n` base-field elements in natural row order.
/// `inv_twiddles` are the size-`n` inverse two-half twiddles; `fwd_twiddles`
/// the size-`n·blowup_factor` forward ones.
pub fn coset_lde_full_expand_row_major<F: IsFFTField + IsSubFieldOf<E> + Send + Sync>(
buffer: &mut Vec<FieldElement<E>>,
num_cols: usize,
blowup_factor: usize,
weights: &[FieldElement<F>],
inv_twiddles: &TwoHalfTwiddles<F>,
fwd_twiddles: &TwoHalfTwiddles<F>,
) -> Result<(), FFTError>
where
E: Send + Sync,
{
if num_cols == 0 || buffer.is_empty() {
return Ok(());
}
let total = buffer.len();
if !total.is_multiple_of(num_cols) {
return Err(FFTError::InputError(total));
}
let n = total / num_cols;
if !n.is_power_of_two() {
return Err(FFTError::InputError(n));
}
let lde_n = n * blowup_factor;
if (lde_n.trailing_zeros() as u64) > F::TWO_ADICITY {
return Err(FFTError::DomainSizeError(lde_n.trailing_zeros() as usize));
}
if weights.len() < n {
return Err(FFTError::InputError(weights.len()));
}

// 1. iFFT on rows[..n] (cache-blocked two-half; natural→natural, no 1/n
// — the 1/n is folded into the coset-weight pass below). Replaces the
// flat-Bowers iFFT, which cache-thrashes at large n.
let prefix_len = n * num_cols;
fft_batch_two_half::<F, E>(&mut buffer[..prefix_len], num_cols, inv_twiddles)?;

// 2. Scale by coset weights — one weight per row, multiply M elements
// of that row by it. Each row is independent → parallelizable.
#[cfg(feature = "parallel")]
{
use rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
buffer[..prefix_len]
.par_chunks_exact_mut(num_cols)
.enumerate()
.for_each(|(r, row)| {
let w = &weights[r];
for x in row.iter_mut() {
*x = w * &*x;
}
});
}
#[cfg(not(feature = "parallel"))]
{
for r in 0..n {
let w = &weights[r];
let row = &mut buffer[r * num_cols..(r + 1) * num_cols];
for x in row.iter_mut() {
*x = w * &*x;
}
}
}

// 3. Zero-pad rows to lde_n.
buffer.resize(lde_n * num_cols, FieldElement::zero());

// 4. Forward FFT (cache-blocked two-half; natural-order output, replaces
// the flat Bowers fwd-FFT(2n) + bit-reverse — the cache-bound step).
fft_batch_two_half::<F, E>(buffer, num_cols, fwd_twiddles)?;

Ok(())
}
}

#[cfg(test)]
Expand Down
Loading
Loading