diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index 393c27678f8c..60048dd877d8 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -32,6 +32,7 @@ use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; /// a function `accumulator` that returns the `Accumulator` instance. /// /// To do so, we must implement the `AggregateUDFImpl` trait. +#[derive(Debug, Clone)] struct GeoMeanUdf { signature: Signature, } diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 8d78726f127b..53756416626c 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -25,8 +25,8 @@ use crate::function::PartitionEvaluatorFactory; use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, - BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, - ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, + BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation, + ScalarUDF, Signature, Volatility, }; use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; @@ -1059,6 +1059,16 @@ pub struct SimpleAggregateUDF { state_type: Vec, } +impl Debug for SimpleAggregateUDF { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + impl SimpleAggregateUDF { /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 5dab4a474b30..667571a3224f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -43,36 +43,29 @@ use std::sync::Arc; /// /// For more information, please see [the examples]. /// +/// 1. For simple (less performant) use cases, use [`create_udaf`] and [`simple_udaf.rs`]. +/// +/// 2. For advanced use cases, use [`AggregateUDFImpl`] and [`advanced_udaf.rs`]. +/// +/// # API Note +/// This is a separate struct from `AggregateUDFImpl` to maintain backwards +/// compatibility with the older API. +/// /// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process /// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function /// [`Accumulator`]: crate::Accumulator -#[derive(Clone)] -pub struct AggregateUDF { - /// name - name: String, - /// Signature (input arguments) - signature: Signature, - /// Return type - return_type: ReturnTypeFunction, - /// actual implementation - accumulator: AccumulatorFactoryFunction, - /// the accumulator's state's description as a function of the return type - state_type: StateTypeFunction, -} +/// [`create_udaf`]: crate::expr_fn::create_udaf +/// [`simple_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"") - .finish() - } +#[derive(Debug, Clone)] +pub struct AggregateUDF { + inner: Arc, } impl PartialEq for AggregateUDF { fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature + self.name() == other.name() && self.signature() == other.signature() } } @@ -80,8 +73,8 @@ impl Eq for AggregateUDF {} impl std::hash::Hash for AggregateUDF { fn hash(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); + self.name().hash(state); + self.signature().hash(state); } } @@ -98,13 +91,13 @@ impl AggregateUDF { accumulator: &AccumulatorFactoryFunction, state_type: &StateTypeFunction, ) -> Self { - Self { + Self::new_from_impl(AggregateUDFLegacyWrapper { name: name.to_owned(), signature: signature.clone(), return_type: return_type.clone(), accumulator: accumulator.clone(), state_type: state_type.clone(), - } + }) } /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object @@ -112,34 +105,18 @@ impl AggregateUDF { /// Note this is the same as using the `From` impl (`AggregateUDF::from`) pub fn new_from_impl(fun: F) -> AggregateUDF where - F: AggregateUDFImpl + Send + Sync + 'static, + F: AggregateUDFImpl + 'static, { - let arc_fun = Arc::new(fun); - let captured_self = arc_fun.clone(); - let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { - let return_type = captured_self.return_type(arg_types)?; - Ok(Arc::new(return_type)) - }); - - let captured_self = arc_fun.clone(); - let accumulator: AccumulatorFactoryFunction = - Arc::new(move |arg| captured_self.accumulator(arg)); - - let captured_self = arc_fun.clone(); - let state_type: StateTypeFunction = Arc::new(move |return_type| { - let state_type = captured_self.state_type(return_type)?; - Ok(Arc::new(state_type)) - }); - Self { - name: arc_fun.name().to_string(), - signature: arc_fun.signature().clone(), - return_type: return_type.clone(), - accumulator, - state_type, + inner: Arc::new(fun), } } + /// Return the underlying [`WindowUDFImpl`] trait object for this function + pub fn inner(&self) -> Arc { + self.inner.clone() + } + /// creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to @@ -155,34 +132,36 @@ impl AggregateUDF { } /// Returns this function's name + /// + /// See [`AggregateUDFImpl::name`] for more details. pub fn name(&self) -> &str { - &self.name + self.inner.name() } /// Returns this function's signature (what input types are accepted) + /// + /// See [`AggregateUDFImpl::signature`] for more details. pub fn signature(&self) -> &Signature { - &self.signature + self.inner.signature() } /// Return the type of the function given its input types + /// + /// See [`AggregateUDFImpl::return_type`] for more details. pub fn return_type(&self, args: &[DataType]) -> Result { - // Old API returns an Arc of the datatype for some reason - let res = (self.return_type)(args)?; - Ok(res.as_ref().clone()) + self.inner.return_type(args) } /// Return an accumualator the given aggregate, given /// its return datatype. pub fn accumulator(&self, return_type: &DataType) -> Result> { - (self.accumulator)(return_type) + self.inner.accumulator(return_type) } /// Return the type of the intermediate state used by this aggregator, given /// its return datatype. Supports multi-phase aggregations pub fn state_type(&self, return_type: &DataType) -> Result> { - // old API returns an Arc for some reason, try and unwrap it here - let res = (self.state_type)(return_type)?; - Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) + self.inner.state_type(return_type) } } @@ -212,6 +191,7 @@ where /// # use datafusion_common::{DataFusionError, plan_err, Result}; /// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; /// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// #[derive(Debug, Clone)] /// struct GeoMeanUdf { /// signature: Signature /// }; @@ -248,7 +228,7 @@ where /// // Call the function `geo_mean(col)` /// let expr = geometric_mean.call(vec![col("a")]); /// ``` -pub trait AggregateUDFImpl { +pub trait AggregateUDFImpl: Debug + Send + Sync { /// Returns this object as an [`Any`] trait object fn as_any(&self) -> &dyn Any; @@ -271,3 +251,57 @@ pub trait AggregateUDFImpl { /// accumulator's state() must match the types here. fn state_type(&self, return_type: &DataType) -> Result>; } + +/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers +/// of the older API +pub struct AggregateUDFLegacyWrapper { + /// name + name: String, + /// Signature (input arguments) + signature: Signature, + /// Return type + return_type: ReturnTypeFunction, + /// actual implementation + accumulator: AccumulatorFactoryFunction, + /// the accumulator's state's description as a function of the return type + state_type: StateTypeFunction, +} + +impl Debug for AggregateUDFLegacyWrapper { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl AggregateUDFImpl for AggregateUDFLegacyWrapper { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Old API returns an Arc of the datatype for some reason + let res = (self.return_type)(arg_types)?; + Ok(res.as_ref().clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, return_type: &DataType) -> Result> { + let res = (self.state_type)(return_type)?; + Ok(res.as_ref().clone()) + } +} diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 239a5e24cbf2..9b8f94f4b020 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -36,7 +36,7 @@ use std::{ /// /// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`]. /// -/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`]. +/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udwf.rs`]. /// /// # API Note /// This is a separate struct from `WindowUDFImpl` to maintain backwards diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 0b5a201c7dff..0ec8a545e1a7 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -898,6 +898,7 @@ mod test { #[test] fn aggregate() -> Result<()> { + #[derive(Debug, Clone)] struct InnerAggregateUDF { signature: Signature, } diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index baa70f8fe556..64dc25411deb 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -421,6 +421,7 @@ let geometric_mean = create_udaf( Arc::new(vec![DataType::Float64, DataType::UInt32]), ); ``` + [`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html [`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs