diff --git a/datafusion/functions-window/src/nth_value.rs b/datafusion/functions-window/src/nth_value.rs index df723772166a6..d608c4e5ac84c 100644 --- a/datafusion/functions-window/src/nth_value.rs +++ b/datafusion/functions-window/src/nth_value.rs @@ -280,6 +280,7 @@ impl WindowUDFImpl for NthValue { state, ignore_nulls: partition_evaluator_args.ignore_nulls(), n: 0, + valid_index_cache: None, })); } @@ -313,6 +314,7 @@ impl WindowUDFImpl for NthValue { state, ignore_nulls: partition_evaluator_args.ignore_nulls(), n, + valid_index_cache: None, })) } @@ -375,6 +377,80 @@ pub(crate) struct NthValueEvaluator { state: NthValueState, ignore_nulls: bool, n: i64, + valid_index_cache: Option, +} + +#[derive(Debug)] +struct ValidIndexCache { + nulls: NullBuffer, + scanned_to: usize, + valid_indices: Vec, +} + +impl ValidIndexCache { + fn new(nulls: &NullBuffer) -> Self { + Self { + nulls: nulls.clone(), + scanned_to: 0, + valid_indices: Vec::new(), + } + } + + fn matches(&self, nulls: &NullBuffer) -> bool { + self.nulls.validity().as_ptr() == nulls.validity().as_ptr() + && self.nulls.offset() == nulls.offset() + && self.nulls.len() == nulls.len() + && self.nulls.null_count() == nulls.null_count() + } + + fn extend_to(&mut self, end: usize) { + let end = end.min(self.nulls.len()); + if end <= self.scanned_to { + return; + } + + self.valid_indices + .extend((self.scanned_to..end).filter(|index| self.nulls.is_valid(*index))); + self.scanned_to = end; + } + + fn valid_index( + &mut self, + range: &Range, + kind: NthValueKind, + n: i64, + ) -> Option { + self.extend_to(range.end); + + let start = self + .valid_indices + .partition_point(|index| *index < range.start); + let end = self + .valid_indices + .partition_point(|index| *index < range.end); + + match kind { + NthValueKind::First => self + .valid_indices + .get(start) + .copied() + .filter(|index| *index < range.end), + NthValueKind::Last => (start < end).then(|| self.valid_indices[end - 1]), + NthValueKind::Nth => match n.cmp(&0) { + Ordering::Greater => { + let index = (n as usize) - 1; + let target = start.checked_add(index)?; + (target < end).then(|| self.valid_indices[target]) + } + Ordering::Less => { + let reverse_index = (-n) as usize; + (reverse_index <= end - start) + .then(|| self.valid_indices[end - reverse_index]) + } + Ordering::Equal => None, + }, + } + } } impl PartitionEvaluator for NthValueEvaluator { @@ -483,15 +559,20 @@ impl PartitionEvaluator for NthValueEvaluator { } impl NthValueEvaluator { - fn valid_index(&self, array: &ArrayRef, range: &Range) -> Option { + fn valid_index(&mut self, array: &ArrayRef, range: &Range) -> Option { let n_range = range.end - range.start; if self.ignore_nulls { // Calculate valid indices, inside the window frame boundaries. - let slice = array.slice(range.start, n_range); - if let Some(nulls) = slice.nulls() + if let Some(nulls) = array.nulls() && nulls.null_count() > 0 { - return self.valid_index_with_nulls(nulls, range.start); + let cache = self + .valid_index_cache + .get_or_insert_with(|| ValidIndexCache::new(nulls)); + if !cache.matches(nulls) { + *cache = ValidIndexCache::new(nulls); + } + return cache.valid_index(range, self.state.kind, self.n); } } // Either no nulls, or nulls are regarded as valid rows @@ -522,34 +603,6 @@ impl NthValueEvaluator { }, } } - - fn valid_index_with_nulls(&self, nulls: &NullBuffer, offset: usize) -> Option { - match self.state.kind { - NthValueKind::First => nulls.valid_indices().next().map(|idx| idx + offset), - NthValueKind::Last => nulls.valid_indices().last().map(|idx| idx + offset), - NthValueKind::Nth => { - match self.n.cmp(&0) { - Ordering::Greater => { - // SQL indices are not 0-based. - let index = (self.n as usize) - 1; - nulls.valid_indices().nth(index).map(|idx| idx + offset) - } - Ordering::Less => { - let reverse_index = (-self.n) as usize; - let valid_indices_len = nulls.len() - nulls.null_count(); - if reverse_index > valid_indices_len { - return None; - } - nulls - .valid_indices() - .nth(valid_indices_len - reverse_index) - .map(|idx| idx + offset) - } - Ordering::Equal => None, - } - } - } - } } #[cfg(test)] @@ -670,6 +723,92 @@ mod tests { Ok(()) } + fn test_i32_ignore_nulls_result( + expr: NthValue, + input_exprs: &[Arc], + expected: Int32Array, + ) -> Result<()> { + let arr: ArrayRef = Arc::new(Int32Array::from(vec![ + None, + Some(2), + None, + Some(4), + Some(5), + None, + ])); + let values = vec![arr]; + let ranges = [ + Range { start: 0, end: 6 }, + Range { start: 0, end: 1 }, + Range { start: 2, end: 3 }, + Range { start: 0, end: 2 }, + Range { start: 0, end: 4 }, + ]; + let input_fields: Vec = + vec![Field::new("f", DataType::Int32, true).into()]; + + let mut evaluator = expr.partition_evaluator(PartitionEvaluatorArgs::new( + input_exprs, + &input_fields, + false, + true, + ))?; + let result = ranges + .iter() + .map(|range| evaluator.evaluate(&values, range)) + .collect::>>()?; + let result = ScalarValue::iter_to_array(result)?; + let result = as_int32_array(&result)?; + assert_eq!(expected, *result); + Ok(()) + } + + #[test] + fn first_value_ignore_nulls_cached_ranges() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_ignore_nulls_result( + NthValue::first(), + &[expr], + Int32Array::from(vec![Some(2), None, None, Some(2), Some(2)]), + ) + } + + #[test] + fn last_value_ignore_nulls_cached_ranges() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + test_i32_ignore_nulls_result( + NthValue::last(), + &[expr], + Int32Array::from(vec![Some(5), None, None, Some(2), Some(4)]), + ) + } + + #[test] + fn nth_value_ignore_nulls_cached_ranges() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let n_value = + Arc::new(Literal::new(ScalarValue::Int32(Some(2)))) as Arc; + + test_i32_ignore_nulls_result( + NthValue::nth(), + &[expr, n_value], + Int32Array::from(vec![Some(4), None, None, None, Some(4)]), + ) + } + + #[test] + fn nth_value_negative_ignore_nulls_cached_ranges() -> Result<()> { + let expr = Arc::new(Column::new("c3", 0)) as Arc; + let n_value = + Arc::new(Literal::new(ScalarValue::Int32(Some(-2)))) as Arc; + + test_i32_ignore_nulls_result( + NthValue::nth(), + &[expr, n_value], + Int32Array::from(vec![Some(4), None, None, None, Some(2)]), + ) + } + #[test] fn nth_value_i64_min_returns_error() { let expr = Arc::new(Column::new("c3", 0)) as Arc;