diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 1d3e244d73971..be73b0a9d11be 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -38,8 +38,10 @@ use datafusion_expr::{ColumnarValue, expr_vec_fmt}; mod array_static_filter; mod primitive_filter; +mod result; mod static_filter; mod strategy; +mod transform; use static_filter::StaticFilter; use strategy::instantiate_static_filter; diff --git a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs index 93bfcd49600d0..75e92dbcc59b4 100644 --- a/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/array_static_filter.rs @@ -23,11 +23,11 @@ use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::{SortOptions, take}; use arrow::datatypes::DataType; use arrow::util::bit_iterator::BitIndexIterator; -use datafusion_common::HashMap; use datafusion_common::Result; use datafusion_common::hash_utils::{RandomState, with_hashes}; -use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashTable; +use super::result::build_in_list_result; use super::static_filter::StaticFilter; /// Static filter for InList that stores the array and hash set for O(1) lookups @@ -35,11 +35,92 @@ use super::static_filter::StaticFilter; pub(super) struct ArrayStaticFilter { in_array: ArrayRef, state: RandomState, - /// Used to provide a lookup from value to in list index + /// Stores indices into `in_array` for O(1) lookups. + table: HashTable, +} + +impl ArrayStaticFilter { + /// Computes a [`StaticFilter`] for the provided [`Array`] if there + /// are nulls present or there are more than the configured number of + /// elements. /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, + /// Note: This is split into a separate function as higher-rank trait bounds currently + /// cause type inference to misbehave + pub(super) fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(ArrayStaticFilter { + in_array, + state: RandomState::default(), + table: HashTable::new(), + }); + } + + let state = RandomState::default(); + let table = Self::build_haystack_table(&in_array, &state)?; + + Ok(Self { + in_array, + state, + table, + }) + } + + fn build_haystack_table( + haystack: &ArrayRef, + state: &RandomState, + ) -> Result> { + let mut table = HashTable::new(); + + with_hashes([haystack.as_ref()], state, |hashes| -> Result<()> { + let cmp = make_comparator(haystack, haystack, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + // Only insert if not already present (deduplication) + if table.find(hash, |&x| cmp(x, idx).is_eq()).is_none() { + table.insert_unique(hash, idx, |&x| hashes[x]); + } + }; + + match haystack.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..haystack.len()).for_each(insert_value), + } + + Ok(()) + })?; + + Ok(table) + } + + fn find_needles_in_haystack( + &self, + needles: &dyn Array, + negated: bool, + ) -> Result { + let needle_nulls = needles.logical_nulls(); + let haystack_has_nulls = self.in_array.null_count() != 0; + + with_hashes([needles], &self.state, |needle_hashes| { + let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?; + + Ok(build_in_list_result( + needles.len(), + needle_nulls.as_ref(), + haystack_has_nulls, + negated, + #[inline(always)] + |i| { + let hash = needle_hashes[i]; + self.table.find(hash, |&idx| cmp(i, idx).is_eq()).is_some() + }, + )) + }) + } } impl StaticFilter for ArrayStaticFilter { @@ -76,85 +157,6 @@ impl StaticFilter for ArrayStaticFilter { _ => {} } - let needle_nulls = v.logical_nulls(); - let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; - - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - Ok((0..v.len()) - .map(|i| { - // SQL three-valued logic: null IN (...) is always null - if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }) - .collect()) - }) - } -} - -impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave - pub(super) fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set - if in_array.data_type() == &DataType::Null { - return Ok(ArrayStaticFilter { - in_array, - state: RandomState::default(), - map: HashMap::with_hasher(()), - }); - } - - let state = RandomState::default(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([&in_array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array, - state, - map, - }) + self.find_needles_in_haystack(v, negated) } } diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs index 2c084a1cb247b..fb1d34bb0e5ec 100644 --- a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -15,16 +15,258 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{ - Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array, -}; +//! Optimized primitive type filters for InList expressions. +//! +//! This module provides membership tests for Arrow primitive types. + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::buffer::{BooleanBuffer, NullBuffer}; -use arrow::compute::take; use arrow::datatypes::*; use datafusion_common::{HashSet, Result, exec_datafusion_err}; use std::hash::{Hash, Hasher}; -use super::static_filter::StaticFilter; +use super::result::build_in_list_result; +use super::static_filter::{StaticFilter, handle_dictionary}; + +pub(super) trait BitmapStorage: Send + Sync { + fn new_zeroed() -> Self; + fn set_bit(&mut self, index: usize); + fn get_bit(&self, index: usize) -> bool; +} + +impl BitmapStorage for [u64; 4] { + #[inline] + fn new_zeroed() -> Self { + [0u64; 4] + } + #[inline] + fn set_bit(&mut self, index: usize) { + self[index / 64] |= 1u64 << (index % 64); + } + #[inline(always)] + fn get_bit(&self, index: usize) -> bool { + (self[index / 64] >> (index % 64)) & 1 != 0 + } +} + +impl BitmapStorage for Box<[u64; 1024]> { + #[inline] + fn new_zeroed() -> Self { + Box::new([0u64; 1024]) + } + #[inline] + fn set_bit(&mut self, index: usize) { + self[index / 64] |= 1u64 << (index % 64); + } + #[inline(always)] + fn get_bit(&self, index: usize) -> bool { + (self[index / 64] >> (index % 64)) & 1 != 0 + } +} + +pub(super) trait BitmapFilterConfig: Send + Sync + 'static { + const DATA_TYPE_NAME: &'static str; + + type Native: ArrowNativeType + Copy + Send + Sync; + type ArrowType: ArrowPrimitiveType; + type Storage: BitmapStorage; + + fn to_index(v: Self::Native) -> usize; +} + +pub(super) enum UInt8BitmapConfig {} +impl BitmapFilterConfig for UInt8BitmapConfig { + const DATA_TYPE_NAME: &'static str = "UInt8"; + + type Native = u8; + type ArrowType = UInt8Type; + type Storage = [u64; 4]; + + #[inline(always)] + fn to_index(v: u8) -> usize { + v as usize + } +} + +pub(super) enum UInt16BitmapConfig {} +impl BitmapFilterConfig for UInt16BitmapConfig { + const DATA_TYPE_NAME: &'static str = "UInt16"; + + type Native = u16; + type ArrowType = UInt16Type; + type Storage = Box<[u64; 1024]>; + + #[inline(always)] + fn to_index(v: u16) -> usize { + v as usize + } +} + +/// Bitmap filter for O(1) set membership via single bit test. +/// +/// Small integer domains can store membership in a fixed-size bitmap instead +/// of using a hash table. +pub(super) struct BitmapFilter { + null_count: usize, + bits: C::Storage, +} + +impl BitmapFilter { + pub(super) fn try_new(in_array: &ArrayRef) -> Result { + let prim_array = + in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("BitmapFilter: expected {} array", C::DATA_TYPE_NAME) + })?; + let mut bits = C::Storage::new_zeroed(); + for v in prim_array.iter().flatten() { + bits.set_bit(C::to_index(v)); + } + Ok(Self { + null_count: prim_array.null_count(), + bits, + }) + } + + #[inline(always)] + fn check(&self, needle: C::Native) -> bool { + self.bits.get_bit(C::to_index(needle)) + } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(super) fn contains_slice( + &self, + values: &[C::Native], + nulls: Option<&NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..values.len()`. + let needle = unsafe { *values.get_unchecked(i) }; + self.check(needle) + }) + } +} + +impl StaticFilter for BitmapFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("BitmapFilter: expected {} array", C::DATA_TYPE_NAME) + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..v.len()`, which matches `input_values.len()`. + let needle = unsafe { *input_values.get_unchecked(i) }; + self.check(needle) + }, + )) + } +} + +/// A branchless filter for very small fixed-width primitive IN lists. +/// +/// Uses const generics to unroll the membership check into a fixed-size +/// comparison chain, outperforming hash lookups for small lists due to: +/// - No branching (uses bitwise OR to combine comparisons) +/// - Better CPU pipelining +/// - No hash computation overhead +pub(super) struct BranchlessFilter { + null_count: usize, + values: [T::Native; N], +} + +impl BranchlessFilter +where + T::Native: Copy + PartialEq, +{ + /// Try to create a branchless filter if the array has exactly N non-null values. + pub(super) fn try_new(in_array: &ArrayRef) -> Option> { + let in_array = in_array.as_primitive_opt::()?; + let non_null_count = in_array.len() - in_array.null_count(); + if non_null_count != N { + return None; + } + // Use default_value() from ArrowPrimitiveType trait instead of Default::default() + let mut arr = [T::default_value(); N]; + let mut i = 0; + for value in in_array.iter().flatten() { + arr[i] = value; + i += 1; + } + debug_assert_eq!(i, N); + Some(Ok(Self { + null_count: in_array.null_count(), + values: arr, + })) + } + + /// Branchless membership check using OR-chain. + #[inline(always)] + fn check(&self, needle: T::Native) -> bool { + self.values + .iter() + .fold(false, |acc, &v| acc | (v == needle)) + } + + /// Check membership using a raw values slice (zero-copy path for type reinterpretation). + #[inline] + pub(super) fn contains_slice( + &self, + values: &[T::Native], + nulls: Option<&NullBuffer>, + negated: bool, + ) -> BooleanArray { + build_in_list_result(values.len(), nulls, self.null_count > 0, negated, |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..values.len()`. + let needle = unsafe { *values.get_unchecked(i) }; + self.check(needle) + }) + } +} + +impl StaticFilter for BranchlessFilter +where + T::Native: Copy + PartialEq + Send + Sync, +{ + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to primitive type") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| { + // SAFETY: `build_in_list_result` invokes this closure for + // indices in `0..v.len()`, which matches `input_values.len()`. + let needle = unsafe { *input_values.get_unchecked(i) }; + self.check(needle) + }, + )) + } +} /// Wrapper for f32 that implements Hash and Eq using bit comparison. /// This treats NaN values as equal to each other when they have the same bit pattern. @@ -94,9 +336,13 @@ macro_rules! primitive_static_filter { impl $Name { pub(super) fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); @@ -115,19 +361,14 @@ macro_rules! primitive_static_filter { } fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } + handle_dictionary!(self, v, negated); - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; let haystack_has_nulls = self.null_count > 0; let needle_values = v.values(); @@ -188,8 +429,10 @@ macro_rules! primitive_static_filter { } (true, true) => { // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); + let needle_validity = + needle_nulls.map(|n| n.inner().clone()).unwrap_or_else( + || BooleanBuffer::new_set(needle_values.len()), + ); // Valid when original "in set" is true (see above) let haystack_validity = if negated { @@ -210,13 +453,8 @@ macro_rules! primitive_static_filter { }; } -// Generate specialized filters for all integer primitive types -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); primitive_static_filter!(Int32StaticFilter, Int32Type); primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); primitive_static_filter!(UInt32StaticFilter, UInt32Type); primitive_static_filter!(UInt64StaticFilter, UInt64Type); @@ -231,3 +469,75 @@ macro_rules! float_static_filter { // Generate specialized filters for float types using ordered wrappers float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{DictionaryArray, Int8Array, UInt8Array, UInt16Array}; + + fn assert_contains( + filter: &dyn StaticFilter, + needles: &dyn Array, + expected: Vec>, + ) -> Result<()> { + assert_eq!( + filter.contains(needles, false)?, + BooleanArray::from(expected) + ); + Ok(()) + } + + #[test] + fn bitmap_filter_u8_handles_nulls() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])); + let filter = BitmapFilter::::try_new(&haystack)?; + let needles = UInt8Array::from(vec![Some(1), Some(2), None, Some(3)]); + + assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)])?; + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), None, None, Some(false)]) + ); + + Ok(()) + } + + #[test] + fn bitmap_filter_u8_handles_dictionary_needles() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt8Array::from(vec![Some(1), None, Some(3)])); + let filter = BitmapFilter::::try_new(&haystack)?; + + let keys = Int8Array::from(vec![Some(0), Some(1), None, Some(2)]); + let values = Arc::new(UInt8Array::from(vec![Some(1), Some(2), Some(3)])); + let needles = DictionaryArray::try_new(keys, values)?; + + assert_contains(&filter, &needles, vec![Some(true), None, None, Some(true)]) + } + + #[test] + fn bitmap_filter_u16_handles_boundaries_and_nulls() -> Result<()> { + let haystack: ArrayRef = Arc::new(UInt16Array::from(vec![ + Some(0), + None, + Some(1024), + Some(u16::MAX), + ])); + let filter = BitmapFilter::::try_new(&haystack)?; + let needles = + UInt16Array::from(vec![Some(0), Some(1), Some(1024), Some(u16::MAX), None]); + + assert_contains( + &filter, + &needles, + vec![Some(true), None, Some(true), Some(true), None], + )?; + assert_eq!( + filter.contains(&needles, true)?, + BooleanArray::from(vec![Some(false), None, Some(false), Some(false), None]) + ); + + Ok(()) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/result.rs b/datafusion/physical-expr/src/expressions/in_list/result.rs new file mode 100644 index 0000000000000..3ebdbfe19f743 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/result.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Result building helpers for InList operations. +//! +//! This module provides unified logic for building BooleanArray results +//! from IN list membership tests, handling null propagation correctly +//! according to SQL three-valued logic. + +use arrow::array::BooleanArray; +use arrow::buffer::{BooleanBuffer, NullBuffer}; + +// Truth table for (needle_nulls, haystack_has_nulls, negated): +// (Some, true, false) => values: valid & contains, nulls: valid & contains +// (None, true, false) => values: contains, nulls: contains +// (Some, true, true) => values: valid & !contains, nulls: valid & contains +// (None, true, true) => values: !contains, nulls: contains +// (Some, false, false) => values: valid & contains, nulls: valid +// (Some, false, true) => values: valid & !contains, nulls: valid +// (None, false, false) => values: contains, nulls: none +// (None, false, true) => values: !contains, nulls: none + +/// Builds a BooleanArray result for IN list operations. +/// +/// This function handles the null propagation logic for SQL IN lists: +/// - If the needle value is null, the result is null +/// - If the needle is not in the set and the haystack has nulls, the result is null +/// - Otherwise, the result is true/false based on membership and negation +/// +/// This version computes contains for all positions, including nulls, then applies +/// null masking via bitmap operations. +#[inline] +pub(crate) fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: FnMut(usize) -> bool, +{ + let contains_buf = BooleanBuffer::collect_bool(len, contains); + build_result_from_contains(needle_nulls, haystack_has_nulls, negated, contains_buf) +} + +/// Builds a BooleanArray result from a pre-computed contains buffer. +/// +/// This version does not assume contains_buf is pre-masked at null positions. +/// It handles nulls using bitmap operations. +#[inline] +pub(crate) fn build_result_from_contains( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is null unless value is found. + (Some(v), true, false) => { + // values: valid & contains, nulls: valid & contains + let values = v.inner() & &contains_buf; + BooleanArray::new(values.clone(), Some(NullBuffer::new(values))) + } + (None, true, false) => { + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) + } + (Some(v), true, true) => { + // NOT IN with nulls: false if found, null if not found or needle null. + // values: valid & !contains, nulls: valid & contains + let valid = v.inner(); + let values = valid & &(!&contains_buf); + let nulls = valid & &contains_buf; + BooleanArray::new(values, Some(NullBuffer::new(nulls))) + } + (None, true, true) => { + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) + } + // Haystack has no nulls: result validity follows needle validity. + (Some(v), false, false) => { + // values: valid & contains, nulls: valid + BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) + } + (Some(v), false, true) => { + // values: valid & !contains, nulls: valid + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs index 218bd27950266..3c964d4183474 100644 --- a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -35,3 +35,20 @@ pub(super) trait StaticFilter { /// implementation unwraps the dictionary and operates on its values. fn contains(&self, v: &dyn Array, negated: bool) -> Result; } + +/// Evaluate dictionary-encoded needles by applying a filter to dictionary +/// values and remapping the result through the keys. +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + arrow::array::downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = arrow::compute::take(&values_contains, $v.keys(), None)?; + return Ok(arrow::array::downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +pub(super) use handle_dictionary; diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs index b7ee3dd1a3b9d..c99bf49f9e742 100644 --- a/datafusion/physical-expr/src/expressions/in_list/strategy.rs +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -19,13 +19,66 @@ use std::sync::Arc; use arrow::array::ArrayRef; use arrow::compute::cast; -use arrow::datatypes::DataType; -use datafusion_common::Result; +use arrow::datatypes::*; +use datafusion_common::{Result, exec_datafusion_err}; use super::array_static_filter::ArrayStaticFilter; use super::primitive_filter::*; use super::static_filter::StaticFilter; +use super::transform::{make_bitmap_filter, make_branchless_filter}; +/// Maximum list size for branchless lookup on 4-byte primitives (Int32, UInt32, Float32). +const BRANCHLESS_MAX_4B: usize = 32; + +/// Maximum list size for branchless lookup on 8-byte primitives (Int64, UInt64, Float64). +const BRANCHLESS_MAX_8B: usize = 16; + +/// Maximum list size for branchless lookup on 16-byte types (Decimal128). +const BRANCHLESS_MAX_16B: usize = 4; + +/// The lookup strategy to use for a given data type and list size. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterStrategy { + /// Bitmap filter for u8/u16 domains. + Bitmap1B, + Bitmap2B, + /// Branchless OR-chain for small lists. + Branchless, + /// Generic ArrayStaticFilter fallback. + Generic, +} + +/// Selects the lookup strategy based on data type and list size. +fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy { + match dt.primitive_width() { + Some(1) => FilterStrategy::Bitmap1B, + Some(2) => FilterStrategy::Bitmap2B, + Some(4) => { + if len <= BRANCHLESS_MAX_4B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + Some(8) => { + if len <= BRANCHLESS_MAX_8B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + Some(16) => { + if len <= BRANCHLESS_MAX_16B { + FilterStrategy::Branchless + } else { + FilterStrategy::Generic + } + } + _ => FilterStrategy::Generic, + } +} + +/// Creates the optimal static filter for the given array. pub(super) fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { @@ -36,22 +89,46 @@ pub(super) fn instantiate_static_filter( DataType::Dictionary(_, value_type) => cast(&in_array, value_type.as_ref())?, _ => in_array, }; - match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Float primitive types (use ordered wrappers for Hash/Eq) - DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } + use FilterStrategy::*; + + let len = in_array.len(); + let dt = in_array.data_type(); + let strategy = select_strategy(dt, len); + + match (dt, strategy) { + // Bitmap filters for 1-byte and 2-byte types + (_, Bitmap1B) => make_bitmap_filter::(&in_array), + (_, Bitmap2B) => make_bitmap_filter::(&in_array), + + // Branchless filters for small lists of primitives + (_, Branchless) => dispatch_branchless(&in_array).ok_or_else(|| { + exec_datafusion_err!( + "Branchless strategy selected but no filter for {:?}", + dt + ) + })?, + + // Fallback for larger primitive lists or complex types. + (_, Generic) => match dt { + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + } +} + +fn dispatch_branchless( + arr: &ArrayRef, +) -> Option>> { + // Dispatch to width-specific branchless filter. + match arr.data_type().primitive_width() { + Some(4) => Some(make_branchless_filter::(arr)), + Some(8) => Some(make_branchless_filter::(arr)), + Some(16) => Some(make_branchless_filter::(arr)), + _ => None, } } diff --git a/datafusion/physical-expr/src/expressions/in_list/transform.rs b/datafusion/physical-expr/src/expressions/in_list/transform.rs new file mode 100644 index 0000000000000..6b4a5523dc7f2 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/transform.rs @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Type transformation utilities for InList filters. +//! +//! Some filters only depend on fixed-width value bit patterns. For those cases, +//! compatible primitive arrays can be reinterpreted to the filter's unsigned +//! storage type without copying values. + +use std::mem::size_of; +use std::sync::Arc; + +use arrow::array::{Array, ArrayRef, BooleanArray, PrimitiveArray}; +use arrow::buffer::ScalarBuffer; +use arrow::datatypes::ArrowPrimitiveType; +use datafusion_common::{Result, exec_datafusion_err}; + +use super::primitive_filter::{BitmapFilter, BitmapFilterConfig, BranchlessFilter}; +use super::static_filter::{StaticFilter, handle_dictionary}; + +/// Reinterpreting filter for bitmap lookups (u8/u16). +struct ReinterpretedBitmap { + inner: BitmapFilter, +} + +impl StaticFilter for ReinterpretedBitmap { + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + if v.data_type().primitive_width() != Some(size_of::()) { + return Err(exec_datafusion_err!( + "BitmapFilter: expected {}-byte primitive array, got {}", + size_of::(), + v.data_type() + )); + } + + let data = v.to_data(); + let values: &[C::Native] = &data.buffer::(0)[..v.len()]; + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Reinterpreting filter for branchless lookups. +struct ReinterpretedBranchless { + inner: BranchlessFilter, +} + +impl StaticFilter for ReinterpretedBranchless +where + T: ArrowPrimitiveType + 'static, + T::Native: Copy + PartialEq + Send + Sync + 'static, +{ + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + if v.data_type().primitive_width() != Some(size_of::()) { + return Err(exec_datafusion_err!( + "BranchlessFilter: expected {}-byte primitive array, got {}", + size_of::(), + v.data_type() + )); + } + + let data = v.to_data(); + let values: &[T::Native] = &data.buffer::(0)[..v.len()]; + + Ok(self.inner.contains_slice(values, data.nulls(), negated)) + } +} + +/// Reinterprets a same-width primitive array as the target primitive type `T`. +/// +/// This is a zero-copy operation: the returned array shares the original values +/// buffer and null buffer. Callers must ensure the source array and target type +/// have the same primitive width. +#[inline] +pub(crate) fn reinterpret_any_primitive_to( + array: &dyn Array, +) -> ArrayRef { + let data = array.to_data(); + let values = data.buffers()[0].clone(); + let buffer = ScalarBuffer::::new(values, data.offset(), data.len()); + Arc::new(PrimitiveArray::::new(buffer, array.nulls().cloned())) +} + +/// Creates a bitmap filter for u8/u16 types, reinterpreting if needed. +pub(crate) fn make_bitmap_filter( + in_array: &ArrayRef, +) -> Result> +where + C: BitmapFilterConfig, +{ + if in_array.data_type() == &C::ArrowType::DATA_TYPE { + return Ok(Arc::new(BitmapFilter::::try_new(in_array)?)); + } + if in_array.data_type().primitive_width() != Some(size_of::()) { + return Err(exec_datafusion_err!( + "BitmapFilter: expected {}-byte primitive array for {} bitmap, got {}", + size_of::(), + C::DATA_TYPE_NAME, + in_array.data_type() + )); + } + + let reinterpreted = reinterpret_any_primitive_to::(in_array.as_ref()); + let inner = BitmapFilter::::try_new(&reinterpreted)?; + Ok(Arc::new(ReinterpretedBitmap { inner })) +} + +/// Creates a branchless filter for primitive types. +/// +/// Dispatches based on byte width and element count: +/// - 4-byte types (Int32, Float32, etc.): supports 0-32 elements +/// - 8-byte types (Int64, Float64, Timestamp, etc.): supports 0-16 elements +/// - 16-byte types (Decimal128): supports 0-4 elements +pub(crate) fn make_branchless_filter( + in_array: &ArrayRef, +) -> Result> +where + D: ArrowPrimitiveType + 'static, + D::Native: Copy + PartialEq + Send + Sync + 'static, +{ + let is_native = in_array.data_type() == &D::DATA_TYPE; + let width = size_of::(); + let arr = if is_native { + Arc::clone(in_array) + } else { + if in_array.data_type().primitive_width() != Some(width) { + return Err(exec_datafusion_err!( + "BranchlessFilter: expected {width}-byte primitive array, got {}", + in_array.data_type() + )); + } + reinterpret_any_primitive_to::(in_array.as_ref()) + }; + let n = arr.len() - arr.null_count(); + + // Helper to create the filter for a known size N + #[inline] + fn create( + arr: &ArrayRef, + is_native: bool, + ) -> Result> + where + D::Native: Copy + PartialEq + Send + Sync + 'static, + { + let inner = BranchlessFilter::::try_new(arr) + .expect("size verified") + .expect("type verified"); + if is_native { + Ok(Arc::new(inner)) + } else { + Ok(Arc::new(ReinterpretedBranchless { inner })) + } + } + + // Match on (width, count) - shared sizes use or-patterns to avoid duplication + match (width, n) { + // All widths: 0-4 + (4 | 8 | 16, 0) => create::(&arr, is_native), + (4 | 8 | 16, 1) => create::(&arr, is_native), + (4 | 8 | 16, 2) => create::(&arr, is_native), + (4 | 8 | 16, 3) => create::(&arr, is_native), + (4 | 8 | 16, 4) => create::(&arr, is_native), + // 4-byte and 8-byte: 5-16 + (4 | 8, 5) => create::(&arr, is_native), + (4 | 8, 6) => create::(&arr, is_native), + (4 | 8, 7) => create::(&arr, is_native), + (4 | 8, 8) => create::(&arr, is_native), + (4 | 8, 9) => create::(&arr, is_native), + (4 | 8, 10) => create::(&arr, is_native), + (4 | 8, 11) => create::(&arr, is_native), + (4 | 8, 12) => create::(&arr, is_native), + (4 | 8, 13) => create::(&arr, is_native), + (4 | 8, 14) => create::(&arr, is_native), + (4 | 8, 15) => create::(&arr, is_native), + (4 | 8, 16) => create::(&arr, is_native), + // 4-byte only: 17-32 + (4, 17) => create::(&arr, is_native), + (4, 18) => create::(&arr, is_native), + (4, 19) => create::(&arr, is_native), + (4, 20) => create::(&arr, is_native), + (4, 21) => create::(&arr, is_native), + (4, 22) => create::(&arr, is_native), + (4, 23) => create::(&arr, is_native), + (4, 24) => create::(&arr, is_native), + (4, 25) => create::(&arr, is_native), + (4, 26) => create::(&arr, is_native), + (4, 27) => create::(&arr, is_native), + (4, 28) => create::(&arr, is_native), + (4, 29) => create::(&arr, is_native), + (4, 30) => create::(&arr, is_native), + (4, 31) => create::(&arr, is_native), + (4, 32) => create::(&arr, is_native), + // Error cases + (4, n) => datafusion_common::exec_err!( + "Branchless filter for 4-byte types supports 0-32 elements, got {n}" + ), + (8, n) => datafusion_common::exec_err!( + "Branchless filter for 8-byte types supports 0-16 elements, got {n}" + ), + (16, n) => datafusion_common::exec_err!( + "Branchless filter for 16-byte types supports 0-4 elements, got {n}" + ), + (w, _) => datafusion_common::exec_err!( + "Branchless filter not supported for {w}-byte types" + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use arrow::array::{ArrayRef, BooleanArray, Int8Array, Int16Array, Int32Array}; + use arrow::datatypes::UInt32Type; + + #[test] + fn reinterpreted_bitmap_handles_signed_boundaries_and_slices() -> Result<()> { + let haystack: ArrayRef = Arc::new( + Int8Array::from(vec![Some(99), Some(i8::MIN), None, Some(-1), Some(42)]) + .slice(1, 3), + ); + let filter = make_bitmap_filter::< + super::super::primitive_filter::UInt8BitmapConfig, + >(&haystack)?; + let needles = + Int8Array::from(vec![Some(7), Some(i8::MIN), Some(-1), None]).slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), Some(true), None]) + ); + + let haystack: ArrayRef = Arc::new( + Int16Array::from(vec![ + Some(123), + Some(i16::MIN), + None, + Some(-1), + Some(i16::MAX), + ]) + .slice(1, 4), + ); + let filter = make_bitmap_filter::< + super::super::primitive_filter::UInt16BitmapConfig, + >(&haystack)?; + let needles = + Int16Array::from(vec![Some(0), Some(i16::MIN), Some(7), Some(i16::MAX)]) + .slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + Ok(()) + } + + #[test] + fn reinterpreted_branchless_handles_slices() -> Result<()> { + let haystack: ArrayRef = Arc::new( + Int32Array::from(vec![Some(99), Some(-7), None, Some(42)]).slice(1, 3), + ); + let filter = make_branchless_filter::(&haystack)?; + let needles = + Int32Array::from(vec![Some(0), Some(-7), Some(1), Some(42)]).slice(1, 3); + + assert_eq!( + filter.contains(&needles, false)?, + BooleanArray::from(vec![Some(true), None, Some(true)]) + ); + + Ok(()) + } +}