Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert Average to UDAF #10942 #10964

Merged
merged 22 commits into from
Jun 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion-examples/examples/dataframe_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use arrow_schema::DataType;
use std::sync::Arc;

use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg;
use datafusion::prelude::*;
use datafusion::test_util::arrow_test_data;
use datafusion_common::ScalarValue;
Expand Down
38 changes: 18 additions & 20 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,21 @@
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use arrow_schema::{Field, Schema};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::function::{AggregateFunctionSimplification, StateFieldsArgs};
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
expr::AggregateFunction, function::AccumulatorArgs, Accumulator, AggregateUDF,
AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
Expand Down Expand Up @@ -92,18 +92,16 @@ impl AggregateUDFImpl for BetterAvgUdaf {
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
avg_udaf(),
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
aggregate_function.args,
aggregate_function.distinct,
aggregate_function.filter,
aggregate_function.order_by,
aggregate_function.null_treatment,
)))
};

Some(Box::new(simplify))
Expand Down
10 changes: 5 additions & 5 deletions datafusion-examples/examples/simplify_udwf_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
use std::any::Any;

use arrow_schema::DataType;

use datafusion::execution::context::SessionContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::{error::Result, execution::options::CsvReadOptions};
use datafusion_expr::function::WindowFunctionSimplification;
use datafusion_expr::{
expr::WindowFunction, simplify::SimplifyInfo, AggregateFunction, Expr,
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
Volatility, WindowUDF, WindowUDFImpl,
};

