diff --git a/datafusion/functions/src/string/common.rs b/datafusion/functions/src/string/common.rs index 77af82e25c483..cc97c1b2c1957 100644 --- a/datafusion/functions/src/string/common.rs +++ b/datafusion/functions/src/string/common.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::strings::make_and_append_view; use arrow::array::{ Array, ArrayRef, GenericStringArray, GenericStringBuilder, NullBufferBuilder, - OffsetSizeTrait, StringBuilder, StringViewArray, new_null_array, + OffsetSizeTrait, StringViewArray, StringViewBuilder, new_null_array, }; use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::DataType; @@ -358,10 +358,8 @@ where >(array, op)?)), DataType::Utf8View => { let string_array = as_string_view_array(array)?; - let mut string_builder = StringBuilder::with_capacity( - string_array.len(), - string_array.get_array_memory_size(), - ); + let mut string_builder = + StringViewBuilder::with_capacity(string_array.len()); for str in string_array.iter() { if let Some(str) = str { @@ -386,7 +384,7 @@ where } ScalarValue::Utf8View(a) => { let result = a.as_ref().map(|x| op(x)); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result))) + Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(result))) } other => exec_err!("Unsupported data type {other:?} for function {name}"), }, diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 3750d3d290a9c..d91e4595c58ac 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -19,7 +19,6 @@ use arrow::datatypes::DataType; use std::any::Any; use crate::string::common::to_lower; -use crate::utils::utf8_to_str_type; use datafusion_common::Result; use datafusion_common::types::logical_string; use datafusion_expr::{ @@ -82,7 +81,7 @@ impl ScalarUDFImpl for LowerFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "lower") + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -97,8 +96,7 @@ impl ScalarUDFImpl for LowerFunc { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, StringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; use std::sync::Arc; @@ -111,7 +109,7 @@ mod tests { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], arg_fields, - return_field: Field::new("f", Utf8, true).into(), + return_field: Field::new("f", expected.data_type().clone(), true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -197,4 +195,21 @@ mod tests { to_lower(input, expected) } + + #[test] + fn lower_utf8view() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), + None, + Some("TSCHÜSS"), + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("arrow"), + None, + Some("tschüss"), + ])) as ArrayRef; + + to_lower(input, expected) + } } diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index a2a7db1848f59..80375f58c87be 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -16,7 +16,6 @@ // under the License. use crate::string::common::to_upper; -use crate::utils::utf8_to_str_type; use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_common::types::logical_string; @@ -81,7 +80,7 @@ impl ScalarUDFImpl for UpperFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - utf8_to_str_type(&arg_types[0], "upper") + Ok(arg_types[0].clone()) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -96,8 +95,7 @@ impl ScalarUDFImpl for UpperFunc { #[cfg(test)] mod tests { use super::*; - use arrow::array::{Array, ArrayRef, StringArray}; - use arrow::datatypes::DataType::Utf8; + use arrow::array::{Array, ArrayRef, StringArray, StringViewArray}; use arrow::datatypes::Field; use datafusion_common::config::ConfigOptions; use std::sync::Arc; @@ -110,7 +108,7 @@ mod tests { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], arg_fields: vec![arg_field], - return_field: Field::new("f", Utf8, true).into(), + return_field: Field::new("f", expected.data_type().clone(), true).into(), config_options: Arc::new(ConfigOptions::default()), }; @@ -196,4 +194,21 @@ mod tests { to_upper(input, expected) } + + #[test] + fn upper_utf8view() -> Result<()> { + let input = Arc::new(StringViewArray::from(vec![ + Some("arrow"), + None, + Some("tschüß"), + ])) as ArrayRef; + + let expected = Arc::new(StringViewArray::from(vec![ + Some("ARROW"), + None, + Some("TSCHÜSS"), + ])) as ArrayRef; + + to_upper(input, expected) + } } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 5a43d18e23879..9dcfd83ab7f92 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -435,6 +435,11 @@ SELECT upper(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) ---- FOO +query T +SELECT upper(arrow_cast(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 'Dictionary(Int32, Utf8View)')) +---- +FOO + query T SELECT upper('árvore ação αβγ') ---- @@ -445,6 +450,26 @@ SELECT upper(arrow_cast('árvore ação αβγ', 'Dictionary(Int32, Utf8)')) ---- ÁRVORE AÇÃO ΑΒΓ +query T +SELECT arrow_typeof(upper('foo')) +---- +Utf8 + +query T +SELECT arrow_typeof(upper(arrow_cast('foo', 'LargeUtf8'))) +---- +LargeUtf8 + +query T +SELECT arrow_typeof(upper(arrow_cast('foo', 'Utf8View'))) +---- +Utf8View + +query T +SELECT arrow_typeof(upper(arrow_cast(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 'Dictionary(Int32, Utf8View)'))) +---- +Utf8View + query T SELECT btrim(' foo ') ---- @@ -490,6 +515,11 @@ SELECT lower(arrow_cast('FOObar', 'Dictionary(Int32, Utf8)')) ---- foobar +query T +SELECT lower(arrow_cast(arrow_cast('FOObar', 'Dictionary(Int32, Utf8)'), 'Dictionary(Int32, Utf8View)')) +---- +foobar + query T SELECT lower('ÁRVORE AÇÃO ΑΒΓ') ---- @@ -500,6 +530,26 @@ SELECT lower(arrow_cast('ÁRVORE AÇÃO ΑΒΓ', 'Dictionary(Int32, Utf8)')) ---- árvore ação αβγ +query T +SELECT arrow_typeof(lower('FOObar')) +---- +Utf8 + +query T +SELECT arrow_typeof(lower(arrow_cast('FOObar', 'LargeUtf8'))) +---- +LargeUtf8 + +query T +SELECT arrow_typeof(lower(arrow_cast('FOObar', 'Utf8View'))) +---- +Utf8View + +query T +SELECT arrow_typeof(lower(arrow_cast(arrow_cast('FOObar', 'Dictionary(Int32, Utf8)'), 'Dictionary(Int32, Utf8View)'))) +---- +Utf8View + query T SELECT ltrim(' foo') ----