diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index e700d4a04da3b..0fe0eed2304d5 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -1452,6 +1452,37 @@ fn mathematics_numerical_coercion( ) -> Option { use arrow::datatypes::DataType::*; + // String coercion + if let Some(coerced) = match (lhs_type, rhs_type) { + (s, Float64) + | (Float64, s) + | (Decimal32(_, _), s) + | (s, Decimal32(_, _)) + | (Decimal64(_, _), s) + | (s, Decimal64(_, _)) + | (Decimal128(_, _), s) + | (s, Decimal128(_, _)) + | (Decimal256(_, _), s) + | (s, Decimal256(_, _)) + if s.is_string() => + { + Some(Float64) + } + (s, Float32) | (Float32, s) if s.is_string() => Some(Float32), + (s, Float16) | (Float16, s) if s.is_string() => Some(Float16), + (s, Int64) | (Int64, s) if s.is_string() => Some(Int64), + (s, Int32) | (Int32, s) if s.is_string() => Some(Int32), + (s, Int16) | (Int16, s) if s.is_string() => Some(Int16), + (s, Int8) | (Int8, s) if s.is_string() => Some(Int8), + (s, UInt64) | (UInt64, s) if s.is_string() => Some(UInt64), + (s, UInt32) | (UInt32, s) if s.is_string() => Some(UInt32), + (s, UInt16) | (UInt16, s) if s.is_string() => Some(UInt16), + (s, UInt8) | (UInt8, s) if s.is_string() => Some(UInt8), + _ => None, + } { + return Some(coerced); + } + // Error on any non-numeric type if !both_numeric_or_null_and_numeric(lhs_type, rhs_type) { return None; diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index 70a8fc0e35a15..ff2cc2e0456fe 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -18,20 +18,6 @@ use super::*; use datafusion_common::assert_contains; -#[test] -fn test_coercion_error() -> Result<()> { - let coercer = - BinaryTypeCoercer::new(&DataType::Float32, &Operator::Plus, &DataType::Utf8); - let result_type = coercer.get_input_types(); - - let e = result_type.unwrap_err(); - assert_eq!( - e.strip_backtrace(), - "Error during planning: Cannot coerce arithmetic expression Float32 + Utf8 to valid types" - ); - Ok(()) -} - #[test] fn test_date_timestamp_arithmetic_error() -> Result<()> { let (lhs, rhs) = BinaryTypeCoercer::new( @@ -225,6 +211,39 @@ fn test_type_coercion_arithmetic() -> Result<()> { // (Int8, _) | (_, Int8) => Some(Int8) test_coercion_binary_rule!(Int8, Int8, Operator::Plus, Int8); + test_coercion_binary_rule_multiple!( + Utf8, + [ + Float64, + Decimal32(10, 2), + Decimal64(10, 2), + Decimal128(10, 2), + Decimal256(10, 2), + ], + Operator::Plus, + Float64 + ); + test_coercion_binary_rule!(Utf8, Float32, Operator::Plus, Float32); + test_coercion_binary_rule!(Utf8, Float16, Operator::Plus, Float16); + test_coercion_binary_rule!(Utf8, Int64, Operator::Plus, Int64); + test_coercion_binary_rule!(Utf8, Int32, Operator::Plus, Int32); + test_coercion_binary_rule!(Utf8, Int16, Operator::Plus, Int16); + test_coercion_binary_rule!(Utf8, Int8, Operator::Plus, Int8); + test_coercion_binary_rule!(Utf8, UInt64, Operator::Plus, UInt64); + test_coercion_binary_rule!(Utf8, UInt32, Operator::Plus, UInt32); + test_coercion_binary_rule!(Utf8, UInt16, Operator::Plus, UInt16); + test_coercion_binary_rule!(Utf8, UInt8, Operator::Plus, UInt8); + test_coercion_binary_rule!(Float32, Utf8, Operator::Plus, Float32); + test_coercion_binary_rule!(Float16, Utf8, Operator::Plus, Float16); + test_coercion_binary_rule!(Int64, Utf8, Operator::Plus, Int64); + test_coercion_binary_rule!(Int32, Utf8, Operator::Plus, Int32); + test_coercion_binary_rule!(Int16, Utf8, Operator::Plus, Int16); + test_coercion_binary_rule!(Int8, Utf8, Operator::Plus, Int8); + test_coercion_binary_rule!(UInt64, Utf8, Operator::Plus, UInt64); + test_coercion_binary_rule!(UInt32, Utf8, Operator::Plus, UInt32); + test_coercion_binary_rule!(UInt16, Utf8, Operator::Plus, UInt16); + test_coercion_binary_rule!(UInt8, Utf8, Operator::Plus, UInt8); + Ok(()) } diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 7a729739469d3..d2da57e837607 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -202,20 +202,6 @@ fn test_ambiguous_reference() -> Result<()> { Ok(()) } -#[test] -fn test_incompatible_types_binary_arithmetic() -> Result<()> { - let query = "SELECT /*whole+left*/id/*left*/ + /*right*/first_name/*right+whole*/ FROM person"; - let spans = get_spans(query); - let diag = do_query(query); - assert_snapshot!(diag.message, @"expressions have incompatible types"); - assert_eq!(diag.span, Some(spans["whole"])); - assert_snapshot!(diag.notes[0].message, @"has type UInt32"); - assert_eq!(diag.notes[0].span, Some(spans["left"])); - assert_snapshot!(diag.notes[1].message, @"has type Utf8"); - assert_eq!(diag.notes[1].span, Some(spans["right"])); - Ok(()) -} - #[test] fn test_field_not_found_suggestion() -> Result<()> { let query = "SELECT /*whole*/first_na/*whole*/ FROM person";