Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluate expressions after type coercion #3444

Merged
merged 12 commits into from
Sep 12, 2022
10 changes: 5 additions & 5 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1834,11 +1834,11 @@ async fn aggregate_avg_add() -> Result<()> {
assert_eq!(results.len(), 1);

let expected = vec![
"+--------------+-------------------------+-------------------------+-------------------------+",
"| AVG(test.c1) | AVG(test.c1) + Int64(1) | AVG(test.c1) + Int64(2) | Int64(1) + AVG(test.c1) |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+-------------------------+-------------------------+-------------------------+",
"+--------------+---------------------------+---------------------------+---------------------------+",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just have the comments about the header of the expr.
The input sql is AGG(C1) + 1, 1 is the int64 data type, but the header is convert to float after casted

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have the method to make the header consistent, and it can be changed with the changes of the optimizer plan.
cc @andygrove @alamb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a concern I also have for a longer time and had a PR open once.

One approach would be to add an alias for every unnamed expression based on the original query SQL or expression.
This would avoid having the column names changed by the optimizers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of adding an alias once (maybe as the initial optimizer pass?)

I am not sure how valuable adding the types in the column names is in general, to be honest. I wouldn't mind if rather than Int(1) this was simply rendered 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the idea of adding an alias once (maybe as the initial optimizer pass?)

I am not sure how valuable adding the types in the column names is in general, to be honest. I wouldn't mind if rather than Int(1) this was simply rendered 1

Do you have plan or a draft pr for that? @Dandandan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @Dandandan was referring to #280 / #279

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes indeed, we can give those a second life 🎉

I had some concerns with the PR, but I believe it is still a big improvement over the current state of things.

"| AVG(test.c1) | AVG(test.c1) + Float64(1) | AVG(test.c1) + Float64(2) | Float64(1) + AVG(test.c1) |",
"+--------------+---------------------------+---------------------------+---------------------------+",
"| 1.5 | 2.5 | 3.5 | 2.5 |",
"+--------------+---------------------------+---------------------------+---------------------------+",
];
assert_batches_sorted_eq!(expected, &results);

Expand Down
114 changes: 57 additions & 57 deletions datafusion/core/tests/sql/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------+",
"| decimal_simple.c1 + Int64(1) |",
"+------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+------------------------------+",
"+----------------------------------------------------+",
"| decimal_simple.c1 + Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| 1.000010 |",
"| 1.000020 |",
"| 1.000020 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000030 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000040 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"| 1.000050 |",
"+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);
// array decimal(10,6) + array decimal(12,7) => decimal(13,7)
Expand Down Expand Up @@ -434,25 +434,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+------------------------------+",
"| decimal_simple.c1 - Int64(1) |",
"+------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+------------------------------+",
"+----------------------------------------------------+",
"| decimal_simple.c1 - Decimal128(Some(1000000),27,6) |",
"+----------------------------------------------------+",
"| -0.999990 |",
"| -0.999980 |",
"| -0.999980 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999970 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999960 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"| -0.999950 |",
"+----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down Expand Up @@ -492,25 +492,25 @@ async fn decimal_arithmetic_op() -> Result<()> {
actual[0].schema().field(0).data_type()
);
let expected = vec![
"+-------------------------------+",
"| decimal_simple.c1 * Int64(20) |",
"+-------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-------------------------------+",
"+-----------------------------------------------------+",
"| decimal_simple.c1 * Decimal128(Some(20000000),31,6) |",
"+-----------------------------------------------------+",
"| 0.000200 |",
"| 0.000400 |",
"| 0.000400 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000600 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.000800 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"| 0.001000 |",
"+-----------------------------------------------------+",
];
assert_batches_eq!(expected, &actual);

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ order by
let expected = "\
Sort: #revenue DESC NULLS FIRST\
\n Projection: #customer.c_custkey, #customer.c_name, #SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue, #customer.c_acctbal, #nation.n_name, #customer.c_address, #customer.c_phone, #customer.c_comment\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * CAST(Int64(1) AS Float64) - #lineitem.l_discount)]]\
\n Aggregate: groupBy=[[#customer.c_custkey, #customer.c_name, #customer.c_acctbal, #customer.c_phone, #nation.n_name, #customer.c_address, #customer.c_comment]], aggr=[[SUM(#lineitem.l_extendedprice * Float64(1) - #lineitem.l_discount)]]\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice

