diff --git a/src/planner.rs b/src/planner.rs index 2af009b..323eca6 100644 --- a/src/planner.rs +++ b/src/planner.rs @@ -34,10 +34,7 @@ use std::collections::{BinaryHeap, HashMap, HashSet}; use std::fmt; use std::sync::Arc; -use arrow_array::{ - Array, BooleanArray, FixedSizeListArray, Float32Array, Float64Array, LargeListArray, ListArray, - RecordBatch, -}; +use arrow_array::{Array, BooleanArray, Float32Array, RecordBatch}; use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion::common::Result; @@ -58,10 +55,28 @@ use usearch::ScalarKind; use tracing::Instrument; +use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::logical_expr::Expr; + use crate::lookup::extract_keys_as_u64; use crate::node::{DistanceType, USearchNode}; use crate::registry::USearchRegistry; +/// Strip table qualifiers from column references so expressions can be +/// resolved against an unqualified Arrow schema. Mirrors the pattern in +/// DataFusion's own `physical_planner.rs::strip_column_qualifiers`. +fn strip_column_qualifier(expr: &Expr) -> Expr { + match expr.clone().transform(|e| match &e { + Expr::Column(col) if col.relation.is_some() => Ok(Transformed::yes(Expr::Column( + datafusion::common::Column::new_unqualified(col.name.clone()), + ))), + _ => Ok(Transformed::no(e)), + }) { + Ok(t) => t.data, + Err(_) => expr.clone(), + } +} + // ── QueryPlanner wrapper ────────────────────────────────────────────────────── pub struct USearchQueryPlanner { @@ -142,17 +157,32 @@ impl ExtensionPlanner for USearchExecPlanner { .map(|f| create_physical_expr(f, &node.schema, exec_props)) .collect::>()?; - // For the filtered path, pre-plan two provider scans: - // 1. Pre-scan (scalar + _key only, no vector col) — cheap selectivity estimation. - // 2. Full scan (all cols including vector) — only executed at runtime if low-sel. - // Both scans receive the node's WHERE filters for Parquet predicate pushdown. - let (provider_scan, full_scan) = if !node.filters.is_empty() { + // For the filtered path, pre-plan a provider scan: + // Pre-scan (_key + filter cols only) — cheap selectivity estimation. + // The scan receives the node's WHERE filters for Parquet predicate pushdown. + let (provider_scan, prescan_filters) = if !node.filters.is_empty() { let scan_schema = registered.scan_provider.schema(); - let vec_col_idx = scan_schema.index_of(&node.vector_col).ok(); - // Pre-scan projection: all columns except vector. + // Pre-scan projection: _key + columns referenced by filters. + // Only these are needed — _key to collect valid keys, and filter + // columns for predicate evaluation. Reading anything else wastes I/O. + let filter_col_names: HashSet<&str> = node + .filters + .iter() + .flat_map(|f| f.column_refs()) + .map(|c| c.name.as_str()) + .collect(); + let key_col_idx = scan_schema.index_of(®istered.key_col).map_err(|_| { + DataFusionError::Execution(format!( + "USearchExec: key column '{}' not found in scan provider schema", + registered.key_col + )) + })?; let scalar_projection: Vec = (0..scan_schema.fields().len()) - .filter(|&i| Some(i) != vec_col_idx) + .filter(|&i| { + i == key_col_idx + || filter_col_names.contains(scan_schema.field(i).name().as_str()) + }) .collect(); let pre_scan = registered @@ -160,22 +190,24 @@ impl ExtensionPlanner for USearchExecPlanner { .scan(session_state, Some(&scalar_projection), &node.filters, None) .await?; - // Full scan: all columns (vector included for distance computation). - // Only created if the vector column exists in the scan provider. - let full = if vec_col_idx.is_some() { - Some( - registered - .scan_provider - .scan(session_state, None, &node.filters, None) - .await?, - ) - } else { - None - }; - - (Some(pre_scan), full) + // Compile physical filters against the pre-scan's projected schema + // so column indices match the narrower batch layout. + let pre_scan_schema = pre_scan.schema(); + let pre_scan_df_schema = + datafusion::common::DFSchema::try_from(pre_scan_schema.as_ref().clone())?; + let prescan_physical_filters: Vec> = node + .filters + .iter() + .map(|f| { + // Strip table qualifiers — the projected schema is unqualified. + let unqualified = strip_column_qualifier(f); + create_physical_expr(&unqualified, &pre_scan_df_schema, exec_props) + }) + .collect::>()?; + + (Some(pre_scan), prescan_physical_filters) } else { - (None, None) + (None, vec![]) }; Ok(Some(Arc::new(USearchExec::new(SearchParams { @@ -185,13 +217,12 @@ impl ExtensionPlanner for USearchExecPlanner { k: node.k, distance_type: node.distance_type.clone(), physical_filters, + prescan_filters, schema: registered.schema.clone(), key_col: registered.key_col.clone(), scalar_kind: registered.scalar_kind, - vector_col: node.vector_col.clone(), brute_force_threshold: registered.config.brute_force_selectivity_threshold, provider_scan, - full_scan, })))) } } @@ -207,17 +238,17 @@ struct SearchParams { k: usize, distance_type: DistanceType, physical_filters: Vec>, + /// Physical filters compiled against the pre-scan's projected schema. + /// Column indices match the narrow _key + filter-col projection, not the + /// full table schema. Used by adaptive_filtered_execute for pre-scan evaluation. + prescan_filters: Vec>, schema: SchemaRef, key_col: String, scalar_kind: ScalarKind, - vector_col: String, brute_force_threshold: f64, - /// Pre-planned provider scan for the filtered path (scalar + _key only, no vector col). + /// Pre-planned provider scan for the filtered path (_key + filter cols only). /// Used for selectivity estimation. None for the unfiltered path. provider_scan: Option>, - /// Pre-planned full scan (all cols including vector) for the low-selectivity - /// Parquet-native path. None when unfiltered or when no vector column exists. - full_scan: Option>, } // ── Physical execution node ─────────────────────────────────────────────────── @@ -264,14 +295,10 @@ impl ExecutionPlan for USearchExec { &self.properties } fn children(&self) -> Vec<&Arc> { - let mut children = Vec::new(); - if let Some(ref scan) = self.params.provider_scan { - children.push(scan); + match self.params.provider_scan { + Some(ref scan) => vec![scan], + None => vec![], } - if let Some(ref full) = self.params.full_scan { - children.push(full); - } - children } fn with_new_children( @@ -286,12 +313,8 @@ impl ExecutionPlan for USearchExec { ))); } let mut params = self.params.clone(); - let mut iter = children.into_iter(); if params.provider_scan.is_some() { - params.provider_scan = Some(iter.next().unwrap()); - } - if params.full_scan.is_some() { - params.full_scan = Some(iter.next().unwrap()); + params.provider_scan = Some(children.into_iter().next().unwrap()); } Ok(Arc::new(USearchExec::new(params))) } @@ -434,7 +457,7 @@ async fn adaptive_filtered_execute( async { while let Some(batch_result) = stream.next().await { let batch = batch_result?; - let mask = evaluate_filters(¶ms.physical_filters, &batch)?; + let mask = evaluate_filters(¶ms.prescan_filters, &batch)?; let keys = extract_keys_as_u64(batch.column(scan_key_col_idx).as_ref())?; for row_idx in 0..batch.num_rows() { @@ -461,10 +484,9 @@ async fn adaptive_filtered_execute( let total = registered.index.size(); let selectivity = valid_keys.len() as f64 / total.max(1) as f64; let threshold = params.brute_force_threshold; - let has_full_scan = params.full_scan.is_some(); - let path = if selectivity <= threshold && has_full_scan { - "parquet-native" + let path = if selectivity <= threshold { + "index-get" } else { "filtered_search" }; @@ -473,12 +495,64 @@ async fn adaptive_filtered_execute( tracing::Span::current().record("usearch.selectivity", selectivity); tracing::Span::current().record("usearch.path", path); - if selectivity <= threshold && has_full_scan { - // ── Low-selectivity: Parquet-native path (no USearch, no SQLite) ── - // Second scan reads all columns including vector. Compute distances - // inline, maintain top-k heap, build result directly from scan data. - let full_scan = params.full_scan.clone().unwrap(); - parquet_native_execute(params, registered, full_scan, task_ctx).await + if selectivity <= threshold { + // ── Low-selectivity: retrieve vectors from USearch index ───────── + // The index stores vectors alongside the graph. Retrieve them by key, + // compute exact distances, keep top-k, then fetch result rows from + // the lookup provider. This avoids the expensive full Parquet scan + // that the previous parquet-native path required. + let top_keys = { + let _span = tracing::info_span!( + "usearch_index_get_distances", + usearch.valid_keys = valid_keys.len(), + ) + .entered(); + index_get_top_k( + ®istered.index, + &valid_keys, + ¶ms.query_vec, + params.k, + registered.scalar_kind, + ¶ms.distance_type, + )? + }; + + if top_keys.is_empty() { + tracing::Span::current().record("usearch.result_count", 0usize); + return Ok(vec![]); + } + + let fetch_keys: Vec = top_keys.iter().map(|&(k, _)| k).collect(); + let key_to_dist: HashMap = top_keys.into_iter().collect(); + + let fetch_keys_count = fetch_keys.len(); + let data_batches = async { + registered + .lookup_provider + .fetch_by_keys(&fetch_keys, ¶ms.key_col, None) + .await + } + .instrument(tracing::info_span!( + "usearch_sqlite_fetch", + usearch.fetch_keys = fetch_keys_count, + )) + .await?; + + let result_batches = { + let _span = tracing::info_span!("usearch_attach_distances").entered(); + attach_distances( + data_batches, + lookup_key_col_idx, + &key_to_dist, + ¶ms.schema, + )? + }; + + tracing::Span::current().record( + "usearch.result_count", + result_batches.iter().map(|b| b.num_rows()).sum::(), + ); + Ok(result_batches) } else { // ── High-selectivity: HNSW filtered_search + SQLite fetch ───────── let matches = tracing::info_span!( @@ -538,164 +612,6 @@ async fn adaptive_filtered_execute( } } -// ── Parquet-native low-selectivity execution ────────────────────────────────── - -/// Execute the low-selectivity path entirely from Parquet — no USearch, no SQLite. -/// -/// Streams the full scan (all columns including vector), evaluates WHERE filters, -/// computes distances for passing rows, maintains a top-k heap, and builds the -/// result directly from scan data. -#[tracing::instrument( - name = "usearch_parquet_native", - skip_all, - fields(usearch.table = %params.table_name) -)] -async fn parquet_native_execute( - params: &SearchParams, - registered: &crate::registry::RegisteredTable, - full_scan: Arc, - task_ctx: Arc, -) -> Result> { - let full_schema = full_scan.schema(); - let vec_col_idx = full_schema.index_of(¶ms.vector_col).map_err(|_| { - DataFusionError::Execution(format!( - "USearchExec: vector column '{}' not found in full scan schema", - params.vector_col - )) - })?; - - // Map each lookup_provider field to its index in the full scan schema by name. - // This avoids silent column mismatches if the two schemas have different orderings. - let lookup_schema = registered.lookup_provider.schema(); - let output_col_indices: Vec = lookup_schema - .fields() - .iter() - .map(|f| { - full_schema.index_of(f.name()).map_err(|_| { - DataFusionError::Execution(format!( - "USearchExec: column '{}' from lookup schema not found in full scan schema", - f.name() - )) - }) - }) - .collect::>()?; - - // Top-k heap: stores (distance, projected_row_slice). - // At low selectivity (<=5%), the number of passing rows is small. - let mut heap: BinaryHeap = BinaryHeap::with_capacity(params.k + 1); - - let mut stream = full_scan.execute(0, task_ctx)?; - while let Some(batch_result) = stream.next().await { - let batch = batch_result?; - let mask = evaluate_filters(¶ms.physical_filters, &batch)?; - - for row_idx in 0..batch.num_rows() { - if mask.is_null(row_idx) || !mask.value(row_idx) { - continue; - } - - let dist = match compute_distance_for_row( - &batch, - vec_col_idx, - row_idx, - ¶ms.query_vec, - registered.scalar_kind, - ¶ms.distance_type, - ) { - Ok(d) if !d.is_nan() => d, - _ => continue, // skip null vectors and NaN distances - }; - - // Project the row to output columns (drop vector col), zero-copy slice. - let row_cols: Vec> = output_col_indices - .iter() - .map(|&i| batch.column(i).slice(row_idx, 1)) - .collect(); - - heap.push(ScoredRow { - distance: dist, - row: row_cols, - }); - if heap.len() > params.k { - heap.pop(); // evict farthest - } - } - } - - if heap.is_empty() { - tracing::Span::current().record("usearch.result_count", 0usize); - return Ok(vec![]); - } - - // Build result: sort by distance ascending, concat into batches. - let mut entries: Vec = heap.into_vec(); - entries.sort_by(|a, b| { - a.distance - .partial_cmp(&b.distance) - .unwrap_or(std::cmp::Ordering::Equal) - }); - - let n = entries.len(); - let mut result_cols: Vec>> = - vec![Vec::with_capacity(n); output_col_indices.len()]; - let mut distances: Vec = Vec::with_capacity(n); - - for entry in &entries { - for (col_idx, col_slice) in entry.row.iter().enumerate() { - result_cols[col_idx].push(col_slice.clone()); - } - distances.push(entry.distance); - } - - // Concatenate per-column arrays. - let concat_cols: Vec> = result_cols - .into_iter() - .map(|slices| { - let refs: Vec<&dyn Array> = slices.iter().map(|a| a.as_ref()).collect(); - datafusion::arrow::compute::concat(&refs) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None)) - }) - .collect::>()?; - - // Append _distance column. - let mut all_cols = concat_cols; - all_cols.push(Arc::new(Float32Array::from(distances))); - - let result_batch = RecordBatch::try_new(params.schema.clone(), all_cols) - .map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?; - - tracing::Span::current().record("usearch.result_count", result_batch.num_rows()); - Ok(vec![result_batch]) -} - -/// A row with its computed distance, for the top-k heap. -/// The max-heap evicts the *farthest* row when it exceeds k. -struct ScoredRow { - distance: f32, - /// Column arrays (one per output column), each a single-row slice. - row: Vec>, -} - -impl PartialEq for ScoredRow { - fn eq(&self, other: &Self) -> bool { - self.distance == other.distance - } -} -impl Eq for ScoredRow {} -impl PartialOrd for ScoredRow { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} -impl Ord for ScoredRow { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - // Max-heap: largest distance at top → gets evicted first. - self.distance - .partial_cmp(&other.distance) - .unwrap_or(std::cmp::Ordering::Less) - } -} - // ── USearch dispatch helpers ────────────────────────────────────────────────── /// Call `index.search` with the native scalar type appropriate for the column. @@ -744,6 +660,138 @@ where } } +/// Retrieve vectors from the USearch index for each valid key, compute exact +/// distances against the query vector, and return the top-k (key, distance) pairs. +/// +/// This is the low-selectivity path: when few rows pass the WHERE filter, it is +/// cheaper to fetch vectors by key from the index (O(1) per key) than to scan the +/// entire Parquet vector column. +/// +/// For `F64` scalar kind, vectors are retrieved and distances computed in f64. +/// For all other kinds (F32, F16, BF16, I8, B1), vectors are retrieved as f32 +/// (USearch dequantizes internally) and distances computed in f32. +fn index_get_top_k( + index: &usearch::Index, + valid_keys: &HashSet, + query_f64: &[f64], + k: usize, + scalar_kind: ScalarKind, + dist_type: &DistanceType, +) -> Result> { + let dim = index.dimensions(); + let mut heap: BinaryHeap = BinaryHeap::with_capacity(k + 1); + + match scalar_kind { + ScalarKind::F64 => { + let mut buf = vec![0.0f64; dim]; + for &key in valid_keys { + let n = index + .get(key, &mut buf) + .map_err(|e| DataFusionError::Execution(format!("index.get({key}): {e}")))?; + if n == 0 { + continue; // key not found in index (e.g. null vector was skipped during build) + } + let dist = compute_raw_distance_f64(&buf, query_f64, dist_type); + if dist.is_nan() { + continue; + } + heap.push(ScoredKey { + distance: dist, + key, + }); + if heap.len() > k { + heap.pop(); + } + } + } + _ => { + let query_f32: Vec = query_f64.iter().map(|&v| v as f32).collect(); + let mut buf = vec![0.0f32; dim]; + for &key in valid_keys { + let n = index + .get(key, &mut buf) + .map_err(|e| DataFusionError::Execution(format!("index.get({key}): {e}")))?; + if n == 0 { + continue; + } + let dist = compute_raw_distance_f32(&buf, &query_f32, dist_type); + if dist.is_nan() { + continue; + } + heap.push(ScoredKey { + distance: dist, + key, + }); + if heap.len() > k { + heap.pop(); + } + } + } + } + + let mut result: Vec<(u64, f32)> = heap + .into_vec() + .into_iter() + .map(|s| (s.key, s.distance)) + .collect(); + result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)); + Ok(result) +} + +/// A key with its computed distance, for the top-k heap. +struct ScoredKey { + distance: f32, + key: u64, +} + +impl PartialEq for ScoredKey { + fn eq(&self, other: &Self) -> bool { + self.distance == other.distance + } +} +impl Eq for ScoredKey {} +impl PartialOrd for ScoredKey { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl Ord for ScoredKey { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.distance + .partial_cmp(&other.distance) + .unwrap_or(std::cmp::Ordering::Less) + } +} + +fn compute_raw_distance_f32(v: &[f32], q: &[f32], dist_type: &DistanceType) -> f32 { + match dist_type { + DistanceType::L2 => v.iter().zip(q).map(|(a, b)| (a - b) * (a - b)).sum(), + DistanceType::Cosine => { + let dot: f32 = v.iter().zip(q).map(|(a, b)| a * b).sum(); + let norm_v: f32 = v.iter().map(|x| x * x).sum::().sqrt(); + let norm_q: f32 = q.iter().map(|x| x * x).sum::().sqrt(); + let denom = norm_v * norm_q; + if denom == 0.0 { 1.0 } else { 1.0 - dot / denom } + } + DistanceType::NegativeDot => -v.iter().zip(q).map(|(a, b)| a * b).sum::(), + } +} + +fn compute_raw_distance_f64(v: &[f64], q: &[f64], dist_type: &DistanceType) -> f32 { + let d = match dist_type { + DistanceType::L2 => v.iter().zip(q).map(|(a, b)| (a - b) * (a - b)).sum::(), + DistanceType::Cosine => { + let dot: f64 = v.iter().zip(q).map(|(a, b)| a * b).sum(); + let norm_v: f64 = v.iter().map(|x| x * x).sum::().sqrt(); + let norm_q: f64 = q.iter().map(|x| x * x).sum::().sqrt(); + let denom = norm_v * norm_q; + if denom == 0.0 { 1.0 } else { 1.0 - dot / denom } + } + DistanceType::NegativeDot => -v.iter().zip(q).map(|(a, b)| a * b).sum::(), + }; + d as f32 +} + // ── Helpers ─────────────────────────────────────────────────────────────────── /// AND all physical filter expressions against a batch. @@ -781,105 +829,6 @@ fn evaluate_filters( /// Extract the distance from a single row of a vector column. /// -/// Handles all combinations of outer array type (FixedSizeList / List / LargeList) -/// and inner element type (Float32 / Float64). The distance is always returned as -/// `f32` — matching the `_distance` column type — regardless of the column's native -/// precision. Query is accepted as `f64` and cast to the column's native type. -fn compute_distance_for_row( - batch: &RecordBatch, - vec_col_idx: usize, - row_idx: usize, - query_f64: &[f64], - scalar_kind: ScalarKind, - dist_type: &DistanceType, -) -> Result { - let col = batch.column(vec_col_idx); - - if col.is_null(row_idx) { - return Err(DataFusionError::Execution( - "null vector in brute-force distance computation".into(), - )); - } - - // Extract the row's inner array, regardless of outer type. - let row_arr: Arc = - if let Some(fsl) = col.as_any().downcast_ref::() { - fsl.value(row_idx) - } else if let Some(la) = col.as_any().downcast_ref::() { - la.value(row_idx) - } else if let Some(la) = col.as_any().downcast_ref::() { - la.value(row_idx) - } else { - return Err(DataFusionError::Execution(format!( - "vector column type not supported in brute-force path (got {:?})", - col.data_type() - ))); - }; - - // Dispatch distance computation by the column's native element type. - match scalar_kind { - ScalarKind::F64 => { - let f64_arr = row_arr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution("F64 column: inner array is not Float64Array".into()) - })?; - let v = f64_arr.values(); - let query = query_f64; - let dist = match dist_type { - DistanceType::L2 => v - .iter() - .zip(query) - .map(|(a, b)| (a - b) * (a - b)) - .sum::(), - DistanceType::Cosine => { - let dot: f64 = v.iter().zip(query).map(|(a, b)| a * b).sum(); - let norm_v: f64 = v.iter().map(|x| x * x).sum::().sqrt(); - let norm_q: f64 = query.iter().map(|x| x * x).sum::().sqrt(); - let denom = norm_v * norm_q; - if denom == 0.0 { 1.0 } else { 1.0 - dot / denom } - } - DistanceType::NegativeDot => -v.iter().zip(query).map(|(a, b)| a * b).sum::(), - }; - Ok(dist as f32) - } - _ => { - // F32 (and any other kind): extract as f32, cast query to f32. - let f32_arr = row_arr - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution(format!( - "F32 column: inner array is not Float32Array (got {:?})", - row_arr.data_type() - )) - })?; - let v = f32_arr.values(); - let query: Vec = query_f64.iter().map(|&x| x as f32).collect(); - let dist = match dist_type { - // L2sq — matches USearch MetricKind::L2sq (no sqrt). - DistanceType::L2 => v - .iter() - .zip(&query) - .map(|(a, b)| (a - b) * (a - b)) - .sum::(), - // Cosine distance = 1 - cosine_similarity. - DistanceType::Cosine => { - let dot: f32 = v.iter().zip(&query).map(|(a, b)| a * b).sum(); - let norm_v: f32 = v.iter().map(|x| x * x).sum::().sqrt(); - let norm_q: f32 = query.iter().map(|x| x * x).sum::().sqrt(); - let denom = norm_v * norm_q; - if denom == 0.0 { 1.0 } else { 1.0 - dot / denom } - } - // Negative inner product — matches USearch MetricKind::IP. - DistanceType::NegativeDot => -v.iter().zip(&query).map(|(a, b)| a * b).sum::(), - }; - Ok(dist) - } - } -} - /// Index of the key column in the lookup provider schema. fn provider_key_col_idx(registered: &crate::registry::RegisteredTable) -> Result { registered