From df30f85172452eafb8b09767089faf6749970fce Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 11 May 2024 06:33:23 -0400 Subject: [PATCH 1/6] Minor: Clarify usecase for `LogicalPlan::recompute_schema` (#10443) --- datafusion/expr/src/logical_plan/plan.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7dca12f793699..9832b69f841a9 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -479,6 +479,11 @@ impl LogicalPlan { /// expressions. For example [`LogicalPlan::Filter`] schema is always the /// same as its input schema. /// + /// This is useful after modifying a plans `Expr`s (or input plans) via + /// methods such as [Self::map_children] and [Self::map_expressions]. Unlike + /// [Self::with_new_exprs], this method does not require a new set of + /// expressions or inputs plans. + /// /// # Return value /// Returns an error if there is some issue recomputing the schema. /// From d8bcff5db22f47e5da778b8012bca9e16df35540 Mon Sep 17 00:00:00 2001 From: Jeffrey Vo Date: Sat, 11 May 2024 20:56:04 +1000 Subject: [PATCH 2/6] doc: fix old master branch references to main (#10458) --- README.md | 2 +- docs/source/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8c2392850953d..197e5d2b3fe16 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ Here are links to some important information - [Rust Getting Started](https://datafusion.apache.org/user-guide/example-usage.html) - [Rust DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) - [Rust API docs](https://docs.rs/datafusion/latest/datafusion) -- [Rust Examples](https://github.com/apache/datafusion/tree/master/datafusion-examples) +- [Rust Examples](https://github.com/apache/datafusion/tree/main/datafusion-examples) - [Python DataFrame API](https://arrow.apache.org/datafusion-python/) - [Architecture](https://docs.rs/datafusion/latest/datafusion/index.html#architecture) diff --git a/docs/source/index.rst b/docs/source/index.rst index ca978d4bc30ca..5d6dcd3f87a20 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -46,7 +46,7 @@ The `example usage`_ section in the user guide and the `datafusion-examples`_ co Please see the `developer’s guide`_ for contributing and `communication`_ for getting in touch with us. .. _example usage: user-guide/example-usage.html -.. _datafusion-examples: https://github.com/apache/datafusion/tree/master/datafusion-examples +.. _datafusion-examples: https://github.com/apache/datafusion/tree/main/datafusion-examples .. _developer’s guide: contributor-guide/index.html#developer-s-guide .. _communication: contributor-guide/communication.html From c2095e420088b4db2639829a11714839d59024f4 Mon Sep 17 00:00:00 2001 From: NoeB Date: Sat, 11 May 2024 14:30:22 +0200 Subject: [PATCH 3/6] Move bit_and_or_xor unit tests to slt (#10457) * move bit_and_or_xor unit tests to slt Signed-off-by: NoeB * Apply suggestions from code review --------- Signed-off-by: NoeB --- .../src/aggregate/bit_and_or_xor.rs | 127 ------------ .../sqllogictest/test_files/aggregate.slt | 195 ++++++++++++++++++ 2 files changed, 195 insertions(+), 127 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs index 7244686a5195f..3fa225c5e4791 100644 --- a/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs +++ b/datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs @@ -693,130 +693,3 @@ where + self.values.capacity() * std::mem::size_of::() } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use crate::expressions::tests::aggregate; - use crate::generic_test_op; - use arrow::array::*; - use arrow::datatypes::*; - - #[test] - fn bit_and_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::from(4i32)) - } - - #[test] - fn bit_and_i32_with_nulls() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::from(1i32)) - } - - #[test] - fn bit_and_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitAnd, ScalarValue::Int32(None)) - } - - #[test] - fn bit_and_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitAnd, ScalarValue::from(4u32)) - } - - #[test] - fn bit_or_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::from(15i32)) - } - - #[test] - fn bit_or_i32_with_nulls() -> Result<()> { - let a: ArrayRef = - Arc::new(Int32Array::from(vec![Some(1), None, Some(3), Some(5)])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::from(7i32)) - } - - #[test] - fn bit_or_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitOr, ScalarValue::Int32(None)) - } - - #[test] - fn bit_or_u32() -> Result<()> { - let a: ArrayRef = Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitOr, ScalarValue::from(15u32)) - } - - #[test] - fn bit_xor_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(15i32)) - } - - #[test] - fn bit_xor_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - None, - Some(3), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::from(6i32)) - } - - #[test] - fn bit_xor_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, BitXor, ScalarValue::Int32(None)) - } - - #[test] - fn bit_xor_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32])); - generic_test_op!(a, DataType::UInt32, BitXor, ScalarValue::from(15u32)) - } - - #[test] - fn bit_xor_distinct_i32() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![4, 7, 4, 7, 15])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(12i32)) - } - - #[test] - fn bit_xor_distinct_i32_with_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - Some(1), - None, - Some(3), - Some(5), - ])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::from(7i32)) - } - - #[test] - fn bit_xor_distinct_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, DistinctBitXor, ScalarValue::Int32(None)) - } - - #[test] - fn bit_xor_distinct_u32() -> Result<()> { - let a: ArrayRef = - Arc::new(UInt32Array::from(vec![4_u32, 7_u32, 4_u32, 7_u32, 15_u32])); - generic_test_op!( - a, - DataType::UInt32, - DistinctBitXor, - ScalarValue::from(12u32) - ) - } -} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 40f78e7f4d24d..1e0d522492e79 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -2285,6 +2285,201 @@ ORDER BY tag 33 11 NULL 33 11 NULL 33 11 NULL B +# bit_and_i32 +statement ok +create table t (c int) as values (4), (7), (15); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +4 Int32 + +statement ok +drop table t; + +# bit_and_i32_with_nulls +statement ok +create table t (c int) as values (1), (NULL), (3), (5); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +1 Int32 + +statement ok +drop table t; + +# bit_and_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + +# bit_and_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (15); + +query IT +Select bit_and(c), arrow_typeof(bit_and(c)) from t; +---- +4 UInt32 + +statement ok +drop table t; + +# bit_or_i32 +statement ok +create table t (c int) as values (4), (7), (15); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +15 Int32 + +statement ok +drop table t; + +# bit_or_i32_with_nulls +statement ok +create table t (c int) as values (1), (NULL), (3), (5); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +7 Int32 + +statement ok +drop table t; + +#bit_or_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + + +#bit_or_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (15); + +query IT +Select bit_or(c), arrow_typeof(bit_or(c)) from t; +---- +15 UInt32 + +statement ok +drop table t; + +#bit_xor_i32 +statement ok +create table t (c int) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +15 Int32 + +statement ok +drop table t; + +# bit_xor_i32_with_nulls +statement ok +create table t (c int) as values (1), (1), (NULL), (3), (5); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +6 Int32 + +statement ok +drop table t; + +# bit_xor_i32_all_nulls +statement ok +create table t (c int) as values (NULL), (NULL); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +NULL Int32 + +statement ok +drop table t; + +# bit_xor_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(c), arrow_typeof(bit_xor(c)) from t; +---- +15 UInt32 + +statement ok +drop table t; + +# bit_xor_distinct_i32 +statement ok +create table t (c int) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +12 Int32 + +statement ok +drop table t; + +# bit_xor_distinct_i32_with_nulls +statement ok +create table t (c int) as values (1), (1), (NULL), (3), (5); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +7 Int32 + + +statement ok +drop table t; + +# bit_xor_distinct_i32_all_nulls +statement ok +create table t (c int ) as values (NULL), (NULL); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +NULL Int32 + + +statement ok +drop table t; + +# bit_xor_distinct_u32 +statement ok +create table t (c int unsigned) as values (4), (7), (4), (7), (15); + +query IT +Select bit_xor(DISTINCT c), arrow_typeof(bit_xor(DISTINCT c)) from t; +---- +12 UInt32 + +statement ok +drop table t; + statement ok create table bool_aggregate_functions ( c1 boolean not null, From 76f9e2eb44444b1b6adaf97c4601f5bd32d352d1 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sat, 11 May 2024 20:56:40 +0800 Subject: [PATCH 4/6] Introduce user-defined signature (#10439) * introduce new sig Signed-off-by: jayzhan211 * add udfimpl Signed-off-by: jayzhan211 * replace fun Signed-off-by: jayzhan211 * replace array Signed-off-by: jayzhan211 * coalesce Signed-off-by: jayzhan211 * nvl2 Signed-off-by: jayzhan211 * rm variadic equal Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * rm err msg to fix ci Signed-off-by: jayzhan211 * user defined sig Signed-off-by: jayzhan211 * add err msg Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * cleanup Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * fix ci Signed-off-by: jayzhan211 * upd comment Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/expr/src/expr_schema.rs | 7 +- datafusion/expr/src/signature.rs | 23 +-- .../expr/src/type_coercion/functions.rs | 176 +++++++++++++++--- datafusion/expr/src/udaf.rs | 4 + datafusion/expr/src/udf.rs | 29 +++ datafusion/functions-array/src/make_array.rs | 31 ++- datafusion/functions/src/core/coalesce.rs | 29 ++- datafusion/functions/src/core/nvl2.rs | 44 +++-- .../optimizer/src/analyzer/type_coercion.rs | 59 ++++-- .../physical-expr/src/scalar_function.rs | 4 +- datafusion/sqllogictest/test_files/array.slt | 4 +- .../sqllogictest/test_files/arrow_typeof.slt | 3 +- .../sqllogictest/test_files/coalesce.slt | 16 +- .../sqllogictest/test_files/encoding.slt | 2 +- datafusion/sqllogictest/test_files/errors.slt | 12 +- datafusion/sqllogictest/test_files/expr.slt | 15 +- datafusion/sqllogictest/test_files/math.slt | 4 +- datafusion/sqllogictest/test_files/scalar.slt | 17 +- datafusion/sqllogictest/test_files/struct.slt | 2 +- .../sqllogictest/test_files/timestamps.slt | 2 +- 20 files changed, 359 insertions(+), 124 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 4aca52d67c4f6..ce79f9da64593 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -23,7 +23,7 @@ use crate::expr::{ }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; -use crate::type_coercion::functions::data_types; +use crate::type_coercion::functions::data_types_with_scalar_udf; use crate::{utils, LogicalPlan, Projection, Subquery}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; @@ -139,9 +139,10 @@ impl ExprSchemable for Expr { .map(|e| e.get_type(schema)) .collect::>>()?; // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types(&arg_data_types, func.signature()).map_err(|_| { + data_types_with_scalar_udf(&arg_data_types, func).map_err(|err| { plan_datafusion_err!( - "{}", + "{} and {}", + err, utils::generate_signature_error_msg( func.name(), func.signature().clone(), diff --git a/datafusion/expr/src/signature.rs b/datafusion/expr/src/signature.rs index e2505d6fd65f8..5d925c8605ee2 100644 --- a/datafusion/expr/src/signature.rs +++ b/datafusion/expr/src/signature.rs @@ -91,15 +91,12 @@ pub enum TypeSignature { /// # Examples /// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` Variadic(Vec), - /// One or more arguments of an arbitrary but equal type. - /// DataFusion attempts to coerce all argument types to match the first argument's type + /// The acceptable signature and coercions rules to coerce arguments to this + /// signature are special for this function. If this signature is specified, + /// Datafusion will call [`ScalarUDFImpl::coerce_types`] to prepare argument types. /// - /// # Examples - /// Given types in signature should be coercible to the same final type. - /// A function such as `make_array` is `VariadicEqual`. - /// - /// `make_array(i32, i64) -> make_array(i64, i64)` - VariadicEqual, + /// [`ScalarUDFImpl::coerce_types`]: crate::udf::ScalarUDFImpl::coerce_types + UserDefined, /// One or more arguments with arbitrary types VariadicAny, /// Fixed number of arguments of an arbitrary but equal type out of a list of valid types. @@ -190,8 +187,8 @@ impl TypeSignature { .collect::>() .join(", ")] } - TypeSignature::VariadicEqual => { - vec!["CoercibleT, .., CoercibleT".to_string()] + TypeSignature::UserDefined => { + vec!["UserDefined".to_string()] } TypeSignature::VariadicAny => vec!["Any, .., Any".to_string()], TypeSignature::OneOf(sigs) => { @@ -255,10 +252,10 @@ impl Signature { volatility, } } - /// An arbitrary number of arguments of the same type. - pub fn variadic_equal(volatility: Volatility) -> Self { + /// User-defined coercion rules for the function. + pub fn user_defined(volatility: Volatility) -> Self { Self { - type_signature: TypeSignature::VariadicEqual, + type_signature: TypeSignature::UserDefined, volatility, } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index eb4f325ff818c..583d75e1ccfca 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -20,16 +20,114 @@ use std::sync::Arc; use crate::signature::{ ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD, }; -use crate::{Signature, TypeSignature}; +use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature}; use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; -use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; +use datafusion_common::{ + exec_err, internal_datafusion_err, internal_err, plan_err, Result, +}; use super::binary::{comparison_binary_numeric_coercion, comparison_coercion}; +/// Performs type coercion for scalar function arguments. +/// +/// Returns the data types to which each argument must be coerced to +/// match `signature`. +/// +/// For more details on coercion in general, please see the +/// [`type_coercion`](crate::type_coercion) module. +pub fn data_types_with_scalar_udf( + current_types: &[DataType], + func: &ScalarUDF, +) -> Result> { + let signature = func.signature(); + + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "[data_types_with_scalar_udf] signature {:?} does not support zero arguments.", + &signature.type_signature + ); + } + } + + let valid_types = + get_valid_types_with_scalar_udf(&signature.type_signature, current_types, func)?; + + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, current_types) { + return Ok(types); + } + } + + // none possible -> Error + plan_err!( + "[data_types_with_scalar_udf] Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ) +} + +pub fn data_types_with_aggregate_udf( + current_types: &[DataType], + func: &AggregateUDF, +) -> Result> { + let signature = func.signature(); + + if current_types.is_empty() { + if signature.type_signature.supports_zero_argument() { + return Ok(vec![]); + } else { + return plan_err!( + "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ); + } + } + + let valid_types = get_valid_types_with_aggregate_udf( + &signature.type_signature, + current_types, + func, + )?; + if valid_types + .iter() + .any(|data_type| data_type == current_types) + { + return Ok(current_types.to_vec()); + } + + // Try and coerce the argument types to match the signature, returning the + // coerced types from the first matching signature. + for valid_types in valid_types { + if let Some(types) = maybe_data_types(&valid_types, current_types) { + return Ok(types); + } + } + + // none possible -> Error + plan_err!( + "[data_types_with_aggregate_udf] Coercion from {:?} to the signature {:?} failed.", + current_types, + &signature.type_signature + ) +} + /// Performs type coercion for function arguments. /// /// Returns the data types to which each argument must be coerced to @@ -46,7 +144,7 @@ pub fn data_types( return Ok(vec![]); } else { return plan_err!( - "Coercion from {:?} to the signature {:?} failed.", + "[data_types] Coercion from {:?} to the signature {:?} failed.", current_types, &signature.type_signature ); @@ -72,12 +170,56 @@ pub fn data_types( // none possible -> Error plan_err!( - "Coercion from {:?} to the signature {:?} failed.", + "[data_types] Coercion from {:?} to the signature {:?} failed.", current_types, &signature.type_signature ) } +fn get_valid_types_with_scalar_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &ScalarUDF, +) -> Result>> { + let valid_types = match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => vec![coerced_types], + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => signatures + .iter() + .filter_map(|t| get_valid_types_with_scalar_udf(t, current_types, func).ok()) + .flatten() + .collect::>(), + _ => get_valid_types(signature, current_types)?, + }; + + Ok(valid_types) +} + +fn get_valid_types_with_aggregate_udf( + signature: &TypeSignature, + current_types: &[DataType], + func: &AggregateUDF, +) -> Result>> { + let valid_types = match signature { + TypeSignature::UserDefined => match func.coerce_types(current_types) { + Ok(coerced_types) => vec![coerced_types], + Err(e) => return exec_err!("User-defined coercion failed with {:?}", e), + }, + TypeSignature::OneOf(signatures) => signatures + .iter() + .filter_map(|t| { + get_valid_types_with_aggregate_udf(t, current_types, func).ok() + }) + .flatten() + .collect::>(), + _ => get_valid_types(signature, current_types)?, + }; + + Ok(valid_types) +} + /// Returns a Vec of all possible valid argument types for the given signature. fn get_valid_types( signature: &TypeSignature, @@ -184,32 +326,14 @@ fn get_valid_types( .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), - TypeSignature::VariadicEqual => { - let new_type = current_types.iter().skip(1).try_fold( - current_types.first().unwrap().clone(), - |acc, x| { - // The coerced types found by `comparison_coercion` are not guaranteed to be - // coercible for the arguments. `comparison_coercion` returns more loose - // types that can be coerced to both `acc` and `x` for comparison purpose. - // See `maybe_data_types` for the actual coercion. - let coerced_type = comparison_coercion(&acc, x); - if let Some(coerced_type) = coerced_type { - Ok(coerced_type) - } else { - internal_err!("Coercion from {acc:?} to {x:?} failed.") - } - }, - ); - - match new_type { - Ok(new_type) => vec![vec![new_type; current_types.len()]], - Err(e) => return Err(e), - } + TypeSignature::UserDefined => { + return internal_err!( + "User-defined signature should be handled by function-specific coerce_types." + ) } TypeSignature::VariadicAny => { vec![current_types.to_vec()] } - TypeSignature::Exact(valid_types) => vec![valid_types.clone()], TypeSignature::ArraySignature(ref function_signature) => match function_signature { diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 67c3b51ca3739..e5a47ddcd8b6a 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -195,6 +195,10 @@ impl AggregateUDF { pub fn create_groups_accumulator(&self) -> Result> { self.inner.create_groups_accumulator() } + + pub fn coerce_types(&self, _args: &[DataType]) -> Result> { + not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) + } } impl From for AggregateUDF diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 29ee4a86e57dc..fadea26e7f4ee 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -213,6 +213,11 @@ impl ScalarUDF { pub fn short_circuits(&self) -> bool { self.inner.short_circuits() } + + /// See [`ScalarUDFImpl::coerce_types`] for more details. + pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + self.inner.coerce_types(arg_types) + } } impl From for ScalarUDF @@ -420,6 +425,29 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn short_circuits(&self) -> bool { false } + + /// Coerce arguments of a function call to types that the function can evaluate. + /// + /// This function is only called if [`ScalarUDFImpl::signature`] returns [`crate::TypeSignature::UserDefined`]. Most + /// UDFs should return one of the other variants of `TypeSignature` which handle common + /// cases + /// + /// See the [type coercion module](crate::type_coercion) + /// documentation for more details on type coercion + /// + /// For example, if your function requires a floating point arguments, but the user calls + /// it like `my_func(1::int)` (aka with `1` as an integer), coerce_types could return `[DataType::Float64]` + /// to ensure the argument was cast to `1::double` + /// + /// # Parameters + /// * `arg_types`: The argument types of the arguments this function with + /// + /// # Return value + /// A Vec the same length as `arg_types`. DataFusion will `CAST` the function call + /// arguments to these specific types. + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { + not_impl_err!("Function {} does not implement coerce_types", self.name()) + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -446,6 +474,7 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { self.inner.name() } diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 770276938f6be..4f7dda933f427 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -26,12 +26,12 @@ use arrow_array::{ use arrow_buffer::OffsetBuffer; use arrow_schema::DataType::{LargeList, List, Null}; use arrow_schema::{DataType, Field}; +use datafusion_common::internal_err; use datafusion_common::{plan_err, utils::array_into_list_array, Result}; use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::Expr; -use datafusion_expr::{ - ColumnarValue, ScalarUDFImpl, Signature, TypeSignature, Volatility, -}; +use datafusion_expr::type_coercion::binary::comparison_coercion; +use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{Expr, TypeSignature}; use crate::utils::make_scalar_function; @@ -58,10 +58,10 @@ impl MakeArray { pub fn new() -> Self { Self { signature: Signature::one_of( - vec![TypeSignature::VariadicEqual, TypeSignature::Any(0)], + vec![TypeSignature::UserDefined, TypeSignature::Any(0)], Volatility::Immutable, ), - aliases: vec![String::from("make_array"), String::from("make_list")], + aliases: vec![String::from("make_list")], } } } @@ -111,6 +111,25 @@ impl ScalarUDFImpl for MakeArray { fn aliases(&self) -> &[String] { &self.aliases } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let new_type = arg_types.iter().skip(1).try_fold( + arg_types.first().unwrap().clone(), + |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + )?; + Ok(vec![new_type; arg_types.len()]) + } } /// `make_array_inner` is the implementation of the `make_array` function. diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 76f2a3ed741b4..63778eb7738ac 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -22,8 +22,8 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::DataType; -use datafusion_common::{exec_err, Result}; -use datafusion_expr::type_coercion::functions::data_types; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::ColumnarValue; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; @@ -41,7 +41,7 @@ impl Default for CoalesceFunc { impl CoalesceFunc { pub fn new() -> Self { Self { - signature: Signature::variadic_equal(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -60,9 +60,7 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, arg_types: &[DataType]) -> Result { - // COALESCE has multiple args and they might get coerced, get a preview of this - let coerced_types = data_types(arg_types, self.signature()); - coerced_types.map(|types| types[0].clone()) + Ok(arg_types[0].clone()) } /// coalesce evaluates to the first value which is not NULL @@ -124,6 +122,25 @@ impl ScalarUDFImpl for CoalesceFunc { fn short_circuits(&self) -> bool { true } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let new_type = arg_types.iter().skip(1).try_fold( + arg_types.first().unwrap().clone(), + |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + )?; + Ok(vec![new_type; arg_types.len()]) + } } #[cfg(test)] diff --git a/datafusion/functions/src/core/nvl2.rs b/datafusion/functions/src/core/nvl2.rs index 66b9ef566a78b..573ac72425fb4 100644 --- a/datafusion/functions/src/core/nvl2.rs +++ b/datafusion/functions/src/core/nvl2.rs @@ -19,8 +19,11 @@ use arrow::array::Array; use arrow::compute::is_not_null; use arrow::compute::kernels::zip::zip; use arrow::datatypes::DataType; -use datafusion_common::{internal_err, plan_datafusion_err, Result}; -use datafusion_expr::{utils, ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{exec_err, internal_err, Result}; +use datafusion_expr::{ + type_coercion::binary::comparison_coercion, ColumnarValue, ScalarUDFImpl, Signature, + Volatility, +}; #[derive(Debug)] pub struct NVL2Func { @@ -36,7 +39,7 @@ impl Default for NVL2Func { impl NVL2Func { pub fn new() -> Self { Self { - signature: Signature::variadic_equal(Volatility::Immutable), + signature: Signature::user_defined(Volatility::Immutable), } } } @@ -55,22 +58,37 @@ impl ScalarUDFImpl for NVL2Func { } fn return_type(&self, arg_types: &[DataType]) -> Result { - if arg_types.len() != 3 { - return Err(plan_datafusion_err!( - "{}", - utils::generate_signature_error_msg( - self.name(), - self.signature().clone(), - arg_types, - ) - )); - } Ok(arg_types[1].clone()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { nvl2_func(args) } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 3 { + return exec_err!( + "NVL2 takes exactly three arguments, but got {}", + arg_types.len() + ); + } + let new_type = arg_types.iter().skip(1).try_fold( + arg_types.first().unwrap().clone(), + |acc, x| { + // The coerced types found by `comparison_coercion` are not guaranteed to be + // coercible for the arguments. `comparison_coercion` returns more loose + // types that can be coerced to both `acc` and `x` for comparison purpose. + // See `maybe_data_types` for the actual coercion. + let coerced_type = comparison_coercion(&acc, x); + if let Some(coerced_type) = coerced_type { + Ok(coerced_type) + } else { + internal_err!("Coercion from {acc:?} to {x:?} failed.") + } + }, + )?; + Ok(vec![new_type; arg_types.len()]) + } } fn nvl2_func(args: &[ColumnarValue]) -> Result { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 61b1d1d77b20c..e5c7afa10e3a2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -37,7 +37,9 @@ use datafusion_expr::logical_plan::Subquery; use datafusion_expr::type_coercion::binary::{ comparison_coercion, get_input_types, like_coercion, }; -use datafusion_expr::type_coercion::functions::data_types; +use datafusion_expr::type_coercion::functions::{ + data_types_with_aggregate_udf, data_types_with_scalar_udf, +}; use datafusion_expr::type_coercion::other::{ get_coerce_type_for_case_expression, get_coerce_type_for_list, }; @@ -45,8 +47,8 @@ use datafusion_expr::type_coercion::{is_datetime, is_utf8_or_large_utf8}; use datafusion_expr::utils::merge_schema; use datafusion_expr::{ is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown, not, - type_coercion, AggregateFunction, Expr, ExprSchemable, LogicalPlan, Operator, - ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, + type_coercion, AggregateFunction, AggregateUDF, Expr, ExprSchemable, LogicalPlan, + Operator, ScalarUDF, Signature, WindowFrame, WindowFrameBound, WindowFrameUnits, }; use crate::analyzer::AnalyzerRule; @@ -303,8 +305,11 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let new_expr = - coerce_arguments_for_signature(args, self.schema, func.signature())?; + let new_expr = coerce_arguments_for_signature_with_scalar_udf( + args, + self.schema, + &func, + )?; let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?; Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf(func, new_expr), @@ -337,10 +342,10 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ))) } AggregateFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( + let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, self.schema, - fun.signature(), + &fun, )?; Ok(Transformed::yes(Expr::AggregateFunction( expr::AggregateFunction::new_udf( @@ -532,10 +537,37 @@ fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result { /// `signature`, if possible. /// /// See the module level documentation for more detail on coercion. -fn coerce_arguments_for_signature( +fn coerce_arguments_for_signature_with_scalar_udf( expressions: Vec, schema: &DFSchema, - signature: &Signature, + func: &ScalarUDF, +) -> Result> { + if expressions.is_empty() { + return Ok(expressions); + } + + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new_types = data_types_with_scalar_udf(¤t_types, func)?; + + expressions + .into_iter() + .enumerate() + .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .collect() +} + +/// Returns `expressions` coerced to types compatible with +/// `signature`, if possible. +/// +/// See the module level documentation for more detail on coercion. +fn coerce_arguments_for_signature_with_aggregate_udf( + expressions: Vec, + schema: &DFSchema, + func: &AggregateUDF, ) -> Result> { if expressions.is_empty() { return Ok(expressions); @@ -546,7 +578,7 @@ fn coerce_arguments_for_signature( .map(|e| e.get_type(schema)) .collect::>>()?; - let new_types = data_types(¤t_types, signature)?; + let new_types = data_types_with_aggregate_udf(¤t_types, func)?; expressions .into_iter() @@ -833,12 +865,9 @@ mod test { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), }) .call(vec![lit("Apple")]); - let plan_err = Projection::try_new(vec![udf], empty) + Projection::try_new(vec![udf], empty) .expect_err("Expected an error due to incorrect function input"); - let expected_error = "Error during planning: No function matches the given name and argument types 'TestScalarUDF(Utf8)'. You might need to add explicit type casts."; - - assert!(plan_err.to_string().starts_with(expected_error)); Ok(()) } @@ -914,7 +943,7 @@ mod test { .err() .unwrap(); assert_eq!( - "type_coercion\ncaused by\nError during planning: Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.", + "type_coercion\ncaused by\nError during planning: [data_types_with_aggregate_udf] Coercion from [Utf8] to the signature Uniform(1, [Float64]) failed.", err.strip_backtrace() ); Ok(()) diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 180f2a7946bd5..1244a9b4db38f 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -39,7 +39,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DFSchema, Result}; -use datafusion_expr::type_coercion::functions::data_types; +use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; @@ -220,7 +220,7 @@ pub fn create_physical_expr( .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types(&input_expr_types, fun.signature())?; + data_types_with_scalar_udf(&input_expr_types, fun)?; // Since we have arg_types, we dont need args and schema. let return_type = diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index eaec0f4d8d6a2..eeb5dc01b6e7e 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -1137,7 +1137,7 @@ from arrays_values_without_nulls; ## array_element (aliases: array_extract, list_extract, list_element) # array_element error -query error DataFusion error: Error during planning: No function matches the given name and argument types 'array_element\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarray_element\(array, index\) +query error select array_element(1, 2); # array_element with null @@ -4625,7 +4625,7 @@ NULL 10 ## array_dims (aliases: `list_dims`) # array dims error -query error DataFusion error: Error during planning: No function matches the given name and argument types 'array_dims\(Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarray_dims\(array\) +query error select array_dims(1); # array_dims scalar function diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 3e8694f3b2c2b..94cce61245e17 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -92,10 +92,9 @@ SELECT arrow_cast('1', 'Int16') 1 # Basic error test -query error DataFusion error: Error during planning: No function matches the given name and argument types 'arrow_cast\(Utf8\)'. You might need to add explicit type casts. +query error SELECT arrow_cast('1') - query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got Literal\(Int64\(43\)\) SELECT arrow_cast('1', 43) diff --git a/datafusion/sqllogictest/test_files/coalesce.slt b/datafusion/sqllogictest/test_files/coalesce.slt index 527d4fe9c41e4..a0317ac4a5f4e 100644 --- a/datafusion/sqllogictest/test_files/coalesce.slt +++ b/datafusion/sqllogictest/test_files/coalesce.slt @@ -23,7 +23,7 @@ select coalesce(1, 2, 3); 1 # test with first null -query IT +query ?T select coalesce(null, 3, 2, 1), arrow_typeof(coalesce(null, 3, 2, 1)); ---- 3 Int64 @@ -35,7 +35,7 @@ select coalesce(null, null); NULL # cast to float -query RT +query IT select coalesce(1, 2.0), arrow_typeof(coalesce(1, 2.0)) @@ -51,7 +51,7 @@ select ---- 2 Float64 -query RT +query IT select coalesce(1, arrow_cast(2.0, 'Float32')), arrow_typeof(coalesce(1, arrow_cast(2.0, 'Float32'))) @@ -177,7 +177,7 @@ select 2 Decimal256(22, 2) # coalesce string -query TT +query T? select coalesce('', 'test'), coalesce(null, 'test'); @@ -226,7 +226,7 @@ select coalesce(column1, 'none_set') from test1; foo none_set -query T +query ? select coalesce(null, column1, 'none_set') from test1; ---- foo @@ -248,12 +248,12 @@ select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ---- 34 -query I +query ? select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'), 34); ---- 123 -query I +query ? select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)')); ---- 34 @@ -288,7 +288,7 @@ SELECT COALESCE(c1, c2) FROM test NULL # numeric string is coerced to numeric in both Postgres and DuckDB -query T +query I SELECT COALESCE(c1, c2, '-1') FROM test; ---- 0 diff --git a/datafusion/sqllogictest/test_files/encoding.slt b/datafusion/sqllogictest/test_files/encoding.slt index 9f4f508e23f32..626af88aa9b8c 100644 --- a/datafusion/sqllogictest/test_files/encoding.slt +++ b/datafusion/sqllogictest/test_files/encoding.slt @@ -40,7 +40,7 @@ select decode(12, 'hex') query error DataFusion error: Error during planning: There is no built\-in encoding named 'non_encoding', currently supported encodings are: base64, hex select decode(hex_field, 'non_encoding') from test; -query error DataFusion error: Error during planning: No function matches the given name and argument types 'to_hex\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tto_hex\(Int64\) +query error select to_hex(hex_field) from test; # Arrays tests diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index ab281eac31f52..b5464e2a274c3 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -38,7 +38,7 @@ WITH HEADER ROW LOCATION '../../testing/data/csv/aggregate_test_100.csv' # csv_query_error -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'sin\(Utf8\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tsin\(Float64/Float32\) +statement error SELECT sin(c1) FROM aggregate_test_100 # cast_expressions_error @@ -80,23 +80,23 @@ SELECT COUNT(*) FROM way.too.many.namespaces.as.ident.prefixes.aggregate_test_10 # # error message for wrong function signature (Variadic: arbitrary number of args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'concat\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\) +statement error SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. +statement error SELECT nullif(1); # error message for wrong function signature (Exact: exact number of args of an exact type) -statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) +statement error SELECT pi(3.14); # error message for wrong function signature (Any: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) +statement error SELECT arrow_typeof(1, 1); # error message for wrong function signature (OneOf: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64, Float64\) +statement error SELECT power(1, 2, 3); # diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 7e7ebd8529daf..129a67208354f 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -1899,22 +1899,21 @@ a # The 'from' and 'for' parameters don't support string types, because they should be treated as # regular expressions, which we have not implemented yet. -query error DataFusion error: Error during planning: No function matches the given name and argument types +query error SELECT substring('alphabet' FROM '3') -query error DataFusion error: Error during planning: No function matches the given name and argument types +query error SELECT substring('alphabet' FROM '3' FOR '2') -query error DataFusion error: Error during planning: No function matches the given name and argument types +query error SELECT substring('alphabet' FROM '3' FOR 2) -query error DataFusion error: Error during planning: No function matches the given name and argument types +query error SELECT substring('alphabet' FROM 3 FOR '2') -query error DataFusion error: Error during planning: No function matches the given name and argument types +query error SELECT substring('alphabet' FOR '2') - ##### csv_query_nullif_divide_by_0 @@ -2275,13 +2274,13 @@ select f64, round(1.0 / f64) as i64_1, acos(round(1.0 / f64)) from doubles; 10.1 0 1.570796326795 # common subexpr with coalesce (short-circuited) -query RRR rowsort +query RRR select f64, coalesce(1.0 / f64, 0.0), acos(coalesce(1.0 / f64, 0.0)) from doubles; ---- 10.1 0.09900990099 1.471623942989 # common subexpr with coalesce (short-circuited) and alias -query RRR rowsort +query RRR select f64, coalesce(1.0 / f64, 0.0) as f64_1, acos(coalesce(1.0 / f64, 0.0)) from doubles; ---- 10.1 0.09900990099 1.471623942989 diff --git a/datafusion/sqllogictest/test_files/math.slt b/datafusion/sqllogictest/test_files/math.slt index 802323ca45ee2..3315ff4549248 100644 --- a/datafusion/sqllogictest/test_files/math.slt +++ b/datafusion/sqllogictest/test_files/math.slt @@ -113,11 +113,11 @@ SELECT iszero(1.0), iszero(0.0), iszero(-0.0), iszero(NULL) false true true NULL # abs: empty argumnet -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +statement error SELECT abs(); # abs: wrong number of arguments -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'abs\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +statement error SELECT abs(1, 2); # abs: unsupported argument type diff --git a/datafusion/sqllogictest/test_files/scalar.slt b/datafusion/sqllogictest/test_files/scalar.slt index 7fb2d55ff84a9..c52881b7b0ba3 100644 --- a/datafusion/sqllogictest/test_files/scalar.slt +++ b/datafusion/sqllogictest/test_files/scalar.slt @@ -1799,34 +1799,33 @@ statement ok drop table test # error message for wrong function signature (Variadic: arbitrary number of args all from some common types) -statement error Error during planning: No function matches the given name and argument types 'concat\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tconcat\(Utf8, ..\) +statement error SELECT concat(); # error message for wrong function signature (Uniform: t args all from some common types) -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'nullif\(Int64\)'. You might need to add explicit type casts. +statement error SELECT nullif(1); - # error message for wrong function signature (Exact: exact number of args of an exact type) -statement error Error during planning: No function matches the given name and argument types 'pi\(Float64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpi\(\) +statement error SELECT pi(3.14); # error message for wrong function signature (Any: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'arrow_typeof\(Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tarrow_typeof\(Any\) +statement error SELECT arrow_typeof(1, 1); # error message for wrong function signature (OneOf: fixed number of args of arbitrary types) -statement error Error during planning: No function matches the given name and argument types 'power\(Int64, Int64, Int64\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tpower\(Int64, Int64\)\n\tpower\(Float64, Float64\) +statement error SELECT power(1, 2, 3); # The following functions need 1 argument -statement error Error during planning: No function matches the given name and argument types 'abs\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tabs\(Any\) +statement error SELECT abs(); -statement error Error during planning: No function matches the given name and argument types 'acos\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tacos\(Float64/Float32\) +statement error SELECT acos(); -statement error Error during planning: No function matches the given name and argument types 'isnan\(\)'. You might need to add explicit type casts.\n\tCandidate functions:\n\tisnan\(Float32\)\n\tisnan\(Float64\) +statement error SELECT isnan(); # turn off enable_ident_normalization diff --git a/datafusion/sqllogictest/test_files/struct.slt b/datafusion/sqllogictest/test_files/struct.slt index 3e685cbb45a03..46a08709c3a3b 100644 --- a/datafusion/sqllogictest/test_files/struct.slt +++ b/datafusion/sqllogictest/test_files/struct.slt @@ -92,7 +92,7 @@ physical_plan 02)--MemoryExec: partitions=1, partition_sizes=[1] # error on 0 arguments -query error DataFusion error: Error during planning: No function matches the given name and argument types 'named_struct\(\)'. You might need to add explicit type casts. +query error select named_struct(); # error on odd number of arguments #1 diff --git a/datafusion/sqllogictest/test_files/timestamps.slt b/datafusion/sqllogictest/test_files/timestamps.slt index 32a28231d0340..13fb8fba0d315 100644 --- a/datafusion/sqllogictest/test_files/timestamps.slt +++ b/datafusion/sqllogictest/test_files/timestamps.slt @@ -538,7 +538,7 @@ select to_timestamp_seconds(cast (1 as int)); ########## # invalid second arg type -query error DataFusion error: Error during planning: No function matches the given name and argument types 'date_bin\(Interval\(MonthDayNano\), Int64, Timestamp\(Nanosecond, None\)\)'\. +query error SELECT DATE_BIN(INTERVAL '0 second', 25, TIMESTAMP '1970-01-01T00:00:00Z') # not support interval 0 From 6d413a4b52fcca76c29cba997661d3ce41c49d72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Sun, 12 May 2024 00:34:32 +0800 Subject: [PATCH 5/6] Remove AggregateFunctionDefinition::Name (#10441) --- datafusion/core/src/physical_planner.rs | 8 -------- datafusion/expr/src/expr.rs | 10 ++-------- datafusion/expr/src/expr_schema.rs | 3 --- datafusion/expr/src/tree_node.rs | 5 +---- datafusion/optimizer/src/analyzer/type_coercion.rs | 3 --- datafusion/optimizer/src/decorrelate.rs | 3 --- datafusion/proto/src/logical_plan/to_proto.rs | 6 ------ datafusion/substrait/src/logical_plan/producer.rs | 3 --- 8 files changed, 3 insertions(+), 38 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 132bc3953cd3d..d4a9a949fc418 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -276,9 +276,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { .collect::>>()?; Ok(format!("{}({})", fun.name(), names.join(","))) } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Aggregate function `Expr` with name should be resolved.") - } }, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => Ok(format!( @@ -1947,11 +1944,6 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( )?; (agg_expr, filter, physical_sort_exprs) } - AggregateFunctionDefinition::Name(_) => { - return internal_err!( - "Aggregate function name should have been resolved" - ) - } }; Ok((agg_expr, filter, order_by)) } diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c531d7af17568..84e4cb6435a3a 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -515,9 +515,6 @@ pub enum AggregateFunctionDefinition { BuiltIn(aggregate_function::AggregateFunction), /// Resolved to a user defined aggregate function UDF(Arc), - /// A aggregation function constructed with name. This variant can not be executed directly - /// and instead must be resolved to one of the other variants prior to physical planning. - Name(Arc), } impl AggregateFunctionDefinition { @@ -526,7 +523,6 @@ impl AggregateFunctionDefinition { match self { AggregateFunctionDefinition::BuiltIn(fun) => fun.name(), AggregateFunctionDefinition::UDF(udf) => udf.name(), - AggregateFunctionDefinition::Name(func_name) => func_name.as_ref(), } } } @@ -1857,8 +1853,7 @@ pub(crate) fn create_name(e: &Expr) -> Result { null_treatment, }) => { let name = match func_def { - AggregateFunctionDefinition::BuiltIn(..) - | AggregateFunctionDefinition::Name(..) => { + AggregateFunctionDefinition::BuiltIn(..) => { create_function_name(func_def.name(), *distinct, args)? } AggregateFunctionDefinition::UDF(..) => { @@ -1878,8 +1873,7 @@ pub(crate) fn create_name(e: &Expr) -> Result { info += &format!(" {}", nt); } match func_def { - AggregateFunctionDefinition::BuiltIn(..) - | AggregateFunctionDefinition::Name(..) => { + AggregateFunctionDefinition::BuiltIn(..) => { Ok(format!("{}{}", name, info)) } AggregateFunctionDefinition::UDF(fun) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index ce79f9da64593..2c08dbe0429a2 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -174,9 +174,6 @@ impl ExprSchemable for Expr { AggregateFunctionDefinition::UDF(fun) => { Ok(fun.return_type(&data_types)?) } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } } } Expr::Not(_) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 710164eca3d0e..1b3b5e8fcb836 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -27,7 +27,7 @@ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, }; -use datafusion_common::{internal_err, map_until_stop_and_collect, Result}; +use datafusion_common::{map_until_stop_and_collect, Result}; impl TreeNode for Expr { fn apply_children Result>( @@ -348,9 +348,6 @@ impl TreeNode for Expr { null_treatment, ))) } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } }, )?, Expr::GroupingSet(grouping_set) => match grouping_set { diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e5c7afa10e3a2..994adf732785d 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -358,9 +358,6 @@ impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> { ), ))) } - AggregateFunctionDefinition::Name(_) => { - internal_err!("Function `Expr` with name should be resolved.") - } }, Expr::WindowFunction(WindowFunction { fun, diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index a6abec9efd8c2..3959223e68c1a 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -398,9 +398,6 @@ fn agg_exprs_evaluation_result_on_empty_batch( AggregateFunctionDefinition::UDF { .. } => { Transformed::yes(Expr::Literal(ScalarValue::Null)) } - AggregateFunctionDefinition::Name(_) => { - Transformed::yes(Expr::Literal(ScalarValue::Null)) - } }, _ => Transformed::no(expr), }; diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 4c29d7551bc66..ecdbde6faf597 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -746,12 +746,6 @@ pub fn serialize_expr( }, ))), }, - AggregateFunctionDefinition::Name(_) => { - return Err(Error::NotImplemented( - "Proto serialization error: Trying to serialize a unresolved function" - .to_string(), - )); - } }, Expr::ScalarVariable(_, _) => { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 39b2b0aa16066..db5d341bc225f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -733,9 +733,6 @@ pub fn to_substrait_agg_measure( } }) } - AggregateFunctionDefinition::Name(name) => { - internal_err!("AggregateFunctionDefinition::Name({:?}) should be resolved during `AnalyzerRule`", name) - } } } From 1eff714ef8356dc305047386ba250b62bed6a795 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 11 May 2024 12:36:29 -0400 Subject: [PATCH 6/6] Remove some Expr clones in `EliminateCrossJoin`(3%-5% faster planning) (#10430) * Remove some Expr clones in `EliminateCrossJoin` * Apply suggestions from code review Co-authored-by: comphead * fix --------- Co-authored-by: comphead --- .../optimizer/src/eliminate_cross_join.rs | 123 ++++----- datafusion/optimizer/src/join_key_set.rs | 240 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + 3 files changed, 291 insertions(+), 73 deletions(-) create mode 100644 datafusion/optimizer/src/join_key_set.rs diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index a807ee5ff2c50..923be75748037 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,11 +16,11 @@ // under the License. //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. -use std::collections::HashSet; use std::sync::Arc; use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::join_key_set::JoinKeySet; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ @@ -55,7 +55,7 @@ impl OptimizerRule for EliminateCrossJoin { plan: &LogicalPlan, config: &dyn OptimizerConfig, ) -> Result> { - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; + let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; let parent_predicate = match plan { LogicalPlan::Filter(filter) => { @@ -76,7 +76,7 @@ impl OptimizerRule for EliminateCrossJoin { extract_possible_join_keys( &filter.predicate, &mut possible_join_keys, - )?; + ); Some(&filter.predicate) } _ => { @@ -101,7 +101,7 @@ impl OptimizerRule for EliminateCrossJoin { }; // Join keys are handled locally: - let mut all_join_keys = HashSet::<(Expr, Expr)>::new(); + let mut all_join_keys = JoinKeySet::new(); let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( @@ -131,7 +131,7 @@ impl OptimizerRule for EliminateCrossJoin { .map(|f| Some(LogicalPlan::Filter(f))) } else { // Remove join expressions from filter: - match remove_join_expressions(predicate, &all_join_keys)? { + match remove_join_expressions(predicate.clone(), &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) .map(|f| Some(LogicalPlan::Filter(f))), _ => Ok(Some(left)), @@ -150,7 +150,7 @@ impl OptimizerRule for EliminateCrossJoin { /// Returns a boolean indicating whether the flattening was successful. fn try_flatten_join_inputs( plan: &LogicalPlan, - possible_join_keys: &mut Vec<(Expr, Expr)>, + possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, ) -> Result { let children = match plan { @@ -160,7 +160,7 @@ fn try_flatten_join_inputs( // issue: https://github.com/apache/datafusion/issues/4844 return Ok(false); } - possible_join_keys.extend(join.on.clone()); + possible_join_keys.insert_all(join.on.iter()); vec![&join.left, &join.right] } LogicalPlan::CrossJoin(join) => { @@ -204,8 +204,8 @@ fn try_flatten_join_inputs( fn find_inner_join( left_input: &LogicalPlan, rights: &mut Vec, - possible_join_keys: &[(Expr, Expr)], - all_join_keys: &mut HashSet<(Expr, Expr)>, + possible_join_keys: &JoinKeySet, + all_join_keys: &mut JoinKeySet, ) -> Result { for (i, right_input) in rights.iter().enumerate() { let mut join_keys = vec![]; @@ -228,7 +228,7 @@ fn find_inner_join( // Found one or more matching join keys if !join_keys.is_empty() { - all_join_keys.extend(join_keys.clone()); + all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); let join_schema = Arc::new(build_join_schema( left_input.schema(), @@ -265,90 +265,67 @@ fn find_inner_join( })) } -fn intersect( - accum: &mut Vec<(Expr, Expr)>, - vec1: &[(Expr, Expr)], - vec2: &[(Expr, Expr)], -) { - if !(vec1.is_empty() || vec2.is_empty()) { - for x1 in vec1.iter() { - for x2 in vec2.iter() { - if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == x2.1 { - accum.push((x1.0.clone(), x1.1.clone())); - } - } - } - } -} - /// Extract join keys from a WHERE clause -fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> { +fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { match op { Operator::Eq => { - // Ensure that we don't add the same Join keys multiple times - if !(accum.contains(&(*left.clone(), *right.clone())) - || accum.contains(&(*right.clone(), *left.clone()))) - { - accum.push((*left.clone(), *right.clone())); - } + // insert handles ensuring we don't add the same Join keys multiple times + join_keys.insert(left, right); } Operator::And => { - extract_possible_join_keys(left, accum)?; - extract_possible_join_keys(right, accum)? + extract_possible_join_keys(left, join_keys); + extract_possible_join_keys(right, join_keys) } // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. Operator::Or => { - let mut left_join_keys = vec![]; - let mut right_join_keys = vec![]; + let mut left_join_keys = JoinKeySet::new(); + let mut right_join_keys = JoinKeySet::new(); - extract_possible_join_keys(left, &mut left_join_keys)?; - extract_possible_join_keys(right, &mut right_join_keys)?; + extract_possible_join_keys(left, &mut left_join_keys); + extract_possible_join_keys(right, &mut right_join_keys); - intersect(accum, &left_join_keys, &right_join_keys) + join_keys.insert_intersection(left_join_keys, right_join_keys) } _ => (), }; } - Ok(()) } /// Remove join expressions from a filter expression -/// Returns Some() when there are few remaining predicates in filter_expr -/// Returns None otherwise -fn remove_join_expressions( - expr: &Expr, - join_keys: &HashSet<(Expr, Expr)>, -) -> Result> { +/// +/// # Returns +/// * `Some()` when there are few remaining predicates in filter_expr +/// * `None` otherwise +fn remove_join_expressions(expr: Expr, join_keys: &JoinKeySet) -> Option { match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - match op { - Operator::Eq => { - if join_keys.contains(&(*left.clone(), *right.clone())) - || join_keys.contains(&(*right.clone(), *left.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } - } - // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Operator::And | Operator::Or => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(Expr::BinaryExpr( - BinaryExpr::new(Box::new(ll), *op, Box::new(rr)), - ))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), - } - } - _ => Ok(Some(expr.clone())), + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) if join_keys.contains(&left, &right) => { + // was a join key, so remove it + None + } + // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. + Expr::BinaryExpr(BinaryExpr { left, op, right }) + if matches!(op, Operator::And | Operator::Or) => + { + let l = remove_join_expressions(*left, join_keys); + let r = remove_join_expressions(*right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryExpr(BinaryExpr::new( + Box::new(ll), + op, + Box::new(rr), + ))), + (Some(ll), _) => Some(ll), + (_, Some(rr)) => Some(rr), + _ => None, } } - _ => Ok(Some(expr.clone())), + + _ => Some(expr), } } diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs new file mode 100644 index 0000000000000..c47afa012c174 --- /dev/null +++ b/datafusion/optimizer/src/join_key_set.rs @@ -0,0 +1,240 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [JoinKeySet] for tracking the set of join keys in a plan. + +use datafusion_expr::Expr; +use indexmap::{Equivalent, IndexSet}; + +/// Tracks a set of equality Join keys +/// +/// A join key is an expression that is used to join two tables via an equality +/// predicate such as `a.x = b.y` +/// +/// This struct models `a.x + 5 = b.y AND a.z = b.z` as two join keys +/// 1. `(a.x + 5, b.y)` +/// 2. `(a.z, b.z)` +/// +/// # Important properties: +/// +/// 1. Retains insert order +/// 2. Can quickly look up if a pair of expressions are in the set. +#[derive(Debug)] +pub struct JoinKeySet { + inner: IndexSet<(Expr, Expr)>, +} + +impl JoinKeySet { + /// Create a new empty set + pub fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Return true if the set contains a join pair + /// where left = right or right = left + pub fn contains(&self, left: &Expr, right: &Expr) -> bool { + self.inner.contains(&ExprPair::new(left, right)) + || self.inner.contains(&ExprPair::new(right, left)) + } + + /// Insert the join key `(left = right)` into the set if join pair `(right = + /// left)` is not already in the set + /// + /// returns true if the pair was inserted + pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool { + if self.contains(left, right) { + false + } else { + self.inner.insert((left.clone(), right.clone())); + true + } + } + + /// Inserts potentially many join keys into the set, copying only when necessary + /// + /// returns true if any of the pairs were inserted + pub fn insert_all<'a>( + &mut self, + iter: impl Iterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter { + inserted |= self.insert(left, right); + } + inserted + } + + /// Inserts any join keys that are common to both `s1` and `s2` into self + pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) { + // note can't use inner.intersection as we need to consider both (l, r) + // and (r, l) in equality + for (left, right) in s1.inner.iter() { + if s2.contains(left, right) { + self.insert(left, right); + } + } + } + + /// returns true if this set is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Return the length of this set + #[cfg(test)] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Return an iterator over the join keys in this set + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(l, r)| (l, r)) + } +} + +/// Custom comparison operation to avoid copying owned values +/// +/// This behaves like a `(Expr, Expr)` tuple for hashing and comparison, but +/// avoids copying the values simply to comparing them. + +#[derive(Debug, Eq, PartialEq, Hash)] +struct ExprPair<'a>(&'a Expr, &'a Expr); + +impl<'a> ExprPair<'a> { + fn new(left: &'a Expr, right: &'a Expr) -> Self { + Self(left, right) + } +} + +impl<'a> Equivalent<(Expr, Expr)> for ExprPair<'a> { + fn equivalent(&self, other: &(Expr, Expr)) -> bool { + self.0 == &other.0 && self.1 == &other.1 + } +} + +#[cfg(test)] +mod test { + use crate::join_key_set::JoinKeySet; + use datafusion_expr::{col, Expr}; + + #[test] + fn test_insert() { + let mut set = JoinKeySet::new(); + // new sets should be empty + assert!(set.is_empty()); + + // insert (a = b) + assert!(set.insert(&col("a"), &col("b"))); + assert!(!set.is_empty()); + + // insert (a=b) again returns false + assert!(!set.insert(&col("a"), &col("b"))); + assert_eq!(set.len(), 1); + + // insert (b = a) , should be considered equivalent + assert!(!set.insert(&col("b"), &col("a"))); + assert_eq!(set.len(), 1); + + // insert (a = c) should be considered different + assert!(set.insert(&col("a"), &col("c"))); + assert_eq!(set.len(), 2); + } + + #[test] + fn test_contains() { + let mut set = JoinKeySet::new(); + assert!(set.insert(&col("a"), &col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + + assert!(set.insert(&col("a"), &col("c"))); + assert!(set.contains(&col("a"), &col("c"))); + assert!(set.contains(&col("c"), &col("a"))); + } + + #[test] + fn test_iterator() { + // put in c = a and + let mut set = JoinKeySet::new(); + // put in c = a , b = c, and a = c and expect to get only the first 2 + set.insert(&col("c"), &col("a")); + set.insert(&col("b"), &col("c")); + set.insert(&col("a"), &col("c")); + assert_contents(&set, vec![(&col("c"), &col("a")), (&col("b"), &col("c"))]); + } + + #[test] + fn test_insert_intersection() { + // a = b, b = c, c = d + let mut set1 = JoinKeySet::new(); + set1.insert(&col("a"), &col("b")); + set1.insert(&col("b"), &col("c")); + set1.insert(&col("c"), &col("d")); + + // a = a, b = b, b = c, d = c + // should only intersect on b = c and c = d + let mut set2 = JoinKeySet::new(); + set2.insert(&col("a"), &col("a")); + set2.insert(&col("b"), &col("b")); + set2.insert(&col("b"), &col("c")); + set2.insert(&col("d"), &col("c")); + + let mut set = JoinKeySet::new(); + // put something in there already + set.insert(&col("x"), &col("y")); + set.insert_intersection(set1, set2); + + assert_contents( + &set, + vec![ + (&col("x"), &col("y")), + (&col("b"), &col("c")), + (&col("c"), &col("d")), + ], + ); + } + + fn assert_contents(set: &JoinKeySet, expected: Vec<(&Expr, &Expr)>) { + let contents: Vec<_> = set.iter().collect(); + assert_eq!(contents, expected); + } + + #[test] + fn test_insert_many() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all( + vec![ + &(col("a"), col("b")), + &(col("b"), col("c")), + &(col("b"), col("a")), + ] + .into_iter(), + ); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 9176d67c1d18a..793c87f8bc0c7 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -62,6 +62,7 @@ pub use analyzer::{Analyzer, AnalyzerRule}; pub use optimizer::{Optimizer, OptimizerConfig, OptimizerContext, OptimizerRule}; pub use utils::optimize_children; +pub(crate) mod join_key_set; mod plan_signature; #[cfg(test)]