diff --git a/diskann-disk/src/search/pq/pq_scratch.rs b/diskann-disk/src/search/pq/pq_scratch.rs index e90af5212..1df666e18 100644 --- a/diskann-disk/src/search/pq/pq_scratch.rs +++ b/diskann-disk/src/search/pq/pq_scratch.rs @@ -78,6 +78,12 @@ impl PQScratch { self.query_scratch.copy_from_slice(&query[..dim]); Ok(()) } + + /// Return the largest number of PQ vectors whose distances can be computed using this + /// scratch data structure. + pub(crate) fn max_vectors(&self) -> usize { + self.aligned_dist_scratch.len() + } } #[cfg(test)] @@ -112,6 +118,8 @@ mod tests { 0 ); + assert_eq!(pq_scratch.max_vectors(), graph_degree); + // Test set() method let query: Vec = (1..=dim).map(|i| i as f32).collect(); pq_scratch.set(&query).unwrap(); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 1dd518781..da54a58db 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -924,21 +924,42 @@ where let mut accessor = strategy .search_accessor(provider, &DefaultContext) .into_ann_result()?; - let computer = accessor.build_query_computer(query).into_ann_result()?; + + // Derive the batch size from the scratch data structure. Providing too many vectors + // will panic. + let batch_size = accessor.scratch.pq_scratch.max_vectors(); + + // This check should always hold since `graph_degree` comes from + // `diskann::graph::Config` and is forced to be non-zero. But this is defensive + // against misconfiguration. + if batch_size == 0 { + return Err(ANNError::message( + diskann::ANNErrorKind::IndexError, + "pq scratch must support at least one vector", + )); + } + + let mut id_buffer = Vec::with_capacity(batch_size); let mut best = NeighborPriorityQueue::new(neighbors_before_reranking); let mut cmps = 0u32; - let num_points = provider.num_points as u32; - for id in 0..num_points { - if vector_filter(&id) { - let element = accessor.get_element(id).await.into_ann_result()?; - let dist = computer.evaluate_similarity(element); - best.insert(Neighbor::new(id, dist)); - cmps += 1; + let mut iter = (0..provider.num_points as u32).filter(vector_filter); + loop { + id_buffer.clear(); + id_buffer.extend(iter.by_ref().take(batch_size)); + + if id_buffer.is_empty() { + break; } + + accessor.pq_distances(&id_buffer, |dist, id| best.insert(Neighbor::new(id, dist)))?; + cmps += id_buffer.len() as u32; } + // FIXME: This is a temporary bridge. We don't really need the query computer, but + // we do need to satisfy the trait definition until PR 1067 lands. + let computer = accessor.build_query_computer(query).into_ann_result()?; let result_count = strategy .default_post_processor() .post_process(&mut accessor, query, &computer, best.iter(), output)