Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 57 additions & 89 deletions src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@
// 3. Attach _distance column.
//
// ── With filters, high selectivity (> threshold) ─────────────────────────────
// 1. Pre-scan: scan_provider with projection (scalar + _key cols only,
// no vector column) and filter pushdown. Collect valid_keys.
// 1. Pre-scan: CoalescePartitionsExec → FilterExec → DataSourceExec
// (_key + filter cols only). Collect valid_keys from all partitions.
// 2. selectivity = valid_keys.len() / index.size()
// 3. filtered_search(query, k, |key| valid_keys.contains(key))
// 4. lookup_provider.fetch_by_keys() → O(k) rows. Attach _distance.
//
// ── With filters, low selectivity (≤ threshold) — Parquet-native ─────────────
// ── With filters, low selectivity (≤ threshold) — index-get ──────────────────
// 1. Pre-scan: same as above, collect valid_keys and compute selectivity.
// 2. Full scan: scan_provider with all columns (including vector) and
// filter pushdown. Evaluate WHERE per batch, compute distances for
// passing rows, maintain top-k heap. Return directly — no USearch,
// no lookup_provider.
// 2. index.get(key) for each valid_key → compute distances → top-k heap.
// 3. lookup_provider.fetch_by_keys() → O(k) rows. Attach _distance.
//
// All I/O is deferred to USearchExec::execute() — plan_extension is purely
// structural (validate registry entry, compile PhysicalExprs, build scan plans).
// structural (validate registry, compile PhysicalExprs, build scan plans).
//
// The Sort node is kept in the logical plan so DataFusion handles ordering
// by _distance / dist alias.
Expand All @@ -34,16 +32,20 @@ use std::collections::{BinaryHeap, HashMap, HashSet};
use std::fmt;
use std::sync::Arc;

use arrow_array::{Array, BooleanArray, Float32Array, RecordBatch};
use arrow_array::{Array, Float32Array, RecordBatch};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use datafusion::common::Result;
use datafusion::error::DataFusionError;
use datafusion::execution::context::QueryPlanner;
use datafusion::execution::{SendableRecordBatchStream, SessionState, TaskContext};
use datafusion::logical_expr::{LogicalPlan, UserDefinedLogicalNode};
use datafusion::physical_expr::{EquivalenceProperties, PhysicalExpr, create_physical_expr};
use datafusion::physical_expr::{
EquivalenceProperties, PhysicalExpr, conjunction, create_physical_expr,
};
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
use datafusion::physical_plan::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
Expand Down Expand Up @@ -149,18 +151,16 @@ impl ExtensionPlanner for USearchExecPlanner {
}
};

// Compile filter Exprs → PhysicalExprs (synchronous, no I/O).
let exec_props = session_state.execution_props();
let physical_filters: Vec<Arc<dyn PhysicalExpr>> = node
.filters
.iter()
.map(|f| create_physical_expr(f, &node.schema, exec_props))
.collect::<Result<_>>()?;

// For the filtered path, pre-plan a provider scan:
// Pre-scan (_key + filter cols only) — cheap selectivity estimation.
// The scan receives the node's WHERE filters for Parquet predicate pushdown.
let (provider_scan, prescan_filters) = if !node.filters.is_empty() {
// For the filtered path, build a pre-scan plan:
// CoalescePartitionsExec → FilterExec → DataSourceExec
// DataSourceExec may have multiple partitions (file groups); FilterExec
// evaluates the predicate per partition; CoalescePartitionsExec merges
// all partitions into a single stream of matching rows.
// DataFusion's physical optimizer pushes the predicate from FilterExec
// into the Parquet reader for row group / bloom / page index pruning.
let provider_scan = if !node.filters.is_empty() {
let scan_schema = registered.scan_provider.schema();

// Pre-scan projection: _key + columns referenced by filters.
Expand All @@ -185,29 +185,40 @@ impl ExtensionPlanner for USearchExecPlanner {
})
.collect();

let pre_scan = registered
// Don't pass filters to scan() — FilterExec handles filtering, and
// DataFusion's physical optimizer pushes it into the Parquet reader
// for row group / bloom / page index pruning.
let data_source = registered
.scan_provider
.scan(session_state, Some(&scalar_projection), &node.filters, None)
.scan(session_state, Some(&scalar_projection), &[], None)
.await?;

// Compile physical filters against the pre-scan's projected schema
// so column indices match the narrower batch layout.
let pre_scan_schema = pre_scan.schema();
let pre_scan_df_schema =
datafusion::common::DFSchema::try_from(pre_scan_schema.as_ref().clone())?;
let prescan_physical_filters: Vec<Arc<dyn PhysicalExpr>> = node
// Compile physical filters against the projected schema and wrap
// in a FilterExec. Column qualifiers are stripped because the
// projected schema (from Arrow Schema) is unqualified.
let proj_schema = data_source.schema();
let proj_df_schema =
datafusion::common::DFSchema::try_from(proj_schema.as_ref().clone())?;
let phys_filters: Vec<Arc<dyn PhysicalExpr>> = node
.filters
.iter()
.map(|f| {
// Strip table qualifiers — the projected schema is unqualified.
let unqualified = strip_column_qualifier(f);
create_physical_expr(&unqualified, &pre_scan_df_schema, exec_props)
create_physical_expr(&unqualified, &proj_df_schema, exec_props)
})
.collect::<Result<_>>()?;
let predicate = conjunction(phys_filters);
let filtered: Arc<dyn ExecutionPlan> =
Arc::new(FilterExec::try_new(predicate, data_source)?);

// Merge all partitions into a single stream so the pre-scan
// collects valid keys from the entire dataset, not just one
// partition's file group.
let coalesced: Arc<dyn ExecutionPlan> = Arc::new(CoalescePartitionsExec::new(filtered));

(Some(pre_scan), prescan_physical_filters)
Some(coalesced)
} else {
(None, vec![])
None
};

