diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 375c3307e..34a8a5aee 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ env: RUST_BACKTRACE: 1 # The features we want to explicitly test. For example, the `flatbuffers-build` feature # of `diskann-quantization` requires additional setup and so must not be included by default. - DISKANN_FEATURES: "virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen" + DISKANN_FEATURES: "virtual_storage,bf_tree,spherical-quantization,product-quantization,tracing,experimental_diversity_search,disk-index,flatbuffers,linalg,codegen,document-index" # Use the Rust version specified in rust-toolchain.toml rust_stable: "1.92" diff --git a/Cargo.lock b/Cargo.lock index 9bce846e3..6373d32bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -799,6 +799,7 @@ dependencies = [ "serde_json", "tempfile", "thiserror 2.0.17", + "tokio", "tracing", ] @@ -918,6 +919,7 @@ dependencies = [ "rayon", "rstest", "serde", + "serde_json", "tracing", "tracing-subscriber", "vfs", diff --git a/diskann-benchmark/Cargo.toml b/diskann-benchmark/Cargo.toml index bebaf4b8e..39d64b1bb 100644 --- a/diskann-benchmark/Cargo.toml +++ b/diskann-benchmark/Cargo.toml @@ -63,6 +63,9 @@ scalar-quantization = [] # Enable minmax-quantization based algorithms minmax-quantization = [] +# Enable Document Index benchmarks +document-index = [] + # Enable Disk Index benchmarks disk-index = [ "diskann-disk/perf_test", diff --git a/diskann-benchmark/example/document-filter.json b/diskann-benchmark/example/document-filter.json new file mode 100644 index 000000000..0bce1572d --- /dev/null +++ b/diskann-benchmark/example/document-filter.json @@ -0,0 +1,39 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "document-index-build", + "content": { + "build": { + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "data_labels": "data.256.label.jsonl", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "num_threads": 4 + }, + "search": { + "queries": "disk_index_sample_query_10pts.fbin", + "query_predicates": "query.10.label.jsonl", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin", + "beta": 0.5, + "reps": 5, + "num_threads": [ + 1 + ], + "runs": [ + { + "search_n": 20, + "search_l": [20, 30, 40], + "recall_k": 10 + } + ] + } + } + } + ] +} \ No newline at end of file diff --git a/diskann-benchmark/src/backend/document_index/benchmark.rs b/diskann-benchmark/src/backend/document_index/benchmark.rs new file mode 100644 index 000000000..8f3d56175 --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/benchmark.rs @@ -0,0 +1,845 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Benchmark for DocumentInsertStrategy which allows inserting Documents +//! (vector + attributes) into a DiskANN index built with DocumentProvider. +//! Also benchmarks filtered search using InlineBetaStrategy. + +use std::io::Write; +use std::num::NonZeroUsize; +use std::path::Path; +use std::sync::Arc; + +use anyhow::Result; +use diskann::{ + graph::{ + glue, search::Knn, search_output_buffer, DiskANNIndex, SearchOutputBuffer, + StartPointStrategy, + }, + provider::DefaultContext, + ANNError, ANNErrorKind, +}; +use diskann_benchmark_core::{ + build::{self, Build, Parallelism}, + recall, search as search_api, tokio, +}; +use diskann_benchmark_runner::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + output::Output, + registry::Benchmarks, + utils::{datatype, fmt, percentiles, MicroSeconds}, + Any, +}; +use diskann_label_filter::{ + attribute::{Attribute, AttributeValue}, + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::DocumentInsertStrategy, document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::inline_beta_filter::InlineBetaStrategy, + query::FilteredQuery, + read_and_parse_queries, read_baselabels, + traits::attribute_store::AttributeStore, + ASTExpr, +}; + +use diskann_providers::model::graph::provider::async_::{ + common::{self, NoStore, TableBasedDeletes}, + inmem::{CreateFullPrecision, DefaultProvider, DefaultProviderParameters, SetStartPoints}, +}; +use diskann_utils::views::Matrix; +use diskann_utils::views::MatrixView; +use diskann_utils::{future::AsyncFriendly, sampling::medoid::ComputeMedoid}; +use diskann_vector::distance::SquaredL2; +use diskann_vector::PureDistanceFunction; +use serde::Serialize; + +use crate::{ + backend::index::build::ProgressMeter, + inputs::document_index::DocumentIndexBuild, + utils::{ + datafiles::{self, BinFile}, + recall::RecallMetrics, + }, +}; + +/// Register the document index benchmarks. +pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { + benchmarks.register::>( + "document-index-build-f32", + |job, _checkpoint, out| { + let stats = job.run(out)?; + Ok(serde_json::to_value(stats)?) + }, + ); +} + +/// Document index benchmark job. +pub(super) struct DocumentIndexJob<'a, T> { + input: &'a DocumentIndexBuild, + _type: std::marker::PhantomData, +} + +impl<'a, T> DocumentIndexJob<'a, T> { + fn new(input: &'a DocumentIndexBuild) -> Self { + Self { + input, + _type: std::marker::PhantomData, + } + } +} + +impl diskann_benchmark_runner::dispatcher::Map for DocumentIndexJob<'static, T> { + type Type<'a> = DocumentIndexJob<'a, T>; +} + +// Dispatch from the concrete input type +impl<'a, T> DispatchRule<&'a DocumentIndexBuild> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ + type Error = std::convert::Infallible; + + fn try_match(_from: &&'a DocumentIndexBuild) -> Result { + datatype::Type::::try_match(&_from.build.data_type) + } + + fn convert(from: &'a DocumentIndexBuild) -> Result { + Ok(DocumentIndexJob::new(from)) + } +} + +// Central dispatch mapping from Any +impl<'a, T> DispatchRule<&'a Any> for DocumentIndexJob<'a, T> +where + datatype::Type: DispatchRule, +{ + type Error = anyhow::Error; + + fn try_match(from: &&'a Any) -> Result { + from.try_match::() + } + + fn convert(from: &'a Any) -> Result { + from.convert::() + } + + fn description(f: &mut std::fmt::Formatter, from: Option<&&'a Any>) -> std::fmt::Result { + Any::description::(f, from, DocumentIndexBuild::tag()) + } +} +/// Convert a HashMap to Vec +fn hashmap_to_attributes(map: std::collections::HashMap) -> Vec { + map.into_iter() + .map(|(k, v)| Attribute::from_value(k, v)) + .collect() +} + +fn find_medoid_index(x: MatrixView<'_, T>, y: &[T]) -> Option +where + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, +{ + let mut min_dist = f32::INFINITY; + let mut min_ind = x.nrows(); + for (i, row) in x.row_iter().enumerate() { + let dist = SquaredL2::evaluate(row, y); + if dist < min_dist { + min_dist = dist; + min_ind = i; + } + } + + // No closest neighbor found. + if min_ind == x.nrows() { + None + } else { + Some(min_ind) + } +} + +/// Compute the index of the row closest to the medoid (centroid) of the data. +fn compute_medoid_index(data: &Matrix) -> anyhow::Result +where + T: bytemuck::Pod + Copy + 'static + ComputeMedoid, + for<'a> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'a [T], &'a [T], f32>, +{ + let dim = data.ncols(); + if dim == 0 || data.nrows() == 0 { + return Ok(0); + } + + // returns row closes to centroid. + let medoid = T::compute_medoid(data.as_view()); + + find_medoid_index(data.as_view(), medoid.as_slice()) + .ok_or_else(|| anyhow::anyhow!("Failed to find medoid index: no closest row found")) +} + +impl<'a, T> DocumentIndexJob<'a, T> { + fn run(self, mut output: &mut dyn Output) -> Result + where + T: diskann::utils::VectorRepr + + diskann::graph::SampleableForStart + + diskann_utils::sampling::WithApproximateNorm + + 'static, + for<'b> diskann_vector::distance::SquaredL2: PureDistanceFunction<&'b [T], &'b [T]>, + { + let build = &self.input.build; + + // 1. Load vectors from data file in the original data type + writeln!(output, "Loading vectors ({})...", build.data_type)?; + let timer = std::time::Instant::now(); + let data_path: &Path = build.data.as_ref(); + writeln!(output, "Data path is: {}", data_path.to_string_lossy())?; + let data: Matrix = datafiles::load_dataset(BinFile(data_path))?; + let data_load_time: MicroSeconds = timer.elapsed().into(); + let num_vectors = data.nrows(); + let dim = data.ncols(); + writeln!( + output, + " Loaded {} vectors of dimension {}", + num_vectors, dim + )?; + + // 2. Load and parse labels from the data_labels file + writeln!(output, "Loading labels...")?; + let timer = std::time::Instant::now(); + let label_path: &Path = build.data_labels.as_ref(); + let labels = read_baselabels(label_path)?; + let label_load_time: MicroSeconds = timer.elapsed().into(); + let label_count = labels.len(); + writeln!(output, " Loaded {} label documents", label_count)?; + + if num_vectors != label_count { + return Err(anyhow::anyhow!( + "Mismatch: {} vectors but {} label documents", + num_vectors, + label_count + )); + } + + // Convert labels to attribute vectors + let attributes: Vec> = labels + .into_iter() + .map(|doc| hashmap_to_attributes(doc.flatten_metadata_with_separator(""))) + .collect(); + + // 3. Create the index configuration + let config = build.build_config()?; + + // 4. Create the data provider directly + writeln!(output, "Creating index...")?; + let params = DefaultProviderParameters { + max_points: num_vectors, + frozen_points: diskann::utils::ONE, + metric: build.distance.into(), + dim, + prefetch_lookahead: None, + prefetch_cache_line_level: None, + max_degree: build.max_degree as u32, + }; + + // Create the underlying provider + let fp_precursor = CreateFullPrecision::::new(dim, None); + let inner_provider = + DefaultProvider::new_empty(params, fp_precursor, NoStore, TableBasedDeletes)?; + + // Set start points using medoid strategy + let start_points = StartPointStrategy::Medoid + .compute(data.as_view()) + .map_err(|e| anyhow::anyhow!("Failed to compute start points: {}", e))?; + inner_provider.set_start_points(start_points.row_iter())?; + + // 5. Create DocumentProvider wrapping the inner provider + let attribute_store = RoaringAttributeStore::::new(); + + // Store attributes for the start point (medoid) + // Start points are stored at indices num_vectors..num_vectors+frozen_points + let medoid_idx = compute_medoid_index(&data)?; + let start_point_id = num_vectors as u32; // Start points begin at max_points + let medoid_attrs = attributes.get(medoid_idx).cloned().unwrap_or_default(); + attribute_store.set_element(&start_point_id, &medoid_attrs)?; + + let doc_provider = DocumentProvider::new(inner_provider, attribute_store); + + // Create a new DiskANNIndex with DocumentProvider + let doc_index = Arc::new(DiskANNIndex::new(config, doc_provider, None)); + + // 6. Build index by inserting vectors and attributes (parallel) + writeln!( + output, + "Building index with {} vectors using {} threads...", + num_vectors, build.num_threads + )?; + let timer = std::time::Instant::now(); + + let rt = tokio::runtime(build.num_threads)?; + let data_arc = Arc::new(data); + let attributes_arc = Arc::new(attributes); + + let builder = DocumentIndexBuilder::new( + doc_index.clone(), + data_arc.clone(), + attributes_arc.clone(), + DocumentInsertStrategy::new(common::FullPrecision), + ); + let num_tasks = NonZeroUsize::new(build.num_threads).unwrap_or(diskann::utils::ONE); + let parallelism = Parallelism::dynamic(diskann::utils::ONE, num_tasks); + let build_results = + build::build_tracked(builder, parallelism, &rt, Some(&ProgressMeter::new(output)))?; + let insert_latencies: Vec = build_results + .take_output() + .into_iter() + .map(|r| r.latency) + .collect(); + + let build_time: MicroSeconds = timer.elapsed().into(); + writeln!(output, " Index built in {} s", build_time.as_seconds())?; + + let insert_percentiles = percentiles::compute_percentiles(&mut insert_latencies.clone())?; + // ===================== + // Search Phase + // ===================== + let search_input = &self.input.search; + + // Load query vectors (same type as data for compatible distance computation) + writeln!(output, "\nLoading query vectors...")?; + let query_path: &Path = search_input.queries.as_ref(); + let queries: Matrix = datafiles::load_dataset(BinFile(query_path))?; + let num_queries = queries.nrows(); + writeln!(output, " Loaded {} queries", num_queries)?; + + // Load and parse query predicates + writeln!(output, "Loading query predicates...")?; + let predicate_path: &Path = search_input.query_predicates.as_ref(); + let parsed_predicates = read_and_parse_queries(predicate_path)?; + writeln!(output, " Loaded {} predicates", parsed_predicates.len())?; + + if num_queries != parsed_predicates.len() { + return Err(anyhow::anyhow!( + "Mismatch: {} queries but {} predicates", + num_queries, + parsed_predicates.len() + )); + } + + // Load groundtruth + writeln!(output, "Loading groundtruth...")?; + let gt_path: &Path = search_input.groundtruth.as_ref(); + let groundtruth: Vec> = datafiles::load_range_groundtruth(BinFile(gt_path))?; + writeln!( + output, + " Loaded groundtruth with {} rows", + groundtruth.len() + )?; + + // Run filtered searches + writeln!( + output, + "\nRunning filtered searches (beta={})...", + search_input.beta + )?; + let mut search_results = Vec::new(); + + for num_threads in &search_input.num_threads { + for run in &search_input.runs { + for &search_l in &run.search_l { + writeln!( + output, + " threads={}, search_n={}, search_l={}...", + num_threads, run.search_n, search_l + )?; + + let search_run_result = run_filtered_search( + &doc_index, + &queries, + &parsed_predicates, + &groundtruth, + search_input.beta, + *num_threads, + run.search_n, + search_l, + run.recall_k, + search_input.reps, + )?; + + writeln!( + output, + " recall={:.4}, mean_qps={:.1}", + search_run_result.recall.average, + if search_run_result.qps.is_empty() { + 0.0 + } else { + search_run_result.qps.iter().sum::() + / search_run_result.qps.len() as f64 + } + )?; + + search_results.push(search_run_result); + } + } + } + + let stats = DocumentIndexStats { + num_vectors, + dim, + label_count, + data_load_time, + label_load_time, + build_time, + insert_latencies: insert_percentiles, + search: search_results, + }; + + writeln!(output, "\n{}", stats)?; + Ok(stats) + } +} + +/// Per-query output from [`FilteredSearcher::search`]. +struct FilteredSearchOutput { + distances: Vec, + comparisons: u32, + hops: u32, +} + +/// Implements [`search_api::Search`] for parallelized inline-beta filtered search. +/// +/// Each query is paired with a predicate at the same index in `predicates`. The +/// [`InlineBetaStrategy`] is used with a [`FilteredQuery`] containing the raw vector +/// and the predicate's [`ASTExpr`]. +struct FilteredSearcher +where + DP: diskann::provider::DataProvider, +{ + index: Arc>, + queries: Arc>, + predicates: Arc>, + beta: f32, +} + +impl search_api::Search for FilteredSearcher +where + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, + T: bytemuck::Pod + Copy + Send + Sync + 'static, +{ + type Id = DP::ExternalId; + type Parameters = Knn; + type Output = FilteredSearchOutput; + + fn num_queries(&self) -> usize { + self.queries.nrows() + } + + fn id_count(&self, parameters: &Knn) -> search_api::IdCount { + search_api::IdCount::Fixed(parameters.k_value()) + } + + async fn search( + &self, + parameters: &Knn, + buffer: &mut O, + index: usize, + ) -> diskann::ANNResult + where + O: diskann::graph::SearchOutputBuffer + Send, + { + let ctx = DefaultContext; + let query_vec = self.queries.row(index); + let (_, ref ast_expr) = self.predicates[index]; + let strategy = InlineBetaStrategy::new(self.beta, common::FullPrecision); + let filtered_query = FilteredQuery::new(query_vec, ast_expr.clone()); + + // Use a concrete IdDistance scratch buffer so that both the IDs and distances + // are captured. Afterwards, the valid IDs are forwarded into the framework buffer. + let k = parameters.k_value().get(); + let mut ids = vec![0u32; k]; + let mut distances = vec![0.0f32; k]; + let mut scratch = search_output_buffer::IdDistance::new(&mut ids, &mut distances); + + let stats = &self + .index + .search(*parameters, &strategy, &ctx, &filtered_query, &mut scratch) + .await?; + + let count = scratch.current_len(); + for (&id, &dist) in std::iter::zip(&ids[..count], &distances[..count]) { + if buffer.push(id, dist).is_full() { + break; + } + } + + Ok(FilteredSearchOutput { + distances: distances[..count].to_vec(), + comparisons: stats.cmps, + hops: stats.hops, + }) + } +} + +/// Aggregates per-rep [`search_api::SearchResults`] into a [`SearchRunStats`]. +struct FilteredSearchAggregator<'a> { + groundtruth: &'a Vec>, + predicates: &'a [(usize, ASTExpr)], + recall_k: usize, +} + +impl search_api::Aggregate for FilteredSearchAggregator<'_> { + type Output = SearchRunStats; + + fn aggregate( + &mut self, + run: search_api::Run, + results: Vec>, + ) -> anyhow::Result { + let parameters = run.parameters(); + let search_n = parameters.k_value().get(); + let num_queries = results.first().map(|r| r.len()).unwrap_or(0); + + // Recall from first rep only. + let recall_metrics: RecallMetrics = match results.first() { + Some(first) => (&recall::knn( + self.groundtruth, + None, + first.ids().as_rows(), + self.recall_k, + search_n, + true, + )?) + .into(), + None => anyhow::bail!("no search results"), + }; + + // Per-query details from first rep (only queries with recall < 1). + let first = results.first().unwrap(); + let per_query_details: Vec = (0..num_queries) + .filter_map(|query_idx| { + let result_ids: Vec = first.ids().as_rows().row(query_idx).to_vec(); + let result_distances: Vec = first + .output() + .get(query_idx) + .map(|o| o.distances.clone()) + .unwrap_or_default(); + let gt_ids: Vec = self + .groundtruth + .get(query_idx) + .map(|gt| gt.iter().take(20).copied().collect()) + .unwrap_or_default(); + + let result_set: std::collections::HashSet = + result_ids.iter().copied().collect(); + let gt_set: std::collections::HashSet = + gt_ids.iter().take(self.recall_k).copied().collect(); + let intersection = result_set.intersection(>_set).count(); + let per_query_recall = if gt_set.is_empty() { + 1.0 + } else { + intersection as f64 / gt_set.len() as f64 + }; + + if per_query_recall >= 1.0 { + return None; + } + + let (_, ref ast_expr) = self.predicates[query_idx]; + Some(PerQueryDetails { + query_id: query_idx, + filter: format!("{:?}", ast_expr), + recall: per_query_recall, + result_ids, + result_distances, + groundtruth_ids: gt_ids, + }) + }) + .collect(); + + // Wall-clock latency and QPS per rep. + let rep_latencies: Vec = + results.iter().map(|r| r.end_to_end_latency()).collect(); + let qps: Vec = rep_latencies + .iter() + .map(|l| num_queries as f64 / l.as_seconds()) + .collect(); + + // Per-query latencies, comparisons, and hops aggregated across all reps. + let mut all_latencies: Vec = Vec::new(); + let mut all_cmps: Vec = Vec::new(); + let mut all_hops: Vec = Vec::new(); + for r in &results { + all_latencies.extend_from_slice(r.latencies()); + for o in r.output() { + all_cmps.push(o.comparisons); + all_hops.push(o.hops); + } + } + + let percentiles::Percentiles { mean, p90, p99, .. } = + percentiles::compute_percentiles(&mut all_latencies)?; + + let mean_cmps = if all_cmps.is_empty() { + 0.0 + } else { + all_cmps.iter().map(|&x| x as f32).sum::() / all_cmps.len() as f32 + }; + let mean_hops = if all_hops.is_empty() { + 0.0 + } else { + all_hops.iter().map(|&x| x as f32).sum::() / all_hops.len() as f32 + }; + + Ok(SearchRunStats { + num_threads: run.setup().threads.get(), + num_queries, + search_n, + search_l: parameters.l_value().get(), + recall: recall_metrics, + qps, + wall_clock_time: rep_latencies, + mean_latency: mean, + p90_latency: p90, + p99_latency: p99, + mean_cmps, + mean_hops, + per_query_details: Some(per_query_details), + }) + } +} + +/// Run filtered search with the given parameters. +#[allow(clippy::too_many_arguments)] +fn run_filtered_search( + index: &Arc>, + queries: &Matrix, + predicates: &[(usize, ASTExpr)], + groundtruth: &Vec>, + beta: f32, + num_threads: NonZeroUsize, + search_n: usize, + search_l: usize, + recall_k: usize, + reps: NonZeroUsize, +) -> anyhow::Result +where + T: bytemuck::Pod + Copy + Send + Sync + 'static, + DP: diskann::provider::DataProvider< + Context = DefaultContext, + ExternalId = u32, + InternalId = u32, + > + Send + + Sync + + 'static, + for<'a> InlineBetaStrategy: glue::SearchStrategy> + + glue::DefaultPostProcessor, u32>, +{ + let searcher = Arc::new(FilteredSearcher { + index: index.clone(), + queries: Arc::new(queries.clone()), + predicates: Arc::new(predicates.to_vec()), + beta, + }); + + let parameters = Knn::new_default(search_n, search_l)?; + let setup = search_api::Setup { + threads: num_threads, + tasks: num_threads, + reps, + }; + + let mut results = search_api::search_all( + searcher, + [search_api::Run::new(parameters, setup)], + FilteredSearchAggregator { + groundtruth, + predicates, + recall_k, + }, + )?; + + results + .pop() + .ok_or_else(|| anyhow::anyhow!("no search results")) +} + +/// Per-query detailed results for debugging/analysis +#[derive(Debug, Serialize)] +pub struct PerQueryDetails { + pub query_id: usize, + pub filter: String, + pub recall: f64, + pub result_ids: Vec, + pub result_distances: Vec, + pub groundtruth_ids: Vec, +} + +/// Results from a single search configuration (one search_l value). +#[derive(Debug, Serialize)] +pub struct SearchRunStats { + pub num_threads: usize, + pub num_queries: usize, + pub search_n: usize, + pub search_l: usize, + pub recall: RecallMetrics, + pub qps: Vec, + pub wall_clock_time: Vec, + pub mean_latency: f64, + pub p90_latency: MicroSeconds, + pub p99_latency: MicroSeconds, + pub mean_cmps: f32, + pub mean_hops: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub per_query_details: Option>, +} + +#[derive(Debug, Serialize)] +pub struct DocumentIndexStats { + pub num_vectors: usize, + pub dim: usize, + pub label_count: usize, + pub data_load_time: MicroSeconds, + pub label_load_time: MicroSeconds, + pub build_time: MicroSeconds, + pub insert_latencies: percentiles::Percentiles, + pub search: Vec, +} + +impl std::fmt::Display for DocumentIndexStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build Stats:")?; + writeln!(f, " Vectors: {} x {}", self.num_vectors, self.dim)?; + writeln!(f, " Label Count: {}", self.label_count)?; + writeln!( + f, + " Data Load Time: {} s", + self.data_load_time.as_seconds() + )?; + writeln!( + f, + " Label Load Time: {} s", + self.label_load_time.as_seconds() + )?; + writeln!(f, " Total Build Time: {} s", self.build_time.as_seconds())?; + writeln!(f, " Insert Latencies:")?; + writeln!(f, " Mean: {} us", self.insert_latencies.mean)?; + writeln!(f, " P50: {} us", self.insert_latencies.median)?; + writeln!(f, " P90: {} us", self.insert_latencies.p90)?; + writeln!(f, " P99: {} us", self.insert_latencies.p99)?; + + if !self.search.is_empty() { + let header = [ + "L", + "KNN", + "Avg Cmps", + "Avg Hops", + "QPS -mean(max)", + "Avg Latency", + "p99 Latency", + "Recall", + "Threads", + "Queries", + "WallClock(s)", + ]; + writeln!(f, "\nFiltered Search Results:")?; + let mut table = fmt::Table::new(header, self.search.len()); + self.search.iter().enumerate().for_each(|(row_idx, s)| { + let mut row = table.row(row_idx); + let mean_qps = percentiles::mean(&s.qps).unwrap_or(0.0); + let max_qps = s.qps.iter().cloned().fold(0.0_f64, f64::max); + let mean_wall_clock = percentiles::mean( + &s.wall_clock_time + .iter() + .map(|l| l.as_seconds()) + .collect::>(), + ) + .unwrap_or(0.0); + row.insert(s.search_l, 0); + row.insert(s.search_n, 1); + row.insert(format!("{:.1}", s.mean_cmps), 2); + row.insert(format!("{:.1}", s.mean_hops), 3); + row.insert(format!("{:.1}({:.1})", mean_qps, max_qps), 4); + row.insert(format!("{:.1} s", s.mean_latency), 5); + row.insert(format!("{:.1} s", s.p99_latency), 6); + row.insert(format!("{:.4}", s.recall.average), 7); + row.insert(s.num_threads, 8); + row.insert(s.num_queries, 9); + row.insert(format!("{:.3} s", mean_wall_clock), 10); + }); + write!(f, "{}", table)?; + } + Ok(()) + } +} + +// ================================ +// Parallel Build Support +// ================================ + +/// Implements [`Build`] for parallel document insertion into a [`DiskANNIndex`] +/// backed by a [`DocumentProvider`]. Each call to [`Build::build`] inserts a +/// contiguous range of vectors and their associated attributes. +struct DocumentIndexBuilder { + index: Arc>, + data: Arc>, + attributes: Arc>>, + strategy: DocumentInsertStrategy, +} + +impl DocumentIndexBuilder { + fn new( + index: Arc>, + data: Arc>, + attributes: Arc>>, + strategy: DocumentInsertStrategy, + ) -> Arc { + Arc::new(Self { + index, + data, + attributes, + strategy, + }) + } +} + +impl Build for DocumentIndexBuilder +where + DP: diskann::provider::DataProvider + + for<'doc> diskann::provider::SetElement> + + AsyncFriendly, + for<'doc> DocumentInsertStrategy: + diskann::graph::glue::InsertStrategy>, + DocumentInsertStrategy: AsyncFriendly, + T: AsyncFriendly, +{ + type Output = (); + + fn num_data(&self) -> usize { + self.data.nrows() + } + + async fn build(&self, range: std::ops::Range) -> diskann::ANNResult { + let ctx = DefaultContext; + for i in range { + let attrs = self.attributes.get(i).cloned().ok_or_else(|| { + ANNError::message( + ANNErrorKind::Opaque, + format!("Failed to get attributes at index {}", i), + ) + })?; + let doc = Document::new(self.data.row(i), attrs); + self.index + .insert(self.strategy, &ctx, &(i as u32), &doc) + .await?; + } + Ok(()) + } +} diff --git a/diskann-benchmark/src/backend/document_index/mod.rs b/diskann-benchmark/src/backend/document_index/mod.rs new file mode 100644 index 000000000..022470578 --- /dev/null +++ b/diskann-benchmark/src/backend/document_index/mod.rs @@ -0,0 +1,32 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Backend benchmark implementation for document index with label filters. +//! +//! This benchmark tests the DocumentInsertStrategy which enables inserting +//! Document objects (vector + attributes) into a DiskANN index. + +use diskann_benchmark_runner::registry::Benchmarks; + +cfg_if::cfg_if! { + if #[cfg(feature = "document-index")] { + mod benchmark; + + /// Register document index benchmarks when the `document-index` feature is enabled. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + benchmark::register_benchmarks(registry); + } + } else { + crate::utils::stub_impl!( + "document-index", + inputs::document_index::DocumentIndexBuild + ); + + /// Register a stub that guides users to enable the `document-index` feature. + pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { + imp::register("document-index", registry); + } + } +} diff --git a/diskann-benchmark/src/backend/index/build.rs b/diskann-benchmark/src/backend/index/build.rs index ef6284d2a..b674bf5e8 100644 --- a/diskann-benchmark/src/backend/index/build.rs +++ b/diskann-benchmark/src/backend/index/build.rs @@ -213,12 +213,12 @@ impl std::fmt::Display for BuildStats { } } -pub struct ProgressMeter<'a> { +pub(crate) struct ProgressMeter<'a> { output: &'a mut dyn Output, } impl<'a> ProgressMeter<'a> { - pub fn new(output: &'a mut dyn Output) -> Self { + pub(crate) fn new(output: &'a mut dyn Output) -> Self { Self { output } } } diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..07ed0ccb8 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -mod build; +pub(crate) mod build; mod search; mod streaming; diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..5dc1967de 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -4,6 +4,7 @@ */ mod disk_index; +mod document_index; mod exhaustive; mod filters; mod index; @@ -13,4 +14,5 @@ pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::regis disk_index::register_benchmarks(registry); index::register_benchmarks(registry); filters::register_benchmarks(registry); + document_index::register_benchmarks(registry); } diff --git a/diskann-benchmark/src/inputs/document_index.rs b/diskann-benchmark/src/inputs/document_index.rs new file mode 100644 index 000000000..f1ed3c063 --- /dev/null +++ b/diskann-benchmark/src/inputs/document_index.rs @@ -0,0 +1,182 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Input types for document index benchmarks using DocumentInsertStrategy. + +use std::num::NonZeroUsize; + +use anyhow::Context; +use diskann::graph::{ + config::{Builder, ConfigError, MaxDegree, PruneKind}, + Config, +}; +use diskann_benchmark_runner::{ + files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, +}; +use serde::{Deserialize, Serialize}; + +use super::async_::GraphSearch; +use crate::inputs::{as_input, Example, Input}; + +////////////// +// Registry // +////////////// + +as_input!(DocumentIndexBuild); + +pub(super) fn register_inputs( + registry: &mut diskann_benchmark_runner::registry::Inputs, +) -> anyhow::Result<()> { + registry.register(Input::::new())?; + Ok(()) +} + +/// Build parameters for document index construction. +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentBuildParams { + pub(crate) data_type: DataType, + pub(crate) data: InputFile, + pub(crate) data_labels: InputFile, + pub(crate) distance: crate::utils::SimilarityMeasure, + pub(crate) max_degree: usize, + pub(crate) l_build: usize, + pub(crate) alpha: f32, + pub(crate) num_threads: usize, +} + +impl DocumentBuildParams { + pub(crate) fn build_config(&self) -> Result { + let metric = self.distance.into(); + let prune_kind = PruneKind::from_metric(metric); + let mut config_builder = Builder::new( + self.max_degree, // pruned_degree + MaxDegree::default_slack(), // max_degree + self.l_build, + prune_kind, + ); + config_builder.alpha(self.alpha); + let config = config_builder.build()?; + Ok(config) + } +} + +impl CheckDeserialization for DocumentBuildParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.check_deserialization(checker)?; + self.data_labels.check_deserialization(checker)?; + self.build_config()?; + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentSearchParams { + pub(crate) queries: InputFile, + pub(crate) query_predicates: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) beta: f32, + pub(crate) reps: NonZeroUsize, + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, +} + +fn default_reps() -> NonZeroUsize { + NonZeroUsize::new(5).unwrap() +} +fn default_thread_counts() -> Vec { + vec![NonZeroUsize::new(1).unwrap()] +} + +impl CheckDeserialization for DocumentSearchParams { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.check_deserialization(checker)?; + self.query_predicates.check_deserialization(checker)?; + self.groundtruth.check_deserialization(checker)?; + if self.beta <= 0.0 || self.beta > 1.0 { + return Err(anyhow::anyhow!( + "beta must be in range (0, 1], got: {}", + self.beta + )); + } + for (i, run) in self.runs.iter_mut().enumerate() { + run.check_deserialization(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct DocumentIndexBuild { + pub(crate) build: DocumentBuildParams, + pub(crate) search: DocumentSearchParams, +} + +impl DocumentIndexBuild { + pub(crate) const fn tag() -> &'static str { + "document-index-build" + } +} + +impl CheckDeserialization for DocumentIndexBuild { + fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.build.check_deserialization(checker)?; + self.search.check_deserialization(checker)?; + Ok(()) + } +} + +impl Example for DocumentIndexBuild { + fn example() -> Self { + Self { + build: DocumentBuildParams { + data_type: DataType::Float32, + data: InputFile::new("data.fbin"), + data_labels: InputFile::new("data.label.jsonl"), + distance: crate::utils::SimilarityMeasure::SquaredL2, + max_degree: 32, + l_build: 50, + alpha: 1.2, + num_threads: 1, + }, + search: DocumentSearchParams { + queries: InputFile::new("queries.fbin"), + query_predicates: InputFile::new("query.label.jsonl"), + groundtruth: InputFile::new("groundtruth.bin"), + beta: 0.5, + reps: default_reps(), + num_threads: default_thread_counts(), + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![20, 30, 40, 50], + recall_k: 10, + }], + }, + } + } +} + +impl std::fmt::Display for DocumentIndexBuild { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Document Index Build with Label Filters\n")?; + writeln!(f, "tag: \"{}\"", Self::tag())?; + writeln!( + f, + "\nBuild: data={}, labels={}, R={}, L={}, alpha={}", + self.build.data.display(), + self.build.data_labels.display(), + self.build.max_degree, + self.build.l_build, + self.build.alpha + )?; + writeln!( + f, + "Search: queries={}, beta={}", + self.search.queries.display(), + self.search.beta + )?; + Ok(()) + } +} diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index a0ae1a982..65de65a41 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -5,6 +5,7 @@ pub(crate) mod async_; pub(crate) mod disk; +pub(crate) mod document_index; pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod save_and_load; @@ -16,6 +17,7 @@ pub(crate) fn register_inputs( exhaustive::register_inputs(registry)?; disk::register_inputs(registry)?; filters::register_inputs(registry)?; + document_index::register_inputs(registry)?; Ok(()) } diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index b3de5901e..cdca116e3 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -779,4 +779,71 @@ mod tests { let mut output = Memory::new(); cli.check_target(&mut output).unwrap(); } + + #[test] + fn document_filter_integration() { + let input_path = example_directory().join("document-filter.json"); + + let tempdir = tempfile::tempdir().unwrap(); + let output_path = tempdir.path().join("output.json"); + assert!(!output_path.exists()); + + let modified_input_path = tempdir.path().join("input.json"); + + let mut raw = value_from_file(&input_path); + prefix_search_directories(&mut raw, &root_directory()); + save_to_file(&modified_input_path, &raw); + + run_document_filter_integration(&modified_input_path, &output_path, &raw); + } + + #[cfg(feature = "document-index")] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + raw: &serde_json::Value, + ) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + cli.run(&mut output).unwrap(); + + let output = String::from_utf8(output.into_inner()).unwrap(); + println!("output = {}", output); + // Check that the results file is generated. + assert!(output_path.exists()); + + let results: Vec = load_from_file(output_path); + assert_eq!(results.len(), num_jobs(raw)); + } + + #[cfg(not(feature = "document-index"))] + fn run_document_filter_integration( + input_path: &std::path::Path, + output_path: &std::path::Path, + _raw: &serde_json::Value, + ) { + let command = Commands::Run { + input_file: input_path.to_owned(), + output_file: output_path.to_owned(), + dry_run: false, + }; + let cli = Cli::from_commands(command, true); + let mut output = Memory::new(); + + let err = cli.run(&mut output).unwrap_err(); + println!("err = {:?}", err); + + let output = String::from_utf8(output.into_inner()).unwrap(); + assert!(output.contains("\"document-index\" feature")); + println!("output = {}", output); + + // The output file should not have been created because we failed. + assert!(!output_path.exists()); + } } diff --git a/diskann-label-filter/Cargo.toml b/diskann-label-filter/Cargo.toml index 98fe3879c..b204dae23 100644 --- a/diskann-label-filter/Cargo.toml +++ b/diskann-label-filter/Cargo.toml @@ -33,6 +33,8 @@ tempfile.workspace = true anyhow.workspace = true futures-util.workspace = true tracing.workspace = true +diskann = { workspace = true, features = ["testing"] } +tokio = { workspace = true, features = ["rt"] } diff --git a/diskann-label-filter/src/document.rs b/diskann-label-filter/src/document.rs index 31cad4772..5c817525c 100644 --- a/diskann-label-filter/src/document.rs +++ b/diskann-label-filter/src/document.rs @@ -8,12 +8,12 @@ use diskann_utils::reborrow::Reborrow; ///Simple container class that clients can use to /// supply diskann with a vector and its attributes -pub struct Document<'a, V> { +pub struct Document<'a, V: ?Sized> { vector: &'a V, attributes: Vec, } -impl<'a, V> Document<'a, V> { +impl<'a, V: ?Sized> Document<'a, V> { pub fn new(vector: &'a V, attributes: Vec) -> Self { Self { vector, attributes } } diff --git a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs index 0fa21cc02..8b39d8731 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/ast_label_id_mapper.rs @@ -31,19 +31,14 @@ impl ASTLabelIdMapper { Self { attribute_map } } - fn _lookup( - encoder: &AttributeEncoder, - attribute: &Attribute, - field: &str, - op: &CompareOp, - ) -> ANNResult> { + fn _lookup(encoder: &AttributeEncoder, attribute: &Attribute) -> ANNResult> { match encoder.get(attribute) { Some(attribute_id) => Ok(ASTIdExpr::Terminal(attribute_id)), None => Err(ANNError::message( ANNErrorKind::Opaque, format!( - "{}+{} present in the query does not exist in the dataset.", - field, op + "{} present in the query does not exist in the dataset.", + attribute ), )), } @@ -120,10 +115,10 @@ impl ASTVisitor for ASTLabelIdMapper { if let Some(attribute) = label_or_none { match self.attribute_map.read() { - Ok(guard) => Self::_lookup(&guard, &attribute, field, op), + Ok(guard) => Self::_lookup(&guard, &attribute), Err(poison_error) => { let attr_map = poison_error.into_inner(); - Self::_lookup(&attr_map, &attribute, field, op) + Self::_lookup(&attr_map, &attribute) } } } else { diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs new file mode 100644 index 000000000..6e6886eda --- /dev/null +++ b/diskann-label-filter/src/encoded_attribute_provider/document_insert_strategy.rs @@ -0,0 +1,221 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +//! A strategy wrapper that enables insertion of [Document] objects into a +//! [DiskANNIndex] using a [DocumentProvider]. + +use diskann::{ + graph::glue::{ExpandBeam, InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + provider::{Accessor, BuildQueryComputer, DataProvider, DelegateNeighbor, HasId}, + ANNResult, +}; + +use super::document_provider::DocumentProvider; +use crate::document::Document; +use crate::encoded_attribute_provider::roaring_attribute_store::RoaringAttributeStore; + +/// A strategy wrapper that enables insertion of [Document] objects. +pub struct DocumentInsertStrategy { + inner: Inner, +} + +impl Clone for DocumentInsertStrategy { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl Copy for DocumentInsertStrategy {} + +impl DocumentInsertStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } + + pub fn inner(&self) -> &Inner { + &self.inner + } +} + +/// Wrapper accessor for Document queries +pub struct DocumentSearchAccessor { + inner: Inner, +} + +impl DocumentSearchAccessor { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl HasId for DocumentSearchAccessor +where + Inner: HasId, +{ + type Id = Inner::Id; +} + +impl Accessor for DocumentSearchAccessor +where + Inner: Accessor, +{ + type ElementRef<'a> = Inner::ElementRef<'a>; + type Element<'a> + = Inner::Element<'a> + where + Self: 'a; + type Extended = Inner::Extended; + type GetError = Inner::GetError; + + fn get_element( + &mut self, + id: Self::Id, + ) -> impl std::future::Future, Self::GetError>> + Send { + self.inner.get_element(id) + } + + fn on_elements_unordered( + &mut self, + itr: Itr, + f: F, + ) -> impl std::future::Future> + Send + where + Self: Sync, + Itr: Iterator + Send, + F: Send + for<'b> FnMut(Self::ElementRef<'b>, Self::Id), + { + self.inner.on_elements_unordered(itr, f) + } +} + +impl<'doc, Inner, VT> BuildQueryComputer> for DocumentSearchAccessor +where + Inner: BuildQueryComputer, + VT: ?Sized, +{ + type QueryComputerError = Inner::QueryComputerError; + type QueryComputer = Inner::QueryComputer; + + fn build_query_computer( + &self, + from: &Document<'doc, VT>, + ) -> Result { + self.inner.build_query_computer(from.vector()) + } +} + +impl<'this, Inner> DelegateNeighbor<'this> for DocumentSearchAccessor +where + Inner: DelegateNeighbor<'this>, +{ + type Delegate = Inner::Delegate; + fn delegate_neighbor(&'this mut self) -> Self::Delegate { + self.inner.delegate_neighbor() + } +} + +impl<'doc, Inner, VT> ExpandBeam> for DocumentSearchAccessor +where + Inner: ExpandBeam, + VT: ?Sized, +{ +} + +impl SearchExt for DocumentSearchAccessor +where + Inner: SearchExt, +{ + fn starting_points( + &self, + ) -> impl std::future::Future>> + Send { + self.inner.starting_points() + } + fn terminate_early(&mut self) -> bool { + self.inner.terminate_early() + } +} + +impl<'doc, Inner, DP, VT> + SearchStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type QueryComputer = Inner::QueryComputer; + type SearchAccessorError = Inner::SearchAccessorError; + type SearchAccessor<'a> = DocumentSearchAccessor>; + + fn search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +impl<'doc, Inner, DP, VT> + InsertStrategy>, Document<'doc, VT>> + for DocumentInsertStrategy +where + Inner: InsertStrategy, + DP: DataProvider, + VT: Sync + Send + ?Sized + 'static, +{ + type PruneStrategy = DocumentPruneStrategy; + + fn prune_strategy(&self) -> Self::PruneStrategy { + DocumentPruneStrategy::new(self.inner.prune_strategy()) + } + + fn insert_search_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a > as DataProvider>::Context, + ) -> Result, Self::SearchAccessorError> { + let inner_accessor = self + .inner + .insert_search_accessor(provider.inner_provider(), context)?; + Ok(DocumentSearchAccessor::new(inner_accessor)) + } +} + +#[derive(Clone, Copy)] +pub struct DocumentPruneStrategy { + inner: Inner, +} + +impl DocumentPruneStrategy { + pub fn new(inner: Inner) -> Self { + Self { inner } + } +} + +impl PruneStrategy>> + for DocumentPruneStrategy +where + DP: DataProvider, + Inner: PruneStrategy, +{ + type DistanceComputer = Inner::DistanceComputer; + type PruneAccessor<'a> = Inner::PruneAccessor<'a>; + type PruneAccessorError = Inner::PruneAccessorError; + + fn prune_accessor<'a>( + &'a self, + provider: &'a DocumentProvider>, + context: &'a DP::Context, + ) -> Result, Self::PruneAccessorError> { + self.inner + .prune_accessor(provider.inner_provider(), context) + } +} diff --git a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs index 6b496271b..1fabf5f54 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/document_provider.rs @@ -77,7 +77,7 @@ impl<'a, VT, DP, AS> SetElement> for DocumentProvider where DP: DataProvider + Delete + SetElement, AS: AttributeStore + AsyncFriendly, - VT: Sync + Send, + VT: Sync + Send + ?Sized, { type SetError = ANNError; type Guard = >::Guard; diff --git a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs index 6b82a68b1..cab250c76 100644 --- a/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs +++ b/diskann-label-filter/src/encoded_attribute_provider/roaring_attribute_store.rs @@ -15,7 +15,7 @@ use diskann::{utils::VectorId, ANNError, ANNErrorKind, ANNResult}; use diskann_utils::future::AsyncFriendly; use std::sync::{Arc, RwLock}; -pub(crate) struct RoaringAttributeStore +pub struct RoaringAttributeStore where IT: VectorId + AsyncFriendly, { @@ -24,6 +24,15 @@ where inv_index: Arc>>, } +impl Default for RoaringAttributeStore +where + IT: VectorId, +{ + fn default() -> Self { + Self::new() + } +} + impl RoaringAttributeStore where IT: VectorId, diff --git a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs index 962d361d7..ab82dad56 100644 --- a/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs +++ b/diskann-label-filter/src/inline_beta_search/encoded_document_accessor.rs @@ -11,7 +11,7 @@ use diskann::{ provider::{Accessor, AsNeighbor, BuildQueryComputer, DelegateNeighbor, HasId}, ANNError, ANNErrorKind, }; -use diskann_utils::{future::AsyncFriendly, Reborrow}; +use diskann_utils::Reborrow; use roaring::RoaringTreemap; use crate::traits::attribute_accessor::AttributeAccessor; @@ -28,7 +28,7 @@ use crate::{ type AttrAccessor = EncodedAttributeAccessor::Id>>; -pub(crate) struct EncodedDocumentAccessor +pub struct EncodedDocumentAccessor where IA: HasId, { @@ -136,7 +136,7 @@ where Some(set) => Ok(set.into_owned()), None => Err(ANNError::message( ANNErrorKind::IndexError, - "No labels were found for vector", + format!("No labels were found for vector:{:?}", id), )), } })?; @@ -204,17 +204,17 @@ where } } -impl BuildQueryComputer> for EncodedDocumentAccessor +impl<'a, IA, Q> BuildQueryComputer> for EncodedDocumentAccessor where IA: BuildQueryComputer, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputerError = ANNError; type QueryComputer = InlineBetaComputer; fn build_query_computer( &self, - from: &FilteredQuery, + from: &FilteredQuery<'a, Q>, ) -> Result { let inner_computer = self .inner_accessor @@ -234,7 +234,7 @@ impl ExpandBeam for EncodedDocumentAccessor where IA: Accessor, EncodedDocumentAccessor: BuildQueryComputer + AsNeighbor, - Q: Clone + AsyncFriendly, + Q: Send + Sync + ?Sized, { } diff --git a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs index 6e4015948..f7a4f505d 100644 --- a/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs +++ b/diskann-label-filter/src/inline_beta_search/inline_beta_filter.rs @@ -9,7 +9,6 @@ use diskann::neighbor::Neighbor; use diskann::provider::{Accessor, BuildQueryComputer, DataProvider}; use diskann::ANNError; -use diskann_utils::future::AsyncFriendly; use diskann_vector::PreprocessedDistanceFunction; use roaring::RoaringTreemap; @@ -28,13 +27,22 @@ pub struct InlineBetaStrategy { inner: Strategy, } +impl InlineBetaStrategy { + /// Create a new InlineBetaStrategy with the given beta value and inner strategy. + pub fn new(beta: f32, inner: Strategy) -> Self { + Self { beta, inner } + } +} + impl - SearchStrategy>, FilteredQuery> - for InlineBetaStrategy + SearchStrategy< + DocumentProvider>, + FilteredQuery<'_, Q>, + > for InlineBetaStrategy where DP: DataProvider, Strategy: SearchStrategy, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type QueryComputer = InlineBetaComputer; type SearchAccessorError = ANNError; @@ -66,12 +74,12 @@ where impl diskann::graph::glue::DefaultPostProcessor< DocumentProvider>, - FilteredQuery, + FilteredQuery<'_, Q>, > for InlineBetaStrategy where DP: DataProvider, Strategy: diskann::graph::glue::DefaultPostProcessor, - Q: AsyncFriendly + Clone, + Q: Send + Sync + ?Sized, { type Processor = FilterResults; @@ -115,22 +123,15 @@ where let (vec, attrs) = changing.destructure(); let sim = self.inner_computer.evaluate_similarity(vec); let pred_eval = PredicateEvaluator::new(attrs); - match self.filter_expr.encoded_filter_expr().accept(&pred_eval) { - Ok(matched) => { - if matched { - sim * self.beta_value - } else { - sim - } - } - Err(_) => { - //TODO: If predicate evaluation fails, we are taking the approach that we will simply - //return the score returned by the inner computer, as though no predicate was specified. - tracing::warn!( - "Predicate evaluation failed in OnlineBetaComputer::evaluate_similarity()" - ); - sim - } + if self + .filter_expr + .encoded_filter_expr() + .accept(&pred_eval) + .expect("Expected predicate evaluation to not error out!") + { + sim * self.beta_value + } else { + sim } } } @@ -139,19 +140,28 @@ pub struct FilterResults { inner_post_processor: IPP, } -impl SearchPostProcess, FilteredQuery> +impl FilterResults { + #[cfg(test)] + pub(crate) fn new(inner_post_processor: IPP) -> Self { + Self { + inner_post_processor, + } + } +} + +impl<'a, Q, IA, IPP> SearchPostProcess, FilteredQuery<'a, Q>> for FilterResults where IA: BuildQueryComputer, - Q: Clone + AsyncFriendly, IPP: SearchPostProcess + Send + Sync, + Q: Send + Sync + ?Sized, { type Error = ANNError; async fn post_process( &self, accessor: &mut EncodedDocumentAccessor, - query: &FilteredQuery, + query: &FilteredQuery<'a, Q>, computer: &InlineBetaComputer<>::QueryComputer>, candidates: I, output: &mut B, diff --git a/diskann-label-filter/src/lib.rs b/diskann-label-filter/src/lib.rs index 106845f98..414b868d4 100644 --- a/diskann-label-filter/src/lib.rs +++ b/diskann-label-filter/src/lib.rs @@ -40,6 +40,7 @@ pub mod encoded_attribute_provider { pub(crate) mod ast_id_expr; pub(crate) mod ast_label_id_mapper; pub(crate) mod attribute_encoder; + pub mod document_insert_strategy; pub mod document_provider; pub mod encoded_attribute_accessor; pub(crate) mod encoded_filter_expr; @@ -52,6 +53,10 @@ pub mod tests { #[cfg(test)] pub mod common; #[cfg(test)] + pub mod document_insert_strategy_test; + #[cfg(test)] + pub mod inline_beta_filter_test; + #[cfg(test)] pub mod roaring_attribute_store_test; } diff --git a/diskann-label-filter/src/query.rs b/diskann-label-filter/src/query.rs index 15c42501f..b1a7a6409 100644 --- a/diskann-label-filter/src/query.rs +++ b/diskann-label-filter/src/query.rs @@ -9,18 +9,18 @@ use crate::ASTExpr; /// The Readme.md file in the label-filter folder describes the format /// of the query expression. #[derive(Clone)] -pub struct FilteredQuery { - query: V, +pub struct FilteredQuery<'a, V: ?Sized> { + query: &'a V, filter_expr: ASTExpr, } -impl FilteredQuery { - pub fn new(query: V, filter_expr: ASTExpr) -> Self { +impl<'a, V: ?Sized> FilteredQuery<'a, V> { + pub fn new(query: &'a V, filter_expr: ASTExpr) -> Self { Self { query, filter_expr } } - pub(crate) fn query(&self) -> &V { - &self.query + pub(crate) fn query(&self) -> &'a V { + self.query } pub(crate) fn filter_expr(&self) -> &ASTExpr { diff --git a/diskann-label-filter/src/tests/document_insert_strategy_test.rs b/diskann-label-filter/src/tests/document_insert_strategy_test.rs new file mode 100644 index 000000000..5fb78dd75 --- /dev/null +++ b/diskann-label-filter/src/tests/document_insert_strategy_test.rs @@ -0,0 +1,124 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use diskann::{ + graph::{ + glue::{InsertStrategy, PruneStrategy, SearchExt, SearchStrategy}, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + provider::BuildQueryComputer, +}; +use diskann_vector::distance::Metric; + +use crate::{ + document::Document, + encoded_attribute_provider::{ + document_insert_strategy::{ + DocumentInsertStrategy, DocumentPruneStrategy, DocumentSearchAccessor, + }, + document_provider::DocumentProvider, + roaring_attribute_store::RoaringAttributeStore, + }, +}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Build a minimal test provider with a single start point and three dimensions. +fn make_test_provider() -> Provider { + let config = Config::new( + Metric::L2, + 10, + StartPoint::new(u32::MAX, vec![1.0f32, 2.0, 0.0]), + ) + .expect("test provider config should be valid"); + Provider::new(config) +} + +fn make_doc_provider(provider: Provider) -> DocumentProvider> { + DocumentProvider::new(provider, RoaringAttributeStore::new()) +} + +/// `search_accessor` successfully creates a `DocumentSearchAccessor` wrapping the +/// inner accessor. +#[test] +fn test_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as SearchStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_insert_search_accessor_creates_wrapped_accessor() { + let strategy = DocumentInsertStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as InsertStrategy< + DocumentProvider>, + Document<'_, [f32]>, + >>::insert_search_accessor(&strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_prune_accessor_delegates_to_inner_provider() { + let doc_prune_strategy = DocumentPruneStrategy::new(Strategy::new()); + let provider = make_doc_provider(make_test_provider()); + let context = Context::new(); + + let result = as PruneStrategy< + DocumentProvider>, + >>::prune_accessor(&doc_prune_strategy, &provider, &context); + + assert!(result.is_ok()); +} + +#[test] +fn test_build_query_computer_extracts_vector_from_document() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let doc_accessor = DocumentSearchAccessor::new(inner_accessor); + + let vector = vec![1.0f32, 2.0, 0.0]; + let doc = Document::new(vector.as_slice(), vec![]); + + let result = as BuildQueryComputer>>::build_query_computer(&doc_accessor, &doc); + + assert!( + result.is_ok(), + "build_query_computer should succeed for a valid vector" + ); +} + +#[test] +fn test_terminate_early_delegates_to_inner() { + let provider = make_test_provider(); + let context = Context::new(); + let strategy_inner = Strategy::new(); + let mut inner_accessor = strategy_inner + .search_accessor(&provider, &context) + .expect("creating search accessor should succeed"); + let inner_terminate_early = inner_accessor.terminate_early(); + let mut doc_accessor = DocumentSearchAccessor::new(inner_accessor); + assert_eq!( + inner_terminate_early, + doc_accessor.terminate_early(), + "terminate_early should have same value as inner accessor" + ); +} diff --git a/diskann-label-filter/src/tests/inline_beta_filter_test.rs b/diskann-label-filter/src/tests/inline_beta_filter_test.rs new file mode 100644 index 000000000..079f91eff --- /dev/null +++ b/diskann-label-filter/src/tests/inline_beta_filter_test.rs @@ -0,0 +1,224 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ + +use std::sync::{Arc, RwLock}; + +use diskann::{ + graph::{ + glue::{self, SearchPostProcess, SearchStrategy}, + search_output_buffer::IdDistance, + test::provider::{Config, Context, Provider, StartPoint, Strategy}, + }, + neighbor::Neighbor, + provider::{BuildQueryComputer, SetElement}, +}; +use diskann_vector::{distance::Metric, PreprocessedDistanceFunction}; +use roaring::RoaringTreemap; +use serde_json::Value; + +use crate::{ + attribute::{Attribute, AttributeValue}, + document::EncodedDocument, + encoded_attribute_provider::{ + attribute_encoder::AttributeEncoder, encoded_filter_expr::EncodedFilterExpr, + roaring_attribute_store::RoaringAttributeStore, + }, + inline_beta_search::{ + encoded_document_accessor::EncodedDocumentAccessor, + inline_beta_filter::{FilterResults, InlineBetaComputer}, + }, + query::FilteredQuery, + traits::attribute_store::AttributeStore, + ASTExpr, CompareOp, +}; + +// ----------------------------------------------------------------------- +// Stub inner distance computer +// ----------------------------------------------------------------------- + +/// Always returns a fixed constant distance, regardless of the vector value. +struct ConstComputer(f32); + +impl PreprocessedDistanceFunction<&[f32], f32> for ConstComputer { + fn evaluate_similarity(&self, _: &[f32]) -> f32 { + self.0 + } +} + +// ----------------------------------------------------------------------- +// Helper: build an AttributeEncoder + ASTExpr for `field == value`, +// returning (attr_map, ast_expr, encoded_id_of_that_attribute). +// ----------------------------------------------------------------------- + +fn setup_encoder_and_filter( + field: &str, + value: &str, +) -> (Arc>, ASTExpr, u64) { + let mut encoder = AttributeEncoder::new(); + let attr = Attribute::from_value(field, AttributeValue::String(value.to_owned())); + let encoded_id = encoder.insert(&attr); + let attr_map = Arc::new(RwLock::new(encoder)); + let ast_expr = ASTExpr::Compare { + field: field.to_string(), + op: CompareOp::Eq(Value::String(value.to_string())), + }; + (attr_map, ast_expr, encoded_id) +} + +// ----------------------------------------------------------------------- +// Test 1: when the filter matches, evaluate_similarity returns inner * beta +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_filter_match_scales_by_beta() { + let (attr_map, ast_expr, color_red_id) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Bitmap contains the encoded ID for "color=red" → predicate matches + let mut matching_map = RoaringTreemap::new(); + matching_map.insert(color_red_id); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &matching_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist * beta, + "a matched filter should multiply the inner similarity by beta" + ); +} + +// ----------------------------------------------------------------------- +// Test 2: when the filter does not match, evaluate_similarity is unchanged +// ----------------------------------------------------------------------- + +#[test] +fn test_evaluate_similarity_no_filter_match_preserves_score() { + let (attr_map, ast_expr, _) = setup_encoder_and_filter("color", "red"); + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map).expect("filter expr"); + + let beta = 2.5_f32; + let inner_dist = 4.0_f32; + let computer = InlineBetaComputer::new(ConstComputer(inner_dist), beta, filter_expr); + + // Empty bitmap → no attribute matches the predicate + let empty_map = RoaringTreemap::new(); + let doc = EncodedDocument::new(&[1.0f32, 0.0][..], &empty_map); + + assert_eq!( + computer.evaluate_similarity(doc), + inner_dist, + "an unmatched filter should leave the inner similarity unchanged" + ); +} + +// ----------------------------------------------------------------------- +// Test 3: post_process forwards only filter-matching candidates to the +// inner post processor (and therefore to the output buffer). +// ----------------------------------------------------------------------- + +#[test] +fn test_post_process_only_passes_matching_candidates_to_inner() { + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .expect("test tokio runtime"); + + // IDs 0 and 1 carry color=red (should pass the filter) + // IDs 2 and 3 carry color=blue (should be dropped by the filter) + let attr_store = RoaringAttributeStore::::new(); + let red = Attribute::from_value("color", AttributeValue::String("red".to_owned())); + let blue = Attribute::from_value("color", AttributeValue::String("blue".to_owned())); + for id in 0u32..2 { + attr_store + .set_element(&id, std::slice::from_ref(&red)) + .expect("set red attr"); + } + for id in 2u32..4 { + attr_store + .set_element(&id, std::slice::from_ref(&blue)) + .expect("set blue attr"); + } + + // The attribute_map is shared so EncodedFilterExpr sees the same encodings + // as those stored by the attribute store. + let attr_map = attr_store.attribute_map(); + + let ast_expr = ASTExpr::Compare { + field: "color".to_string(), + op: CompareOp::Eq(Value::String("red".to_string())), + }; + let filter_expr = EncodedFilterExpr::new(&ast_expr, attr_map.clone()).expect("filter expr"); + + // Build the inner vector provider: start point at u32::MAX + 2-D zero vectors for 0..3 + let config = Config::new(Metric::L2, 10, StartPoint::new(u32::MAX, vec![1.0f32, 0.0])) + .expect("provider config"); + let inner_provider = Provider::new(config); + let ctx = Context::new(); + rt.block_on(async { + for id in 0u32..4 { + inner_provider + .set_element(&ctx, &id, &[0.0f32, 0.0] as &[f32]) + .await + .expect("add vector to inner provider"); + } + }); + + // Obtain the inner search accessor and derive an inner computer from it + let strategy = Strategy::new(); + let inner_accessor = strategy + .search_accessor(&inner_provider, &ctx) + .expect("inner accessor"); + let inner_computer = inner_accessor + .build_query_computer(&[0.0f32, 0.0][..]) + .expect("inner computer"); + + // Wrap accessor + attribute store into an EncodedDocumentAccessor + let attribute_accessor = attr_store.attribute_accessor().expect("attribute accessor"); + let mut doc_accessor = + EncodedDocumentAccessor::new(inner_accessor, attribute_accessor, attr_map, 2.0); + + let computer = InlineBetaComputer::new(inner_computer, 2.0, filter_expr); + + // Four candidates: 0 and 1 match (red); 2 and 3 do not (blue) + let candidates = [ + Neighbor::new(0u32, 1.0_f32), + Neighbor::new(1u32, 2.0_f32), + Neighbor::new(2u32, 3.0_f32), + Neighbor::new(3u32, 4.0_f32), + ]; + + let mut ids = [u32::MAX; 4]; + let mut distances = [f32::MAX; 4]; + let mut output = IdDistance::new(&mut ids, &mut distances); + + let query_vec = [0.0f32, 0.0]; + let filter_query = FilteredQuery::new(&query_vec[..], ast_expr); + + // CopyIds simply copies whatever it receives into the output buffer, + // so the output reflects exactly what FilterResults lets through. + let count = rt + .block_on(FilterResults::new(glue::CopyIds).post_process( + &mut doc_accessor, + &filter_query, + &computer, + candidates.into_iter(), + &mut output, + )) + .expect("post_process"); + + // Only the two red-labeled candidates should have been forwarded + assert_eq!(count, 2, "exactly 2 of 4 candidates should pass the filter"); + let passed = &ids[..count]; + assert!( + passed.contains(&0), + "ID 0 (color=red) should pass the filter" + ); + assert!( + passed.contains(&1), + "ID 1 (color=red) should pass the filter" + ); +} diff --git a/diskann-tools/Cargo.toml b/diskann-tools/Cargo.toml index ae987dca9..0803dae32 100644 --- a/diskann-tools/Cargo.toml +++ b/diskann-tools/Cargo.toml @@ -12,7 +12,7 @@ license.workspace = true [dependencies] byteorder.workspace = true clap = { workspace = true, features = ["derive"] } -diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` +diskann-providers = { workspace = true, default-features = false } # see `linalg/Cargo.toml` diskann-vector = { workspace = true } diskann-disk = { workspace = true } diskann-utils = { workspace = true } @@ -24,6 +24,7 @@ ordered-float = "4.2.0" rand_distr.workspace = true rand.workspace = true serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true bincode.workspace = true opentelemetry.workspace = true diskann-quantization = { workspace = true } diff --git a/diskann-tools/src/utils/ground_truth.rs b/diskann-tools/src/utils/ground_truth.rs index 883f0c2ec..e32325083 100644 --- a/diskann-tools/src/utils/ground_truth.rs +++ b/diskann-tools/src/utils/ground_truth.rs @@ -4,7 +4,7 @@ */ use bit_set::BitSet; -use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels}; +use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels, ASTExpr}; use std::{io::Write, mem::size_of, str::FromStr}; @@ -26,18 +26,100 @@ use diskann_utils::{ use diskann_vector::{distance::Metric, DistanceFunction}; use itertools::Itertools; use rayon::prelude::*; +use serde_json::{Map, Value}; use crate::utils::{search_index_utils, CMDResult, CMDToolError}; +/// Evaluates a query expression against a label, expanding array-valued fields by recursion. +/// +/// For each key in the JSON object, if the value is an array the expression is evaluated +/// against one element at a time (any-match semantics) without materialising the full +/// Cartesian product. Non-object labels are evaluated directly. +fn eval_query_with_array_expansion(query_expr: &ASTExpr, label: &Value) -> bool { + match label { + Value::Object(map) => { + let entries: Vec<(&String, &Value)> = map.iter().collect(); + eval_map_recursive(query_expr, &entries, Map::new()) + } + _ => eval_query_expr(query_expr, label), + } +} + +/// Walk `entries` one field at a time, accumulating scalar values into `current`. +/// +/// * Scalar fields are inserted directly and the walk continues with the remaining entries. +/// * Array fields branch once per element; evaluation short-circuits on the first branch +/// that returns `true`. +/// * An empty array is treated as an absent field (preserving the previous behaviour). +/// * When all fields have been consumed, `eval_query_expr` is called on the accumulated object. +fn eval_map_recursive( + query_expr: &ASTExpr, + entries: &[(&String, &Value)], + mut current: Map, +) -> bool { + match entries { + [] => eval_query_expr(query_expr, &Value::Object(current)), + [(key, Value::Array(arr)), rest @ ..] => { + if arr.is_empty() { + // Omit this key, matching the original behaviour for empty arrays. + eval_map_recursive(query_expr, rest, current) + } else { + for item in arr { + let mut branch = current.clone(); + branch.insert((*key).clone(), item.clone()); + if eval_map_recursive(query_expr, rest, branch) { + return true; + } + } + false + } + } + [(key, val), rest @ ..] => { + current.insert((*key).clone(), (*val).clone()); + eval_map_recursive(query_expr, rest, current) + } + } +} + pub fn read_labels_and_compute_bitmap( base_label_filename: &str, query_label_filename: &str, ) -> CMDResult> { // Read base labels let base_labels = read_baselabels(base_label_filename)?; + tracing::info!( + "Loaded {} base labels from {}", + base_labels.len(), + base_label_filename + ); + + // Print first few base labels for debugging + for (i, label) in base_labels.iter().take(3).enumerate() { + tracing::debug!( + "Base label sample [{}]: doc_id={}, label={}", + i, + label.doc_id, + label.label + ); + } // Parse queries and evaluate against labels let parsed_queries = read_and_parse_queries(query_label_filename)?; + tracing::info!( + "Loaded {} queries from {}", + parsed_queries.len(), + query_label_filename + ); + + // Print first few queries for debugging + for (i, (query_id, query_expr)) in parsed_queries.iter().take(3).enumerate() { + tracing::debug!( + "Query sample [{}]: query_id={}, expr={:?}", + i, + query_id, + query_expr + ); + } // using the global threadpool is fine here #[allow(clippy::disallowed_methods)] @@ -46,7 +128,17 @@ pub fn read_labels_and_compute_bitmap( .map(|(_query_id, query_expr)| { let mut bitmap = BitSet::new(); for base_label in base_labels.iter() { - if eval_query_expr(query_expr, &base_label.label) { + // Handle case where base_label.label is an array - check if any element matches + // Also expand array-valued fields within objects (e.g., {"country": ["AU", "NZ"]}) + let matches = if let Some(array) = base_label.label.as_array() { + array + .iter() + .any(|item| eval_query_with_array_expansion(query_expr, item)) + } else { + eval_query_with_array_expansion(query_expr, &base_label.label) + }; + + if matches { bitmap.insert(base_label.doc_id); } } @@ -54,6 +146,44 @@ pub fn read_labels_and_compute_bitmap( }) .collect(); + // Debug: Print match statistics for each query + let total_matches: usize = query_bitmaps.iter().map(|b| b.len()).sum(); + let queries_with_matches = query_bitmaps.iter().filter(|b| !b.is_empty()).count(); + tracing::info!( + "Filter matching summary: {} total matches across {} queries ({} queries have matches)", + total_matches, + query_bitmaps.len(), + queries_with_matches + ); + + // Print per-query match counts + for (i, bitmap) in query_bitmaps.iter().enumerate() { + if i < 10 || bitmap.is_empty() { + tracing::debug!( + "Query {}: {} base vectors matched the filter", + i, + bitmap.len() + ); + } + } + + // If no matches, print more diagnostic info + if total_matches == 0 { + tracing::warn!("WARNING: No base vectors matched any query filters!"); + tracing::warn!( + "This could indicate a format mismatch between base labels and query filters." + ); + + // Try to identify what keys exist in base labels vs queries + if let Some(first_label) = base_labels.first() { + tracing::warn!( + "First base label (full): doc_id={}, label={}", + first_label.doc_id, + first_label.label + ); + } + } + Ok(query_bitmaps) } @@ -196,6 +326,47 @@ pub fn compute_ground_truth_from_datafiles< assert_ne!(ground_truth.len(), 0, "No ground-truth results computed"); + // Debug: Print top K matches for each query + tracing::info!( + "Ground truth computed for {} queries with recall_at={}", + ground_truth.len(), + recall_at + ); + for (query_idx, npq) in ground_truth.iter().enumerate() { + let neighbors: Vec<_> = npq.iter().collect(); + let neighbor_count = neighbors.len(); + + if query_idx < 10 { + // Print top K IDs and distances for first 10 queries + let top_ids: Vec = neighbors.iter().take(10).map(|n| n.id).collect(); + let top_dists: Vec = neighbors.iter().take(10).map(|n| n.distance).collect(); + tracing::debug!( + "Query {}: {} neighbors found. Top IDs: {:?}, Top distances: {:?}", + query_idx, + neighbor_count, + top_ids, + top_dists + ); + } + + if neighbor_count == 0 { + tracing::warn!("Query {} has 0 neighbors in ground truth!", query_idx); + } + } + + // Summary stats + let total_neighbors: usize = ground_truth.iter().map(|npq| npq.iter().count()).sum(); + let queries_with_neighbors = ground_truth + .iter() + .filter(|npq| npq.iter().count() > 0) + .count(); + tracing::info!( + "Ground truth summary: {} total neighbors, {} queries have neighbors, {} queries have 0 neighbors", + total_neighbors, + queries_with_neighbors, + ground_truth.len() - queries_with_neighbors + ); + if has_vector_filters || has_query_bitmaps { let ground_truth_collection = ground_truth .into_iter()