Skip to content

Commit

Permalink
implement Inner
Browse files Browse the repository at this point in the history
  • Loading branch information
guojidan committed Jan 8, 2024
1 parent 1a310fa commit 3b08c78
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 62 deletions.
1 change: 1 addition & 0 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature};
/// a function `accumulator` that returns the `Accumulator` instance.
///
/// To do so, we must implement the `AggregateUDFImpl` trait.
#[derive(Debug, Clone)]
struct GeoMeanUdf {
signature: Signature,
}
Expand Down
14 changes: 12 additions & 2 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use crate::function::PartitionEvaluatorFactory;
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder,
logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ScalarFunctionImplementation,
ScalarUDF, Signature, Volatility,
};
use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl};
use arrow::datatypes::DataType;
Expand Down Expand Up @@ -1059,6 +1059,16 @@ pub struct SimpleAggregateUDF {
state_type: Vec<DataType>,
}

impl Debug for SimpleAggregateUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl SimpleAggregateUDF {
/// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and
/// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility
Expand Down
152 changes: 93 additions & 59 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,38 @@ use std::sync::Arc;
///
/// For more information, please see [the examples].
///
/// 1. For simple (less performant) use cases, use [`create_udaf`] and [`simple_udaf.rs`].
///
/// 2. For advanced use cases, use [`AggregateUDFImpl`] and [`advanced_udaf.rs`].
///
/// # API Note
/// This is a separate struct from `AggregateUDFImpl` to maintain backwards
/// compatibility with the older API.
///
/// [the examples]: https://github.com/apache/arrow-datafusion/tree/main/datafusion-examples#single-process
/// [aggregate function]: https://en.wikipedia.org/wiki/Aggregate_function
/// [`Accumulator`]: crate::Accumulator
#[derive(Clone)]
pub struct AggregateUDF {
/// name
name: String,
/// Signature (input arguments)
signature: Signature,
/// Return type
return_type: ReturnTypeFunction,
/// actual implementation
accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
state_type: StateTypeFunction,
}
/// [`create_udaf`]: crate::expr_fn::create_udaf
/// [`simple_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/simple_udaf.rs
/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
impl Debug for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
#[derive(Debug, Clone)]
pub struct AggregateUDF {
inner: Arc<dyn AggregateUDFImpl>,
}

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.signature == other.signature
self.name() == other.name() && self.signature() == other.signature()
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
self.name().hash(state);
self.signature().hash(state);
}
}

Expand All @@ -98,48 +91,32 @@ impl AggregateUDF {
accumulator: &AccumulatorFactoryFunction,
state_type: &StateTypeFunction,
) -> Self {
Self {
Self::new_from_impl(AggregateUDFLegacyWrapper {
name: name.to_owned(),
signature: signature.clone(),
return_type: return_type.clone(),
accumulator: accumulator.clone(),
state_type: state_type.clone(),
}
})
}

/// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object
///
/// Note this is the same as using the `From` impl (`AggregateUDF::from`)
pub fn new_from_impl<F>(fun: F) -> AggregateUDF
where
F: AggregateUDFImpl + Send + Sync + 'static,
F: AggregateUDFImpl + 'static,
{
let arc_fun = Arc::new(fun);
let captured_self = arc_fun.clone();
let return_type: ReturnTypeFunction = Arc::new(move |arg_types| {
let return_type = captured_self.return_type(arg_types)?;
Ok(Arc::new(return_type))
});

let captured_self = arc_fun.clone();
let accumulator: AccumulatorFactoryFunction =
Arc::new(move |arg| captured_self.accumulator(arg));

let captured_self = arc_fun.clone();
let state_type: StateTypeFunction = Arc::new(move |return_type| {
let state_type = captured_self.state_type(return_type)?;
Ok(Arc::new(state_type))
});

Self {
name: arc_fun.name().to_string(),
signature: arc_fun.signature().clone(),
return_type: return_type.clone(),
accumulator,
state_type,
inner: Arc::new(fun),
}
}

