diff --git a/Cargo.lock b/Cargo.lock index a2c805220cc..bf8c74d7058 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11251,6 +11251,28 @@ dependencies = [ "web-sys", ] +[[package]] +name = "vortex-turboquant" +version = "0.1.0" +dependencies = [ + "codspeed-divan-compat", + "half", + "num-traits", + "prost 0.14.3", + "rand 0.10.1", + "rstest", + "vortex-array", + "vortex-buffer", + "vortex-error", + "vortex-file", + "vortex-io", + "vortex-layout", + "vortex-mask", + "vortex-session", + "vortex-tensor", + "vortex-utils", +] + [[package]] name = "vortex-utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f8726a6f265..58149ae7cb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "vortex-proto", "vortex-array", "vortex-tensor", + "vortex-turboquant", "vortex-compressor", "vortex-btrblocks", "vortex-layout", @@ -296,6 +297,7 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } +vortex-turboquant = { version = "0.1.0", path = "./vortex-turboquant", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 1ebb5962a90..6fbc7b26fa3 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -528,6 +528,10 @@ pub fn vortex_tensor::vector::AnyVector::try_match<'a>(&'a vortex_array::dtype:: pub struct vortex_tensor::vector::Vector +impl vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::try_new_vector_array(vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::vector::Vector pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index ca32faa6ec9..e656ba18822 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -205,8 +205,8 @@ fn turboquant_quantize_core( let dimension = fsl.list_size() as usize; let num_rows = fsl.len(); - let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dimension.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 3913cf3d8fe..a9ec822cb2a 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -259,8 +259,8 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> { } // Forward transform + quantize (mimicking what turboquant_quantize_core does). - let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let centroids = compute_or_get_centroids(padded_dim as u32, 8)?; let boundaries = compute_centroid_boundaries(¢roids); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index f34cb257dbd..2386dc03e53 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -366,7 +366,7 @@ impl InnerProduct { let mut padded_query = vec![0.0f32; padded_dim]; padded_query[..dim].copy_from_slice(flat.as_slice::()); - let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); @@ -930,7 +930,7 @@ mod tests { seed: u64, num_rounds: u8, ) -> VortexResult> { - let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?; + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; let mut out = Vec::with_capacity(num_rows * dim); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs index ff8aebd0f11..b416eb4a7e6 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs @@ -77,9 +77,25 @@ impl SorfMatrix { /// round-major, block-major order, with each `u64` contributing 64 sign bits in /// least-significant-bit-first order. pub fn try_new(seed: u64, dimensions: usize, num_rounds: usize) -> VortexResult { + Self::try_new_padded(dimensions.next_power_of_two(), num_rounds, seed) + } + + /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. + /// + /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded + /// logical dimension should call [`Self::try_new`] instead. + pub(crate) fn try_new_padded( + padded_dimensions: usize, + num_rounds: usize, + seed: u64, + ) -> VortexResult { vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); + vortex_ensure!( + padded_dimensions.is_power_of_two(), + "padded_dimensions must be a power of two, got {padded_dimensions}" + ); - let padded_dim = dimensions.next_power_of_two(); + let padded_dim = padded_dimensions; let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. @@ -132,8 +148,7 @@ impl SorfMatrix { /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. fn apply_srht(&self, buf: &mut [f32]) { for round in 0..self.num_rounds { - let offset = round * self.padded_dim; - apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]); + self.apply_signs_xor(buf, round); walsh_hadamard_transform(buf); } @@ -148,14 +163,24 @@ impl SorfMatrix { fn apply_inverse_srht(&self, buf: &mut [f32]) { for round in (0..self.num_rounds).rev() { walsh_hadamard_transform(buf); - let offset = round * self.padded_dim; - apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]); + self.apply_signs_xor(buf, round); } let norm = self.norm_factor; buf.iter_mut().for_each(|val| *val *= norm); } + /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. + /// + /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to + /// multiplying each element by +/-1.0, but avoids FP dependency chains. + fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { + let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } + } + /// Export the sign vectors as a flat `Vec` of 0/1 values in inverse application order /// `[D_k | ... | D₁]`. /// @@ -263,16 +288,6 @@ fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { } } -/// Apply sign masks via XOR on the IEEE 754 sign bit. -/// -/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to -/// multiplying each element by +/-1.0, but avoids FP dependency chains. -fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { - for (val, &mask) in buf.iter_mut().zip(masks.iter()) { - *val = f32::from_bits(val.to_bits() ^ mask); - } -} - /// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. /// /// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` @@ -327,14 +342,24 @@ mod tests { .collect() } + fn dim_to_usize(dim: u32) -> usize { + usize::try_from(dim).unwrap() + } + + fn rounds_to_usize(num_rounds: u8) -> usize { + usize::from(num_rounds) + } + #[test] fn deterministic_from_seed() -> VortexResult<()> { - let r1 = SorfMatrix::try_new(42, 64, 3)?; - let r2 = SorfMatrix::try_new(42, 64, 3)?; + let dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(3u8); + let r1 = SorfMatrix::try_new(42u64, dim, num_rounds)?; + let r2 = SorfMatrix::try_new(42u64, dim, num_rounds)?; let pd = r1.padded_dim(); let mut input = vec![0.0f32; pd]; - for i in 0..64 { + for i in 0..dim { input[i] = i as f32; } let mut out1 = vec![0.0f32; pd]; @@ -349,15 +374,19 @@ mod tests { #[test] fn export_inverse_signs_matches_golden_words() -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, 64, 2)?; + let dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(2u8); + let seed = 42u64; + let rot = SorfMatrix::try_new(seed, dim, num_rounds)?; + let padded_dim = rot.padded_dim(); let actual = rot.export_inverse_signs_u8(); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let round0_word = rng.next_u64(); let round1_word = rng.next_u64(); - let mut expected = Vec::with_capacity(128); - expected.extend(unpack_sign_bits(round1_word, 64)); - expected.extend(unpack_sign_bits(round0_word, 64)); + let mut expected = Vec::with_capacity(num_rounds * padded_dim); + expected.extend(unpack_sign_bits(round1_word, padded_dim)); + expected.extend(unpack_sign_bits(round0_word, padded_dim)); assert_eq!(actual, expected); Ok(()) @@ -365,25 +394,38 @@ mod tests { #[test] fn one_word_generates_64_signs_lsb_first() { - let masks = gen_sign_masks_from_seed(42, 64, 1); - assert_eq!(masks.len(), 64); + let seed = 42u64; + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let word = rng.next_u64(); - let expected: Vec<_> = (0..64) + let expected: Vec<_> = (0..padded_dim) .map(|bit_idx| sign_mask_from_word(word, bit_idx)) .collect(); assert_eq!(masks, expected); } + #[test] + fn accepts_non_power_of_two_dimensions() -> VortexResult<()> { + let rot = SorfMatrix::try_new(42u64, dim_to_usize(100u32), rounds_to_usize(3u8))?; + assert_eq!(rot.padded_dim(), 128); + Ok(()) + } + #[test] fn tail_block_uses_only_required_bits() { - let masks = gen_sign_masks_from_seed(42, 32, 1); - assert_eq!(masks.len(), 32); + let seed = 42u64; + let padded_dim = dim_to_usize(32u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let word = rng.next_u64(); - let expected: Vec<_> = (0..32) + let expected: Vec<_> = (0..padded_dim) .map(|bit_idx| sign_mask_from_word(word, bit_idx)) .collect(); assert_eq!(masks, expected); @@ -392,19 +434,21 @@ mod tests { /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, /// including non-power-of-two dimensions that require padding. #[rstest] - #[case(32, 3)] - #[case(64, 3)] - #[case(100, 3)] - #[case(128, 1)] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 5)] - #[case(256, 3)] - #[case(512, 3)] - #[case(768, 3)] - #[case(1024, 3)] - fn roundtrip_exact(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, dim, num_rounds)?; + #[case(32u32, 3u8)] + #[case(64u32, 3u8)] + #[case(100u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 2u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(256u32, 3u8)] + #[case(512u32, 3u8)] + #[case(768u32, 3u8)] + #[case(1024u32, 3u8)] + fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -435,12 +479,14 @@ mod tests { /// Verify norm preservation across dimensions and round counts. #[rstest] - #[case(128, 1)] - #[case(128, 3)] - #[case(128, 5)] - #[case(768, 3)] - fn preserves_norm(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = SorfMatrix::try_new(7, dim, num_rounds)?; + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -465,16 +511,15 @@ mod tests { /// Verify that export -> [`from_u8_slice`] produces identical transform output. #[rstest] - #[case(64, 3)] - #[case(128, 1)] - #[case(128, 3)] - #[case(128, 5)] - #[case(768, 3)] - fn sign_export_import_roundtrip( - #[case] dim: usize, - #[case] num_rounds: usize, - ) -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, dim, num_rounds)?; + #[case(64u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn sign_export_import_roundtrip(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let signs_u8 = rot.export_inverse_signs_u8(); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 46abc66db71..c3aaf8d3cd5 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -65,8 +65,8 @@ fn forward_rotate_and_quantize( } } - let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; let centroids = compute_or_get_centroids(padded_dim as u32, bit_width)?; let boundaries = compute_centroid_boundaries(¢roids); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 827f8e6a796..0932c4cc498 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -164,7 +164,8 @@ impl ScalarFnVTable for SorfTransform { let f32_elements = elements_prim.into_buffer::(); // Reconstruct the orthogonal transform matrix from the seed. - let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; + let rotation = + SorfMatrix::try_new_padded(padded_dim, options.num_rounds as usize, options.seed)?; // Inverse transform each row, truncate to original dimension, cast to target type. match_each_float_ptype!(options.element_ptype, |T| { diff --git a/vortex-tensor/src/types/vector/mod.rs b/vortex-tensor/src/types/vector/mod.rs index 3763220fc82..a66aaf76982 100644 --- a/vortex-tensor/src/types/vector/mod.rs +++ b/vortex-tensor/src/types/vector/mod.rs @@ -49,7 +49,7 @@ impl Vector { /// # Errors /// /// Returns an error if the [`Vector`] extension dtype rejects the storage array. - pub(crate) fn try_new_vector_array(storage: ArrayRef) -> VortexResult { + pub fn try_new_vector_array(storage: ArrayRef) -> VortexResult { ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage) .map(|ext| ext.into_array()) } diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml new file mode 100644 index 00000000000..ab3f63583d3 --- /dev/null +++ b/vortex-turboquant/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "vortex-turboquant" +authors = { workspace = true } +categories = { workspace = true } +description = "TurboQuant vector extension type" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +half = { workspace = true } +num-traits = { workspace = true } +prost = { workspace = true } +vortex-array = { workspace = true } +vortex-buffer = { workspace = true } +vortex-error = { workspace = true } +vortex-mask = { workspace = true } +vortex-session = { workspace = true } +vortex-tensor = { workspace = true } +vortex-utils = { workspace = true, features = ["dashmap"] } + +[dev-dependencies] +divan = { workspace = true } +rand = { workspace = true } +rstest = { workspace = true } +vortex-file = { workspace = true } +vortex-io = { workspace = true } +vortex-layout = { workspace = true } + +[[bench]] +name = "encode_decode" +harness = false diff --git a/vortex-turboquant/benches/encode_decode.rs b/vortex-turboquant/benches/encode_decode.rs new file mode 100644 index 00000000000..f88c37c347a --- /dev/null +++ b/vortex-turboquant/benches/encode_decode.rs @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for `turboquant_encode` and `turboquant_decode` across different validity-mask +//! shapes. +//! +//! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the +//! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, +//! `vector/quantize.rs`, and `scalar_fns/decode.rs`. + +#![expect(clippy::unwrap_used)] + +use std::sync::LazyLock; + +use divan::Bencher; +use rand::RngExt; +use rand::SeedableRng as _; +use rand::rngs::StdRng; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_session::VortexSession; +use vortex_tensor::vector::Vector; +use vortex_turboquant::TQDecode; +use vortex_turboquant::TQEncode; +use vortex_turboquant::TurboQuantConfig; + +fn main() { + divan::main(); +} + +static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + vortex_turboquant::initialize(&session); + session +}); + +/// Shape of the validity mask used to drive the variant-specialized paths. +#[derive(Copy, Clone)] +enum MaskShape { + AllValid, + AllInvalid, + DenseValues, + SparseValues, +} + +impl std::fmt::Debug for MaskShape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + MaskShape::AllValid => "all_valid", + MaskShape::AllInvalid => "all_invalid", + MaskShape::DenseValues => "dense_95pct", + MaskShape::SparseValues => "sparse_5pct", + }) + } +} + +impl MaskShape { + fn build(self, rows: usize, rng: &mut StdRng) -> Validity { + match self { + MaskShape::AllValid => Validity::NonNullable, + MaskShape::AllInvalid => Validity::AllInvalid, + MaskShape::DenseValues => Validity::from_iter((0..rows).map(|_| rng.random_bool(0.95))), + MaskShape::SparseValues => { + Validity::from_iter((0..rows).map(|_| rng.random_bool(0.05))) + } + } + } +} + +const MASK_SHAPES: &[MaskShape] = &[ + MaskShape::AllValid, + MaskShape::AllInvalid, + MaskShape::DenseValues, + MaskShape::SparseValues, +]; + +const ROWS: usize = 4096; +const DIMENSIONS: u32 = 128; + +fn build_vector_array(shape: MaskShape) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(0xC0FFEE); + let dim = DIMENSIONS as usize; + let values: Buffer = (0..ROWS * dim).map(|_| rng.random::()).collect(); + let elements = PrimitiveArray::new::(values, Validity::NonNullable); + let validity = shape.build(ROWS, &mut rng); + let fsl = + FixedSizeListArray::try_new(elements.into_array(), DIMENSIONS, validity, ROWS).unwrap(); + + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array()) + .unwrap() + .into_array() +} + +fn encode(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { + TQEncode::try_new_array(vec, config) + .unwrap() + .into_array() + .execute(ctx) + .unwrap() +} + +fn decode(encoded: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef { + TQDecode::try_new_array(encoded) + .unwrap() + .into_array() + .execute(ctx) + .unwrap() +} + +fn config() -> TurboQuantConfig { + // 4 bits, 4 SORF rounds, fixed seed: representative defaults from the test fixtures. + TurboQuantConfig::try_new(4, 0xDEADBEEF, 4).unwrap() +} + +#[divan::bench(args = MASK_SHAPES)] +fn turboquant_encode(bencher: Bencher, shape: &MaskShape) { + let shape = *shape; + let cfg = config(); + bencher + .with_inputs(|| (build_vector_array(shape), SESSION.create_execution_ctx())) + .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) + .bench_values(|(arr, mut ctx)| encode(arr, &cfg, &mut ctx)) +} + +#[divan::bench(args = MASK_SHAPES)] +fn turboquant_decode(bencher: Bencher, shape: &MaskShape) { + let shape = *shape; + let cfg = config(); + bencher + .with_inputs(|| { + let arr = build_vector_array(shape); + let mut ctx = SESSION.create_execution_ctx(); + let encoded = encode(arr, &cfg, &mut ctx); + (encoded, SESSION.create_execution_ctx()) + }) + .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) + .bench_values(|(encoded, mut ctx)| decode(encoded, &mut ctx)) +} diff --git a/vortex-turboquant/public-api.lock b/vortex-turboquant/public-api.lock new file mode 100644 index 00000000000..ac6b8eeac5b --- /dev/null +++ b/vortex-turboquant/public-api.lock @@ -0,0 +1,199 @@ +pub mod vortex_turboquant + +pub struct vortex_turboquant::TQDecode + +impl vortex_turboquant::TQDecode + +pub fn vortex_turboquant::TQDecode::new() -> vortex_array::scalar_fn::typed::TypedScalarFnInstance + +pub fn vortex_turboquant::TQDecode::try_new_array(vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TQDecode + +pub fn vortex_turboquant::TQDecode::clone(&self) -> vortex_turboquant::TQDecode + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQDecode + +pub type vortex_turboquant::TQDecode::Options = vortex_array::extension::EmptyMetadata + +pub fn vortex_turboquant::TQDecode::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_turboquant::TQDecode::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_turboquant::TQDecode::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQDecode::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQDecode::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_turboquant::TQDecode::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_turboquant::TQDecode::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQDecode::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQDecode::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQDecode::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TQDecode::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_turboquant::TQEncode + +impl vortex_turboquant::TQEncode + +pub fn vortex_turboquant::TQEncode::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance + +pub fn vortex_turboquant::TQEncode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TQEncode + +pub fn vortex_turboquant::TQEncode::clone(&self) -> vortex_turboquant::TQEncode + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQEncode + +pub type vortex_turboquant::TQEncode::Options = vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TQEncode::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_turboquant::TQEncode::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_turboquant::TQEncode::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQEncode::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQEncode::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_turboquant::TQEncode::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_turboquant::TQEncode::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQEncode::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQEncode::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQEncode::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TQEncode::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_turboquant::TurboQuant + +impl core::clone::Clone for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::clone(&self) -> vortex_turboquant::TurboQuant + +impl core::cmp::Eq for vortex_turboquant::TurboQuant + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::eq(&self, &vortex_turboquant::TurboQuant) -> bool + +impl core::default::Default for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::default() -> vortex_turboquant::TurboQuant + +impl core::fmt::Debug for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuant + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_turboquant::TurboQuant + +pub type vortex_turboquant::TurboQuant::Metadata = vortex_turboquant::TurboQuantMetadata + +pub type vortex_turboquant::TurboQuant::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_turboquant::TurboQuant::deserialize_metadata(&self, &[u8]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_turboquant::TurboQuant::serialize_metadata(&self, &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TurboQuant::unpack_native<'a>(&'a vortex_array::dtype::extension::typed::ExtDType, &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::validate_dtype(&vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> + +pub struct vortex_turboquant::TurboQuantConfig + +impl vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::bit_width(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantConfig::num_rounds(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantConfig::seed(&self) -> u64 + +pub fn vortex_turboquant::TurboQuantConfig::try_new(u8, u64, u8) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::clone(&self) -> vortex_turboquant::TurboQuantConfig + +impl core::cmp::Eq for vortex_turboquant::TurboQuantConfig + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::eq(&self, &vortex_turboquant::TurboQuantConfig) -> bool + +impl core::default::Default for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::default() -> Self + +impl core::fmt::Debug for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuantConfig + +pub struct vortex_turboquant::TurboQuantMetadata + +pub vortex_turboquant::TurboQuantMetadata::bit_width: u8 + +pub vortex_turboquant::TurboQuantMetadata::dimensions: u32 + +pub vortex_turboquant::TurboQuantMetadata::element_ptype: vortex_array::dtype::ptype::PType + +pub vortex_turboquant::TurboQuantMetadata::num_rounds: u8 + +pub vortex_turboquant::TurboQuantMetadata::seed: u64 + +impl core::clone::Clone for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::clone(&self) -> vortex_turboquant::TurboQuantMetadata + +impl core::cmp::Eq for vortex_turboquant::TurboQuantMetadata + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::eq(&self, &vortex_turboquant::TurboQuantMetadata) -> bool + +impl core::fmt::Debug for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::Copy for vortex_turboquant::TurboQuantMetadata + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::initialize(&vortex_session::VortexSession) diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs new file mode 100644 index 00000000000..8499dfc397a --- /dev/null +++ b/vortex-turboquant/src/centroids.rs @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes and caches optimal scalar quantizer centroids for the marginal distribution of +//! coordinates after a random orthogonal transform of a unit-norm vector. +//! +//! In high dimensions, each coordinate of a randomly transformed unit vector follows a +//! distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, which converges to +//! `N(0, 1/d)`. +//! +//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this +//! distribution. +//! +//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from +//! `(padded_dim, bit_width)` and cached process-locally. +//! +//! The centroid model follows the random orthogonal transform marginal used by the TurboQuant +//! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian +//! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this +//! implementation without separate empirical validation. + +use std::sync::LazyLock; + +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_utils::aliases::dash_map::DashMap; + +use crate::config::MAX_BIT_WIDTH; +use crate::config::MIN_DIMENSION; + +// NB: Some of these numbers were arbitrarily chosen... + +/// The maximum iterations for Max-Lloyd algorithm when computing centroids. +const MAX_ITERATIONS: usize = 200; + +/// The Max-Lloyd convergence threshold for stopping early when computing centroids. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar +/// quantization levels for the coordinate distribution after a random orthogonal transform in +/// `dimension`-dimensional space. +pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + vortex_ensure!( + (1..=MAX_BIT_WIDTH).contains(&bit_width), + "TurboQuant bit_width must be 1-{}, got {bit_width}", + MAX_BIT_WIDTH + ); + vortex_ensure!( + dimension >= MIN_DIMENSION, + "TurboQuant dimension must be >= {}, got {dimension}", + MIN_DIMENSION + ); + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + + Ok(centroids) +} + +// TODO(connor): It would potentially be more performant if this was modelled as const generic +// parameters to functions. +/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. +/// +/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a +/// half-integer (when `d` is even). +/// +/// This type makes that invariant explicit and avoids floating-point comparison in the hot path. +#[derive(Clone, Copy, Debug)] +struct HalfIntExponent { + int_part: i32, + has_half: bool, +} + +impl HalfIntExponent { + /// Compute `(numerator) / 2` as a half-integer exponent. + /// + /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. + fn from_numerator(numerator: i32) -> Self { + // Use Euclidean division to get floor division toward negative infinity. + let int_part = numerator.div_euclid(2); + let has_half = numerator.rem_euclid(2) != 0; + Self { int_part, has_half } + } +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly transformed unit +/// vector in d dimensions. +/// +/// The probability distribution function is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { + debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); + let num_centroids = 1usize << bit_width; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = mean_between_centroids(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + #[expect( + clippy::cast_possible_truncation, + reason = "all values are in [-1, 1] so this just loses precision" + )] + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1]. +/// +/// Since there is no closed form for the integrals, we compute this numerically. +fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. +/// This is significantly faster than the general `powf` which goes through +/// `exp(exponent * ln(base))`. +fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + if exponent.has_half { + // Half-integer exponent: base^(int_part) * sqrt(base). + base.powi(exponent.int_part) * base.sqrt() + } else { + // Integer exponent: use powi directly. + base.powi(exponent.int_part) + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a +/// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a +/// value `>= boundaries[k-2]` maps to centroid `k-1`. +pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_centroid_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +pub(crate) fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + debug_assert!( + boundaries.len() <= 256, // 1 << 8 + "too many boundaries" + ); + + #[expect( + clippy::cast_possible_truncation, + reason = "num_centroids <= 256 and partition_point will return at most 255" + )] + (boundaries.partition_point(|&b| b < value) as u8) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + for &val in centroids.iter() { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = compute_or_get_centroids(128, 2)?; + let c2 = compute_or_get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = compute_or_get_centroids(128, 2)?; + let boundaries = compute_centroid_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + #[expect(clippy::cast_possible_truncation)] + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + #[expect(clippy::cast_possible_truncation)] + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(compute_or_get_centroids(128, 0).is_err()); + assert!(compute_or_get_centroids(128, 9).is_err()); + assert!(compute_or_get_centroids(1, 2).is_err()); + assert!(compute_or_get_centroids(127, 2).is_err()); + } +} diff --git a/vortex-turboquant/src/config.rs b/vortex-turboquant/src/config.rs new file mode 100644 index 00000000000..57cd8b1e94b --- /dev/null +++ b/vortex-turboquant/src/config.rs @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; + +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// Minimum vector dimension for TurboQuant encoding. +/// +/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total +/// amount of distortion. +pub(crate) const MIN_DIMENSION: u32 = 128; + +/// Maximum supported number of bits per quantized coordinate. +pub(crate) const MAX_BIT_WIDTH: u8 = 8; + +/// Configuration for lossy TurboQuant encoding. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct TurboQuantConfig { + bit_width: u8, + seed: u64, + num_rounds: u8, +} + +impl TurboQuantConfig { + /// Build a TurboQuant configuration. + /// + /// # Errors + /// + /// Returns an error if `bit_width` is outside `1..=8` or `num_rounds` is zero. + pub fn try_new(bit_width: u8, seed: u64, num_rounds: u8) -> VortexResult { + vortex_ensure!( + (1..=MAX_BIT_WIDTH).contains(&bit_width), + "TurboQuant bit_width must be 1-{MAX_BIT_WIDTH}, got {bit_width}", + ); + vortex_ensure!( + num_rounds > 0, + "TurboQuant num_rounds must be > 0, got {num_rounds}" + ); + + Ok(Self { + bit_width, + seed, + num_rounds, + }) + } + + /// Bits per coordinate in the scalar quantizer codebook. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Seed used to derive the deterministic SORF transform. + pub fn seed(&self) -> u64 { + self.seed + } + + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + pub fn num_rounds(&self) -> u8 { + self.num_rounds + } +} + +impl Default for TurboQuantConfig { + /// Defaults to 8 bits per coordinate, seed 42, and 3 SORF rounds. + fn default() -> Self { + Self { + bit_width: MAX_BIT_WIDTH, + seed: 42, + num_rounds: 3, + } + } +} + +impl fmt::Display for TurboQuantConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "bit_width: {}, seed: {}, num_rounds: {}", + self.bit_width, self.seed, self.num_rounds + ) + } +} diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs new file mode 100644 index 00000000000..7aeb60368dd --- /dev/null +++ b/vortex-turboquant/src/lib.rs @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization extension type for Vortex. +//! +//! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy compression +//! of high-dimensional vector data. The extension operates on +//! [`Vector`](vortex_tensor::vector::Vector) extension arrays, encoding their `FixedSizeList` +//! storage into quantized codes after a structured orthogonal surrogate transform. +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! [RFC 0033]: https://vortex-data.github.io/rfcs/rfc/0033.html +//! +//! # Overview +//! +//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) +//! using MSE-optimal scalar quantization on coordinates of a transformed unit vector. +//! +//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector +//! row, then normalizes each valid nonzero row internally before SORF transform and scalar +//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, +//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the +//! stored norm. +//! +//! The encoded storage is a row-aligned extension tree: +//! +//! ```text +//! Extension( +//! Struct { +//! norms: Primitive, +//! codes: FixedSizeList, padded_dim, vector_validity>, +//! } +//! ) +//! ``` +//! +//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized +//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. +//! +//! # Source map +//! +//! Implementation details are documented next to the code that owns them: +//! +//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level +//! validity for null vectors. +//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor +//! crate's null-row zeroing helper. +//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather +//! than quantized. +//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. +//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. +//! +//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL +//! residual correction for unbiased inner-product estimation, and it still uses internal +//! power-of-2 padding rather than the block decomposition proposed in RFC 0033. + +mod centroids; +mod config; +mod scalar_fns; +mod sorf; +mod vector; +mod vtable; + +pub use config::TurboQuantConfig; +pub use scalar_fns::TQDecode; +pub use scalar_fns::TQEncode; +pub use vtable::TurboQuant; +pub use vtable::TurboQuantMetadata; + +// TODO(connor): We need to somehow make sure that callers call `vortex_tensor::initialize` first. +/// Register the TurboQuant extension type with a Vortex session. +pub fn initialize(session: &vortex_session::VortexSession) { + use vortex_array::dtype::session::DTypeSessionExt; + use vortex_array::scalar_fn::session::ScalarFnSessionExt; + session.dtypes().register(TurboQuant); + + session.scalar_fns().register(TQEncode); + session.scalar_fns().register(TQDecode); +} + +#[cfg(test)] +mod tests; diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs new file mode 100644 index 00000000000..22e4f2560b8 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -0,0 +1,307 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant decode scalar function. + +use std::fmt; +use std::fmt::Formatter; +use std::sync::Arc; + +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::expr::Expression; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::scalar_fn::TypedScalarFnInstance; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_mask::Mask; +use vortex_session::VortexSession; +use vortex_tensor::vector::Vector; + +use crate::centroids::compute_or_get_centroids; +use crate::sorf::SorfMatrix; +use crate::vector::storage::parse_storage; +use crate::vector::tq_padded_dim; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_metadata; + +/// Lazy TurboQuant vector decode scalar function. +#[derive(Clone)] +pub struct TQDecode; + +impl TQDecode { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant decoding. + pub fn new() -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQDecode, EmptyMetadata) + } + + /// Constructs a [`ScalarFnArray`] that lazily decodes a `TurboQuant` child into a `Vector`. + pub fn try_new_array(child: ArrayRef) -> VortexResult { + let len = child.len(); + ScalarFnArray::try_new(TQDecode::new().erased(), vec![child], len) + } +} + +impl ScalarFnVTable for TQDecode { + type Options = EmptyMetadata; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new("vortex.turboquant.decode") + } + + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + vortex_ensure!( + metadata.is_empty(), + "TQDecode options metadata must be empty" + ); + + Ok(EmptyMetadata) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("turboquant"), + _ => unreachable!("TQDecode must have exactly one child"), + } + } + + fn fmt_sql( + &self, + _options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "tq_decode(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ")") + } + + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let child_dtype = &arg_dtypes[0]; + let metadata = tq_metadata(child_dtype)?; + + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive( + metadata.element_ptype, + Nullability::NonNullable, + )), + metadata.dimensions, + child_dtype.nullability(), + ); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); + + Ok(DType::Extension(ext_dtype)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + decode_vector(args.get(0)?, ctx) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Decode a `TurboQuant` extension array back into a `Vector` extension array. +/// +/// The decoded directions are inverse-transformed, truncated to the original dimension, and +/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with +/// [`TQEncode`](crate::TQEncode). +pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let parsed = parse_storage(input, ctx)?; + let metadata = parsed.metadata; + if parsed.len == 0 { + return build_empty_vector(metadata, parsed.vector_validity); + } + + let padded_dim = tq_padded_dim(metadata.dimensions)?; + let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; + let padded_dim = u32::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; + + match_each_float_ptype!(metadata.element_ptype, |T| { + decode_typed::( + DecodeInputs { + metadata: &metadata, + sorf_matrix: &transform, + centroids: ¢roids, + norms: &parsed.norms, + codes: &parsed.codes, + }, + parsed.vector_validity, + parsed.len, + ctx, + ) + }) +} + +fn build_empty_vector( + metadata: TurboQuantMetadata, + vector_validity: Validity, +) -> VortexResult { + match_each_float_ptype!(metadata.element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + 0, + )?; + + Vector::try_new_vector_array(fsl.into_array()) + }) +} + +struct DecodeInputs<'a> { + metadata: &'a TurboQuantMetadata, + sorf_matrix: &'a SorfMatrix, + centroids: &'a [f32], + norms: &'a PrimitiveArray, + codes: &'a PrimitiveArray, +} + +fn decode_typed( + decode: DecodeInputs<'_>, + vector_validity: Validity, + num_vectors: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType + Float + FromPrimitive, +{ + let metadata = decode.metadata; + let dimensions = usize::try_from(metadata.dimensions) + .vortex_expect("dimensions stays representable as usize"); + let padded_dim = decode.sorf_matrix.padded_dim(); + let centroids = decode.centroids; + let norms = decode.norms.as_slice::(); + let codes = decode.codes.as_slice::(); + let mask = vector_validity.execute_mask(num_vectors, ctx)?; + + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + let mut decoded = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; + + let mut decode_row = |output: &mut BufferMut, i: usize| { + let code_row = &codes[i * padded_dim..][..padded_dim]; + + for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { + *dst = *centroids + .get(usize::from(code)) + .vortex_expect("TurboQuant code exceeds centroid count"); + } + + decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); + + let norm = norms[i]; + for &value in inverse.iter().take(dimensions) { + // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, + // `f64`): values outside `f16` range saturate to `±inf` rather than returning + // `None`. + let value = T::from_f32(value) + .vortex_expect("from_f32 is infallible for supported float types"); + + // SAFETY: total pushes across all match arms equal `output_len`. + unsafe { output.push_unchecked(value * norm) }; + } + }; + + match &mask { + Mask::AllFalse(_) => { + // SAFETY: `output` was allocated with capacity `output_len`, and this push writes + // exactly `output_len` zero placeholders. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for i in 0..num_vectors { + decode_row(&mut output, i); + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + + for i in start..end { + decode_row(&mut output, i); + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + num_vectors, + )?; + + Vector::try_new_vector_array(fsl.into_array()) +} diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs new file mode 100644 index 00000000000..29ce7cc580a --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant encode scalar function. + +use std::fmt; +use std::fmt::Formatter; + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::expr::Expression; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::scalar_fn::TypedScalarFnInstance; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_session::VortexSession; +use vortex_tensor::vector::AnyVector; + +use super::metadata::deserialize_config; +use super::metadata::serialize_config; +use crate::TurboQuantConfig; +use crate::config::MIN_DIMENSION; +use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::vector::quantize::empty_quantization; +use crate::vector::quantize::turboquant_quantize_core; +use crate::vector::storage::build_codes_child; +use crate::vector::storage::build_storage; +use crate::vector::tq_padded_dim; +use crate::vtable::TurboQuant; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_storage_dtype; + +/// TurboQuant vector encode scalar function. +/// +/// `TQEncode` itself is a `ScalarFnVTable` and so its options round-trip through expression +/// serialization. +/// +/// Unlike `TQDecode`, it deliberately does **not** implement `ScalarFnArrayVTable` since the +/// persisted artifact would be the original vector array, not the TurboQuant-quantized array. +#[derive(Clone)] +pub struct TQEncode; + +impl TQEncode { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant encoding. + pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQEncode, config.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily encodes a `Vector` child into `TurboQuant`. + pub fn try_new_array( + child: ArrayRef, + config: &TurboQuantConfig, + ) -> VortexResult { + let len = child.len(); + ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child], len) + } +} + +impl ScalarFnVTable for TQEncode { + type Options = TurboQuantConfig; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new("vortex.turboquant.encode") + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + Ok(Some(serialize_config(options))) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + deserialize_config(metadata) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("vector"), + _ => unreachable!("TQEncode must have exactly one child"), + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "tq_encode(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", {options})") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let input_dtype = &arg_dtypes[0]; + let vector_metadata = input_dtype + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TQEncode expects a Vector extension array, got {input_dtype}") + })?; + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + tq_padded_dim(dimensions)?; + + let metadata = TurboQuantMetadata { + element_ptype: vector_metadata.element_ptype(), + dimensions, + bit_width: options.bit_width(), + seed: options.seed(), + num_rounds: options.num_rounds(), + }; + let storage_dtype = tq_storage_dtype(&metadata, input_dtype.nullability())?; + let ext_dtype = ExtDType::::try_new(metadata, storage_dtype)?.erased(); + + Ok(DType::Extension(ext_dtype)) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + encode_vector(args.get(0)?, options, ctx) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. +/// +/// Valid rows are normalized internally before SORF transform and scalar quantization. The original +/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct +/// and both row-aligned child arrays. +pub(crate) fn encode_vector( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_vectors = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; + + let element_ptype = vector_metadata.element_ptype(); + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + let padded_dim = tq_padded_dim(dimensions)?; + + let vector_validity = input.validity()?; + + let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + + let normalized_ext = normalized + .as_opt::() + .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; + let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; + + let core = if normalized_fsl.is_empty() { + empty_quantization(padded_dim) + } else { + // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. + unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } + }; + let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + + let metadata = TurboQuantMetadata { + element_ptype, + dimensions, + bit_width: config.bit_width(), + seed: config.seed(), + num_rounds: config.num_rounds(), + }; + let storage = build_storage(norms, codes, num_vectors, vector_validity)?; + + Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) +} diff --git a/vortex-turboquant/src/scalar_fns/metadata.rs b/vortex-turboquant/src/scalar_fns/metadata.rs new file mode 100644 index 00000000000..f5eddfe51d9 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/metadata.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use prost::Message; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::TurboQuantConfig; + +#[derive(Clone, PartialEq, Message)] +pub(super) struct TQScalarFnMetadata { + #[prost(uint32, tag = "1")] + bit_width: u32, + #[prost(uint64, tag = "2")] + seed: u64, + #[prost(uint32, tag = "3")] + num_rounds: u32, +} + +impl TQScalarFnMetadata { + pub(super) fn from_config(config: &TurboQuantConfig) -> Self { + Self { + bit_width: u32::from(config.bit_width()), + seed: config.seed(), + num_rounds: u32::from(config.num_rounds()), + } + } + + pub(super) fn to_config(&self) -> VortexResult { + let bit_width = u8::try_from(self.bit_width) + .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; + let num_rounds = u8::try_from(self.num_rounds) + .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; + + TurboQuantConfig::try_new(bit_width, self.seed, num_rounds) + } +} + +pub(super) fn serialize_config(config: &TurboQuantConfig) -> Vec { + TQScalarFnMetadata::from_config(config).encode_to_vec() +} + +pub(super) fn deserialize_config(metadata: &[u8]) -> VortexResult { + TQScalarFnMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode TurboQuant scalar function metadata: {e}"))? + .to_config() +} diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs new file mode 100644 index 00000000000..1acea9f70f3 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Scalar functions for lazy TurboQuant vector encode and decode operations. + +mod decode; +mod encode; +mod metadata; + +pub use decode::TQDecode; +pub use encode::TQEncode; diff --git a/vortex-turboquant/src/sorf/mod.rs b/vortex-turboquant/src/sorf/mod.rs new file mode 100644 index 00000000000..cce477aa906 --- /dev/null +++ b/vortex-turboquant/src/sorf/mod.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod splitmix64; +mod transform; + +pub(crate) use transform::SorfMatrix; diff --git a/vortex-turboquant/src/sorf/splitmix64.rs b/vortex-turboquant/src/sorf/splitmix64.rs new file mode 100644 index 00000000000..1233e4dc7ee --- /dev/null +++ b/vortex-turboquant/src/sorf/splitmix64.rs @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Frozen local SplitMix64 stream used to define SORF sign diagonals. +//! +//! This is a direct translation of the `splitmix64.c` [reference implementation][impl]. +//! +//! The state is a single `u64`, and `next_u64()` first adds [`SPLITMIX64_INCREMENT`] with wrapping +//! arithmetic, then applies the two reference mixing steps and final xor-shift. +//! +//! [impl]: https://prng.di.unimi.it/splitmix64.c + +/// SplitMix64 additive constant from the reference implementation. +const SPLITMIX64_INCREMENT: u64 = 0x9E37_79B9_7F4A_7C15; + +/// First SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; + +/// Second SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; + +/// Frozen local SplitMix64 stream used to define SORF sign diagonals. +pub(crate) struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + pub(crate) fn new(seed: u64) -> Self { + Self { state: seed } + } + + pub(crate) fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(SPLITMIX64_INCREMENT); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(SPLITMIX64_MUL1); + z = (z ^ (z >> 27)).wrapping_mul(SPLITMIX64_MUL2); + z ^ (z >> 31) + } +} + +#[cfg(test)] +mod tests { + use super::SplitMix64; + + const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ + 0xE220_A839_7B1D_CDAF, + 0x6E78_9E6A_A1B9_65F4, + 0x06C4_5D18_8009_454F, + 0xF88B_B8A8_724C_81EC, + ]; + + const SPLITMIX64_SEED42_GOLDEN: [u64; 4] = [ + 0xBDD7_3226_2FEB_6E95, + 0x28EF_E333_B266_F103, + 0x4752_6757_130F_9F52, + 0x581C_E1FF_0E4A_E394, + ]; + + #[test] + fn splitmix64_seed0_matches_golden_outputs() { + let mut rng = SplitMix64::new(0); + let actual: Vec<_> = (0..SPLITMIX64_SEED0_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED0_GOLDEN); + } + + #[test] + fn splitmix64_seed42_matches_golden_outputs() { + let mut rng = SplitMix64::new(42); + let actual: Vec<_> = (0..SPLITMIX64_SEED42_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); + } +} diff --git a/vortex-turboquant/src/sorf/transform.rs b/vortex-turboquant/src/sorf/transform.rs new file mode 100644 index 00000000000..3fa221fe03a --- /dev/null +++ b/vortex-turboquant/src/sorf/transform.rs @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! SORF (Structured Orthogonal Random Features) orthogonal transform. +//! +//! Implements the SORF construction from [Yu et al. 2016][sorf-paper]: a fast structured +//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the +//! Fast Walsh-Hadamard Transform (FWHT). +//! +//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf +//! +//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the +//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round +//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost +//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix. +//! +//! This implementation defines those sign diagonals using a frozen local SplitMix64 stream rather +//! than an +//! external RNG crate. The contract is: +//! +//! - state is a single `u64` seed, +//! - each `next_u64()` call uses the SplitMix64 reference algorithm with wrapping `u64` +//! arithmetic, +//! - signs are generated in round-major, block-major order, +//! - each generated `u64` contributes 64 signs in least-significant-bit-first order, +//! - bit `1` means `+1` and bit `0` means `-1`. +//! +//! This makes SORF sign generation stable as an extension format contract even if external RNG +//! implementations change. +//! +//! This transform is the crate's practical structured transform choice for TurboQuant. It is not +//! the dense random Gaussian or orthogonal matrix used by some theoretical analyses, so theoretical +//! bounds from those models need separate validation before being presented as implementation +//! guarantees. +//! +//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2 +//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n) +//! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard +//! matrix is ever materialized. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before +//! the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for +//! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of +//! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into +//! `vpxor`/`veor`. + +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use super::splitmix64::SplitMix64; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A Walsh-Hadamard-based structured orthogonal transform matrix. +/// +/// All computation is done in f32. The sign diagonals are stored as IEEE 754 XOR masks on +/// f32 bit patterns, and the Walsh-Hadamard butterfly operates on `&mut [f32]` slices. +pub(crate) struct SorfMatrix { + /// Flat XOR masks for all `num_rounds` diagonal matrices, total length + /// `num_rounds * padded_dim`. + /// + /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = + /// multiply by -1 (flip sign bit). + sign_masks: Vec, + + /// The number of sign-diagonal + WHT rounds. + num_rounds: usize, + + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + + /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. + /// + /// This is stored for convenience. + norm_factor: f32, +} + +impl SorfMatrix { + /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. + /// + /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded + /// logical dimension are responsible for padding it before constructing the matrix. + pub(crate) fn try_new( + padded_dimensions: usize, + num_rounds: usize, + seed: u64, + ) -> VortexResult { + vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); + vortex_ensure!( + padded_dimensions.is_power_of_two(), + "padded_dimensions must be a power of two, got {padded_dimensions}" + ); + + let padded_dim = padded_dimensions; + let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + + // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. + // The result is always in (0, 1] for any valid padded_dim >= 2 and num_rounds >= 1, so + // the f64 -> f32 cast is a precision loss only (it cannot overflow to infinity). + #[expect( + clippy::cast_possible_truncation, + reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" + )] + let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; + + Ok(Self { + sign_masks, + num_rounds, + padded_dim, + norm_factor, + }) + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All `transform`/`inverse_transform` buffers must be this length. + pub(crate) fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the forward orthogonal transform: `output = R(input)`. + /// + /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is + /// responsible for zero-padding input beyond `dim` positions. + pub(crate) fn transform(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub(crate) fn inverse_transform(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. + fn apply_srht(&self, buf: &mut [f32]) { + for round in 0..self.num_rounds { + self.apply_signs_xor(buf, round); + walsh_hadamard_transform(buf); + } + + buf.iter_mut().for_each(|val| *val *= self.norm_factor); + } + + /// Apply the inverse structured transform. + /// + /// Forward is: `norm · H · D_k · ... · H · D₁`. + /// Inverse is: `norm · D₁ · H · ... · D_k · H`. + fn apply_inverse_srht(&self, buf: &mut [f32]) { + for round in (0..self.num_rounds).rev() { + walsh_hadamard_transform(buf); + self.apply_signs_xor(buf, round); + } + + buf.iter_mut().for_each(|val| *val *= self.norm_factor); + } + + /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. + /// + /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to + /// multiplying each element by +/-1.0, but avoids FP dependency chains. + fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { + let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } + } +} + +/// Generate XOR sign masks from the frozen local SplitMix64 stream. +/// +/// Signs are produced in round-major, block-major order. For each block we call +/// [`SplitMix64::next_u64`] exactly once and unpack its bits from least significant to most +/// significant. Bit `1` means positive sign / `0x00000000`; bit `0` means negative sign / +/// [`F32_SIGN_BIT`]. +fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec { + let mut rng = SplitMix64::new(seed); + let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); + + for _round in 0..num_rounds { + for base_idx in (0..padded_dim).step_by(64) { + let word = rng.next_u64(); + let bits_in_block = (padded_dim - base_idx).min(64); + sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx))); + } + } + + sign_masks +} + +/// Convert one bit from a SplitMix64 output word into an XOR sign mask. +fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { + if ((word >> bit_idx) & 1) != 0 { + 0u32 + } else { + F32_SIGN_BIT + } +} + +/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. +/// +/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` +/// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids +/// materializing the full Hadamard matrix. +/// +/// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call +/// into NEON/AVX SIMD instructions. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `(lo[i], hi[i]) -> (lo[i] + hi[i], lo[i] - hi[i])`. +/// +/// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element +/// pair. Factored into a separate function so LLVM can see the slice lengths match and +/// auto-vectorize. +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + fn dim_to_usize(dim: u32) -> usize { + usize::try_from(dim).unwrap() + } + + fn rounds_to_usize(num_rounds: u8) -> usize { + usize::from(num_rounds) + } + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(3u8); + let seed = 42u64; + let transform1 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; + let transform2 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; + let pd = transform1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..padded_dim { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + transform1.transform(&input, &mut out1); + transform2.transform(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + #[test] + fn one_word_generates_64_signs_lsb_first() { + let seed = 42u64; + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); + + let mut rng = SplitMix64::new(seed); + let word = rng.next_u64(); + let expected: Vec<_> = (0..padded_dim) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + + #[test] + fn rejects_non_power_of_two_padded_dimensions() { + assert!(SorfMatrix::try_new(dim_to_usize(100u32), rounds_to_usize(3u8), 42u64).is_err()); + } + + #[test] + fn tail_block_uses_only_required_bits() { + let seed = 42u64; + let padded_dim = dim_to_usize(32u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); + + let mut rng = SplitMix64::new(seed); + let word = rng.next_u64(); + let expected: Vec<_> = (0..padded_dim) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + + /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32u32, 3u8)] + #[case(64u32, 3u8)] + #[case(100u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 2u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(256u32, 3u8)] + #[case(512u32, 3u8)] + #[case(768u32, 3u8)] + #[case(1024u32, 3u8)] + fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 42u64)?; + let padded_dim = transform.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut transformed = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + transform.transform(&input, &mut transformed); + transform.inverse_transform(&transformed, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions and round counts. + #[rstest] + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 7u64)?; + let padded_dim = transform.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut transformed = vec![0.0f32; padded_dim]; + transform.transform(&input, &mut transformed); + let transformed_norm: f32 = transformed.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - transformed_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + transformed_norm, + (input_norm - transformed_norm).abs() / input_norm + ); + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs new file mode 100644 index 00000000000..ed5aab190aa --- /dev/null +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::arrays::struct_::StructArrayExt; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; + +use super::execute_tq_decode; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::test_session; +use super::turboquant_storage; +use super::vector_array; +use super::vector_element_ptype; +use super::vector_validity; +use super::vector_values_f32; +use crate::TurboQuantConfig; +use crate::centroids::compute_or_get_centroids; +use crate::vector::normalize::tq_normalize_as_l2_denorm; + +#[rstest] +#[case::zero_bits(0, 42, 3)] +#[case::too_many_bits(9, 42, 3)] +#[case::zero_rounds(2, 42, 0)] +fn config_rejects_invalid_values(#[case] bit_width: u8, #[case] seed: u64, #[case] num_rounds: u8) { + assert!(TurboQuantConfig::try_new(bit_width, seed, num_rounds).is_err()); +} + +#[test] +fn encode_rejects_non_vector_input() { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = PrimitiveArray::new::(Buffer::copy_from([1.0, 2.0]), Validity::NonNullable) + .into_array(); + + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); +} + +#[test] +fn encode_rejects_small_dimensions() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(127, 1, 1.0, Validity::NonNullable)?; + + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn encode_rejects_padded_dimension_overflow() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; + + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn centroid_cache_is_deterministic() -> VortexResult<()> { + let first = compute_or_get_centroids(128, 3)?; + let second = compute_or_get_centroids(128, 3)?; + + assert_eq!(first.as_slice(), second.as_slice()); + Ok(()) +} + +#[test] +fn encode_decode_empty_vectors() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(128, &[], Validity::NonNullable)?; + + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + + assert!(decoded.is_empty()); + Ok(()) +} + +#[test] +fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let validity = Validity::from_iter([true, false, true]); + let input = f32_vector_array(128, 3, 0.25, validity)?; + + let config = TurboQuantConfig::try_new(3, 1, 2)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let storage = turboquant_storage(encoded, &mut ctx)?; + let mask = storage.struct_validity().execute_mask(3, &mut ctx)?; + let norms: PrimitiveArray = storage + .unmasked_field_by_name("norms")? + .clone() + .execute(&mut ctx)?; + let codes: FixedSizeListArray = storage + .unmasked_field_by_name("codes")? + .clone() + .execute(&mut ctx)?; + + assert!(mask.value(0)); + assert!(!mask.value(1)); + assert!(mask.value(2)); + assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); + + let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?; + assert!(norms_validity.value(0)); + assert!(!norms_validity.value(1)); + assert!(norms_validity.value(2)); + assert!(codes_validity.value(0)); + assert!(!codes_validity.value(1)); + assert!(codes_validity.value(2)); + + let codes_values: PrimitiveArray = codes.elements().clone().execute(&mut ctx)?; + assert!( + codes_values.as_slice::()[128..256] + .iter() + .all(|&code| code == 0) + ); + Ok(()) +} + +#[test] +fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[128..256].fill(13.0); + values[256] = 1.0; + let input = vector_array(128, &values, Validity::from_iter([true, false, true]))?; + + let l2_denorm = tq_normalize_as_l2_denorm(input, &mut ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + + let normalized_ext: ExtensionArray = normalized.execute(&mut ctx)?; + let normalized_fsl: FixedSizeListArray = + normalized_ext.storage_array().clone().execute(&mut ctx)?; + let normalized_values: PrimitiveArray = normalized_fsl.elements().clone().execute(&mut ctx)?; + let norms: PrimitiveArray = norms.execute(&mut ctx)?; + let normalized_validity = normalized_fsl.validity()?.execute_mask(3, &mut ctx)?; + let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + + assert!(normalized_validity.value(0)); + assert!(!normalized_validity.value(1)); + assert!(normalized_validity.value(2)); + assert!(norms_validity.value(0)); + assert!(!norms_validity.value(1)); + assert!(norms_validity.value(2)); + assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!(norms.as_slice::()[0], 5.0); + assert!( + normalized_values.as_slice::()[128..256] + .iter() + .all(|&value| value == 0.0) + ); + Ok(()) +} + +#[test] +fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[256] = 1.0; + values[257] = -1.0; + let input = vector_array(128, &values, Validity::from_iter([true, true, false]))?; + + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let output = vector_values_f32(decoded.clone(), &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; + + assert!(validity.value(0)); + assert!(validity.value(1)); + assert!(!validity.value(2)); + assert!(output[128..256].iter().all(|&v| v == 0.0)); + Ok(()) +} + +#[rstest] +#[case::f16(PType::F16)] +#[case::f64(PType::F64)] +fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + match ptype { + PType::F16 => { + let values = (0..2 * 128) + .map(|i| half::f16::from_f32(((i % 17) as f32 - 8.0) * 0.25)) + .collect::>(); + let input = vector_array(128, &values, Validity::NonNullable)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let ext: ExtensionArray = decoded.execute(&mut ctx)?; + assert_eq!(vector_element_ptype(&ext)?, PType::F16); + } + PType::F64 => { + let values = (0..2 * 128) + .map(|i| ((i % 17) as f64 - 8.0) * 0.25) + .collect::>(); + let input = vector_array(128, &values, Validity::NonNullable)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; + let ext: ExtensionArray = decoded.execute(&mut ctx)?; + assert_eq!(vector_element_ptype(&ext)?, PType::F64); + } + _ => unreachable!("test only passes f16/f64"), + } + Ok(()) +} + +#[test] +fn decode_scales_by_stored_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let base = f32_vector_array(128, 1, 0.5, Validity::NonNullable)?; + let scaled = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(2, 99, 3)?; + + let base_values = vector_values_f32( + execute_tq_decode(execute_tq_encode(base, &config, &mut ctx)?, &mut ctx)?, + &mut ctx, + )?; + let scaled_values = vector_values_f32( + execute_tq_decode(execute_tq_encode(scaled, &config, &mut ctx)?, &mut ctx)?, + &mut ctx, + )?; + + for (base, scaled) in base_values.iter().zip(scaled_values.iter()) { + assert!((*scaled - 2.0 * *base).abs() < 1e-5); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs new file mode 100644 index 00000000000..e59b7a95c75 --- /dev/null +++ b/vortex-turboquant/src/tests/file.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::stream::ArrayStreamExt; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_file::OpenOptionsSessionExt; +use vortex_file::VortexWriteOptions; +use vortex_io::runtime::BlockingRuntime; +use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_tensor::vector::Vector; + +use super::execute_tq_decode_from_metadata; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::file_session; +use super::vector_validity; +use crate::TQDecode; +use crate::TurboQuantConfig; +use crate::vtable::tq_metadata; + +#[test] +fn file_roundtrip_with_initialize_session() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, encoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + let metadata = tq_metadata(read.dtype())?; + assert_eq!(metadata.dimensions, 128); + let decoded = execute_tq_decode_from_metadata(read, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} + +#[test] +fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = TQDecode::try_new_array(encoded)?.into_array(); + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, decoded.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + assert!(read.dtype().as_extension().is::()); + + let validity = vector_validity(read, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs new file mode 100644 index 00000000000..f99f0ee5105 --- /dev/null +++ b/vortex-turboquant/src/tests/malformed.rs @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; + +use super::execute_tq_decode_from_metadata; +use super::test_session; +use super::vector_validity; +use crate::TurboQuant; +use crate::TurboQuantMetadata; + +#[rstest] +#[case::nullable_norms_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::Nullable, + Nullability::NonNullable +)] +#[case::nullable_codes_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::NonNullable, + Nullability::Nullable +)] +#[case::nonnullable_norms_under_nullable_struct( + Nullability::Nullable, + Nullability::NonNullable, + Nullability::Nullable +)] +#[case::nonnullable_codes_under_nullable_struct( + Nullability::Nullable, + Nullability::Nullable, + Nullability::NonNullable +)] +fn decode_accepts_child_nullability_that_covers_struct_validity( + #[case] struct_nullability: Nullability, + #[case] norms_nullability: Nullability, + #[case] codes_nullability: Nullability, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes.into_array(), + 128, + Validity::from(codes_nullability), + 1, + ) + .unwrap() + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::from(struct_nullability), + ) + .unwrap(); + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) + .unwrap() + .into_array(); + + execute_tq_decode_from_metadata(tq, &mut ctx)?; + Ok(()) +} + +#[test] +fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let decoded = execute_tq_decode_from_metadata(tq, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + Ok(()) +} + +#[test] +fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = PrimitiveArray::new::( + Buffer::copy_from([1.0, 1.0, 1.0]), + Validity::from_iter([true, true, false]), + ) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes.into_array(), + 128, + Validity::from_iter([true, false, true]), + 3, + )? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +#[should_panic(expected = "TurboQuant code exceeds centroid count")] +fn decode_panics_on_codes_outside_centroid_table() { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); + let mut codes = vec![0u8; 128]; + codes[0] = 2; + let codes = PrimitiveArray::new::(codes, Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1) + .unwrap() + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::NonNullable, + ) + .unwrap(); + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) + .unwrap() + .into_array(); + + drop(execute_tq_decode_from_metadata(tq, &mut ctx)); +} diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs new file mode 100644 index 00000000000..e0d1042f02f --- /dev/null +++ b/vortex-turboquant/src/tests/metadata.rs @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use prost::Message; +use rstest::rstest; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::StructFields; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::TurboQuant; +use crate::TurboQuantMetadata; +use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::NORMS_FIELD; +use crate::vector::tq_padded_dim; + +#[derive(Clone, PartialEq, Message)] +struct MetadataWire { + #[prost(enumeration = "PType", tag = "1")] + element_ptype: i32, + #[prost(uint32, tag = "2")] + dimensions: u32, + #[prost(uint32, tag = "3")] + bit_width: u32, + #[prost(uint64, tag = "4")] + seed: u64, + #[prost(uint32, tag = "5")] + num_rounds: u32, +} + +fn tq_storage_dtype( + metadata: &TurboQuantMetadata, + row_nullability: Nullability, +) -> VortexResult { + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + Ok(DType::Struct( + StructFields::new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![ + DType::Primitive(metadata.element_ptype, row_nullability), + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + padded_dim, + row_nullability, + ), + ], + ), + row_nullability, + )) +} + +#[rstest] +#[case::f16(PType::F16)] +#[case::f32(PType::F32)] +#[case::f64(PType::F64)] +fn metadata_serialization_roundtrips(#[case] element_ptype: PType) -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + let encoded = TurboQuant.serialize_metadata(&metadata)?; + let decoded = TurboQuant.deserialize_metadata(&encoded)?; + + assert_eq!(decoded, metadata); + Ok(()) +} + +#[test] +fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + let encoded = TurboQuant.serialize_metadata(&metadata)?; + let wire = MetadataWire::decode(encoded.as_slice())?; + + assert_eq!(wire.element_ptype, PType::F32 as i32); + assert_eq!(wire.dimensions, 128); + Ok(()) +} + +#[test] +fn metadata_display_matches_field_order() { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + assert_eq!( + metadata.to_string(), + "element_ptype: f32, dimensions: 128, bit_width: 4, seed: 7, num_rounds: 3" + ); +} + +#[test] +fn dtype_validation_accepts_expected_storage() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 129, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + + ExtDType::::try_new( + metadata, + tq_storage_dtype(&metadata, Nullability::Nullable)?, + )?; + Ok(()) +} + +#[test] +fn dtype_validation_accepts_nonnullable_storage() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 129, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + + ExtDType::::try_new( + metadata, + tq_storage_dtype(&metadata, Nullability::NonNullable)?, + )?; + Ok(()) +} + +#[test] +fn dtype_validation_rejects_malformed_storage() { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + let storage = DType::Struct( + StructFields::new( + FieldNames::from(["norms", "codes"]), + vec![ + DType::Primitive(PType::F32, Nullability::Nullable), + DType::FixedSizeList( + DType::Primitive(PType::U8, Nullability::Nullable).into(), + 128, + Nullability::NonNullable, + ), + ], + ), + Nullability::Nullable, + ); + + assert!(ExtDType::::try_new(metadata, storage).is_err()); +} diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs new file mode 100644 index 00000000000..ffa1db175a7 --- /dev/null +++ b/vortex-turboquant/src/tests/mod.rs @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![cfg_attr( + test, + allow(clippy::unwrap_used, clippy::expect_used, clippy::unwrap_in_result) +)] + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::PType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::memory::MemorySession; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_io::runtime::BlockingRuntime; +use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_io::session::RuntimeSession; +use vortex_io::session::RuntimeSessionExt; +use vortex_layout::session::LayoutSession; +use vortex_session::VortexSession; +use vortex_tensor::vector::Vector; + +use crate::TQDecode; +use crate::TQEncode; +use crate::TurboQuantConfig; +use crate::initialize; + +mod encode_decode; +mod file; +mod malformed; +mod metadata; +mod parity; +mod scalar_fns; + +fn test_session() -> VortexSession { + let session = VortexSession::empty().with::(); + initialize(&session); + session +} + +fn file_session(runtime: &SingleThreadRuntime) -> VortexSession { + let session = VortexSession::empty() + .with::() + .with::() + .with::() + .with::() + .with_handle(runtime.handle()); + vortex_file::register_default_encodings(&session); + vortex_tensor::initialize(&session); + initialize(&session); + session +} + +fn vector_array( + dimensions: u32, + values: &[T], + validity: Validity, +) -> VortexResult { + assert!(dimensions > 0, "dimensions must be > 0"); + let row_count = values.len() / dimensions as usize; + + let elements = PrimitiveArray::new::( + values.iter().copied().collect::>(), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new(elements.into_array(), dimensions, validity, row_count)?; + + Ok(ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array())?.into_array()) +} + +fn f32_vector_array( + dimensions: u32, + rows: usize, + scale: f32, + validity: Validity, +) -> VortexResult { + let values = (0..rows * dimensions as usize) + .map(|i| ((i % 17) as f32 - 8.0) * scale) + .collect::>(); + vector_array(dimensions, &values, validity) +} + +fn vector_values_f32(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { + let ext: ExtensionArray = array.execute(ctx)?; + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + Ok(elements.as_slice::().to_vec()) +} + +fn vector_validity(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext: ExtensionArray = array.execute(ctx)?; + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + fsl.validity() +} + +fn vector_element_ptype(array: &ExtensionArray) -> VortexResult { + Ok(array + .storage_array() + .dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| vortex_err!("expected FixedSizeList vector storage"))? + .as_ptype()) +} + +fn turboquant_storage(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext: ExtensionArray = array.execute(ctx)?; + ext.storage_array().clone().execute(ctx) +} + +fn execute_tq_encode( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + TQEncode::try_new_array(input, config)? + .into_array() + .execute(ctx) +} + +fn execute_tq_decode(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + TQDecode::try_new_array(input)?.into_array().execute(ctx) +} + +fn execute_tq_decode_from_metadata( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + execute_tq_decode(input, ctx) +} diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs new file mode 100644 index 00000000000..2995fb8b296 --- /dev/null +++ b/vortex-turboquant/src/tests/parity.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::VortexSessionExecute; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::execute_tq_decode; +use super::execute_tq_encode; +use super::f32_vector_array; +use super::test_session; +use super::vector_values_f32; +use crate::TurboQuantConfig; + +#[test] +fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; + let old_config = vortex_tensor::encodings::turboquant::TurboQuantConfig { + bit_width: config.bit_width(), + seed: config.seed(), + num_rounds: config.num_rounds(), + }; + let old_decoded = + vortex_tensor::encodings::turboquant::turboquant_encode(input, &old_config, &mut ctx)? + .execute(&mut ctx)?; + + let new_values = vector_values_f32(new_decoded, &mut ctx)?; + let old_values = vector_values_f32(old_decoded, &mut ctx)?; + + assert_eq!(new_values, old_values); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs new file mode 100644 index 00000000000..de125e8a5f1 --- /dev/null +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::f32_vector_array; +use super::test_session; +use super::vector_validity; +use crate::TQDecode; +use crate::TQEncode; +use crate::TurboQuant; +use crate::TurboQuantConfig; +use crate::vtable::tq_metadata; + +#[test] +fn scalar_fn_ids_and_options_roundtrip() -> VortexResult<()> { + let session = test_session(); + let config = TurboQuantConfig::try_new(4, 7, 2)?; + + assert_eq!(TQEncode.id().as_ref(), "vortex.turboquant.encode"); + assert_eq!(TQDecode.id().as_ref(), "vortex.turboquant.decode"); + + let encode_metadata = TQEncode.serialize(&config)?.unwrap(); + let decode_metadata = TQDecode.serialize(&EmptyMetadata)?.unwrap(); + + assert_eq!(TQEncode.deserialize(&encode_metadata, &session)?, config); + assert!(decode_metadata.is_empty()); + assert_eq!( + TQDecode.deserialize(&decode_metadata, &session)?, + EmptyMetadata + ); + Ok(()) +} + +#[test] +fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let encoded_lazy = TQEncode::try_new_array(input, &config)?; + let encoded_metadata = tq_metadata(encoded_lazy.dtype())?; + assert_eq!(encoded_metadata.dimensions, 128); + assert_eq!(encoded_metadata.bit_width, config.bit_width()); + assert!(encoded_lazy.dtype().as_extension().is::()); + + let encoded = encoded_lazy.into_array().execute(&mut ctx)?; + let decoded_lazy = TQDecode::try_new_array(encoded)?; + let decoded = decoded_lazy.into_array().execute(&mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; + + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs new file mode 100644 index 00000000000..58c4271a398 --- /dev/null +++ b/vortex-turboquant/src/vector/mod.rs @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +pub(crate) mod normalize; +pub(crate) mod quantize; +pub(crate) mod storage; + +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +/// Compute the padded SORF dimension for an original vector dimension. +pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult { + let padded_dim = dimensions + .checked_next_power_of_two() + .ok_or_else(|| vortex_err!("TurboQuant padded dimension overflow for {dimensions}"))?; + + usize::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit usize")) +} diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs new file mode 100644 index 00000000000..642949eecf6 --- /dev/null +++ b/vortex-turboquant/src/vector/normalize.rs @@ -0,0 +1,236 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant-local vector normalization. + +// TODO(connor): Remove this comment once we delete the other version in `vortex-tensor`. +// The tensor crate also has a `normalize_as_l2_denorm` helper, but TurboQuant needs different +// validity semantics: a null vector is not a zero vector, so invalid rows keep their row validity +// on both `L2Denorm` children and downstream quantization skips them. + +use num_traits::Float; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::NativePType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_mask::Mask; +use vortex_mask::MaskValues; +use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; +use vortex_tensor::vector::AnyVector; +use vortex_tensor::vector::Vector; + +/// Normalize a `Vector` array and wrap it with its original row norms with [`L2Denorm`]. +/// +/// This preserves input row validity on both [`L2Denorm`] children. Or in other words, validity is +/// propagated down to the children so that TurboQuant can skip quantizing those vectors (as it does +/// not have a good way to represent 0 vectors in its quantized domain). +pub(crate) fn tq_normalize_as_l2_denorm( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let row_count = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant normalization expects a Vector extension array"))?; + let dimensions = vector_metadata.dimensions() as usize; + let vector_validity = input.validity()?; + + // Use `L2Norm` to calculate the normals for each vector. + let norms: ArrayRef = L2Norm::try_new_array(input.clone(), row_count)? + .into_array() + .execute(ctx)?; + let primitive_norms: PrimitiveArray = norms.clone().execute(ctx)?; + + let input: ExtensionArray = input.execute(ctx)?; + let storage: FixedSizeListArray = input.storage_array().clone().execute(ctx)?; + vortex_ensure_eq!( + storage.list_size() as usize, + dimensions, + "Vector storage dimension must be {dimensions}, got {}", + storage.list_size() + ); + let elements: PrimitiveArray = storage.elements().clone().execute(ctx)?; + + let mask = vector_validity.execute_mask(row_count, ctx)?; + + let normalized = match_each_float_ptype!(elements.ptype(), |T| { + normalize_vectors::( + &elements, + &primitive_norms, + &mask, + dimensions, + vector_validity.clone(), + ) + })?; + + // SAFETY: matches the lossy-encoding relaxation documented on + // [`L2Denorm::new_array_unchecked`]. Norms come from `L2Norm` over the same input, so they + // match the vector element type and row count. Valid nonzero rows are divided by their stored + // norm and are unit-norm. Valid zero-norm rows and invalid rows use physical zero placeholders; + // invalid rows remain guarded by row-level invalid validity. + unsafe { L2Denorm::new_array_unchecked(normalized, norms, row_count) } +} + +fn normalize_vectors( + elements: &PrimitiveArray, + norms: &PrimitiveArray, + mask: &Mask, + dimensions: usize, + vector_validity: Validity, +) -> VortexResult +where + T: Float + NativePType, +{ + let num_vectors = norms.len(); + + let values = elements.as_slice::(); + let norm_values = norms.as_slice::(); + + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + // The total number of pushes is always exactly `num_vectors * dimensions == output_len` + // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. + match mask { + Mask::AllFalse(_) => { + // Every row is invalid: bulk-fill the output with zero placeholders. + // + // SAFETY: `output` was allocated with capacity `output_len`, and this push writes + // exactly `output_len` zero placeholders. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for i in 0..num_vectors { + // SAFETY: `output` was allocated with capacity `output_len = num_vectors * + // dimensions`. This loop runs `num_vectors` times and each call pushes exactly + // `dimensions` elements, so capacity for `dimensions` more elements always + // remains. + unsafe { normalize_one_row::(&mut output, values, norm_values, dimensions, i) }; + } + } + Mask::Values(values_mask) => { + // SAFETY: `output` was allocated with capacity `output_len = num_vectors * + // dimensions`, which is the bound the helper requires. + unsafe { + normalize_vectors_with_mask::( + &mut output, + values, + norm_values, + dimensions, + num_vectors, + values_mask, + ) + }; + } + } + + // Vector elements are always non-nullable. + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + + #[expect( + clippy::cast_possible_truncation, + reason = "this initially came from a u32" + )] + let storage = FixedSizeListArray::try_new( + elements.into_array(), + dimensions as u32, + vector_validity, + num_vectors, + )?; + + Ok( + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage.into_array())? + .into_array(), + ) +} + +/// Normalize a single valid row, or push `dimensions` zero placeholders if the row's L2 norm +/// is zero. +/// +/// A valid vector with L2 norm zero is all zeros, so dividing through it would be undefined. +/// Treating it the same as an invalid row preserves the original semantics. +/// +/// # Safety +/// +/// `output` must have capacity for at least `dimensions` more elements before this call. +unsafe fn normalize_one_row( + output: &mut BufferMut, + values: &[T], + norm_values: &[T], + dimensions: usize, + i: usize, +) where + T: Float + NativePType, +{ + let norm = norm_values[i]; + + if norm == T::zero() { + // SAFETY: caller guarantees capacity for `dimensions` more elements. + unsafe { output.push_n_unchecked(T::zero(), dimensions) }; + } else { + let row_values = &values[i * dimensions..][..dimensions]; + + for &value in row_values { + // SAFETY: caller guarantees capacity for `dimensions` more elements. + unsafe { output.push_unchecked(value / norm) }; + } + } +} + +/// Walk the pre-cached run boundaries of a `Values` mask, bulk-pushing zero placeholders for +/// invalid runs and normalizing valid runs row by row. +/// +/// # Safety +/// +/// `output` must have capacity for at least `num_vectors * dimensions` more elements before +/// this call. +unsafe fn normalize_vectors_with_mask( + output: &mut BufferMut, + values: &[T], + norm_values: &[T], + dimensions: usize, + num_vectors: usize, + values_mask: &MaskValues, +) where + T: Float + NativePType, +{ + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: capacity invariant from caller. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + + for i in start..end { + // SAFETY: capacity invariant from caller — each call pushes `dimensions` and the + // total number of valid rows in the mask is bounded by `num_vectors`. + unsafe { normalize_one_row::(output, values, norm_values, dimensions, i) }; + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: capacity invariant from caller. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } +} diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs new file mode 100644 index 00000000000..0861b9f6805 --- /dev/null +++ b/vortex-turboquant/src/vector/quantize.rs @@ -0,0 +1,181 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Core TurboQuant quantization helpers. +//! +//! Quantization consumes the TurboQuant-local normalized `Vector` child. Valid rows are transformed +//! and mapped to scalar centroid indices. Invalid rows remain in the full-length output but are +//! skipped: their physical code bytes are placeholders guarded by the `codes` row validity. +//! +//! This matters because TurboQuant's scalar codebook is optimized for coordinates of transformed +//! unit-norm vectors. The codebook does not generally contain an exact zero centroid, and a +//! physical code byte of `0` means "centroid 0", not "zero coordinate". Null vectors therefore +//! should not be converted to zero vectors and fed through the quantizer. + +use half::f16; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::PType; +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_mask::Mask; + +use super::tq_padded_dim; +use crate::TurboQuantConfig; +use crate::centroids::compute_centroid_boundaries; +use crate::centroids::compute_or_get_centroids; +use crate::centroids::find_nearest_centroid; +use crate::sorf::SorfMatrix; + +/// Shared intermediate results from the quantization loop. +pub(crate) struct QuantizationResult { + pub(crate) all_indices: Buffer, + pub(crate) padded_dim: usize, +} + +pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { + QuantizationResult { + all_indices: Buffer::empty(), + padded_dim, + } +} + +/// Core quantization: transform and quantize already-normalized rows. +/// +/// # Safety +/// +/// The input `fsl` must contain unit-norm vectors (already L2-normalized) for every valid row. +/// Invalid rows are left row-aligned in the output but are not transformed or quantized. The +/// transform and centroid lookup happen in f32. +pub(crate) unsafe fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dimension = fsl.list_size(); + let num_vectors = fsl.len(); + let padded_dim = tq_padded_dim(dimension)?; + + let sorf_transform = + SorfMatrix::try_new(padded_dim, config.num_rounds() as usize, config.seed())?; + debug_assert_eq!(sorf_transform.padded_dim(), padded_dim); + let padded_dim_u32 = u32::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + let f32_elements = cast_to_f32(elements_prim)?; + let validity = fsl.validity()?; + let mask = validity.execute_mask(num_vectors, ctx)?; + + let centroids = compute_or_get_centroids(padded_dim_u32, config.bit_width())?; + let boundaries = compute_centroid_boundaries(¢roids); + + let codes_len = num_vectors + .checked_mul(padded_dim) + .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; + let mut all_indices = BufferMut::::with_capacity(codes_len); + + let mut padded = vec![0.0f32; padded_dim]; + let mut transformed = vec![0.0f32; padded_dim]; + + // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into + // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site + // only needs to pass `all_indices` and the row index. + // + // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call + // with direct `all_indices.push_n_unchecked` calls. + let f32_slice = f32_elements.as_slice(); + let dimension = dimension as usize; + let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { + // Reuse `padded` and `transformed` from the outer scope. + padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); + padded[dimension..].fill(0.0); + sorf_transform.transform(&padded, &mut transformed); + + for &value in &transformed { + // SAFETY: total pushes across all match arms equal `codes_len`. + unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; + } + }; + + // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` + // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. + match &mask { + Mask::AllFalse(_) => { + // Every row is invalid: bulk-fill placeholder zero codes. + // + // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push + // writes exactly `codes_len` zero codes. + unsafe { all_indices.push_n_unchecked(0, codes_len) }; + } + Mask::AllTrue(_) => { + for row in 0..num_vectors { + quantize_row(&mut all_indices, row); + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `codes_len`. + unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; + } + + for row in start..end { + quantize_row(&mut all_indices, row); + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `codes_len`. + unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; + } + } + } + + Ok(QuantizationResult { + all_indices: all_indices.freeze(), + padded_dim, + }) +} + +/// Cast a float [`PrimitiveArray`] to a `Buffer`. +/// +/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively +/// in f32. This function handles the cast from any float ptype: +/// +/// - f16: losslessly widened to f32. +/// - f32: zero-copy buffer extraction. +/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is +/// acceptable because callers of this function operate in f32 and document this constraint. +fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { + match prim.ptype() { + PType::F16 => Ok(prim + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(prim.into_buffer()), + PType::F64 => Ok(prim + .as_slice::() + .iter() + .map(|&v| { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 values outside f32 range become infinity, matching tensor TQ" + )] + let v = v as f32; + v + }) + .collect()), + other => vortex_bail!("expected float elements, got {other:?}"), + } +} diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs new file mode 100644 index 00000000000..016eff3b0dc --- /dev/null +++ b/vortex-turboquant/src/vector/storage.rs @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant physical storage helpers. +//! +//! TurboQuant storage is row-aligned and full length: +//! +//! ```text +//! Struct { +//! norms: Primitive, +//! codes: FixedSizeList, padded_dim, vector_validity>, +//! } +//! ``` +//! +//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. +//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being +//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the +//! field-level validity records that those rows were not quantized. +//! +//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than the +//! struct validity, for example after a generic mask only updates the struct validity, but each +//! child must be valid wherever the struct row is valid. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::struct_::StructArrayExt; +use vortex_array::dtype::FieldNames; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_mask::Mask; + +use super::quantize::QuantizationResult; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_metadata; + +/// Name of the stored row-norm child. +pub(crate) const NORMS_FIELD: &str = "norms"; + +/// Name of the stored quantized-code child. +pub(crate) const CODES_FIELD: &str = "codes"; + +/// Parsed TurboQuant storage arrays. +/// +/// We use this as a helper struct for working with a TurboQuant extension array. +pub(crate) struct TurboQuantParsedStorage { + pub(crate) metadata: TurboQuantMetadata, + pub(crate) vector_validity: Validity, + pub(crate) norms: PrimitiveArray, + pub(crate) codes: PrimitiveArray, + pub(crate) len: usize, +} + +/// Build the `codes: FixedSizeList, padded_dim>` storage child. +/// +/// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived +/// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. +pub(crate) fn build_codes_child( + num_vectors: usize, + quantization: QuantizationResult, + vector_validity: Validity, +) -> VortexResult { + let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); + let padded_dim_u32 = u32::try_from(quantization.padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + Ok(FixedSizeListArray::try_new( + codes.into_array(), + padded_dim_u32, + vector_validity, + num_vectors, + )? + .into_array()) +} + +/// Build the TurboQuant `Struct { norms, codes }` storage array. +pub(crate) fn build_storage( + norms: ArrayRef, + codes: ArrayRef, + len: usize, + vector_validity: Validity, +) -> VortexResult { + Ok(StructArray::try_new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![norms, codes], + len, + vector_validity, + )? + .into_array()) +} + +/// Parse a TurboQuant extension array into executed storage children. +pub(crate) fn parse_storage( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let metadata = tq_metadata(input.dtype())?; + let ext: ExtensionArray = input.execute(ctx)?; + let storage: StructArray = ext.storage_array().clone().execute(ctx)?; + + let norms: PrimitiveArray = storage + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + + let codes_fsl: FixedSizeListArray = storage + .unmasked_field_by_name(CODES_FIELD)? + .clone() + .execute(ctx)?; + let codes: PrimitiveArray = codes_fsl.elements().clone().execute(ctx)?; + + let len = storage.len(); + let struct_validity = storage.struct_validity(); + let norms_validity = norms.validity()?; + let codes_validity = codes_fsl.validity()?; + + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + let codes_mask = codes_validity.execute_mask(len, ctx)?; + validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; + + Ok(TurboQuantParsedStorage { + metadata, + vector_validity: struct_validity, + norms, + codes, + len, + }) +} + +/// Validate that both child masks cover the struct mask: every row that the struct considers +/// valid must also be valid in the `norms` and `codes` children. +/// +/// `struct_mask & !child_mask` selects rows where the struct is valid but the child is not. If +/// no such row exists, the child covers the struct. [`Mask::bitand_not`] is variant-specialized, +/// so this short-circuits in `O(1)` when either mask is `AllTrue` or `AllFalse`. +fn validate_child_validity_covers_struct( + struct_mask: &Mask, + norms_mask: &Mask, + codes_mask: &Mask, +) -> VortexResult<()> { + vortex_ensure!( + struct_mask.clone().bitand_not(norms_mask).all_false(), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + vortex_ensure!( + struct_mask.clone().bitand_not(codes_mask).all_false(), + "TurboQuant {CODES_FIELD} row validity must cover storage validity" + ); + Ok(()) +} diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs new file mode 100644 index 00000000000..51dd0933ca6 --- /dev/null +++ b/vortex-turboquant/src/vtable.rs @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; +use std::sync::Arc; + +use prost::Message; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::StructFields; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; + +use crate::TurboQuantConfig; +use crate::config::MIN_DIMENSION; +use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::NORMS_FIELD; +use crate::vector::tq_padded_dim; + +/// TurboQuant logical extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct TurboQuant; + +/// Serialized metadata for a TurboQuant extension array. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct TurboQuantMetadata { + /// Original vector element type and stored norm type. + pub element_ptype: PType, + /// Original vector dimension before SORF padding. + pub dimensions: u32, + /// Bits per coordinate in the scalar quantizer codebook. + pub bit_width: u8, + /// Seed used to derive the deterministic SORF transform. + pub seed: u64, + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + pub num_rounds: u8, +} + +impl ExtVTable for TurboQuant { + type Metadata = TurboQuantMetadata; + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new("vortex.turboquant") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + validate_tq_metadata(metadata)?; + + let proto = TurboQuantMetadataProto { + element_ptype: metadata.element_ptype as i32, + dimensions: metadata.dimensions, + bit_width: u32::from(metadata.bit_width), + seed: metadata.seed, + num_rounds: u32::from(metadata.num_rounds), + }; + + Ok(proto.encode_to_vec()) + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + let proto = TurboQuantMetadataProto::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode TurboQuantMetadata: {e}"))?; + let bit_width = u8::try_from(proto.bit_width) + .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; + let num_rounds = u8::try_from(proto.num_rounds) + .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; + let element_ptype = PType::try_from(proto.element_ptype).map_err(|e| { + vortex_err!( + "invalid TurboQuant metadata element_ptype code {}: {e}", + proto.element_ptype + ) + })?; + + let metadata = TurboQuantMetadata { + element_ptype, + dimensions: proto.dimensions, + bit_width, + seed: proto.seed, + num_rounds, + }; + validate_tq_metadata(&metadata)?; + + Ok(metadata) + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + validate_tq_metadata(ext_dtype.metadata())?; + validate_tq_storage_dtype(ext_dtype.metadata(), ext_dtype.storage_dtype()) + } + + fn unpack_native<'a>( + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok(storage_value) + } +} + +#[derive(Clone, PartialEq, Message)] +struct TurboQuantMetadataProto { + #[prost(enumeration = "PType", tag = "1")] + element_ptype: i32, + #[prost(uint32, tag = "2")] + dimensions: u32, + #[prost(uint32, tag = "3")] + bit_width: u32, + #[prost(uint64, tag = "4")] + seed: u64, + #[prost(uint32, tag = "5")] + num_rounds: u32, +} + +/// Extract TurboQuant metadata from a dtype. +/// +/// Returns an error when the dtype is not the TurboQuant extension type. +pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { + let ext_dtype = dtype + .as_extension_opt() + .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; + + let metadata = ext_dtype + .metadata_opt::() + .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; + + Ok(*metadata) +} + +pub(crate) fn tq_storage_dtype( + metadata: &TurboQuantMetadata, + row_nullability: Nullability, +) -> VortexResult { + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + Ok(DType::Struct( + StructFields::new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![ + DType::Primitive(metadata.element_ptype, row_nullability), + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + padded_dim, + row_nullability, + ), + ], + ), + row_nullability, + )) +} + +fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { + vortex_ensure!( + metadata.dimensions >= MIN_DIMENSION, + "TurboQuant dimensions must be >= {MIN_DIMENSION}, got {}", + metadata.dimensions + ); + vortex_ensure!( + metadata.element_ptype.is_float(), + "TurboQuant element_ptype must be a float, got {:?}", + metadata.element_ptype + ); + + tq_padded_dim(metadata.dimensions)?; + + TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ()) +} + +fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> { + let DType::Struct(fields, _) = dtype else { + vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); + }; + let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); + vortex_ensure_eq!( + fields.names(), + &expected_names, + "TurboQuant storage fields must be {expected_names}, got {}", + fields.names() + ); + + let Some(norms_dtype) = fields.field(NORMS_FIELD) else { + vortex_bail!("TurboQuant storage missing {NORMS_FIELD} field"); + }; + let DType::Primitive(norms_ptype, _) = norms_dtype else { + vortex_bail!("TurboQuant {NORMS_FIELD} field must be primitive, got {norms_dtype}"); + }; + vortex_ensure_eq!( + norms_ptype, + metadata.element_ptype, + "TurboQuant {NORMS_FIELD} ptype must be {}, got {norms_ptype}", + metadata.element_ptype + ); + + let Some(codes_dtype) = fields.field(CODES_FIELD) else { + vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); + }; + let DType::FixedSizeList(element_dtype, list_size, _) = codes_dtype else { + vortex_bail!("TurboQuant {CODES_FIELD} field must be fixed-size-list, got {codes_dtype}"); + }; + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + vortex_ensure_eq!( + list_size, + padded_dim, + "TurboQuant {CODES_FIELD} list size must be {padded_dim}, got {list_size}" + ); + vortex_ensure_eq!( + element_dtype.as_ref(), + &DType::Primitive(PType::U8, Nullability::NonNullable), + "TurboQuant {CODES_FIELD} elements must be non-nullable u8, got {element_dtype}" + ); + + Ok(()) +} + +impl fmt::Display for TurboQuantMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "element_ptype: {}, dimensions: {}, bit_width: {}, seed: {}, num_rounds: {}", + self.element_ptype, self.dimensions, self.bit_width, self.seed, self.num_rounds + ) + } +}