\n Inner Join: #customer.c_nationkey = #nation.n_nationkey\
\n Inner Join: #orders.o_orderkey = #lineitem.l_orderkey\
\n Inner Join: #customer.c_custkey = #orders.o_custkey\
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/predicates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ async fn multiple_or_predicates() -> Result<()> {
"Explain [plan_type:Utf8, plan:Utf8]",
" Projection: #lineitem.l_partkey [l_partkey:Int64]",
" Projection: #part.p_partkey = #lineitem.l_partkey AS #part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey, #part.p_size >= Int32(1) AS #part.p_size >= Int32(1)Int32(1)#part.p_size, #lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size [#part.p_partkey = #lineitem.l_partkey#lineitem.l_partkey#part.p_partkey:Boolean;N, #part.p_size >= Int32(1)Int32(1)#part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= CAST(Int64(1) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(11) AS Float64) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= CAST(Int64(10) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(20) AS Float64) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= CAST(Int64(20) AS Float64) AND #lineitem.l_quantity <= CAST(Int64(30) AS Float64) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand = Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Float64(1) AND #lineitem.l_quantity <= Float64(11) AND #part.p_size <= Int32(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= Float64(10) AND #lineitem.l_quantity <= Float64(20) AND #part.p_size <= Int32(10) OR #part.p_brand = Utf8(\"Brand#34\") AND #lineitem.l_quantity >= Float64(20) AND #lineitem.l_quantity <= Float64(30) AND #part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" CrossJoin: [l_partkey:Int64, l_quantity:Float64, p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
" TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Float64]",
" Filter: #part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/subqueries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ order by s_name;
TableScan: part projection=[p_partkey, p_name], partial_filters=[#part.p_name LIKE Utf8("forest%")]
Projection: #lineitem.l_partkey, #lineitem.l_suppkey, Float64(0.5) * #SUM(lineitem.l_quantity) AS __value, alias=__sq_3
Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], aggr=[[SUM(#lineitem.l_quantity)]]
Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)
Filter: #lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS Date32)]"#
.to_string();
assert_eq!(actual, expected);
Expand Down
69 changes: 52 additions & 17 deletions datafusion/optimizer/src/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Optimizer rule for type validation and coercion

use crate::simplify_expressions::ConstEvaluator;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::{DFSchema, DFSchemaRef, Result};
use datafusion_expr::binary_rule::coerce_types;
Expand All @@ -27,6 +28,7 @@ use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
use datafusion_physical_expr::execution_props::ExecutionProps;

#[derive(Default)]
pub struct TypeCoercion {}
Expand Down Expand Up @@ -64,7 +66,15 @@ impl OptimizerRule for TypeCoercion {
_ => DFSchemaRef::new(DFSchema::empty()),
};

let mut expr_rewrite = TypeCoercionRewriter { schema };
let mut execution_props = ExecutionProps::new();
execution_props.query_execution_start_time =
optimizer_config.query_execution_start_time;
let const_evaluator = ConstEvaluator::try_new(&execution_props)?;

let mut expr_rewrite = TypeCoercionRewriter {
schema,
const_evaluator,
};

let new_expr = plan
.expressions()
Expand All @@ -76,11 +86,12 @@ impl OptimizerRule for TypeCoercion {
}
}

struct TypeCoercionRewriter {
struct TypeCoercionRewriter<'a> {
schema: DFSchemaRef,
const_evaluator: ConstEvaluator<'a>,
}

impl ExprRewriter for TypeCoercionRewriter {
impl ExprRewriter for TypeCoercionRewriter<'_> {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}
Expand All @@ -91,22 +102,26 @@ impl ExprRewriter for TypeCoercionRewriter {
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
let coerced_type = coerce_types(&left_type, &op, &right_type)?;
Ok(Expr::BinaryExpr {

let expr = Expr::BinaryExpr {
left: Box::new(left.cast_to(&coerced_type, &self.schema)?),
op,
right: Box::new(right.cast_to(&coerced_type, &self.schema)?),
})
};

expr.rewrite(&mut self.const_evaluator)
}
Expr::ScalarUDF { fun, args } => {
let new_expr = coerce_arguments_for_signature(
args.as_slice(),
&self.schema,
&fun.signature,
)?;
Ok(Expr::ScalarUDF {
let expr = Expr::ScalarUDF {
fun,
args: new_expr,
})
};
expr.rewrite(&mut self.const_evaluator)
}
expr => Ok(expr),
}
Expand Down Expand Up @@ -145,7 +160,8 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchema, Result};
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::{col, ColumnarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
Expand All @@ -156,28 +172,40 @@ mod test {

#[test]
fn simple_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?);
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: Float64(1.2) < CAST(UInt32(2) AS Float64)\n EmptyRelation",
"Projection: #a < Float64(2)\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

#[test]
fn nested_case() -> Result<()> {
let expr = lit(1.2_f64).lt(lit(2_u32));
let expr = col("a").lt(lit(2_u32));
let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
schema: Arc::new(
DFSchema::new_with_metadata(
vec![DFField::new(None, "a", DataType::Float64, true)],
std::collections::HashMap::new(),
)
.unwrap(),
),
}));
let plan = LogicalPlan::Projection(Projection::try_new(
vec![expr.clone().or(expr)],
Expand All @@ -187,8 +215,11 @@ mod test {
let rule = TypeCoercion::new();
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!("Projection: Float64(1.2) < CAST(UInt32(2) AS Float64) OR Float64(1.2) < CAST(UInt32(2) AS Float64)\
\n EmptyRelation", &format!("{:?}", plan));
assert_eq!(
"Projection: #a < Float64(2) OR #a < Float64(2)\
\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
}

Expand All @@ -197,7 +228,11 @@ mod test {
let empty = empty();
let return_type: ReturnTypeFunction =
Arc::new(move |_| Ok(Arc::new(DataType::Utf8)));
let fun: ScalarFunctionImplementation = Arc::new(move |_| unimplemented!());
let fun: ScalarFunctionImplementation = Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
"a".to_string(),
))))
});
let udf = Expr::ScalarUDF {
fun: Arc::new(ScalarUDF::new(
"TestScalarUDF",
Expand All @@ -212,7 +247,7 @@ mod test {
let mut config = OptimizerConfig::default();
let plan = rule.optimize(&plan, &mut config)?;
assert_eq!(
"Projection: TestScalarUDF(CAST(Int32(123) AS Float32))\n EmptyRelation",
"Projection: Utf8(\"a\")\n EmptyRelation",
&format!("{:?}", plan)
);
Ok(())
Expand Down