From 46f33e31150a182a9f567e29762384506cda767c Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 8 May 2026 14:14:53 -0400 Subject: [PATCH 1/5] refactor(tensor): prepare SORF and vector APIs for TurboQuant Signed-off-by: Connor Tsui --- vortex-tensor/public-api.lock | 4 + .../src/encodings/turboquant/compress.rs | 4 +- .../encodings/turboquant/tests/structural.rs | 4 +- vortex-tensor/src/scalar_fns/inner_product.rs | 4 +- .../src/scalar_fns/sorf_transform/rotation.rs | 165 +++++++++++------- .../src/scalar_fns/sorf_transform/tests.rs | 4 +- .../src/scalar_fns/sorf_transform/vtable.rs | 3 +- vortex-tensor/src/types/vector/mod.rs | 2 +- 8 files changed, 120 insertions(+), 70 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 1ebb5962a90..6fbc7b26fa3 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -528,6 +528,10 @@ pub fn vortex_tensor::vector::AnyVector::try_match<'a>(&'a vortex_array::dtype:: pub struct vortex_tensor::vector::Vector +impl vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::try_new_vector_array(vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::vector::Vector pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index ca32faa6ec9..e656ba18822 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -205,8 +205,8 @@ fn turboquant_quantize_core( let dimension = fsl.list_size() as usize; let num_rows = fsl.len(); - let rotation = SorfMatrix::try_new(seed, dimension, num_rounds as usize)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dimension.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim stays representable as u32"); diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 3913cf3d8fe..a9ec822cb2a 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -259,8 +259,8 @@ fn sorf_transform_roundtrip_isolation() -> VortexResult<()> { } // Forward transform + quantize (mimicking what turboquant_quantize_core does). - let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let centroids = compute_or_get_centroids(padded_dim as u32, 8)?; let boundaries = compute_centroid_boundaries(¢roids); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index f34cb257dbd..2386dc03e53 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -366,7 +366,7 @@ impl InnerProduct { let mut padded_query = vec![0.0f32; padded_dim]; padded_query[..dim].copy_from_slice(flat.as_slice::()); - let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); @@ -930,7 +930,7 @@ mod tests { seed: u64, num_rounds: u8, ) -> VortexResult> { - let rotation = SorfMatrix::try_new(seed, dim, num_rounds as usize)?; + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds as usize, seed)?; let mut padded = vec![0.0f32; padded_dim]; let mut rotated = vec![0.0f32; padded_dim]; let mut out = Vec::with_capacity(num_rows * dim); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs index ff8aebd0f11..b416eb4a7e6 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/rotation.rs @@ -77,9 +77,25 @@ impl SorfMatrix { /// round-major, block-major order, with each `u64` contributing 64 sign bits in /// least-significant-bit-first order. pub fn try_new(seed: u64, dimensions: usize, num_rounds: usize) -> VortexResult { + Self::try_new_padded(dimensions.next_power_of_two(), num_rounds, seed) + } + + /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. + /// + /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded + /// logical dimension should call [`Self::try_new`] instead. + pub(crate) fn try_new_padded( + padded_dimensions: usize, + num_rounds: usize, + seed: u64, + ) -> VortexResult { vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); + vortex_ensure!( + padded_dimensions.is_power_of_two(), + "padded_dimensions must be a power of two, got {padded_dimensions}" + ); - let padded_dim = dimensions.next_power_of_two(); + let padded_dim = padded_dimensions; let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. @@ -132,8 +148,7 @@ impl SorfMatrix { /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. fn apply_srht(&self, buf: &mut [f32]) { for round in 0..self.num_rounds { - let offset = round * self.padded_dim; - apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]); + self.apply_signs_xor(buf, round); walsh_hadamard_transform(buf); } @@ -148,14 +163,24 @@ impl SorfMatrix { fn apply_inverse_srht(&self, buf: &mut [f32]) { for round in (0..self.num_rounds).rev() { walsh_hadamard_transform(buf); - let offset = round * self.padded_dim; - apply_signs_xor(buf, &self.sign_masks[offset..offset + self.padded_dim]); + self.apply_signs_xor(buf, round); } let norm = self.norm_factor; buf.iter_mut().for_each(|val| *val *= norm); } + /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. + /// + /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to + /// multiplying each element by +/-1.0, but avoids FP dependency chains. + fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { + let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } + } + /// Export the sign vectors as a flat `Vec` of 0/1 values in inverse application order /// `[D_k | ... | D₁]`. /// @@ -263,16 +288,6 @@ fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { } } -/// Apply sign masks via XOR on the IEEE 754 sign bit. -/// -/// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to -/// multiplying each element by +/-1.0, but avoids FP dependency chains. -fn apply_signs_xor(buf: &mut [f32], masks: &[u32]) { - for (val, &mask) in buf.iter_mut().zip(masks.iter()) { - *val = f32::from_bits(val.to_bits() ^ mask); - } -} - /// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. /// /// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` @@ -327,14 +342,24 @@ mod tests { .collect() } + fn dim_to_usize(dim: u32) -> usize { + usize::try_from(dim).unwrap() + } + + fn rounds_to_usize(num_rounds: u8) -> usize { + usize::from(num_rounds) + } + #[test] fn deterministic_from_seed() -> VortexResult<()> { - let r1 = SorfMatrix::try_new(42, 64, 3)?; - let r2 = SorfMatrix::try_new(42, 64, 3)?; + let dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(3u8); + let r1 = SorfMatrix::try_new(42u64, dim, num_rounds)?; + let r2 = SorfMatrix::try_new(42u64, dim, num_rounds)?; let pd = r1.padded_dim(); let mut input = vec![0.0f32; pd]; - for i in 0..64 { + for i in 0..dim { input[i] = i as f32; } let mut out1 = vec![0.0f32; pd]; @@ -349,15 +374,19 @@ mod tests { #[test] fn export_inverse_signs_matches_golden_words() -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, 64, 2)?; + let dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(2u8); + let seed = 42u64; + let rot = SorfMatrix::try_new(seed, dim, num_rounds)?; + let padded_dim = rot.padded_dim(); let actual = rot.export_inverse_signs_u8(); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let round0_word = rng.next_u64(); let round1_word = rng.next_u64(); - let mut expected = Vec::with_capacity(128); - expected.extend(unpack_sign_bits(round1_word, 64)); - expected.extend(unpack_sign_bits(round0_word, 64)); + let mut expected = Vec::with_capacity(num_rounds * padded_dim); + expected.extend(unpack_sign_bits(round1_word, padded_dim)); + expected.extend(unpack_sign_bits(round0_word, padded_dim)); assert_eq!(actual, expected); Ok(()) @@ -365,25 +394,38 @@ mod tests { #[test] fn one_word_generates_64_signs_lsb_first() { - let masks = gen_sign_masks_from_seed(42, 64, 1); - assert_eq!(masks.len(), 64); + let seed = 42u64; + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let word = rng.next_u64(); - let expected: Vec<_> = (0..64) + let expected: Vec<_> = (0..padded_dim) .map(|bit_idx| sign_mask_from_word(word, bit_idx)) .collect(); assert_eq!(masks, expected); } + #[test] + fn accepts_non_power_of_two_dimensions() -> VortexResult<()> { + let rot = SorfMatrix::try_new(42u64, dim_to_usize(100u32), rounds_to_usize(3u8))?; + assert_eq!(rot.padded_dim(), 128); + Ok(()) + } + #[test] fn tail_block_uses_only_required_bits() { - let masks = gen_sign_masks_from_seed(42, 32, 1); - assert_eq!(masks.len(), 32); + let seed = 42u64; + let padded_dim = dim_to_usize(32u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); - let mut rng = SplitMix64::new(42); + let mut rng = SplitMix64::new(seed); let word = rng.next_u64(); - let expected: Vec<_> = (0..32) + let expected: Vec<_> = (0..padded_dim) .map(|bit_idx| sign_mask_from_word(word, bit_idx)) .collect(); assert_eq!(masks, expected); @@ -392,19 +434,21 @@ mod tests { /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, /// including non-power-of-two dimensions that require padding. #[rstest] - #[case(32, 3)] - #[case(64, 3)] - #[case(100, 3)] - #[case(128, 1)] - #[case(128, 2)] - #[case(128, 3)] - #[case(128, 5)] - #[case(256, 3)] - #[case(512, 3)] - #[case(768, 3)] - #[case(1024, 3)] - fn roundtrip_exact(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, dim, num_rounds)?; + #[case(32u32, 3u8)] + #[case(64u32, 3u8)] + #[case(100u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 2u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(256u32, 3u8)] + #[case(512u32, 3u8)] + #[case(768u32, 3u8)] + #[case(1024u32, 3u8)] + fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -435,12 +479,14 @@ mod tests { /// Verify norm preservation across dimensions and round counts. #[rstest] - #[case(128, 1)] - #[case(128, 3)] - #[case(128, 5)] - #[case(768, 3)] - fn preserves_norm(#[case] dim: usize, #[case] num_rounds: usize) -> VortexResult<()> { - let rot = SorfMatrix::try_new(7, dim, num_rounds)?; + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let mut input = vec![0.0f32; padded_dim]; @@ -465,16 +511,15 @@ mod tests { /// Verify that export -> [`from_u8_slice`] produces identical transform output. #[rstest] - #[case(64, 3)] - #[case(128, 1)] - #[case(128, 3)] - #[case(128, 5)] - #[case(768, 3)] - fn sign_export_import_roundtrip( - #[case] dim: usize, - #[case] num_rounds: usize, - ) -> VortexResult<()> { - let rot = SorfMatrix::try_new(42, dim, num_rounds)?; + #[case(64u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn sign_export_import_roundtrip(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let rot = SorfMatrix::try_new(42u64, dim, num_rounds)?; let padded_dim = rot.padded_dim(); let signs_u8 = rot.export_inverse_signs_u8(); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs index 46abc66db71..c3aaf8d3cd5 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/tests.rs @@ -65,8 +65,8 @@ fn forward_rotate_and_quantize( } } - let rotation = SorfMatrix::try_new(seed, dim, num_rounds)?; - let padded_dim = rotation.padded_dim(); + let padded_dim = dim.next_power_of_two(); + let rotation = SorfMatrix::try_new_padded(padded_dim, num_rounds, seed)?; let centroids = compute_or_get_centroids(padded_dim as u32, bit_width)?; let boundaries = compute_centroid_boundaries(¢roids); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 827f8e6a796..0932c4cc498 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -164,7 +164,8 @@ impl ScalarFnVTable for SorfTransform { let f32_elements = elements_prim.into_buffer::(); // Reconstruct the orthogonal transform matrix from the seed. - let rotation = SorfMatrix::try_new(options.seed, dim, options.num_rounds as usize)?; + let rotation = + SorfMatrix::try_new_padded(padded_dim, options.num_rounds as usize, options.seed)?; // Inverse transform each row, truncate to original dimension, cast to target type. match_each_float_ptype!(options.element_ptype, |T| { diff --git a/vortex-tensor/src/types/vector/mod.rs b/vortex-tensor/src/types/vector/mod.rs index 3763220fc82..a66aaf76982 100644 --- a/vortex-tensor/src/types/vector/mod.rs +++ b/vortex-tensor/src/types/vector/mod.rs @@ -49,7 +49,7 @@ impl Vector { /// # Errors /// /// Returns an error if the [`Vector`] extension dtype rejects the storage array. - pub(crate) fn try_new_vector_array(storage: ArrayRef) -> VortexResult { + pub fn try_new_vector_array(storage: ArrayRef) -> VortexResult { ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage) .map(|ext| ext.into_array()) } From d61d4a3a70b0100d395a1cf38af51b096b50f837 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 8 May 2026 14:16:06 -0400 Subject: [PATCH 2/5] feat: add vortex-turboquant crate Signed-off-by: Connor Tsui --- Cargo.lock | 20 + Cargo.toml | 2 + vortex-turboquant/Cargo.toml | 35 ++ vortex-turboquant/public-api.lock | 211 ++++++++++ vortex-turboquant/src/centroids.rs | 346 +++++++++++++++ vortex-turboquant/src/config.rs | 84 ++++ vortex-turboquant/src/lib.rs | 87 ++++ vortex-turboquant/src/scalar_fns/metadata.rs | 47 +++ vortex-turboquant/src/scalar_fns/mod.rs | 11 + vortex-turboquant/src/scalar_fns/pack.rs | 177 ++++++++ vortex-turboquant/src/scalar_fns/unpack.rs | 213 ++++++++++ vortex-turboquant/src/sorf/mod.rs | 7 + vortex-turboquant/src/sorf/splitmix64.rs | 76 ++++ vortex-turboquant/src/sorf/transform.rs | 419 +++++++++++++++++++ vortex-turboquant/src/tests/file.rs | 73 ++++ vortex-turboquant/src/tests/malformed.rs | 189 +++++++++ vortex-turboquant/src/tests/metadata.rs | 173 ++++++++ vortex-turboquant/src/tests/mod.rs | 153 +++++++ vortex-turboquant/src/tests/pack_unpack.rs | 258 ++++++++++++ vortex-turboquant/src/tests/parity.rs | 38 ++ vortex-turboquant/src/tests/scalar_fns.rs | 105 +++++ vortex-turboquant/src/vector/mod.rs | 22 + vortex-turboquant/src/vector/normalize.rs | 146 +++++++ vortex-turboquant/src/vector/pack.rs | 94 +++++ vortex-turboquant/src/vector/quantize.rs | 150 +++++++ vortex-turboquant/src/vector/storage.rs | 167 ++++++++ vortex-turboquant/src/vector/unpack.rs | 166 ++++++++ vortex-turboquant/src/vtable.rs | 233 +++++++++++ 28 files changed, 3702 insertions(+) create mode 100644 vortex-turboquant/Cargo.toml create mode 100644 vortex-turboquant/public-api.lock create mode 100644 vortex-turboquant/src/centroids.rs create mode 100644 vortex-turboquant/src/config.rs create mode 100644 vortex-turboquant/src/lib.rs create mode 100644 vortex-turboquant/src/scalar_fns/metadata.rs create mode 100644 vortex-turboquant/src/scalar_fns/mod.rs create mode 100644 vortex-turboquant/src/scalar_fns/pack.rs create mode 100644 vortex-turboquant/src/scalar_fns/unpack.rs create mode 100644 vortex-turboquant/src/sorf/mod.rs create mode 100644 vortex-turboquant/src/sorf/splitmix64.rs create mode 100644 vortex-turboquant/src/sorf/transform.rs create mode 100644 vortex-turboquant/src/tests/file.rs create mode 100644 vortex-turboquant/src/tests/malformed.rs create mode 100644 vortex-turboquant/src/tests/metadata.rs create mode 100644 vortex-turboquant/src/tests/mod.rs create mode 100644 vortex-turboquant/src/tests/pack_unpack.rs create mode 100644 vortex-turboquant/src/tests/parity.rs create mode 100644 vortex-turboquant/src/tests/scalar_fns.rs create mode 100644 vortex-turboquant/src/vector/mod.rs create mode 100644 vortex-turboquant/src/vector/normalize.rs create mode 100644 vortex-turboquant/src/vector/pack.rs create mode 100644 vortex-turboquant/src/vector/quantize.rs create mode 100644 vortex-turboquant/src/vector/storage.rs create mode 100644 vortex-turboquant/src/vector/unpack.rs create mode 100644 vortex-turboquant/src/vtable.rs diff --git a/Cargo.lock b/Cargo.lock index a2c805220cc..8f4d775e035 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11251,6 +11251,26 @@ dependencies = [ "web-sys", ] +[[package]] +name = "vortex-turboquant" +version = "0.1.0" +dependencies = [ + "half", + "num-traits", + "prost 0.14.3", + "rstest", + "vortex-array", + "vortex-buffer", + "vortex-error", + "vortex-file", + "vortex-io", + "vortex-layout", + "vortex-mask", + "vortex-session", + "vortex-tensor", + "vortex-utils", +] + [[package]] name = "vortex-utils" version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml index f8726a6f265..58149ae7cb5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "vortex-proto", "vortex-array", "vortex-tensor", + "vortex-turboquant", "vortex-compressor", "vortex-btrblocks", "vortex-layout", @@ -296,6 +297,7 @@ vortex-sequence = { version = "0.1.0", path = "encodings/sequence", default-feat vortex-session = { version = "0.1.0", path = "./vortex-session", default-features = false } vortex-sparse = { version = "0.1.0", path = "./encodings/sparse", default-features = false } vortex-tensor = { version = "0.1.0", path = "./vortex-tensor", default-features = false } +vortex-turboquant = { version = "0.1.0", path = "./vortex-turboquant", default-features = false } vortex-utils = { version = "0.1.0", path = "./vortex-utils", default-features = false } vortex-zigzag = { version = "0.1.0", path = "./encodings/zigzag", default-features = false } vortex-zstd = { version = "0.1.0", path = "./encodings/zstd", default-features = false } diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml new file mode 100644 index 00000000000..f8148d8ff8c --- /dev/null +++ b/vortex-turboquant/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "vortex-turboquant" +authors = { workspace = true } +categories = { workspace = true } +description = "TurboQuant vector extension type" +edition = { workspace = true } +homepage = { workspace = true } +include = { workspace = true } +keywords = { workspace = true } +license = { workspace = true } +readme = { workspace = true } +repository = { workspace = true } +rust-version = { workspace = true } +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +half = { workspace = true } +num-traits = { workspace = true } +prost = { workspace = true } +vortex-array = { workspace = true } +vortex-buffer = { workspace = true } +vortex-error = { workspace = true } +vortex-mask = { workspace = true } +vortex-session = { workspace = true } +vortex-tensor = { workspace = true } +vortex-utils = { workspace = true, features = ["dashmap"] } + +[dev-dependencies] +rstest = { workspace = true } +vortex-file = { workspace = true } +vortex-io = { workspace = true } +vortex-layout = { workspace = true } diff --git a/vortex-turboquant/public-api.lock b/vortex-turboquant/public-api.lock new file mode 100644 index 00000000000..e18dcffe235 --- /dev/null +++ b/vortex-turboquant/public-api.lock @@ -0,0 +1,211 @@ +pub mod vortex_turboquant + +pub struct vortex_turboquant::TQPack + +impl vortex_turboquant::TQPack + +pub fn vortex_turboquant::TQPack::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance + +pub fn vortex_turboquant::TQPack::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TQPack + +pub fn vortex_turboquant::TQPack::clone(&self) -> vortex_turboquant::TQPack + +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQPack + +pub fn vortex_turboquant::TQPack::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TQPack::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQPack + +pub type vortex_turboquant::TQPack::Options = vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TQPack::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_turboquant::TQPack::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_turboquant::TQPack::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQPack::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQPack::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_turboquant::TQPack::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_turboquant::TQPack::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQPack::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQPack::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQPack::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TQPack::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_turboquant::TQUnpack + +impl vortex_turboquant::TQUnpack + +pub fn vortex_turboquant::TQUnpack::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance + +pub fn vortex_turboquant::TQUnpack::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TQUnpack + +pub fn vortex_turboquant::TQUnpack::clone(&self) -> vortex_turboquant::TQUnpack + +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQUnpack + +pub fn vortex_turboquant::TQUnpack::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TQUnpack::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> + +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQUnpack + +pub type vortex_turboquant::TQUnpack::Options = vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TQUnpack::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity + +pub fn vortex_turboquant::TQUnpack::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName + +pub fn vortex_turboquant::TQUnpack::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQUnpack::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQUnpack::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +pub fn vortex_turboquant::TQUnpack::id(&self) -> vortex_array::scalar_fn::ScalarFnId + +pub fn vortex_turboquant::TQUnpack::is_fallible(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQUnpack::is_null_sensitive(&self, &Self::Options) -> bool + +pub fn vortex_turboquant::TQUnpack::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TQUnpack::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> + +pub fn vortex_turboquant::TQUnpack::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> + +pub struct vortex_turboquant::TurboQuant + +impl core::clone::Clone for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::clone(&self) -> vortex_turboquant::TurboQuant + +impl core::cmp::Eq for vortex_turboquant::TurboQuant + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::eq(&self, &vortex_turboquant::TurboQuant) -> bool + +impl core::default::Default for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::default() -> vortex_turboquant::TurboQuant + +impl core::fmt::Debug for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuant + +pub fn vortex_turboquant::TurboQuant::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuant + +impl vortex_array::dtype::extension::vtable::ExtVTable for vortex_turboquant::TurboQuant + +pub type vortex_turboquant::TurboQuant::Metadata = vortex_turboquant::TurboQuantMetadata + +pub type vortex_turboquant::TurboQuant::NativeValue<'a> = &'a vortex_array::scalar::scalar_value::ScalarValue + +pub fn vortex_turboquant::TurboQuant::deserialize_metadata(&self, &[u8]) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::id(&self) -> vortex_array::dtype::extension::ExtId + +pub fn vortex_turboquant::TurboQuant::serialize_metadata(&self, &Self::Metadata) -> vortex_error::VortexResult> + +pub fn vortex_turboquant::TurboQuant::unpack_native<'a>(&'a vortex_array::dtype::extension::typed::ExtDType, &'a vortex_array::scalar::scalar_value::ScalarValue) -> vortex_error::VortexResult + +pub fn vortex_turboquant::TurboQuant::validate_dtype(&vortex_array::dtype::extension::typed::ExtDType) -> vortex_error::VortexResult<()> + +pub struct vortex_turboquant::TurboQuantConfig + +impl vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::bit_width(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantConfig::num_rounds(&self) -> u8 + +pub fn vortex_turboquant::TurboQuantConfig::seed(&self) -> u64 + +pub fn vortex_turboquant::TurboQuantConfig::try_new(u8, u64, u8) -> vortex_error::VortexResult + +impl core::clone::Clone for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::clone(&self) -> vortex_turboquant::TurboQuantConfig + +impl core::cmp::Eq for vortex_turboquant::TurboQuantConfig + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::eq(&self, &vortex_turboquant::TurboQuantConfig) -> bool + +impl core::default::Default for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::default() -> Self + +impl core::fmt::Debug for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuantConfig + +pub fn vortex_turboquant::TurboQuantConfig::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuantConfig + +pub struct vortex_turboquant::TurboQuantMetadata + +pub vortex_turboquant::TurboQuantMetadata::bit_width: u8 + +pub vortex_turboquant::TurboQuantMetadata::dimensions: u32 + +pub vortex_turboquant::TurboQuantMetadata::element_ptype: vortex_array::dtype::ptype::PType + +pub vortex_turboquant::TurboQuantMetadata::num_rounds: u8 + +pub vortex_turboquant::TurboQuantMetadata::seed: u64 + +impl core::clone::Clone for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::clone(&self) -> vortex_turboquant::TurboQuantMetadata + +impl core::cmp::Eq for vortex_turboquant::TurboQuantMetadata + +impl core::cmp::PartialEq for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::eq(&self, &vortex_turboquant::TurboQuantMetadata) -> bool + +impl core::fmt::Debug for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::fmt::Display for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::fmt(&self, &mut core::fmt::Formatter<'_>) -> core::fmt::Result + +impl core::hash::Hash for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::TurboQuantMetadata::hash<__H: core::hash::Hasher>(&self, &mut __H) + +impl core::marker::Copy for vortex_turboquant::TurboQuantMetadata + +impl core::marker::StructuralPartialEq for vortex_turboquant::TurboQuantMetadata + +pub fn vortex_turboquant::initialize(&vortex_session::VortexSession) diff --git a/vortex-turboquant/src/centroids.rs b/vortex-turboquant/src/centroids.rs new file mode 100644 index 00000000000..8499dfc397a --- /dev/null +++ b/vortex-turboquant/src/centroids.rs @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Max-Lloyd centroid computation for TurboQuant scalar quantizers. +//! +//! Pre-computes and caches optimal scalar quantizer centroids for the marginal distribution of +//! coordinates after a random orthogonal transform of a unit-norm vector. +//! +//! In high dimensions, each coordinate of a randomly transformed unit vector follows a +//! distribution proportional to `(1 - x^2)^((d-3)/2)` on `[-1, 1]`, which converges to +//! `N(0, 1/d)`. +//! +//! The Max-Lloyd algorithm finds optimal quantization centroids that minimize MSE for this +//! distribution. +//! +//! Centroids are not stored in TurboQuant arrays. They are deterministically derived from +//! `(padded_dim, bit_width)` and cached process-locally. +//! +//! The centroid model follows the random orthogonal transform marginal used by the TurboQuant +//! paper. This encoder applies a SORF-style structured transform instead of a dense random Gaussian +//! or orthogonal matrix, so paper-level error bounds should not be treated as verified for this +//! implementation without separate empirical validation. + +use std::sync::LazyLock; + +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_utils::aliases::dash_map::DashMap; + +use crate::config::MAX_BIT_WIDTH; +use crate::config::MIN_DIMENSION; + +// NB: Some of these numbers were arbitrarily chosen... + +/// The maximum iterations for Max-Lloyd algorithm when computing centroids. +const MAX_ITERATIONS: usize = 200; + +/// The Max-Lloyd convergence threshold for stopping early when computing centroids. +const CONVERGENCE_EPSILON: f64 = 1e-12; + +/// Number of numerical integration points for computing conditional expectations. +const INTEGRATION_POINTS: usize = 1000; + +/// Global centroid cache keyed by (dimension, bit_width). +static CENTROID_CACHE: LazyLock>> = LazyLock::new(DashMap::default); + +/// Get or compute cached centroids for the given dimension and bit width. +/// +/// Returns `2^bit_width` centroids sorted in ascending order, representing optimal scalar +/// quantization levels for the coordinate distribution after a random orthogonal transform in +/// `dimension`-dimensional space. +pub(crate) fn compute_or_get_centroids(dimension: u32, bit_width: u8) -> VortexResult> { + vortex_ensure!( + (1..=MAX_BIT_WIDTH).contains(&bit_width), + "TurboQuant bit_width must be 1-{}, got {bit_width}", + MAX_BIT_WIDTH + ); + vortex_ensure!( + dimension >= MIN_DIMENSION, + "TurboQuant dimension must be >= {}, got {dimension}", + MIN_DIMENSION + ); + + if let Some(centroids) = CENTROID_CACHE.get(&(dimension, bit_width)) { + return Ok(centroids.clone()); + } + + let centroids = max_lloyd_centroids(dimension, bit_width); + CENTROID_CACHE.insert((dimension, bit_width), centroids.clone()); + + Ok(centroids) +} + +// TODO(connor): It would potentially be more performant if this was modelled as const generic +// parameters to functions. +/// Half-integer exponent: represents `int_part + (if has_half { 0.5 } else { 0.0 })`. +/// +/// The marginal distribution exponent `(d-3)/2` is always an integer (when `d` is odd) or a +/// half-integer (when `d` is even). +/// +/// This type makes that invariant explicit and avoids floating-point comparison in the hot path. +#[derive(Clone, Copy, Debug)] +struct HalfIntExponent { + int_part: i32, + has_half: bool, +} + +impl HalfIntExponent { + /// Compute `(numerator) / 2` as a half-integer exponent. + /// + /// `numerator` is `d - 3` where `d` is the dimension (>= 2), so it can be negative. + fn from_numerator(numerator: i32) -> Self { + // Use Euclidean division to get floor division toward negative infinity. + let int_part = numerator.div_euclid(2); + let has_half = numerator.rem_euclid(2) != 0; + Self { int_part, has_half } + } +} + +/// Compute optimal centroids via the Max-Lloyd (Lloyd-Max) algorithm. +/// +/// Operates on the marginal distribution of a single coordinate of a randomly transformed unit +/// vector in d dimensions. +/// +/// The probability distribution function is: +/// `f(x) = C_d * (1 - x^2)^((d-3)/2)` on `[-1, 1]` +/// where `C_d` is the normalizing constant. +fn max_lloyd_centroids(dimension: u32, bit_width: u8) -> Buffer { + debug_assert!((1..=MAX_BIT_WIDTH).contains(&bit_width)); + let num_centroids = 1usize << bit_width; + + // For the marginal distribution on [-1, 1], we use the exponent (d-3)/2. + let exponent = HalfIntExponent::from_numerator(dimension as i32 - 3); + + // Initialize centroids uniformly on [-1, 1]. + let mut centroids: Vec = (0..num_centroids) + .map(|idx| -1.0 + (2.0 * (idx as f64) + 1.0) / (num_centroids as f64)) + .collect(); + + let mut boundaries: Vec = vec![0.0; num_centroids + 1]; + for _ in 0..MAX_ITERATIONS { + // Compute decision boundaries (midpoints between adjacent centroids). + boundaries[0] = -1.0; + for idx in 0..num_centroids - 1 { + boundaries[idx + 1] = (centroids[idx] + centroids[idx + 1]) / 2.0; + } + boundaries[num_centroids] = 1.0; + + // Update each centroid to the conditional mean within its Voronoi cell. + let mut max_change = 0.0f64; + for idx in 0..num_centroids { + let lo = boundaries[idx]; + let hi = boundaries[idx + 1]; + let new_centroid = mean_between_centroids(lo, hi, exponent); + max_change = max_change.max((new_centroid - centroids[idx]).abs()); + centroids[idx] = new_centroid; + } + + if max_change < CONVERGENCE_EPSILON { + break; + } + } + + #[expect( + clippy::cast_possible_truncation, + reason = "all values are in [-1, 1] so this just loses precision" + )] + centroids.into_iter().map(|val| val as f32).collect() +} + +/// Compute the conditional mean of the coordinate distribution on interval [lo, hi]. +/// +/// Returns `E[X | lo <= X <= hi]` where X has PDF proportional to `(1 - x^2)^exponent` on [-1, 1]. +/// +/// Since there is no closed form for the integrals, we compute this numerically. +fn mean_between_centroids(lo: f64, hi: f64, exponent: HalfIntExponent) -> f64 { + if (hi - lo).abs() < 1e-15 { + return (lo + hi) / 2.0; + } + + let dx = (hi - lo) / INTEGRATION_POINTS as f64; + + let mut numerator = 0.0; + let mut denominator = 0.0; + + for step in 0..=INTEGRATION_POINTS { + let x_val = lo + (step as f64) * dx; + let weight = pdf_unnormalized(x_val, exponent); + + let trap_weight = if step == 0 || step == INTEGRATION_POINTS { + 0.5 + } else { + 1.0 + }; + + numerator += trap_weight * x_val * weight; + denominator += trap_weight * weight; + } + + if denominator.abs() < 1e-30 { + (lo + hi) / 2.0 + } else { + numerator / denominator + } +} + +/// Unnormalized PDF of the coordinate distribution: `(1 - x^2)^exponent`. +/// +/// Uses `powi` + `sqrt` instead of `powf` for the half-integer exponents that arise from `(d-3)/2`. +/// This is significantly faster than the general `powf` which goes through +/// `exp(exponent * ln(base))`. +fn pdf_unnormalized(x_val: f64, exponent: HalfIntExponent) -> f64 { + let base = (1.0 - x_val * x_val).max(0.0); + + if exponent.has_half { + // Half-integer exponent: base^(int_part) * sqrt(base). + base.powi(exponent.int_part) * base.sqrt() + } else { + // Integer exponent: use powi directly. + base.powi(exponent.int_part) + } +} + +/// Precompute decision boundaries (midpoints between adjacent centroids). +/// +/// For `k` centroids, returns `k-1` boundaries. A value below `boundaries[0]` maps to centroid 0, a +/// value in `[boundaries[i-1], boundaries[i])` maps to centroid `i`, and a +/// value `>= boundaries[k-2]` maps to centroid `k-1`. +pub(crate) fn compute_centroid_boundaries(centroids: &[f32]) -> Vec { + centroids.windows(2).map(|w| (w[0] + w[1]) * 0.5).collect() +} + +/// Find the index of the nearest centroid using precomputed decision boundaries. +/// +/// `boundaries` must be the output of [`compute_centroid_boundaries`] for the corresponding +/// centroids. Uses binary search on the midpoints, avoiding distance comparisons +/// in the inner loop. +#[inline] +pub(crate) fn find_nearest_centroid(value: f32, boundaries: &[f32]) -> u8 { + debug_assert!( + boundaries.windows(2).all(|w| w[0] <= w[1]), + "boundaries must be sorted" + ); + debug_assert!( + boundaries.len() <= 256, // 1 << 8 + "too many boundaries" + ); + + #[expect( + clippy::cast_possible_truncation, + reason = "num_centroids <= 256 and partition_point will return at most 255" + )] + (boundaries.partition_point(|&b| b < value) as u8) +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + #[rstest] + #[case(128, 1, 2)] + #[case(128, 2, 4)] + #[case(128, 3, 8)] + #[case(128, 4, 16)] + #[case(768, 2, 4)] + #[case(1536, 3, 8)] + fn centroids_have_correct_count( + #[case] dim: u32, + #[case] bits: u8, + #[case] expected: usize, + ) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + assert_eq!(centroids.len(), expected); + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(128, 3)] + #[case(128, 4)] + #[case(768, 2)] + fn centroids_are_sorted(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + for window in centroids.windows(2) { + assert!( + window[0] < window[1], + "centroids not sorted: {:?}", + centroids + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 2)] + #[case(256, 2)] + #[case(768, 2)] + fn centroids_are_symmetric(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + let count = centroids.len(); + for idx in 0..count / 2 { + let diff = (centroids[idx] + centroids[count - 1 - idx]).abs(); + assert!( + diff < 1e-5, + "centroids not symmetric: c[{idx}]={}, c[{}]={}", + centroids[idx], + count - 1 - idx, + centroids[count - 1 - idx] + ); + } + Ok(()) + } + + #[rstest] + #[case(128, 1)] + #[case(128, 4)] + fn centroids_within_bounds(#[case] dim: u32, #[case] bits: u8) -> VortexResult<()> { + let centroids = compute_or_get_centroids(dim, bits)?; + for &val in centroids.iter() { + assert!( + (-1.0..=1.0).contains(&val), + "centroid out of [-1, 1]: {val}", + ); + } + Ok(()) + } + + #[test] + fn centroids_cached() -> VortexResult<()> { + let c1 = compute_or_get_centroids(128, 2)?; + let c2 = compute_or_get_centroids(128, 2)?; + assert_eq!(c1, c2); + Ok(()) + } + + #[test] + fn find_nearest_basic() -> VortexResult<()> { + let centroids = compute_or_get_centroids(128, 2)?; + let boundaries = compute_centroid_boundaries(¢roids); + assert_eq!(find_nearest_centroid(-1.0, &boundaries), 0); + + #[expect(clippy::cast_possible_truncation)] + let last_idx = (centroids.len() - 1) as u8; + assert_eq!(find_nearest_centroid(1.0, &boundaries), last_idx); + for (idx, &cv) in centroids.iter().enumerate() { + #[expect(clippy::cast_possible_truncation)] + let expected = idx as u8; + assert_eq!(find_nearest_centroid(cv, &boundaries), expected); + } + Ok(()) + } + + #[test] + fn rejects_invalid_params() { + assert!(compute_or_get_centroids(128, 0).is_err()); + assert!(compute_or_get_centroids(128, 9).is_err()); + assert!(compute_or_get_centroids(1, 2).is_err()); + assert!(compute_or_get_centroids(127, 2).is_err()); + } +} diff --git a/vortex-turboquant/src/config.rs b/vortex-turboquant/src/config.rs new file mode 100644 index 00000000000..8b5ed03e93c --- /dev/null +++ b/vortex-turboquant/src/config.rs @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; + +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +/// Minimum vector dimension for TurboQuant encoding. +/// +/// Note that this is not a theoretical minimum, it is mostly a practical one to limit the total +/// amount of distortion. +pub(crate) const MIN_DIMENSION: u32 = 128; + +/// Maximum supported number of bits per quantized coordinate. +pub(crate) const MAX_BIT_WIDTH: u8 = 8; + +/// Configuration for lossy TurboQuant packing. +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub struct TurboQuantConfig { + bit_width: u8, + seed: u64, + num_rounds: u8, +} + +impl TurboQuantConfig { + /// Build a TurboQuant configuration. + /// + /// # Errors + /// + /// Returns an error if `bit_width` is outside `1..=8` or `num_rounds` is zero. + pub fn try_new(bit_width: u8, seed: u64, num_rounds: u8) -> VortexResult { + vortex_ensure!( + (1..=MAX_BIT_WIDTH).contains(&bit_width), + "TurboQuant bit_width must be 1-{MAX_BIT_WIDTH}, got {bit_width}", + ); + vortex_ensure!( + num_rounds > 0, + "TurboQuant num_rounds must be > 0, got {num_rounds}" + ); + + Ok(Self { + bit_width, + seed, + num_rounds, + }) + } + + /// Bits per coordinate in the scalar quantizer codebook. + pub fn bit_width(&self) -> u8 { + self.bit_width + } + + /// Seed used to derive the deterministic SORF transform. + pub fn seed(&self) -> u64 { + self.seed + } + + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + pub fn num_rounds(&self) -> u8 { + self.num_rounds + } +} + +impl Default for TurboQuantConfig { + /// Defaults to 8 bits per coordinate, seed 42, and 3 SORF rounds. + fn default() -> Self { + Self { + bit_width: MAX_BIT_WIDTH, + seed: 42, + num_rounds: 3, + } + } +} + +impl fmt::Display for TurboQuantConfig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "bit_width: {}, seed: {}, num_rounds: {}", + self.bit_width, self.seed, self.num_rounds + ) + } +} diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs new file mode 100644 index 00000000000..3db71ceab6a --- /dev/null +++ b/vortex-turboquant/src/lib.rs @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant vector quantization extension type for Vortex. +//! +//! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy compression +//! of high-dimensional vector data. The extension operates on +//! [`Vector`](vortex_tensor::vector::Vector) extension arrays, packing their `FixedSizeList` +//! storage into quantized codes after a structured orthogonal surrogate transform. +//! +//! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 +//! [RFC 0033]: https://vortex-data.github.io/rfcs/rfc/0033.html +//! +//! # Overview +//! +//! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) +//! using MSE-optimal scalar quantization on coordinates of a transformed unit vector. +//! +//! The [`TQPack`] scalar function first computes and stores the original L2 norm for each vector +//! row, then normalizes each valid nonzero row internally before SORF transform and scalar +//! quantization. The [`TQUnpack`] scalar function dequantizes through deterministic centroids, +//! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the +//! stored norm. +//! +//! The packed storage is a row-aligned extension tree: +//! +//! ```text +//! Extension( +//! Struct { +//! norms: Primitive, +//! codes: FixedSizeList, padded_dim, vector_validity>, +//! } +//! ) +//! ``` +//! +//! Stored norms are authoritative for future TurboQuant-aware scalar functions. Decoded quantized +//! directions are not guaranteed to have unit norm after scalar quantization and inverse transform. +//! +//! # Source map +//! +//! Implementation details are documented next to the code that owns them: +//! +//! - `vector/storage.rs`: physical storage shape, full-length child arrays, and field-level +//! validity for null vectors. +//! - `vector/normalize.rs`: TurboQuant-local normalization and how it differs from the tensor +//! crate's null-row zeroing helper. +//! - `vector/quantize.rs`: SORF transform, centroid lookup, and why invalid rows are skipped rather +//! than quantized. +//! - `centroids.rs`: deterministic Max-Lloyd centroid computation and process-local caching. +//! - `sorf/`: the Walsh-Hadamard-based structured transform and the stable SplitMix64 sign stream. +//! +//! The current encoding is intentionally MSE-only. It does not yet implement the paper's QJL +//! residual correction for unbiased inner-product estimation, and it still uses internal +//! power-of-2 padding rather than the block decomposition proposed in RFC 0033. + +mod centroids; +mod config; +mod scalar_fns; +mod sorf; +mod vector; +mod vtable; + +pub use config::TurboQuantConfig; +pub use scalar_fns::TQPack; +pub use scalar_fns::TQUnpack; +pub use vtable::TurboQuant; +pub use vtable::TurboQuantMetadata; + +// TODO(connor): We need to somehow make sure that callers call `vortex_tensor::initialize` first. +/// Register the TurboQuant extension type with a Vortex session. +pub fn initialize(session: &vortex_session::VortexSession) { + use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; + use vortex_array::dtype::session::DTypeSessionExt; + use vortex_array::scalar_fn::session::ScalarFnSessionExt; + use vortex_array::session::ArraySessionExt; + + session.dtypes().register(TurboQuant); + + session.scalar_fns().register(TQPack); + session.scalar_fns().register(TQUnpack); + + let session_arrays = session.arrays(); + session_arrays.register(ScalarFnArrayPlugin::new(TQUnpack)); +} + +#[cfg(test)] +mod tests; diff --git a/vortex-turboquant/src/scalar_fns/metadata.rs b/vortex-turboquant/src/scalar_fns/metadata.rs new file mode 100644 index 00000000000..f5eddfe51d9 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/metadata.rs @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use prost::Message; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::TurboQuantConfig; + +#[derive(Clone, PartialEq, Message)] +pub(super) struct TQScalarFnMetadata { + #[prost(uint32, tag = "1")] + bit_width: u32, + #[prost(uint64, tag = "2")] + seed: u64, + #[prost(uint32, tag = "3")] + num_rounds: u32, +} + +impl TQScalarFnMetadata { + pub(super) fn from_config(config: &TurboQuantConfig) -> Self { + Self { + bit_width: u32::from(config.bit_width()), + seed: config.seed(), + num_rounds: u32::from(config.num_rounds()), + } + } + + pub(super) fn to_config(&self) -> VortexResult { + let bit_width = u8::try_from(self.bit_width) + .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; + let num_rounds = u8::try_from(self.num_rounds) + .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; + + TurboQuantConfig::try_new(bit_width, self.seed, num_rounds) + } +} + +pub(super) fn serialize_config(config: &TurboQuantConfig) -> Vec { + TQScalarFnMetadata::from_config(config).encode_to_vec() +} + +pub(super) fn deserialize_config(metadata: &[u8]) -> VortexResult { + TQScalarFnMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode TurboQuant scalar function metadata: {e}"))? + .to_config() +} diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs new file mode 100644 index 00000000000..9151ab9c1c7 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Scalar functions for lazy TurboQuant vector pack and unpack operations. + +mod metadata; +mod pack; +mod unpack; + +pub use pack::TQPack; +pub use unpack::TQUnpack; diff --git a/vortex-turboquant/src/scalar_fns/pack.rs b/vortex-turboquant/src/scalar_fns/pack.rs new file mode 100644 index 00000000000..7748952ea76 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/pack.rs @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant pack scalar function. + +use std::fmt; +use std::fmt::Formatter; + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; +use vortex_array::dtype::DType; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::expr::Expression; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::scalar_fn::TypedScalarFnInstance; +use vortex_array::serde::ArrayChildren; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_session::VortexSession; +use vortex_tensor::vector::AnyVector; + +use super::metadata::deserialize_config; +use super::metadata::serialize_config; +use crate::TurboQuantConfig; +use crate::config::MIN_DIMENSION; +use crate::vector::pack::pack_vector; +use crate::vector::tq_padded_dim; +use crate::vtable::TurboQuant; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_storage_dtype; + +/// Lazy TurboQuant vector pack scalar function. +#[derive(Clone)] +pub struct TQPack; + +impl TQPack { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant packing. + pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQPack, config.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily packs a `Vector` child into `TurboQuant`. + pub fn try_new_array( + child: ArrayRef, + config: &TurboQuantConfig, + len: usize, + ) -> VortexResult { + ScalarFnArray::try_new(TQPack::new(config).erased(), vec![child], len) + } +} + +impl ScalarFnVTable for TQPack { + type Options = TurboQuantConfig; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new("vortex.turboquant.pack") + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + Ok(Some(serialize_config(options))) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + deserialize_config(metadata) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("vector"), + _ => unreachable!("TQPack must have exactly one child"), + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "tq_pack(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", {options})") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let input_dtype = &arg_dtypes[0]; + let vector_metadata = input_dtype + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TQPack expects a Vector extension array, got {input_dtype}") + })?; + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + tq_padded_dim(dimensions)?; + + let metadata = TurboQuantMetadata { + element_ptype: vector_metadata.element_ptype(), + dimensions, + bit_width: options.bit_width(), + seed: options.seed(), + num_rounds: options.num_rounds(), + }; + let storage_dtype = tq_storage_dtype(&metadata, input_dtype.nullability())?; + let ext_dtype = ExtDType::::try_new(metadata, storage_dtype)?.erased(); + + Ok(DType::Extension(ext_dtype)) + } + + fn execute( + &self, + options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + pack_vector(args.get(0)?, options, ctx) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +impl ScalarFnArrayVTable for TQPack { + fn serialize( + &self, + _view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + Ok(None) + } + + fn deserialize( + &self, + _dtype: &DType, + _len: usize, + _metadata: &[u8], + _children: &dyn ArrayChildren, + _session: &VortexSession, + ) -> VortexResult> { + vortex_bail!("TQPack scalar-fn arrays are not serializable") + } +} diff --git a/vortex-turboquant/src/scalar_fns/unpack.rs b/vortex-turboquant/src/scalar_fns/unpack.rs new file mode 100644 index 00000000000..4d1d414a220 --- /dev/null +++ b/vortex-turboquant/src/scalar_fns/unpack.rs @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant unpack scalar function. + +use std::fmt; +use std::fmt::Formatter; +use std::sync::Arc; + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::scalar_fn::ScalarFnArrayView; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; +use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::expr::Expression; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar_fn::Arity; +use vortex_array::scalar_fn::ChildName; +use vortex_array::scalar_fn::ExecutionArgs; +use vortex_array::scalar_fn::ScalarFnId; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::scalar_fn::TypedScalarFnInstance; +use vortex_array::serde::ArrayChildren; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_session::VortexSession; +use vortex_tensor::vector::AnyVector; +use vortex_tensor::vector::Vector; + +use super::metadata::deserialize_config; +use super::metadata::serialize_config; +use crate::TurboQuantConfig; +use crate::vector::unpack::unpack_vector; +use crate::vtable::TurboQuant; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_metadata; +use crate::vtable::tq_storage_dtype; + +/// Lazy TurboQuant vector unpack scalar function. +#[derive(Clone)] +pub struct TQUnpack; + +impl TQUnpack { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant unpacking. + pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQUnpack, config.clone()) + } + + /// Constructs a [`ScalarFnArray`] that lazily unpacks a `TurboQuant` child into a `Vector`. + pub fn try_new_array( + child: ArrayRef, + config: &TurboQuantConfig, + len: usize, + ) -> VortexResult { + ScalarFnArray::try_new(TQUnpack::new(config).erased(), vec![child], len) + } +} + +impl ScalarFnVTable for TQUnpack { + type Options = TurboQuantConfig; + + fn id(&self) -> ScalarFnId { + ScalarFnId::new("vortex.turboquant.unpack") + } + + fn serialize(&self, options: &Self::Options) -> VortexResult>> { + Ok(Some(serialize_config(options))) + } + + fn deserialize( + &self, + metadata: &[u8], + _session: &VortexSession, + ) -> VortexResult { + deserialize_config(metadata) + } + + fn arity(&self, _options: &Self::Options) -> Arity { + Arity::Exact(1) + } + + fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { + match child_idx { + 0 => ChildName::from("turboquant"), + _ => unreachable!("TQUnpack must have exactly one child"), + } + } + + fn fmt_sql( + &self, + options: &Self::Options, + expr: &Expression, + f: &mut Formatter<'_>, + ) -> fmt::Result { + write!(f, "tq_unpack(")?; + expr.child(0).fmt_sql(f)?; + write!(f, ", {options})") + } + + fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + let child_dtype = &arg_dtypes[0]; + let metadata = tq_metadata(child_dtype)?; + validate_config_matches_metadata(options, &metadata)?; + + let storage_dtype = DType::FixedSizeList( + Arc::new(DType::Primitive( + metadata.element_ptype, + Nullability::NonNullable, + )), + metadata.dimensions, + child_dtype.nullability(), + ); + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage_dtype)?.erased(); + + Ok(DType::Extension(ext_dtype)) + } + + fn execute( + &self, + _options: &Self::Options, + args: &dyn ExecutionArgs, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + unpack_vector(args.get(0)?, ctx) + } + + fn validity( + &self, + _options: &Self::Options, + expression: &Expression, + ) -> VortexResult> { + Ok(Some(expression.child(0).validity()?)) + } + + fn is_null_sensitive(&self, _options: &Self::Options) -> bool { + false + } + + fn is_fallible(&self, _options: &Self::Options) -> bool { + false + } +} + +impl ScalarFnArrayVTable for TQUnpack { + fn serialize( + &self, + view: &ScalarFnArrayView, + _session: &VortexSession, + ) -> VortexResult>> { + Ok(Some(serialize_config(view.options))) + } + + fn deserialize( + &self, + dtype: &DType, + len: usize, + metadata: &[u8], + children: &dyn ArrayChildren, + _session: &VortexSession, + ) -> VortexResult> { + let options = deserialize_config(metadata)?; + let vector_metadata = dtype + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| { + vortex_err!("TQUnpack parent dtype must be a Vector extension array, got {dtype}") + })?; + + let metadata = TurboQuantMetadata { + element_ptype: vector_metadata.element_ptype(), + dimensions: vector_metadata.dimensions(), + bit_width: options.bit_width(), + seed: options.seed(), + num_rounds: options.num_rounds(), + }; + let storage_dtype = tq_storage_dtype(&metadata, dtype.nullability())?; + let child_dtype = + DType::Extension(ExtDType::::try_new(metadata, storage_dtype)?.erased()); + let child = children.get(0, &child_dtype, len)?; + + Ok(ScalarFnArrayParts { + options, + children: vec![child], + }) + } +} + +fn validate_config_matches_metadata( + config: &TurboQuantConfig, + metadata: &TurboQuantMetadata, +) -> VortexResult<()> { + vortex_ensure_eq!( + config.bit_width(), + metadata.bit_width, + "TQUnpack config bit_width must match TurboQuant child metadata" + ); + vortex_ensure_eq!( + config.seed(), + metadata.seed, + "TQUnpack config seed must match TurboQuant child metadata" + ); + vortex_ensure_eq!( + config.num_rounds(), + metadata.num_rounds, + "TQUnpack config num_rounds must match TurboQuant child metadata" + ); + Ok(()) +} diff --git a/vortex-turboquant/src/sorf/mod.rs b/vortex-turboquant/src/sorf/mod.rs new file mode 100644 index 00000000000..cce477aa906 --- /dev/null +++ b/vortex-turboquant/src/sorf/mod.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod splitmix64; +mod transform; + +pub(crate) use transform::SorfMatrix; diff --git a/vortex-turboquant/src/sorf/splitmix64.rs b/vortex-turboquant/src/sorf/splitmix64.rs new file mode 100644 index 00000000000..1233e4dc7ee --- /dev/null +++ b/vortex-turboquant/src/sorf/splitmix64.rs @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Frozen local SplitMix64 stream used to define SORF sign diagonals. +//! +//! This is a direct translation of the `splitmix64.c` [reference implementation][impl]. +//! +//! The state is a single `u64`, and `next_u64()` first adds [`SPLITMIX64_INCREMENT`] with wrapping +//! arithmetic, then applies the two reference mixing steps and final xor-shift. +//! +//! [impl]: https://prng.di.unimi.it/splitmix64.c + +/// SplitMix64 additive constant from the reference implementation. +const SPLITMIX64_INCREMENT: u64 = 0x9E37_79B9_7F4A_7C15; + +/// First SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL1: u64 = 0xBF58_476D_1CE4_E5B9; + +/// Second SplitMix64 mixing multiplier from the reference implementation. +const SPLITMIX64_MUL2: u64 = 0x94D0_49BB_1331_11EB; + +/// Frozen local SplitMix64 stream used to define SORF sign diagonals. +pub(crate) struct SplitMix64 { + state: u64, +} + +impl SplitMix64 { + pub(crate) fn new(seed: u64) -> Self { + Self { state: seed } + } + + pub(crate) fn next_u64(&mut self) -> u64 { + self.state = self.state.wrapping_add(SPLITMIX64_INCREMENT); + let mut z = self.state; + z = (z ^ (z >> 30)).wrapping_mul(SPLITMIX64_MUL1); + z = (z ^ (z >> 27)).wrapping_mul(SPLITMIX64_MUL2); + z ^ (z >> 31) + } +} + +#[cfg(test)] +mod tests { + use super::SplitMix64; + + const SPLITMIX64_SEED0_GOLDEN: [u64; 4] = [ + 0xE220_A839_7B1D_CDAF, + 0x6E78_9E6A_A1B9_65F4, + 0x06C4_5D18_8009_454F, + 0xF88B_B8A8_724C_81EC, + ]; + + const SPLITMIX64_SEED42_GOLDEN: [u64; 4] = [ + 0xBDD7_3226_2FEB_6E95, + 0x28EF_E333_B266_F103, + 0x4752_6757_130F_9F52, + 0x581C_E1FF_0E4A_E394, + ]; + + #[test] + fn splitmix64_seed0_matches_golden_outputs() { + let mut rng = SplitMix64::new(0); + let actual: Vec<_> = (0..SPLITMIX64_SEED0_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED0_GOLDEN); + } + + #[test] + fn splitmix64_seed42_matches_golden_outputs() { + let mut rng = SplitMix64::new(42); + let actual: Vec<_> = (0..SPLITMIX64_SEED42_GOLDEN.len()) + .map(|_| rng.next_u64()) + .collect(); + assert_eq!(actual, SPLITMIX64_SEED42_GOLDEN); + } +} diff --git a/vortex-turboquant/src/sorf/transform.rs b/vortex-turboquant/src/sorf/transform.rs new file mode 100644 index 00000000000..3fa221fe03a --- /dev/null +++ b/vortex-turboquant/src/sorf/transform.rs @@ -0,0 +1,419 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! SORF (Structured Orthogonal Random Features) orthogonal transform. +//! +//! Implements the SORF construction from [Yu et al. 2016][sorf-paper]: a fast structured +//! approximation to a random orthogonal matrix using random sign diagonals interleaved with the +//! Fast Walsh-Hadamard Transform (FWHT). +//! +//! [sorf-paper]: https://proceedings.neurips.cc/paper_files/paper/2016/file/53adaf494dc89ef7196d73636eb2451b-Paper.pdf +//! +//! For `k` rounds, the transform is `norm * H * D_k * ... * H * D_1 * x`, where `D_1` is the +//! first sign diagonal applied. The number of rounds is configurable (typically 3). Each round +//! applies a random sign diagonal `D_i` and then the Hadamard matrix `H`, giving O(d log d) cost +//! per matrix-vector product instead of the O(d^2) cost of a dense orthogonal matrix. +//! +//! This implementation defines those sign diagonals using a frozen local SplitMix64 stream rather +//! than an +//! external RNG crate. The contract is: +//! +//! - state is a single `u64` seed, +//! - each `next_u64()` call uses the SplitMix64 reference algorithm with wrapping `u64` +//! arithmetic, +//! - signs are generated in round-major, block-major order, +//! - each generated `u64` contributes 64 signs in least-significant-bit-first order, +//! - bit `1` means `+1` and bit `0` means `-1`. +//! +//! This makes SORF sign generation stable as an extension format contract even if external RNG +//! implementations change. +//! +//! This transform is the crate's practical structured transform choice for TurboQuant. It is not +//! the dense random Gaussian or orthogonal matrix used by some theoretical analyses, so theoretical +//! bounds from those models need separate validation before being presented as implementation +//! guarantees. +//! +//! The FWHT exploits the Kronecker product structure of the Hadamard matrix (`H_n = H_2 (x) H_2 +//! (x) ... (x) H_2`, with `log2(n)` factors) to compute the matrix-vector product in O(n log n) +//! time using only in-place 2-element butterfly operations. No row of the full n x n Hadamard +//! matrix is ever materialized. +//! +//! For dimensions that are not powers of 2, the input is zero-padded to the next power of 2 before +//! the transform and truncated afterward. +//! +//! # Sign representation +//! +//! Signs are stored internally as `u32` XOR masks: `0x00000000` for +1 (no-op) and `0x80000000` for +//! -1 (flip IEEE 754 sign bit). The sign application function uses integer XOR instead of +//! floating-point multiply, which avoids FP dependency chains and auto-vectorizes into +//! `vpxor`/`veor`. + +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; + +use super::splitmix64::SplitMix64; + +/// IEEE 754 sign bit mask for f32. +const F32_SIGN_BIT: u32 = 0x8000_0000; + +/// A Walsh-Hadamard-based structured orthogonal transform matrix. +/// +/// All computation is done in f32. The sign diagonals are stored as IEEE 754 XOR masks on +/// f32 bit patterns, and the Walsh-Hadamard butterfly operates on `&mut [f32]` slices. +pub(crate) struct SorfMatrix { + /// Flat XOR masks for all `num_rounds` diagonal matrices, total length + /// `num_rounds * padded_dim`. + /// + /// Indexed as `round * padded_dim + i`. `0x00000000` = multiply by +1 (no-op), `0x80000000` = + /// multiply by -1 (flip sign bit). + sign_masks: Vec, + + /// The number of sign-diagonal + WHT rounds. + num_rounds: usize, + + /// The padded dimension (next power of 2 >= dimension). + padded_dim: usize, + + /// Normalization factor: `padded_dim^(-num_rounds/2)`, applied once at the end. + /// + /// This is stored for convenience. + norm_factor: f32, +} + +impl SorfMatrix { + /// Create a new structured Walsh-Hadamard-based orthogonal transform for a padded dimension. + /// + /// `padded_dimensions` must already be a power of two. Callers that start from an unpadded + /// logical dimension are responsible for padding it before constructing the matrix. + pub(crate) fn try_new( + padded_dimensions: usize, + num_rounds: usize, + seed: u64, + ) -> VortexResult { + vortex_ensure!(num_rounds >= 1, "num_rounds must be >= 1, got {num_rounds}"); + vortex_ensure!( + padded_dimensions.is_power_of_two(), + "padded_dimensions must be a power of two, got {padded_dimensions}" + ); + + let padded_dim = padded_dimensions; + let sign_masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + + // Compute in f64 for precision, then store as f32 since the WHT operates on f32 buffers. + // The result is always in (0, 1] for any valid padded_dim >= 2 and num_rounds >= 1, so + // the f64 -> f32 cast is a precision loss only (it cannot overflow to infinity). + #[expect( + clippy::cast_possible_truncation, + reason = "the norm factor is in (0, 1] so the f64 -> f32 cast cannot overflow" + )] + let norm_factor = (padded_dim as f64).powf(-(num_rounds as f64) / 2.0) as f32; + + Ok(Self { + sign_masks, + num_rounds, + padded_dim, + norm_factor, + }) + } + + /// Returns the padded dimension (next power of 2 >= dim). + /// + /// All `transform`/`inverse_transform` buffers must be this length. + pub(crate) fn padded_dim(&self) -> usize { + self.padded_dim + } + + /// Apply the forward orthogonal transform: `output = R(input)`. + /// + /// Both `input` and `output` must have length [`padded_dim()`](Self::padded_dim). The caller is + /// responsible for zero-padding input beyond `dim` positions. + pub(crate) fn transform(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_srht(output); + } + + /// Apply the inverse orthogonal transform: `output = R⁻¹(input)`. + /// + /// Both `input` and `output` must have length `padded_dim()`. + pub(crate) fn inverse_transform(&self, input: &[f32], output: &mut [f32]) { + debug_assert_eq!(input.len(), self.padded_dim); + debug_assert_eq!(output.len(), self.padded_dim); + + output.copy_from_slice(input); + self.apply_inverse_srht(output); + } + + /// Apply the forward structured transform: `norm · H · D_k · ... · H · D₁ · x`. + fn apply_srht(&self, buf: &mut [f32]) { + for round in 0..self.num_rounds { + self.apply_signs_xor(buf, round); + walsh_hadamard_transform(buf); + } + + buf.iter_mut().for_each(|val| *val *= self.norm_factor); + } + + /// Apply the inverse structured transform. + /// + /// Forward is: `norm · H · D_k · ... · H · D₁`. + /// Inverse is: `norm · D₁ · H · ... · D_k · H`. + fn apply_inverse_srht(&self, buf: &mut [f32]) { + for round in (0..self.num_rounds).rev() { + walsh_hadamard_transform(buf); + self.apply_signs_xor(buf, round); + } + + buf.iter_mut().for_each(|val| *val *= self.norm_factor); + } + + /// Apply one round's sign masks via XOR on the IEEE 754 sign bit. + /// + /// This is branchless and auto-vectorizes into `vpxor` (x86) / `veor` (ARM). Equivalent to + /// multiplying each element by +/-1.0, but avoids FP dependency chains. + fn apply_signs_xor(&self, buf: &mut [f32], round: usize) { + let masks = &self.sign_masks[round * self.padded_dim..][..self.padded_dim]; + for (val, &mask) in buf.iter_mut().zip(masks.iter()) { + *val = f32::from_bits(val.to_bits() ^ mask); + } + } +} + +/// Generate XOR sign masks from the frozen local SplitMix64 stream. +/// +/// Signs are produced in round-major, block-major order. For each block we call +/// [`SplitMix64::next_u64`] exactly once and unpack its bits from least significant to most +/// significant. Bit `1` means positive sign / `0x00000000`; bit `0` means negative sign / +/// [`F32_SIGN_BIT`]. +fn gen_sign_masks_from_seed(seed: u64, padded_dim: usize, num_rounds: usize) -> Vec { + let mut rng = SplitMix64::new(seed); + let mut sign_masks = Vec::with_capacity(num_rounds * padded_dim); + + for _round in 0..num_rounds { + for base_idx in (0..padded_dim).step_by(64) { + let word = rng.next_u64(); + let bits_in_block = (padded_dim - base_idx).min(64); + sign_masks.extend((0..bits_in_block).map(|bit_idx| sign_mask_from_word(word, bit_idx))); + } + } + + sign_masks +} + +/// Convert one bit from a SplitMix64 output word into an XOR sign mask. +fn sign_mask_from_word(word: u64, bit_idx: usize) -> u32 { + if ((word >> bit_idx) & 1) != 0 { + 0u32 + } else { + F32_SIGN_BIT + } +} + +/// In-place Fast Walsh-Hadamard Transform (FWHT), unnormalized and iterative. +/// +/// Input length must be a power of 2. Runs in O(n log n) via `log2(n)` stages of `n / 2` +/// [`butterfly`] operations each. See the [module-level docs](self) for why this avoids +/// materializing the full Hadamard matrix. +/// +/// The chunk-based iteration gives LLVM enough structure to auto-vectorize each butterfly call +/// into NEON/AVX SIMD instructions. +fn walsh_hadamard_transform(buf: &mut [f32]) { + let len = buf.len(); + debug_assert!(len.is_power_of_two()); + + let mut half = 1; + while half < len { + let stride = half * 2; + // Process in chunks of `stride` elements. Within each chunk, + // split into non-overlapping (lo, hi) halves for the butterfly. + for chunk in buf.chunks_exact_mut(stride) { + let (lo, hi) = chunk.split_at_mut(half); + butterfly(lo, hi); + } + half *= 2; + } +} + +/// Butterfly: `(lo[i], hi[i]) -> (lo[i] + hi[i], lo[i] - hi[i])`. +/// +/// This is multiplication by the 2x2 Hadamard kernel `H_2 = [[1, 1], [1, -1]]` on each element +/// pair. Factored into a separate function so LLVM can see the slice lengths match and +/// auto-vectorize. +fn butterfly(lo: &mut [f32], hi: &mut [f32]) { + debug_assert_eq!(lo.len(), hi.len()); + for (a, b) in lo.iter_mut().zip(hi.iter_mut()) { + let sum = *a + *b; + let diff = *a - *b; + *a = sum; + *b = diff; + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_error::VortexResult; + + use super::*; + + fn dim_to_usize(dim: u32) -> usize { + usize::try_from(dim).unwrap() + } + + fn rounds_to_usize(num_rounds: u8) -> usize { + usize::from(num_rounds) + } + + #[test] + fn deterministic_from_seed() -> VortexResult<()> { + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(3u8); + let seed = 42u64; + let transform1 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; + let transform2 = SorfMatrix::try_new(padded_dim, num_rounds, seed)?; + let pd = transform1.padded_dim(); + + let mut input = vec![0.0f32; pd]; + for i in 0..padded_dim { + input[i] = i as f32; + } + let mut out1 = vec![0.0f32; pd]; + let mut out2 = vec![0.0f32; pd]; + + transform1.transform(&input, &mut out1); + transform2.transform(&input, &mut out2); + + assert_eq!(out1, out2); + Ok(()) + } + + #[test] + fn one_word_generates_64_signs_lsb_first() { + let seed = 42u64; + let padded_dim = dim_to_usize(64u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); + + let mut rng = SplitMix64::new(seed); + let word = rng.next_u64(); + let expected: Vec<_> = (0..padded_dim) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + + #[test] + fn rejects_non_power_of_two_padded_dimensions() { + assert!(SorfMatrix::try_new(dim_to_usize(100u32), rounds_to_usize(3u8), 42u64).is_err()); + } + + #[test] + fn tail_block_uses_only_required_bits() { + let seed = 42u64; + let padded_dim = dim_to_usize(32u32); + let num_rounds = rounds_to_usize(1u8); + let masks = gen_sign_masks_from_seed(seed, padded_dim, num_rounds); + assert_eq!(masks.len(), padded_dim); + + let mut rng = SplitMix64::new(seed); + let word = rng.next_u64(); + let expected: Vec<_> = (0..padded_dim) + .map(|bit_idx| sign_mask_from_word(word, bit_idx)) + .collect(); + assert_eq!(masks, expected); + } + + /// Verify roundtrip is exact to f32 precision across many dimensions and round counts, + /// including non-power-of-two dimensions that require padding. + #[rstest] + #[case(32u32, 3u8)] + #[case(64u32, 3u8)] + #[case(100u32, 3u8)] + #[case(128u32, 1u8)] + #[case(128u32, 2u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(256u32, 3u8)] + #[case(512u32, 3u8)] + #[case(768u32, 3u8)] + #[case(1024u32, 3u8)] + fn roundtrip_exact(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 42u64)?; + let padded_dim = transform.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32 + 1.0) * 0.01; + } + let mut transformed = vec![0.0f32; padded_dim]; + let mut recovered = vec![0.0f32; padded_dim]; + + transform.transform(&input, &mut transformed); + transform.inverse_transform(&transformed, &mut recovered); + + let max_err: f32 = input + .iter() + .zip(recovered.iter()) + .map(|(a, b)| (a - b).abs()) + .fold(0.0f32, f32::max); + let max_val: f32 = input.iter().map(|x| x.abs()).fold(0.0f32, f32::max); + let rel_err = max_err / max_val; + + // SRHT roundtrip should be exact up to f32 precision (~1e-6). + assert!( + rel_err < 1e-5, + "roundtrip relative error too large for dim={dim}, rounds={num_rounds}: {rel_err:.2e}" + ); + Ok(()) + } + + /// Verify norm preservation across dimensions and round counts. + #[rstest] + #[case(128u32, 1u8)] + #[case(128u32, 3u8)] + #[case(128u32, 5u8)] + #[case(768u32, 3u8)] + fn preserves_norm(#[case] dim: u32, #[case] num_rounds: u8) -> VortexResult<()> { + let dim = dim_to_usize(dim); + let num_rounds = rounds_to_usize(num_rounds); + let transform = SorfMatrix::try_new(dim.next_power_of_two(), num_rounds, 7u64)?; + let padded_dim = transform.padded_dim(); + + let mut input = vec![0.0f32; padded_dim]; + for i in 0..dim { + input[i] = (i as f32) * 0.01; + } + let input_norm: f32 = input.iter().map(|x| x * x).sum::().sqrt(); + + let mut transformed = vec![0.0f32; padded_dim]; + transform.transform(&input, &mut transformed); + let transformed_norm: f32 = transformed.iter().map(|x| x * x).sum::().sqrt(); + + assert!( + (input_norm - transformed_norm).abs() / input_norm < 1e-5, + "norm not preserved for dim={dim}: {} vs {} (rel err: {:.2e})", + input_norm, + transformed_norm, + (input_norm - transformed_norm).abs() / input_norm + ); + Ok(()) + } + + #[test] + fn wht_basic() { + // WHT of [1, 0, 0, 0] should be [1, 1, 1, 1] + let mut buf = vec![1.0f32, 0.0, 0.0, 0.0]; + walsh_hadamard_transform(&mut buf); + assert_eq!(buf, vec![1.0, 1.0, 1.0, 1.0]); + + // WHT is self-inverse (up to scaling by n) + walsh_hadamard_transform(&mut buf); + // After two WHTs: each element multiplied by n=4 + assert_eq!(buf, vec![4.0, 0.0, 0.0, 0.0]); + } +} diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs new file mode 100644 index 00000000000..3c5dbe5ccd9 --- /dev/null +++ b/vortex-turboquant/src/tests/file.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::stream::ArrayStreamExt; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_file::OpenOptionsSessionExt; +use vortex_file::VortexWriteOptions; +use vortex_io::runtime::BlockingRuntime; +use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_tensor::vector::Vector; + +use super::execute_tq_pack; +use super::execute_tq_unpack_from_metadata; +use super::f32_vector_array; +use super::file_session; +use super::vector_validity; +use crate::TQUnpack; +use crate::TurboQuantConfig; +use crate::vtable::tq_metadata; + +#[test] +fn file_roundtrip_with_initialize_session() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, packed.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + let metadata = tq_metadata(read.dtype())?; + assert_eq!(metadata.dimensions, 128); + let unpacked = execute_tq_unpack_from_metadata(read, &mut ctx)?; + let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} + +#[test] +fn file_roundtrip_lazy_unpack_scalar_fn_with_initialize_session() -> VortexResult<()> { + let runtime = SingleThreadRuntime::default(); + let session = file_session(&runtime); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let packed = execute_tq_pack(input, &config, &mut ctx)?; + let unpacked = TQUnpack::try_new_array(packed, &config, 2)?.into_array(); + + let mut file_bytes = Vec::new(); + VortexWriteOptions::new(session.clone()) + .blocking(&runtime) + .write(&mut file_bytes, unpacked.to_array_iterator())?; + + let file = session.open_options().open_buffer(file_bytes)?; + let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; + + assert!(read.dtype().as_extension().is::()); + + let validity = vector_validity(read, &mut ctx)?.execute_mask(2, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs new file mode 100644 index 00000000000..7ca368dd00b --- /dev/null +++ b/vortex-turboquant/src/tests/malformed.rs @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; + +use super::execute_tq_unpack_from_metadata; +use super::test_session; +use super::vector_validity; +use crate::TurboQuant; +use crate::TurboQuantMetadata; + +#[rstest] +#[case::nullable_norms_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::Nullable, + Nullability::NonNullable +)] +#[case::nullable_codes_under_nonnullable_struct( + Nullability::NonNullable, + Nullability::NonNullable, + Nullability::Nullable +)] +#[case::nonnullable_norms_under_nullable_struct( + Nullability::Nullable, + Nullability::NonNullable, + Nullability::Nullable +)] +#[case::nonnullable_codes_under_nullable_struct( + Nullability::Nullable, + Nullability::Nullable, + Nullability::NonNullable +)] +fn unpack_accepts_child_nullability_that_covers_struct_validity( + #[case] struct_nullability: Nullability, + #[case] norms_nullability: Nullability, + #[case] codes_nullability: Nullability, +) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::from(norms_nullability)) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes.into_array(), + 128, + Validity::from(codes_nullability), + 1, + ) + .unwrap() + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::from(struct_nullability), + ) + .unwrap(); + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) + .unwrap() + .into_array(); + + execute_tq_unpack_from_metadata(tq, &mut ctx)?; + Ok(()) +} + +#[test] +fn unpack_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0, 1.0, 1.0]), Validity::NonNullable) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 3)? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + let unpacked = execute_tq_unpack_from_metadata(tq, &mut ctx)?; + let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(3, &mut ctx)?; + assert!(validity.value(0)); + assert!(!validity.value(1)); + assert!(validity.value(2)); + Ok(()) +} + +#[test] +fn unpack_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = PrimitiveArray::new::( + Buffer::copy_from([1.0, 1.0, 1.0]), + Validity::from_iter([true, true, false]), + ) + .into_array(); + let codes = PrimitiveArray::new::(vec![0u8; 3 * 128], Validity::NonNullable); + let codes = FixedSizeListArray::try_new( + codes.into_array(), + 128, + Validity::from_iter([true, false, true]), + 3, + )? + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 3, + Validity::from_iter([true, false, true]), + )?; + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? + .into_array(); + + assert!(execute_tq_unpack_from_metadata(tq, &mut ctx).is_err()); + Ok(()) +} + +#[test] +#[should_panic(expected = "TurboQuant code exceeds centroid count")] +fn unpack_panics_on_codes_outside_centroid_table() { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 1, + seed: 42, + num_rounds: 3, + }; + let norms = + PrimitiveArray::new::(Buffer::copy_from([1.0]), Validity::NonNullable).into_array(); + let mut codes = vec![0u8; 128]; + codes[0] = 2; + let codes = PrimitiveArray::new::(codes, Validity::NonNullable); + let codes = FixedSizeListArray::try_new(codes.into_array(), 128, Validity::NonNullable, 1) + .unwrap() + .into_array(); + let storage = StructArray::try_new( + FieldNames::from(["norms", "codes"]), + vec![norms, codes], + 1, + Validity::NonNullable, + ) + .unwrap(); + let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array()) + .unwrap() + .into_array(); + + drop(execute_tq_unpack_from_metadata(tq, &mut ctx)); +} diff --git a/vortex-turboquant/src/tests/metadata.rs b/vortex-turboquant/src/tests/metadata.rs new file mode 100644 index 00000000000..e0d1042f02f --- /dev/null +++ b/vortex-turboquant/src/tests/metadata.rs @@ -0,0 +1,173 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::sync::Arc; + +use prost::Message; +use rstest::rstest; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::StructFields; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtVTable; +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +use crate::TurboQuant; +use crate::TurboQuantMetadata; +use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::NORMS_FIELD; +use crate::vector::tq_padded_dim; + +#[derive(Clone, PartialEq, Message)] +struct MetadataWire { + #[prost(enumeration = "PType", tag = "1")] + element_ptype: i32, + #[prost(uint32, tag = "2")] + dimensions: u32, + #[prost(uint32, tag = "3")] + bit_width: u32, + #[prost(uint64, tag = "4")] + seed: u64, + #[prost(uint32, tag = "5")] + num_rounds: u32, +} + +fn tq_storage_dtype( + metadata: &TurboQuantMetadata, + row_nullability: Nullability, +) -> VortexResult { + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + Ok(DType::Struct( + StructFields::new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![ + DType::Primitive(metadata.element_ptype, row_nullability), + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + padded_dim, + row_nullability, + ), + ], + ), + row_nullability, + )) +} + +#[rstest] +#[case::f16(PType::F16)] +#[case::f32(PType::F32)] +#[case::f64(PType::F64)] +fn metadata_serialization_roundtrips(#[case] element_ptype: PType) -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + let encoded = TurboQuant.serialize_metadata(&metadata)?; + let decoded = TurboQuant.deserialize_metadata(&encoded)?; + + assert_eq!(decoded, metadata); + Ok(()) +} + +#[test] +fn metadata_serialization_uses_ptype_discriminants() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + let encoded = TurboQuant.serialize_metadata(&metadata)?; + let wire = MetadataWire::decode(encoded.as_slice())?; + + assert_eq!(wire.element_ptype, PType::F32 as i32); + assert_eq!(wire.dimensions, 128); + Ok(()) +} + +#[test] +fn metadata_display_matches_field_order() { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 4, + seed: 7, + num_rounds: 3, + }; + + assert_eq!( + metadata.to_string(), + "element_ptype: f32, dimensions: 128, bit_width: 4, seed: 7, num_rounds: 3" + ); +} + +#[test] +fn dtype_validation_accepts_expected_storage() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 129, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + + ExtDType::::try_new( + metadata, + tq_storage_dtype(&metadata, Nullability::Nullable)?, + )?; + Ok(()) +} + +#[test] +fn dtype_validation_accepts_nonnullable_storage() -> VortexResult<()> { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 129, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + + ExtDType::::try_new( + metadata, + tq_storage_dtype(&metadata, Nullability::NonNullable)?, + )?; + Ok(()) +} + +#[test] +fn dtype_validation_rejects_malformed_storage() { + let metadata = TurboQuantMetadata { + element_ptype: PType::F32, + dimensions: 128, + bit_width: 2, + seed: 42, + num_rounds: 3, + }; + let storage = DType::Struct( + StructFields::new( + FieldNames::from(["norms", "codes"]), + vec![ + DType::Primitive(PType::F32, Nullability::Nullable), + DType::FixedSizeList( + DType::Primitive(PType::U8, Nullability::Nullable).into(), + 128, + Nullability::NonNullable, + ), + ], + ), + Nullability::Nullable, + ); + + assert!(ExtDType::::try_new(metadata, storage).is_err()); +} diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs new file mode 100644 index 00000000000..3b2bce0c4e2 --- /dev/null +++ b/vortex-turboquant/src/tests/mod.rs @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![cfg_attr( + test, + allow(clippy::unwrap_used, clippy::expect_used, clippy::unwrap_in_result) +)] + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::PType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::memory::MemorySession; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_io::runtime::BlockingRuntime; +use vortex_io::runtime::single::SingleThreadRuntime; +use vortex_io::session::RuntimeSession; +use vortex_io::session::RuntimeSessionExt; +use vortex_layout::session::LayoutSession; +use vortex_session::VortexSession; +use vortex_tensor::vector::Vector; + +use crate::TQPack; +use crate::TQUnpack; +use crate::TurboQuantConfig; +use crate::initialize; +use crate::vtable::tq_metadata; + +mod file; +mod malformed; +mod metadata; +mod pack_unpack; +mod parity; +mod scalar_fns; + +fn test_session() -> VortexSession { + let session = VortexSession::empty().with::(); + initialize(&session); + session +} + +fn file_session(runtime: &SingleThreadRuntime) -> VortexSession { + let session = VortexSession::empty() + .with::() + .with::() + .with::() + .with::() + .with_handle(runtime.handle()); + vortex_file::register_default_encodings(&session); + vortex_tensor::initialize(&session); + initialize(&session); + session +} + +fn vector_array( + dimensions: u32, + values: &[T], + validity: Validity, +) -> VortexResult { + assert!(dimensions > 0, "dimensions must be > 0"); + let row_count = values.len() / dimensions as usize; + + let elements = PrimitiveArray::new::( + values.iter().copied().collect::>(), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new(elements.into_array(), dimensions, validity, row_count)?; + + Ok(ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array())?.into_array()) +} + +fn f32_vector_array( + dimensions: u32, + rows: usize, + scale: f32, + validity: Validity, +) -> VortexResult { + let values = (0..rows * dimensions as usize) + .map(|i| ((i % 17) as f32 - 8.0) * scale) + .collect::>(); + vector_array(dimensions, &values, validity) +} + +fn vector_values_f32(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult> { + let ext: ExtensionArray = array.execute(ctx)?; + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + let elements: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + Ok(elements.as_slice::().to_vec()) +} + +fn vector_validity(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext: ExtensionArray = array.execute(ctx)?; + let fsl: FixedSizeListArray = ext.storage_array().clone().execute(ctx)?; + fsl.validity() +} + +fn vector_element_ptype(array: &ExtensionArray) -> VortexResult { + Ok(array + .storage_array() + .dtype() + .as_fixed_size_list_element_opt() + .ok_or_else(|| vortex_err!("expected FixedSizeList vector storage"))? + .as_ptype()) +} + +fn turboquant_storage(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let ext: ExtensionArray = array.execute(ctx)?; + ext.storage_array().clone().execute(ctx) +} + +fn execute_tq_pack( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let len = input.len(); + TQPack::try_new_array(input, config, len)? + .into_array() + .execute(ctx) +} + +fn execute_tq_unpack( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let len = input.len(); + TQUnpack::try_new_array(input, config, len)? + .into_array() + .execute(ctx) +} + +fn execute_tq_unpack_from_metadata( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let metadata = tq_metadata(input.dtype())?; + let config = TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds)?; + + execute_tq_unpack(input, &config, ctx) +} diff --git a/vortex-turboquant/src/tests/pack_unpack.rs b/vortex-turboquant/src/tests/pack_unpack.rs new file mode 100644 index 00000000000..9c5a56f007a --- /dev/null +++ b/vortex-turboquant/src/tests/pack_unpack.rs @@ -0,0 +1,258 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use rstest::rstest; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_array::arrays::struct_::StructArrayExt; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_error::VortexResult; + +use super::execute_tq_pack; +use super::execute_tq_unpack; +use super::f32_vector_array; +use super::test_session; +use super::turboquant_storage; +use super::vector_array; +use super::vector_element_ptype; +use super::vector_validity; +use super::vector_values_f32; +use crate::TurboQuantConfig; +use crate::centroids::compute_or_get_centroids; +use crate::vector::normalize::tq_normalize_as_l2_denorm; + +#[rstest] +#[case::zero_bits(0, 42, 3)] +#[case::too_many_bits(9, 42, 3)] +#[case::zero_rounds(2, 42, 0)] +fn config_rejects_invalid_values(#[case] bit_width: u8, #[case] seed: u64, #[case] num_rounds: u8) { + assert!(TurboQuantConfig::try_new(bit_width, seed, num_rounds).is_err()); +} + +#[test] +fn pack_rejects_non_vector_input() { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = PrimitiveArray::new::(Buffer::copy_from([1.0, 2.0]), Validity::NonNullable) + .into_array(); + + assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); +} + +#[test] +fn pack_rejects_small_dimensions() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(127, 1, 1.0, Validity::NonNullable)?; + + assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn pack_rejects_padded_dimension_overflow() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; + + assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + Ok(()) +} + +#[test] +fn centroid_cache_is_deterministic() -> VortexResult<()> { + let first = compute_or_get_centroids(128, 3)?; + let second = compute_or_get_centroids(128, 3)?; + + assert_eq!(first.as_slice(), second.as_slice()); + Ok(()) +} + +#[test] +fn pack_unpack_empty_vectors() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = vector_array::(128, &[], Validity::NonNullable)?; + + let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; + let unpacked = execute_tq_unpack(packed, &TurboQuantConfig::default(), &mut ctx)?; + + assert!(unpacked.is_empty()); + Ok(()) +} + +#[test] +fn pack_stores_norms_and_struct_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let validity = Validity::from_iter([true, false, true]); + let input = f32_vector_array(128, 3, 0.25, validity)?; + + let config = TurboQuantConfig::try_new(3, 1, 2)?; + let packed = execute_tq_pack(input, &config, &mut ctx)?; + let storage = turboquant_storage(packed, &mut ctx)?; + let mask = storage.struct_validity().execute_mask(3, &mut ctx)?; + let norms: PrimitiveArray = storage + .unmasked_field_by_name("norms")? + .clone() + .execute(&mut ctx)?; + let codes: FixedSizeListArray = storage + .unmasked_field_by_name("codes")? + .clone() + .execute(&mut ctx)?; + + assert!(mask.value(0)); + assert!(!mask.value(1)); + assert!(mask.value(2)); + assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!(codes.validity()?.nullability(), Nullability::Nullable); + + let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + let codes_validity = codes.validity()?.execute_mask(3, &mut ctx)?; + assert!(norms_validity.value(0)); + assert!(!norms_validity.value(1)); + assert!(norms_validity.value(2)); + assert!(codes_validity.value(0)); + assert!(!codes_validity.value(1)); + assert!(codes_validity.value(2)); + + let codes_values: PrimitiveArray = codes.elements().clone().execute(&mut ctx)?; + assert!( + codes_values.as_slice::()[128..256] + .iter() + .all(|&code| code == 0) + ); + Ok(()) +} + +#[test] +fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[128..256].fill(13.0); + values[256] = 1.0; + let input = vector_array(128, &values, Validity::from_iter([true, false, true]))?; + + let l2_denorm = tq_normalize_as_l2_denorm(input, &mut ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + + let normalized_ext: ExtensionArray = normalized.execute(&mut ctx)?; + let normalized_fsl: FixedSizeListArray = + normalized_ext.storage_array().clone().execute(&mut ctx)?; + let normalized_values: PrimitiveArray = normalized_fsl.elements().clone().execute(&mut ctx)?; + let norms: PrimitiveArray = norms.execute(&mut ctx)?; + let normalized_validity = normalized_fsl.validity()?.execute_mask(3, &mut ctx)?; + let norms_validity = norms.validity()?.execute_mask(3, &mut ctx)?; + + assert!(normalized_validity.value(0)); + assert!(!normalized_validity.value(1)); + assert!(normalized_validity.value(2)); + assert!(norms_validity.value(0)); + assert!(!norms_validity.value(1)); + assert!(norms_validity.value(2)); + assert_eq!(norms.validity()?.nullability(), Nullability::Nullable); + assert_eq!(norms.as_slice::()[0], 5.0); + assert!( + normalized_values.as_slice::()[128..256] + .iter() + .all(|&value| value == 0.0) + ); + Ok(()) +} + +#[test] +fn pack_unpack_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let mut values = vec![0.0f32; 3 * 128]; + values[0] = 3.0; + values[1] = 4.0; + values[256] = 1.0; + values[257] = -1.0; + let input = vector_array(128, &values, Validity::from_iter([true, true, false]))?; + + let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; + let unpacked = execute_tq_unpack(packed, &TurboQuantConfig::default(), &mut ctx)?; + let output = vector_values_f32(unpacked.clone(), &mut ctx)?; + let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(3, &mut ctx)?; + + assert!(validity.value(0)); + assert!(validity.value(1)); + assert!(!validity.value(2)); + assert!(output[128..256].iter().all(|&v| v == 0.0)); + Ok(()) +} + +#[rstest] +#[case::f16(PType::F16)] +#[case::f64(PType::F64)] +fn pack_unpack_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + match ptype { + PType::F16 => { + let values = (0..2 * 128) + .map(|i| half::f16::from_f32(((i % 17) as f32 - 8.0) * 0.25)) + .collect::>(); + let input = vector_array(128, &values, Validity::NonNullable)?; + let packed = execute_tq_pack(input, &config, &mut ctx)?; + let unpacked = execute_tq_unpack(packed, &config, &mut ctx)?; + let ext: ExtensionArray = unpacked.execute(&mut ctx)?; + assert_eq!(vector_element_ptype(&ext)?, PType::F16); + } + PType::F64 => { + let values = (0..2 * 128) + .map(|i| ((i % 17) as f64 - 8.0) * 0.25) + .collect::>(); + let input = vector_array(128, &values, Validity::NonNullable)?; + let packed = execute_tq_pack(input, &config, &mut ctx)?; + let unpacked = execute_tq_unpack(packed, &config, &mut ctx)?; + let ext: ExtensionArray = unpacked.execute(&mut ctx)?; + assert_eq!(vector_element_ptype(&ext)?, PType::F64); + } + _ => unreachable!("test only passes f16/f64"), + } + Ok(()) +} + +#[test] +fn unpack_scales_by_stored_norms() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let base = f32_vector_array(128, 1, 0.5, Validity::NonNullable)?; + let scaled = f32_vector_array(128, 1, 1.0, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(2, 99, 3)?; + + let base_values = vector_values_f32( + execute_tq_unpack(execute_tq_pack(base, &config, &mut ctx)?, &config, &mut ctx)?, + &mut ctx, + )?; + let scaled_values = vector_values_f32( + execute_tq_unpack( + execute_tq_pack(scaled, &config, &mut ctx)?, + &config, + &mut ctx, + )?, + &mut ctx, + )?; + + for (base, scaled) in base_values.iter().zip(scaled_values.iter()) { + assert!((*scaled - 2.0 * *base).abs() < 1e-5); + } + Ok(()) +} diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs new file mode 100644 index 00000000000..b6cb6c1b0cd --- /dev/null +++ b/vortex-turboquant/src/tests/parity.rs @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::VortexSessionExecute; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::execute_tq_pack; +use super::execute_tq_unpack; +use super::f32_vector_array; +use super::test_session; +use super::vector_values_f32; +use crate::TurboQuantConfig; + +#[test] +fn unpack_pack_matches_old_turboquant_decode() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let new_packed = execute_tq_pack(input.clone(), &config, &mut ctx)?; + let new_decoded = execute_tq_unpack(new_packed, &config, &mut ctx)?; + let old_config = vortex_tensor::encodings::turboquant::TurboQuantConfig { + bit_width: config.bit_width(), + seed: config.seed(), + num_rounds: config.num_rounds(), + }; + let old_decoded = + vortex_tensor::encodings::turboquant::turboquant_encode(input, &old_config, &mut ctx)? + .execute(&mut ctx)?; + + let new_values = vector_values_f32(new_decoded, &mut ctx)?; + let old_values = vector_values_f32(old_decoded, &mut ctx)?; + + assert_eq!(new_values, old_values); + Ok(()) +} diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs new file mode 100644 index 00000000000..a937643d217 --- /dev/null +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::ArrayPlugin; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::scalar_fn::ScalarFnVTable; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; + +use super::execute_tq_pack; +use super::f32_vector_array; +use super::test_session; +use super::vector_validity; +use crate::TQPack; +use crate::TQUnpack; +use crate::TurboQuant; +use crate::TurboQuantConfig; +use crate::vtable::tq_metadata; + +#[test] +fn scalar_fn_ids_and_config_options_roundtrip() -> VortexResult<()> { + let session = test_session(); + let config = TurboQuantConfig::try_new(4, 7, 2)?; + + assert_eq!(TQPack.id().as_ref(), "vortex.turboquant.pack"); + assert_eq!(TQUnpack.id().as_ref(), "vortex.turboquant.unpack"); + + let pack_metadata = TQPack.serialize(&config)?.unwrap(); + let unpack_metadata = TQUnpack.serialize(&config)?.unwrap(); + + assert_eq!(TQPack.deserialize(&pack_metadata, &session)?, config); + assert_eq!(TQUnpack.deserialize(&unpack_metadata, &session)?, config); + Ok(()) +} + +#[test] +fn scalar_fn_arrays_pack_and_unpack_vectors() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + + let packed_lazy = TQPack::try_new_array(input, &config, 2)?; + let packed_metadata = tq_metadata(packed_lazy.dtype())?; + assert_eq!(packed_metadata.dimensions, 128); + assert_eq!(packed_metadata.bit_width, config.bit_width()); + assert!(packed_lazy.dtype().as_extension().is::()); + + let packed = packed_lazy.into_array().execute(&mut ctx)?; + let unpacked_lazy = TQUnpack::try_new_array(packed, &config, 2)?; + let unpacked = unpacked_lazy.into_array().execute(&mut ctx)?; + let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(2, &mut ctx)?; + + assert!(validity.value(0)); + assert!(!validity.value(1)); + Ok(()) +} + +#[test] +fn scalar_fn_array_metadata_stores_only_config() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 2, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let config_metadata = TQUnpack.serialize(&config)?.unwrap(); + + let packed = execute_tq_pack(input, &config, &mut ctx)?; + let unpacked_lazy = TQUnpack::try_new_array(packed, &config, 2)?.into_array(); + let unpack_plugin = ScalarFnArrayPlugin::new(TQUnpack); + assert_eq!( + unpack_plugin.serialize(&unpacked_lazy, &session)?.unwrap(), + config_metadata + ); + Ok(()) +} + +#[test] +fn pack_scalar_fn_array_is_not_serializable() -> VortexResult<()> { + let session = test_session(); + let input = f32_vector_array(128, 2, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let packed_lazy = TQPack::try_new_array(input, &config, 2)?.into_array(); + let pack_plugin = ScalarFnArrayPlugin::new(TQPack); + + assert!(pack_plugin.serialize(&packed_lazy, &session)?.is_none()); + Ok(()) +} + +#[test] +fn unpack_rejects_config_that_disagrees_with_turboquant_child() -> VortexResult<()> { + let session = test_session(); + let mut ctx = session.create_execution_ctx(); + let input = f32_vector_array(128, 1, 0.25, Validity::NonNullable)?; + let config = TurboQuantConfig::try_new(3, 42, 3)?; + let different_config = TurboQuantConfig::try_new(4, 42, 3)?; + + let packed = TQPack::try_new_array(input, &config, 1)? + .into_array() + .execute(&mut ctx)?; + + assert!(TQUnpack::try_new_array(packed, &different_config, 1).is_err()); + Ok(()) +} diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs new file mode 100644 index 00000000000..3a7c978cfae --- /dev/null +++ b/vortex-turboquant/src/vector/mod.rs @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +mod quantize; + +pub(crate) mod normalize; +pub(crate) mod pack; +pub(crate) mod storage; +pub(crate) mod unpack; + +use vortex_error::VortexResult; +use vortex_error::vortex_err; + +/// Compute the padded SORF dimension for an original vector dimension. +pub(crate) fn tq_padded_dim(dimensions: u32) -> VortexResult { + let padded_dim = dimensions + .checked_next_power_of_two() + .ok_or_else(|| vortex_err!("TurboQuant padded dimension overflow for {dimensions}"))?; + + usize::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit usize")) +} diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs new file mode 100644 index 00000000000..5a66f7c0e4f --- /dev/null +++ b/vortex-turboquant/src/vector/normalize.rs @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant-local vector normalization. + +// TODO(connor): Remove this comment once we delete the other version in `vortex-tensor`. +// The tensor crate also has a `normalize_as_l2_denorm` helper, but TurboQuant needs different +// validity semantics: a null vector is not a zero vector, so invalid rows keep their row validity +// on both `L2Denorm` children and downstream quantization skips them. + +use num_traits::Float; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::NativePType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; +use vortex_mask::Mask; +use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; +use vortex_tensor::scalar_fns::l2_norm::L2Norm; +use vortex_tensor::vector::AnyVector; +use vortex_tensor::vector::Vector; + +/// Normalize a `Vector` array and wrap it with its original row norms with [`L2Denorm`]. +/// +/// This preserves input row validity on both [`L2Denorm`] children. Or in other words, validity is +/// propagated down to the children so that TurboQuant can skip quantizing those vectors (as it does +/// not have a good way to represent 0 vectors in its quantized domain). +pub(crate) fn tq_normalize_as_l2_denorm( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let row_count = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant normalization expects a Vector extension array"))?; + let dimensions = vector_metadata.dimensions() as usize; + let vector_validity = input.validity()?; + + // Use `L2Norm` to calculate the normals for each vector. + let norms: ArrayRef = L2Norm::try_new_array(input.clone(), row_count)? + .into_array() + .execute(ctx)?; + let primitive_norms: PrimitiveArray = norms.clone().execute(ctx)?; + + let input: ExtensionArray = input.execute(ctx)?; + let storage: FixedSizeListArray = input.storage_array().clone().execute(ctx)?; + vortex_ensure_eq!( + storage.list_size() as usize, + dimensions, + "Vector storage dimension must be {dimensions}, got {}", + storage.list_size() + ); + let elements: PrimitiveArray = storage.elements().clone().execute(ctx)?; + + let mask = vector_validity.execute_mask(row_count, ctx)?; + + let normalized = match_each_float_ptype!(elements.ptype(), |T| { + normalize_vectors::( + &elements, + &primitive_norms, + &mask, + dimensions, + vector_validity.clone(), + ) + })?; + + // SAFETY: matches the lossy-encoding relaxation documented on + // [`L2Denorm::new_array_unchecked`]. Norms come from `L2Norm` over the same input, so they + // match the vector element type and row count. Valid nonzero rows are divided by their stored + // norm and are unit-norm. Valid zero-norm rows and invalid rows use physical zero placeholders; + // invalid rows remain guarded by row-level invalid validity. + unsafe { L2Denorm::new_array_unchecked(normalized, norms, row_count) } +} + +fn normalize_vectors( + elements: &PrimitiveArray, + norms: &PrimitiveArray, + mask: &Mask, + dimensions: usize, + vector_validity: Validity, +) -> VortexResult +where + T: Float + NativePType, +{ + let num_vectors = norms.len(); + + let values = elements.as_slice::(); + let norm_values = norms.as_slice::(); + + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + for i in 0..num_vectors { + let row_values = &values[i * dimensions..][..dimensions]; + let norm = norm_values[i]; + + if !mask.value(i) || norm == T::zero() { + // Invalid rows are placeholders guarded by validity. A valid vector with L2 norm zero + // is all zeros. In both cases, append zero placeholders without per-element work. + + // SAFETY: `output` was allocated with capacity `output_len`, and the loop appends + // exactly `dimensions` values for each of `num_vectors` iterations. + unsafe { output.push_n_unchecked(T::zero(), dimensions) }; + } else { + for &value in row_values { + // SAFETY: same capacity invariant as above. + unsafe { output.push_unchecked(value / norm) }; + } + } + } + + // Vector elements are always non-nullable. + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + + #[expect( + clippy::cast_possible_truncation, + reason = "this initially came from a u32" + )] + let storage = FixedSizeListArray::try_new( + elements.into_array(), + dimensions as u32, + vector_validity, + num_vectors, + )?; + + Ok( + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, storage.into_array())? + .into_array(), + ) +} diff --git a/vortex-turboquant/src/vector/pack.rs b/vortex-turboquant/src/vector/pack.rs new file mode 100644 index 00000000000..e277ed1c847 --- /dev/null +++ b/vortex-turboquant/src/vector/pack.rs @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant packing (quantization) logic. +//! +//! The input to [`pack_vector()`] must be a [`Vector`](vortex_tensor::vector::Vector) extension +//! array. [`pack_vector()`] computes original row norms, normalizes valid rows internally via +//! [`tq_normalize_as_l2_denorm()`], quantizes the normalized child, and stores row-aligned norms and +//! codes in the TurboQuant extension storage. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; +use vortex_tensor::vector::AnyVector; + +use super::normalize::tq_normalize_as_l2_denorm; +use super::quantize::empty_quantization; +use super::quantize::turboquant_quantize_core; +use super::storage::build_codes_child; +use super::storage::build_storage; +use super::tq_padded_dim; +use crate::TurboQuantConfig; +use crate::config::MIN_DIMENSION; +use crate::vtable::TurboQuant; +use crate::vtable::TurboQuantMetadata; + +/// Lossily pack a `Vector` extension array into a `TurboQuant` extension array. +/// +/// Valid rows are normalized internally before SORF transform and scalar quantization. The original +/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct +/// and both row-aligned child arrays. +pub(crate) fn pack_vector( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_vectors = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant pack expects a Vector extension array"))?; + + let element_ptype = vector_metadata.element_ptype(); + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + let padded_dim = tq_padded_dim(dimensions)?; + + let vector_validity = input.validity()?; + + // We must normalize the vectors in order to apply the transform during quantization. + // NB: The 2 child arrays share the same validity with `input`. + let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + + let normalized_ext = normalized + .as_opt::() + .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; + let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; + + let core = if normalized_fsl.is_empty() { + empty_quantization(padded_dim) + } else { + // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. + unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } + }; + let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + + // Now that we have the codes into the centroid codebook and the norms, we can build the + // TurboQuant extension array. + let metadata = TurboQuantMetadata { + element_ptype, + dimensions, + bit_width: config.bit_width(), + seed: config.seed(), + num_rounds: config.num_rounds(), + }; + let storage = build_storage(norms, codes, num_vectors, vector_validity)?; + + Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) +} diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs new file mode 100644 index 00000000000..c551d33861f --- /dev/null +++ b/vortex-turboquant/src/vector/quantize.rs @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Core TurboQuant quantization helpers. +//! +//! Quantization consumes the TurboQuant-local normalized `Vector` child. Valid rows are transformed +//! and mapped to scalar centroid indices. Invalid rows remain in the full-length output but are +//! skipped: their physical code bytes are placeholders guarded by the `codes` row validity. +//! +//! This matters because TurboQuant's scalar codebook is optimized for coordinates of transformed +//! unit-norm vectors. The codebook does not generally contain an exact zero centroid, and a +//! physical code byte of `0` means "centroid 0", not "zero coordinate". Null vectors therefore +//! should not be converted to zero vectors and fed through the quantizer. + +use half::f16; +use vortex_array::ExecutionCtx; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::dtype::PType; +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; + +use super::tq_padded_dim; +use crate::TurboQuantConfig; +use crate::centroids::compute_centroid_boundaries; +use crate::centroids::compute_or_get_centroids; +use crate::centroids::find_nearest_centroid; +use crate::sorf::SorfMatrix; + +/// Shared intermediate results from the quantization loop. +pub(crate) struct QuantizationResult { + pub(crate) all_indices: Buffer, + pub(crate) padded_dim: usize, +} + +pub(crate) fn empty_quantization(padded_dim: usize) -> QuantizationResult { + QuantizationResult { + all_indices: Buffer::empty(), + padded_dim, + } +} + +/// Core quantization: transform and quantize already-normalized rows. +/// +/// # Safety +/// +/// The input `fsl` must contain unit-norm vectors (already L2-normalized) for every valid row. +/// Invalid rows are left row-aligned in the output but are not transformed or quantized. The +/// transform and centroid lookup happen in f32. +pub(crate) unsafe fn turboquant_quantize_core( + fsl: &FixedSizeListArray, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let dimension = fsl.list_size(); + let num_vectors = fsl.len(); + let padded_dim = tq_padded_dim(dimension)?; + + let sorf_transform = + SorfMatrix::try_new(padded_dim, config.num_rounds() as usize, config.seed())?; + debug_assert_eq!(sorf_transform.padded_dim(), padded_dim); + let padded_dim_u32 = u32::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + let elements_prim: PrimitiveArray = fsl.elements().clone().execute(ctx)?; + let f32_elements = cast_to_f32(elements_prim)?; + let validity = fsl.validity()?; + let mask = validity.execute_mask(num_vectors, ctx)?; + + let centroids = compute_or_get_centroids(padded_dim_u32, config.bit_width())?; + let boundaries = compute_centroid_boundaries(¢roids); + + let codes_len = num_vectors + .checked_mul(padded_dim) + .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; + let mut all_indices = BufferMut::::with_capacity(codes_len); + let mut padded = vec![0.0f32; padded_dim]; + let mut transformed = vec![0.0f32; padded_dim]; + + let f32_slice = f32_elements.as_slice(); + let dimension = dimension as usize; + for row in 0..num_vectors { + if !mask.value(row) { + // The row-level FSL validity marks these bytes invalid, so keep full-length storage + // without spending SORF/quantization work on a null vector. + + // SAFETY: `all_indices` was allocated with capacity `codes_len`, and the loop appends + // exactly `padded_dim` codes for each of `num_vectors` iterations. + unsafe { all_indices.push_n_unchecked(0, padded_dim) }; + + continue; + } + + let x = &f32_slice[row * dimension..(row + 1) * dimension]; + + // Zero-pad to the next power of 2. + padded[..dimension].copy_from_slice(x); + padded[dimension..].fill(0.0); + + sorf_transform.transform(&padded, &mut transformed); + + // SAFETY: `all_indices` was allocated with capacity `codes_len`, and the loop appends + // exactly `padded_dim` codes for each of `num_vectors` iterations. + for &value in transformed.iter() { + unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; + } + } + + Ok(QuantizationResult { + all_indices: all_indices.freeze(), + padded_dim, + }) +} + +/// Cast a float [`PrimitiveArray`] to a `Buffer`. +/// +/// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively +/// in f32. This function handles the cast from any float ptype: +/// +/// - f16: losslessly widened to f32. +/// - f32: zero-copy buffer extraction. +/// - f64: truncated to f32 precision. Values outside f32 range become +/- infinity. This is +/// acceptable because callers of this function operate in f32 and document this constraint. +fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { + match prim.ptype() { + PType::F16 => Ok(prim + .as_slice::() + .iter() + .map(|&v| f32::from(v)) + .collect()), + PType::F32 => Ok(prim.into_buffer()), + PType::F64 => Ok(prim + .as_slice::() + .iter() + .map(|&v| { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 values outside f32 range become infinity, matching tensor TQ" + )] + let v = v as f32; + v + }) + .collect()), + other => vortex_bail!("expected float elements, got {other:?}"), + } +} diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs new file mode 100644 index 00000000000..4794d51d323 --- /dev/null +++ b/vortex-turboquant/src/vector/storage.rs @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant physical storage helpers. +//! +//! TurboQuant storage is row-aligned and full length: +//! +//! ```text +//! Struct { +//! norms: Primitive, +//! codes: FixedSizeList, padded_dim, vector_validity>, +//! } +//! ``` +//! +//! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. +//! This is deliberate duplication: null vectors remain null throughout pack/unpack instead of being +//! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the +//! field-level validity records that those rows were not quantized. +//! +//! Parsing treats the outer struct validity as authoritative. Child validity may be wider than the +//! struct validity, for example after a generic mask only updates the struct validity, but each +//! child must be valid wherever the struct row is valid. + +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::arrays::StructArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::struct_::StructArrayExt; +use vortex_array::dtype::FieldNames; +use vortex_array::validity::Validity; +use vortex_error::VortexResult; +use vortex_error::vortex_ensure; +use vortex_error::vortex_err; + +use super::quantize::QuantizationResult; +use crate::vtable::TurboQuantMetadata; +use crate::vtable::tq_metadata; + +/// Name of the stored row-norm child. +pub(crate) const NORMS_FIELD: &str = "norms"; + +/// Name of the stored quantized-code child. +pub(crate) const CODES_FIELD: &str = "codes"; + +/// Parsed TurboQuant storage arrays. +/// +/// We use this as a helper struct for working with a TurboQuant extension array. +pub(crate) struct TurboQuantParsedStorage { + pub(crate) metadata: TurboQuantMetadata, + pub(crate) vector_validity: Validity, + pub(crate) norms: PrimitiveArray, + pub(crate) codes: PrimitiveArray, + pub(crate) len: usize, +} + +/// Build the `codes: FixedSizeList, padded_dim>` storage child. +/// +/// Each row of `padded_dim` u8 codes indexes into the deterministic centroid codebook derived +/// from `(padded_dim, bit_width)`. The centroid values are intentionally not stored in the array. +pub(crate) fn build_codes_child( + num_vectors: usize, + quantization: QuantizationResult, + vector_validity: Validity, +) -> VortexResult { + let codes = PrimitiveArray::new::(quantization.all_indices, Validity::NonNullable); + let padded_dim_u32 = u32::try_from(quantization.padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + Ok(FixedSizeListArray::try_new( + codes.into_array(), + padded_dim_u32, + vector_validity, + num_vectors, + )? + .into_array()) +} + +/// Build the TurboQuant `Struct { norms, codes }` storage array. +pub(crate) fn build_storage( + norms: ArrayRef, + codes: ArrayRef, + len: usize, + vector_validity: Validity, +) -> VortexResult { + Ok(StructArray::try_new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![norms, codes], + len, + vector_validity, + )? + .into_array()) +} + +/// Parse a TurboQuant extension array into executed storage children. +pub(crate) fn parse_storage( + input: ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let metadata = tq_metadata(input.dtype())?; + let ext: ExtensionArray = input.execute(ctx)?; + let storage: StructArray = ext.storage_array().clone().execute(ctx)?; + + let norms: PrimitiveArray = storage + .unmasked_field_by_name(NORMS_FIELD)? + .clone() + .execute(ctx)?; + + let codes_fsl: FixedSizeListArray = storage + .unmasked_field_by_name(CODES_FIELD)? + .clone() + .execute(ctx)?; + let codes: PrimitiveArray = codes_fsl.elements().clone().execute(ctx)?; + + let len = storage.len(); + let struct_validity = storage.struct_validity(); + let norms_validity = norms.validity()?; + let codes_validity = codes_fsl.validity()?; + validate_child_validity_covers_struct( + &struct_validity, + &norms_validity, + &codes_validity, + len, + ctx, + )?; + + Ok(TurboQuantParsedStorage { + metadata, + vector_validity: struct_validity, + norms, + codes, + len, + }) +} + +fn validate_child_validity_covers_struct( + struct_validity: &Validity, + norms_validity: &Validity, + codes_validity: &Validity, + len: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + let codes_mask = codes_validity.execute_mask(len, ctx)?; + + for row in 0..len { + if !struct_mask.value(row) { + continue; + } + + vortex_ensure!( + norms_mask.value(row), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + vortex_ensure!( + codes_mask.value(row), + "TurboQuant {CODES_FIELD} row validity must cover storage validity" + ); + } + + Ok(()) +} diff --git a/vortex-turboquant/src/vector/unpack.rs b/vortex-turboquant/src/vector/unpack.rs new file mode 100644 index 00000000000..0c812d7bc02 --- /dev/null +++ b/vortex-turboquant/src/vector/unpack.rs @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! TurboQuant unpacking (dequantization) logic. +//! +//! Note that because TurboQuant is a lossy compression scheme, unpacking does not roundtrip with +//! the initial packing. + +use num_traits::Float; +use num_traits::FromPrimitive; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::match_each_float_ptype; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; +use vortex_error::VortexResult; +use vortex_error::vortex_err; +use vortex_tensor::vector::Vector; + +use super::storage::parse_storage; +use super::tq_padded_dim; +use crate::centroids::compute_or_get_centroids; +use crate::sorf::SorfMatrix; +use crate::vtable::TurboQuantMetadata; + +/// Decode a `TurboQuant` extension array back into a `Vector` extension array. +/// +/// The decoded directions are inverse-transformed, truncated to the original dimension, and +/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with +/// [`TQPack`](crate::TQPack). +pub(crate) fn unpack_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + // Get the input TurboQuant array into a form that is easier to work with. + let parsed = parse_storage(input, ctx)?; + let metadata = parsed.metadata; + if parsed.len == 0 { + return build_empty_vector(metadata, parsed.vector_validity); + } + + let padded_dim = tq_padded_dim(metadata.dimensions)?; + let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; + let padded_dim = u32::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + // We retrieve the centroids on read because they are mostly known statically for the given + // settings. + let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; + + match_each_float_ptype!(metadata.element_ptype, |T| { + unpack_typed::( + DecodeInputs { + metadata: &metadata, + sorf_matrix: &transform, + centroids: ¢roids, + norms: &parsed.norms, + codes: &parsed.codes, + }, + parsed.vector_validity, + parsed.len, + ctx, + ) + }) +} + +fn build_empty_vector( + metadata: TurboQuantMetadata, + vector_validity: Validity, +) -> VortexResult { + match_each_float_ptype!(metadata.element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + 0, + )?; + + Vector::try_new_vector_array(fsl.into_array()) + }) +} + +struct DecodeInputs<'a> { + metadata: &'a TurboQuantMetadata, + sorf_matrix: &'a SorfMatrix, + centroids: &'a [f32], + norms: &'a PrimitiveArray, + codes: &'a PrimitiveArray, +} + +fn unpack_typed( + decode: DecodeInputs<'_>, + vector_validity: Validity, + num_vectors: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType + Float + FromPrimitive, +{ + let metadata = decode.metadata; + let dimensions = usize::try_from(metadata.dimensions) + .vortex_expect("dimensions stays representable as usize"); + let padded_dim = decode.sorf_matrix.padded_dim(); + let centroids = decode.centroids; + let norms = decode.norms.as_slice::(); + let codes = decode.codes.as_slice::(); + let mask = vector_validity.execute_mask(num_vectors, ctx)?; + + let mut decoded = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + for i in 0..num_vectors { + if !mask.value(i) { + // Null rows still need `dimensions` placeholder elements so the FSL stays row-aligned. + // The row's nullability bit is authoritative, so consumers should be treating these + // zeros as undefined. + + // SAFETY: `output` was allocated with capacity `output_len`, and the loop appends + // exactly `dimensions` values for each of `num_vectors` iterations. + unsafe { output.push_n_unchecked(T::zero(), dimensions) }; + + continue; + } + + // Perform the gather from codes to values. + let code_row = &codes[i * padded_dim..][..padded_dim]; + for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { + *dst = *centroids + .get(usize::from(code)) + .vortex_expect("TurboQuant code exceeds centroid count"); + } + + // Undo the transform. + decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); + + // Multiply all elements by the corresponding normal. + let norm = norms[i]; + for &value in inverse.iter().take(dimensions) { + // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, `f64`): + // values outside `f16` range saturate to `±inf` rather than returning `None`. + let value = T::from_f32(value) + .vortex_expect("from_f32 is infallible for supported float types"); + + // SAFETY: same capacity invariant as above. + unsafe { output.push_unchecked(value * norm) }; + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + num_vectors, + )?; + + Vector::try_new_vector_array(fsl.into_array()) +} diff --git a/vortex-turboquant/src/vtable.rs b/vortex-turboquant/src/vtable.rs new file mode 100644 index 00000000000..51dd0933ca6 --- /dev/null +++ b/vortex-turboquant/src/vtable.rs @@ -0,0 +1,233 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use std::fmt; +use std::sync::Arc; + +use prost::Message; +use vortex_array::dtype::DType; +use vortex_array::dtype::FieldNames; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::PType; +use vortex_array::dtype::StructFields; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::dtype::extension::ExtId; +use vortex_array::dtype::extension::ExtVTable; +use vortex_array::scalar::ScalarValue; +use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_ensure; +use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_err; + +use crate::TurboQuantConfig; +use crate::config::MIN_DIMENSION; +use crate::vector::storage::CODES_FIELD; +use crate::vector::storage::NORMS_FIELD; +use crate::vector::tq_padded_dim; + +/// TurboQuant logical extension type. +#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] +pub struct TurboQuant; + +/// Serialized metadata for a TurboQuant extension array. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub struct TurboQuantMetadata { + /// Original vector element type and stored norm type. + pub element_ptype: PType, + /// Original vector dimension before SORF padding. + pub dimensions: u32, + /// Bits per coordinate in the scalar quantizer codebook. + pub bit_width: u8, + /// Seed used to derive the deterministic SORF transform. + pub seed: u64, + /// Number of sign-diagonal plus Walsh-Hadamard rounds in the SORF transform. + pub num_rounds: u8, +} + +impl ExtVTable for TurboQuant { + type Metadata = TurboQuantMetadata; + type NativeValue<'a> = &'a ScalarValue; + + fn id(&self) -> ExtId { + ExtId::new("vortex.turboquant") + } + + fn serialize_metadata(&self, metadata: &Self::Metadata) -> VortexResult> { + validate_tq_metadata(metadata)?; + + let proto = TurboQuantMetadataProto { + element_ptype: metadata.element_ptype as i32, + dimensions: metadata.dimensions, + bit_width: u32::from(metadata.bit_width), + seed: metadata.seed, + num_rounds: u32::from(metadata.num_rounds), + }; + + Ok(proto.encode_to_vec()) + } + + fn deserialize_metadata(&self, metadata: &[u8]) -> VortexResult { + let proto = TurboQuantMetadataProto::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode TurboQuantMetadata: {e}"))?; + let bit_width = u8::try_from(proto.bit_width) + .map_err(|_| vortex_err!("TurboQuant bit_width does not fit u8"))?; + let num_rounds = u8::try_from(proto.num_rounds) + .map_err(|_| vortex_err!("TurboQuant num_rounds does not fit u8"))?; + let element_ptype = PType::try_from(proto.element_ptype).map_err(|e| { + vortex_err!( + "invalid TurboQuant metadata element_ptype code {}: {e}", + proto.element_ptype + ) + })?; + + let metadata = TurboQuantMetadata { + element_ptype, + dimensions: proto.dimensions, + bit_width, + seed: proto.seed, + num_rounds, + }; + validate_tq_metadata(&metadata)?; + + Ok(metadata) + } + + fn validate_dtype(ext_dtype: &ExtDType) -> VortexResult<()> { + validate_tq_metadata(ext_dtype.metadata())?; + validate_tq_storage_dtype(ext_dtype.metadata(), ext_dtype.storage_dtype()) + } + + fn unpack_native<'a>( + _ext_dtype: &'a ExtDType, + storage_value: &'a ScalarValue, + ) -> VortexResult> { + Ok(storage_value) + } +} + +#[derive(Clone, PartialEq, Message)] +struct TurboQuantMetadataProto { + #[prost(enumeration = "PType", tag = "1")] + element_ptype: i32, + #[prost(uint32, tag = "2")] + dimensions: u32, + #[prost(uint32, tag = "3")] + bit_width: u32, + #[prost(uint64, tag = "4")] + seed: u64, + #[prost(uint32, tag = "5")] + num_rounds: u32, +} + +/// Extract TurboQuant metadata from a dtype. +/// +/// Returns an error when the dtype is not the TurboQuant extension type. +pub(crate) fn tq_metadata(dtype: &DType) -> VortexResult { + let ext_dtype = dtype + .as_extension_opt() + .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; + + let metadata = ext_dtype + .metadata_opt::() + .ok_or_else(|| vortex_err!("expected a TurboQuant extension array, got {dtype}"))?; + + Ok(*metadata) +} + +pub(crate) fn tq_storage_dtype( + metadata: &TurboQuantMetadata, + row_nullability: Nullability, +) -> VortexResult { + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + Ok(DType::Struct( + StructFields::new( + FieldNames::from([NORMS_FIELD, CODES_FIELD]), + vec![ + DType::Primitive(metadata.element_ptype, row_nullability), + DType::FixedSizeList( + Arc::new(DType::Primitive(PType::U8, Nullability::NonNullable)), + padded_dim, + row_nullability, + ), + ], + ), + row_nullability, + )) +} + +fn validate_tq_metadata(metadata: &TurboQuantMetadata) -> VortexResult<()> { + vortex_ensure!( + metadata.dimensions >= MIN_DIMENSION, + "TurboQuant dimensions must be >= {MIN_DIMENSION}, got {}", + metadata.dimensions + ); + vortex_ensure!( + metadata.element_ptype.is_float(), + "TurboQuant element_ptype must be a float, got {:?}", + metadata.element_ptype + ); + + tq_padded_dim(metadata.dimensions)?; + + TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds).map(|_| ()) +} + +fn validate_tq_storage_dtype(metadata: &TurboQuantMetadata, dtype: &DType) -> VortexResult<()> { + let DType::Struct(fields, _) = dtype else { + vortex_bail!("TurboQuant storage dtype must be a Struct, got {dtype}"); + }; + let expected_names = FieldNames::from([NORMS_FIELD, CODES_FIELD]); + vortex_ensure_eq!( + fields.names(), + &expected_names, + "TurboQuant storage fields must be {expected_names}, got {}", + fields.names() + ); + + let Some(norms_dtype) = fields.field(NORMS_FIELD) else { + vortex_bail!("TurboQuant storage missing {NORMS_FIELD} field"); + }; + let DType::Primitive(norms_ptype, _) = norms_dtype else { + vortex_bail!("TurboQuant {NORMS_FIELD} field must be primitive, got {norms_dtype}"); + }; + vortex_ensure_eq!( + norms_ptype, + metadata.element_ptype, + "TurboQuant {NORMS_FIELD} ptype must be {}, got {norms_ptype}", + metadata.element_ptype + ); + + let Some(codes_dtype) = fields.field(CODES_FIELD) else { + vortex_bail!("TurboQuant storage missing {CODES_FIELD} field"); + }; + let DType::FixedSizeList(element_dtype, list_size, _) = codes_dtype else { + vortex_bail!("TurboQuant {CODES_FIELD} field must be fixed-size-list, got {codes_dtype}"); + }; + let padded_dim = u32::try_from(tq_padded_dim(metadata.dimensions)?) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + vortex_ensure_eq!( + list_size, + padded_dim, + "TurboQuant {CODES_FIELD} list size must be {padded_dim}, got {list_size}" + ); + vortex_ensure_eq!( + element_dtype.as_ref(), + &DType::Primitive(PType::U8, Nullability::NonNullable), + "TurboQuant {CODES_FIELD} elements must be non-nullable u8, got {element_dtype}" + ); + + Ok(()) +} + +impl fmt::Display for TurboQuantMetadata { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "element_ptype: {}, dimensions: {}, bit_width: {}, seed: {}, num_rounds: {}", + self.element_ptype, self.dimensions, self.bit_width, self.seed, self.num_rounds + ) + } +} From 3c3836fedc5f8028eda15cb394186299817121e6 Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 8 May 2026 14:17:11 -0400 Subject: [PATCH 3/5] perf(turboquant): specialize pack and unpack for mask variants Signed-off-by: Connor Tsui --- Cargo.lock | 2 + vortex-turboquant/Cargo.toml | 6 + vortex-turboquant/benches/pack_unpack.rs | 149 ++++++++++++++++++++++ vortex-turboquant/public-api.lock | 6 - vortex-turboquant/src/scalar_fns/pack.rs | 34 +---- vortex-turboquant/src/tests/scalar_fns.rs | 12 -- vortex-turboquant/src/vector/normalize.rs | 120 ++++++++++++++--- vortex-turboquant/src/vector/quantize.rs | 71 ++++++++--- vortex-turboquant/src/vector/storage.rs | 54 ++++---- vortex-turboquant/src/vector/unpack.rs | 72 ++++++++--- 10 files changed, 395 insertions(+), 131 deletions(-) create mode 100644 vortex-turboquant/benches/pack_unpack.rs diff --git a/Cargo.lock b/Cargo.lock index 8f4d775e035..bf8c74d7058 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11255,9 +11255,11 @@ dependencies = [ name = "vortex-turboquant" version = "0.1.0" dependencies = [ + "codspeed-divan-compat", "half", "num-traits", "prost 0.14.3", + "rand 0.10.1", "rstest", "vortex-array", "vortex-buffer", diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index f8148d8ff8c..c5ebba4b227 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -29,7 +29,13 @@ vortex-tensor = { workspace = true } vortex-utils = { workspace = true, features = ["dashmap"] } [dev-dependencies] +divan = { workspace = true } +rand = { workspace = true } rstest = { workspace = true } vortex-file = { workspace = true } vortex-io = { workspace = true } vortex-layout = { workspace = true } + +[[bench]] +name = "pack_unpack" +harness = false diff --git a/vortex-turboquant/benches/pack_unpack.rs b/vortex-turboquant/benches/pack_unpack.rs new file mode 100644 index 00000000000..15992f480c1 --- /dev/null +++ b/vortex-turboquant/benches/pack_unpack.rs @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +//! Benchmarks for `turboquant_pack` and `turboquant_unpack` across different validity-mask +//! shapes. +//! +//! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the +//! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, +//! `vector/quantize.rs`, and `vector/unpack.rs`. + +#![expect(clippy::unwrap_used)] + +use std::sync::LazyLock; + +use divan::Bencher; +use rand::RngExt; +use rand::SeedableRng as _; +use rand::rngs::StdRng; +use vortex_array::ArrayRef; +use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::extension::EmptyMetadata; +use vortex_array::session::ArraySession; +use vortex_array::validity::Validity; +use vortex_buffer::Buffer; +use vortex_session::VortexSession; +use vortex_tensor::vector::Vector; +use vortex_turboquant::TQPack; +use vortex_turboquant::TQUnpack; +use vortex_turboquant::TurboQuantConfig; + +fn main() { + divan::main(); +} + +static SESSION: LazyLock = LazyLock::new(|| { + let session = VortexSession::empty().with::(); + vortex_turboquant::initialize(&session); + session +}); + +/// Shape of the validity mask used to drive the variant-specialized paths. +#[derive(Copy, Clone)] +enum MaskShape { + AllValid, + AllInvalid, + DenseValues, + SparseValues, +} + +impl std::fmt::Debug for MaskShape { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + MaskShape::AllValid => "all_valid", + MaskShape::AllInvalid => "all_invalid", + MaskShape::DenseValues => "dense_95pct", + MaskShape::SparseValues => "sparse_5pct", + }) + } +} + +impl MaskShape { + fn build(self, rows: usize, rng: &mut StdRng) -> Validity { + match self { + MaskShape::AllValid => Validity::NonNullable, + MaskShape::AllInvalid => Validity::AllInvalid, + MaskShape::DenseValues => Validity::from_iter((0..rows).map(|_| rng.random_bool(0.95))), + MaskShape::SparseValues => { + Validity::from_iter((0..rows).map(|_| rng.random_bool(0.05))) + } + } + } +} + +const MASK_SHAPES: &[MaskShape] = &[ + MaskShape::AllValid, + MaskShape::AllInvalid, + MaskShape::DenseValues, + MaskShape::SparseValues, +]; + +const ROWS: usize = 4096; +const DIMENSIONS: u32 = 128; + +fn build_vector_array(shape: MaskShape) -> ArrayRef { + let mut rng = StdRng::seed_from_u64(0xC0FFEE); + let dim = DIMENSIONS as usize; + let values: Buffer = (0..ROWS * dim).map(|_| rng.random::()).collect(); + let elements = PrimitiveArray::new::(values, Validity::NonNullable); + let validity = shape.build(ROWS, &mut rng); + let fsl = + FixedSizeListArray::try_new(elements.into_array(), DIMENSIONS, validity, ROWS).unwrap(); + + ExtensionArray::try_new_from_vtable(Vector, EmptyMetadata, fsl.into_array()) + .unwrap() + .into_array() +} + +fn pack(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { + let len = vec.len(); + TQPack::try_new_array(vec, config, len) + .unwrap() + .into_array() + .execute(ctx) + .unwrap() +} + +fn unpack(packed: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { + let len = packed.len(); + TQUnpack::try_new_array(packed, config, len) + .unwrap() + .into_array() + .execute(ctx) + .unwrap() +} + +fn config() -> TurboQuantConfig { + // 4 bits, 4 SORF rounds, fixed seed: representative defaults from the test fixtures. + TurboQuantConfig::try_new(4, 0xDEADBEEF, 4).unwrap() +} + +#[divan::bench(args = MASK_SHAPES)] +fn turboquant_pack(bencher: Bencher, shape: &MaskShape) { + let shape = *shape; + let cfg = config(); + bencher + .with_inputs(|| (build_vector_array(shape), SESSION.create_execution_ctx())) + .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) + .bench_values(|(arr, mut ctx)| pack(arr, &cfg, &mut ctx)) +} + +#[divan::bench(args = MASK_SHAPES)] +fn turboquant_unpack(bencher: Bencher, shape: &MaskShape) { + let shape = *shape; + let cfg = config(); + bencher + .with_inputs(|| { + let arr = build_vector_array(shape); + let mut ctx = SESSION.create_execution_ctx(); + let packed = pack(arr, &cfg, &mut ctx); + (packed, SESSION.create_execution_ctx()) + }) + .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) + .bench_values(|(packed, mut ctx)| unpack(packed, &cfg, &mut ctx)) +} diff --git a/vortex-turboquant/public-api.lock b/vortex-turboquant/public-api.lock index e18dcffe235..4c837b5fea3 100644 --- a/vortex-turboquant/public-api.lock +++ b/vortex-turboquant/public-api.lock @@ -12,12 +12,6 @@ impl core::clone::Clone for vortex_turboquant::TQPack pub fn vortex_turboquant::TQPack::clone(&self) -> vortex_turboquant::TQPack -impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQPack - -pub fn vortex_turboquant::TQPack::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> - -pub fn vortex_turboquant::TQPack::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> - impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQPack pub type vortex_turboquant::TQPack::Options = vortex_turboquant::TurboQuantConfig diff --git a/vortex-turboquant/src/scalar_fns/pack.rs b/vortex-turboquant/src/scalar_fns/pack.rs index 7748952ea76..e8be14d24a2 100644 --- a/vortex-turboquant/src/scalar_fns/pack.rs +++ b/vortex-turboquant/src/scalar_fns/pack.rs @@ -9,9 +9,6 @@ use std::fmt::Formatter; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::scalar_fn::ScalarFnArrayView; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; use vortex_array::dtype::extension::ExtDType; use vortex_array::expr::Expression; @@ -21,9 +18,7 @@ use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::scalar_fn::TypedScalarFnInstance; -use vortex_array::serde::ArrayChildren; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; use vortex_session::VortexSession; @@ -39,7 +34,13 @@ use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_storage_dtype; -/// Lazy TurboQuant vector pack scalar function. +/// TurboQuant vector pack scalar function. +/// +/// `TQPack` itself is a `ScalarFnVTable` and so its options round-trip through expression +/// serialization. +/// +/// Unlike `TQUnpack`, it deliberately does **not** implement `ScalarFnArrayVTable` since the +/// persisted artifact would be the original vector array, not the TurboQuant-quantized array. #[derive(Clone)] pub struct TQPack; @@ -154,24 +155,3 @@ impl ScalarFnVTable for TQPack { false } } - -impl ScalarFnArrayVTable for TQPack { - fn serialize( - &self, - _view: &ScalarFnArrayView, - _session: &VortexSession, - ) -> VortexResult>> { - Ok(None) - } - - fn deserialize( - &self, - _dtype: &DType, - _len: usize, - _metadata: &[u8], - _children: &dyn ArrayChildren, - _session: &VortexSession, - ) -> VortexResult> { - vortex_bail!("TQPack scalar-fn arrays are not serializable") - } -} diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs index a937643d217..99c8603f272 100644 --- a/vortex-turboquant/src/tests/scalar_fns.rs +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -76,18 +76,6 @@ fn scalar_fn_array_metadata_stores_only_config() -> VortexResult<()> { Ok(()) } -#[test] -fn pack_scalar_fn_array_is_not_serializable() -> VortexResult<()> { - let session = test_session(); - let input = f32_vector_array(128, 2, 0.25, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - let packed_lazy = TQPack::try_new_array(input, &config, 2)?.into_array(); - let pack_plugin = ScalarFnArrayPlugin::new(TQPack); - - assert!(pack_plugin.serialize(&packed_lazy, &session)?.is_none()); - Ok(()) -} - #[test] fn unpack_rejects_config_that_disagrees_with_turboquant_child() -> VortexResult<()> { let session = test_session(); diff --git a/vortex-turboquant/src/vector/normalize.rs b/vortex-turboquant/src/vector/normalize.rs index 5a66f7c0e4f..642949eecf6 100644 --- a/vortex-turboquant/src/vector/normalize.rs +++ b/vortex-turboquant/src/vector/normalize.rs @@ -27,6 +27,7 @@ use vortex_error::VortexResult; use vortex_error::vortex_ensure_eq; use vortex_error::vortex_err; use vortex_mask::Mask; +use vortex_mask::MaskValues; use vortex_tensor::scalar_fns::l2_denorm::L2Denorm; use vortex_tensor::scalar_fns::l2_norm::L2Norm; use vortex_tensor::vector::AnyVector; @@ -106,23 +107,39 @@ where .ok_or_else(|| vortex_err!("TurboQuant normalized vector length overflow"))?; let mut output = BufferMut::::with_capacity(output_len); - for i in 0..num_vectors { - let row_values = &values[i * dimensions..][..dimensions]; - let norm = norm_values[i]; - - if !mask.value(i) || norm == T::zero() { - // Invalid rows are placeholders guarded by validity. A valid vector with L2 norm zero - // is all zeros. In both cases, append zero placeholders without per-element work. - - // SAFETY: `output` was allocated with capacity `output_len`, and the loop appends - // exactly `dimensions` values for each of `num_vectors` iterations. - unsafe { output.push_n_unchecked(T::zero(), dimensions) }; - } else { - for &value in row_values { - // SAFETY: same capacity invariant as above. - unsafe { output.push_unchecked(value / norm) }; + // The total number of pushes is always exactly `num_vectors * dimensions == output_len` + // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. + match mask { + Mask::AllFalse(_) => { + // Every row is invalid: bulk-fill the output with zero placeholders. + // + // SAFETY: `output` was allocated with capacity `output_len`, and this push writes + // exactly `output_len` zero placeholders. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for i in 0..num_vectors { + // SAFETY: `output` was allocated with capacity `output_len = num_vectors * + // dimensions`. This loop runs `num_vectors` times and each call pushes exactly + // `dimensions` elements, so capacity for `dimensions` more elements always + // remains. + unsafe { normalize_one_row::(&mut output, values, norm_values, dimensions, i) }; } } + Mask::Values(values_mask) => { + // SAFETY: `output` was allocated with capacity `output_len = num_vectors * + // dimensions`, which is the bound the helper requires. + unsafe { + normalize_vectors_with_mask::( + &mut output, + values, + norm_values, + dimensions, + num_vectors, + values_mask, + ) + }; + } } // Vector elements are always non-nullable. @@ -144,3 +161,76 @@ where .into_array(), ) } + +/// Normalize a single valid row, or push `dimensions` zero placeholders if the row's L2 norm +/// is zero. +/// +/// A valid vector with L2 norm zero is all zeros, so dividing through it would be undefined. +/// Treating it the same as an invalid row preserves the original semantics. +/// +/// # Safety +/// +/// `output` must have capacity for at least `dimensions` more elements before this call. +unsafe fn normalize_one_row( + output: &mut BufferMut, + values: &[T], + norm_values: &[T], + dimensions: usize, + i: usize, +) where + T: Float + NativePType, +{ + let norm = norm_values[i]; + + if norm == T::zero() { + // SAFETY: caller guarantees capacity for `dimensions` more elements. + unsafe { output.push_n_unchecked(T::zero(), dimensions) }; + } else { + let row_values = &values[i * dimensions..][..dimensions]; + + for &value in row_values { + // SAFETY: caller guarantees capacity for `dimensions` more elements. + unsafe { output.push_unchecked(value / norm) }; + } + } +} + +/// Walk the pre-cached run boundaries of a `Values` mask, bulk-pushing zero placeholders for +/// invalid runs and normalizing valid runs row by row. +/// +/// # Safety +/// +/// `output` must have capacity for at least `num_vectors * dimensions` more elements before +/// this call. +unsafe fn normalize_vectors_with_mask( + output: &mut BufferMut, + values: &[T], + norm_values: &[T], + dimensions: usize, + num_vectors: usize, + values_mask: &MaskValues, +) where + T: Float + NativePType, +{ + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: capacity invariant from caller. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + + for i in start..end { + // SAFETY: capacity invariant from caller — each call pushes `dimensions` and the + // total number of valid rows in the mask is bounded by `num_vectors`. + unsafe { normalize_one_row::(output, values, norm_values, dimensions, i) }; + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: capacity invariant from caller. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } +} diff --git a/vortex-turboquant/src/vector/quantize.rs b/vortex-turboquant/src/vector/quantize.rs index c551d33861f..0861b9f6805 100644 --- a/vortex-turboquant/src/vector/quantize.rs +++ b/vortex-turboquant/src/vector/quantize.rs @@ -23,6 +23,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; +use vortex_mask::Mask; use super::tq_padded_dim; use crate::TurboQuantConfig; @@ -78,36 +79,66 @@ pub(crate) unsafe fn turboquant_quantize_core( .checked_mul(padded_dim) .ok_or_else(|| vortex_err!("TurboQuant codes length overflow"))?; let mut all_indices = BufferMut::::with_capacity(codes_len); + let mut padded = vec![0.0f32; padded_dim]; let mut transformed = vec![0.0f32; padded_dim]; + // Pad, SORF-transform, and quantize a single row, pushing `padded_dim` codes into + // `all_indices`. Captures the read-only inputs and the scratch buffers so each call site + // only needs to pass `all_indices` and the row index. + // + // NB: `all_indices` cannot be captured here: the `Values` arm interleaves the closure call + // with direct `all_indices.push_n_unchecked` calls. let f32_slice = f32_elements.as_slice(); let dimension = dimension as usize; - for row in 0..num_vectors { - if !mask.value(row) { - // The row-level FSL validity marks these bytes invalid, so keep full-length storage - // without spending SORF/quantization work on a null vector. - - // SAFETY: `all_indices` was allocated with capacity `codes_len`, and the loop appends - // exactly `padded_dim` codes for each of `num_vectors` iterations. - unsafe { all_indices.push_n_unchecked(0, padded_dim) }; - - continue; - } - - let x = &f32_slice[row * dimension..(row + 1) * dimension]; - - // Zero-pad to the next power of 2. - padded[..dimension].copy_from_slice(x); + let mut quantize_row = |all_indices: &mut BufferMut, row: usize| { + // Reuse `padded` and `transformed` from the outer scope. + padded[..dimension].copy_from_slice(&f32_slice[row * dimension..][..dimension]); padded[dimension..].fill(0.0); - sorf_transform.transform(&padded, &mut transformed); - // SAFETY: `all_indices` was allocated with capacity `codes_len`, and the loop appends - // exactly `padded_dim` codes for each of `num_vectors` iterations. - for &value in transformed.iter() { + for &value in &transformed { + // SAFETY: total pushes across all match arms equal `codes_len`. unsafe { all_indices.push_unchecked(find_nearest_centroid(value, &boundaries)) }; } + }; + + // The total number of pushes is always exactly `num_vectors * padded_dim == codes_len` + // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. + match &mask { + Mask::AllFalse(_) => { + // Every row is invalid: bulk-fill placeholder zero codes. + // + // SAFETY: `all_indices` was allocated with capacity `codes_len`, and this push + // writes exactly `codes_len` zero codes. + unsafe { all_indices.push_n_unchecked(0, codes_len) }; + } + Mask::AllTrue(_) => { + for row in 0..num_vectors { + quantize_row(&mut all_indices, row); + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `codes_len`. + unsafe { all_indices.push_n_unchecked(0, (start - cursor) * padded_dim) }; + } + + for row in start..end { + quantize_row(&mut all_indices, row); + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `codes_len`. + unsafe { all_indices.push_n_unchecked(0, (num_vectors - cursor) * padded_dim) }; + } + } } Ok(QuantizationResult { diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index 4794d51d323..a5db2e4a167 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -36,6 +36,7 @@ use vortex_array::validity::Validity; use vortex_error::VortexResult; use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use vortex_mask::Mask; use super::quantize::QuantizationResult; use crate::vtable::TurboQuantMetadata; @@ -120,13 +121,11 @@ pub(crate) fn parse_storage( let struct_validity = storage.struct_validity(); let norms_validity = norms.validity()?; let codes_validity = codes_fsl.validity()?; - validate_child_validity_covers_struct( - &struct_validity, - &norms_validity, - &codes_validity, - len, - ctx, - )?; + + let struct_mask = struct_validity.execute_mask(len, ctx)?; + let norms_mask = norms_validity.execute_mask(len, ctx)?; + let codes_mask = codes_validity.execute_mask(len, ctx)?; + validate_child_validity_covers_struct(&struct_mask, &norms_mask, &codes_mask)?; Ok(TurboQuantParsedStorage { metadata, @@ -137,31 +136,24 @@ pub(crate) fn parse_storage( }) } +/// Validate that both child masks cover the struct mask: every row that the struct considers +/// valid must also be valid in the `norms` and `codes` children. +/// +/// `struct_mask & !child_mask` selects rows where the struct is valid but the child is not. If +/// no such row exists, the child covers the struct. [`Mask::bitand_not`] is variant-specialized, +/// so this short-circuits in `O(1)` when either mask is `AllTrue` or `AllFalse`. fn validate_child_validity_covers_struct( - struct_validity: &Validity, - norms_validity: &Validity, - codes_validity: &Validity, - len: usize, - ctx: &mut ExecutionCtx, + struct_mask: &Mask, + norms_mask: &Mask, + codes_mask: &Mask, ) -> VortexResult<()> { - let struct_mask = struct_validity.execute_mask(len, ctx)?; - let norms_mask = norms_validity.execute_mask(len, ctx)?; - let codes_mask = codes_validity.execute_mask(len, ctx)?; - - for row in 0..len { - if !struct_mask.value(row) { - continue; - } - - vortex_ensure!( - norms_mask.value(row), - "TurboQuant {NORMS_FIELD} row validity must cover storage validity" - ); - vortex_ensure!( - codes_mask.value(row), - "TurboQuant {CODES_FIELD} row validity must cover storage validity" - ); - } - + vortex_ensure!( + struct_mask.clone().bitand_not(norms_mask).all_false(), + "TurboQuant {NORMS_FIELD} row validity must cover storage validity" + ); + vortex_ensure!( + struct_mask.clone().bitand_not(codes_mask).all_false(), + "TurboQuant {CODES_FIELD} row validity must cover storage validity" + ); Ok(()) } diff --git a/vortex-turboquant/src/vector/unpack.rs b/vortex-turboquant/src/vector/unpack.rs index 0c812d7bc02..59f54a92b47 100644 --- a/vortex-turboquant/src/vector/unpack.rs +++ b/vortex-turboquant/src/vector/unpack.rs @@ -21,6 +21,7 @@ use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_err; +use vortex_mask::Mask; use vortex_tensor::vector::Vector; use super::storage::parse_storage; @@ -110,48 +111,79 @@ where let codes = decode.codes.as_slice::(); let mask = vector_validity.execute_mask(num_vectors, ctx)?; - let mut decoded = vec![0.0f32; padded_dim]; - let mut inverse = vec![0.0f32; padded_dim]; let output_len = num_vectors .checked_mul(dimensions) .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; let mut output = BufferMut::::with_capacity(output_len); - for i in 0..num_vectors { - if !mask.value(i) { - // Null rows still need `dimensions` placeholder elements so the FSL stays row-aligned. - // The row's nullability bit is authoritative, so consumers should be treating these - // zeros as undefined. - - // SAFETY: `output` was allocated with capacity `output_len`, and the loop appends - // exactly `dimensions` values for each of `num_vectors` iterations. - unsafe { output.push_n_unchecked(T::zero(), dimensions) }; - - continue; - } + let mut decoded = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; - // Perform the gather from codes to values. + // Decode a single row: gather codes through the centroid table, apply the inverse SORF + // transform, then denormalize and push `dimensions` elements into `output`. Captures the + // read-only inputs and the scratch buffers so each call site only needs to pass `output` + // and the row index. + let mut decode_row = |output: &mut BufferMut, i: usize| { let code_row = &codes[i * padded_dim..][..padded_dim]; + + // Gather the values according to the codes. for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { *dst = *centroids .get(usize::from(code)) .vortex_expect("TurboQuant code exceeds centroid count"); } - // Undo the transform. decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); - // Multiply all elements by the corresponding normal. let norm = norms[i]; for &value in inverse.iter().take(dimensions) { - // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, `f64`): - // values outside `f16` range saturate to `±inf` rather than returning `None`. + // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, + // `f64`): values outside `f16` range saturate to `±inf` rather than returning + // `None`. let value = T::from_f32(value) .vortex_expect("from_f32 is infallible for supported float types"); - // SAFETY: same capacity invariant as above. + // SAFETY: total pushes across all match arms equal `output_len`. unsafe { output.push_unchecked(value * norm) }; } + }; + + // The total number of pushes is always exactly `num_vectors * dimensions == output_len` + // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. + match &mask { + Mask::AllFalse(_) => { + // Every row is invalid: bulk-fill the output with zero placeholders. + // + // SAFETY: `output` was allocated with capacity `output_len`, and this push writes + // exactly `output_len` zero placeholders. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for i in 0..num_vectors { + decode_row(&mut output, i); + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + + for i in start..end { + decode_row(&mut output, i); + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } + } } let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); From 53143995e36122d1387d5a69741853f37c2623dd Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 8 May 2026 14:48:29 -0400 Subject: [PATCH 4/5] rename pack/unpack to encode/decode Signed-off-by: Connor Tsui --- vortex-turboquant/Cargo.toml | 2 +- .../{pack_unpack.rs => encode_decode.rs} | 30 +++---- vortex-turboquant/public-api.lock | 82 +++++++++---------- vortex-turboquant/src/config.rs | 2 +- vortex-turboquant/src/lib.rs | 18 ++-- .../src/scalar_fns/{unpack.rs => decode.rs} | 40 ++++----- .../src/scalar_fns/{pack.rs => encode.rs} | 36 ++++---- vortex-turboquant/src/scalar_fns/mod.rs | 10 +-- .../{pack_unpack.rs => encode_decode.rs} | 66 ++++++++------- vortex-turboquant/src/tests/file.rs | 22 ++--- vortex-turboquant/src/tests/malformed.rs | 20 ++--- vortex-turboquant/src/tests/mod.rs | 18 ++-- vortex-turboquant/src/tests/parity.rs | 10 +-- vortex-turboquant/src/tests/scalar_fns.rs | 54 ++++++------ .../src/vector/{unpack.rs => decode.rs} | 14 ++-- .../src/vector/{pack.rs => encode.rs} | 12 +-- vortex-turboquant/src/vector/mod.rs | 4 +- vortex-turboquant/src/vector/storage.rs | 2 +- 18 files changed, 223 insertions(+), 219 deletions(-) rename vortex-turboquant/benches/{pack_unpack.rs => encode_decode.rs} (81%) rename vortex-turboquant/src/scalar_fns/{unpack.rs => decode.rs} (86%) rename vortex-turboquant/src/scalar_fns/{pack.rs => encode.rs} (81%) rename vortex-turboquant/src/tests/{pack_unpack.rs => encode_decode.rs} (78%) rename vortex-turboquant/src/vector/{unpack.rs => decode.rs} (96%) rename vortex-turboquant/src/vector/{pack.rs => encode.rs} (88%) diff --git a/vortex-turboquant/Cargo.toml b/vortex-turboquant/Cargo.toml index c5ebba4b227..ab3f63583d3 100644 --- a/vortex-turboquant/Cargo.toml +++ b/vortex-turboquant/Cargo.toml @@ -37,5 +37,5 @@ vortex-io = { workspace = true } vortex-layout = { workspace = true } [[bench]] -name = "pack_unpack" +name = "encode_decode" harness = false diff --git a/vortex-turboquant/benches/pack_unpack.rs b/vortex-turboquant/benches/encode_decode.rs similarity index 81% rename from vortex-turboquant/benches/pack_unpack.rs rename to vortex-turboquant/benches/encode_decode.rs index 15992f480c1..845c4ae85a8 100644 --- a/vortex-turboquant/benches/pack_unpack.rs +++ b/vortex-turboquant/benches/encode_decode.rs @@ -1,12 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Benchmarks for `turboquant_pack` and `turboquant_unpack` across different validity-mask +//! Benchmarks for `turboquant_encode` and `turboquant_decode` across different validity-mask //! shapes. //! //! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the //! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, -//! `vector/quantize.rs`, and `vector/unpack.rs`. +//! `vector/quantize.rs`, and `vector/decode.rs`. #![expect(clippy::unwrap_used)] @@ -29,8 +29,8 @@ use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_session::VortexSession; use vortex_tensor::vector::Vector; -use vortex_turboquant::TQPack; -use vortex_turboquant::TQUnpack; +use vortex_turboquant::TQDecode; +use vortex_turboquant::TQEncode; use vortex_turboquant::TurboQuantConfig; fn main() { @@ -100,18 +100,18 @@ fn build_vector_array(shape: MaskShape) -> ArrayRef { .into_array() } -fn pack(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { +fn encode(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { let len = vec.len(); - TQPack::try_new_array(vec, config, len) + TQEncode::try_new_array(vec, config, len) .unwrap() .into_array() .execute(ctx) .unwrap() } -fn unpack(packed: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { - let len = packed.len(); - TQUnpack::try_new_array(packed, config, len) +fn decode(encoded: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { + let len = encoded.len(); + TQDecode::try_new_array(encoded, config, len) .unwrap() .into_array() .execute(ctx) @@ -124,26 +124,26 @@ fn config() -> TurboQuantConfig { } #[divan::bench(args = MASK_SHAPES)] -fn turboquant_pack(bencher: Bencher, shape: &MaskShape) { +fn turboquant_encode(bencher: Bencher, shape: &MaskShape) { let shape = *shape; let cfg = config(); bencher .with_inputs(|| (build_vector_array(shape), SESSION.create_execution_ctx())) .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) - .bench_values(|(arr, mut ctx)| pack(arr, &cfg, &mut ctx)) + .bench_values(|(arr, mut ctx)| encode(arr, &cfg, &mut ctx)) } #[divan::bench(args = MASK_SHAPES)] -fn turboquant_unpack(bencher: Bencher, shape: &MaskShape) { +fn turboquant_decode(bencher: Bencher, shape: &MaskShape) { let shape = *shape; let cfg = config(); bencher .with_inputs(|| { let arr = build_vector_array(shape); let mut ctx = SESSION.create_execution_ctx(); - let packed = pack(arr, &cfg, &mut ctx); - (packed, SESSION.create_execution_ctx()) + let encoded = encode(arr, &cfg, &mut ctx); + (encoded, SESSION.create_execution_ctx()) }) .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) - .bench_values(|(packed, mut ctx)| unpack(packed, &cfg, &mut ctx)) + .bench_values(|(encoded, mut ctx)| decode(encoded, &cfg, &mut ctx)) } diff --git a/vortex-turboquant/public-api.lock b/vortex-turboquant/public-api.lock index 4c837b5fea3..95e162f40f6 100644 --- a/vortex-turboquant/public-api.lock +++ b/vortex-turboquant/public-api.lock @@ -1,86 +1,86 @@ pub mod vortex_turboquant -pub struct vortex_turboquant::TQPack +pub struct vortex_turboquant::TQDecode -impl vortex_turboquant::TQPack +impl vortex_turboquant::TQDecode -pub fn vortex_turboquant::TQPack::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance +pub fn vortex_turboquant::TQDecode::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance -pub fn vortex_turboquant::TQPack::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQDecode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult -impl core::clone::Clone for vortex_turboquant::TQPack +impl core::clone::Clone for vortex_turboquant::TQDecode -pub fn vortex_turboquant::TQPack::clone(&self) -> vortex_turboquant::TQPack +pub fn vortex_turboquant::TQDecode::clone(&self) -> vortex_turboquant::TQDecode -impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQPack +impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQDecode -pub type vortex_turboquant::TQPack::Options = vortex_turboquant::TurboQuantConfig +pub fn vortex_turboquant::TQDecode::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> -pub fn vortex_turboquant::TQPack::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity +pub fn vortex_turboquant::TQDecode::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> -pub fn vortex_turboquant::TQPack::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQDecode -pub fn vortex_turboquant::TQPack::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult +pub type vortex_turboquant::TQDecode::Options = vortex_turboquant::TurboQuantConfig -pub fn vortex_turboquant::TQPack::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQDecode::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity -pub fn vortex_turboquant::TQPack::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_turboquant::TQDecode::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_turboquant::TQPack::id(&self) -> vortex_array::scalar_fn::ScalarFnId +pub fn vortex_turboquant::TQDecode::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQPack::is_fallible(&self, &Self::Options) -> bool +pub fn vortex_turboquant::TQDecode::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQPack::is_null_sensitive(&self, &Self::Options) -> bool +pub fn vortex_turboquant::TQDecode::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub fn vortex_turboquant::TQPack::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQDecode::id(&self) -> vortex_array::scalar_fn::ScalarFnId -pub fn vortex_turboquant::TQPack::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> +pub fn vortex_turboquant::TQDecode::is_fallible(&self, &Self::Options) -> bool -pub fn vortex_turboquant::TQPack::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> +pub fn vortex_turboquant::TQDecode::is_null_sensitive(&self, &Self::Options) -> bool -pub struct vortex_turboquant::TQUnpack +pub fn vortex_turboquant::TQDecode::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -impl vortex_turboquant::TQUnpack +pub fn vortex_turboquant::TQDecode::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> -pub fn vortex_turboquant::TQUnpack::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance +pub fn vortex_turboquant::TQDecode::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> -pub fn vortex_turboquant::TQUnpack::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult +pub struct vortex_turboquant::TQEncode -impl core::clone::Clone for vortex_turboquant::TQUnpack +impl vortex_turboquant::TQEncode -pub fn vortex_turboquant::TQUnpack::clone(&self) -> vortex_turboquant::TQUnpack +pub fn vortex_turboquant::TQEncode::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance -impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQUnpack +pub fn vortex_turboquant::TQEncode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQUnpack::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> +impl core::clone::Clone for vortex_turboquant::TQEncode -pub fn vortex_turboquant::TQUnpack::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> +pub fn vortex_turboquant::TQEncode::clone(&self) -> vortex_turboquant::TQEncode -impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQUnpack +impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQEncode -pub type vortex_turboquant::TQUnpack::Options = vortex_turboquant::TurboQuantConfig +pub type vortex_turboquant::TQEncode::Options = vortex_turboquant::TurboQuantConfig -pub fn vortex_turboquant::TQUnpack::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity +pub fn vortex_turboquant::TQEncode::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity -pub fn vortex_turboquant::TQUnpack::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName +pub fn vortex_turboquant::TQEncode::child_name(&self, &Self::Options, usize) -> vortex_array::scalar_fn::vtable::ChildName -pub fn vortex_turboquant::TQUnpack::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQEncode::deserialize(&self, &[u8], &vortex_session::VortexSession) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQUnpack::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQEncode::execute(&self, &Self::Options, &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQUnpack::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_turboquant::TQEncode::fmt_sql(&self, &Self::Options, &vortex_array::expr::expression::Expression, &mut core::fmt::Formatter<'_>) -> core::fmt::Result -pub fn vortex_turboquant::TQUnpack::id(&self) -> vortex_array::scalar_fn::ScalarFnId +pub fn vortex_turboquant::TQEncode::id(&self) -> vortex_array::scalar_fn::ScalarFnId -pub fn vortex_turboquant::TQUnpack::is_fallible(&self, &Self::Options) -> bool +pub fn vortex_turboquant::TQEncode::is_fallible(&self, &Self::Options) -> bool -pub fn vortex_turboquant::TQUnpack::is_null_sensitive(&self, &Self::Options) -> bool +pub fn vortex_turboquant::TQEncode::is_null_sensitive(&self, &Self::Options) -> bool -pub fn vortex_turboquant::TQUnpack::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQEncode::return_dtype(&self, &Self::Options, &[vortex_array::dtype::DType]) -> vortex_error::VortexResult -pub fn vortex_turboquant::TQUnpack::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> +pub fn vortex_turboquant::TQEncode::serialize(&self, &Self::Options) -> vortex_error::VortexResult>> -pub fn vortex_turboquant::TQUnpack::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> +pub fn vortex_turboquant::TQEncode::validity(&self, &Self::Options, &vortex_array::expr::expression::Expression) -> vortex_error::VortexResult> pub struct vortex_turboquant::TurboQuant diff --git a/vortex-turboquant/src/config.rs b/vortex-turboquant/src/config.rs index 8b5ed03e93c..57cd8b1e94b 100644 --- a/vortex-turboquant/src/config.rs +++ b/vortex-turboquant/src/config.rs @@ -15,7 +15,7 @@ pub(crate) const MIN_DIMENSION: u32 = 128; /// Maximum supported number of bits per quantized coordinate. pub(crate) const MAX_BIT_WIDTH: u8 = 8; -/// Configuration for lossy TurboQuant packing. +/// Configuration for lossy TurboQuant encoding. #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct TurboQuantConfig { bit_width: u8, diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 3db71ceab6a..60e36edff3a 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -5,7 +5,7 @@ //! //! Implements a Stage 1 TurboQuant encoding ([arXiv:2504.19874], [RFC 0033]) for lossy compression //! of high-dimensional vector data. The extension operates on -//! [`Vector`](vortex_tensor::vector::Vector) extension arrays, packing their `FixedSizeList` +//! [`Vector`](vortex_tensor::vector::Vector) extension arrays, encoding their `FixedSizeList` //! storage into quantized codes after a structured orthogonal surrogate transform. //! //! [arXiv:2504.19874]: https://arxiv.org/abs/2504.19874 @@ -16,13 +16,13 @@ //! TurboQuant minimizes mean-squared reconstruction error (1-8 bits per coordinate) //! using MSE-optimal scalar quantization on coordinates of a transformed unit vector. //! -//! The [`TQPack`] scalar function first computes and stores the original L2 norm for each vector +//! The [`TQEncode`] scalar function first computes and stores the original L2 norm for each vector //! row, then normalizes each valid nonzero row internally before SORF transform and scalar -//! quantization. The [`TQUnpack`] scalar function dequantizes through deterministic centroids, +//! quantization. The [`TQDecode`] scalar function dequantizes through deterministic centroids, //! applies the inverse SORF transform, truncates back to the original dimension, and re-applies the //! stored norm. //! -//! The packed storage is a row-aligned extension tree: +//! The encoded storage is a row-aligned extension tree: //! //! ```text //! Extension( @@ -61,8 +61,8 @@ mod vector; mod vtable; pub use config::TurboQuantConfig; -pub use scalar_fns::TQPack; -pub use scalar_fns::TQUnpack; +pub use scalar_fns::TQDecode; +pub use scalar_fns::TQEncode; pub use vtable::TurboQuant; pub use vtable::TurboQuantMetadata; @@ -76,11 +76,11 @@ pub fn initialize(session: &vortex_session::VortexSession) { session.dtypes().register(TurboQuant); - session.scalar_fns().register(TQPack); - session.scalar_fns().register(TQUnpack); + session.scalar_fns().register(TQEncode); + session.scalar_fns().register(TQDecode); let session_arrays = session.arrays(); - session_arrays.register(ScalarFnArrayPlugin::new(TQUnpack)); + session_arrays.register(ScalarFnArrayPlugin::new(TQDecode)); } #[cfg(test)] diff --git a/vortex-turboquant/src/scalar_fns/unpack.rs b/vortex-turboquant/src/scalar_fns/decode.rs similarity index 86% rename from vortex-turboquant/src/scalar_fns/unpack.rs rename to vortex-turboquant/src/scalar_fns/decode.rs index 4d1d414a220..8a937b23a64 100644 --- a/vortex-turboquant/src/scalar_fns/unpack.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant unpack scalar function. +//! TurboQuant decode scalar function. use std::fmt; use std::fmt::Formatter; @@ -35,37 +35,37 @@ use vortex_tensor::vector::Vector; use super::metadata::deserialize_config; use super::metadata::serialize_config; use crate::TurboQuantConfig; -use crate::vector::unpack::unpack_vector; +use crate::vector::decode::decode_vector; use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; use crate::vtable::tq_storage_dtype; -/// Lazy TurboQuant vector unpack scalar function. +/// Lazy TurboQuant vector decode scalar function. #[derive(Clone)] -pub struct TQUnpack; +pub struct TQDecode; -impl TQUnpack { - /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant unpacking. - pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { - TypedScalarFnInstance::new(TQUnpack, config.clone()) +impl TQDecode { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant decoding. + pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQDecode, config.clone()) } - /// Constructs a [`ScalarFnArray`] that lazily unpacks a `TurboQuant` child into a `Vector`. + /// Constructs a [`ScalarFnArray`] that lazily decodes a `TurboQuant` child into a `Vector`. pub fn try_new_array( child: ArrayRef, config: &TurboQuantConfig, len: usize, ) -> VortexResult { - ScalarFnArray::try_new(TQUnpack::new(config).erased(), vec![child], len) + ScalarFnArray::try_new(TQDecode::new(config).erased(), vec![child], len) } } -impl ScalarFnVTable for TQUnpack { +impl ScalarFnVTable for TQDecode { type Options = TurboQuantConfig; fn id(&self) -> ScalarFnId { - ScalarFnId::new("vortex.turboquant.unpack") + ScalarFnId::new("vortex.turboquant.decode") } fn serialize(&self, options: &Self::Options) -> VortexResult>> { @@ -87,7 +87,7 @@ impl ScalarFnVTable for TQUnpack { fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("turboquant"), - _ => unreachable!("TQUnpack must have exactly one child"), + _ => unreachable!("TQDecode must have exactly one child"), } } @@ -97,7 +97,7 @@ impl ScalarFnVTable for TQUnpack { expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { - write!(f, "tq_unpack(")?; + write!(f, "tq_decode(")?; expr.child(0).fmt_sql(f)?; write!(f, ", {options})") } @@ -126,7 +126,7 @@ impl ScalarFnVTable for TQUnpack { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - unpack_vector(args.get(0)?, ctx) + decode_vector(args.get(0)?, ctx) } fn validity( @@ -146,7 +146,7 @@ impl ScalarFnVTable for TQUnpack { } } -impl ScalarFnArrayVTable for TQUnpack { +impl ScalarFnArrayVTable for TQDecode { fn serialize( &self, view: &ScalarFnArrayView, @@ -168,7 +168,7 @@ impl ScalarFnArrayVTable for TQUnpack { .as_extension_opt() .and_then(|ext_dtype| ext_dtype.metadata_opt::()) .ok_or_else(|| { - vortex_err!("TQUnpack parent dtype must be a Vector extension array, got {dtype}") + vortex_err!("TQDecode parent dtype must be a Vector extension array, got {dtype}") })?; let metadata = TurboQuantMetadata { @@ -197,17 +197,17 @@ fn validate_config_matches_metadata( vortex_ensure_eq!( config.bit_width(), metadata.bit_width, - "TQUnpack config bit_width must match TurboQuant child metadata" + "TQDecode config bit_width must match TurboQuant child metadata" ); vortex_ensure_eq!( config.seed(), metadata.seed, - "TQUnpack config seed must match TurboQuant child metadata" + "TQDecode config seed must match TurboQuant child metadata" ); vortex_ensure_eq!( config.num_rounds(), metadata.num_rounds, - "TQUnpack config num_rounds must match TurboQuant child metadata" + "TQDecode config num_rounds must match TurboQuant child metadata" ); Ok(()) } diff --git a/vortex-turboquant/src/scalar_fns/pack.rs b/vortex-turboquant/src/scalar_fns/encode.rs similarity index 81% rename from vortex-turboquant/src/scalar_fns/pack.rs rename to vortex-turboquant/src/scalar_fns/encode.rs index e8be14d24a2..d683cd7e502 100644 --- a/vortex-turboquant/src/scalar_fns/pack.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant pack scalar function. +//! TurboQuant encode scalar function. use std::fmt; use std::fmt::Formatter; @@ -28,43 +28,43 @@ use super::metadata::deserialize_config; use super::metadata::serialize_config; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; -use crate::vector::pack::pack_vector; +use crate::vector::encode::encode_vector; use crate::vector::tq_padded_dim; use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_storage_dtype; -/// TurboQuant vector pack scalar function. +/// TurboQuant vector encode scalar function. /// -/// `TQPack` itself is a `ScalarFnVTable` and so its options round-trip through expression +/// `TQEncode` itself is a `ScalarFnVTable` and so its options round-trip through expression /// serialization. /// -/// Unlike `TQUnpack`, it deliberately does **not** implement `ScalarFnArrayVTable` since the +/// Unlike `TQDecode`, it deliberately does **not** implement `ScalarFnArrayVTable` since the /// persisted artifact would be the original vector array, not the TurboQuant-quantized array. #[derive(Clone)] -pub struct TQPack; +pub struct TQEncode; -impl TQPack { - /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant packing. - pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { - TypedScalarFnInstance::new(TQPack, config.clone()) +impl TQEncode { + /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant encoding. + pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQEncode, config.clone()) } - /// Constructs a [`ScalarFnArray`] that lazily packs a `Vector` child into `TurboQuant`. + /// Constructs a [`ScalarFnArray`] that lazily encodes a `Vector` child into `TurboQuant`. pub fn try_new_array( child: ArrayRef, config: &TurboQuantConfig, len: usize, ) -> VortexResult { - ScalarFnArray::try_new(TQPack::new(config).erased(), vec![child], len) + ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child], len) } } -impl ScalarFnVTable for TQPack { +impl ScalarFnVTable for TQEncode { type Options = TurboQuantConfig; fn id(&self) -> ScalarFnId { - ScalarFnId::new("vortex.turboquant.pack") + ScalarFnId::new("vortex.turboquant.encode") } fn serialize(&self, options: &Self::Options) -> VortexResult>> { @@ -86,7 +86,7 @@ impl ScalarFnVTable for TQPack { fn child_name(&self, _options: &Self::Options, child_idx: usize) -> ChildName { match child_idx { 0 => ChildName::from("vector"), - _ => unreachable!("TQPack must have exactly one child"), + _ => unreachable!("TQEncode must have exactly one child"), } } @@ -96,7 +96,7 @@ impl ScalarFnVTable for TQPack { expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { - write!(f, "tq_pack(")?; + write!(f, "tq_encode(")?; expr.child(0).fmt_sql(f)?; write!(f, ", {options})") } @@ -107,7 +107,7 @@ impl ScalarFnVTable for TQPack { .as_extension_opt() .and_then(|ext_dtype| ext_dtype.metadata_opt::()) .ok_or_else(|| { - vortex_err!("TQPack expects a Vector extension array, got {input_dtype}") + vortex_err!("TQEncode expects a Vector extension array, got {input_dtype}") })?; let dimensions = vector_metadata.dimensions(); @@ -136,7 +136,7 @@ impl ScalarFnVTable for TQPack { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - pack_vector(args.get(0)?, options, ctx) + encode_vector(args.get(0)?, options, ctx) } fn validity( diff --git a/vortex-turboquant/src/scalar_fns/mod.rs b/vortex-turboquant/src/scalar_fns/mod.rs index 9151ab9c1c7..1acea9f70f3 100644 --- a/vortex-turboquant/src/scalar_fns/mod.rs +++ b/vortex-turboquant/src/scalar_fns/mod.rs @@ -1,11 +1,11 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! Scalar functions for lazy TurboQuant vector pack and unpack operations. +//! Scalar functions for lazy TurboQuant vector encode and decode operations. +mod decode; +mod encode; mod metadata; -mod pack; -mod unpack; -pub use pack::TQPack; -pub use unpack::TQUnpack; +pub use decode::TQDecode; +pub use encode::TQEncode; diff --git a/vortex-turboquant/src/tests/pack_unpack.rs b/vortex-turboquant/src/tests/encode_decode.rs similarity index 78% rename from vortex-turboquant/src/tests/pack_unpack.rs rename to vortex-turboquant/src/tests/encode_decode.rs index 9c5a56f007a..f9b9f3d3569 100644 --- a/vortex-turboquant/src/tests/pack_unpack.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -17,8 +17,8 @@ use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; -use super::execute_tq_pack; -use super::execute_tq_unpack; +use super::execute_tq_decode; +use super::execute_tq_encode; use super::f32_vector_array; use super::test_session; use super::turboquant_storage; @@ -39,32 +39,32 @@ fn config_rejects_invalid_values(#[case] bit_width: u8, #[case] seed: u64, #[cas } #[test] -fn pack_rejects_non_vector_input() { +fn encode_rejects_non_vector_input() { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = PrimitiveArray::new::(Buffer::copy_from([1.0, 2.0]), Validity::NonNullable) .into_array(); - assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); } #[test] -fn pack_rejects_small_dimensions() -> VortexResult<()> { +fn encode_rejects_small_dimensions() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(127, 1, 1.0, Validity::NonNullable)?; - assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); Ok(()) } #[test] -fn pack_rejects_padded_dimension_overflow() -> VortexResult<()> { +fn encode_rejects_padded_dimension_overflow() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = vector_array::(2_147_483_649, &[], Validity::NonNullable)?; - assert!(execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx).is_err()); + assert!(execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx).is_err()); Ok(()) } @@ -78,28 +78,28 @@ fn centroid_cache_is_deterministic() -> VortexResult<()> { } #[test] -fn pack_unpack_empty_vectors() -> VortexResult<()> { +fn encode_decode_empty_vectors() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = vector_array::(128, &[], Validity::NonNullable)?; - let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; - let unpacked = execute_tq_unpack(packed, &TurboQuantConfig::default(), &mut ctx)?; + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &TurboQuantConfig::default(), &mut ctx)?; - assert!(unpacked.is_empty()); + assert!(decoded.is_empty()); Ok(()) } #[test] -fn pack_stores_norms_and_struct_validity() -> VortexResult<()> { +fn encode_stores_norms_and_struct_validity() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let validity = Validity::from_iter([true, false, true]); let input = f32_vector_array(128, 3, 0.25, validity)?; let config = TurboQuantConfig::try_new(3, 1, 2)?; - let packed = execute_tq_pack(input, &config, &mut ctx)?; - let storage = turboquant_storage(packed, &mut ctx)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let storage = turboquant_storage(encoded, &mut ctx)?; let mask = storage.struct_validity().execute_mask(3, &mut ctx)?; let norms: PrimitiveArray = storage .unmasked_field_by_name("norms")? @@ -174,7 +174,7 @@ fn normalize_as_l2_denorm_preserves_child_validity() -> VortexResult<()> { } #[test] -fn pack_unpack_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { +fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let mut values = vec![0.0f32; 3 * 128]; @@ -184,10 +184,10 @@ fn pack_unpack_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { values[257] = -1.0; let input = vector_array(128, &values, Validity::from_iter([true, true, false]))?; - let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; - let unpacked = execute_tq_unpack(packed, &TurboQuantConfig::default(), &mut ctx)?; - let output = vector_values_f32(unpacked.clone(), &mut ctx)?; - let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(3, &mut ctx)?; + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &TurboQuantConfig::default(), &mut ctx)?; + let output = vector_values_f32(decoded.clone(), &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; assert!(validity.value(0)); assert!(validity.value(1)); @@ -199,7 +199,7 @@ fn pack_unpack_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { #[rstest] #[case::f16(PType::F16)] #[case::f64(PType::F64)] -fn pack_unpack_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { +fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let config = TurboQuantConfig::try_new(3, 42, 3)?; @@ -210,9 +210,9 @@ fn pack_unpack_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> .map(|i| half::f16::from_f32(((i % 17) as f32 - 8.0) * 0.25)) .collect::>(); let input = vector_array(128, &values, Validity::NonNullable)?; - let packed = execute_tq_pack(input, &config, &mut ctx)?; - let unpacked = execute_tq_unpack(packed, &config, &mut ctx)?; - let ext: ExtensionArray = unpacked.execute(&mut ctx)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &config, &mut ctx)?; + let ext: ExtensionArray = decoded.execute(&mut ctx)?; assert_eq!(vector_element_ptype(&ext)?, PType::F16); } PType::F64 => { @@ -220,9 +220,9 @@ fn pack_unpack_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> .map(|i| ((i % 17) as f64 - 8.0) * 0.25) .collect::>(); let input = vector_array(128, &values, Validity::NonNullable)?; - let packed = execute_tq_pack(input, &config, &mut ctx)?; - let unpacked = execute_tq_unpack(packed, &config, &mut ctx)?; - let ext: ExtensionArray = unpacked.execute(&mut ctx)?; + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &config, &mut ctx)?; + let ext: ExtensionArray = decoded.execute(&mut ctx)?; assert_eq!(vector_element_ptype(&ext)?, PType::F64); } _ => unreachable!("test only passes f16/f64"), @@ -231,7 +231,7 @@ fn pack_unpack_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<()> } #[test] -fn unpack_scales_by_stored_norms() -> VortexResult<()> { +fn decode_scales_by_stored_norms() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let base = f32_vector_array(128, 1, 0.5, Validity::NonNullable)?; @@ -239,12 +239,16 @@ fn unpack_scales_by_stored_norms() -> VortexResult<()> { let config = TurboQuantConfig::try_new(2, 99, 3)?; let base_values = vector_values_f32( - execute_tq_unpack(execute_tq_pack(base, &config, &mut ctx)?, &config, &mut ctx)?, + execute_tq_decode( + execute_tq_encode(base, &config, &mut ctx)?, + &config, + &mut ctx, + )?, &mut ctx, )?; let scaled_values = vector_values_f32( - execute_tq_unpack( - execute_tq_pack(scaled, &config, &mut ctx)?, + execute_tq_decode( + execute_tq_encode(scaled, &config, &mut ctx)?, &config, &mut ctx, )?, diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index 3c5dbe5ccd9..f403431516d 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -12,12 +12,12 @@ use vortex_io::runtime::BlockingRuntime; use vortex_io::runtime::single::SingleThreadRuntime; use vortex_tensor::vector::Vector; -use super::execute_tq_pack; -use super::execute_tq_unpack_from_metadata; +use super::execute_tq_decode_from_metadata; +use super::execute_tq_encode; use super::f32_vector_array; use super::file_session; use super::vector_validity; -use crate::TQUnpack; +use crate::TQDecode; use crate::TurboQuantConfig; use crate::vtable::tq_metadata; @@ -27,39 +27,39 @@ fn file_roundtrip_with_initialize_session() -> VortexResult<()> { let session = file_session(&runtime); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; - let packed = execute_tq_pack(input, &TurboQuantConfig::default(), &mut ctx)?; + let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; let mut file_bytes = Vec::new(); VortexWriteOptions::new(session.clone()) .blocking(&runtime) - .write(&mut file_bytes, packed.to_array_iterator())?; + .write(&mut file_bytes, encoded.to_array_iterator())?; let file = session.open_options().open_buffer(file_bytes)?; let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; let metadata = tq_metadata(read.dtype())?; assert_eq!(metadata.dimensions, 128); - let unpacked = execute_tq_unpack_from_metadata(read, &mut ctx)?; - let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(2, &mut ctx)?; + let decoded = execute_tq_decode_from_metadata(read, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; assert!(validity.value(0)); assert!(!validity.value(1)); Ok(()) } #[test] -fn file_roundtrip_lazy_unpack_scalar_fn_with_initialize_session() -> VortexResult<()> { +fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResult<()> { let runtime = SingleThreadRuntime::default(); let session = file_session(&runtime); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; let config = TurboQuantConfig::try_new(3, 42, 3)?; - let packed = execute_tq_pack(input, &config, &mut ctx)?; - let unpacked = TQUnpack::try_new_array(packed, &config, 2)?.into_array(); + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded = TQDecode::try_new_array(encoded, &config, 2)?.into_array(); let mut file_bytes = Vec::new(); VortexWriteOptions::new(session.clone()) .blocking(&runtime) - .write(&mut file_bytes, unpacked.to_array_iterator())?; + .write(&mut file_bytes, decoded.to_array_iterator())?; let file = session.open_options().open_buffer(file_bytes)?; let read = runtime.block_on(async { file.scan()?.into_array_stream()?.read_all().await })?; diff --git a/vortex-turboquant/src/tests/malformed.rs b/vortex-turboquant/src/tests/malformed.rs index 7ca368dd00b..f99f0ee5105 100644 --- a/vortex-turboquant/src/tests/malformed.rs +++ b/vortex-turboquant/src/tests/malformed.rs @@ -15,7 +15,7 @@ use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; -use super::execute_tq_unpack_from_metadata; +use super::execute_tq_decode_from_metadata; use super::test_session; use super::vector_validity; use crate::TurboQuant; @@ -42,7 +42,7 @@ use crate::TurboQuantMetadata; Nullability::Nullable, Nullability::NonNullable )] -fn unpack_accepts_child_nullability_that_covers_struct_validity( +fn decode_accepts_child_nullability_that_covers_struct_validity( #[case] struct_nullability: Nullability, #[case] norms_nullability: Nullability, #[case] codes_nullability: Nullability, @@ -79,12 +79,12 @@ fn unpack_accepts_child_nullability_that_covers_struct_validity( .unwrap() .into_array(); - execute_tq_unpack_from_metadata(tq, &mut ctx)?; + execute_tq_decode_from_metadata(tq, &mut ctx)?; Ok(()) } #[test] -fn unpack_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { +fn decode_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let metadata = TurboQuantMetadata { @@ -109,8 +109,8 @@ fn unpack_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? .into_array(); - let unpacked = execute_tq_unpack_from_metadata(tq, &mut ctx)?; - let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(3, &mut ctx)?; + let decoded = execute_tq_decode_from_metadata(tq, &mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; assert!(validity.value(0)); assert!(!validity.value(1)); assert!(validity.value(2)); @@ -118,7 +118,7 @@ fn unpack_accepts_struct_mask_with_all_valid_children() -> VortexResult<()> { } #[test] -fn unpack_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { +fn decode_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let metadata = TurboQuantMetadata { @@ -150,13 +150,13 @@ fn unpack_rejects_child_masks_that_disagree_with_struct_validity() -> VortexResu let tq = ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage.into_array())? .into_array(); - assert!(execute_tq_unpack_from_metadata(tq, &mut ctx).is_err()); + assert!(execute_tq_decode_from_metadata(tq, &mut ctx).is_err()); Ok(()) } #[test] #[should_panic(expected = "TurboQuant code exceeds centroid count")] -fn unpack_panics_on_codes_outside_centroid_table() { +fn decode_panics_on_codes_outside_centroid_table() { let session = test_session(); let mut ctx = session.create_execution_ctx(); let metadata = TurboQuantMetadata { @@ -185,5 +185,5 @@ fn unpack_panics_on_codes_outside_centroid_table() { .unwrap() .into_array(); - drop(execute_tq_unpack_from_metadata(tq, &mut ctx)); + drop(execute_tq_decode_from_metadata(tq, &mut ctx)); } diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index 3b2bce0c4e2..c60d52cd957 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -32,16 +32,16 @@ use vortex_layout::session::LayoutSession; use vortex_session::VortexSession; use vortex_tensor::vector::Vector; -use crate::TQPack; -use crate::TQUnpack; +use crate::TQDecode; +use crate::TQEncode; use crate::TurboQuantConfig; use crate::initialize; use crate::vtable::tq_metadata; +mod encode_decode; mod file; mod malformed; mod metadata; -mod pack_unpack; mod parity; mod scalar_fns; @@ -120,34 +120,34 @@ fn turboquant_storage(array: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult VortexResult { let len = input.len(); - TQPack::try_new_array(input, config, len)? + TQEncode::try_new_array(input, config, len)? .into_array() .execute(ctx) } -fn execute_tq_unpack( +fn execute_tq_decode( input: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx, ) -> VortexResult { let len = input.len(); - TQUnpack::try_new_array(input, config, len)? + TQDecode::try_new_array(input, config, len)? .into_array() .execute(ctx) } -fn execute_tq_unpack_from_metadata( +fn execute_tq_decode_from_metadata( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { let metadata = tq_metadata(input.dtype())?; let config = TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds)?; - execute_tq_unpack(input, &config, ctx) + execute_tq_decode(input, &config, ctx) } diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index b6cb6c1b0cd..e5fc82fc1fd 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -5,22 +5,22 @@ use vortex_array::VortexSessionExecute; use vortex_array::validity::Validity; use vortex_error::VortexResult; -use super::execute_tq_pack; -use super::execute_tq_unpack; +use super::execute_tq_decode; +use super::execute_tq_encode; use super::f32_vector_array; use super::test_session; use super::vector_values_f32; use crate::TurboQuantConfig; #[test] -fn unpack_pack_matches_old_turboquant_decode() -> VortexResult<()> { +fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.125, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; - let new_packed = execute_tq_pack(input.clone(), &config, &mut ctx)?; - let new_decoded = execute_tq_unpack(new_packed, &config, &mut ctx)?; + let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; + let new_decoded = execute_tq_decode(new_encoded, &config, &mut ctx)?; let old_config = vortex_tensor::encodings::turboquant::TurboQuantConfig { bit_width: config.bit_width(), seed: config.seed(), diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs index 99c8603f272..3def4b7b991 100644 --- a/vortex-turboquant/src/tests/scalar_fns.rs +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -9,12 +9,12 @@ use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::validity::Validity; use vortex_error::VortexResult; -use super::execute_tq_pack; +use super::execute_tq_encode; use super::f32_vector_array; use super::test_session; use super::vector_validity; -use crate::TQPack; -use crate::TQUnpack; +use crate::TQDecode; +use crate::TQEncode; use crate::TurboQuant; use crate::TurboQuantConfig; use crate::vtable::tq_metadata; @@ -24,34 +24,34 @@ fn scalar_fn_ids_and_config_options_roundtrip() -> VortexResult<()> { let session = test_session(); let config = TurboQuantConfig::try_new(4, 7, 2)?; - assert_eq!(TQPack.id().as_ref(), "vortex.turboquant.pack"); - assert_eq!(TQUnpack.id().as_ref(), "vortex.turboquant.unpack"); + assert_eq!(TQEncode.id().as_ref(), "vortex.turboquant.encode"); + assert_eq!(TQDecode.id().as_ref(), "vortex.turboquant.decode"); - let pack_metadata = TQPack.serialize(&config)?.unwrap(); - let unpack_metadata = TQUnpack.serialize(&config)?.unwrap(); + let encode_metadata = TQEncode.serialize(&config)?.unwrap(); + let decode_metadata = TQDecode.serialize(&config)?.unwrap(); - assert_eq!(TQPack.deserialize(&pack_metadata, &session)?, config); - assert_eq!(TQUnpack.deserialize(&unpack_metadata, &session)?, config); + assert_eq!(TQEncode.deserialize(&encode_metadata, &session)?, config); + assert_eq!(TQDecode.deserialize(&decode_metadata, &session)?, config); Ok(()) } #[test] -fn scalar_fn_arrays_pack_and_unpack_vectors() -> VortexResult<()> { +fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; let config = TurboQuantConfig::try_new(3, 42, 3)?; - let packed_lazy = TQPack::try_new_array(input, &config, 2)?; - let packed_metadata = tq_metadata(packed_lazy.dtype())?; - assert_eq!(packed_metadata.dimensions, 128); - assert_eq!(packed_metadata.bit_width, config.bit_width()); - assert!(packed_lazy.dtype().as_extension().is::()); + let encoded_lazy = TQEncode::try_new_array(input, &config, 2)?; + let encoded_metadata = tq_metadata(encoded_lazy.dtype())?; + assert_eq!(encoded_metadata.dimensions, 128); + assert_eq!(encoded_metadata.bit_width, config.bit_width()); + assert!(encoded_lazy.dtype().as_extension().is::()); - let packed = packed_lazy.into_array().execute(&mut ctx)?; - let unpacked_lazy = TQUnpack::try_new_array(packed, &config, 2)?; - let unpacked = unpacked_lazy.into_array().execute(&mut ctx)?; - let validity = vector_validity(unpacked, &mut ctx)?.execute_mask(2, &mut ctx)?; + let encoded = encoded_lazy.into_array().execute(&mut ctx)?; + let decoded_lazy = TQDecode::try_new_array(encoded, &config, 2)?; + let decoded = decoded_lazy.into_array().execute(&mut ctx)?; + let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; assert!(validity.value(0)); assert!(!validity.value(1)); @@ -64,30 +64,30 @@ fn scalar_fn_array_metadata_stores_only_config() -> VortexResult<()> { let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 2, 0.25, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; - let config_metadata = TQUnpack.serialize(&config)?.unwrap(); + let config_metadata = TQDecode.serialize(&config)?.unwrap(); - let packed = execute_tq_pack(input, &config, &mut ctx)?; - let unpacked_lazy = TQUnpack::try_new_array(packed, &config, 2)?.into_array(); - let unpack_plugin = ScalarFnArrayPlugin::new(TQUnpack); + let encoded = execute_tq_encode(input, &config, &mut ctx)?; + let decoded_lazy = TQDecode::try_new_array(encoded, &config, 2)?.into_array(); + let decode_plugin = ScalarFnArrayPlugin::new(TQDecode); assert_eq!( - unpack_plugin.serialize(&unpacked_lazy, &session)?.unwrap(), + decode_plugin.serialize(&decoded_lazy, &session)?.unwrap(), config_metadata ); Ok(()) } #[test] -fn unpack_rejects_config_that_disagrees_with_turboquant_child() -> VortexResult<()> { +fn decode_rejects_config_that_disagrees_with_turboquant_child() -> VortexResult<()> { let session = test_session(); let mut ctx = session.create_execution_ctx(); let input = f32_vector_array(128, 1, 0.25, Validity::NonNullable)?; let config = TurboQuantConfig::try_new(3, 42, 3)?; let different_config = TurboQuantConfig::try_new(4, 42, 3)?; - let packed = TQPack::try_new_array(input, &config, 1)? + let encoded = TQEncode::try_new_array(input, &config, 1)? .into_array() .execute(&mut ctx)?; - assert!(TQUnpack::try_new_array(packed, &different_config, 1).is_err()); + assert!(TQDecode::try_new_array(encoded, &different_config, 1).is_err()); Ok(()) } diff --git a/vortex-turboquant/src/vector/unpack.rs b/vortex-turboquant/src/vector/decode.rs similarity index 96% rename from vortex-turboquant/src/vector/unpack.rs rename to vortex-turboquant/src/vector/decode.rs index 59f54a92b47..1caf441725b 100644 --- a/vortex-turboquant/src/vector/unpack.rs +++ b/vortex-turboquant/src/vector/decode.rs @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant unpacking (dequantization) logic. +//! TurboQuant decoding (dequantization) logic. //! -//! Note that because TurboQuant is a lossy compression scheme, unpacking does not roundtrip with -//! the initial packing. +//! Note that because TurboQuant is a lossy compression scheme, decoding does not roundtrip with +//! the initial encoding. use num_traits::Float; use num_traits::FromPrimitive; @@ -34,8 +34,8 @@ use crate::vtable::TurboQuantMetadata; /// /// The decoded directions are inverse-transformed, truncated to the original dimension, and /// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQPack`](crate::TQPack). -pub(crate) fn unpack_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { +/// [`TQEncode`](crate::TQEncode). +pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { // Get the input TurboQuant array into a form that is easier to work with. let parsed = parse_storage(input, ctx)?; let metadata = parsed.metadata; @@ -53,7 +53,7 @@ pub(crate) fn unpack_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexRe let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; match_each_float_ptype!(metadata.element_ptype, |T| { - unpack_typed::( + decode_typed::( DecodeInputs { metadata: &metadata, sorf_matrix: &transform, @@ -93,7 +93,7 @@ struct DecodeInputs<'a> { codes: &'a PrimitiveArray, } -fn unpack_typed( +fn decode_typed( decode: DecodeInputs<'_>, vector_validity: Validity, num_vectors: usize, diff --git a/vortex-turboquant/src/vector/pack.rs b/vortex-turboquant/src/vector/encode.rs similarity index 88% rename from vortex-turboquant/src/vector/pack.rs rename to vortex-turboquant/src/vector/encode.rs index e277ed1c847..204c74aa3f7 100644 --- a/vortex-turboquant/src/vector/pack.rs +++ b/vortex-turboquant/src/vector/encode.rs @@ -1,10 +1,10 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -//! TurboQuant packing (quantization) logic. +//! TurboQuant encoding (quantization) logic. //! -//! The input to [`pack_vector()`] must be a [`Vector`](vortex_tensor::vector::Vector) extension -//! array. [`pack_vector()`] computes original row norms, normalizes valid rows internally via +//! The input to [`encode_vector()`] must be a [`Vector`](vortex_tensor::vector::Vector) extension +//! array. [`encode_vector()`] computes original row norms, normalizes valid rows internally via //! [`tq_normalize_as_l2_denorm()`], quantizes the normalized child, and stores row-aligned norms and //! codes in the TurboQuant extension storage. @@ -32,12 +32,12 @@ use crate::config::MIN_DIMENSION; use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; -/// Lossily pack a `Vector` extension array into a `TurboQuant` extension array. +/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. /// /// Valid rows are normalized internally before SORF transform and scalar quantization. The original /// row norms are stored explicitly, and original vector nulls are preserved on the storage struct /// and both row-aligned child arrays. -pub(crate) fn pack_vector( +pub(crate) fn encode_vector( input: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx, @@ -47,7 +47,7 @@ pub(crate) fn pack_vector( .dtype() .as_extension_opt() .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| vortex_err!("TurboQuant pack expects a Vector extension array"))?; + .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; let element_ptype = vector_metadata.element_ptype(); diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs index 3a7c978cfae..7015fded568 100644 --- a/vortex-turboquant/src/vector/mod.rs +++ b/vortex-turboquant/src/vector/mod.rs @@ -3,10 +3,10 @@ mod quantize; +pub(crate) mod decode; +pub(crate) mod encode; pub(crate) mod normalize; -pub(crate) mod pack; pub(crate) mod storage; -pub(crate) mod unpack; use vortex_error::VortexResult; use vortex_error::vortex_err; diff --git a/vortex-turboquant/src/vector/storage.rs b/vortex-turboquant/src/vector/storage.rs index a5db2e4a167..016eff3b0dc 100644 --- a/vortex-turboquant/src/vector/storage.rs +++ b/vortex-turboquant/src/vector/storage.rs @@ -13,7 +13,7 @@ //! ``` //! //! Row nullability is carried on the outer struct and on the `norms` and `codes` field arrays. -//! This is deliberate duplication: null vectors remain null throughout pack/unpack instead of being +//! This is deliberate duplication: null vectors remain null throughout encode/decode instead of being //! converted into zero vectors. The code bytes for invalid rows are physical placeholders only; the //! field-level validity records that those rows were not quantized. //! From 11ac1c02d21759be3f53d0e9ec1971b48e2350ce Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 8 May 2026 15:33:55 -0400 Subject: [PATCH 5/5] move code around and remove scalar fn plugin Signed-off-by: Connor Tsui --- vortex-turboquant/benches/encode_decode.rs | 12 +- vortex-turboquant/public-api.lock | 14 +- vortex-turboquant/src/lib.rs | 6 - vortex-turboquant/src/scalar_fns/decode.rs | 270 +++++++++++++------ vortex-turboquant/src/scalar_fns/encode.rs | 71 ++++- vortex-turboquant/src/tests/encode_decode.rs | 20 +- vortex-turboquant/src/tests/file.rs | 2 +- vortex-turboquant/src/tests/mod.rs | 20 +- vortex-turboquant/src/tests/parity.rs | 2 +- vortex-turboquant/src/tests/scalar_fns.rs | 52 +--- vortex-turboquant/src/vector/decode.rs | 198 -------------- vortex-turboquant/src/vector/encode.rs | 94 ------- vortex-turboquant/src/vector/mod.rs | 5 +- 13 files changed, 283 insertions(+), 483 deletions(-) delete mode 100644 vortex-turboquant/src/vector/decode.rs delete mode 100644 vortex-turboquant/src/vector/encode.rs diff --git a/vortex-turboquant/benches/encode_decode.rs b/vortex-turboquant/benches/encode_decode.rs index 845c4ae85a8..f88c37c347a 100644 --- a/vortex-turboquant/benches/encode_decode.rs +++ b/vortex-turboquant/benches/encode_decode.rs @@ -6,7 +6,7 @@ //! //! The four mask shapes (`AllTrue`, `AllFalse`, dense `Values`, sparse `Values`) exercise the //! variant-specialized paths added in the mask refactor in `vector/normalize.rs`, -//! `vector/quantize.rs`, and `vector/decode.rs`. +//! `vector/quantize.rs`, and `scalar_fns/decode.rs`. #![expect(clippy::unwrap_used)] @@ -101,17 +101,15 @@ fn build_vector_array(shape: MaskShape) -> ArrayRef { } fn encode(vec: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { - let len = vec.len(); - TQEncode::try_new_array(vec, config, len) + TQEncode::try_new_array(vec, config) .unwrap() .into_array() .execute(ctx) .unwrap() } -fn decode(encoded: ArrayRef, config: &TurboQuantConfig, ctx: &mut ExecutionCtx) -> ArrayRef { - let len = encoded.len(); - TQDecode::try_new_array(encoded, config, len) +fn decode(encoded: ArrayRef, ctx: &mut ExecutionCtx) -> ArrayRef { + TQDecode::try_new_array(encoded) .unwrap() .into_array() .execute(ctx) @@ -145,5 +143,5 @@ fn turboquant_decode(bencher: Bencher, shape: &MaskShape) { (encoded, SESSION.create_execution_ctx()) }) .input_counter(|_| divan::counter::ItemsCount::new(ROWS)) - .bench_values(|(encoded, mut ctx)| decode(encoded, &cfg, &mut ctx)) + .bench_values(|(encoded, mut ctx)| decode(encoded, &mut ctx)) } diff --git a/vortex-turboquant/public-api.lock b/vortex-turboquant/public-api.lock index 95e162f40f6..ac6b8eeac5b 100644 --- a/vortex-turboquant/public-api.lock +++ b/vortex-turboquant/public-api.lock @@ -4,23 +4,17 @@ pub struct vortex_turboquant::TQDecode impl vortex_turboquant::TQDecode -pub fn vortex_turboquant::TQDecode::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance +pub fn vortex_turboquant::TQDecode::new() -> vortex_array::scalar_fn::typed::TypedScalarFnInstance -pub fn vortex_turboquant::TQDecode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQDecode::try_new_array(vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult impl core::clone::Clone for vortex_turboquant::TQDecode pub fn vortex_turboquant::TQDecode::clone(&self) -> vortex_turboquant::TQDecode -impl vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable for vortex_turboquant::TQDecode - -pub fn vortex_turboquant::TQDecode::deserialize(&self, &vortex_array::dtype::DType, usize, &[u8], &dyn vortex_array::serde::ArrayChildren, &vortex_session::VortexSession) -> vortex_error::VortexResult> - -pub fn vortex_turboquant::TQDecode::serialize(&self, &vortex_array::arrays::scalar_fn::vtable::ScalarFnArrayView<'_, Self>, &vortex_session::VortexSession) -> vortex_error::VortexResult>> - impl vortex_array::scalar_fn::vtable::ScalarFnVTable for vortex_turboquant::TQDecode -pub type vortex_turboquant::TQDecode::Options = vortex_turboquant::TurboQuantConfig +pub type vortex_turboquant::TQDecode::Options = vortex_array::extension::EmptyMetadata pub fn vortex_turboquant::TQDecode::arity(&self, &Self::Options) -> vortex_array::scalar_fn::vtable::Arity @@ -50,7 +44,7 @@ impl vortex_turboquant::TQEncode pub fn vortex_turboquant::TQEncode::new(&vortex_turboquant::TurboQuantConfig) -> vortex_array::scalar_fn::typed::TypedScalarFnInstance -pub fn vortex_turboquant::TQEncode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig, usize) -> vortex_error::VortexResult +pub fn vortex_turboquant::TQEncode::try_new_array(vortex_array::array::erased::ArrayRef, &vortex_turboquant::TurboQuantConfig) -> vortex_error::VortexResult impl core::clone::Clone for vortex_turboquant::TQEncode diff --git a/vortex-turboquant/src/lib.rs b/vortex-turboquant/src/lib.rs index 60e36edff3a..7aeb60368dd 100644 --- a/vortex-turboquant/src/lib.rs +++ b/vortex-turboquant/src/lib.rs @@ -69,18 +69,12 @@ pub use vtable::TurboQuantMetadata; // TODO(connor): We need to somehow make sure that callers call `vortex_tensor::initialize` first. /// Register the TurboQuant extension type with a Vortex session. pub fn initialize(session: &vortex_session::VortexSession) { - use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; use vortex_array::dtype::session::DTypeSessionExt; use vortex_array::scalar_fn::session::ScalarFnSessionExt; - use vortex_array::session::ArraySessionExt; - session.dtypes().register(TurboQuant); session.scalar_fns().register(TQEncode); session.scalar_fns().register(TQDecode); - - let session_arrays = session.arrays(); - session_arrays.register(ScalarFnArrayPlugin::new(TQDecode)); } #[cfg(test)] diff --git a/vortex-turboquant/src/scalar_fns/decode.rs b/vortex-turboquant/src/scalar_fns/decode.rs index 8a937b23a64..22e4f2560b8 100644 --- a/vortex-turboquant/src/scalar_fns/decode.rs +++ b/vortex-turboquant/src/scalar_fns/decode.rs @@ -7,39 +7,43 @@ use std::fmt; use std::fmt::Formatter; use std::sync::Arc; +use num_traits::Float; +use num_traits::FromPrimitive; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::FixedSizeListArray; +use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::scalar_fn::ScalarFnArrayView; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; use vortex_array::expr::Expression; use vortex_array::extension::EmptyMetadata; +use vortex_array::match_each_float_ptype; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::ExecutionArgs; use vortex_array::scalar_fn::ScalarFnId; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::scalar_fn::TypedScalarFnInstance; -use vortex_array::serde::ArrayChildren; +use vortex_array::validity::Validity; +use vortex_buffer::BufferMut; +use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_ensure_eq; +use vortex_error::vortex_ensure; use vortex_error::vortex_err; +use vortex_mask::Mask; use vortex_session::VortexSession; -use vortex_tensor::vector::AnyVector; use vortex_tensor::vector::Vector; -use super::metadata::deserialize_config; -use super::metadata::serialize_config; -use crate::TurboQuantConfig; -use crate::vector::decode::decode_vector; -use crate::vtable::TurboQuant; +use crate::centroids::compute_or_get_centroids; +use crate::sorf::SorfMatrix; +use crate::vector::storage::parse_storage; +use crate::vector::tq_padded_dim; use crate::vtable::TurboQuantMetadata; use crate::vtable::tq_metadata; -use crate::vtable::tq_storage_dtype; /// Lazy TurboQuant vector decode scalar function. #[derive(Clone)] @@ -47,29 +51,26 @@ pub struct TQDecode; impl TQDecode { /// Creates a new [`TypedScalarFnInstance`] wrapping TurboQuant decoding. - pub fn new(config: &TurboQuantConfig) -> TypedScalarFnInstance { - TypedScalarFnInstance::new(TQDecode, config.clone()) + pub fn new() -> TypedScalarFnInstance { + TypedScalarFnInstance::new(TQDecode, EmptyMetadata) } /// Constructs a [`ScalarFnArray`] that lazily decodes a `TurboQuant` child into a `Vector`. - pub fn try_new_array( - child: ArrayRef, - config: &TurboQuantConfig, - len: usize, - ) -> VortexResult { - ScalarFnArray::try_new(TQDecode::new(config).erased(), vec![child], len) + pub fn try_new_array(child: ArrayRef) -> VortexResult { + let len = child.len(); + ScalarFnArray::try_new(TQDecode::new().erased(), vec![child], len) } } impl ScalarFnVTable for TQDecode { - type Options = TurboQuantConfig; + type Options = EmptyMetadata; fn id(&self) -> ScalarFnId { ScalarFnId::new("vortex.turboquant.decode") } - fn serialize(&self, options: &Self::Options) -> VortexResult>> { - Ok(Some(serialize_config(options))) + fn serialize(&self, _options: &Self::Options) -> VortexResult>> { + Ok(Some(vec![])) } fn deserialize( @@ -77,7 +78,12 @@ impl ScalarFnVTable for TQDecode { metadata: &[u8], _session: &VortexSession, ) -> VortexResult { - deserialize_config(metadata) + vortex_ensure!( + metadata.is_empty(), + "TQDecode options metadata must be empty" + ); + + Ok(EmptyMetadata) } fn arity(&self, _options: &Self::Options) -> Arity { @@ -93,19 +99,18 @@ impl ScalarFnVTable for TQDecode { fn fmt_sql( &self, - options: &Self::Options, + _options: &Self::Options, expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { write!(f, "tq_decode(")?; expr.child(0).fmt_sql(f)?; - write!(f, ", {options})") + write!(f, ")") } - fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { + fn return_dtype(&self, _options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { let child_dtype = &arg_dtypes[0]; let metadata = tq_metadata(child_dtype)?; - validate_config_matches_metadata(options, &metadata)?; let storage_dtype = DType::FixedSizeList( Arc::new(DType::Primitive( @@ -146,68 +151,157 @@ impl ScalarFnVTable for TQDecode { } } -impl ScalarFnArrayVTable for TQDecode { - fn serialize( - &self, - view: &ScalarFnArrayView, - _session: &VortexSession, - ) -> VortexResult>> { - Ok(Some(serialize_config(view.options))) +/// Decode a `TurboQuant` extension array back into a `Vector` extension array. +/// +/// The decoded directions are inverse-transformed, truncated to the original dimension, and +/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with +/// [`TQEncode`](crate::TQEncode). +pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + let parsed = parse_storage(input, ctx)?; + let metadata = parsed.metadata; + if parsed.len == 0 { + return build_empty_vector(metadata, parsed.vector_validity); } - fn deserialize( - &self, - dtype: &DType, - len: usize, - metadata: &[u8], - children: &dyn ArrayChildren, - _session: &VortexSession, - ) -> VortexResult> { - let options = deserialize_config(metadata)?; - let vector_metadata = dtype - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| { - vortex_err!("TQDecode parent dtype must be a Vector extension array, got {dtype}") - })?; - - let metadata = TurboQuantMetadata { - element_ptype: vector_metadata.element_ptype(), - dimensions: vector_metadata.dimensions(), - bit_width: options.bit_width(), - seed: options.seed(), - num_rounds: options.num_rounds(), - }; - let storage_dtype = tq_storage_dtype(&metadata, dtype.nullability())?; - let child_dtype = - DType::Extension(ExtDType::::try_new(metadata, storage_dtype)?.erased()); - let child = children.get(0, &child_dtype, len)?; - - Ok(ScalarFnArrayParts { - options, - children: vec![child], - }) - } + let padded_dim = tq_padded_dim(metadata.dimensions)?; + let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; + let padded_dim = u32::try_from(padded_dim) + .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; + + let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; + + match_each_float_ptype!(metadata.element_ptype, |T| { + decode_typed::( + DecodeInputs { + metadata: &metadata, + sorf_matrix: &transform, + centroids: ¢roids, + norms: &parsed.norms, + codes: &parsed.codes, + }, + parsed.vector_validity, + parsed.len, + ctx, + ) + }) } -fn validate_config_matches_metadata( - config: &TurboQuantConfig, - metadata: &TurboQuantMetadata, -) -> VortexResult<()> { - vortex_ensure_eq!( - config.bit_width(), - metadata.bit_width, - "TQDecode config bit_width must match TurboQuant child metadata" - ); - vortex_ensure_eq!( - config.seed(), - metadata.seed, - "TQDecode config seed must match TurboQuant child metadata" - ); - vortex_ensure_eq!( - config.num_rounds(), - metadata.num_rounds, - "TQDecode config num_rounds must match TurboQuant child metadata" - ); - Ok(()) +fn build_empty_vector( + metadata: TurboQuantMetadata, + vector_validity: Validity, +) -> VortexResult { + match_each_float_ptype!(metadata.element_ptype, |T| { + let elements = PrimitiveArray::empty::(Nullability::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + 0, + )?; + + Vector::try_new_vector_array(fsl.into_array()) + }) +} + +struct DecodeInputs<'a> { + metadata: &'a TurboQuantMetadata, + sorf_matrix: &'a SorfMatrix, + centroids: &'a [f32], + norms: &'a PrimitiveArray, + codes: &'a PrimitiveArray, +} + +fn decode_typed( + decode: DecodeInputs<'_>, + vector_validity: Validity, + num_vectors: usize, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType + Float + FromPrimitive, +{ + let metadata = decode.metadata; + let dimensions = usize::try_from(metadata.dimensions) + .vortex_expect("dimensions stays representable as usize"); + let padded_dim = decode.sorf_matrix.padded_dim(); + let centroids = decode.centroids; + let norms = decode.norms.as_slice::(); + let codes = decode.codes.as_slice::(); + let mask = vector_validity.execute_mask(num_vectors, ctx)?; + + let output_len = num_vectors + .checked_mul(dimensions) + .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; + let mut output = BufferMut::::with_capacity(output_len); + + let mut decoded = vec![0.0f32; padded_dim]; + let mut inverse = vec![0.0f32; padded_dim]; + + let mut decode_row = |output: &mut BufferMut, i: usize| { + let code_row = &codes[i * padded_dim..][..padded_dim]; + + for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { + *dst = *centroids + .get(usize::from(code)) + .vortex_expect("TurboQuant code exceeds centroid count"); + } + + decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); + + let norm = norms[i]; + for &value in inverse.iter().take(dimensions) { + // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, + // `f64`): values outside `f16` range saturate to `±inf` rather than returning + // `None`. + let value = T::from_f32(value) + .vortex_expect("from_f32 is infallible for supported float types"); + + // SAFETY: total pushes across all match arms equal `output_len`. + unsafe { output.push_unchecked(value * norm) }; + } + }; + + match &mask { + Mask::AllFalse(_) => { + // SAFETY: `output` was allocated with capacity `output_len`, and this push writes + // exactly `output_len` zero placeholders. + unsafe { output.push_n_unchecked(T::zero(), output_len) }; + } + Mask::AllTrue(_) => { + for i in 0..num_vectors { + decode_row(&mut output, i); + } + } + Mask::Values(values_mask) => { + let mut cursor = 0; + + for &(start, end) in values_mask.slices() { + if start > cursor { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; + } + + for i in start..end { + decode_row(&mut output, i); + } + + cursor = end; + } + + if cursor < num_vectors { + // SAFETY: total pushes across all arms equal `output_len`. + unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; + } + } + } + + let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + metadata.dimensions, + vector_validity, + num_vectors, + )?; + + Vector::try_new_vector_array(fsl.into_array()) } diff --git a/vortex-turboquant/src/scalar_fns/encode.rs b/vortex-turboquant/src/scalar_fns/encode.rs index d683cd7e502..29ce7cc580a 100644 --- a/vortex-turboquant/src/scalar_fns/encode.rs +++ b/vortex-turboquant/src/scalar_fns/encode.rs @@ -8,7 +8,13 @@ use std::fmt::Formatter; use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; +use vortex_array::IntoArray; +use vortex_array::arrays::Extension; +use vortex_array::arrays::ExtensionArray; +use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::ScalarFnArray; +use vortex_array::arrays::extension::ExtensionArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::DType; use vortex_array::dtype::extension::ExtDType; use vortex_array::expr::Expression; @@ -28,7 +34,11 @@ use super::metadata::deserialize_config; use super::metadata::serialize_config; use crate::TurboQuantConfig; use crate::config::MIN_DIMENSION; -use crate::vector::encode::encode_vector; +use crate::vector::normalize::tq_normalize_as_l2_denorm; +use crate::vector::quantize::empty_quantization; +use crate::vector::quantize::turboquant_quantize_core; +use crate::vector::storage::build_codes_child; +use crate::vector::storage::build_storage; use crate::vector::tq_padded_dim; use crate::vtable::TurboQuant; use crate::vtable::TurboQuantMetadata; @@ -54,8 +64,8 @@ impl TQEncode { pub fn try_new_array( child: ArrayRef, config: &TurboQuantConfig, - len: usize, ) -> VortexResult { + let len = child.len(); ScalarFnArray::try_new(TQEncode::new(config).erased(), vec![child], len) } } @@ -155,3 +165,60 @@ impl ScalarFnVTable for TQEncode { false } } + +/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. +/// +/// Valid rows are normalized internally before SORF transform and scalar quantization. The original +/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct +/// and both row-aligned child arrays. +pub(crate) fn encode_vector( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let num_vectors = input.len(); + let vector_metadata = input + .dtype() + .as_extension_opt() + .and_then(|ext_dtype| ext_dtype.metadata_opt::()) + .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; + + let element_ptype = vector_metadata.element_ptype(); + + let dimensions = vector_metadata.dimensions(); + vortex_ensure!( + dimensions >= MIN_DIMENSION, + "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", + ); + let padded_dim = tq_padded_dim(dimensions)?; + + let vector_validity = input.validity()?; + + let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + + let normalized_ext = normalized + .as_opt::() + .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; + let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; + + let core = if normalized_fsl.is_empty() { + empty_quantization(padded_dim) + } else { + // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. + unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } + }; + let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; + + let metadata = TurboQuantMetadata { + element_ptype, + dimensions, + bit_width: config.bit_width(), + seed: config.seed(), + num_rounds: config.num_rounds(), + }; + let storage = build_storage(norms, codes, num_vectors, vector_validity)?; + + Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) +} diff --git a/vortex-turboquant/src/tests/encode_decode.rs b/vortex-turboquant/src/tests/encode_decode.rs index f9b9f3d3569..ed5aab190aa 100644 --- a/vortex-turboquant/src/tests/encode_decode.rs +++ b/vortex-turboquant/src/tests/encode_decode.rs @@ -84,7 +84,7 @@ fn encode_decode_empty_vectors() -> VortexResult<()> { let input = vector_array::(128, &[], Validity::NonNullable)?; let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; - let decoded = execute_tq_decode(encoded, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; assert!(decoded.is_empty()); Ok(()) @@ -185,7 +185,7 @@ fn encode_decode_preserves_nulls_and_zero_norm_rows() -> VortexResult<()> { let input = vector_array(128, &values, Validity::from_iter([true, true, false]))?; let encoded = execute_tq_encode(input, &TurboQuantConfig::default(), &mut ctx)?; - let decoded = execute_tq_decode(encoded, &TurboQuantConfig::default(), &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; let output = vector_values_f32(decoded.clone(), &mut ctx)?; let validity = vector_validity(decoded, &mut ctx)?.execute_mask(3, &mut ctx)?; @@ -211,7 +211,7 @@ fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<( .collect::>(); let input = vector_array(128, &values, Validity::NonNullable)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = execute_tq_decode(encoded, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; let ext: ExtensionArray = decoded.execute(&mut ctx)?; assert_eq!(vector_element_ptype(&ext)?, PType::F16); } @@ -221,7 +221,7 @@ fn encode_decode_supports_non_f32_inputs(#[case] ptype: PType) -> VortexResult<( .collect::>(); let input = vector_array(128, &values, Validity::NonNullable)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = execute_tq_decode(encoded, &config, &mut ctx)?; + let decoded = execute_tq_decode(encoded, &mut ctx)?; let ext: ExtensionArray = decoded.execute(&mut ctx)?; assert_eq!(vector_element_ptype(&ext)?, PType::F64); } @@ -239,19 +239,11 @@ fn decode_scales_by_stored_norms() -> VortexResult<()> { let config = TurboQuantConfig::try_new(2, 99, 3)?; let base_values = vector_values_f32( - execute_tq_decode( - execute_tq_encode(base, &config, &mut ctx)?, - &config, - &mut ctx, - )?, + execute_tq_decode(execute_tq_encode(base, &config, &mut ctx)?, &mut ctx)?, &mut ctx, )?; let scaled_values = vector_values_f32( - execute_tq_decode( - execute_tq_encode(scaled, &config, &mut ctx)?, - &config, - &mut ctx, - )?, + execute_tq_decode(execute_tq_encode(scaled, &config, &mut ctx)?, &mut ctx)?, &mut ctx, )?; diff --git a/vortex-turboquant/src/tests/file.rs b/vortex-turboquant/src/tests/file.rs index f403431516d..e59b7a95c75 100644 --- a/vortex-turboquant/src/tests/file.rs +++ b/vortex-turboquant/src/tests/file.rs @@ -54,7 +54,7 @@ fn file_roundtrip_lazy_decode_scalar_fn_with_initialize_session() -> VortexResul let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; let config = TurboQuantConfig::try_new(3, 42, 3)?; let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded = TQDecode::try_new_array(encoded, &config, 2)?.into_array(); + let decoded = TQDecode::try_new_array(encoded)?.into_array(); let mut file_bytes = Vec::new(); VortexWriteOptions::new(session.clone()) diff --git a/vortex-turboquant/src/tests/mod.rs b/vortex-turboquant/src/tests/mod.rs index c60d52cd957..ffa1db175a7 100644 --- a/vortex-turboquant/src/tests/mod.rs +++ b/vortex-turboquant/src/tests/mod.rs @@ -36,7 +36,6 @@ use crate::TQDecode; use crate::TQEncode; use crate::TurboQuantConfig; use crate::initialize; -use crate::vtable::tq_metadata; mod encode_decode; mod file; @@ -125,29 +124,18 @@ fn execute_tq_encode( config: &TurboQuantConfig, ctx: &mut ExecutionCtx, ) -> VortexResult { - let len = input.len(); - TQEncode::try_new_array(input, config, len)? + TQEncode::try_new_array(input, config)? .into_array() .execute(ctx) } -fn execute_tq_decode( - input: ArrayRef, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let len = input.len(); - TQDecode::try_new_array(input, config, len)? - .into_array() - .execute(ctx) +fn execute_tq_decode(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + TQDecode::try_new_array(input)?.into_array().execute(ctx) } fn execute_tq_decode_from_metadata( input: ArrayRef, ctx: &mut ExecutionCtx, ) -> VortexResult { - let metadata = tq_metadata(input.dtype())?; - let config = TurboQuantConfig::try_new(metadata.bit_width, metadata.seed, metadata.num_rounds)?; - - execute_tq_decode(input, &config, ctx) + execute_tq_decode(input, ctx) } diff --git a/vortex-turboquant/src/tests/parity.rs b/vortex-turboquant/src/tests/parity.rs index e5fc82fc1fd..2995fb8b296 100644 --- a/vortex-turboquant/src/tests/parity.rs +++ b/vortex-turboquant/src/tests/parity.rs @@ -20,7 +20,7 @@ fn encode_decode_matches_old_turboquant_decode() -> VortexResult<()> { let config = TurboQuantConfig::try_new(3, 42, 3)?; let new_encoded = execute_tq_encode(input.clone(), &config, &mut ctx)?; - let new_decoded = execute_tq_decode(new_encoded, &config, &mut ctx)?; + let new_decoded = execute_tq_decode(new_encoded, &mut ctx)?; let old_config = vortex_tensor::encodings::turboquant::TurboQuantConfig { bit_width: config.bit_width(), seed: config.seed(), diff --git a/vortex-turboquant/src/tests/scalar_fns.rs b/vortex-turboquant/src/tests/scalar_fns.rs index 3def4b7b991..de125e8a5f1 100644 --- a/vortex-turboquant/src/tests/scalar_fns.rs +++ b/vortex-turboquant/src/tests/scalar_fns.rs @@ -1,15 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_array::ArrayPlugin; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; -use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayPlugin; +use vortex_array::extension::EmptyMetadata; use vortex_array::scalar_fn::ScalarFnVTable; use vortex_array::validity::Validity; use vortex_error::VortexResult; -use super::execute_tq_encode; use super::f32_vector_array; use super::test_session; use super::vector_validity; @@ -20,7 +18,7 @@ use crate::TurboQuantConfig; use crate::vtable::tq_metadata; #[test] -fn scalar_fn_ids_and_config_options_roundtrip() -> VortexResult<()> { +fn scalar_fn_ids_and_options_roundtrip() -> VortexResult<()> { let session = test_session(); let config = TurboQuantConfig::try_new(4, 7, 2)?; @@ -28,10 +26,14 @@ fn scalar_fn_ids_and_config_options_roundtrip() -> VortexResult<()> { assert_eq!(TQDecode.id().as_ref(), "vortex.turboquant.decode"); let encode_metadata = TQEncode.serialize(&config)?.unwrap(); - let decode_metadata = TQDecode.serialize(&config)?.unwrap(); + let decode_metadata = TQDecode.serialize(&EmptyMetadata)?.unwrap(); assert_eq!(TQEncode.deserialize(&encode_metadata, &session)?, config); - assert_eq!(TQDecode.deserialize(&decode_metadata, &session)?, config); + assert!(decode_metadata.is_empty()); + assert_eq!( + TQDecode.deserialize(&decode_metadata, &session)?, + EmptyMetadata + ); Ok(()) } @@ -42,14 +44,14 @@ fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { let input = f32_vector_array(128, 2, 0.25, Validity::from_iter([true, false]))?; let config = TurboQuantConfig::try_new(3, 42, 3)?; - let encoded_lazy = TQEncode::try_new_array(input, &config, 2)?; + let encoded_lazy = TQEncode::try_new_array(input, &config)?; let encoded_metadata = tq_metadata(encoded_lazy.dtype())?; assert_eq!(encoded_metadata.dimensions, 128); assert_eq!(encoded_metadata.bit_width, config.bit_width()); assert!(encoded_lazy.dtype().as_extension().is::()); let encoded = encoded_lazy.into_array().execute(&mut ctx)?; - let decoded_lazy = TQDecode::try_new_array(encoded, &config, 2)?; + let decoded_lazy = TQDecode::try_new_array(encoded)?; let decoded = decoded_lazy.into_array().execute(&mut ctx)?; let validity = vector_validity(decoded, &mut ctx)?.execute_mask(2, &mut ctx)?; @@ -57,37 +59,3 @@ fn scalar_fn_arrays_encode_and_decode_vectors() -> VortexResult<()> { assert!(!validity.value(1)); Ok(()) } - -#[test] -fn scalar_fn_array_metadata_stores_only_config() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 2, 0.25, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - let config_metadata = TQDecode.serialize(&config)?.unwrap(); - - let encoded = execute_tq_encode(input, &config, &mut ctx)?; - let decoded_lazy = TQDecode::try_new_array(encoded, &config, 2)?.into_array(); - let decode_plugin = ScalarFnArrayPlugin::new(TQDecode); - assert_eq!( - decode_plugin.serialize(&decoded_lazy, &session)?.unwrap(), - config_metadata - ); - Ok(()) -} - -#[test] -fn decode_rejects_config_that_disagrees_with_turboquant_child() -> VortexResult<()> { - let session = test_session(); - let mut ctx = session.create_execution_ctx(); - let input = f32_vector_array(128, 1, 0.25, Validity::NonNullable)?; - let config = TurboQuantConfig::try_new(3, 42, 3)?; - let different_config = TurboQuantConfig::try_new(4, 42, 3)?; - - let encoded = TQEncode::try_new_array(input, &config, 1)? - .into_array() - .execute(&mut ctx)?; - - assert!(TQDecode::try_new_array(encoded, &different_config, 1).is_err()); - Ok(()) -} diff --git a/vortex-turboquant/src/vector/decode.rs b/vortex-turboquant/src/vector/decode.rs deleted file mode 100644 index 1caf441725b..00000000000 --- a/vortex-turboquant/src/vector/decode.rs +++ /dev/null @@ -1,198 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant decoding (dequantization) logic. -//! -//! Note that because TurboQuant is a lossy compression scheme, decoding does not roundtrip with -//! the initial encoding. - -use num_traits::Float; -use num_traits::FromPrimitive; -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::PrimitiveArray; -use vortex_array::dtype::NativePType; -use vortex_array::dtype::Nullability; -use vortex_array::match_each_float_ptype; -use vortex_array::validity::Validity; -use vortex_buffer::BufferMut; -use vortex_error::VortexExpect; -use vortex_error::VortexResult; -use vortex_error::vortex_err; -use vortex_mask::Mask; -use vortex_tensor::vector::Vector; - -use super::storage::parse_storage; -use super::tq_padded_dim; -use crate::centroids::compute_or_get_centroids; -use crate::sorf::SorfMatrix; -use crate::vtable::TurboQuantMetadata; - -/// Decode a `TurboQuant` extension array back into a `Vector` extension array. -/// -/// The decoded directions are inverse-transformed, truncated to the original dimension, and -/// multiplied by the stored row norms. The conversion is lossy and does not roundtrip with -/// [`TQEncode`](crate::TQEncode). -pub(crate) fn decode_vector(input: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - // Get the input TurboQuant array into a form that is easier to work with. - let parsed = parse_storage(input, ctx)?; - let metadata = parsed.metadata; - if parsed.len == 0 { - return build_empty_vector(metadata, parsed.vector_validity); - } - - let padded_dim = tq_padded_dim(metadata.dimensions)?; - let transform = SorfMatrix::try_new(padded_dim, metadata.num_rounds as usize, metadata.seed)?; - let padded_dim = u32::try_from(padded_dim) - .map_err(|_| vortex_err!("TurboQuant padded dimension does not fit u32"))?; - - // We retrieve the centroids on read because they are mostly known statically for the given - // settings. - let centroids = compute_or_get_centroids(padded_dim, metadata.bit_width)?; - - match_each_float_ptype!(metadata.element_ptype, |T| { - decode_typed::( - DecodeInputs { - metadata: &metadata, - sorf_matrix: &transform, - centroids: ¢roids, - norms: &parsed.norms, - codes: &parsed.codes, - }, - parsed.vector_validity, - parsed.len, - ctx, - ) - }) -} - -fn build_empty_vector( - metadata: TurboQuantMetadata, - vector_validity: Validity, -) -> VortexResult { - match_each_float_ptype!(metadata.element_ptype, |T| { - let elements = PrimitiveArray::empty::(Nullability::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - metadata.dimensions, - vector_validity, - 0, - )?; - - Vector::try_new_vector_array(fsl.into_array()) - }) -} - -struct DecodeInputs<'a> { - metadata: &'a TurboQuantMetadata, - sorf_matrix: &'a SorfMatrix, - centroids: &'a [f32], - norms: &'a PrimitiveArray, - codes: &'a PrimitiveArray, -} - -fn decode_typed( - decode: DecodeInputs<'_>, - vector_validity: Validity, - num_vectors: usize, - ctx: &mut ExecutionCtx, -) -> VortexResult -where - T: NativePType + Float + FromPrimitive, -{ - let metadata = decode.metadata; - let dimensions = usize::try_from(metadata.dimensions) - .vortex_expect("dimensions stays representable as usize"); - let padded_dim = decode.sorf_matrix.padded_dim(); - let centroids = decode.centroids; - let norms = decode.norms.as_slice::(); - let codes = decode.codes.as_slice::(); - let mask = vector_validity.execute_mask(num_vectors, ctx)?; - - let output_len = num_vectors - .checked_mul(dimensions) - .ok_or_else(|| vortex_err!("TurboQuant decoded vector length overflow"))?; - let mut output = BufferMut::::with_capacity(output_len); - - let mut decoded = vec![0.0f32; padded_dim]; - let mut inverse = vec![0.0f32; padded_dim]; - - // Decode a single row: gather codes through the centroid table, apply the inverse SORF - // transform, then denormalize and push `dimensions` elements into `output`. Captures the - // read-only inputs and the scratch buffers so each call site only needs to pass `output` - // and the row index. - let mut decode_row = |output: &mut BufferMut, i: usize| { - let code_row = &codes[i * padded_dim..][..padded_dim]; - - // Gather the values according to the codes. - for (dst, &code) in decoded.iter_mut().zip(code_row.iter()) { - *dst = *centroids - .get(usize::from(code)) - .vortex_expect("TurboQuant code exceeds centroid count"); - } - - decode.sorf_matrix.inverse_transform(&decoded, &mut inverse); - - let norm = norms[i]; - for &value in inverse.iter().take(dimensions) { - // `T::from_f32` is infallible for the supported float ptypes (`f16`, `f32`, - // `f64`): values outside `f16` range saturate to `±inf` rather than returning - // `None`. - let value = T::from_f32(value) - .vortex_expect("from_f32 is infallible for supported float types"); - - // SAFETY: total pushes across all match arms equal `output_len`. - unsafe { output.push_unchecked(value * norm) }; - } - }; - - // The total number of pushes is always exactly `num_vectors * dimensions == output_len` - // across every arm below, which is the invariant the per-row `unsafe` blocks rely on. - match &mask { - Mask::AllFalse(_) => { - // Every row is invalid: bulk-fill the output with zero placeholders. - // - // SAFETY: `output` was allocated with capacity `output_len`, and this push writes - // exactly `output_len` zero placeholders. - unsafe { output.push_n_unchecked(T::zero(), output_len) }; - } - Mask::AllTrue(_) => { - for i in 0..num_vectors { - decode_row(&mut output, i); - } - } - Mask::Values(values_mask) => { - let mut cursor = 0; - - for &(start, end) in values_mask.slices() { - if start > cursor { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (start - cursor) * dimensions) }; - } - - for i in start..end { - decode_row(&mut output, i); - } - - cursor = end; - } - - if cursor < num_vectors { - // SAFETY: total pushes across all arms equal `output_len`. - unsafe { output.push_n_unchecked(T::zero(), (num_vectors - cursor) * dimensions) }; - } - } - } - - let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - metadata.dimensions, - vector_validity, - num_vectors, - )?; - - Vector::try_new_vector_array(fsl.into_array()) -} diff --git a/vortex-turboquant/src/vector/encode.rs b/vortex-turboquant/src/vector/encode.rs deleted file mode 100644 index 204c74aa3f7..00000000000 --- a/vortex-turboquant/src/vector/encode.rs +++ /dev/null @@ -1,94 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// SPDX-FileCopyrightText: Copyright the Vortex contributors - -//! TurboQuant encoding (quantization) logic. -//! -//! The input to [`encode_vector()`] must be a [`Vector`](vortex_tensor::vector::Vector) extension -//! array. [`encode_vector()`] computes original row norms, normalizes valid rows internally via -//! [`tq_normalize_as_l2_denorm()`], quantizes the normalized child, and stores row-aligned norms and -//! codes in the TurboQuant extension storage. - -use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::FixedSizeListArray; -use vortex_array::arrays::extension::ExtensionArrayExt; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_error::VortexResult; -use vortex_error::vortex_ensure; -use vortex_error::vortex_err; -use vortex_tensor::vector::AnyVector; - -use super::normalize::tq_normalize_as_l2_denorm; -use super::quantize::empty_quantization; -use super::quantize::turboquant_quantize_core; -use super::storage::build_codes_child; -use super::storage::build_storage; -use super::tq_padded_dim; -use crate::TurboQuantConfig; -use crate::config::MIN_DIMENSION; -use crate::vtable::TurboQuant; -use crate::vtable::TurboQuantMetadata; - -/// Lossily encode a `Vector` extension array into a `TurboQuant` extension array. -/// -/// Valid rows are normalized internally before SORF transform and scalar quantization. The original -/// row norms are stored explicitly, and original vector nulls are preserved on the storage struct -/// and both row-aligned child arrays. -pub(crate) fn encode_vector( - input: ArrayRef, - config: &TurboQuantConfig, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let num_vectors = input.len(); - let vector_metadata = input - .dtype() - .as_extension_opt() - .and_then(|ext_dtype| ext_dtype.metadata_opt::()) - .ok_or_else(|| vortex_err!("TurboQuant encode expects a Vector extension array"))?; - - let element_ptype = vector_metadata.element_ptype(); - - let dimensions = vector_metadata.dimensions(); - vortex_ensure!( - dimensions >= MIN_DIMENSION, - "TurboQuant requires dimension >= {MIN_DIMENSION}, got {dimensions}", - ); - let padded_dim = tq_padded_dim(dimensions)?; - - let vector_validity = input.validity()?; - - // We must normalize the vectors in order to apply the transform during quantization. - // NB: The 2 child arrays share the same validity with `input`. - let l2_denorm = tq_normalize_as_l2_denorm(input, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - - let normalized_ext = normalized - .as_opt::() - .ok_or_else(|| vortex_err!("normalized TurboQuant input must be a Vector extension"))?; - let normalized_fsl: FixedSizeListArray = normalized_ext.storage_array().clone().execute(ctx)?; - - let core = if normalized_fsl.is_empty() { - empty_quantization(padded_dim) - } else { - // SAFETY: `tq_normalize_as_l2_denorm` returned this normalized Vector child. - unsafe { turboquant_quantize_core(&normalized_fsl, config, ctx)? } - }; - let codes = build_codes_child(num_vectors, core, vector_validity.clone())?; - - // Now that we have the codes into the centroid codebook and the norms, we can build the - // TurboQuant extension array. - let metadata = TurboQuantMetadata { - element_ptype, - dimensions, - bit_width: config.bit_width(), - seed: config.seed(), - num_rounds: config.num_rounds(), - }; - let storage = build_storage(norms, codes, num_vectors, vector_validity)?; - - Ok(ExtensionArray::try_new_from_vtable(TurboQuant, metadata, storage)?.into_array()) -} diff --git a/vortex-turboquant/src/vector/mod.rs b/vortex-turboquant/src/vector/mod.rs index 7015fded568..58c4271a398 100644 --- a/vortex-turboquant/src/vector/mod.rs +++ b/vortex-turboquant/src/vector/mod.rs @@ -1,11 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -mod quantize; - -pub(crate) mod decode; -pub(crate) mod encode; pub(crate) mod normalize; +pub(crate) mod quantize; pub(crate) mod storage; use vortex_error::VortexResult;