From 742e3c5b3f5d4d961d50cc77033ab7774b90c56f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9E=97=E4=BC=9F?= Date: Wed, 8 May 2024 00:49:57 +0800 Subject: [PATCH] Remove ScalarFunctionDefinition (#10325) * Remove ScalarFunctionDefinition * Fix test * rename func_def to func --------- Co-authored-by: Andrew Lamb --- .../core/src/datasource/listing/helpers.rs | 18 +++---- .../physical_optimizer/projection_pushdown.rs | 19 ++----- datafusion/expr/src/expr.rs | 54 ++++--------------- datafusion/expr/src/expr_schema.rs | 16 +++--- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/tree_node.rs | 12 ++--- datafusion/functions-array/src/rewrite.rs | 8 +-- datafusion/functions/src/math/log.rs | 19 ++----- datafusion/functions/src/math/power.rs | 17 ++---- datafusion/functions/src/string/concat.rs | 4 +- datafusion/functions/src/string/concat_ws.rs | 4 +- .../optimizer/src/analyzer/type_coercion.rs | 24 ++++----- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 7 +-- .../simplify_expressions/expr_simplifier.rs | 32 +++++------ datafusion/physical-expr/src/planner.rs | 22 ++++---- .../physical-expr/src/scalar_function.rs | 36 +++++-------- datafusion/proto/src/logical_plan/to_proto.rs | 34 +++++------- .../proto/src/physical_plan/from_proto.rs | 3 +- .../proto/src/physical_plan/to_proto.rs | 7 +-- .../tests/cases/roundtrip_physical_plan.rs | 9 ++-- datafusion/sql/src/unparser/expr.rs | 4 +- 22 files changed, 124 insertions(+), 231 deletions(-) diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 09d9aa881133..0cffa0513171 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -38,7 +38,7 @@ use log::{debug, trace}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DFSchema, DataFusionError}; -use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{Expr, Volatility}; use datafusion_physical_expr::create_physical_expr; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; @@ -89,16 +89,12 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { - match &scalar_function.func_def { - ScalarFunctionDefinition::UDF(fun) => { - match fun.signature().volatility { - Volatility::Immutable => Ok(TreeNodeRecursion::Continue), - // TODO: Stable functions could be `applicable`, but that would require access to the context - Volatility::Stable | Volatility::Volatile => { - is_applicable = false; - Ok(TreeNodeRecursion::Stop) - } - } + match scalar_function.func.signature().volatility { + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), + // TODO: Stable functions could be `applicable`, but that would require access to the context + Volatility::Stable | Volatility::Volatile => { + is_applicable = false; + Ok(TreeNodeRecursion::Stop) } } } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 160dd3a1c4ee..0190f35cc97b 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1301,8 +1301,7 @@ mod tests { use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_expr::{ - ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Operator, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, NegativeExpr, @@ -1363,9 +1362,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1431,9 +1428,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1502,9 +1497,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1570,9 +1563,7 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( - DummyUDF::new(), - ))), + Arc::new(ScalarUDF::new_from_impl(DummyUDF::new())), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c154cd999a2b..9789dd345faa 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -26,11 +26,11 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::expr_to_columns; -use crate::window_frame; use crate::{ aggregate_function, built_in_window_function, udaf, ExprSchemable, Operator, Signature, }; +use crate::{window_frame, Volatility}; use arrow::datatypes::{DataType, FieldRef}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -399,18 +399,11 @@ impl Between { } } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -/// Defines which implementation of a function for DataFusion to call. -pub enum ScalarFunctionDefinition { - /// Resolved to a user defined function - UDF(Arc), -} - /// ScalarFunction expression invokes a built-in scalar function #[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct ScalarFunction { /// The function - pub func_def: ScalarFunctionDefinition, + pub func: Arc, /// List of expressions to feed to the functions as arguments pub args: Vec, } @@ -418,41 +411,14 @@ pub struct ScalarFunction { impl ScalarFunction { // return the Function's name pub fn name(&self) -> &str { - self.func_def.name() - } -} - -impl ScalarFunctionDefinition { - /// Function's name for display - pub fn name(&self) -> &str { - match self { - ScalarFunctionDefinition::UDF(udf) => udf.name(), - } - } - - /// Whether this function is volatile, i.e. whether it can return different results - /// when evaluated multiple times with the same input. - pub fn is_volatile(&self) -> Result { - match self { - ScalarFunctionDefinition::UDF(udf) => { - Ok(udf.signature().volatility == crate::Volatility::Volatile) - } - } + self.func.name() } } impl ScalarFunction { /// Create a new ScalarFunction expression with a user-defined function (UDF) pub fn new_udf(udf: Arc, args: Vec) -> Self { - Self { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - } - } - - /// Create a new ScalarFunction expression with a user-defined function (UDF) - pub fn new_func_def(func_def: ScalarFunctionDefinition, args: Vec) -> Self { - Self { func_def, args } + Self { func: udf, args } } } @@ -1299,7 +1265,7 @@ impl Expr { /// results when evaluated multiple times with the same input. pub fn is_volatile(&self) -> Result { self.exists(|expr| { - Ok(matches!(expr, Expr::ScalarFunction(func) if func.func_def.is_volatile()?)) + Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile )) }) } @@ -1334,9 +1300,7 @@ impl Expr { /// and thus any side effects (like divide by zero) may not be encountered pub fn short_circuits(&self) -> bool { match self { - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { - matches!(func_def, ScalarFunctionDefinition::UDF(fun) if fun.short_circuits()) - } + Expr::ScalarFunction(ScalarFunction { func, .. }) => func.short_circuits(), Expr::BinaryExpr(BinaryExpr { op, .. }) => { matches!(op, Operator::And | Operator::Or) } @@ -2071,7 +2035,7 @@ mod test { } #[test] - fn test_is_volatile_scalar_func_definition() { + fn test_is_volatile_scalar_func() { // UDF #[derive(Debug)] struct TestScalarUDF { @@ -2100,7 +2064,7 @@ mod test { let udf = Arc::new(ScalarUDF::from(TestScalarUDF { signature: Signature::uniform(1, vec![DataType::Float32], Volatility::Stable), })); - assert!(!ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + assert_ne!(udf.signature().volatility, Volatility::Volatile); let udf = Arc::new(ScalarUDF::from(TestScalarUDF { signature: Signature::uniform( @@ -2109,7 +2073,7 @@ mod test { Volatility::Volatile, ), })); - assert!(ScalarFunctionDefinition::UDF(udf).is_volatile().unwrap()); + assert_eq!(udf.signature().volatility, Volatility::Volatile); } use super::*; diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index f93f08574906..4aca52d67c4f 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -19,7 +19,7 @@ use super::{Between, Expr, Like}; use crate::expr::{ AggregateFunction, AggregateFunctionDefinition, Alias, BinaryExpr, Cast, GetFieldAccess, GetIndexedField, InList, InSubquery, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction, + Sort, TryCast, Unnest, WindowFunction, }; use crate::field_util::GetFieldAccessSchema; use crate::type_coercion::binary::get_result_type; @@ -133,20 +133,18 @@ impl ExprSchemable for Expr { } } } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }) => { let arg_data_types = args .iter() .map(|e| e.get_type(schema)) .collect::>>()?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => { // verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` - data_types(&arg_data_types, fun.signature()).map_err(|_| { + data_types(&arg_data_types, func.signature()).map_err(|_| { plan_datafusion_err!( "{}", utils::generate_signature_error_msg( - fun.name(), - fun.signature().clone(), + func.name(), + func.signature().clone(), &arg_data_types, ) ) @@ -154,9 +152,7 @@ impl ExprSchemable for Expr { // perform additional function arguments validation (due to limited // expressiveness of `TypeSignature`), then infer return type - Ok(fun.return_type_from_exprs(args, schema, &arg_data_types)?) - } - } + Ok(func.return_type_from_exprs(args, schema, &arg_data_types)?) } Expr::WindowFunction(WindowFunction { fun, args, .. }) => { let data_types = args diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index de4f31029293..e2b68388abb9 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -63,7 +63,7 @@ pub use built_in_window_function::BuiltInWindowFunction; pub use columnar_value::ColumnarValue; pub use expr::{ Between, BinaryExpr, Case, Cast, Expr, GetFieldAccess, GetIndexedField, GroupingSet, - Like, ScalarFunctionDefinition, TryCast, WindowFunctionDefinition, + Like, TryCast, WindowFunctionDefinition, }; pub use expr_fn::*; pub use expr_schema::ExprSchemable; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index ae3ca9afc4f5..710164eca3d0 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -20,7 +20,7 @@ use crate::expr::{ AggregateFunction, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Case, Cast, GetIndexedField, GroupingSet, InList, InSubquery, Like, Placeholder, - ScalarFunction, ScalarFunctionDefinition, Sort, TryCast, Unnest, WindowFunction, + ScalarFunction, Sort, TryCast, Unnest, WindowFunction, }; use crate::{Expr, GetFieldAccess}; @@ -281,11 +281,11 @@ impl TreeNode for Expr { nulls_first, }) => transform_box(expr, &mut f)? .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - transform_vec(args, &mut f)?.map_data(|new_args| match func_def { - ScalarFunctionDefinition::UDF(fun) => { - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) - } + Expr::ScalarFunction(ScalarFunction { func, args }) => { + transform_vec(args, &mut f)?.map_data(|new_args| { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + func, new_args, + ))) })? } Expr::WindowFunction(WindowFunction { diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 32d15b5563a5..416e79cbc079 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -182,20 +182,20 @@ impl FunctionRewrite for ArrayFunctionRewriter { /// Returns true if expr is a function call to the specified named function. /// Returns false otherwise. fn is_func(expr: &Expr, func_name: &str) -> bool { - let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else { + let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { return false; }; - func_def.name() == func_name + func.name() == func_name } /// Returns true if expr is a function call with one of the specified names fn is_one_of_func(expr: &Expr, func_names: &[&str]) -> bool { - let Expr::ScalarFunction(ScalarFunction { func_def, args: _ }) = expr else { + let Expr::ScalarFunction(ScalarFunction { func, args: _ }) = expr else { return false; }; - func_names.contains(&func_def.name()) + func_names.contains(&func.name()) } /// returns Some(col) if this is Expr::Column diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index f451321ea120..e6c698ad1a80 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -24,9 +24,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ - lit, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition, -}; +use datafusion_expr::{lit, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF}; use arrow::array::{ArrayRef, Float32Array, Float64Array}; use datafusion_expr::TypeSignature::*; @@ -178,8 +176,8 @@ impl ScalarUDFImpl for LogFunc { &info.get_data_type(&base)?, )?))) } - Expr::ScalarFunction(ScalarFunction { func_def, mut args }) - if is_pow(&func_def) && args.len() == 2 && base == args[0] => + Expr::ScalarFunction(ScalarFunction { func, mut args }) + if is_pow(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above Ok(ExprSimplifyResult::Simplified(b)) @@ -207,15 +205,8 @@ impl ScalarUDFImpl for LogFunc { } /// Returns true if the function is `PowerFunc` -fn is_pow(func_def: &ScalarFunctionDefinition) -> bool { - match func_def { - ScalarFunctionDefinition::UDF(fun) => fun - .as_ref() - .inner() - .as_any() - .downcast_ref::() - .is_some(), - } +fn is_pow(func: &ScalarUDF) -> bool { + func.inner().as_any().downcast_ref::().is_some() } #[cfg(test)] diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 8cc6b4c02aeb..7677e8b2af95 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -23,7 +23,7 @@ use datafusion_common::{ }; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionDefinition}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDF}; use arrow::array::{ArrayRef, Float64Array, Int64Array}; use datafusion_expr::TypeSignature::*; @@ -140,8 +140,8 @@ impl ScalarUDFImpl for PowerFunc { Expr::Literal(value) if value == ScalarValue::new_one(&exponent_type)? => { Ok(ExprSimplifyResult::Simplified(base)) } - Expr::ScalarFunction(ScalarFunction { func_def, mut args }) - if is_log(&func_def) && args.len() == 2 && base == args[0] => + Expr::ScalarFunction(ScalarFunction { func, mut args }) + if is_log(&func) && args.len() == 2 && base == args[0] => { let b = args.pop().unwrap(); // length checked above Ok(ExprSimplifyResult::Simplified(b)) @@ -152,15 +152,8 @@ impl ScalarUDFImpl for PowerFunc { } /// Return true if this function call is a call to `Log` -fn is_log(func_def: &ScalarFunctionDefinition) -> bool { - match func_def { - ScalarFunctionDefinition::UDF(fun) => fun - .as_ref() - .inner() - .as_any() - .downcast_ref::() - .is_some(), - } +fn is_log(func: &ScalarUDF) -> bool { + func.inner().as_any().downcast_ref::().is_some() } #[cfg(test)] diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 55b7c2f22249..6d15e2206721 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -25,7 +25,7 @@ use datafusion_common::cast::as_string_array; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; @@ -182,7 +182,7 @@ pub fn simplify_concat(args: Vec) -> Result { if !args.eq(&new_args) { Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(concat()), + func: concat(), args: new_args, }, ))) diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 1d27712b2c93..4d05f4e707b1 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -26,7 +26,7 @@ use datafusion_common::cast::as_string_array; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{lit, ColumnarValue, Expr, ScalarFunctionDefinition, Volatility}; +use datafusion_expr::{lit, ColumnarValue, Expr, Volatility}; use datafusion_expr::{ScalarUDFImpl, Signature}; use crate::string::common::*; @@ -266,7 +266,7 @@ fn simplify_concat_ws(delimiter: &Expr, args: &[Expr]) -> Result TreeNodeRewriter for TypeCoercionRewriter<'a> { let case = coerce_case_expression(case, self.schema)?; Ok(Transformed::yes(Expr::Case(case))) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::UDF(fun) => { - let new_expr = coerce_arguments_for_signature( - args, - self.schema, - fun.signature(), - )?; - let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?; - Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new_udf(fun, new_expr), - ))) - } - }, + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let new_expr = + coerce_arguments_for_signature(args, self.schema, func.signature())?; + let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &func)?; + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(func, new_expr), + ))) + } Expr::AggregateFunction(expr::AggregateFunction { func_def, args, diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 0f2aaa6cbcb3..aa2d00537940 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -555,8 +555,8 @@ fn rewrite_expr(expr: &Expr, input: &Projection) -> Result> { .map(|expr| rewrite_expr(expr, input)) .collect::>>()? .map(|new_args| { - Expr::ScalarFunction(ScalarFunction::new_func_def( - scalar_fn.func_def.clone(), + Expr::ScalarFunction(ScalarFunction::new_udf( + scalar_fn.func.clone(), new_args, )) })); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 2355ee604e2a..f58345237bee 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -35,7 +35,7 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - ScalarFunctionDefinition, TableProviderFilterPushDown, + TableProviderFilterPushDown, }; use crate::optimizer::ApplyOrder; @@ -228,10 +228,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::ScalarSubquery(_) | Expr::OuterReferenceColumn(_, _) | Expr::Unnest(_) - | Expr::ScalarFunction(datafusion_expr::expr::ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(_), - .. - }) => { + | Expr::ScalarFunction(_) => { is_evaluate = false; Ok(TreeNodeRecursion::Stop) } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 0f711d6a2c6d..5122de4f09a7 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -35,8 +35,7 @@ use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarV use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, - ScalarFunctionDefinition, Volatility, + and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; @@ -595,11 +594,9 @@ impl<'a> ConstEvaluator<'a> { | Expr::GroupingSet(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) => false, - Expr::ScalarFunction(ScalarFunction { func_def, .. }) => match func_def { - ScalarFunctionDefinition::UDF(fun) => { - Self::volatility_ok(fun.signature().volatility) - } - }, + Expr::ScalarFunction(ScalarFunction { func, .. }) => { + Self::volatility_ok(func.signature().volatility) + } Expr::Literal(_) | Expr::Unnest(_) | Expr::BinaryExpr { .. } @@ -1373,18 +1370,17 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // Do a first pass at simplification out_expr.rewrite(self)? } - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - }) => match udf.simplify(args, info)? { - ExprSimplifyResult::Original(args) => { - Transformed::no(Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(udf), - args, - })) + Expr::ScalarFunction(ScalarFunction { func: udf, args }) => { + match udf.simplify(args, info)? { + ExprSimplifyResult::Original(args) => { + Transformed::no(Expr::ScalarFunction(ScalarFunction { + func: udf, + args, + })) + } + ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), } - ExprSimplifyResult::Simplified(expr) => Transformed::yes(expr), - }, + } // // Rules for Between diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 2621b817b2da..ab57a8e80056 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -28,7 +28,7 @@ use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ binary_expr, Between, BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Like, - Operator, ScalarFunctionDefinition, TryCast, + Operator, TryCast, }; use crate::scalar_function; @@ -305,21 +305,17 @@ pub fn create_physical_expr( } }, - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }) => { let physical_args = create_physical_exprs(args, input_dfschema, execution_props)?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => { - scalar_function::create_physical_expr( - fun.clone().as_ref(), - &physical_args, - input_schema, - args, - input_dfschema, - ) - } - } + scalar_function::create_physical_expr( + func.clone().as_ref(), + &physical_args, + input_schema, + args, + input_dfschema, + ) } Expr::Between(Between { expr, diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 6b84b81e9fae..180f2a7946bd 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -40,10 +40,7 @@ use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, DFSchema, Result}; use datafusion_expr::type_coercion::functions::data_types; -use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, ScalarFunctionDefinition, - ScalarUDF, -}; +use datafusion_expr::{expr_vec_fmt, ColumnarValue, Expr, FuncMonotonicity, ScalarUDF}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; @@ -51,7 +48,7 @@ use crate::PhysicalExpr; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionDefinition, + fun: Arc, name: String, args: Vec>, return_type: DataType, @@ -78,7 +75,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionDefinition, + fun: Arc, args: Vec>, return_type: DataType, monotonicity: Option, @@ -93,7 +90,7 @@ impl ScalarFunctionExpr { } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionDefinition { + pub fn fun(&self) -> &ScalarUDF { &self.fun } @@ -146,22 +143,18 @@ impl PhysicalExpr for ScalarFunctionExpr { .collect::>>()?; // evaluate the function - match self.fun { - ScalarFunctionDefinition::UDF(ref fun) => { - let output = match self.args.is_empty() { - true => fun.invoke_no_args(batch.num_rows()), - false => fun.invoke(&inputs), - }?; - - if let ColumnarValue::Array(array) = &output { - if array.len() != batch.num_rows() { - return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", + let output = match self.args.is_empty() { + true => self.fun.invoke_no_args(batch.num_rows()), + false => self.fun.invoke(&inputs), + }?; + + if let ColumnarValue::Array(array) = &output { + if array.len() != batch.num_rows() { + return internal_err!("UDF returned a different number of rows than expected. Expected: {}, Got: {}", batch.num_rows(), array.len()); - } - } - Ok(output) } } + Ok(output) } fn children(&self) -> Vec> { @@ -233,10 +226,9 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun_def, + Arc::new(fun.clone()), input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index dcec2a3b8595..80acd12e4e60 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -36,8 +36,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::{ self, AggregateFunctionDefinition, Alias, Between, BinaryExpr, Cast, GetFieldAccess, - GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, - ScalarFunctionDefinition, Sort, Unnest, + GetIndexedField, GroupingSet, InList, Like, Placeholder, ScalarFunction, Sort, + Unnest, }; use datafusion_expr::{ logical_plan::PlanType, logical_plan::StringifiedPlan, AggregateFunction, @@ -763,25 +763,19 @@ pub fn serialize_expr( "Proto serialization error: Scalar Variable not supported".to_string(), )) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + Expr::ScalarFunction(ScalarFunction { func, args }) => { let args = serialize_exprs(args, codec)?; - match func_def { - ScalarFunctionDefinition::UDF(fun) => { - let mut buf = Vec::new(); - let _ = codec.try_encode_udf(fun.as_ref(), &mut buf); - - let fun_definition = if buf.is_empty() { None } else { Some(buf) }; - - protobuf::LogicalExprNode { - expr_type: Some(ExprType::ScalarUdfExpr( - protobuf::ScalarUdfExprNode { - fun_name: fun.name().to_string(), - fun_definition, - args, - }, - )), - } - } + let mut buf = Vec::new(); + let _ = codec.try_encode_udf(func.as_ref(), &mut buf); + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; + + protobuf::LogicalExprNode { + expr_type: Some(ExprType::ScalarUdfExpr(protobuf::ScalarUdfExprNode { + fun_name: func.name().to_string(), + fun_definition, + args, + })), } } Expr::Not(expr) => { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 33e632b0d942..4bd07fae497f 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -53,7 +53,6 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; -use datafusion_expr::ScalarFunctionDefinition; use crate::common::proto_error; use crate::convert_required; @@ -342,7 +341,7 @@ pub fn parse_physical_expr( Some(buf) => codec.try_decode_udf(&e.name, buf)?, None => registry.udf(e.name.as_str())?, }; - let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); + let scalar_fun_def = udf.clone(); let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index a0a0ee72054b..3bc71f5f4c90 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -56,7 +56,6 @@ use datafusion_common::{ stats::Precision, DataFusionError, JoinSide, Result, }; -use datafusion_expr::ScalarFunctionDefinition; use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ @@ -540,11 +539,7 @@ pub fn serialize_physical_expr( let args = serialize_physical_exprs(expr.args().to_vec(), codec)?; let mut buf = Vec::new(); - match expr.fun() { - ScalarFunctionDefinition::UDF(udf) => { - codec.try_encode_udf(udf, &mut buf)?; - } - } + codec.try_encode_udf(expr.fun(), &mut buf)?; let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5e446f93fea7..c2018352c7cf 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -75,9 +75,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, - ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, - WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, ScalarUDF, + ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, @@ -618,7 +617,7 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); - let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let fun_def = Arc::new(udf.clone()); let expr = ScalarFunctionExpr::new( "dummy", @@ -750,7 +749,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); let udf_expr = Arc::new(ScalarFunctionExpr::new( udf.name(), - ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + Arc::new(udf.clone()), vec![col("text", &schema)?], DataType::Int64, None, diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index c619c62668cc..804fa6d306b4 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -108,8 +108,8 @@ impl Unparser<'_> { negated: *negated, }) } - Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - let func_name = func_def.name(); + Expr::ScalarFunction(ScalarFunction { func, args }) => { + let func_name = func.name(); let args = args .iter()