diff --git a/src/udf.rs b/src/udf.rs index 233fcc1..32652fe 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -139,14 +139,23 @@ fn inner_to_f32(inner: &dyn Array, udf_name: &str) -> Result> { /// /// Supports all outer array types (FixedSizeList, List, LargeList) and /// inner element types (Float32, Float64 — Float64 is cast to f32 for the kernel). +/// +/// Returns an error if vector dimensions do not match the query length. fn compute_distances( vec_col: &dyn Array, query_vec: &[f32], kernel: Kernel, udf_name: &str, ) -> Result>> { - // FixedSizeListArray — typical for DuckDB FLOAT[N] or pre-cast columns + // FixedSizeListArray — dimension known from type, validate once upfront. if let Some(fsl) = vec_col.as_any().downcast_ref::() { + let col_dim = fsl.value_length() as usize; + if col_dim != query_vec.len() { + return Err(DataFusionError::Execution(format!( + "{udf_name}: query vector length ({}) must match column dimensionality ({col_dim})", + query_vec.len(), + ))); + } let mut out = Vec::with_capacity(fsl.len()); for i in 0..fsl.len() { if fsl.is_null(i) { @@ -159,7 +168,7 @@ fn compute_distances( return Ok(out); } - // ListArray — variable-length, e.g. PostgreSQL real[] / float8[] + // ListArray — variable-length, validate per row. if let Some(lst) = vec_col.as_any().downcast_ref::() { let mut out = Vec::with_capacity(lst.len()); for i in 0..lst.len() { @@ -168,12 +177,19 @@ fn compute_distances( continue; } let f32s = inner_to_f32(&*lst.value(i), udf_name)?; + if f32s.len() != query_vec.len() { + return Err(DataFusionError::Execution(format!( + "{udf_name}: query vector length ({}) must match row {i} dimensionality ({})", + query_vec.len(), + f32s.len(), + ))); + } out.push(Some(kernel(&f32s, query_vec))); } return Ok(out); } - // LargeListArray — large-offset variant, e.g. some Postgres/Parquet encodings + // LargeListArray — large-offset variant, validate per row. if let Some(lst) = vec_col.as_any().downcast_ref::() { let mut out = Vec::with_capacity(lst.len()); for i in 0..lst.len() { @@ -182,6 +198,13 @@ fn compute_distances( continue; } let f32s = inner_to_f32(&*lst.value(i), udf_name)?; + if f32s.len() != query_vec.len() { + return Err(DataFusionError::Execution(format!( + "{udf_name}: query vector length ({}) must match row {i} dimensionality ({})", + query_vec.len(), + f32s.len(), + ))); + } out.push(Some(kernel(&f32s, query_vec))); } return Ok(out); diff --git a/tests/execution.rs b/tests/execution.rs index 542b5fe..5fa7c33 100644 --- a/tests/execution.rs +++ b/tests/execution.rs @@ -458,6 +458,274 @@ async fn exec_parquet_native_where_no_matches() { // Numeric regression — l2_distance must return L2sq (no sqrt) // ═══════════════════════════════════════════════════════════════════════════════ +// ═══════════════════════════════════════════════════════════════════════════════ +// Split-provider tests — lookup_provider WITHOUT vector column +// +// In production, lookup_provider (SQLite) does NOT have the vector column. +// These tests verify that the USearch optimized path fires and works correctly +// when selecting specific columns (not SELECT *). +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Build context with split providers (scan has vector, lookup doesn't) +/// and default brute_force threshold. +async fn make_split_provider_ctx(reg_key: &str) -> SessionContext { + let schema = exec_schema(); + let batch = test_batch(&schema); + + // scan_provider: full schema including vector column (simulates Parquet) + let scan_provider = Arc::new( + HashKeyProvider::try_new(schema.clone(), vec![batch.clone()], "id") + .expect("scan HashKeyProvider"), + ); + + // lookup_provider: no vector column (simulates SQLite) + let lookup_batch = { + let ids = batch.column(0).clone(); + let labels = batch.column(1).clone(); + RecordBatch::try_new(lookup_schema(), vec![ids, labels]).expect("lookup batch") + }; + let lookup_provider = Arc::new( + HashKeyProvider::try_new(lookup_schema(), vec![lookup_batch], "id") + .expect("lookup HashKeyProvider"), + ); + + let reg = USearchRegistry::new(); + reg.add( + reg_key, + make_populated_index(), + scan_provider, + lookup_provider, + "id", + MetricKind::L2sq, + ScalarKind::F32, + ) + .expect("reg.add"); + let registry = reg.into_arc(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(USearchQueryPlanner::new(registry.clone()))) + .build(); + let ctx = SessionContext::new_with_state(state); + register_all(&ctx, registry).expect("register_all"); + + let table_provider = Arc::new( + HashKeyProvider::try_new(exec_schema(), vec![test_batch(&exec_schema())], "id") + .expect("table HashKeyProvider"), + ); + ctx.register_table("items", table_provider) + .expect("register_table"); + ctx +} + +/// SELECT specific columns (no vector) with distance UDF — must use USearch path. +/// This is the exact pattern that fails in production while SELECT * works. +#[tokio::test] +async fn exec_split_provider_select_specific_columns() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = + format!("SELECT id, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 2"); + let ids = collect_ids(&ctx, &sql).await; + assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}"); + assert_eq!(ids.len(), 2, "expected 2 results; got {ids:?}"); +} + +/// SELECT * with distance UDF — should fall back to UDF brute-force +/// (since vector column is not in lookup provider schema). +#[tokio::test] +async fn exec_split_provider_select_star() { + let ctx = make_split_provider_ctx("items::vector").await; + let sql = + format!("SELECT *, l2_distance(vector, {Q}) AS dist FROM items ORDER BY dist ASC LIMIT 2"); + let df = ctx.sql(&sql).await.expect("sql"); + let batches = df.collect().await.expect("collect"); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 2, "expected 2 results"); +} + +/// SELECT specific columns with fully qualified table name. +#[tokio::test] +async fn exec_split_provider_qualified_select_specific() { + let ctx = make_split_provider_ctx("datafusion::public::items::vector").await; + let sql = format!( + "SELECT id, l2_distance(vector, {Q}) AS dist FROM datafusion.public.items ORDER BY dist ASC LIMIT 2" + ); + let ids = collect_ids(&ctx, &sql).await; + assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}"); +} + +/// negative_dot_product with split providers and IP metric — mirrors production setup. +#[tokio::test] +async fn exec_split_provider_negative_dot_product() { + let schema = exec_schema(); + let batch = test_batch(&schema); + + let scan_provider = Arc::new( + HashKeyProvider::try_new(schema.clone(), vec![batch.clone()], "id") + .expect("scan HashKeyProvider"), + ); + + let lookup_batch = { + let ids = batch.column(0).clone(); + let labels = batch.column(1).clone(); + RecordBatch::try_new(lookup_schema(), vec![ids, labels]).expect("lookup batch") + }; + let lookup_provider = Arc::new( + HashKeyProvider::try_new(lookup_schema(), vec![lookup_batch], "id") + .expect("lookup HashKeyProvider"), + ); + + // Build IP-metric index + let opts = IndexOptions { + dimensions: 4, + metric: MetricKind::IP, + quantization: ScalarKind::F32, + ..Default::default() + }; + let index = Arc::new(Index::new(&opts).expect("Index::new")); + index.reserve(4).expect("reserve"); + let rows: &[(u64, [f32; 4])] = &[ + (1, [1.0, 0.0, 0.0, 0.0]), + (2, [0.0, 1.0, 0.0, 0.0]), + (3, [0.0, 0.0, 1.0, 0.0]), + (4, [0.0, 0.0, 0.0, 1.0]), + ]; + for &(key, ref v) in rows { + index.add(key, v.as_slice()).expect("index.add"); + } + + let reg = USearchRegistry::new(); + reg.add( + "items::vector", + index, + scan_provider, + lookup_provider, + "id", + MetricKind::IP, + ScalarKind::F32, + ) + .expect("reg.add"); + let registry = reg.into_arc(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(USearchQueryPlanner::new(registry.clone()))) + .build(); + let ctx = SessionContext::new_with_state(state); + register_all(&ctx, registry).expect("register_all"); + let table_provider = Arc::new( + HashKeyProvider::try_new(exec_schema(), vec![test_batch(&exec_schema())], "id") + .expect("table HashKeyProvider"), + ); + ctx.register_table("items", table_provider) + .expect("register_table"); + + // This is the exact pattern that fails: SELECT specific_cols, negative_dot_product, ORDER BY alias + let sql = "SELECT id, negative_dot_product(vector, ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float]) AS dist FROM items ORDER BY dist ASC LIMIT 2"; + let ids = collect_ids(&ctx, sql).await; + assert_eq!(ids[0], 1, "closest must be row 1\nids: {ids:?}"); +} + +/// 768-dim negative_dot_product with split providers — reproduces production query pattern. +#[tokio::test] +async fn exec_split_provider_768dim_negative_dot_product() { + let dim = 768i32; + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt64, false), + Field::new("label", DataType::Utf8, false), + Field::new( + "vector", + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), dim), + false, + ), + ])); + + let ids_arr = UInt64Array::from(vec![1u64, 2, 3, 4]); + let labels_arr = StringArray::from(vec!["a", "b", "c", "d"]); + let vecs: Vec> = (0..4) + .map(|row| { + (0..dim as usize) + .map(|i| ((row * dim as usize + i) as f32) * 0.001) + .collect() + }) + .collect(); + let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), dim); + for v in &vecs { + builder.values().append_slice(v); + builder.append(true); + } + let vector_col: FixedSizeListArray = builder.finish(); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(ids_arr), + Arc::new(labels_arr), + Arc::new(vector_col), + ], + ) + .unwrap(); + + let scan_provider = + Arc::new(HashKeyProvider::try_new(schema.clone(), vec![batch.clone()], "id").unwrap()); + let lookup_batch = RecordBatch::try_new( + lookup_schema(), + vec![batch.column(0).clone(), batch.column(1).clone()], + ) + .unwrap(); + let lookup_provider = + Arc::new(HashKeyProvider::try_new(lookup_schema(), vec![lookup_batch], "id").unwrap()); + + let opts = IndexOptions { + dimensions: dim as usize, + metric: MetricKind::IP, + quantization: ScalarKind::F32, + ..Default::default() + }; + let index = Arc::new(Index::new(&opts).unwrap()); + index.reserve(4).unwrap(); + for (row, key) in vecs.iter().zip([1u64, 2, 3, 4]) { + index.add(key, row.as_slice()).unwrap(); + } + + let reg = USearchRegistry::new(); + reg.add( + "items::vector", + index, + scan_provider, + lookup_provider, + "id", + MetricKind::IP, + ScalarKind::F32, + ) + .unwrap(); + let registry = reg.into_arc(); + + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(USearchQueryPlanner::new(registry.clone()))) + .build(); + let ctx = SessionContext::new_with_state(state); + register_all(&ctx, registry).unwrap(); + + let table_provider = Arc::new(HashKeyProvider::try_new(schema, vec![batch], "id").unwrap()); + ctx.register_table("items", table_provider).unwrap(); + + // Build 768-element query array + let query_arr: Vec = (0..dim) + .map(|i| format!("{:.6}", i as f64 * 0.001)) + .collect(); + let query_str = query_arr.join(","); + let sql = format!( + "SELECT id, negative_dot_product(vector, ARRAY[{}]) AS dist FROM items ORDER BY dist ASC LIMIT 2", + query_str + ); + + let df = ctx.sql(&sql).await.expect("sql failed"); + let batches = df.collect().await.expect("collect failed"); + let total: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total, 2, "expected 2 results"); +} + /// l2_distance must return squared L2, not actual L2. /// Row 1 = [1,0,0,0], query = [1,0,0,0] → L2sq = 0.0 /// Row 2 = [0,1,0,0], query = [1,0,0,0] → L2sq = 2.0 (L2 would be ~1.414) @@ -510,3 +778,70 @@ async fn exec_l2_distance_returns_l2sq() { row2.1 ); } + +// ═══════════════════════════════════════════════════════════════════════════════ +// Dimension mismatch — UDF must reject mismatched query vectors +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Dimension mismatch must error — optimizer path (USearch) catches it. +#[tokio::test] +async fn udf_dimension_mismatch_fewer() { + let ctx = make_exec_ctx("items::vector").await; + // Column is 4-dim, query is 3-dim + let sql = "SELECT id, l2_distance(vector, ARRAY[1.0::float, 0.0::float, 0.0::float]) AS dist FROM items ORDER BY dist ASC LIMIT 2"; + let err = ctx + .sql(sql) + .await + .expect("sql") + .collect() + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("must match"), + "expected dimension mismatch error, got: {msg}" + ); +} + +/// Dimension mismatch must error — optimizer path (USearch) catches it. +#[tokio::test] +async fn udf_dimension_mismatch_more() { + let ctx = make_exec_ctx("items::vector").await; + // Column is 4-dim, query is 5-dim + let sql = "SELECT id, l2_distance(vector, ARRAY[1.0::float, 0.0::float, 0.0::float, 0.0::float, 0.0::float]) AS dist FROM items ORDER BY dist ASC LIMIT 2"; + let err = ctx + .sql(sql) + .await + .expect("sql") + .collect() + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("must match"), + "expected dimension mismatch error, got: {msg}" + ); +} + +/// SELECT * with mismatched dimensions must also error (not silently truncate). +/// This is the key test: SELECT * bypasses the optimizer (vector column not in +/// lookup schema), so the UDF brute-force path runs. Before the fix, zip() +/// silently truncated and returned wrong results. +#[tokio::test] +async fn udf_dimension_mismatch_select_star() { + let ctx = make_split_provider_ctx("items::vector").await; + // Column is 4-dim, query is 3-dim. SELECT * falls back to UDF path. + let sql = "SELECT *, l2_distance(vector, ARRAY[1.0::float, 0.0::float, 0.0::float]) AS dist FROM items ORDER BY dist ASC LIMIT 2"; + let err = ctx + .sql(sql) + .await + .expect("sql") + .collect() + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("must match"), + "expected dimension mismatch error, got: {msg}" + ); +}