/// This UDWF will show how to use the WindowUDFImpl::simplify() API
Expand Down Expand Up @@ -71,9 +73,7 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
let simplify = |window_function: datafusion_expr::expr::WindowFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::WindowFunction(WindowFunction {
fun: datafusion_expr::WindowFunctionDefinition::AggregateFunction(
AggregateFunction::Avg,
),
fun: datafusion_expr::WindowFunctionDefinition::AggregateUDF(avg_udaf()),
args: window_function.args,
partition_by: window_function.partition_by,
order_by: window_function.order_by,
Expand Down
12 changes: 5 additions & 7 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ use datafusion_common::config::{CsvOptions, FormatOptions, JsonOptions};
use datafusion_common::{
plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions,
};
use datafusion_expr::lit;
use datafusion_expr::{case, is_null, lit};
use datafusion_expr::{
avg, max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown,
UNNAMED_TABLE,
max, min, utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::{count, median, stddev, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, median, stddev, sum};

use async_trait::async_trait;

Expand Down Expand Up @@ -534,7 +532,7 @@ impl DataFrame {
/// # async fn main() -> Result<()> {
/// let ctx = SessionContext::new();
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?
/// // Return a single row (a, b) for each distinct value of a
/// // Return a single row (a, b) for each distinct value of a
/// .distinct_on(vec![col("a")], vec![col("a"), col("b")], None)?;
/// # Ok(())
/// # }
Expand Down Expand Up @@ -2018,7 +2016,7 @@ mod tests {

assert_batches_sorted_eq!(
["+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | AVG(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
"| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |",
"+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+",
"| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |",
"| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |",
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
array_agg, avg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col,
placeholder, scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame,
WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
array_agg, cast, col, exists, expr, in_subquery, lit, max, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunctionDefinition,
};
use datafusion_functions_aggregate::expr_fn::{count, sum};
use datafusion_functions_aggregate::expr_fn::{avg, count, sum};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ use datafusion_expr::{
create_udaf, function::AccumulatorArgs, AggregateUDFImpl, GroupsAccumulator,
SimpleAggregateUDF,
};
use datafusion_physical_expr::expressions::AvgAccumulator;
use datafusion_functions_aggregate::average::AvgAccumulator;

/// Test to show the contents of the setup
#[tokio::test]
async fn test_setup() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c11)) |",
"| avg(custom_sqrt(aggregate_test_100.c11)) |",
"+------------------------------------------+",
"| 0.6584408483418835 |",
"+------------------------------------------+",
Expand All @@ -69,7 +69,7 @@ async fn csv_query_avg_sqrt() -> Result<()> {
let actual = plan_and_collect(&ctx, sql).await.unwrap();
let expected = [
"+------------------------------------------+",
"| AVG(custom_sqrt(aggregate_test_100.c12)) |",
"| avg(custom_sqrt(aggregate_test_100.c12)) |",
"+------------------------------------------+",
"| 0.6706002946036459 |",
"+------------------------------------------+",
Expand Down
22 changes: 0 additions & 22 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ pub enum AggregateFunction {
Min,
/// Maximum
Max,
/// Average
Avg,
/// Aggregation into an array
ArrayAgg,
/// N'th value in a group according to some ordering
Expand All @@ -55,7 +53,6 @@ impl AggregateFunction {
match self {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
Expand All @@ -75,9 +72,7 @@ impl FromStr for AggregateFunction {
fn from_str(name: &str) -> Result<AggregateFunction> {
Ok(match name {
// general
"avg" => AggregateFunction::Avg,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"min" => AggregateFunction::Min,
"array_agg" => AggregateFunction::ArrayAgg,
"nth_value" => AggregateFunction::NthValue,
Expand Down Expand Up @@ -123,7 +118,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
coerced_data_types[0].clone(),
Expand All @@ -135,19 +129,6 @@ impl AggregateFunction {
}
}

/// Returns the internal sum datatype of the avg aggregate function.
pub fn sum_type_of_avg(input_expr_types: &[DataType]) -> Result<DataType> {
// Note that this function *must* return the same type that the respective physical expression returns
// or the execution panics.
let fun = AggregateFunction::Avg;
let coerced_data_types = crate::type_coercion::aggregates::coerce_types(
&fun,
input_expr_types,
&fun.signature(),
)?;
avg_sum_type(&coerced_data_types[0])
}

impl AggregateFunction {
/// the signatures supported by the function `fun`.
pub fn signature(&self) -> Signature {
Expand All @@ -168,9 +149,6 @@ impl AggregateFunction {
.collect::<Vec<_>>();
Signature::uniform(1, valid, Volatility::Immutable)
}
AggregateFunction::Avg => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
Expand Down
7 changes: 0 additions & 7 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2280,7 +2280,6 @@ mod test {
"nth_value",
"min",
"max",
"avg",
];
for name in names {
let fun = find_df_window_func(name).unwrap();
Expand Down Expand Up @@ -2309,12 +2308,6 @@ mod test {
aggregate_function::AggregateFunction::Min
))
);
assert_eq!(
find_df_window_func("avg"),
Some(WindowFunctionDefinition::AggregateFunction(
aggregate_function::AggregateFunction::Avg
))
);
assert_eq!(
find_df_window_func("cume_dist"),
Some(WindowFunctionDefinition::BuiltInWindowFunction(
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,6 @@ pub fn array_agg(expr: Expr) -> Expr {
))
}

/// Create an expression to represent the avg() aggregate function
pub fn avg(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Avg,
vec![expr],
false,
None,
None,
None,
))
}

/// Return a new expression with bitwise AND
pub fn bitwise_and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr(BinaryExpr::new(
Expand Down
8 changes: 4 additions & 4 deletions datafusion/expr/src/expr_rewriter/order_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ mod test {
use arrow::datatypes::{DataType, Field, Schema};

use crate::{
avg, cast, col, lit, logical_plan::builder::LogicalTableSource, min, try_cast,
LogicalPlanBuilder,
cast, col, lit, logical_plan::builder::LogicalTableSource, min,
test::function_stub::avg, try_cast, LogicalPlanBuilder,
};

use super::*;
Expand Down Expand Up @@ -246,9 +246,9 @@ mod test {
expected: sort(col("c1") + col("MIN(t.c2)")),
},
TestCase {
desc: r#"avg(c3) --> "AVG(t.c3)" as average (column *named* "AVG(t.c3)", aliased)"#,
desc: r#"avg(c3) --> "avg(t.c3)" as average (column *named* "avg(t.c3)", aliased)"#,
input: sort(avg(col("c3"))),
expected: sort(col("AVG(t.c3)").alias("average")),
expected: sort(col("avg(t.c3)").alias("average")),
},
];

Expand Down
Loading