From 0a7042026b6734a0f18d903d8d24c25884d85749 Mon Sep 17 00:00:00 2001 From: Suryansh Gupta Date: Thu, 26 Mar 2026 00:16:52 +0530 Subject: [PATCH] Add Cache aware multi-vector distance functions --- .../src/multi_vector/block_transposed.rs | 47 ++ .../distance/cache_aware/f32_kernel.rs | 424 ++++++++++++++++++ .../distance/cache_aware/kernel.rs | 245 ++++++++++ .../multi_vector/distance/cache_aware/mod.rs | 112 +++++ .../src/multi_vector/distance/mod.rs | 2 + .../src/multi_vector/distance/simple.rs | 2 +- diskann-quantization/src/multi_vector/mod.rs | 2 +- 7 files changed, 832 insertions(+), 2 deletions(-) create mode 100644 diskann-quantization/src/multi_vector/distance/cache_aware/f32_kernel.rs create mode 100644 diskann-quantization/src/multi_vector/distance/cache_aware/kernel.rs create mode 100644 diskann-quantization/src/multi_vector/distance/cache_aware/mod.rs diff --git a/diskann-quantization/src/multi_vector/block_transposed.rs b/diskann-quantization/src/multi_vector/block_transposed.rs index 10339ea7b..fbca4fbfa 100644 --- a/diskann-quantization/src/multi_vector/block_transposed.rs +++ b/diskann-quantization/src/multi_vector/block_transposed.rs @@ -231,6 +231,15 @@ impl BlockTransposedRepr usize { + self.num_blocks() * GROUP + } + /// The stride (in elements) between the start of consecutive blocks. #[inline] fn block_stride(&self) -> usize { @@ -743,6 +752,15 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedRef<'a, self.data.repr().remainder() } + /// Total number of logical rows rounded up to the next multiple of `GROUP`. + /// + /// This is the number of "available" row slots in the backing allocation, + /// including zero-padded rows in the last (possibly partial) block. + #[inline] + pub fn available_rows(&self) -> usize { + self.data.repr().available_rows() + } + /// Return a raw typed pointer to the start of the backing data. #[inline] pub fn as_ptr(&self) -> *const T { @@ -870,6 +888,7 @@ impl<'a, T: Copy, const GROUP: usize, const PACK: usize> BlockTransposedMut<'a, delegate_to_ref!(pub fn full_blocks(&self) -> usize); delegate_to_ref!(pub fn num_blocks(&self) -> usize); delegate_to_ref!(pub fn remainder(&self) -> usize); + delegate_to_ref!(pub fn available_rows(&self) -> usize); delegate_to_ref!(pub fn as_ptr(&self) -> *const T); delegate_to_ref!(pub fn as_slice(&self) -> &[T]); delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T); @@ -1017,6 +1036,7 @@ impl BlockTransposed usize); delegate_to_ref!(pub fn num_blocks(&self) -> usize); delegate_to_ref!(pub fn remainder(&self) -> usize); + delegate_to_ref!(pub fn available_rows(&self) -> usize); delegate_to_ref!(pub fn as_ptr(&self) -> *const T); delegate_to_ref!(pub fn as_slice(&self) -> &[T]); delegate_to_ref!(#[allow(clippy::missing_safety_doc)] unsafe pub fn block_ptr_unchecked(&self, block: usize) -> *const T); @@ -1164,6 +1184,33 @@ impl BlockTransposed>> for BlockTransposed +// ════════════════════════════════════════════════════════════════════ + +impl + From>> for BlockTransposed +{ + /// Convert a row-major [`MatRef`] into a block-transposed matrix. + /// + /// This copies the data from the dense row-major layout into the + /// block-transposed layout suitable for SIMD distance computation. + fn from(v: MatRef<'_, super::matrix::Standard>) -> Self { + let nrows = v.num_vectors(); + let ncols = v.vector_dim(); + // SAFETY: `MatRef>` stores `nrows * ncols` contiguous `T` elements + // starting at the raw pointer returned by `as_raw_ptr()`. + let slice = + unsafe { std::slice::from_raw_parts(v.as_raw_ptr().cast::(), nrows * ncols) }; + // The dimensions are guaranteed valid because `MatRef>` was + // constructed from the same `nrows * ncols` layout. + #[allow(clippy::expect_used)] + let view = MatrixView::try_from(slice, nrows, ncols) + .expect("MatRef> has valid dimensions"); + Self::from_matrix_view(view) + } +} + // ════════════════════════════════════════════════════════════════════ // Index<(usize, usize)> for BlockTransposed // ════════════════════════════════════════════════════════════════════ diff --git a/diskann-quantization/src/multi_vector/distance/cache_aware/f32_kernel.rs b/diskann-quantization/src/multi_vector/distance/cache_aware/f32_kernel.rs new file mode 100644 index 000000000..3ec3f28df --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/cache_aware/f32_kernel.rs @@ -0,0 +1,424 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! f32 × f32 cache-aware micro-kernel using FMA and max-reduce. +//! +//! Provides: +//! +//! - `F32Kernel` — SIMD micro-kernel (16×4 via FMA + `max_simd`). +//! - [`cache_aware_chamfer`] — public entry point for the reducing max-IP GEMM. +//! - [`QueryBlockTransposedRef`] — newtype wrapper for f32 block-transposed queries. +//! - [`MaxSim`](crate::multi_vector::distance::MaxSim) / +//! [`Chamfer`](crate::multi_vector::distance::Chamfer) trait implementations for +//! `QueryBlockTransposedRef` × `MatRef>`. + +use std::ops::Deref; + +use diskann_vector::{DistanceFunctionMut, PureDistanceFunction}; +use diskann_wide::{SIMDMinMax, SIMDMulAdd, SIMDVector}; + +use super::CacheAwareKernel; +use super::kernel::{Reduce, tiled_reduce}; +use crate::multi_vector::distance::{Chamfer, MaxSim}; +use crate::multi_vector::{BlockTransposedRef, MatRef, Standard}; + +diskann_wide::alias!(f32s = f32x8); + +// ── QueryBlockTransposedRef ────────────────────────────────────── + +/// A query wrapper for block-transposed multi-vector views. +/// +/// This wrapper distinguishes query matrices from document matrices +/// at compile time, preventing accidental argument swapping in asymmetric +/// distance computations like [`MaxSim`](crate::multi_vector::distance::MaxSim) and +/// [`Chamfer`](crate::multi_vector::distance::Chamfer). +/// +/// Analogous to [`QueryMatRef`](crate::multi_vector::distance::QueryMatRef) but for +/// [`BlockTransposedRef`] queries rather than row-major +/// [`MatRef`](crate::multi_vector::MatRef) queries. +/// +/// # Example +/// +/// ``` +/// use diskann_quantization::multi_vector::{BlockTransposed, MatRef, Standard}; +/// use diskann_quantization::multi_vector::distance::QueryBlockTransposedRef; +/// +/// let data = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; +/// let mat_ref = MatRef::new(Standard::new(2, 3).unwrap(), &data).unwrap(); +/// let bt = BlockTransposed::::from(mat_ref); +/// let query = QueryBlockTransposedRef::from(bt.as_view()); +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct QueryBlockTransposedRef<'a>(pub BlockTransposedRef<'a, f32, 16>); + +impl<'a> From> for QueryBlockTransposedRef<'a> { + fn from(view: BlockTransposedRef<'a, f32, 16>) -> Self { + Self(view) + } +} + +impl<'a> Deref for QueryBlockTransposedRef<'a> { + type Target = BlockTransposedRef<'a, f32, 16>; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +// ── F32 kernel ─────────────────────────────────────────────────── + +/// Cache-aware micro-kernel for f32 queries and f32 documents. +/// +/// Uses FMA (`mul_add_simd`) to accumulate inner products and `max_simd` to +/// reduce across document vectors. The micro-panel geometry is 16 × 4 +/// (2 × f32x8 lanes × 4 broadcast unrolls). +pub(crate) struct F32Kernel; + +// SAFETY: F32Kernel's `full_panel` and `remainder_dispatch` only access +// A_PANEL(16) * k query elements, UNROLL * k doc elements, and A_PANEL(16) +// scratch elements — all within the bounds guaranteed by `tiled_reduce`. +unsafe impl CacheAwareKernel for F32Kernel { + type QueryElem = f32; + type DocElem = f32; + const A_PANEL: usize = 16; + const B_PANEL: usize = 4; + + #[inline(always)] + unsafe fn full_panel( + arch: diskann_wide::arch::Current, + a: *const f32, + b: *const f32, + k: usize, + r: *mut f32, + ) { + // SAFETY: Caller guarantees pointer validity per CacheAwareKernel contract. + unsafe { f32_microkernel::<{ Self::B_PANEL }>(arch, a, b, k, r) } + } + + #[inline(always)] + unsafe fn remainder_dispatch( + arch: diskann_wide::arch::Current, + remainder: usize, + a: *const f32, + b: *const f32, + k: usize, + r: *mut f32, + ) { + // SAFETY: Caller guarantees pointer validity per CacheAwareKernel contract. + unsafe { + match remainder { + 1 => f32_microkernel::<1>(arch, a, b, k, r), + 2 => f32_microkernel::<2>(arch, a, b, k, r), + 3 => f32_microkernel::<3>(arch, a, b, k, r), + _ => { + debug_assert!( + false, + "unexpected remainder {remainder} for B_PANEL={}", + Self::B_PANEL + ) + } + } + } + } +} + +// ── f32 micro-kernel ───────────────────────────────────────────── + +/// SIMD micro-kernel: processes 16 query rows × `UNROLL` document rows. +/// +/// Accumulates inner products via FMA (`mul_add_simd`) into two `f32x8` register +/// tiles (covering 16 query rows), then reduces across the `UNROLL` document +/// lanes with `max_simd` and merges into the scratch buffer `r`. +/// +/// # Safety +/// +/// * `a_packed` must point to `A_PANEL(16) × k` contiguous `f32` values. +/// * `b` must point to `UNROLL` rows of `k` contiguous `f32` values. +/// * `r` must point to at least `A_PANEL(16)` writable `f32` values. +#[inline(always)] +unsafe fn f32_microkernel( + arch: diskann_wide::arch::Current, + a_packed: *const f32, + b: *const f32, + k: usize, + r: *mut f32, +) where + [f32s; UNROLL]: Reduce, +{ + let op = |x: f32s, y: f32s| x.max_simd(y); + + let mut p0 = [f32s::default(arch); UNROLL]; + let mut p1 = [f32s::default(arch); UNROLL]; + let offsets: [usize; UNROLL] = core::array::from_fn(|i| k * i); + + let a_stride = 2 * f32s::LANES; + let a_stride_half = f32s::LANES; + + for i in 0..k { + // SAFETY: a_packed points to A_PANEL * k contiguous f32s (one micro-panel). + // b points to UNROLL rows of k contiguous f32s each. All reads are in-bounds. + unsafe { + let a0 = f32s::load_simd(arch, a_packed.add(a_stride * i)); + let a1 = f32s::load_simd(arch, a_packed.add(a_stride * i + a_stride_half)); + + for j in 0..UNROLL { + let bj = f32s::splat(arch, b.add(i + offsets[j]).read_unaligned()); + p0[j] = a0.mul_add_simd(bj, p0[j]); + p1[j] = a1.mul_add_simd(bj, p1[j]); + } + } + } + + // SAFETY: r points to at least A_PANEL = 16 writable f32s (2 × f32x8). + let mut r0 = unsafe { f32s::load_simd(arch, r) }; + // SAFETY: r + f32s::LANES is within the same A_PANEL-sized scratch region. + let mut r1 = unsafe { f32s::load_simd(arch, r.add(f32s::LANES)) }; + + r0 = op(r0, p0.reduce(&op)); + r1 = op(r1, p1.reduce(&op)); + + // SAFETY: r points to at least A_PANEL = 16 writable f32s (2 × f32x8). + unsafe { r0.store_simd(r) }; + // SAFETY: r + f32s::LANES is within the same A_PANEL-sized scratch region. + unsafe { r1.store_simd(r.add(f32s::LANES)) }; +} + +// ── Public f32 entry point ─────────────────────────────────────── + +#[inline(never)] +#[cold] +#[allow(clippy::panic)] +fn cache_aware_chamfer_panic() { + panic!( + "cache_aware_chamfer: precondition failed (scratch.len != available_rows or dimension mismatch)" + ); +} + +/// Compute the reducing max-IP GEMM between a block-transposed query (`a`) and +/// a row-major document matrix (`b`), writing per-query max similarities into `scratch`. +/// +/// This is a thin wrapper over the generic [`tiled_reduce`] loop using the +/// [`F32Kernel`] micro-kernel. +/// +/// # Arguments +/// +/// * `arch` - The SIMD architecture to use (from [`diskann_wide::ARCH`]). +/// * `a` - Block-transposed query matrix view (GROUP=16, PACK=1). +/// * `b` - Row-major document matrix view. +/// * `scratch` - Mutable buffer of length [`BlockTransposedRef::available_rows()`]. +/// Must be initialized to `f32::MIN` before the first call. On return, `scratch[i]` +/// contains the maximum inner product between query vector `i` and any document vector. +/// +/// # Panics +/// +/// Panics if `scratch.len() != a.available_rows()` or `a.ncols() != b.vector_dim()`. +pub fn cache_aware_chamfer( + arch: diskann_wide::arch::Current, + a: BlockTransposedRef<'_, f32, 16>, + b: MatRef<'_, Standard>, + scratch: &mut [f32], +) { + if scratch.len() != a.available_rows() || a.ncols() != b.vector_dim() { + cache_aware_chamfer_panic(); + } + + let k = a.ncols(); + let b_nrows = b.num_vectors(); + + // SAFETY: + // - a.as_ptr() is valid for a.available_rows() * k elements of f32. + // - MatRef> stores nrows * ncols contiguous f32 elements at as_raw_ptr(). + // - scratch.len() == a.available_rows() (checked above). + // - a.available_rows() is always a multiple of F32Kernel::A_PANEL (= 16 = GROUP). + unsafe { + tiled_reduce::( + arch, + a.as_ptr(), + a.available_rows(), + b.as_raw_ptr().cast::(), + b_nrows, + k, + scratch, + ); + } +} + +// ── MaxSim / Chamfer trait implementations ─────────────────────── + +impl DistanceFunctionMut, MatRef<'_, Standard>> for MaxSim<'_> { + #[inline(always)] + fn evaluate(&mut self, query: QueryBlockTransposedRef<'_>, doc: MatRef<'_, Standard>) { + assert!( + self.size() == query.nrows(), + "scores buffer not right size: {} != {}", + self.size(), + query.nrows() + ); + + if doc.num_vectors() == 0 { + // No document vectors — fill with MAX (no similarity found). + self.scores_mut().fill(f32::MAX); + return; + } + + let scratch = self.scores_mut(); + scratch.fill(f32::MIN); + + // Extend scratch to available_rows if needed (scratch may be smaller + // than available_rows due to padding). + let available = query.available_rows(); + let nq = query.nrows(); + + if available == nq { + // No padding — scratch is exactly the right size. + cache_aware_chamfer(diskann_wide::ARCH, *query, doc, scratch); + } else { + // Padding rows exist — need a larger scratch buffer. + let mut padded_scratch = vec![f32::MIN; available]; + cache_aware_chamfer(diskann_wide::ARCH, *query, doc, &mut padded_scratch); + scratch.copy_from_slice(&padded_scratch[..nq]); + } + + // The kernel wrote max inner products (positive = more similar). + // DiskANN convention: negate so that lower = better (distance semantics). + for s in scratch.iter_mut() { + *s = -*s; + } + } +} + +impl PureDistanceFunction, MatRef<'_, Standard>, f32> for Chamfer { + #[inline(always)] + fn evaluate(query: QueryBlockTransposedRef<'_>, doc: MatRef<'_, Standard>) -> f32 { + if doc.num_vectors() == 0 { + return 0.0; + } + + let available = query.available_rows(); + let nq = query.nrows(); + + let mut scratch = vec![f32::MIN; available]; + cache_aware_chamfer(diskann_wide::ARCH, *query, doc, &mut scratch); + + // Sum negated max similarities to get Chamfer distance. + scratch.iter().take(nq).map(|&s| -s).sum() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::multi_vector::BlockTransposed; + use crate::multi_vector::distance::QueryMatRef; + + /// Helper to create a MatRef from raw data. + fn make_query_mat(data: &[f32], nrows: usize, ncols: usize) -> MatRef<'_, Standard> { + MatRef::new(Standard::new(nrows, ncols).unwrap(), data).unwrap() + } + + /// Generate deterministic test data. + fn make_test_data(len: usize, ceil: usize, shift: usize) -> Vec { + (0..len).map(|v| ((v + shift) % ceil) as f32).collect() + } + + /// Test cases: (num_queries, num_docs, dim). + const TEST_CASES: &[(usize, usize, usize)] = &[ + (1, 1, 4), // Single query, single doc + (1, 5, 8), // Single query, multiple docs + (5, 1, 8), // Multiple queries, single doc + (3, 4, 16), // General case + (7, 7, 32), // Square case + (2, 3, 128), // Larger dimension + (16, 4, 64), // Exact A_PANEL width + (17, 4, 64), // One more than A_PANEL (remainder) + (32, 5, 16), // Multiple full A-panels, remainder B-rows (5 % 4 = 1) + (16, 6, 32), // Remainder B-rows (6 % 4 = 2) + (16, 7, 32), // Remainder B-rows (7 % 4 = 3) + (16, 8, 32), // No remainder B-rows (8 % 4 = 0) + ]; + + #[test] + fn chamfer_matches_simple_kernel() { + for &(nq, nd, dim) in TEST_CASES { + let query_data = make_test_data(nq * dim, dim, dim / 2); + let doc_data = make_test_data(nd * dim, dim, dim); + + let query_mat = make_query_mat(&query_data, nq, dim); + let doc = make_query_mat(&doc_data, nd, dim); + + // Reference: simple kernel Chamfer + let simple_query: QueryMatRef<_> = query_mat.into(); + let expected = Chamfer::evaluate(simple_query, doc); + + // Cache-aware: block-transposed query + let bt = BlockTransposed::::from(query_mat); + let bt_query = QueryBlockTransposedRef::from(bt.as_view()); + let actual = Chamfer::evaluate(bt_query, doc); + + assert!( + (actual - expected).abs() < 1e-2, + "Chamfer mismatch for ({nq},{nd},{dim}): actual={actual}, expected={expected}" + ); + } + } + + #[test] + fn max_sim_matches_simple_kernel() { + for &(nq, nd, dim) in TEST_CASES { + let query_data = make_test_data(nq * dim, dim, dim / 2); + let doc_data = make_test_data(nd * dim, dim, dim); + + let query_mat = make_query_mat(&query_data, nq, dim); + let doc = make_query_mat(&doc_data, nd, dim); + + // Reference: simple kernel MaxSim + let mut expected_scores = vec![0.0f32; nq]; + let simple_query: QueryMatRef<_> = query_mat.into(); + let _ = MaxSim::new(&mut expected_scores) + .unwrap() + .evaluate(simple_query, doc); + + // Cache-aware: block-transposed query + let bt = BlockTransposed::::from(query_mat); + let bt_query = QueryBlockTransposedRef::from(bt.as_view()); + let mut actual_scores = vec![0.0f32; nq]; + MaxSim::new(&mut actual_scores) + .unwrap() + .evaluate(bt_query, doc); + + for i in 0..nq { + assert!( + (actual_scores[i] - expected_scores[i]).abs() < 1e-2, + "MaxSim[{i}] mismatch for ({nq},{nd},{dim}): actual={}, expected={}", + actual_scores[i], + expected_scores[i] + ); + } + } + } + + #[test] + fn chamfer_with_zero_docs_returns_zero() { + let query_data = [1.0f32, 0.0, 0.0, 1.0]; + let query_mat = make_query_mat(&query_data, 2, 2); + let bt = BlockTransposed::::from(query_mat); + let bt_query = QueryBlockTransposedRef::from(bt.as_view()); + + let doc = make_query_mat(&[], 0, 2); + let result = Chamfer::evaluate(bt_query, doc); + assert_eq!(result, 0.0); + } + + #[test] + #[should_panic(expected = "scores buffer not right size")] + fn max_sim_panics_on_size_mismatch() { + let query_data = [1.0f32, 2.0, 3.0, 4.0]; + let query_mat = make_query_mat(&query_data, 2, 2); + let bt = BlockTransposed::::from(query_mat); + let bt_query = QueryBlockTransposedRef::from(bt.as_view()); + + let doc = make_query_mat(&[1.0, 1.0], 1, 2); + let mut scores = vec![0.0f32; 3]; // Wrong size + MaxSim::new(&mut scores).unwrap().evaluate(bt_query, doc); + } +} diff --git a/diskann-quantization/src/multi_vector/distance/cache_aware/kernel.rs b/diskann-quantization/src/multi_vector/distance/cache_aware/kernel.rs new file mode 100644 index 000000000..921241a15 --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/cache_aware/kernel.rs @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Generic cache-aware tiling loop and shared reduction utilities. +//! +//! This module contains the type-agnostic parts of the cache-aware implementation: +//! +//! - `FullReduce` — tile planner that computes A/B panel counts from cache budgets. +//! - `tiled_reduce` — the 5-level loop nest that drives any `CacheAwareKernel`. +//! - `Reduce` — compile-time unroll reduction trait for fixed-size accumulator arrays. + +use super::{CacheAwareKernel, L1_B_TILE_BUDGET, L2_A_TILE_BUDGET}; + +// ── Tile planner ───────────────────────────────────────────────── + +/// A plan for performing the reducing GEMM when the contraction dimension `k` +/// is small enough that a micro-panel of `A` fits comfortably in L1 cache with room for +/// multiple micro-panels of `B`. +#[derive(Debug, Clone, Copy)] +struct FullReduce { + /// The number of micro panels of `A` that make up a tile. + a_panels: usize, + + /// The number of micro panels of `B` that make up a tile. + b_panels: usize, +} + +impl FullReduce { + /// Compute A-tile and B-tile panel counts from cache budgets. + /// + /// * `a_row_bytes` — bytes per query row (`k * size_of::()`). + /// * `b_row_bytes` — bytes per document row (`k * size_of::()`). + /// * `a_panel` — micro-kernel query panel height (`K::A_PANEL`). + /// * `b_panel` — micro-kernel document panel width (`K::B_PANEL`). + /// * `l2_budget` — L2 cache budget in bytes for the A tile. + /// * `l1_budget` — L1 cache budget in bytes for the B tile. + fn new( + a_row_bytes: usize, + b_row_bytes: usize, + a_panel: usize, + b_panel: usize, + l2_budget: usize, + l1_budget: usize, + ) -> Self { + let a_row_bytes = a_row_bytes.max(1); + let b_row_bytes = b_row_bytes.max(1); + + let a_panels = (l2_budget / (a_row_bytes * a_panel)).max(1); + + let a_panel_bytes = a_panel * a_row_bytes; + let b_tile_budget = l1_budget.saturating_sub(a_panel_bytes); + let b_panels = (b_tile_budget / (b_row_bytes * b_panel)).max(1); + + Self { a_panels, b_panels } + } +} + +// ── Generic tiled reduce ───────────────────────────────────────── + +/// Execute the 5-level cache-aware tiling loop with a pluggable SIMD micro-kernel. +/// +/// This is the core scheduling primitive. The loop nest is: +/// ```text +/// Loop 1: A tiles (query tiles, sized to L2) +/// Loop 2: B tiles (doc tiles, sized to L1) +/// Loop 3: A panels (micro-panels within A tile) +/// Loop 4: B panels (micro-panels within B tile) +/// K::full_panel / K::remainder_dispatch +/// ``` +/// +/// # Safety +/// +/// * `a_ptr` must be valid for `a_available_rows * k` elements of `K::QueryElem`. +/// * `b_ptr` must be valid for `b_nrows * k` elements of `K::DocElem`. +/// * `scratch` must have length `a_available_rows` and be initialized by caller. +/// * `a_available_rows` must be a multiple of `K::A_PANEL`. +pub(crate) unsafe fn tiled_reduce( + arch: diskann_wide::arch::Current, + a_ptr: *const K::QueryElem, + a_available_rows: usize, + b_ptr: *const K::DocElem, + b_nrows: usize, + k: usize, + scratch: &mut [f32], +) { + debug_assert_eq!( + a_available_rows % K::A_PANEL, + 0, + "a_available_rows ({a_available_rows}) must be a multiple of A_PANEL ({})", + K::A_PANEL, + ); + + let a_row_bytes = k * std::mem::size_of::(); + let b_row_bytes = k * std::mem::size_of::(); + let plan = FullReduce::new( + a_row_bytes, + b_row_bytes, + K::A_PANEL, + K::B_PANEL, + L2_A_TILE_BUDGET, + L1_B_TILE_BUDGET, + ); + + let a_panel_stride = K::A_PANEL * k; + let a_tile_stride = a_panel_stride * plan.a_panels; + let b_panel_stride = K::B_PANEL * k; + let b_tile_stride = b_panel_stride * plan.b_panels; + + let remainder = b_nrows % K::B_PANEL; + + // SAFETY: Caller guarantees a_ptr is valid for a_available_rows * k elements. + let pa_end = unsafe { a_ptr.add(a_available_rows * k) }; + // SAFETY: Caller guarantees b_ptr is valid for b_nrows * k elements. + let pb_end = unsafe { b_ptr.add(b_nrows * k) }; + // SAFETY: remainder < B_PANEL, so pb_end - remainder * k is within allocation. + let pb_full_end = unsafe { pb_end.sub(remainder * k) }; + + // SAFETY: All pointer arithmetic stays within the respective allocations. + unsafe { + let mut pa_tile = a_ptr; + let mut pr_tile = scratch.as_mut_ptr(); + + // Loop 1: Tiles of `A`. + while pa_tile < pa_end { + let remaining_a = + (pa_end as usize - pa_tile as usize) / std::mem::size_of::(); + let pa_tile_end = pa_tile.add(a_tile_stride.min(remaining_a)); + + let mut pb_tile = b_ptr; + + // Loop 2: Full B-tiles (every panel in the tile is complete). + while pb_tile.wrapping_add(b_tile_stride) <= pb_full_end { + let pb_tile_end = pb_tile.add(b_tile_stride); + + let mut pa_panel = pa_tile; + let mut pr_panel = pr_tile; + + // Loop 3: Micro-panels of `A`. + while pa_panel < pa_tile_end { + let mut pb_panel = pb_tile; + + // Loop 4: Micro-panels of `B` (all full, no remainder check). + while pb_panel < pb_tile_end { + K::full_panel(arch, pa_panel, pb_panel, k, pr_panel); + pb_panel = pb_panel.add(b_panel_stride); + } + + pa_panel = pa_panel.add(a_panel_stride); + pr_panel = pr_panel.add(K::A_PANEL); + } + pb_tile = pb_tile.add(b_tile_stride); + } + + // Peeled last B-tile: contains remaining full panels + remainder rows. + if pb_tile < pb_end { + let mut pa_panel = pa_tile; + let mut pr_panel = pr_tile; + + // Loop 3 (peeled): Micro-panels of `A`. + while pa_panel < pa_tile_end { + let mut pb_panel = pb_tile; + + // Loop 4 (peeled): Full B-panels in the last tile. + while pb_panel < pb_full_end { + K::full_panel(arch, pa_panel, pb_panel, k, pr_panel); + pb_panel = pb_panel.add(b_panel_stride); + } + + // Remainder dispatch: 1..(B_PANEL-1) leftover B-rows. + if remainder > 0 { + K::remainder_dispatch(arch, remainder, pa_panel, pb_panel, k, pr_panel); + } + + pa_panel = pa_panel.add(a_panel_stride); + pr_panel = pr_panel.add(K::A_PANEL); + } + } + + // NOTE: Use `wrapping_add` so we can still do this on the last iteration. + pa_tile = pa_tile.wrapping_add(a_tile_stride); + pr_tile = pr_tile.wrapping_add(K::A_PANEL * plan.a_panels); + } + } +} + +// ── Reduce trait for compile-time unroll reduction ─────────────── + +/// Compile-time unroll reduction over fixed-size arrays. +/// +/// Used by the micro-kernel to reduce `UNROLL` accumulators into a single value +/// using a caller-supplied binary operator (e.g. `max_simd`). +pub(super) trait Reduce { + type Element; + fn reduce(&self, f: &F) -> Self::Element + where + F: Fn(Self::Element, Self::Element) -> Self::Element; +} + +impl Reduce for [T; 1] { + type Element = T; + + #[inline(always)] + fn reduce(&self, _f: &F) -> T + where + F: Fn(T, T) -> T, + { + self[0] + } +} + +impl Reduce for [T; 2] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(self[0], self[1]) + } +} + +impl Reduce for [T; 3] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(f(self[0], self[1]), self[2]) + } +} + +impl Reduce for [T; 4] { + type Element = T; + + #[inline(always)] + fn reduce(&self, f: &F) -> T + where + F: Fn(T, T) -> T, + { + f(f(self[0], self[1]), f(self[2], self[3])) + } +} diff --git a/diskann-quantization/src/multi_vector/distance/cache_aware/mod.rs b/diskann-quantization/src/multi_vector/distance/cache_aware/mod.rs new file mode 100644 index 000000000..f93764c0e --- /dev/null +++ b/diskann-quantization/src/multi_vector/distance/cache_aware/mod.rs @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +//! Cache-aware block-transposed SIMD implementation for multi-vector distance computation. +//! +//! This module provides a SIMD-accelerated implementation that uses block-transposed +//! memory layout for **query** vectors (instead of documents), with documents remaining +//! in row-major format. +//! +//! # Module Organization +//! +//! - [`CacheAwareKernel`] — unsafe trait that each element type implements (this file). +//! - `kernel` — generic 5-level cache-aware tiling loop (`tiled_reduce`). +//! - `f32_kernel` — f32 SIMD micro-kernel, entry point ([`cache_aware_chamfer`]), +//! query wrapper ([`QueryBlockTransposedRef`]), and `MaxSim` / `Chamfer` trait impls. +//! +//! # Cache-Aware Tiling Strategy +//! +//! This approach uses a reducing-GEMM pattern modeled after high-performance BLAS +//! implementations: +//! +//! - **L2 cache**: Tiles of the transposed query ("A") are sized to fit in L2. +//! - **L1 cache**: Tiles of the document ("B") plus one micro-panel of A are sized +//! to fit in L1. +//! - **Micro-kernel**: An `A_PANEL × B_PANEL` micro-kernel (e.g. 16×4 for f32) +//! processes a panel of query vectors against a panel of document vectors per +//! invocation, accumulating max-IP into a scratch buffer. The panel sizes are +//! determined by the [`CacheAwareKernel`] implementation for each element type. +//! +//! # Memory Layout +//! +//! - **Query**: Block-transposed (`GROUP` vectors per block, dimensions contiguous +//! within each block). The block size is determined by the kernel's `A_PANEL`. +//! - **Document**: Row-major (standard [`MatRef`](crate::multi_vector::MatRef) format). + +mod f32_kernel; +mod kernel; + +pub use f32_kernel::{QueryBlockTransposedRef, cache_aware_chamfer}; + +// ── Cache budget constants ─────────────────────────────────────── + +/// Approximate usable L1 data cache in bytes (conservative estimate). +const L1_CACHE: usize = 48_000; + +/// Approximate usable L2 cache in bytes (conservative estimate). +const L2_CACHE: usize = 1_250_000; + +/// Fraction of L2 reserved for the A tile. The remainder accommodates B streaming +/// traffic and incidental cache pollution. +const L2_A_TILE_BUDGET: usize = L2_CACHE / 2; + +/// Fraction of L1 available for the B tile. The A micro-panel is subtracted at +/// runtime since it depends on K; this is the total L1 budget before that subtraction. +const L1_B_TILE_BUDGET: usize = L1_CACHE * 3 / 4; + +// ── CacheAwareKernel trait ─────────────────────────────────────── + +/// Trait abstracting a SIMD micro-kernel for the cache-aware tiling loop. +/// +/// Each implementation provides the element types, panel geometry, and the actual +/// SIMD micro-kernel body. The generic [`tiled_reduce`](kernel::tiled_reduce) function +/// handles the 5-level cache-aware loop nest and calls into the kernel via this trait. +/// +/// # Safety +/// +/// Implementors must ensure that `full_panel` and `remainder_dispatch` only +/// read/write within the bounds described by their pointer arguments and the +/// `k` / panel-size contracts. +pub(crate) unsafe trait CacheAwareKernel { + /// Element type stored in the block-transposed query ("A" side). + type QueryElem: Copy; + /// Element type stored in the row-major document ("B" side). + type DocElem: Copy; + + /// Number of query rows processed per micro-kernel invocation. + /// Determined by SIMD register width for the element type. + const A_PANEL: usize; + /// Number of document rows processed per micro-kernel invocation + /// (broadcast unroll factor). + const B_PANEL: usize; + + /// Process one full `A_PANEL × B_PANEL` micro-panel pair. + /// + /// # Safety + /// + /// * `a` must point to `A_PANEL * k` contiguous `QueryElem` values. + /// * `b` must point to `B_PANEL` rows of `k` contiguous `DocElem` values. + /// * `r` must point to at least `A_PANEL` writable `f32` values. + unsafe fn full_panel( + arch: diskann_wide::arch::Current, + a: *const Self::QueryElem, + b: *const Self::DocElem, + k: usize, + r: *mut f32, + ); + + /// Dispatch for `1..(B_PANEL-1)` remainder document rows. + /// + /// # Safety + /// + /// Same pointer contracts as `full_panel`, but `b` points to `remainder` + /// rows instead of `B_PANEL` rows. + unsafe fn remainder_dispatch( + arch: diskann_wide::arch::Current, + remainder: usize, + a: *const Self::QueryElem, + b: *const Self::DocElem, + k: usize, + r: *mut f32, + ); +} diff --git a/diskann-quantization/src/multi_vector/distance/mod.rs b/diskann-quantization/src/multi_vector/distance/mod.rs index e0b106548..84755ce0c 100644 --- a/diskann-quantization/src/multi_vector/distance/mod.rs +++ b/diskann-quantization/src/multi_vector/distance/mod.rs @@ -45,8 +45,10 @@ //! // scores[1] = 0.0 (query[1] has no good match) //! ``` +pub mod cache_aware; mod max_sim; mod simple; +pub use cache_aware::QueryBlockTransposedRef; pub use max_sim::{Chamfer, MaxSim, MaxSimError}; pub use simple::QueryMatRef; diff --git a/diskann-quantization/src/multi_vector/distance/simple.rs b/diskann-quantization/src/multi_vector/distance/simple.rs index b92f9fa7e..395610fb5 100644 --- a/diskann-quantization/src/multi_vector/distance/simple.rs +++ b/diskann-quantization/src/multi_vector/distance/simple.rs @@ -299,7 +299,7 @@ mod tests { // No query vectors means sum is 0 assert_eq!(result, 0.0); - let result = Chamfer::evaluate(doc.into(), query.deref().reborrow()); + let result = Chamfer::evaluate(QueryMatRef::from(doc), query.deref().reborrow()); assert_eq!(result, 0.0); } diff --git a/diskann-quantization/src/multi_vector/mod.rs b/diskann-quantization/src/multi_vector/mod.rs index f8598eba0..d0ce77203 100644 --- a/diskann-quantization/src/multi_vector/mod.rs +++ b/diskann-quantization/src/multi_vector/mod.rs @@ -71,7 +71,7 @@ pub mod distance; pub(crate) mod matrix; pub use block_transposed::{BlockTransposed, BlockTransposedMut, BlockTransposedRef}; -pub use distance::{Chamfer, MaxSim, MaxSimError, QueryMatRef}; +pub use distance::{Chamfer, MaxSim, MaxSimError, QueryBlockTransposedRef, QueryMatRef}; pub use matrix::{ Defaulted, LayoutError, Mat, MatMut, MatRef, NewCloned, NewMut, NewOwned, NewRef, Overflow, Repr, ReprMut, ReprOwned, SliceError, Standard,