From c861f4099626a8b39ab3b36201bff9c214bfe832 Mon Sep 17 00:00:00 2001 From: hhj Date: Thu, 11 Jan 2024 21:16:33 +0800 Subject: [PATCH 1/3] fix: don't extract common sub expr in CASE WHEN clause --- .../optimizer/src/common_subexpr_eliminate.rs | 24 +++++++----- datafusion/optimizer/src/push_down_filter.rs | 39 +++---------------- datafusion/optimizer/src/utils.rs | 17 ++++++++ .../sqllogictest/test_files/functions.slt | 2 +- datafusion/sqllogictest/test_files/select.slt | 19 +++++++++ 5 files changed, 56 insertions(+), 45 deletions(-) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 000329d0d078..fc867df23c36 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -20,6 +20,7 @@ use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; +use crate::utils::is_volatile_expression; use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; @@ -29,7 +30,7 @@ use datafusion_common::tree_node::{ use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, }; -use datafusion_expr::expr::{is_volatile, Alias}; +use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, }; @@ -518,7 +519,7 @@ enum ExprMask { } impl ExprMask { - fn ignores(&self, expr: &Expr) -> Result { + fn ignores(&self, expr: &Expr) -> bool { let is_normal_minus_aggregates = matches!( expr, Expr::Literal(..) @@ -529,14 +530,12 @@ impl ExprMask { | Expr::Wildcard { .. } ); - let is_volatile = is_volatile(expr)?; - let is_aggr = matches!(expr, Expr::AggregateFunction(..)); - Ok(match self { - Self::Normal => is_volatile || is_normal_minus_aggregates || is_aggr, - Self::NormalAndAggregates => is_volatile || is_normal_minus_aggregates, - }) + match self { + Self::Normal => is_normal_minus_aggregates || is_aggr, + Self::NormalAndAggregates => is_normal_minus_aggregates, + } } } @@ -614,7 +613,12 @@ impl ExprIdentifierVisitor<'_> { impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type N = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // related to https://github.com/apache/arrow-datafusion/issues/8814 + // If the expr contain volatile expression or is a case expression, skip it. + if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? { + return Ok(VisitRecursion::Skip); + } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; @@ -628,7 +632,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { let (idx, sub_expr_desc) = self.pop_enter_mark(); // skip exprs should not be recognize. - if self.expr_mask.ignores(expr)? { + if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 4eb925ac0629..7086c5cda56f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -19,6 +19,7 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; use crate::optimizer::ApplyOrder; +use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; @@ -34,7 +35,7 @@ use datafusion_expr::logical_plan::{ use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned}; use datafusion_expr::{ and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator, - ScalarFunctionDefinition, TableProviderFilterPushDown, Volatility, + ScalarFunctionDefinition, TableProviderFilterPushDown, }; use itertools::Itertools; @@ -739,7 +740,9 @@ impl OptimizerRule for PushDownFilter { (field.qualified_name(), expr) }) - .partition(|(_, value)| is_volatile_expression(value)); + .partition(|(_, value)| { + is_volatile_expression(value).unwrap_or(true) + }); let mut push_predicates = vec![]; let mut keep_predicates = vec![]; @@ -1028,38 +1031,6 @@ pub fn replace_cols_by_name( }) } -/// check whether the expression is volatile predicates -fn is_volatile_expression(e: &Expr) -> bool { - let mut is_volatile = false; - e.apply(&mut |expr| { - Ok(match expr { - Expr::ScalarFunction(f) => match &f.func_def { - ScalarFunctionDefinition::BuiltIn(fun) - if fun.volatility() == Volatility::Volatile => - { - is_volatile = true; - VisitRecursion::Stop - } - ScalarFunctionDefinition::UDF(fun) - if fun.signature().volatility == Volatility::Volatile => - { - is_volatile = true; - VisitRecursion::Stop - } - ScalarFunctionDefinition::Name(_) => { - return internal_err!( - "Function `Expr` with name should be resolved." - ); - } - _ => VisitRecursion::Continue, - }, - _ => VisitRecursion::Continue, - }) - }) - .unwrap(); - is_volatile -} - /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 44f2404afade..7c2d210adcfe 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,8 +18,10 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; +use datafusion_common::tree_node::{TreeNode, VisitRecursion}; use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; +use datafusion_expr::expr::is_volatile; use datafusion_expr::expr_rewriter::replace_col; use datafusion_expr::utils as expr_utils; use datafusion_expr::{logical_plan::LogicalPlan, Expr, Operator}; @@ -92,6 +94,21 @@ pub fn log_plan(description: &str, plan: &LogicalPlan) { trace!("{description}::\n{}\n", plan.display_indent_schema()); } +/// check whether the expression is volatile predicates +pub(crate) fn is_volatile_expression(e: &Expr) -> Result { + let mut is_volatile_expr = false; + e.apply(&mut |expr| { + Ok(if is_volatile(expr)? { + is_volatile_expr = true; + VisitRecursion::Stop + } else { + VisitRecursion::Continue + }) + }) + .unwrap(); + Ok(is_volatile_expr) +} + /// Splits a conjunctive [`Expr`] such as `A AND B AND C` => `[A, B, C]` /// /// See [`split_conjunction_owned`] for more details and an example. diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 1903088b0748..7bd60a3a154b 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -998,6 +998,6 @@ NULL # Verify that multiple calls to volatile functions like `random()` are not combined / optimized away query B -SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random() r1, random() r2) WHERE r1 > 0 AND r2 > 0) +SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0) ---- false diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index 132bcdd246fe..ca48c07b0914 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1112,3 +1112,22 @@ SELECT abs(x), abs(x) + abs(y) FROM t; statement ok DROP TABLE t; + +# related to https://github.com/apache/arrow-datafusion/issues/8814 +statement ok +create table t(x int, y int) as values (1,1), (2,2), (3,3), (0,0), (4,0); + +query II +SELECT +CASE WHEN B.x > 0 THEN A.x / B.x ELSE 0 END AS value1, +CASE WHEN B.x > 0 AND B.y > 0 THEN A.x / B.x ELSE 0 END AS value3 +FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; +---- +0 0 +0 0 +0 0 +0 0 +0 0 + +statement ok +DROP TABLE t; From e6377c57e56c59c01c5992dfb3d025b646610862 Mon Sep 17 00:00:00 2001 From: hhj Date: Thu, 11 Jan 2024 22:02:25 +0800 Subject: [PATCH 2/3] fix ci --- datafusion/sqllogictest/test_files/tpch/q14.slt.part | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/sqllogictest/test_files/tpch/q14.slt.part b/datafusion/sqllogictest/test_files/tpch/q14.slt.part index b584972c25bc..7e614ab49e38 100644 --- a/datafusion/sqllogictest/test_files/tpch/q14.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/q14.slt.part @@ -33,8 +33,8 @@ where ---- logical_plan Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS Float64) AS promo_revenue ---Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] -----Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) AS lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, part.p_type +--Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount) ELSE Decimal128(Some(0),38,4) END) AS SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] +----Projection: lineitem.l_extendedprice, lineitem.l_discount, part.p_type ------Inner Join: lineitem.l_partkey = part.p_partkey --------Projection: lineitem.l_partkey, lineitem.l_extendedprice, lineitem.l_discount ----------Filter: lineitem.l_shipdate >= Date32("9374") AND lineitem.l_shipdate < Date32("9404") @@ -45,7 +45,7 @@ ProjectionExec: expr=[100 * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") --AggregateExec: mode=Final, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ----CoalescePartitionsExec ------AggregateExec: mode=Partial, gby=[], aggr=[SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END), SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)] ---------ProjectionExec: expr=[l_extendedprice@1 * (Some(1),20,0 - l_discount@2) as lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discount)Decimal128(Some(1),20,0) - lineitem.l_discountlineitem.l_discountDecimal128(Some(1),20,0)lineitem.l_extendedprice, p_type@4 as p_type] +--------ProjectionExec: expr=[l_extendedprice@1 as l_extendedprice, l_discount@2 as l_discount, p_type@4 as p_type] ----------CoalesceBatchesExec: target_batch_size=8192 ------------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)] --------------CoalesceBatchesExec: target_batch_size=8192 From fd3e5ffe317121578766932734eeb36fabe3b629 Mon Sep 17 00:00:00 2001 From: hhj Date: Mon, 15 Jan 2024 19:36:40 +0800 Subject: [PATCH 3/3] fix --- datafusion/optimizer/src/utils.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 7c2d210adcfe..5671dc6ae94d 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -104,8 +104,7 @@ pub(crate) fn is_volatile_expression(e: &Expr) -> Result { } else { VisitRecursion::Continue }) - }) - .unwrap(); + })?; Ok(is_volatile_expr) }