Ok(Some(Arc::new(USearchExec::new(SearchParams {
Expand All @@ -216,8 +227,7 @@ impl ExtensionPlanner for USearchExecPlanner {
query_vec: node.query_vec_f64(),
k: node.k,
distance_type: node.distance_type.clone(),
physical_filters,
prescan_filters,
has_filters: !node.filters.is_empty(),
schema: registered.schema.clone(),
key_col: registered.key_col.clone(),
scalar_kind: registered.scalar_kind,
Expand All @@ -237,11 +247,9 @@ struct SearchParams {
query_vec: Vec<f64>,
k: usize,
distance_type: DistanceType,
physical_filters: Vec<Arc<dyn PhysicalExpr>>,
/// Physical filters compiled against the pre-scan's projected schema.
/// Column indices match the narrow _key + filter-col projection, not the
/// full table schema. Used by adaptive_filtered_execute for pre-scan evaluation.
prescan_filters: Vec<Arc<dyn PhysicalExpr>>,
/// Whether the query has WHERE-clause filters. Used to choose between the
/// unfiltered HNSW path and the adaptive filtered path.
has_filters: bool,
schema: SchemaRef,
key_col: String,
scalar_kind: ScalarKind,
Expand Down Expand Up @@ -276,10 +284,8 @@ impl DisplayAs for USearchExec {
fn fmt_as(&self, _t: DisplayFormatType, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"USearchExec: table={}, k={}, filters={}",
self.params.table_name,
self.params.k,
self.params.physical_filters.len()
"USearchExec: table={}, k={}, filtered={}",
self.params.table_name, self.params.k, self.params.has_filters
)
}
}
Expand Down Expand Up @@ -346,7 +352,7 @@ impl ExecutionPlan for USearchExec {
fields(
usearch.table = %params.table_name,
usearch.k = params.k,
usearch.filter_count = params.physical_filters.len(),
usearch.has_filters = params.has_filters,
)
)]
async fn usearch_execute(
Expand All @@ -361,7 +367,7 @@ async fn usearch_execute(
))
})?;

if params.physical_filters.is_empty() {
if !params.has_filters {
// ── Unfiltered path ───────────────────────────────────────────────
let matches = {
let _span = tracing::info_span!(
Expand Down Expand Up @@ -424,7 +430,7 @@ async fn usearch_execute(
fields(
usearch.table = %params.table_name,
usearch.k = params.k,
usearch.filter_count = params.physical_filters.len(),
usearch.has_filters = params.has_filters,
usearch.valid_rows = tracing::field::Empty,
usearch.total_rows = tracing::field::Empty,
usearch.selectivity = tracing::field::Empty,
Expand All @@ -449,24 +455,19 @@ async fn adaptive_filtered_execute(
// Key column index in lookup_provider schema — used by attach_distances (high-sel path).
let lookup_key_col_idx = provider_key_col_idx(registered)?;

// ── Phase 1: Pre-scan (scalar + _key only) for selectivity estimation ────
// ── Phase 1: Pre-scan for selectivity estimation ───────────────────────
// The scan_plan is CoalescePartitionsExec → FilterExec → DataSourceExec,
// so execute(0) yields already-filtered rows from all partitions.
let mut stream = scan_plan.execute(0, task_ctx.clone())?;
let mut valid_keys: HashSet<u64> = HashSet::new();

let scan_span = tracing::info_span!("usearch_pre_scan", usearch.table = %params.table_name);
async {
while let Some(batch_result) = stream.next().await {
let batch = batch_result?;
let mask = evaluate_filters(&params.prescan_filters, &batch)?;
let keys = extract_keys_as_u64(batch.column(scan_key_col_idx).as_ref())?;

for row_idx in 0..batch.num_rows() {
if !mask.is_null(row_idx)
&& mask.value(row_idx)
&& let Some(Some(key)) = keys.get(row_idx)
{
valid_keys.insert(*key);
}
for key in keys.into_iter().flatten() {
valid_keys.insert(key);
}
}
Ok::<_, datafusion::error::DataFusionError>(())
Expand Down Expand Up @@ -794,39 +795,6 @@ fn compute_raw_distance_f64(v: &[f64], q: &[f64], dist_type: &DistanceType) -> f

// ── Helpers ───────────────────────────────────────────────────────────────────

/// AND all physical filter expressions against a batch.
/// Returns a BooleanArray (one value per row, true = passes all filters).
fn evaluate_filters(
filters: &[Arc<dyn PhysicalExpr>],
batch: &RecordBatch,
) -> Result<BooleanArray> {
use datafusion::arrow::compute;

if filters.is_empty() {
return Ok(BooleanArray::from(vec![true; batch.num_rows()]));
}

let mut combined: Option<BooleanArray> = None;
for filter in filters {
let col_val = filter.evaluate(batch)?;
let arr = col_val.into_array(batch.num_rows())?;
let bool_arr = arr
.as_any()
.downcast_ref::<BooleanArray>()
.ok_or_else(|| {
DataFusionError::Execution("filter expression did not return BooleanArray".into())
})?
.clone();

combined = Some(match combined {
None => bool_arr,
Some(prev) => compute::and(&prev, &bool_arr)
.map_err(|e| DataFusionError::ArrowError(Box::new(e), None))?,
});
}
Ok(combined.unwrap())
}

/// Extract the distance from a single row of a vector column.
///
/// Index of the key column in the lookup provider schema.
Expand Down
Loading