/// Return the underlying [`WindowUDFImpl`] trait object for this function
pub fn inner(&self) -> Arc<dyn AggregateUDFImpl> {
self.inner.clone()
}

/// creates an [`Expr`] that calls the aggregate function.
///
/// This utility allows using the UDAF without requiring access to
Expand All @@ -155,34 +132,36 @@ impl AggregateUDF {
}

/// Returns this function's name
///
/// See [`AggregateUDFImpl::name`] for more details.
pub fn name(&self) -> &str {
&self.name
self.inner.name()
}

/// Returns this function's signature (what input types are accepted)
///
/// See [`AggregateUDFImpl::signature`] for more details.
pub fn signature(&self) -> &Signature {
&self.signature
self.inner.signature()
}

/// Return the type of the function given its input types
///
/// See [`AggregateUDFImpl::return_type`] for more details.
pub fn return_type(&self, args: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(args)?;
Ok(res.as_ref().clone())
self.inner.return_type(args)
}

/// Return an accumualator the given aggregate, given
/// its return datatype.
pub fn accumulator(&self, return_type: &DataType) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(return_type)
self.inner.accumulator(return_type)
}

/// Return the type of the intermediate state used by this aggregator, given
/// its return datatype. Supports multi-phase aggregations
pub fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
// old API returns an Arc for some reason, try and unwrap it here
let res = (self.state_type)(return_type)?;
Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone()))
self.inner.state_type(return_type)
}
}

Expand Down Expand Up @@ -212,6 +191,7 @@ where
/// # use datafusion_common::{DataFusionError, plan_err, Result};
/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility};
/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator};
/// #[derive(Debug, Clone)]
/// struct GeoMeanUdf {
/// signature: Signature
/// };
Expand Down Expand Up @@ -248,7 +228,7 @@ where
/// // Call the function `geo_mean(col)`
/// let expr = geometric_mean.call(vec![col("a")]);
/// ```
pub trait AggregateUDFImpl {
pub trait AggregateUDFImpl: Debug + Send + Sync {
/// Returns this object as an [`Any`] trait object
fn as_any(&self) -> &dyn Any;

Expand All @@ -271,3 +251,57 @@ pub trait AggregateUDFImpl {
/// accumulator's state() must match the types here.
fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>>;
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
/// of the older API
pub struct AggregateUDFLegacyWrapper {
/// name
name: String,
/// Signature (input arguments)
signature: Signature,
/// Return type
return_type: ReturnTypeFunction,
/// actual implementation
accumulator: AccumulatorFactoryFunction,
/// the accumulator's state's description as a function of the return type
state_type: StateTypeFunction,
}

impl Debug for AggregateUDFLegacyWrapper {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.debug_struct("AggregateUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl AggregateUDFImpl for AggregateUDFLegacyWrapper {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
// Old API returns an Arc of the datatype for some reason
let res = (self.return_type)(arg_types)?;
Ok(res.as_ref().clone())
}

fn accumulator(&self, arg: &DataType) -> Result<Box<dyn Accumulator>> {
(self.accumulator)(arg)
}

fn state_type(&self, return_type: &DataType) -> Result<Vec<DataType>> {
let res = (self.state_type)(return_type)?;
Ok(res.as_ref().clone())
}
}
2 changes: 1 addition & 1 deletion datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::{
///
/// 1. For simple (less performant) use cases, use [`create_udwf`] and [`simple_udwf.rs`].
///
/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udf.rs`].
/// 2. For advanced use cases, use [`WindowUDFImpl`] and [`advanced_udwf.rs`].
///
/// # API Note
/// This is a separate struct from `WindowUDFImpl` to maintain backwards
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ mod test {

#[test]
fn aggregate() -> Result<()> {
#[derive(Debug, Clone)]
struct InnerAggregateUDF {
signature: Signature,
}
Expand Down
1 change: 1 addition & 0 deletions docs/source/library-user-guide/adding-udfs.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ let geometric_mean = create_udaf(
Arc::new(vec![DataType::Float64, DataType::UInt32]),
);
```

[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html
[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html
[`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs
Expand Down

0 comments on commit 3b08c78

Please sign in to comment.