From 212c489afcaf37f60e2fb100d98e4e7873bde0ae Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 13 Jan 2024 10:01:26 +0100 Subject: [PATCH 01/40] refactor `TreeNode::rewrite()` --- datafusion-examples/examples/rewrite_expr.rs | 6 +- datafusion/common/src/tree_node.rs | 204 ++++++++++-------- .../core/src/datasource/listing/helpers.rs | 18 +- .../physical_plan/parquet/row_filter.rs | 17 +- datafusion/core/src/execution/context/mod.rs | 6 +- .../combine_partial_final_agg.rs | 2 +- .../physical_optimizer/projection_pushdown.rs | 4 +- .../core/src/physical_optimizer/pruning.rs | 2 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 38 ++-- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/logical_plan/display.rs | 18 +- datafusion/expr/src/logical_plan/plan.rs | 79 ++++--- datafusion/expr/src/tree_node/expr.rs | 14 +- datafusion/expr/src/tree_node/plan.rs | 14 +- datafusion/expr/src/utils.rs | 10 +- .../src/analyzer/count_wildcard_rule.rs | 4 +- .../src/analyzer/inline_table_scan.rs | 2 +- datafusion/optimizer/src/analyzer/mod.rs | 4 +- .../optimizer/src/analyzer/rewrite_expr.rs | 4 +- datafusion/optimizer/src/analyzer/subquery.rs | 16 +- .../optimizer/src/analyzer/type_coercion.rs | 10 +- .../optimizer/src/common_subexpr_eliminate.rs | 95 ++++---- datafusion/optimizer/src/decorrelate.rs | 22 +- datafusion/optimizer/src/plan_signature.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 14 +- .../optimizer/src/scalar_subquery_to_join.rs | 26 +-- .../simplify_expressions/expr_simplifier.rs | 21 +- .../src/simplify_expressions/guarantees.rs | 4 +- .../simplify_expressions/inlist_simplifier.rs | 4 +- .../or_in_list_simplifier.rs | 4 +- .../src/unwrap_cast_in_comparison.rs | 10 +- datafusion/optimizer/src/utils.rs | 6 +- .../physical-expr/src/equivalence/class.rs | 2 +- .../physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-expr/src/utils/mod.rs | 26 +-- .../library-user-guide/working-with-exprs.md | 2 +- 37 files changed, 361 insertions(+), 357 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 5e95562033e6..9dfc238ab9e8 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform(&|plan| { + plan.transform_up(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -106,7 +106,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -161,7 +161,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c5c4ee824d61..a451fe77088d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,21 +22,21 @@ use std::sync::Arc; use crate::Result; -/// If the function returns [`VisitRecursion::Continue`], the normal execution of the -/// function continues. If it returns [`VisitRecursion::Skip`], the function returns -/// with [`VisitRecursion::Continue`] to jump next recursion step, bypassing further -/// exploration of the current step. In case of [`VisitRecursion::Stop`], the function -/// return with [`VisitRecursion::Stop`] and recursion halts. +/// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the +/// function continues. If it returns [`TreeNodeRecursion::Skip`], the function returns +/// with [`TreeNodeRecursion::Continue`] to jump next recursion step, bypassing further +/// exploration of the current step. In case of [`TreeNodeRecursion::Stop`], the function +/// return with [`TreeNodeRecursion::Stop`] and recursion halts. #[macro_export] macro_rules! handle_tree_recursion { ($EXPR:expr) => { match $EXPR { - VisitRecursion::Continue => {} + TreeNodeRecursion::Continue => {} // If the recursion should skip, do not apply to its children, let // the recursion continue: - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), + TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), // If the recursion should stop, do not apply to its children: - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } @@ -53,15 +53,15 @@ macro_rules! handle_tree_recursion { /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { /// Applies `op` to the node and its children. `op` is applied in a preoder way, - /// and it is controlled by [`VisitRecursion`], which means result of the `op` + /// and it is controlled by [`TreeNodeRecursion`], which means result of the `op` /// on the self node can cause an early return. /// /// The `op` closure can be used to collect some info from the /// tree node or do some checking for the tree node. - fn apply Result>( + fn apply Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { handle_tree_recursion!(op(self)?); self.apply_children(&mut |node| node.apply(op)) } @@ -88,7 +88,7 @@ pub trait TreeNode: Sized { /// /// If an Err result is returned, recursion is stopped immediately /// - /// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no + /// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no /// children of that node will be visited, nor is post_visit /// called on that node. Details see [`TreeNodeVisitor`] /// @@ -97,20 +97,53 @@ pub trait TreeNode: Sized { fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { handle_tree_recursion!(visitor.pre_visit(self)?); handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); visitor.post_visit(self) } - /// Convenience utils for writing optimizers rule: recursively apply the given `op` to the node tree. - /// When `op` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, op: &F) -> Result + /// Transforms the tree using `f_down` while traversing the tree top-down + /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order). + /// + /// E.g. for an tree such as: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// The nodes are visited using the following order: + /// ```text + /// f_down(ParentNode) + /// f_down(ChildNode1) + /// f_up(ChildNode1) + /// f_down(ChildNode2) + /// f_up(ChildNode2) + /// f_up(ParentNode) + /// ``` + /// + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// + /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. + fn transform(self, f_down: &mut FD, f_up: &mut FU) -> Result where - F: Fn(Self) -> Result>, + FD: FnMut(Self) -> Result<(Transformed, TreeNodeRecursion)>, + FU: FnMut(Self) -> Result, { - self.transform_up(op) + let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?; + match tnr { + TreeNodeRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + TreeNodeRecursion::Skip => return Ok(new_node), + // If the recursion should stop, do not apply to its children + TreeNodeRecursion::Stop => { + panic!("Stop can't be used in TreeNode::transform()") + } + } + let node_with_new_children = + new_node.map_children(|node| node.transform(f_down, f_up))?; + f_up(node_with_new_children) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -159,56 +192,50 @@ pub trait TreeNode: Sized { Ok(new_node) } - /// Transform the tree node using the given [TreeNodeRewriter] - /// It performs a depth first walk of an node and its children. + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively transforming [`TreeNode`]s. /// - /// For an node tree such as + /// E.g. for an tree such as: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// The nodes are visited using the following order: /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// mutate(ChildNode1) - /// pre_visit(ChildNode2) - /// mutate(ChildNode2) - /// mutate(ParentNode) + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately - /// - /// If [`false`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is mutate - /// called on that node + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// - /// If using the default [`TreeNodeRewriter::pre_visit`] which - /// returns `true`, [`Self::transform`] should be preferred. - fn rewrite>(self, rewriter: &mut R) -> Result { - let need_mutate = match rewriter.pre_visit(&self)? { - RewriteRecursion::Mutate => return rewriter.mutate(self), - RewriteRecursion::Stop => return Ok(self), - RewriteRecursion::Continue => true, - RewriteRecursion::Skip => false, - }; - - let after_op_children = self.map_children(|node| node.rewrite(rewriter))?; - - // now rewrite this node itself - if need_mutate { - rewriter.mutate(after_op_children) - } else { - Ok(after_op_children) + /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], + /// recursion is stopped immediately. + fn rewrite>(self, rewriter: &mut R) -> Result { + let (new_node, tnr) = rewriter.f_down(self)?; + match tnr { + TreeNodeRecursion::Continue => {} + // If the recursion should skip, do not apply to its children. And let the recursion continue + TreeNodeRecursion::Skip => return Ok(new_node), + // If the recursion should stop, do not apply to its children + TreeNodeRecursion::Stop => { + panic!("Stop can't be used in TreeNode::rewrite()") + } } + let node_with_new_children = + new_node.map_children(|node| node.rewrite(rewriter))?; + rewriter.f_up(node_with_new_children) } /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result; + F: FnMut(&Self) -> Result; /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) fn map_children(self, transform: F) -> Result @@ -231,69 +258,58 @@ pub trait TreeNode: Sized { /// If an [`Err`] result is returned, recursion is stopped /// immediately. /// -/// If [`VisitRecursion::Stop`] is returned on a call to pre_visit, no +/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no /// children of that tree node are visited, nor is post_visit /// called on that tree node /// -/// If [`VisitRecursion::Stop`] is returned on a call to post_visit, no +/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no /// siblings of that tree node are visited, nor is post_visit /// called on its parent tree node /// -/// If [`VisitRecursion::Skip`] is returned on a call to pre_visit, no +/// If [`TreeNodeRecursion::Skip`] is returned on a call to pre_visit, no /// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type N: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn pre_visit(&mut self, node: &Self::N) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { - Ok(VisitRecursion::Continue) + fn post_visit(&mut self, _node: &Self::N) -> Result { + Ok(TreeNodeRecursion::Continue) } } -/// Trait for potentially recursively transform an [`TreeNode`] node -/// tree. When passed to `TreeNode::rewrite`, `TreeNodeRewriter::mutate` is -/// invoked recursively on all nodes of a tree. +/// Trait for potentially recursively transform a [`TreeNode`] node tree. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. - type N: TreeNode; + type Node: TreeNode; - /// Invoked before (Preorder) any children of `node` are rewritten / - /// visited. Default implementation returns `Ok(Recursion::Continue)` - fn pre_visit(&mut self, _node: &Self::N) -> Result { - Ok(RewriteRecursion::Continue) + /// Invoked while traversing down the tree before any children are rewritten / + /// visited. + /// Default implementation returns the node unmodified and continues recursion. + fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> { + Ok((node, TreeNodeRecursion::Continue)) } - /// Invoked after (Postorder) all children of `node` have been mutated and - /// returns a potentially modified node. - fn mutate(&mut self, node: Self::N) -> Result; -} - -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::rewrite`]. -#[derive(Debug)] -pub enum RewriteRecursion { - /// Continue rewrite this node tree. - Continue, - /// Call 'op' immediately and return. - Mutate, - /// Do not rewrite the children of this node. - Stop, - /// Keep recursive but skip apply op on this node - Skip, + /// Invoked while traversing up the tree after all children have been rewritten / + /// visited. + /// Default implementation returns the node unmodified. + fn f_up(&mut self, node: Self::Node) -> Result { + Ok(node) + } } -/// Controls how the [`TreeNode`] recursion should proceed for [`TreeNode::visit`]. +/// Controls how [`TreeNode`] recursions should proceed. #[derive(Debug)] -pub enum VisitRecursion { - /// Continue the visit to this node tree. +pub enum TreeNodeRecursion { + /// Continue recursion with the next node. Continue, - /// Keep recursive but skip applying op on the children + /// Skip the current subtree. Skip, - /// Stop the visit to this node tree. + /// Stop recursion. Stop, } @@ -340,14 +356,14 @@ pub trait DynTreeNode { /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { for child in self.arc_children() { handle_tree_recursion!(op(&child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result @@ -382,14 +398,14 @@ pub trait ConcreteTreeNode: Sized { impl TreeNode for T { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, op: &mut F) -> Result where - F: FnMut(&Self) -> Result, + F: FnMut(&Self) -> Result, { for child in self.children() { handle_tree_recursion!(op(child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index a03bcec7abec..96864672573b 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use crate::{error::Result, scalar::ScalarValue}; use super::PartitionedFile; use crate::datasource::listing::ListingTableUrl; use crate::execution::context::SessionState; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; @@ -57,9 +57,9 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(VisitRecursion::Skip) + Ok(TreeNodeRecursion::Skip) } else { - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } Expr::Literal(_) @@ -88,27 +88,27 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::ScalarSubquery(_) | Expr::GetIndexedField { .. } | Expr::GroupingSet(_) - | Expr::Case { .. } => Ok(VisitRecursion::Continue), + | Expr::Case { .. } => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match &scalar_function.func_def { ScalarFunctionDefinition::BuiltIn(fun) => { match fun.volatility() { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } ScalarFunctionDefinition::UDF(fun) => { match fun.signature().volatility { - Volatility::Immutable => Ok(VisitRecursion::Continue), + Volatility::Immutable => Ok(TreeNodeRecursion::Continue), // TODO: Stable functions could be `applicable`, but that would require access to the context Volatility::Stable | Volatility::Volatile => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } } @@ -128,7 +128,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { | Expr::Wildcard { .. } | Expr::Placeholder(_) => { is_applicable = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } } }) diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 3c40509a86d2..ddfeb146b876 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; @@ -209,29 +209,32 @@ impl<'a> FilterCandidateBuilder<'a> { } impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { - type N = Arc; + type Node = Arc; - fn pre_visit(&mut self, node: &Arc) -> Result { + fn f_down( + &mut self, + node: Arc, + ) -> Result<(Arc, TreeNodeRecursion)> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok((node, TreeNodeRecursion::Skip)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(RewriteRecursion::Stop); + return Ok((node, TreeNodeRecursion::Skip)); } } - Ok(RewriteRecursion::Continue) + Ok((node, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Arc) -> Result> { + fn f_up(&mut self, expr: Arc) -> Result> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index b5ad6174821b..4f57d873cbdf 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -39,7 +39,7 @@ use crate::{ use datafusion_common::{ alias::AliasGenerator, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ @@ -2108,7 +2108,7 @@ impl<'a> BadPlanVisitor<'a> { impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { type N = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn pre_visit(&mut self, node: &Self::N) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) @@ -2122,7 +2122,7 @@ impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { LogicalPlan::Statement(stmt) if !self.options.allow_statements => { plan_err!("Statement not supported: {}", stmt.name()) } - _ => Ok(VisitRecursion::Continue), + _ => Ok(TreeNodeRecursion::Continue), } } } diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index 61eb2381c63b..b26d9763e53a 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -178,7 +178,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform(&|expr| { + .transform_up(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 2d20c487e473..64ef92faa865 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,7 @@ use crate::physical_plan::{Distribution, ExecutionPlan}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::JoinSide; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -270,7 +270,7 @@ fn try_unifying_projections( if let Some(column) = expr.as_any().downcast_ref::() { *column_ref_map.entry(column.clone()).or_default() += 1; } - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index aa0c26723767..aa72771b1eb3 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -837,7 +837,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(&|expr| { + e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::Yes(Arc::new(column_new.clone()))); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index c5d158d87638..e0eebf5c8c18 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1247,7 +1247,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform(&|mut expr| { + self.transform_up(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 1f04c80833f0..76bd51619954 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -33,7 +33,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -57,7 +57,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -75,7 +75,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( schemas: &[&[&DFSchema]], using_columns: &[HashSet], ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -102,7 +102,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -122,7 +122,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -164,7 +164,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::Yes(Expr::Column(col)) @@ -250,7 +250,7 @@ pub fn unalias(expr: Expr) -> Expr { /// schema of plan nodes don't change after optimization pub fn rewrite_preserving_name(expr: Expr, rewriter: &mut R) -> Result where - R: TreeNodeRewriter, + R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; let expr = expr.rewrite(rewriter)?; @@ -263,7 +263,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -273,14 +273,14 @@ mod test { } impl TreeNodeRewriter for RecordingRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { self.v.push(format!("Previsited {expr}")); - Ok(RewriteRecursion::Continue) + Ok((expr, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { self.v.push(format!("Mutated {expr}")); Ok(expr) } @@ -305,11 +305,17 @@ mod test { }; // rewrites "foo" --> "bar" - let rewritten = col("state").eq(lit("foo")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("foo")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); // doesn't rewrite - let rewritten = col("state").eq(lit("baz")).transform(&transformer).unwrap(); + let rewritten = col("state") + .eq(lit("baz")) + .transform_up(&transformer) + .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -444,9 +450,9 @@ mod test { } impl TreeNodeRewriter for TestRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, _: Expr) -> Result { + fn f_up(&mut self, _: Expr) -> Result { Ok(self.rewrite_to.clone()) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index c87a724d5646..1e7efcafd04d 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform(&|expr| { + expr.transform_up(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 112dbf74dba1..ebef7791f8d8 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -19,7 +19,7 @@ use crate::LogicalPlan; use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; -use datafusion_common::tree_node::{TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; use std::fmt; @@ -54,7 +54,7 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { if self.indent > 0 { writeln!(self.f)?; } @@ -69,15 +69,15 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { } self.indent += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { self.indent -= 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -176,7 +176,7 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { fn pre_visit( &mut self, plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { let id = self.graphviz_builder.next_id(); // Create a new graph node for `plan` such as @@ -204,18 +204,18 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { } self.parent_ids.push(id); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn post_visit( &mut self, _plan: &LogicalPlan, - ) -> datafusion_common::Result { + ) -> datafusion_common::Result { // always be non-empty as pre_visit always pushes // So it should always be Ok(true) let res = self.parent_ids.pop(); res.ok_or(DataFusionError::Internal("Fail to format".to_string())) - .map(|_| VisitRecursion::Continue) + .map(|_| TreeNodeRecursion::Continue) } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index aee3a59dd2da..80ce38fe9389 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -45,8 +45,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, TreeNodeVisitor, - VisitRecursion, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -475,7 +474,7 @@ impl LogicalPlan { })?; using_columns.push(columns); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(using_columns) @@ -648,31 +647,29 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - struct RemoveAliases {} - - impl TreeNodeRewriter for RemoveAliases { - type N = Expr; - - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(RewriteRecursion::Stop) - } - Expr::Alias(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), + fn unalias_down( + expr: Expr, + ) -> Result<(Transformed, TreeNodeRecursion)> { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok((Transformed::No(expr), TreeNodeRecursion::Skip)) } + Expr::Alias(_) => Ok(( + Transformed::Yes(expr.unalias()), + TreeNodeRecursion::Skip, + )), + _ => Ok((Transformed::No(expr), TreeNodeRecursion::Continue)), } + } - fn mutate(&mut self, expr: Expr) -> Result { - Ok(expr.unalias()) - } + fn dummy_up(expr: Expr) -> Result { + Ok(expr) } - let mut remove_aliases = RemoveAliases {}; - let predicate = predicate.rewrite(&mut remove_aliases)?; + let predicate = predicate.transform(&mut unalias_down, &mut dummy_up)?; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -1124,9 +1121,9 @@ impl LogicalPlan { impl LogicalPlan { /// applies `op` to any subqueries in the plan - pub(crate) fn apply_subqueries(&self, op: &mut F) -> datafusion_common::Result<()> + pub(crate) fn apply_subqueries(&self, op: &mut F) -> Result<()> where - F: FnMut(&Self) -> datafusion_common::Result, + F: FnMut(&Self) -> Result, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -1150,7 +1147,7 @@ impl LogicalPlan { } /// applies visitor to any subqueries in the plan - pub(crate) fn visit_subqueries(&self, v: &mut V) -> datafusion_common::Result<()> + pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> where V: TreeNodeVisitor, { @@ -1225,11 +1222,11 @@ impl LogicalPlan { _ => {} } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok::<(), DataFusionError>(()) })?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(param_types) @@ -1241,7 +1238,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; @@ -2840,7 +2837,7 @@ digraph { impl TreeNodeVisitor for OkVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2851,10 +2848,10 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2865,7 +2862,7 @@ digraph { }; self.strings.push(s.into()); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -2923,18 +2920,18 @@ digraph { impl TreeNodeVisitor for StoppingVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.pre_visit(plan)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { - return Ok(VisitRecursion::Stop); + return Ok(TreeNodeRecursion::Stop); } self.inner.post_visit(plan) @@ -2992,7 +2989,7 @@ digraph { impl TreeNodeVisitor for ErrorVisitor { type N = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } @@ -3000,7 +2997,7 @@ digraph { self.inner.pre_visit(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn post_visit(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } @@ -3306,7 +3303,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform(&|plan| match plan { + .transform_up(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 05464c96d05e..d937c11633f4 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,14 +24,14 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, DataFusionError, Result}; impl TreeNode for Expr { - fn apply_children Result>( + fn apply_children Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { let children = match self { Expr::Alias(Alias{expr, .. }) | Expr::Not(expr) @@ -130,13 +130,13 @@ impl TreeNode for Expr { for child in children { match op(child)? { - VisitRecursion::Continue => {} - VisitRecursion::Skip => return Ok(VisitRecursion::Continue), - VisitRecursion::Stop => return Ok(VisitRecursion::Stop), + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children Result>( diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 589bb917a953..8be24638c1cc 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,14 +19,14 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::{handle_tree_recursion, Result}; impl TreeNode for LogicalPlan { - fn apply Result>( + fn apply Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children handle_tree_recursion!(op(self)?); @@ -57,7 +57,7 @@ impl TreeNode for LogicalPlan { fn visit>( &self, visitor: &mut V, - ) -> Result { + ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children handle_tree_recursion!(visitor.pre_visit(self)?); @@ -66,14 +66,14 @@ impl TreeNode for LogicalPlan { visitor.post_visit(self) } - fn apply_children Result>( + fn apply_children Result>( &self, op: &mut F, - ) -> Result { + ) -> Result { for child in self.inputs() { handle_tree_recursion!(op(child)?) } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } fn map_children(self, transform: F) -> Result diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 02479c0765bd..88b6d34c48dc 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -31,7 +31,7 @@ use crate::{ }; use arrow::datatypes::{DataType, TimeUnit}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::utils::get_at_indices; use datafusion_common::{ internal_err, plan_datafusion_err, plan_err, Column, DFField, DFSchema, DFSchemaRef, @@ -662,10 +662,10 @@ where exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Skip); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); @@ -682,10 +682,10 @@ where if let Err(e) = f(expr) { // save the error for later (it may not be a DataFusionError err = Err(e); - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } else { // keep going - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } }) // The closure always returns OK, so this will always too diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 35a859783239..90046ca2aac0 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -115,9 +115,9 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, old_expr: Expr) -> Result { + fn f_up(&mut self, old_expr: Expr) -> Result { let new_expr = match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 90af7aec8293..a418fbf5537b 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::Yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?; + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; Transformed::Yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index 9d47299a5616..b416e1eb1863 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -28,7 +28,7 @@ use crate::analyzer::subquery::check_subquery_expr; use crate::analyzer::type_coercion::TypeCoercion; use crate::utils::log_plan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -136,7 +136,7 @@ fn check_plan(plan: &LogicalPlan) -> Result<()> { })?; } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 8f1c844ed062..829197b4d948 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -94,9 +94,9 @@ pub(crate) struct OperatorToFunctionRewriter { } impl TreeNodeRewriter for OperatorToFunctionRewriter { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match expr { Expr::BinaryExpr(BinaryExpr { ref left, diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 7c5b70b19af0..7ad9832dea54 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -17,7 +17,7 @@ use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, DataFusionError, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; use datafusion_expr::utils::split_conjunction; @@ -146,7 +146,7 @@ fn check_inner_plan( LogicalPlan::Aggregate(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, true, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -171,7 +171,7 @@ fn check_inner_plan( check_mixed_out_refer_in_window(window)?; inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -188,7 +188,7 @@ fn check_inner_plan( | LogicalPlan::SubqueryAlias(_) => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, can_contain_outer_ref)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -206,7 +206,7 @@ fn check_inner_plan( is_aggregate, can_contain_outer_ref, )?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -221,7 +221,7 @@ fn check_inner_plan( JoinType::Full => { inner_plan.apply_children(&mut |plan| { check_inner_plan(plan, is_scalar, is_aggregate, false)?; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(()) } @@ -290,9 +290,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { correlated .into_iter() .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) })?; Ok(exprs) } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index c0dad2ef4006..0f20ede0f239 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -126,13 +126,9 @@ pub(crate) struct TypeCoercionRewriter { } impl TreeNodeRewriter for TypeCoercionRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match expr { Expr::ScalarSubquery(Subquery { subquery, diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fe71171ce545..564addd53f29 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,7 +25,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - RewriteRecursion, TreeNode, TreeNodeRewriter, TreeNodeVisitor, VisitRecursion, + TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -614,21 +614,21 @@ impl ExprIdentifierVisitor<'_> { impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { type N = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(VisitRecursion::Skip); + return Ok(TreeNodeRecursion::Skip); } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); self.node_count += 1; // put placeholder self.id_array.push((0, "".to_string())); - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn post_visit(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); @@ -637,7 +637,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { self.id_array[idx].0 = self.series_number; let desc = Self::desc_expr(expr); self.visit_stack.push(VisitRecord::ExprItem(desc)); - return Ok(VisitRecursion::Continue); + return Ok(TreeNodeRecursion::Continue); } let mut desc = Self::desc_expr(expr); desc.push_str(&sub_expr_desc); @@ -651,7 +651,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { .entry(desc) .or_insert_with(|| (expr.clone(), 0, data_type)) .1 += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) } } @@ -694,74 +694,71 @@ struct CommonSubexprRewriter<'a> { } impl TreeNodeRewriter for CommonSubexprRewriter<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. - if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(RewriteRecursion::Stop); + if expr.short_circuits() || is_volatile_expression(&expr)? { + return Ok((expr, TreeNodeRecursion::Skip)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(RewriteRecursion::Stop); + return Ok((expr, TreeNodeRecursion::Skip)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok(RewriteRecursion::Skip); + return Ok((expr, TreeNodeRecursion::Continue)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { if *counter > 1 { self.affected_id.insert(curr_id.clone()); - Ok(RewriteRecursion::Mutate) + + // This expr tree is finished. + if self.curr_index >= self.id_array.len() { + return Ok((expr, TreeNodeRecursion::Skip)); + } + + let (series_number, id) = &self.id_array[self.curr_index]; + self.curr_index += 1; + // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. + let expr_set_item = self.expr_set.get(id).ok_or_else(|| { + DataFusionError::Internal("expr_set invalid state".to_string()) + })?; + if *series_number < self.max_series_number + || id.is_empty() + || expr_set_item.1 <= 1 + { + return Ok((expr, TreeNodeRecursion::Skip)); + } + + self.max_series_number = *series_number; + // step index to skip all sub-node (which has smaller series number). + while self.curr_index < self.id_array.len() + && *series_number > self.id_array[self.curr_index].0 + { + self.curr_index += 1; + } + + let expr_name = expr.display_name()?; + // Alias this `Column` expr to it original "expr name", + // `projection_push_down` optimizer use "expr name" to eliminate useless + // projections. + Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip)) } else { self.curr_index += 1; - Ok(RewriteRecursion::Skip) + Ok((expr, TreeNodeRecursion::Continue)) } } _ => internal_err!("expr_set invalid state"), } } - - fn mutate(&mut self, expr: Expr) -> Result { - // This expr tree is finished. - if self.curr_index >= self.id_array.len() { - return Ok(expr); - } - - let (series_number, id) = &self.id_array[self.curr_index]; - self.curr_index += 1; - // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. - let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) - })?; - if *series_number < self.max_series_number - || id.is_empty() - || expr_set_item.1 <= 1 - { - return Ok(expr); - } - - self.max_series_number = *series_number; - // step index to skip all sub-node (which has smaller series number). - while self.curr_index < self.id_array.len() - && *series_number > self.id_array[self.curr_index].0 - { - self.curr_index += 1; - } - - let expr_name = expr.display_name()?; - // Alias this `Column` expr to it original "expr name", - // `projection_push_down` optimizer use "expr name" to eliminate useless - // projections. - Ok(col(id).alias(expr_name)) - } } fn replace_common_expr( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b1000f042c98..49d3c322ca2b 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -18,7 +18,7 @@ use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; use crate::utils::collect_subquery_cols; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Result}; use datafusion_common::{Column, DFSchemaRef, DataFusionError, ScalarValue}; @@ -56,19 +56,19 @@ pub const UN_MATCHED_ROW_INDICATOR: &str = "__always_true"; pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: LogicalPlan) -> Result<(LogicalPlan, TreeNodeRecursion)> { match plan { - LogicalPlan::Filter(_) => Ok(RewriteRecursion::Continue), + LogicalPlan::Filter(_) => Ok((plan, TreeNodeRecursion::Continue)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } else { - Ok(RewriteRecursion::Continue) + Ok((plan, TreeNodeRecursion::Continue)) } } LogicalPlan::Limit(_) => { @@ -77,21 +77,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok((plan, TreeNodeRecursion::Continue)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok(RewriteRecursion::Stop) + Ok((plan, TreeNodeRecursion::Skip)) } - _ => Ok(RewriteRecursion::Continue), + _ => Ok((plan, TreeNodeRecursion::Continue)), } } - fn mutate(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 07f495a7262d..8b8814192d38 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, @@ -75,7 +75,7 @@ fn get_node_number(plan: &LogicalPlan) -> NonZeroUsize { let mut node_number = 0; plan.apply(&mut |_plan| { node_number += 1; - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // Closure always return Ok .unwrap(); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 7086c5cda56f..0ae0bc696a35 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -22,7 +22,7 @@ use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, DataFusionError, JoinConstraint, Result, @@ -222,7 +222,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(VisitRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Skip), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) @@ -232,7 +232,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { .. }) => { is_evaluate = false; - Ok(VisitRecursion::Stop) + Ok(TreeNodeRecursion::Stop) } Expr::Alias(_) | Expr::BinaryExpr(_) @@ -254,7 +254,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { | Expr::Cast(_) | Expr::TryCast(_) | Expr::ScalarFunction(..) - | Expr::InList { .. } => Ok(VisitRecursion::Continue), + | Expr::InList { .. } => Ok(TreeNodeRecursion::Continue), Expr::Sort(_) | Expr::AggregateFunction(_) | Expr::WindowFunction(_) @@ -1039,12 +1039,12 @@ fn contain(e: &Expr, check_map: &HashMap) -> bool { match check_map.get(&c.flat_name()) { Some(_) => { is_contain = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } - None => VisitRecursion::Continue, + None => TreeNodeRecursion::Continue, } } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 34ed4a9475cb..e1c35e468f68 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -21,7 +21,7 @@ use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - RewriteRecursion, Transformed, TreeNode, TreeNodeRewriter, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, DataFusionError, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -201,16 +201,9 @@ struct ExtractScalarSubQuery { } impl TreeNodeRewriter for ExtractScalarSubQuery { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { - match expr { - Expr::ScalarSubquery(_) => Ok(RewriteRecursion::Mutate), - _ => Ok(RewriteRecursion::Continue), - } - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -220,12 +213,15 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(Expr::Column(create_col_from_scalar_expr( - &scalar_expr, - subqry_alias, - )?)) + Ok(( + Expr::Column(create_col_from_scalar_expr( + &scalar_expr, + subqry_alias, + )?), + TreeNodeRecursion::Skip, + )) } - _ => Ok(expr), + _ => Ok((expr, TreeNodeRecursion::Continue)), } } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 1c1228949171..fd77071ea728 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -33,9 +33,10 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; +use datafusion_common::tree_node::TreeNodeRecursion; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, - tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}, + tree_node::{TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -246,9 +247,9 @@ impl Canonicalizer { } impl TreeNodeRewriter for Canonicalizer { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { return Ok(expr); }; @@ -310,9 +311,9 @@ enum ConstSimplifyResult { } impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -320,7 +321,7 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // stack as not ok (as all parents have at least one child or // descendant that can not be evaluated - if !Self::can_evaluate(expr) { + if !Self::can_evaluate(&expr) { // walk back up stack, marking first parent that is not mutable let parent_iter = self.can_evaluate.iter_mut().rev(); for p in parent_iter { @@ -336,10 +337,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok(RewriteRecursion::Continue) + Ok((expr, TreeNodeRecursion::Continue)) } - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -504,10 +505,10 @@ impl<'a, S> Simplifier<'a, S> { } impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { - type N = Expr; + type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index aa7bb4f78a93..e7c619c046de 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -57,9 +57,9 @@ impl<'a> GuaranteeRewriter<'a> { } impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if self.guarantees.is_empty() { return Ok(expr); } diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index fa95f1688e6f..867e96d213d9 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -49,9 +49,9 @@ impl InListSimplifier { } impl TreeNodeRewriter for InListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs index fd5c9ecaf82c..ea02c1f3af8a 100644 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs @@ -37,9 +37,9 @@ impl OrInListSimplifier { } impl TreeNodeRewriter for OrInListSimplifier { - type N = Expr; + type Node = Expr; - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if *op == Operator::Or { let left = as_inlist(left); diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 91603e82a54f..0232a28c722a 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::TreeNodeRewriter; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -127,13 +127,9 @@ struct UnwrapCastExprRewriter { } impl TreeNodeRewriter for UnwrapCastExprRewriter { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, _expr: &Expr) -> Result { - Ok(RewriteRecursion::Continue) - } - - fn mutate(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result { match &expr { // For case: // try_cast/cast(expr as data_type) op literal diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 5671dc6ae94d..13b67794c7dd 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{TreeNode, VisitRecursion}; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{Column, DFSchemaRef}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::is_volatile; @@ -100,9 +100,9 @@ pub(crate) fn is_volatile_expression(e: &Expr) -> Result { e.apply(&mut |expr| { Ok(if is_volatile(expr)? { is_volatile_expr = true; - VisitRecursion::Stop + TreeNodeRecursion::Stop } else { - VisitRecursion::Continue + TreeNodeRecursion::Continue }) })?; Ok(is_volatile_expr) diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index f0bd1740d5d2..29a6825ddcf7 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -260,7 +260,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform(&|expr| { + .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6a168e2f1e5f..b04c66b23728 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -960,7 +960,7 @@ mod tests { let expr2 = expr .clone() - .transform(&|e| { + .transform_up(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index e14ff2692146..8d4f4cad4afa 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -29,9 +29,7 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRewriter, VisitRecursion, -}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::Result; use datafusion_expr::Operator; @@ -130,11 +128,10 @@ pub fn get_indices_of_exprs_strict>>( pub type ExprTreeNode = ExprContext>; -/// This struct facilitates the [TreeNodeRewriter] mechanism to convert a -/// [PhysicalExpr] tree into a DAEG (i.e. an expression DAG) by collecting -/// identical expressions in one node. Caller specifies the node type in the -/// DAEG via the `constructor` argument, which constructs nodes in the DAEG -/// from the [ExprTreeNode] ancillary object. +/// This struct is used to convert a [PhysicalExpr] tree into a DAEG (i.e. an expression +/// DAG) by collecting identical expressions in one node. Caller specifies the node type +/// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from +/// the [ExprTreeNode] ancillary object. struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, @@ -144,16 +141,15 @@ struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result< constructor: &'a F, } -impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter - for PhysicalExprDAEGBuilder<'a, T, F> +impl<'a, T, F: Fn(&ExprTreeNode) -> Result> + PhysicalExprDAEGBuilder<'a, T, F> { - type N = ExprTreeNode; // This method mutates an expression node by transforming it to a physical expression // and adding it to the graph. The method returns the mutated expression node. fn mutate( &mut self, mut node: ExprTreeNode, - ) -> Result> { + ) -> Result>> { // Get the expression associated with the input expression node. let expr = &node.expr; @@ -176,7 +172,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> TreeNodeRewriter // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(node) + Ok(Transformed::Yes(node)) } } @@ -197,7 +193,7 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.rewrite(&mut builder)?; + let root = init.transform_up_mut(&mut |node| builder.mutate(node))?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -211,7 +207,7 @@ pub fn collect_columns(expr: &Arc) -> HashSet { columns.insert(column.clone()); } } - Ok(VisitRecursion::Continue) + Ok(TreeNodeRecursion::Continue) }) // pre_visit always returns OK, so this will always too .expect("no way to return error during recursion"); diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index 96be8ef7f1ae..b128d661f31a 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -96,7 +96,7 @@ To implement the inlining, we'll need to write a function that takes an `Expr` a ```rust fn rewrite_add_one(expr: Expr) -> Result { - expr.transform(&|expr| { + expr.transform_up(&|expr| { Ok(match expr { Expr::ScalarUDF(scalar_fun) if scalar_fun.fun.name == "add_one" => { let input_arg = scalar_fun.args[0].clone(); From c52f134569b85b37d20c547898a2bf0eda82b5f6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 29 Jan 2024 13:06:33 +0100 Subject: [PATCH 02/40] use handle_tree_recursion in `Expr` --- datafusion/expr/src/tree_node/expr.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index d937c11633f4..69ac97917159 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -25,7 +25,7 @@ use crate::expr::{ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{internal_err, DataFusionError, Result}; +use datafusion_common::{handle_tree_recursion, internal_err, DataFusionError, Result}; impl TreeNode for Expr { fn apply_children Result>( @@ -129,11 +129,7 @@ impl TreeNode for Expr { }; for child in children { - match op(child)? { - TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), - TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), - } + handle_tree_recursion!(op(child)?); } Ok(TreeNodeRecursion::Continue) From 3fd2214ef76820c550891ca1b147e3945ce5464e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 29 Jan 2024 13:33:59 +0100 Subject: [PATCH 03/40] use macro for transform recursions --- datafusion/common/src/tree_node.rs | 41 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index a451fe77088d..85338a7200df 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -41,6 +41,21 @@ macro_rules! handle_tree_recursion { }; } +macro_rules! handle_tree_recursion_without_stop { + ($TNR:expr, $NODE:expr) => { + match $TNR { + TreeNodeRecursion::Continue => {} + // If the recursion should skip, do not apply to its children, let + // the recursion continue: + TreeNodeRecursion::Skip => return Ok($NODE), + // Stop is not (yet) supported + TreeNodeRecursion::Stop => { + panic!("Stop can't be used in `TreeNode::transform()` and `TreeNode::rewrite()`") + } + } + }; +} + /// Defines a visitable and rewriteable a tree node. This trait is /// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as /// well as expression trees ([`PhysicalExpr`], [`Expr`]) in @@ -123,7 +138,8 @@ pub trait TreeNode: Sized { /// f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled, + /// and please note that [`TreeNodeRecursion::Stop`] is not supported. /// /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. fn transform(self, f_down: &mut FD, f_up: &mut FU) -> Result @@ -132,15 +148,7 @@ pub trait TreeNode: Sized { FU: FnMut(Self) -> Result, { let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?; - match tnr { - TreeNodeRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - TreeNodeRecursion::Skip => return Ok(new_node), - // If the recursion should stop, do not apply to its children - TreeNodeRecursion::Stop => { - panic!("Stop can't be used in TreeNode::transform()") - } - } + handle_tree_recursion_without_stop!(tnr, new_node); let node_with_new_children = new_node.map_children(|node| node.transform(f_down, f_up))?; f_up(node_with_new_children) @@ -212,21 +220,14 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled, + /// and please note that [`TreeNodeRecursion::Stop`] is not supported. /// /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], /// recursion is stopped immediately. fn rewrite>(self, rewriter: &mut R) -> Result { let (new_node, tnr) = rewriter.f_down(self)?; - match tnr { - TreeNodeRecursion::Continue => {} - // If the recursion should skip, do not apply to its children. And let the recursion continue - TreeNodeRecursion::Skip => return Ok(new_node), - // If the recursion should stop, do not apply to its children - TreeNodeRecursion::Stop => { - panic!("Stop can't be used in TreeNode::rewrite()") - } - } + handle_tree_recursion_without_stop!(tnr, new_node); let node_with_new_children = new_node.map_children(|node| node.rewrite(rewriter))?; rewriter.f_up(node_with_new_children) From 35f9006250655289069e88426d4d25217751d206 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 31 Jan 2024 08:45:42 +0100 Subject: [PATCH 04/40] fix api --- datafusion-examples/examples/rewrite_expr.rs | 14 +- datafusion/common/src/tree_node.rs | 267 ++++++++---- .../physical_plan/parquet/row_filter.rs | 23 +- .../aggregate_statistics.rs | 10 +- .../physical_optimizer/coalesce_batches.rs | 5 +- .../combine_partial_final_agg.rs | 10 +- .../enforce_distribution.rs | 39 +- .../src/physical_optimizer/enforce_sorting.rs | 48 ++- .../src/physical_optimizer/join_selection.rs | 28 +- .../limited_distinct_aggregation.rs | 19 +- .../physical_optimizer/output_requirements.rs | 20 +- .../physical_optimizer/pipeline_checker.rs | 5 +- .../physical_optimizer/projection_pushdown.rs | 27 +- .../core/src/physical_optimizer/pruning.rs | 5 +- .../replace_with_order_preserving_variants.rs | 8 +- .../src/physical_optimizer/sort_pushdown.rs | 4 +- .../core/src/physical_optimizer/test_utils.rs | 26 +- .../physical_optimizer/topk_aggregation.rs | 17 +- datafusion/expr/src/expr.rs | 3 +- datafusion/expr/src/expr_rewriter/mod.rs | 54 ++- datafusion/expr/src/expr_rewriter/order_by.rs | 9 +- datafusion/expr/src/logical_plan/plan.rs | 32 +- datafusion/expr/src/tree_node/expr.rs | 403 ++++++++++-------- datafusion/expr/src/tree_node/plan.rs | 28 +- .../src/analyzer/count_wildcard_rule.rs | 115 +++-- .../src/analyzer/inline_table_scan.rs | 24 +- .../optimizer/src/analyzer/rewrite_expr.rs | 16 +- .../optimizer/src/analyzer/type_coercion.rs | 131 +++--- .../optimizer/src/common_subexpr_eliminate.rs | 31 +- datafusion/optimizer/src/decorrelate.rs | 115 ++--- .../src/decorrelate_predicate_subquery.rs | 2 +- datafusion/optimizer/src/push_down_filter.rs | 7 +- .../optimizer/src/scalar_subquery_to_join.rs | 39 +- .../simplify_expressions/expr_simplifier.rs | 289 +++++++------ .../src/simplify_expressions/guarantees.rs | 59 +-- .../simplify_expressions/inlist_simplifier.rs | 16 +- .../or_in_list_simplifier.rs | 8 +- .../src/unwrap_cast_in_comparison.rs | 34 +- .../physical-expr/src/equivalence/class.rs | 10 +- .../physical-expr/src/equivalence/mod.rs | 5 +- .../src/equivalence/projection.rs | 5 +- .../src/equivalence/properties.rs | 5 +- .../physical-expr/src/expressions/case.rs | 10 +- datafusion/physical-expr/src/physical_expr.rs | 7 +- datafusion/physical-expr/src/tree_node.rs | 6 +- datafusion/physical-expr/src/utils/mod.rs | 11 +- datafusion/physical-plan/src/empty.rs | 2 +- .../src/joins/stream_join_utils.rs | 18 +- datafusion/physical-plan/src/joins/utils.rs | 18 +- datafusion/physical-plan/src/lib.rs | 4 +- .../physical-plan/src/placeholder_row.rs | 2 +- .../physical-plan/src/recursive_query.rs | 5 +- datafusion/physical-plan/src/tree_node.rs | 6 +- datafusion/sql/src/utils.rs | 60 +-- .../library-user-guide/working-with-exprs.md | 6 +- 55 files changed, 1248 insertions(+), 922 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 9dfc238ab9e8..88b43ccdede7 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -95,12 +95,12 @@ impl MyAnalyzerRule { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) }) } @@ -111,11 +111,11 @@ impl MyAnalyzerRule { Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { // transform to UInt64 - Transformed::Yes(Expr::Literal(ScalarValue::UInt64( + Transformed::yes(Expr::Literal(ScalarValue::UInt64( i.map(|i| i as u64), ))) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } @@ -175,12 +175,12 @@ fn my_rewrite(expr: Expr) -> Result { let low: Expr = *low; let high: Expr = *high; if negated { - Transformed::Yes(expr.clone().lt(low).or(expr.gt(high))) + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) } else { - Transformed::Yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) + Transformed::yes(expr.clone().gt_eq(low).and(expr.lt_eq(high))) } } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 85338a7200df..f1619257619b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -41,21 +41,6 @@ macro_rules! handle_tree_recursion { }; } -macro_rules! handle_tree_recursion_without_stop { - ($TNR:expr, $NODE:expr) => { - match $TNR { - TreeNodeRecursion::Continue => {} - // If the recursion should skip, do not apply to its children, let - // the recursion continue: - TreeNodeRecursion::Skip => return Ok($NODE), - // Stop is not (yet) supported - TreeNodeRecursion::Stop => { - panic!("Stop can't be used in `TreeNode::transform()` and `TreeNode::rewrite()`") - } - } - }; -} - /// Defines a visitable and rewriteable a tree node. This trait is /// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as /// well as expression trees ([`PhysicalExpr`], [`Expr`]) in @@ -138,66 +123,65 @@ pub trait TreeNode: Sized { /// f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled, - /// and please note that [`TreeNodeRecursion::Stop`] is not supported. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. - fn transform(self, f_down: &mut FD, f_up: &mut FU) -> Result + fn transform( + self, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> where - FD: FnMut(Self) -> Result<(Transformed, TreeNodeRecursion)>, - FU: FnMut(Self) -> Result, + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, { - let (new_node, tnr) = f_down(self).map(|(t, tnr)| (t.into(), tnr))?; - handle_tree_recursion_without_stop!(tnr, new_node); - let node_with_new_children = - new_node.map_children(|node| node.transform(f_down, f_up))?; - f_up(node_with_new_children) + f_down(self)?.and_then_transform_children(|t| { + t.map_children(|node| node.transform(f_down, f_up))? + .and_then_transform_sibling(f_up) + }) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down(self, op: &F) -> Result + fn transform_down(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down(op)) + f(self)?.and_then_transform_children(|t| t.map_children(|n| n.transform_down(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its /// children(Preorder Traversal) using a mutable function, `F`. /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_down_mut(self, op: &mut F) -> Result + fn transform_down_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - let after_op = op(self)?.into(); - after_op.map_children(|node| node.transform_down_mut(op)) + f(self)? + .and_then_transform_children(|t| t.map_children(|n| n.transform_down_mut(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal). /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up(self, op: &F) -> Result + fn transform_up(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) + self.map_children(|node| node.transform_up(f))? + .and_then_transform_sibling(f) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its /// children and then itself(Postorder Traversal) using a mutable function, `F`. /// When the `op` does not apply to a given node, it is left unchanged. - fn transform_up_mut(self, op: &mut F) -> Result + fn transform_up_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - let after_op_children = self.map_children(|node| node.transform_up_mut(op))?; - let new_node = op(after_op_children)?.into(); - Ok(new_node) + self.map_children(|n| n.transform_up_mut(f))? + .and_then_transform_sibling(f) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -220,17 +204,18 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled, - /// and please note that [`TreeNodeRecursion::Stop`] is not supported. + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], /// recursion is stopped immediately. - fn rewrite>(self, rewriter: &mut R) -> Result { - let (new_node, tnr) = rewriter.f_down(self)?; - handle_tree_recursion_without_stop!(tnr, new_node); - let node_with_new_children = - new_node.map_children(|node| node.rewrite(rewriter))?; - rewriter.f_up(node_with_new_children) + fn rewrite>( + self, + rewriter: &mut R, + ) -> Result> { + rewriter.f_down(self)?.and_then_transform_children(|t| { + t.map_children(|n| n.rewrite(rewriter))? + .and_then_transform_sibling(|t| rewriter.f_up(t)) + }) } /// Apply the closure `F` to the node's children @@ -239,9 +224,9 @@ pub trait TreeNode: Sized { F: FnMut(&Self) -> Result; /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result; + F: FnMut(Self) -> Result>; } /// Implements the [visitor @@ -291,20 +276,20 @@ pub trait TreeNodeRewriter: Sized { /// Invoked while traversing down the tree before any children are rewritten / /// visited. /// Default implementation returns the node unmodified and continues recursion. - fn f_down(&mut self, node: Self::Node) -> Result<(Self::Node, TreeNodeRecursion)> { - Ok((node, TreeNodeRecursion::Continue)) + fn f_down(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } /// Invoked while traversing up the tree after all children have been rewritten / /// visited. /// Default implementation returns the node unmodified. - fn f_up(&mut self, node: Self::Node) -> Result { - Ok(node) + fn f_up(&mut self, node: Self::Node) -> Result> { + Ok(Transformed::no(node)) } } /// Controls how [`TreeNode`] recursions should proceed. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, @@ -314,27 +299,147 @@ pub enum TreeNodeRecursion { Stop, } -pub enum Transformed { - /// The item was transformed / rewritten somehow - Yes(T), - /// The item was not transformed - No(T), +pub struct Transformed { + pub data: T, + pub transformed: bool, + pub tnr: TreeNodeRecursion, } impl Transformed { - pub fn into(self) -> T { - match self { - Transformed::Yes(t) => t, - Transformed::No(t) => t, + pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self { + Self { + data, + transformed, + tnr, } } - pub fn into_pair(self) -> (T, bool) { - match self { - Transformed::Yes(t) => (t, true), - Transformed::No(t) => (t, false), + pub fn yes(data: T) -> Self { + Self { + data, + transformed: true, + tnr: TreeNodeRecursion::Continue, } } + + pub fn no(data: T) -> Self { + Self { + data, + transformed: false, + tnr: TreeNodeRecursion::Continue, + } + } + + pub fn map_data U>(self, f: F) -> Transformed { + Transformed { + data: f(self.data), + transformed: self.transformed, + tnr: self.tnr, + } + } + + pub fn flat_map_data Result>( + self, + f: F, + ) -> Result> { + Ok(Transformed { + data: f(self.data)?, + transformed: self.transformed, + tnr: self.tnr, + }) + } + + fn and_then_transform Result>>( + self, + f: F, + children: bool, + ) -> Result> { + match self.tnr { + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Skip => { + // If the next transformation would happen on children return immediately + // on `Skip`. + if children { + return Ok(Transformed { + tnr: TreeNodeRecursion::Continue, + ..self + }); + } + } + TreeNodeRecursion::Stop => return Ok(self), + }; + let t = f(self.data)?; + Ok(Transformed { + transformed: t.transformed || self.transformed, + ..t + }) + } + + pub fn and_then_transform_sibling Result>>( + self, + f: F, + ) -> Result> { + self.and_then_transform(f, false) + } + + pub fn and_then_transform_children Result>>( + self, + f: F, + ) -> Result> { + self.and_then_transform(f, true) + } +} + +pub trait TransformedIterator: Iterator { + fn map_till_continue_and_collect( + self, + f: F, + ) -> Result>> + where + F: FnMut(Self::Item) -> Result>, + Self: Sized; +} + +impl TransformedIterator for I { + fn map_till_continue_and_collect( + self, + mut f: F, + ) -> Result>> + where + F: FnMut(Self::Item) -> Result>, + { + let mut new_tnr = TreeNodeRecursion::Continue; + let mut new_transformed = false; + let new_data = self + .map(|i| { + if new_tnr == TreeNodeRecursion::Continue + || new_tnr == TreeNodeRecursion::Skip + { + let Transformed { + data, + transformed, + tnr, + } = f(i)?; + new_tnr = if tnr == TreeNodeRecursion::Skip { + // Iterator always considers the elements as siblings so `Skip` + // can be safely converted to `Continue`. + TreeNodeRecursion::Continue + } else { + tnr + }; + new_transformed |= transformed; + Ok(data) + } else { + Ok(i) + } + }) + .collect::>>()?; + Ok(Transformed { + data: new_data, + transformed: new_transformed, + tnr: new_tnr, + }) + } } /// Helper trait for implementing [`TreeNode`] that have children stored as Arc's @@ -350,7 +455,7 @@ pub trait DynTreeNode { &self, arc_self: Arc, new_children: Vec>, - ) -> Result>; + ) -> Result>>; } /// Blanket implementation for Arc for any tye that implements @@ -367,18 +472,18 @@ impl TreeNode for Arc { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let children = self.arc_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; + let t = children.into_iter().map_till_continue_and_collect(f)?; + // TODO: once we trust `t.transformed` don't create new node if not necessary let arc_self = Arc::clone(&self); - self.with_new_arc_children(arc_self, new_children) + self.with_new_arc_children(arc_self, t.data) } else { - Ok(self) + Ok(Transformed::no(self)) } } } @@ -409,17 +514,19 @@ impl TreeNode for T { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let (new_self, children) = self.take_children(); if !children.is_empty() { - let new_children = - children.into_iter().map(transform).collect::>()?; - new_self.with_new_children(new_children) + children + .into_iter() + .map_till_continue_and_collect(f)? + // TODO: once we trust `transformed` don't create new node if not necessary + .flat_map_data(|new_children| new_self.with_new_children(new_children)) } else { - Ok(new_self) + Ok(Transformed::no(new_self)) } } } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index ddfeb146b876..bdd607095f44 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -20,7 +20,9 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; @@ -189,7 +191,7 @@ impl<'a> FilterCandidateBuilder<'a> { metadata: &ParquetMetaData, ) -> Result> { let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?; + let expr = expr.rewrite(&mut self)?.data; if self.non_primitive_columns || self.projected_columns { Ok(None) @@ -214,27 +216,30 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { fn f_down( &mut self, node: Arc, - ) -> Result<(Arc, TreeNodeRecursion)> { + ) -> Result>> { if let Some(column) = node.as_any().downcast_ref::() { if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok((node, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok((node, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); } } - Ok((node, TreeNodeRecursion::Continue)) + Ok(Transformed::no(node)) } - fn f_up(&mut self, expr: Arc) -> Result> { + fn f_up( + &mut self, + expr: Arc, + ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema @@ -242,7 +247,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { Ok(field) => { // return the null value corresponding to the data type let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Arc::new(Literal::new(null_value))) + Ok(Transformed::yes(Arc::new(Literal::new(null_value)))) } Err(e) => { // If the column is not in the table schema, should throw the error @@ -252,7 +257,7 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4fe11c14a758..5f872831ef93 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -27,7 +27,7 @@ use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -85,10 +85,14 @@ impl PhysicalOptimizerRule for AggregateStatistics { Arc::new(PlaceholderRowExec::new(plan.schema())), )?)) } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| { + self.optimize(child, _config).map(Transformed::yes) + }) + .map(|t| t.data) } } else { - plan.map_children(|child| self.optimize(child, _config)) + plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) + .map(|t| t.data) } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 7b66ca529094..e3565e451669 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -71,14 +71,15 @@ impl PhysicalOptimizerRule for CoalesceBatches { }) .unwrap_or(false); if wrap_in_coalesce { - Ok(Transformed::Yes(Arc::new(CoalesceBatchesExec::new( + Ok(Transformed::yes(Arc::new(CoalesceBatchesExec::new( plan, target_batch_size, )))) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .map(|t| t.data) } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index b26d9763e53a..ccc9a2909cca 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -109,11 +109,12 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { }); Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) }) + .map(|t| t.data) } fn name(&self) -> &str { @@ -185,11 +186,12 @@ fn discard_column_index(group_expr: Arc) -> Arc None, }; Ok(if let Some(normalized_form) = normalized_form { - Transformed::Yes(normalized_form) + Transformed::yes(normalized_form) } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .map(|t| t.data) .unwrap_or(group_expr) } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index fab26c49c2da..ff033d168e77 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -197,22 +197,25 @@ impl PhysicalOptimizerRule for EnforceDistribution { let adjusted = if top_down_join_key_reordering { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new_default(plan); - let adjusted = - plan_requirements.transform_down(&adjust_input_keys_ordering)?; + let adjusted = plan_requirements + .transform_down(&adjust_input_keys_ordering)? + .data; adjusted.plan } else { // Run a bottom-up process plan.transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) })? + .data }; let distribution_context = DistributionContext::new_default(adjusted); // Distribution enforcement needs to be applied bottom-up. - let distribution_context = - distribution_context.transform_up(&|distribution_context| { + let distribution_context = distribution_context + .transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) - })?; + })? + .data; Ok(distribution_context.plan) } @@ -306,7 +309,7 @@ fn adjust_input_keys_ordering( vec![], &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } PartitionMode::CollectLeft => { // Push down requirements to the right side @@ -368,18 +371,18 @@ fn adjust_input_keys_ordering( sort_options.clone(), &join_constructor, ) - .map(Transformed::Yes); + .map(Transformed::yes); } else if let Some(aggregate_exec) = plan.as_any().downcast_ref::() { if !requirements.data.is_empty() { if aggregate_exec.mode() == &AggregateMode::FinalPartitioned { return reorder_aggregate_keys(requirements, aggregate_exec) - .map(Transformed::Yes); + .map(Transformed::yes); } else { requirements.data.clear(); } } else { // Keep everything unchanged - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } } else if let Some(proj) = plan.as_any().downcast_ref::() { let expr = proj.expr(); @@ -407,7 +410,7 @@ fn adjust_input_keys_ordering( child.data = requirements.data.clone(); } } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn reorder_partitioned_join_keys( @@ -1065,7 +1068,7 @@ fn ensure_distribution( let dist_context = update_children(dist_context)?; if dist_context.plan.children().is_empty() { - return Ok(Transformed::No(dist_context)); + return Ok(Transformed::no(dist_context)); } let target_partitions = config.execution.target_partitions; @@ -1245,7 +1248,7 @@ fn ensure_distribution( plan.with_new_children(children_plans)? }; - Ok(Transformed::Yes(DistributionContext::new( + Ok(Transformed::yes(DistributionContext::new( plan, data, children, ))) } @@ -1718,7 +1721,7 @@ pub(crate) mod tests { config.optimizer.repartition_file_scans = false; config.optimizer.repartition_file_min_size = 1024; config.optimizer.prefer_existing_sort = prefer_existing_sort; - ensure_distribution(distribution_context, &config).map(|item| item.into().plan) + ensure_distribution(distribution_context, &config).map(|item| item.data.plan) } /// Test whether plan matches with expected plan @@ -1786,22 +1789,22 @@ pub(crate) mod tests { let plan_requirements = PlanWithKeyRequirements::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_down(&adjust_input_keys_ordering) + .transform_down(&adjust_input_keys_ordering).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. adjusted.plan } else { // Run reorder_join_keys_to_inputs rule $PLAN.clone().transform_up(&|plan| { - Ok(Transformed::Yes(reorder_join_keys_to_inputs(plan)?)) - })? + Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) + })?.data }; // Then run ensure_distribution rule DistributionContext::new_default(adjusted) .transform_up(&|distribution_context| { ensure_distribution(distribution_context, &config) - }) + }).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 5c46e64a22f6..7a3b2c512111 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -158,33 +158,35 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_requirements = PlanWithCorrespondingSort::new_default(plan); // Execute a bottom-up traversal to enforce sorting requirements, // remove unnecessary sorts, and optimize sort-sensitive operators: - let adjusted = plan_requirements.transform_up(&ensure_sorting)?; + let adjusted = plan_requirements.transform_up(&ensure_sorting)?.data; let new_plan = if config.optimizer.repartition_sorts { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); - let parallel = - plan_with_coalesce_partitions.transform_up(¶llelize_sorts)?; + let parallel = plan_with_coalesce_partitions + .transform_up(¶llelize_sorts)? + .data; parallel.plan } else { adjusted.plan }; let plan_with_pipeline_fixer = OrderPreservationContext::new_default(new_plan); - let updated_plan = - plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| { + let updated_plan = plan_with_pipeline_fixer + .transform_up(&|plan_with_pipeline_fixer| { replace_with_order_preserving_variants( plan_with_pipeline_fixer, false, true, config, ) - })?; + })? + .data; // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); - let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?; + let adjusted = sort_pushdown.transform_down(&pushdown_sorts)?.data; Ok(adjusted.plan) } @@ -221,7 +223,7 @@ fn parallelize_sorts( // `SortPreservingMergeExec` or a `CoalescePartitionsExec`, and they // all have a single child. Therefore, if the first child has no // connection, we can return immediately. - Ok(Transformed::No(requirements)) + Ok(Transformed::no(requirements)) } else if (is_sort(&requirements.plan) || is_sort_preserving_merge(&requirements.plan)) && requirements.plan.output_partitioning().partition_count() <= 1 @@ -250,7 +252,7 @@ fn parallelize_sorts( } let spm = SortPreservingMergeExec::new(sort_exprs, requirements.plan.clone()); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(spm.with_fetch(fetch)), false, @@ -264,7 +266,7 @@ fn parallelize_sorts( // For the removal of self node which is also a `CoalescePartitionsExec`. requirements = requirements.children.swap_remove(0); - Ok(Transformed::Yes( + Ok(Transformed::yes( PlanWithCorrespondingCoalescePartitions::new( Arc::new(CoalescePartitionsExec::new(requirements.plan.clone())), false, @@ -272,7 +274,7 @@ fn parallelize_sorts( ), )) } else { - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -285,10 +287,12 @@ fn ensure_sorting( // Perform naive analysis at the beginning -- remove already-satisfied sorts: if requirements.children.is_empty() { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } let maybe_requirements = analyze_immediate_sort_removal(requirements); - let Transformed::No(mut requirements) = maybe_requirements else { + requirements = if !maybe_requirements.transformed { + maybe_requirements.data + } else { return Ok(maybe_requirements); }; @@ -327,17 +331,17 @@ fn ensure_sorting( // calculate the result in reverse: let child_node = &requirements.children[0]; if is_window(plan) && child_node.data { - return adjust_window_sort_removal(requirements).map(Transformed::Yes); + return adjust_window_sort_removal(requirements).map(Transformed::yes); } else if is_sort_preserving_merge(plan) && child_node.plan.output_partitioning().partition_count() <= 1 { // This `SortPreservingMergeExec` is unnecessary, input already has a // single partition. let child_node = requirements.children.swap_remove(0); - return Ok(Transformed::Yes(child_node)); + return Ok(Transformed::yes(child_node)); } - update_sort_ctx_children(requirements, false).map(Transformed::Yes) + update_sort_ctx_children(requirements, false).map(Transformed::yes) } /// Analyzes a given [`SortExec`] (`plan`) to determine whether its input @@ -367,10 +371,10 @@ fn analyze_immediate_sort_removal( child.data = false; } node.data = false; - return Transformed::Yes(node); + return Transformed::yes(node); } } - Transformed::No(node) + Transformed::no(node) } /// Adjusts a [`WindowAggExec`] or a [`BoundedWindowAggExec`] to determine @@ -641,7 +645,7 @@ mod tests { { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_up(&ensure_sorting) + .transform_up(&ensure_sorting).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -649,7 +653,7 @@ mod tests { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions - .transform_up(¶llelize_sorts) + .transform_up(¶llelize_sorts).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -666,14 +670,14 @@ mod tests { true, state.config_options(), ) - }) + }).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); sort_pushdown - .transform_down(&pushdown_sorts) + .transform_down(&pushdown_sorts).map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 02626056f6cc..98a05b5877e0 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -241,7 +241,9 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_convert_symmetric_subrule), Box::new(hash_join_swap_subrule), ]; - let state = pipeline.transform_up(&|p| apply_subrules(p, &subrules, config))?; + let state = pipeline + .transform_up(&|p| apply_subrules(p, &subrules, config))? + .data; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -256,13 +258,16 @@ impl PhysicalOptimizerRule for JoinSelection { let config = &config.optimizer; let collect_threshold_byte_size = config.hash_join_single_partition_threshold; let collect_threshold_num_rows = config.hash_join_single_partition_threshold_rows; - state.plan.transform_up(&|plan| { - statistical_join_selection_subrule( - plan, - collect_threshold_byte_size, - collect_threshold_num_rows, - ) - }) + state + .plan + .transform_up(&|plan| { + statistical_join_selection_subrule( + plan, + collect_threshold_byte_size, + collect_threshold_num_rows, + ) + }) + .map(|t| t.data) } fn name(&self) -> &str { @@ -438,9 +443,9 @@ fn statistical_join_selection_subrule( }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(plan) + Transformed::no(plan) }) } @@ -671,7 +676,7 @@ fn apply_subrules( // etc. If this doesn't happen, the final `PipelineChecker` rule will // catch this and raise an error anyway. .unwrap_or(true); - Ok(Transformed::Yes(input)) + Ok(Transformed::yes(input)) } #[cfg(test)] @@ -836,6 +841,7 @@ mod tests_statistical { ]; let state = pipeline .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new())) + .map(|t| t.data) .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let config = ConfigOptions::new().optimizer; diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 9855247151b8..caf8b61c5b2c 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -106,7 +106,7 @@ impl LimitedDistinctAggregation { let mut rewrite_applicable = true; let mut closure = |plan: Arc| { if !rewrite_applicable { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { if found_match_aggr { @@ -117,7 +117,7 @@ impl LimitedDistinctAggregation { // a partial and final aggregation with different groupings disqualifies // rewriting the child aggregation rewrite_applicable = false; - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } } } @@ -128,14 +128,18 @@ impl LimitedDistinctAggregation { Some(new_aggr) => { match_aggr = plan; found_match_aggr = true; - return Ok(Transformed::Yes(new_aggr)); + return Ok(Transformed::yes(new_aggr)); } } } rewrite_applicable = false; - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child + .clone() + .transform_down_mut(&mut closure) + .map(|t| t.data) + .ok()?; if is_global_limit { return Some(Arc::new(GlobalLimitExec::new( child, @@ -165,12 +169,13 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { if let Some(plan) = LimitedDistinctAggregation::transform_limit(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) })? + .data } else { plan }; diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index 4d03840d3dd3..38877d0bab69 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -196,15 +196,17 @@ impl PhysicalOptimizerRule for OutputRequirements { ) -> Result> { match self.mode { RuleMode::Add => require_top_ordering(plan), - RuleMode::Remove => plan.transform_up(&|plan| { - if let Some(sort_req) = - plan.as_any().downcast_ref::() - { - Ok(Transformed::Yes(sort_req.input())) - } else { - Ok(Transformed::No(plan)) - } - }), + RuleMode::Remove => plan + .transform_up(&|plan| { + if let Some(sort_req) = + plan.as_any().downcast_ref::() + { + Ok(Transformed::yes(sort_req.input())) + } else { + Ok(Transformed::no(plan)) + } + }) + .map(|t| t.data), } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index bb0665c10bcc..c09d9ada7def 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -53,7 +53,8 @@ impl PhysicalOptimizerRule for PipelineChecker { ) -> Result> { let pipeline = PipelineStatePropagator::new_default(plan); let state = pipeline - .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))?; + .transform_up(&|p| check_finiteness_requirements(p, &config.optimizer))? + .data; Ok(state.plan) } @@ -93,7 +94,7 @@ pub fn check_finiteness_requirements( .unbounded_output(&children_unbounded(&input)) .map(|value| { input.data = value; - Transformed::Yes(input) + Transformed::yes(input) }) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index b2be307c3bd9..d1af2a29cf91 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -74,6 +74,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { _config: &ConfigOptions, ) -> Result> { plan.transform_down(&remove_unnecessary_projections) + .map(|t| t.data) } fn name(&self) -> &str { @@ -98,7 +99,7 @@ pub fn remove_unnecessary_projections( // If the projection does not cause any change on the input, we can // safely remove it: if is_projection_removable(projection) { - return Ok(Transformed::Yes(projection.input().clone())); + return Ok(Transformed::yes(projection.input().clone())); } // If it does, check if we can push it under its child(ren): let input = projection.input().as_any(); @@ -112,7 +113,7 @@ pub fn remove_unnecessary_projections( // To unify 3 or more sequential projections: remove_unnecessary_projections(new_plan) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; } else if let Some(output_req) = input.downcast_ref::() { try_swapping_with_output_req(projection, output_req)? @@ -148,10 +149,10 @@ pub fn remove_unnecessary_projections( None } } else { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); }; - Ok(maybe_modified.map_or(Transformed::No(plan), Transformed::Yes)) + Ok(maybe_modified.map_or(Transformed::no(plan), Transformed::yes)) } /// Tries to embed `projection` to its input (`csv`). If possible, returns @@ -896,16 +897,16 @@ fn update_expr( .clone() .transform_up_mut(&mut |expr: Arc| { if state == RewriteState::RewrittenInvalid { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); } let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; if sync_with_child { state = RewriteState::RewrittenValid; // Update the index of `column`: - Ok(Transformed::Yes(projected_exprs[column.index()].0.clone())) + Ok(Transformed::yes(projected_exprs[column.index()].0.clone())) } else { // default to invalid, in case we can't find the relevant column state = RewriteState::RewrittenInvalid; @@ -924,11 +925,12 @@ fn update_expr( ) }) .map_or_else( - || Ok(Transformed::No(expr)), - |c| Ok(Transformed::Yes(c)), + || Ok(Transformed::no(expr)), + |c| Ok(Transformed::yes(c)), ) } - }); + }) + .map(|t| t.data); new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } @@ -1045,7 +1047,7 @@ fn new_columns_for_join_on( }) .map(|(index, (_, alias))| Column::new(alias, index)); if let Some(new_column) = new_column { - Ok(Transformed::Yes(Arc::new(new_column))) + Ok(Transformed::yes(Arc::new(new_column))) } else { // If the column is not found in the projection expressions, // it means that the column is not projected. In this case, @@ -1056,9 +1058,10 @@ fn new_columns_for_join_on( ))) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } }) + .map(|t| t.data) .ok() }) .collect::>(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index aa72771b1eb3..ecf4bc0e1b7e 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -840,12 +840,13 @@ fn rewrite_column_expr( e.transform_up(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { - return Ok(Transformed::Yes(Arc::new(column_new.clone()))); + return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index bc9bd0010dc5..4629152cddd9 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -236,7 +236,7 @@ pub(crate) fn replace_with_order_preserving_variants( ) -> Result> { update_children(&mut requirements); if !(is_sort(&requirements.plan) && requirements.children[0].data) { - return Ok(Transformed::No(requirements)); + return Ok(Transformed::no(requirements)); } // For unbounded cases, we replace with the order-preserving variant in any @@ -260,13 +260,13 @@ pub(crate) fn replace_with_order_preserving_variants( for child in alternate_plan.children.iter_mut() { child.data = false; } - Ok(Transformed::Yes(alternate_plan)) + Ok(Transformed::yes(alternate_plan)) } else { // The alternate plan does not help, use faster order-breaking variants: alternate_plan = plan_with_order_breaking_variants(alternate_plan)?; alternate_plan.data = false; requirements.children = vec![alternate_plan]; - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } } @@ -395,7 +395,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).and_then(check_integrity)?; + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).map(|t| t.data).and_then(check_integrity)?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/sort_pushdown.rs b/datafusion/core/src/physical_optimizer/sort_pushdown.rs index 3413486c6b46..16b96fce7301 100644 --- a/datafusion/core/src/physical_optimizer/sort_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/sort_pushdown.rs @@ -87,7 +87,7 @@ pub(crate) fn pushdown_sorts( } // Can push down requirements child.data = None; - return Ok(Transformed::Yes(child)); + return Ok(Transformed::yes(child)); } else { // Can not push down requirements requirements.children = vec![child]; @@ -112,7 +112,7 @@ pub(crate) fn pushdown_sorts( requirements = add_sort_above(requirements, sort_reqs, None); assign_initial_requirements(&mut requirements); } - Ok(Transformed::Yes(requirements)) + Ok(Transformed::yes(requirements)) } fn pushdown_requirement_to_children( diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5de6cff0b4fa..0ab1d4edfe8d 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -374,15 +374,19 @@ pub fn sort_exec( /// TODO: Once [`ExecutionPlan`] implements [`PartialEq`], string comparisons should be /// replaced with direct plan equality checks. pub fn check_integrity(context: PlanContext) -> Result> { - context.transform_up(&|node| { - let children_plans = node.plan.children(); - assert_eq!(node.children.len(), children_plans.len()); - for (child_plan, child_node) in children_plans.iter().zip(node.children.iter()) { - assert_eq!( - displayable(child_plan.as_ref()).one_line().to_string(), - displayable(child_node.plan.as_ref()).one_line().to_string() - ); - } - Ok(Transformed::No(node)) - }) + context + .transform_up(&|node| { + let children_plans = node.plan.children(); + assert_eq!(node.children.len(), children_plans.len()); + for (child_plan, child_node) in + children_plans.iter().zip(node.children.iter()) + { + assert_eq!( + displayable(child_plan.as_ref()).one_line().to_string(), + displayable(child_node.plan.as_ref()).one_line().to_string() + ); + } + Ok(Transformed::no(node)) + }) + .map(|t| t.data) } diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index dd0261420304..245617a4d446 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -101,13 +101,13 @@ impl TopKAggregation { let mut cardinality_preserved = true; let mut closure = |plan: Arc| { if !cardinality_preserved { - return Ok(Transformed::No(plan)); + return Ok(Transformed::no(plan)); } if let Some(aggr) = plan.as_any().downcast_ref::() { // either we run into an Aggregate and transform it match Self::transform_agg(aggr, order, limit) { None => cardinality_preserved = false, - Some(plan) => return Ok(Transformed::Yes(plan)), + Some(plan) => return Ok(Transformed::yes(plan)), } } else { // or we continue down whitelisted nodes of other types @@ -115,9 +115,13 @@ impl TopKAggregation { cardinality_preserved = false; } } - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) }; - let child = child.clone().transform_down_mut(&mut closure).ok()?; + let child = child + .clone() + .transform_down_mut(&mut closure) + .map(|t| t.data) + .ok()?; let sort = SortExec::new(sort.expr().to_vec(), child) .with_fetch(sort.fetch()) .with_preserve_partitioning(sort.preserve_partitioning()); @@ -141,12 +145,13 @@ impl PhysicalOptimizerRule for TopKAggregation { plan.transform_down(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { - Transformed::Yes(plan) + Transformed::yes(plan) } else { - Transformed::No(plan) + Transformed::no(plan) }, ) })? + .data } else { plan }; diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 3e94ecbd746a..77ad9591d8c2 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1268,8 +1268,9 @@ impl Expr { rewrite_placeholder(low.as_mut(), expr.as_ref(), schema)?; rewrite_placeholder(high.as_mut(), expr.as_ref(), schema)?; } - Ok(Transformed::Yes(expr)) + Ok(Transformed::yes(expr)) }) + .map(|t| t.data) } /// Returns true if some of this `exprs` subexpressions may not be evaluated diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 76bd51619954..c72c0f00a737 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -37,12 +37,13 @@ pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions @@ -61,12 +62,13 @@ pub fn normalize_col_with_schemas( Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage @@ -80,12 +82,13 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( if let Expr::Column(c) = expr { let col = c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively normalize all [`Column`] expressions in a list of expression trees @@ -106,14 +109,15 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { - Some(new_c) => Transformed::Yes(Expr::Column((*new_c).to_owned())), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -129,12 +133,13 @@ pub fn unnormalize_col(expr: Expr) -> Expr { relation: None, name: c.name, }; - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) .expect("Unnormalize is infallable") } @@ -167,12 +172,13 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform_up(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { - Transformed::Yes(Expr::Column(col)) + Transformed::yes(Expr::Column(col)) } else { - Transformed::No(expr) + Transformed::no(expr) } }) }) + .map(|t| t.data) .expect("strip_outer_reference is infallable") } @@ -253,7 +259,7 @@ where R: TreeNodeRewriter, { let original_name = expr.name_for_alias()?; - let expr = expr.rewrite(rewriter)?; + let expr = expr.rewrite(rewriter)?.data; expr.alias_if_changed(original_name) } @@ -263,7 +269,7 @@ mod test { use crate::expr::Sort; use crate::{col, lit, Cast}; use arrow::datatypes::DataType; - use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeRewriter}; + use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; use std::ops::Add; @@ -275,14 +281,14 @@ mod test { impl TreeNodeRewriter for RecordingRewriter { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { self.v.push(format!("Previsited {expr}")); - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { self.v.push(format!("Mutated {expr}")); - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -297,10 +303,10 @@ mod test { } else { utf8_val }; - Ok(Transformed::Yes(lit(utf8_val))) + Ok(Transformed::yes(lit(utf8_val))) } // otherwise, return None - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }; @@ -308,6 +314,7 @@ mod test { let rewritten = col("state") .eq(lit("foo")) .transform_up(&transformer) + .map(|t| t.data) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); @@ -315,6 +322,7 @@ mod test { let rewritten = col("state") .eq(lit("baz")) .transform_up(&transformer) + .map(|t| t.data) .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } @@ -452,8 +460,8 @@ mod test { impl TreeNodeRewriter for TestRewriter { type Node = Expr; - fn f_up(&mut self, _: Expr) -> Result { - Ok(self.rewrite_to.clone()) + fn f_up(&mut self, _: Expr) -> Result> { + Ok(Transformed::yes(self.rewrite_to.clone())) } } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 1e7efcafd04d..1cc35a1a4b94 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -91,7 +91,7 @@ fn rewrite_in_terms_of_projection( .to_field(input.schema()) .map(|f| f.qualified_column())?, ); - return Ok(Transformed::Yes(col)); + return Ok(Transformed::yes(col)); } // if that doesn't work, try to match the expression as an @@ -103,7 +103,7 @@ fn rewrite_in_terms_of_projection( e } else { // The expr is not based on Aggregate plan output. Skip it. - return Ok(Transformed::No(expr)); + return Ok(Transformed::no(expr)); }; // expr is an actual expr like min(t.c2), but we are looking @@ -118,7 +118,7 @@ fn rewrite_in_terms_of_projection( // look for the column named the same as this expr if let Some(found) = proj_exprs.iter().find(|a| expr_match(&search_col, a)) { let found = found.clone(); - return Ok(Transformed::Yes(match normalized_expr { + return Ok(Transformed::yes(match normalized_expr { Expr::Cast(Cast { expr: _, data_type }) => Expr::Cast(Cast { expr: Box::new(found), data_type, @@ -131,8 +131,9 @@ fn rewrite_in_terms_of_projection( })); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } /// Does the underlying expr match e? diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 80ce38fe9389..f7b035609e05 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -647,29 +647,29 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - fn unalias_down( - expr: Expr, - ) -> Result<(Transformed, TreeNodeRecursion)> { + fn unalias_down(expr: Expr) -> Result> { match expr { Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok((Transformed::No(expr), TreeNodeRecursion::Skip)) + Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)) } - Expr::Alias(_) => Ok(( - Transformed::Yes(expr.unalias()), + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, TreeNodeRecursion::Skip, )), - _ => Ok((Transformed::No(expr), TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(expr)), } } - fn dummy_up(expr: Expr) -> Result { - Ok(expr) + fn dummy_up(expr: Expr) -> Result> { + Ok(Transformed::no(expr)) } - let predicate = predicate.transform(&mut unalias_down, &mut dummy_up)?; + let predicate = + predicate.transform(&mut unalias_down, &mut dummy_up)?.data; Filter::try_new(predicate, Arc::new(inputs[0].clone())) .map(LogicalPlan::Filter) @@ -1243,19 +1243,20 @@ impl LogicalPlan { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; // Replace the placeholder with the value - Ok(Transformed::Yes(Expr::Literal(value))) + Ok(Transformed::yes(Expr::Literal(value))) } Expr::ScalarSubquery(qry) => { let subquery = Arc::new(qry.subquery.replace_params_with_values(param_values)?); - Ok(Transformed::Yes(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns: qry.outer_ref_columns.clone(), }))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } }) + .map(|t| t.data) } } @@ -3310,10 +3311,11 @@ digraph { Arc::new(LogicalPlan::TableScan(table)), ) .unwrap(); - Ok(Transformed::Yes(LogicalPlan::Filter(filter))) + Ok(Transformed::yes(LogicalPlan::Filter(filter))) } - x => Ok(Transformed::No(x)), + x => Ok(Transformed::no(x)), }) + .map(|t| t.data) .unwrap(); let expected = "Explain\ diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index f2b0b4c2d266..5e7dd1990923 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -24,7 +24,9 @@ use crate::expr::{ }; use crate::{Expr, GetFieldAccess}; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{handle_tree_recursion, internal_err, DataFusionError, Result}; impl TreeNode for Expr { @@ -135,10 +137,10 @@ impl TreeNode for Expr { Ok(TreeNodeRecursion::Continue) } - fn map_children Result>( - self, - mut transform: F, - ) -> Result { + fn map_children(self, mut f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { Ok(match self { Expr::Column(_) | Expr::Wildcard { .. } @@ -147,27 +149,28 @@ impl TreeNode for Expr { | Expr::Exists { .. } | Expr::ScalarSubquery(_) | Expr::ScalarVariable(_, _) - | Expr::Literal(_) => self, + | Expr::Literal(_) => Transformed::no(self), Expr::Alias(Alias { expr, relation, name, - }) => Expr::Alias(Alias::new(transform(*expr)?, relation, name)), + }) => f(*expr)?.map_data(|e| Expr::Alias(Alias::new(e, relation, name))), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => Expr::InSubquery(InSubquery::new( - transform_boxed(expr, &mut transform)?, - subquery, - negated, - )), + }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::InSubquery(InSubquery::new(be, subquery, negated))), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - Expr::BinaryExpr(BinaryExpr::new( - transform_boxed(left, &mut transform)?, - op, - transform_boxed(right, &mut transform)?, - )) + transform_box(left, &mut f)? + .map_data(|new_left| (new_left, right)) + .and_then_transform_sibling(|(new_left, right)| { + Ok(transform_box(right, &mut f)? + .map_data(|new_right| (new_left, new_right))) + })? + .map_data(|(new_left, new_right)| { + Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) + }) } Expr::Like(Like { negated, @@ -175,213 +178,281 @@ impl TreeNode for Expr { pattern, escape_char, case_insensitive, - }) => Expr::Like(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, pattern)) + .and_then_transform_sibling(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .map_data(|new_pattern| (new_expr, new_pattern))) + })? + .map_data(|(new_expr, new_pattern)| { + Expr::Like(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), Expr::SimilarTo(Like { negated, expr, pattern, escape_char, case_insensitive, - }) => Expr::SimilarTo(Like::new( - negated, - transform_boxed(expr, &mut transform)?, - transform_boxed(pattern, &mut transform)?, - escape_char, - case_insensitive, - )), - Expr::Not(expr) => Expr::Not(transform_boxed(expr, &mut transform)?), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, pattern)) + .and_then_transform_sibling(|(new_expr, pattern)| { + Ok(transform_box(pattern, &mut f)? + .map_data(|new_pattern| (new_expr, new_pattern))) + })? + .map_data(|(new_expr, new_pattern)| { + Expr::SimilarTo(Like::new( + negated, + new_expr, + new_pattern, + escape_char, + case_insensitive, + )) + }), + Expr::Not(expr) => transform_box(expr, &mut f)?.map_data(|be| Expr::Not(be)), Expr::IsNotNull(expr) => { - Expr::IsNotNull(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotNull(be)) + } + Expr::IsNull(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNull(be)) + } + Expr::IsTrue(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsTrue(be)) + } + Expr::IsFalse(expr) => { + transform_box(expr, &mut f)?.map_data(|be| Expr::IsFalse(be)) } - Expr::IsNull(expr) => Expr::IsNull(transform_boxed(expr, &mut transform)?), - Expr::IsTrue(expr) => Expr::IsTrue(transform_boxed(expr, &mut transform)?), - Expr::IsFalse(expr) => Expr::IsFalse(transform_boxed(expr, &mut transform)?), Expr::IsUnknown(expr) => { - Expr::IsUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsUnknown(be)) } Expr::IsNotTrue(expr) => { - Expr::IsNotTrue(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotTrue(be)) } Expr::IsNotFalse(expr) => { - Expr::IsNotFalse(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotFalse(be)) } Expr::IsNotUnknown(expr) => { - Expr::IsNotUnknown(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotUnknown(be)) } Expr::Negative(expr) => { - Expr::Negative(transform_boxed(expr, &mut transform)?) + transform_box(expr, &mut f)?.map_data(|be| Expr::Negative(be)) } Expr::Between(Between { expr, negated, low, high, - }) => Expr::Between(Between::new( - transform_boxed(expr, &mut transform)?, - negated, - transform_boxed(low, &mut transform)?, - transform_boxed(high, &mut transform)?, - )), - Expr::Case(case) => { - let expr = transform_option_box(case.expr, &mut transform)?; - let when_then_expr = case - .when_then_expr - .into_iter() - .map(|(when, then)| { - Ok(( - transform_boxed(when, &mut transform)?, - transform_boxed(then, &mut transform)?, + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, low, high)) + .and_then_transform_sibling(|(new_expr, low, high)| { + Ok(transform_box(low, &mut f)? + .map_data(|new_low| (new_expr, new_low, high))) + })? + .and_then_transform_sibling(|(new_expr, new_low, high)| { + Ok(transform_box(high, &mut f)? + .map_data(|new_high| (new_expr, new_low, new_high))) + })? + .map_data(|(new_expr, new_low, new_high)| { + Expr::Between(Between::new(new_expr, negated, new_low, new_high)) + }), + Expr::Case(Case { + expr, + when_then_expr, + else_expr, + }) => transform_option_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, when_then_expr, else_expr)) + .and_then_transform_sibling(|(new_expr, when_then_expr, else_expr)| { + Ok(when_then_expr + .into_iter() + .map_till_continue_and_collect(|(when, then)| { + transform_box(when, &mut f)? + .map_data(|new_when| (new_when, then)) + .and_then_transform_sibling(|(new_when, then)| { + Ok(transform_box(then, &mut f)? + .map_data(|new_then| (new_when, new_then))) + }) + })? + .map_data(|new_when_then_expr| { + (new_expr, new_when_then_expr, else_expr) + })) + })? + .and_then_transform_sibling( + |(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.map_data( + |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), )) - }) - .collect::>>()?; - let else_expr = transform_option_box(case.else_expr, &mut transform)?; - - Expr::Case(Case::new(expr, when_then_expr, else_expr)) - } - Expr::Cast(Cast { expr, data_type }) => { - Expr::Cast(Cast::new(transform_boxed(expr, &mut transform)?, data_type)) - } - Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new( - transform_boxed(expr, &mut transform)?, - data_type, - )), + }, + )? + .map_data(|(new_expr, new_when_then_expr, new_else_expr)| { + Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) + }), + Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::Cast(Cast::new(be, data_type))), + Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::Sort(Sort { expr, asc, nulls_first, - }) => Expr::Sort(Sort::new( - transform_boxed(expr, &mut transform)?, - asc, - nulls_first, - )), - Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { - ScalarFunctionDefinition::BuiltIn(fun) => Expr::ScalarFunction( - ScalarFunction::new(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::UDF(fun) => Expr::ScalarFunction( - ScalarFunction::new_udf(fun, transform_vec(args, &mut transform)?), - ), - ScalarFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_box(expr, &mut f)? + .map_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), + Expr::ScalarFunction(ScalarFunction { func_def, args }) => { + transform_vec(args, &mut f)?.flat_map_data(|new_args| match func_def { + ScalarFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + } + ScalarFunctionDefinition::UDF(fun) => { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) + } + ScalarFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + })? + } Expr::WindowFunction(WindowFunction { args, fun, partition_by, order_by, window_frame, - }) => Expr::WindowFunction(WindowFunction::new( - fun, - transform_vec(args, &mut transform)?, - transform_vec(partition_by, &mut transform)?, - transform_vec(order_by, &mut transform)?, - window_frame, - )), + }) => transform_vec(args, &mut f)? + .map_data(|new_args| (new_args, partition_by, order_by)) + .and_then_transform_sibling(|(new_args, partition_by, order_by)| { + Ok(transform_vec(partition_by, &mut f)?.map_data( + |new_partition_by| (new_args, new_partition_by, order_by), + )) + })? + .and_then_transform_sibling(|(new_args, new_partition_by, order_by)| { + Ok(transform_vec(order_by, &mut f)?.map_data(|new_order_by| { + (new_args, new_partition_by, new_order_by) + })) + })? + .map_data(|(new_args, new_partition_by, new_order_by)| { + Expr::WindowFunction(WindowFunction::new( + fun, + new_args, + new_partition_by, + new_order_by, + window_frame, + )) + }), Expr::AggregateFunction(AggregateFunction { args, func_def, distinct, filter, order_by, - }) => match func_def { - AggregateFunctionDefinition::BuiltIn(fun) => { - Expr::AggregateFunction(AggregateFunction::new( - fun, - transform_vec(args, &mut transform)?, - distinct, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::UDF(fun) => { - let order_by = order_by - .map(|order_by| transform_vec(order_by, &mut transform)) - .transpose()?; - Expr::AggregateFunction(AggregateFunction::new_udf( - fun, - transform_vec(args, &mut transform)?, - false, - transform_option_box(filter, &mut transform)?, - transform_option_vec(order_by, &mut transform)?, - )) - } - AggregateFunctionDefinition::Name(_) => { - return internal_err!("Function `Expr` with name should be resolved.") - } - }, + }) => transform_vec(args, &mut f)? + .map_data(|new_args| (new_args, filter, order_by)) + .and_then_transform_sibling(|(new_args, filter, order_by)| { + Ok(transform_option_box(filter, &mut f)? + .map_data(|new_filter| (new_args, new_filter, order_by))) + })? + .and_then_transform_sibling(|(new_args, new_filter, order_by)| { + Ok(transform_option_vec(order_by, &mut f)? + .map_data(|new_order_by| (new_args, new_filter, new_order_by))) + })? + .flat_map_data(|(new_args, new_filter, new_order_by)| match func_def { + AggregateFunctionDefinition::BuiltIn(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new( + fun, + new_args, + distinct, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::UDF(fun) => { + Ok(Expr::AggregateFunction(AggregateFunction::new_udf( + fun, + new_args, + false, + new_filter, + new_order_by, + ))) + } + AggregateFunctionDefinition::Name(_) => { + return internal_err!( + "Function `Expr` with name should be resolved." + ); + } + })?, Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Expr::GroupingSet(GroupingSet::Rollup( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::Cube(exprs) => Expr::GroupingSet(GroupingSet::Cube( - transform_vec(exprs, &mut transform)?, - )), - GroupingSet::GroupingSets(lists_of_exprs) => { - Expr::GroupingSet(GroupingSet::GroupingSets( - lists_of_exprs - .into_iter() - .map(|exprs| transform_vec(exprs, &mut transform)) - .collect::>>()?, - )) - } + GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? + .map_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), + GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? + .map_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs + .into_iter() + .map_till_continue_and_collect(|exprs| transform_vec(exprs, &mut f))? + .map_data(|new_lists_of_exprs| { + Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) + }), }, Expr::InList(InList { expr, list, negated, - }) => Expr::InList(InList::new( - transform_boxed(expr, &mut transform)?, - transform_vec(list, &mut transform)?, - negated, - )), + }) => transform_box(expr, &mut f)? + .map_data(|new_expr| (new_expr, list)) + .and_then_transform_sibling(|(new_expr, list)| { + Ok(transform_vec(list, &mut f)? + .map_data(|new_list| (new_expr, new_list))) + })? + .map_data(|(new_expr, new_list)| { + Expr::InList(InList::new(new_expr, new_list, negated)) + }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - Expr::GetIndexedField(GetIndexedField::new( - transform_boxed(expr, &mut transform)?, - field, - )) + transform_box(expr, &mut f)? + .map_data(|be| Expr::GetIndexedField(GetIndexedField::new(be, field))) } }) } } -fn transform_boxed Result>( - boxed_expr: Box, - transform: &mut F, -) -> Result> { - // TODO: It might be possible to avoid an allocation (the Box::new) below by reusing the box. - transform(*boxed_expr).map(Box::new) +fn transform_box(be: Box, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + Ok(f(*be)?.map_data(Box::new)) } -fn transform_option_box Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|expr| transform_boxed(expr, transform)) - .transpose() +fn transform_option_box( + obe: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + obe.map_or(Ok(Transformed::no(None)), |be| { + Ok(transform_box(be, f)?.map_data(Some)) + }) } /// &mut transform a Option<`Vec` of `Expr`s> -fn transform_option_vec Result>( - option_box: Option>, - transform: &mut F, -) -> Result>> { - option_box - .map(|exprs| transform_vec(exprs, transform)) - .transpose() +fn transform_option_vec( + ove: Option>, + f: &mut F, +) -> Result>>> +where + F: FnMut(Expr) -> Result>, +{ + ove.map_or(Ok(Transformed::no(None)), |ve| { + Ok(transform_vec(ve, f)?.map_data(Some)) + }) } /// &mut transform a `Vec` of `Expr`s -fn transform_vec Result>( - v: Vec, - transform: &mut F, -) -> Result> { - v.into_iter().map(transform).collect() +fn transform_vec(ve: Vec, f: &mut F) -> Result>> +where + F: FnMut(Expr) -> Result>, +{ + ve.into_iter().map_till_continue_and_collect(f) } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 8be24638c1cc..64e678344ea4 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -19,7 +19,9 @@ use crate::LogicalPlan; -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}; +use datafusion_common::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, +}; use datafusion_common::{handle_tree_recursion, Result}; impl TreeNode for LogicalPlan { @@ -76,26 +78,28 @@ impl TreeNode for LogicalPlan { Ok(TreeNodeRecursion::Continue) } - fn map_children(self, transform: F) -> Result + fn map_children(self, f: F) -> Result> where - F: FnMut(Self) -> Result, + F: FnMut(Self) -> Result>, { let old_children = self.inputs(); - let new_children = old_children + let t = old_children .iter() - .map(|&c| c.clone()) - .map(transform) - .collect::>>()?; - - // if any changes made, make a new child + .map(|c| (*c).clone()) + .map_till_continue_and_collect(f)?; + // TODO: once we trust `t.transformed` remove additional check if old_children .iter() - .zip(new_children.iter()) + .zip(t.data.iter()) .any(|(c1, c2)| c1 != &c2) { - self.with_new_exprs(self.expressions(), new_children.as_slice()) + Ok(Transformed::new( + self.with_new_exprs(self.expressions(), t.data.as_slice())?, + true, + t.tnr, + )) } else { - Ok(self) + Ok(Transformed::new(self, false, t.tnr)) } } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 90046ca2aac0..4b6c355cb7e9 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -43,7 +43,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal) + plan.transform_down(&analyze_internal).map(|t| t.data) } fn name(&self) -> &str { @@ -61,7 +61,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes( + Ok(Transformed::yes( LogicalPlanBuilder::from((*window.input).clone()) .window(window_expr)? .build()?, @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Aggregate( + Ok(Transformed::yes(LogicalPlan::Aggregate( Aggregate::try_new(agg.input.clone(), agg.group_expr, aggr_expr)?, ))) } @@ -83,7 +83,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Sort(Sort { + Ok(Transformed::yes(LogicalPlan::Sort(Sort { expr: sort_expr, input, fetch, @@ -95,7 +95,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { .iter() .map(|expr| rewrite_preserving_name(expr.clone(), &mut rewriter)) .collect::>>()?; - Ok(Transformed::Yes(LogicalPlan::Projection( + Ok(Transformed::yes(LogicalPlan::Projection( Projection::try_new(projection_expr, projection.input)?, ))) } @@ -103,12 +103,12 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { predicate, input, .. }) => { let predicate = rewrite_preserving_name(predicate, &mut rewriter)?; - Ok(Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + Ok(Transformed::yes(LogicalPlan::Filter(Filter::try_new( predicate, input, )?))) } - _ => Ok(Transformed::No(plan)), + _ => Ok(Transformed::no(plan)), } } @@ -117,8 +117,8 @@ struct CountWildcardRewriter {} impl TreeNodeRewriter for CountWildcardRewriter { type Node = Expr; - fn f_up(&mut self, old_expr: Expr) -> Result { - let new_expr = match old_expr.clone() { + fn f_up(&mut self, old_expr: Expr) -> Result> { + Ok(match old_expr.clone() { Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( @@ -130,7 +130,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { window_frame, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::WindowFunction(expr::WindowFunction { + Transformed::yes(Expr::WindowFunction(expr::WindowFunction { fun: expr::WindowFunctionDefinition::AggregateFunction( aggregate_function::AggregateFunction::Count, ), @@ -138,10 +138,10 @@ impl TreeNodeRewriter for CountWildcardRewriter { partition_by, order_by, window_frame, - }) + })) } - _ => old_expr, + _ => Transformed::no(old_expr), }, Expr::AggregateFunction(AggregateFunction { func_def: @@ -154,68 +154,65 @@ impl TreeNodeRewriter for CountWildcardRewriter { order_by, }) if args.len() == 1 => match args[0] { Expr::Wildcard { qualifier: None } => { - Expr::AggregateFunction(AggregateFunction::new( + Transformed::yes(Expr::AggregateFunction(AggregateFunction::new( aggregate_function::AggregateFunction::Count, vec![lit(COUNT_STAR_EXPANSION)], distinct, filter, order_by, - )) + ))) } - _ => old_expr, + _ => Transformed::no(old_expr), }, ScalarSubquery(Subquery { subquery, outer_ref_columns, - }) => { - let new_plan = subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - ScalarSubquery(Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns, - }) - } + }) => subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + }), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::InSubquery(InSubquery::new( - expr, - Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - )) - } - Expr::Exists(expr::Exists { subquery, negated }) => { - let new_plan = subquery - .subquery - .as_ref() - .clone() - .transform_down(&analyze_internal)?; - - Expr::Exists(expr::Exists { - subquery: Subquery { - subquery: Arc::new(new_plan), - outer_ref_columns: subquery.outer_ref_columns, - }, - negated, - }) - } - _ => old_expr, - }; - Ok(new_expr) + }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + Expr::InSubquery(InSubquery::new( + expr, + Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + )) + }), + Expr::Exists(expr::Exists { subquery, negated }) => subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal)? + .map_data(|new_plan| { + Expr::Exists(expr::Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + }) + }), + _ => Transformed::no(old_expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index a418fbf5537b..36f0c3318371 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal) + plan.transform_up(&analyze_internal).map(|t| t.data) } fn name(&self) -> &str { @@ -71,16 +71,16 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { // that reference this table. .alias(table_name)? .build()?; - Transformed::Yes(plan) + Transformed::yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform_up(&rewrite_subquery)?; - Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + let new_expr = filter.predicate.transform_up(&rewrite_subquery)?.data; + Transformed::yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, )?)) } - _ => Transformed::No(plan), + _ => Transformed::no(plan), }) } @@ -88,9 +88,9 @@ fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::Exists(Exists { subquery, negated }))) + Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } Expr::InSubquery(InSubquery { expr, @@ -98,19 +98,19 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, )))) } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?; + let new_plan = plan.transform_up(&analyze_internal)?.data; let subquery = subquery.with_plan(Arc::new(new_plan)); - Ok(Transformed::Yes(Expr::ScalarSubquery(subquery))) + Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } - _ => Ok(Transformed::No(expr)), + _ => Ok(Transformed::no(expr)), } } diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 829197b4d948..8f5f1f4292ca 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::utils::list_ndims; use datafusion_common::DFSchema; use datafusion_common::DFSchemaRef; @@ -96,8 +96,8 @@ pub(crate) struct OperatorToFunctionRewriter { impl TreeNodeRewriter for OperatorToFunctionRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { - match expr { + fn f_up(&mut self, expr: Expr) -> Result> { + Ok(match expr { Expr::BinaryExpr(BinaryExpr { ref left, op, @@ -119,16 +119,16 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { // Convert &Box -> Expr let left = (**left).clone(); let right = (**right).clone(); - return Ok(Expr::ScalarFunction(ScalarFunction { + return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(fun), args: vec![left, right], - })); + }))); } - Ok(expr) + Transformed::no(expr) } - _ => Ok(expr), - } + _ => Transformed::no(expr), + }) } } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 14e15f71b18b..7e5b8de9beae 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -128,27 +128,27 @@ pub(crate) struct TypeCoercionRewriter { impl TreeNodeRewriter for TypeCoercionRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(Subquery { subquery, outer_ref_columns, }) => { let new_plan = analyze_internal(&self.schema, &subquery)?; - Ok(Expr::ScalarSubquery(Subquery { + Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, - })) + }))) } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal(&self.schema, &subquery.subquery)?; - Ok(Expr::Exists(Exists { + Ok(Transformed::yes(Expr::Exists(Exists { subquery: Subquery { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }, negated, - })) + }))) } Expr::InSubquery(InSubquery { expr, @@ -166,42 +166,34 @@ impl TreeNodeRewriter for TypeCoercionRewriter { subquery: Arc::new(new_plan), outer_ref_columns: subquery.outer_ref_columns, }; - Ok(Expr::InSubquery(InSubquery::new( + Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( Box::new(expr.cast_to(&common_type, &self.schema)?), cast_subquery(new_subquery, &common_type)?, negated, - ))) - } - Expr::Not(expr) => { - let expr = not(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsTrue(expr) => { - let expr = is_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotTrue(expr) => { - let expr = is_not_true(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsFalse(expr) => { - let expr = is_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotFalse(expr) => { - let expr = - is_not_false(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsUnknown(expr) => { - let expr = is_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) - } - Expr::IsNotUnknown(expr) => { - let expr = - is_not_unknown(get_casted_expr_for_bool_op(&expr, &self.schema)?); - Ok(expr) + )))) } + Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( + &expr, + &self.schema, + )?))), + Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), + Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( + get_casted_expr_for_bool_op(&expr, &self.schema)?, + ))), Expr::Like(Like { negated, expr, @@ -223,14 +215,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { })?; let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?); let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?); - let expr = Expr::Like(Like::new( + Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, pattern, escape_char, case_insensitive, - )); - Ok(expr) + )))) } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left_type, right_type) = get_input_types( @@ -238,12 +229,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &op, &right.get_type(&self.schema)?, )?; - - Ok(Expr::BinaryExpr(BinaryExpr::new( + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left.cast_to(&left_type, &self.schema)?), op, Box::new(right.cast_to(&right_type, &self.schema)?), - ))) + )))) } Expr::Between(Between { expr, @@ -273,13 +263,12 @@ impl TreeNodeRewriter for TypeCoercionRewriter { "Failed to coerce types {expr_type} and {high_type} in BETWEEN expression" )) })?; - let expr = Expr::Between(Between::new( + Ok(Transformed::yes(Expr::Between(Between::new( Box::new(expr.cast_to(&coercion_type, &self.schema)?), negated, Box::new(low.cast_to(&coercion_type, &self.schema)?), Box::new(high.cast_to(&coercion_type, &self.schema)?), - )); - Ok(expr) + )))) } Expr::InList(InList { expr, @@ -306,18 +295,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter { list_expr.cast_to(&coerced_type, &self.schema) }) .collect::>>()?; - let expr = Expr::InList(InList ::new( + Ok(Transformed::yes(Expr::InList(InList ::new( Box::new(cast_expr), cast_list_expr, negated, - )); - Ok(expr) + )))) } } } Expr::Case(case) => { let case = coerce_case_expression(case, &self.schema)?; - Ok(Expr::Case(case)) + Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { @@ -331,7 +319,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun, )?; - Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) + Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( + fun, new_args, + )))) } ScalarFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -339,7 +329,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_expr))) + Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf(fun, new_expr), + ))) } ScalarFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -359,10 +351,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, &fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new( - fun, new_expr, distinct, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new( + fun, new_expr, distinct, filter, order_by, + ), + ))) } AggregateFunctionDefinition::UDF(fun) => { let new_expr = coerce_arguments_for_signature( @@ -370,10 +363,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter { &self.schema, fun.signature(), )?; - let expr = Expr::AggregateFunction(expr::AggregateFunction::new_udf( - fun, new_expr, false, filter, order_by, - )); - Ok(expr) + Ok(Transformed::yes(Expr::AggregateFunction( + expr::AggregateFunction::new_udf( + fun, new_expr, false, filter, order_by, + ), + ))) } AggregateFunctionDefinition::Name(_) => { internal_err!("Function `Expr` with name should be resolved.") @@ -401,14 +395,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter { _ => args, }; - let expr = Expr::WindowFunction(WindowFunction::new( + Ok(Transformed::yes(Expr::WindowFunction(WindowFunction::new( fun, args, partition_by, order_by, window_frame, - )); - Ok(expr) + )))) } Expr::Alias(_) | Expr::Column(_) @@ -425,7 +418,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(expr), + | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), } } } @@ -1283,7 +1276,7 @@ mod test { std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1318,7 +1311,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); // eq @@ -1329,7 +1322,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); // lt @@ -1340,7 +1333,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?; + let result = expr.rewrite(&mut rewriter)?.data; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f3b8d4b4842a..fafc6340f1a1 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,7 +25,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -745,24 +745,24 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. if expr.short_circuits() || is_volatile_expression(&expr)? { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); } let curr_id = &self.id_array[self.curr_index].1; // skip `Expr`s without identifier (empty identifier). if curr_id.is_empty() { self.curr_index += 1; - return Ok((expr, TreeNodeRecursion::Continue)); + return Ok(Transformed::no(expr)); } match self.expr_set.get(curr_id) { Some((_, counter, _)) => { @@ -771,7 +771,11 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // This expr tree is finished. if self.curr_index >= self.id_array.len() { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Skip, + )); } let (series_number, id) = &self.id_array[self.curr_index]; @@ -784,7 +788,11 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { || id.is_empty() || expr_set_item.1 <= 1 { - return Ok((expr, TreeNodeRecursion::Skip)); + return Ok(Transformed::new( + expr, + false, + TreeNodeRecursion::Skip, + )); } self.max_series_number = *series_number; @@ -799,10 +807,14 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // Alias this `Column` expr to it original "expr name", // `projection_push_down` optimizer use "expr name" to eliminate useless // projections. - Ok((col(id).alias(expr_name), TreeNodeRecursion::Skip)) + Ok(Transformed::new( + col(id).alias(expr_name), + true, + TreeNodeRecursion::Skip, + )) } else { self.curr_index += 1; - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } } _ => internal_err!("expr_set invalid state"), @@ -823,6 +835,7 @@ fn replace_common_expr( max_series_number: 0, curr_index: 0, }) + .map(|t| t.data) } #[cfg(test)] diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 49d3c322ca2b..b7119966c41c 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -58,17 +58,17 @@ pub type ExprResultMap = HashMap; impl TreeNodeRewriter for PullUpCorrelatedExpr { type Node = LogicalPlan; - fn f_down(&mut self, plan: LogicalPlan) -> Result<(LogicalPlan, TreeNodeRecursion)> { + fn f_down(&mut self, plan: LogicalPlan) -> Result> { match plan { - LogicalPlan::Filter(_) => Ok((plan, TreeNodeRecursion::Continue)), + LogicalPlan::Filter(_) => Ok(Transformed::no(plan)), LogicalPlan::Union(_) | LogicalPlan::Sort(_) | LogicalPlan::Extension(_) => { let plan_hold_outer = !plan.all_out_ref_exprs().is_empty(); if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } else { - Ok((plan, TreeNodeRecursion::Continue)) + Ok(Transformed::no(plan)) } } LogicalPlan::Limit(_) => { @@ -77,21 +77,21 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } - _ => Ok((plan, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(plan)), } } _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok((plan, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) } - _ => Ok((plan, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(plan)), } } - fn f_up(&mut self, plan: LogicalPlan) -> Result { + fn f_up(&mut self, plan: LogicalPlan) -> Result> { let subquery_schema = plan.schema().clone(); match &plan { LogicalPlan::Filter(plan_filter) => { @@ -140,7 +140,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { .build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } (None, _) => { // if the subquery still has filter expressions, restore them. @@ -152,7 +152,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = plan.build()?; self.correlated_subquery_cols_map .insert(new_plan.clone(), correlated_subquery_cols); - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } } } @@ -196,7 +196,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Aggregate(aggregate) if self.in_predicate_opt.is_some() || !self.join_filters.is_empty() => @@ -240,7 +240,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(new_plan.clone(), expr_result_map_for_count_bug); } - Ok(new_plan) + Ok(Transformed::yes(new_plan)) } LogicalPlan::SubqueryAlias(alias) => { let mut local_correlated_cols = BTreeSet::new(); @@ -262,7 +262,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { self.collected_count_expr_map .insert(plan.clone(), input_map.clone()); } - Ok(plan) + Ok(Transformed::no(plan)) } LogicalPlan::Limit(limit) => { let input_expr_map = self @@ -273,7 +273,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { let new_plan = match (self.exists_sub_query, self.join_filters.is_empty()) { // Correlated exist subquery, remove the limit(so that correlated expressions can pull up) - (true, false) => { + (true, false) => Transformed::yes( if limit.fetch.filter(|limit_row| *limit_row == 0).is_some() { LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -281,17 +281,17 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { }) } else { LogicalPlanBuilder::from((*limit.input).clone()).build()? - } - } - _ => plan, + }, + ), + _ => Transformed::no(plan), }; if let Some(input_map) = input_expr_map { self.collected_count_expr_map - .insert(new_plan.clone(), input_map); + .insert(new_plan.data.clone(), input_map); } Ok(new_plan) } - _ => Ok(plan), + _ => Ok(Transformed::no(plan)), } } } @@ -370,31 +370,34 @@ fn agg_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for e in agg_expr.iter() { - let result_expr = e.clone().transform_up(&|expr| { - let new_expr = match expr { - Expr::AggregateFunction(expr::AggregateFunction { func_def, .. }) => { - match func_def { + let result_expr = e + .clone() + .transform_up(&|expr| { + let new_expr = match expr { + Expr::AggregateFunction(expr::AggregateFunction { + func_def, .. + }) => match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { if matches!(fun, datafusion_expr::AggregateFunction::Count) { - Transformed::Yes(Expr::Literal(ScalarValue::Int64(Some( + Transformed::yes(Expr::Literal(ScalarValue::Int64(Some( 0, )))) } else { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } } AggregateFunctionDefinition::UDF { .. } => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } AggregateFunctionDefinition::Name(_) => { - Transformed::Yes(Expr::Literal(ScalarValue::Null)) + Transformed::yes(Expr::Literal(ScalarValue::Null)) } - } - } - _ => Transformed::No(expr), - }; - Ok(new_expr) - })?; + }, + _ => Transformed::no(expr), + }; + Ok(new_expr) + })? + .data; let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); @@ -415,17 +418,22 @@ fn proj_exprs_evaluation_result_on_empty_batch( expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result<()> { for expr in proj_expr.iter() { - let result_expr = expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(name) + { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + })? + .data; if result_expr.ne(expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); @@ -448,17 +456,20 @@ fn filter_exprs_evaluation_result_on_empty_batch( input_expr_result_map_for_count_bug: &ExprResultMap, expr_result_map_for_count_bug: &mut ExprResultMap, ) -> Result> { - let result_expr = filter_expr.clone().transform_up(&|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { - Ok(Transformed::Yes(result_expr.clone())) + let result_expr = filter_expr + .clone() + .transform_up(&|expr| { + if let Expr::Column(Column { name, .. }) = &expr { + if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + Ok(Transformed::yes(result_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::No(expr)) - } - })?; + })? + .data; let pull_up_expr = if result_expr.ne(filter_expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 450336376a23..4e94bcc2b085 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -228,7 +228,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery.clone().rewrite(&mut pull_up)?; + let new_plan = subquery.clone().rewrite(&mut pull_up)?.data; if !pull_up.can_pull_up { return Ok(None); } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0ae0bc696a35..bb1ff413088f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -1022,13 +1022,14 @@ pub fn replace_cols_by_name( e.transform_up(&|expr| { Ok(if let Expr::Column(c) = &expr { match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::Yes(new_c.clone()), - None => Transformed::No(expr), + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), } } else { - Transformed::No(expr) + Transformed::no(expr) }) }) + .map(|t| t.data) } /// check whether the expression uses the columns in `check_map`. diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index e1c35e468f68..0ac053dacd29 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -56,7 +56,7 @@ impl ScalarSubqueryToJoin { sub_query_info: vec![], alias_gen, }; - let new_expr = predicate.clone().rewrite(&mut extract)?; + let new_expr = predicate.clone().rewrite(&mut extract)?.data; Ok((extract.sub_query_info, new_expr)) } } @@ -86,20 +86,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { build_join(&subquery, &cur_input, &alias)? { if !expr_check_map.is_empty() { - rewrite_expr = - rewrite_expr.clone().transform_up(&|expr| { + rewrite_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + })? + .data; } cur_input = optimized_subquery; } else { @@ -141,20 +143,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { if let Some(rewrite_expr) = expr_to_rewrite_expr_map.get(expr) { - let new_expr = - rewrite_expr.clone().transform_up(&|expr| { + let new_expr = rewrite_expr + .clone() + .transform_up(&|expr| { if let Expr::Column(col) = &expr { if let Some(map_expr) = expr_check_map.get(&col.name) { - Ok(Transformed::Yes(map_expr.clone())) + Ok(Transformed::yes(map_expr.clone())) } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } } else { - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) } - })?; + })? + .data; expr_to_rewrite_expr_map.insert(expr, new_expr); } } @@ -203,7 +207,7 @@ struct ExtractScalarSubQuery { impl TreeNodeRewriter for ExtractScalarSubQuery { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { match expr { Expr::ScalarSubquery(subquery) => { let subqry_alias = self.alias_gen.next("__scalar_sq"); @@ -213,15 +217,16 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { .subquery .head_output_expr()? .map_or(plan_err!("single expression required."), Ok)?; - Ok(( + Ok(Transformed::new( Expr::Column(create_col_from_scalar_expr( &scalar_expr, subqry_alias, )?), + true, TreeNodeRecursion::Skip, )) } - _ => Ok((expr, TreeNodeRecursion::Continue)), + _ => Ok(Transformed::no(expr)), } } } @@ -278,7 +283,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?.data; if !pull_up.can_pull_up { return Ok(None); } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index fd77071ea728..9f0e4b82e3a5 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -33,7 +33,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::tree_node::Transformed; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{TreeNode, TreeNodeRewriter}, @@ -143,18 +143,25 @@ impl ExprSimplifier { // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 expr.rewrite(&mut const_evaluator)? + .data .rewrite(&mut simplifier)? + .data .rewrite(&mut or_in_list_simplifier)? + .data .rewrite(&mut inlist_simplifier)? + .data .rewrite(&mut guarantee_rewriter)? + .data // run both passes twice to try an minimize simplifications that we missed .rewrite(&mut const_evaluator)? + .data .rewrite(&mut simplifier) + .map(|t| t.data) } pub fn canonicalize(&self, expr: Expr) -> Result { let mut canonicalizer = Canonicalizer::new(); - expr.rewrite(&mut canonicalizer) + expr.rewrite(&mut canonicalizer).map(|t| t.data) } /// Apply type coercion to an [`Expr`] so that it can be /// evaluated as a [`PhysicalExpr`](datafusion_physical_expr::PhysicalExpr). @@ -169,7 +176,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite) + expr.rewrite(&mut expr_rewrite).map(|t| t.data) } /// Input guarantees about the values of columns. @@ -249,30 +256,34 @@ impl Canonicalizer { impl TreeNodeRewriter for Canonicalizer { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else { - return Ok(expr); + return Ok(Transformed::no(expr)); }; match (left.as_ref(), right.as_ref(), op.swap()) { // (Expr::Column(left_col), Expr::Column(right_col), Some(swapped_op)) if right_col > left_col => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } // (Expr::Literal(_a), Expr::Column(_b), Some(swapped_op)) => { - Ok(Expr::BinaryExpr(BinaryExpr { + Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr { left: right, op: swapped_op, right: left, - })) + }))) } - _ => Ok(Expr::BinaryExpr(BinaryExpr { left, op, right })), + _ => Ok(Transformed::no(Expr::BinaryExpr(BinaryExpr { + left, + op, + right, + }))), } } } @@ -313,7 +324,7 @@ enum ConstSimplifyResult { impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { type Node = Expr; - fn f_down(&mut self, expr: Expr) -> Result<(Expr, TreeNodeRecursion)> { + fn f_down(&mut self, expr: Expr) -> Result> { // Default to being able to evaluate this node self.can_evaluate.push(true); @@ -337,10 +348,10 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { // NB: do not short circuit recursion even if we find a non // evaluatable node (so we can fold other children, args to // functions, etc) - Ok((expr, TreeNodeRecursion::Continue)) + Ok(Transformed::no(expr)) } - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match self.can_evaluate.pop() { // Certain expressions such as `CASE` and `COALESCE` are short circuiting // and may not evalute all their sub expressions. Thus if @@ -349,11 +360,15 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> { Some(true) => { let result = self.evaluate_to_scalar(expr); match result { - ConstSimplifyResult::Simplified(s) => Ok(Expr::Literal(s)), - ConstSimplifyResult::SimplifyRuntimeError(_, expr) => Ok(expr), + ConstSimplifyResult::Simplified(s) => { + Ok(Transformed::yes(Expr::Literal(s))) + } + ConstSimplifyResult::SimplifyRuntimeError(_, expr) => { + Ok(Transformed::yes(expr)) + } } } - Some(false) => Ok(expr), + Some(false) => Ok(Transformed::no(expr)), _ => internal_err!("Failed to pop can_evaluate"), } } @@ -508,7 +523,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { type Node = Expr; /// rewrite the expression simplifying any constant expressions - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ And, BitwiseAnd, BitwiseOr, BitwiseShiftLeft, BitwiseShiftRight, BitwiseXor, Divide, Eq, Modulo, Multiply, NotEq, Or, RegexIMatch, RegexMatch, @@ -516,7 +531,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { }; let info = self.info; - let new_expr = match expr { + Ok(match expr { // // Rules for Eq // @@ -529,11 +544,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => *right, Some(false) => Expr::Not(right), None => lit_bool_null(), - } + }) } // A = true --> A // A = false --> !A @@ -543,11 +558,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Eq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => *left, Some(false) => Expr::Not(left), None => lit_bool_null(), - } + }) } // expr IN () --> false // expr NOT IN () --> true @@ -556,7 +571,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { list, negated, }) if list.is_empty() && *expr != Expr::Literal(ScalarValue::Null) => { - lit(negated) + Transformed::yes(lit(negated)) } // null in (x, y, z) --> null @@ -565,7 +580,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { expr, list: _, negated: _, - }) if is_null(&expr) => lit_bool_null(), + }) if is_null(&expr) => Transformed::yes(lit_bool_null()), // expr IN ((subquery)) -> expr IN (subquery), see ##5529 Expr::InList(InList { @@ -578,7 +593,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { let Expr::ScalarSubquery(subquery) = list.remove(0) else { unreachable!() }; - Expr::InSubquery(InSubquery::new(expr, subquery, negated)) + Transformed::yes(Expr::InSubquery(InSubquery::new( + expr, subquery, negated, + ))) } // if expr is a single column reference: @@ -599,7 +616,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ) => { let first_val = list[0].clone(); - if negated { + Transformed::yes(if negated { list.into_iter().skip(1).fold( (*expr.clone()).not_eq(first_val), |acc, y| { @@ -631,7 +648,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { acc.or((*expr.clone()).eq(y)) }, ) - } + }) } // // Rules for NotEq @@ -645,11 +662,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&left) && info.is_boolean_type(&right)? => { - match as_bool_lit(*left)? { + Transformed::yes(match as_bool_lit(*left)? { Some(true) => Expr::Not(right), Some(false) => *right, None => lit_bool_null(), - } + }) } // A != true --> !A // A != false --> A @@ -659,11 +676,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: NotEq, right, }) if is_bool_lit(&right) && info.is_boolean_type(&left)? => { - match as_bool_lit(*right)? { + Transformed::yes(match as_bool_lit(*right)? { Some(true) => Expr::Not(left), Some(false) => *left, None => lit_bool_null(), - } + }) } // @@ -675,32 +692,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Or, right: _, - }) if is_true(&left) => *left, + }) if is_true(&left) => Transformed::yes(*left), // false OR A --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&left) => *right, + }) if is_false(&left) => Transformed::yes(*right), // A OR true --> true (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: Or, right, - }) if is_true(&right) => *right, + }) if is_true(&right) => Transformed::yes(*right), // A OR false --> A Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if is_false(&right) => *left, + }) if is_false(&right) => Transformed::yes(*left), // A OR !A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // !A OR A ---> true (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -708,32 +725,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: Or, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(true))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(true)))) } // (..A..) OR A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&left, &right, Or) => *left, + }) if expr_contains(&left, &right, Or) => Transformed::yes(*left), // A OR (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if expr_contains(&right, &left, Or) => *right, + }) if expr_contains(&right, &left, Or) => Transformed::yes(*right), // A OR (A AND B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(And, &right, &left) => { + Transformed::yes(*left) + } // (A AND B) OR A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: Or, right, - }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(And, &left, &right) => { + Transformed::yes(*right) + } // // Rules for AND @@ -744,32 +765,32 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: And, right, - }) if is_true(&left) => *right, + }) if is_true(&left) => Transformed::yes(*right), // false AND A --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: And, right: _, - }) if is_false(&left) => *left, + }) if is_false(&left) => Transformed::yes(*left), // A AND true --> A Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if is_true(&right) => *left, + }) if is_true(&right) => Transformed::yes(*left), // A AND false --> false (even if A is null) Expr::BinaryExpr(BinaryExpr { left: _, op: And, right, - }) if is_false(&right) => *right, + }) if is_false(&right) => Transformed::yes(*right), // A AND !A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: And, right, }) if is_not_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // !A AND A ---> false (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -777,32 +798,36 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: And, right, }) if is_not_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::Boolean(Some(false))) + Transformed::yes(Expr::Literal(ScalarValue::Boolean(Some(false)))) } // (..A..) AND A --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&left, &right, And) => *left, + }) if expr_contains(&left, &right, And) => Transformed::yes(*left), // A AND (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if expr_contains(&right, &left, And) => *right, + }) if expr_contains(&right, &left, And) => Transformed::yes(*right), // A AND (A OR B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + }) if !info.nullable(&right)? && is_op_with(Or, &right, &left) => { + Transformed::yes(*left) + } // (A OR B) AND A --> A (if B not null) Expr::BinaryExpr(BinaryExpr { left, op: And, right, - }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + }) if !info.nullable(&left)? && is_op_with(Or, &left, &right) => { + Transformed::yes(*right) + } // // Rules for Multiply @@ -813,25 +838,25 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Multiply, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // 1 * A --> A Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right, - }) if is_one(&left) => *right, + }) if is_one(&left) => Transformed::yes(*right), // A * null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Multiply, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null * A --> null Expr::BinaryExpr(BinaryExpr { left, op: Multiply, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A * 0 --> 0 (if A is not null and not floating, since NAN * 0 -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -842,7 +867,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_zero(&right) => { - *right + Transformed::yes(*right) } // 0 * A --> 0 (if A is not null and not floating, since 0 * NAN -> NAN) Expr::BinaryExpr(BinaryExpr { @@ -853,7 +878,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&right)?.is_floating() && is_zero(&left) => { - *left + Transformed::yes(*left) } // @@ -865,19 +890,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: Divide, right, - }) if is_one(&right) => *left, + }) if is_one(&right) => Transformed::yes(*left), // null / A --> null Expr::BinaryExpr(BinaryExpr { left, op: Divide, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A / null --> null Expr::BinaryExpr(BinaryExpr { left: _, op: Divide, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // // Rules for Modulo @@ -888,13 +913,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: Modulo, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null % A --> null Expr::BinaryExpr(BinaryExpr { left, op: Modulo, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A % 1 --> 0 (if A is not nullable and not floating, since NAN % 1 --> NAN) Expr::BinaryExpr(BinaryExpr { left, @@ -904,7 +929,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { && !info.get_data_type(&left)?.is_floating() && is_one(&right) => { - lit(0) + Transformed::yes(lit(0)) } // @@ -916,28 +941,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseAnd, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null & A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A & 0 -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&left)? && is_zero(&right) => *right, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*right), // 0 & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if !info.nullable(&right)? && is_zero(&left) => *left, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*left), // !A & A -> 0 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -945,7 +970,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // A & !A -> 0 (if A not nullable) @@ -954,7 +981,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_zero( + &info.get_data_type(&left)?, + )?)) } // (..A..) & A --> (..A..) @@ -962,14 +991,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseAnd, right, - }) if expr_contains(&left, &right, BitwiseAnd) => *left, + }) if expr_contains(&left, &right, BitwiseAnd) => Transformed::yes(*left), // A & (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseAnd, right, - }) if expr_contains(&right, &left, BitwiseAnd) => *right, + }) if expr_contains(&right, &left, BitwiseAnd) => Transformed::yes(*right), // A & (A | B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -977,7 +1006,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&right)? && is_op_with(BitwiseOr, &right, &left) => { - *left + Transformed::yes(*left) } // (A | B) & A --> A (if B not null) @@ -986,7 +1015,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseAnd, right, }) if !info.nullable(&left)? && is_op_with(BitwiseOr, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -998,28 +1027,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseOr, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null | A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A | 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // 0 | A -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if is_zero(&left) => *right, + }) if is_zero(&left) => Transformed::yes(*right), // !A | A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1027,7 +1056,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A | !A -> -1 (if A not nullable) @@ -1036,7 +1067,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) | A --> (..A..) @@ -1044,14 +1077,14 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: BitwiseOr, right, - }) if expr_contains(&left, &right, BitwiseOr) => *left, + }) if expr_contains(&left, &right, BitwiseOr) => Transformed::yes(*left), // A | (..A..) --> (..A..) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseOr, right, - }) if expr_contains(&right, &left, BitwiseOr) => *right, + }) if expr_contains(&right, &left, BitwiseOr) => Transformed::yes(*right), // A | (A & B) --> A (if B not null) Expr::BinaryExpr(BinaryExpr { @@ -1059,7 +1092,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&right)? && is_op_with(BitwiseAnd, &right, &left) => { - *left + Transformed::yes(*left) } // (A & B) | A --> A (if B not null) @@ -1068,7 +1101,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseOr, right, }) if !info.nullable(&left)? && is_op_with(BitwiseAnd, &left, &right) => { - *right + Transformed::yes(*right) } // @@ -1080,28 +1113,28 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseXor, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null ^ A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A ^ 0 -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&left)? && is_zero(&right) => *left, + }) if !info.nullable(&left)? && is_zero(&right) => Transformed::yes(*left), // 0 ^ A -> A (if A not nullable) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseXor, right, - }) if !info.nullable(&right)? && is_zero(&left) => *right, + }) if !info.nullable(&right)? && is_zero(&left) => Transformed::yes(*right), // !A ^ A -> -1 (if A not nullable) Expr::BinaryExpr(BinaryExpr { @@ -1109,7 +1142,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&left, &right) && !info.nullable(&right)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // A ^ !A -> -1 (if A not nullable) @@ -1118,7 +1153,9 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { op: BitwiseXor, right, }) if is_negative_of(&right, &left) && !info.nullable(&left)? => { - Expr::Literal(ScalarValue::new_negative_one(&info.get_data_type(&left)?)?) + Transformed::yes(Expr::Literal(ScalarValue::new_negative_one( + &info.get_data_type(&left)?, + )?)) } // (..A..) ^ A --> (the expression without A, if number of A is odd, otherwise one A) @@ -1128,11 +1165,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&left, &right, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&left, &right, false); - if expr == *right { + Transformed::yes(if expr == *right { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&right)?)?) } else { expr - } + }) } // A ^ (..A..) --> (the expression without A, if number of A is odd, otherwise one A) @@ -1142,11 +1179,11 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { right, }) if expr_contains(&right, &left, BitwiseXor) => { let expr = delete_xor_in_complex_expr(&right, &left, true); - if expr == *left { + Transformed::yes(if expr == *left { Expr::Literal(ScalarValue::new_zero(&info.get_data_type(&left)?)?) } else { expr - } + }) } // @@ -1158,21 +1195,21 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftRight, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null >> A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A >> 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftRight, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for BitwiseShiftRight @@ -1183,31 +1220,31 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left: _, op: BitwiseShiftLeft, right, - }) if is_null(&right) => *right, + }) if is_null(&right) => Transformed::yes(*right), // null << A -> null Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right: _, - }) if is_null(&left) => *left, + }) if is_null(&left) => Transformed::yes(*left), // A << 0 -> A (even if A is null) Expr::BinaryExpr(BinaryExpr { left, op: BitwiseShiftLeft, right, - }) if is_zero(&right) => *left, + }) if is_zero(&right) => Transformed::yes(*left), // // Rules for Not // - Expr::Not(inner) => negate_clause(*inner), + Expr::Not(inner) => Transformed::yes(negate_clause(*inner)), // // Rules for Negative // - Expr::Negative(inner) => distribute_negation(*inner), + Expr::Negative(inner) => Transformed::yes(distribute_negation(*inner)), // // Rules for Case @@ -1261,19 +1298,19 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Log), args, - }) => simpl_log(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_log(args, <&S>::clone(&info))?), // power Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Power), args, - }) => simpl_power(args, <&S>::clone(&info))?, + }) => Transformed::yes(simpl_power(args, <&S>::clone(&info))?), // concat Expr::ScalarFunction(ScalarFunction { func_def: ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Concat), args, - }) => simpl_concat(args)?, + }) => Transformed::yes(simpl_concat(args)?), // concat_ws Expr::ScalarFunction(ScalarFunction { @@ -1283,11 +1320,13 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { ), args, }) => match &args[..] { - [delimiter, vals @ ..] => simpl_concat_ws(delimiter, vals)?, - _ => Expr::ScalarFunction(ScalarFunction::new( + [delimiter, vals @ ..] => { + Transformed::yes(simpl_concat_ws(delimiter, vals)?) + } + _ => Transformed::yes(Expr::ScalarFunction(ScalarFunction::new( BuiltinScalarFunction::ConcatWithSeparator, args, - )), + ))), }, // @@ -1296,18 +1335,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { // a between 3 and 5 --> a >= 3 AND a <=5 // a not between 3 and 5 --> a < 3 OR a > 5 - Expr::Between(between) => { - if between.negated { - let l = *between.expr.clone(); - let r = *between.expr; - or(l.lt(*between.low), r.gt(*between.high)) - } else { - and( - between.expr.clone().gt_eq(*between.low), - between.expr.lt_eq(*between.high), - ) - } - } + Expr::Between(between) => Transformed::yes(if between.negated { + let l = *between.expr.clone(); + let r = *between.expr; + or(l.lt(*between.low), r.gt(*between.high)) + } else { + and( + between.expr.clone().gt_eq(*between.low), + between.expr.lt_eq(*between.high), + ) + }), // // Rules for regexes @@ -1316,7 +1353,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { left, op: op @ (RegexMatch | RegexNotMatch | RegexIMatch | RegexNotIMatch), right, - }) => simplify_regex_expr(left, op, right)?, + }) => Transformed::yes(simplify_regex_expr(left, op, right)?), // Rules for Like Expr::Like(Like { @@ -1331,25 +1368,24 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { Expr::Literal(ScalarValue::Utf8(Some(pattern_str))) if pattern_str == "%" ) => { - lit(!negated) + Transformed::yes(lit(!negated)) } // a is not null/unknown --> true (if a is not nullable) Expr::IsNotNull(expr) | Expr::IsNotUnknown(expr) if !info.nullable(&expr)? => { - lit(true) + Transformed::yes(lit(true)) } // a is null/unknown --> false (if a is not nullable) Expr::IsNull(expr) | Expr::IsUnknown(expr) if !info.nullable(&expr)? => { - lit(false) + Transformed::yes(lit(false)) } // no additional rewrites possible - expr => expr, - }; - Ok(new_expr) + expr => Transformed::no(expr), + }) } } @@ -1473,6 +1509,7 @@ mod tests { let evaluated_expr = input_expr .clone() .rewrite(&mut const_evaluator) + .map(|t| t.data) .expect("successfully evaluated"); assert_eq!( diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index e7c619c046de..8b243f82c714 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -21,6 +21,7 @@ use std::{borrow::Cow, collections::HashMap}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; @@ -59,21 +60,23 @@ impl<'a> GuaranteeRewriter<'a> { impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if self.guarantees.is_empty() { - return Ok(expr); + return Ok(Transformed::no(expr)); } match &expr { Expr::IsNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(true)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(false)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(true))), + Some(NullableInterval::NotNull { .. }) => { + Ok(Transformed::yes(lit(false))) + } + _ => Ok(Transformed::no(expr)), }, Expr::IsNotNull(inner) => match self.guarantees.get(inner.as_ref()) { - Some(NullableInterval::Null { .. }) => Ok(lit(false)), - Some(NullableInterval::NotNull { .. }) => Ok(lit(true)), - _ => Ok(expr), + Some(NullableInterval::Null { .. }) => Ok(Transformed::yes(lit(false))), + Some(NullableInterval::NotNull { .. }) => Ok(Transformed::yes(lit(true))), + _ => Ok(Transformed::no(expr)), }, Expr::Between(Between { expr: inner, @@ -93,14 +96,14 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let contains = expr_interval.contains(*interval)?; if contains.is_certainly_true() { - Ok(lit(!negated)) + Ok(Transformed::yes(lit(!negated))) } else if contains.is_certainly_false() { - Ok(lit(*negated)) + Ok(Transformed::yes(lit(*negated))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -135,23 +138,23 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { let result = left_interval.apply_operator(op, right_interval.as_ref())?; if result.is_certainly_true() { - Ok(lit(true)) + Ok(Transformed::yes(lit(true))) } else if result.is_certainly_false() { - Ok(lit(false)) + Ok(Transformed::yes(lit(false))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } // Columns (if interval is collapsed to a single value) Expr::Column(_) => { if let Some(interval) = self.guarantees.get(&expr) { - Ok(interval.single_value().map_or(expr, lit)) + Ok(Transformed::yes(interval.single_value().map_or(expr, lit))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } @@ -181,17 +184,17 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { }) .collect::>()?; - Ok(Expr::InList(InList { + Ok(Transformed::yes(Expr::InList(InList { expr: inner.clone(), list: new_list, negated: *negated, - })) + }))) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -221,12 +224,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, lit(true)); } @@ -236,7 +239,7 @@ mod tests { T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -248,7 +251,7 @@ mod tests { fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).unwrap(); + let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); assert_eq!( &output, expr, "{} was simplified to {}, but expected it to be unchanged", @@ -478,7 +481,7 @@ mod tests { let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).unwrap(); + let output = col("x").rewrite(&mut rewriter).map(|t| t.data).unwrap(); assert_eq!(output, Expr::Literal(scalar.clone())); } } @@ -522,7 +525,7 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 867e96d213d9..c9d9c00c335e 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -19,7 +19,7 @@ use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; @@ -51,30 +51,30 @@ impl InListSimplifier { impl TreeNodeRewriter for InListSimplifier { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if let (Expr::InList(l1), Operator::And, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) { if l1.expr == l2.expr && !l1.negated && !l2.negated { - return inlist_intersection(l1, l2, false); + return Ok(Transformed::yes(inlist_intersection(l1, l2, false)?)); } else if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_union(l1, l2, true); + return Ok(Transformed::yes(inlist_union(l1, l2, true)?)); } else if l1.expr == l2.expr && !l1.negated && l2.negated { - return inlist_except(l1, l2); + return Ok(Transformed::yes(inlist_except(l1, l2)?)); } else if l1.expr == l2.expr && l1.negated && !l2.negated { - return inlist_except(l2, l1); + return Ok(Transformed::yes(inlist_except(l2, l1)?)); } } else if let (Expr::InList(l1), Operator::Or, Expr::InList(l2)) = (left.as_ref(), op, right.as_ref()) { if l1.expr == l2.expr && l1.negated && l2.negated { - return inlist_intersection(l1, l2, true); + return Ok(Transformed::yes(inlist_intersection(l1, l2, true)?)); } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs index ea02c1f3af8a..ff50b337e158 100644 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs @@ -20,7 +20,7 @@ use std::borrow::Cow; use std::collections::HashSet; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::Result; use datafusion_expr::expr::InList; use datafusion_expr::{BinaryExpr, Expr, Operator}; @@ -39,7 +39,7 @@ impl OrInListSimplifier { impl TreeNodeRewriter for OrInListSimplifier { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { if *op == Operator::Or { let left = as_inlist(left); @@ -66,13 +66,13 @@ impl TreeNodeRewriter for OrInListSimplifier { list, negated: false, }; - return Ok(Expr::InList(merged_inlist)); + return Ok(Transformed::yes(Expr::InList(merged_inlist))); } } } } - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 0232a28c722a..52c9eefb9bab 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS}; -use datafusion_common::tree_node::TreeNodeRewriter; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; @@ -129,7 +129,7 @@ struct UnwrapCastExprRewriter { impl TreeNodeRewriter for UnwrapCastExprRewriter { type Node = Expr; - fn f_up(&mut self, expr: Expr) -> Result { + fn f_up(&mut self, expr: Expr) -> Result> { match &expr { // For case: // try_cast/cast(expr as data_type) op literal @@ -157,11 +157,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(left_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the right expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( lit(value), *op, expr.as_ref().clone(), - )); + ))); } } ( @@ -176,11 +176,11 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { try_cast_literal_to_type(right_lit_value, &expr_type)?; if let Some(value) = casted_scalar_value { // unwrap the cast/try_cast for the left expr - return Ok(binary_expr( + return Ok(Transformed::yes(binary_expr( expr.as_ref().clone(), *op, lit(value), - )); + ))); } } (_, _) => { @@ -189,7 +189,7 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }; } // return the new binary op - Ok(binary_expr(left, *op, right)) + Ok(Transformed::yes(binary_expr(left, *op, right))) } // For case: // try_cast/cast(expr as left_type) in (expr1,expr2,expr3) @@ -213,12 +213,12 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { let internal_left_type = internal_left.get_type(&self.schema); if internal_left_type.is_err() { // error data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let internal_left_type = internal_left_type?; if !is_support_data_type(&internal_left_type) { // not supported data type - return Ok(expr); + return Ok(Transformed::no(expr)); } let right_exprs = list .iter() @@ -253,17 +253,19 @@ impl TreeNodeRewriter for UnwrapCastExprRewriter { }) .collect::>>(); match right_exprs { - Ok(right_exprs) => { - Ok(in_list(internal_left, right_exprs, *negated)) - } - Err(_) => Ok(expr), + Ok(right_exprs) => Ok(Transformed::yes(in_list( + internal_left, + right_exprs, + *negated, + ))), + Err(_) => Ok(Transformed::no(expr)), } } else { - Ok(expr) + Ok(Transformed::no(expr)) } } // TODO: handle other expr type and dfs visit them - _ => Ok(expr), + _ => Ok(Transformed::no(expr)), } } } @@ -730,7 +732,7 @@ mod tests { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).unwrap() + expr.rewrite(&mut expr_rewriter).map(|t| t.data).unwrap() } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 9ee9be94a5f2..87e71e3458cd 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -263,11 +263,12 @@ impl EquivalenceGroup { .transform_up(&|expr| { for cls in self.iter() { if cls.contains(&expr) { - return Ok(Transformed::Yes(cls.canonical_expr().unwrap())); + return Ok(Transformed::yes(cls.canonical_expr().unwrap())); } } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) .unwrap_or(expr) } @@ -458,11 +459,12 @@ impl EquivalenceGroup { column.index() + left_size, )) as _; - return Ok(Transformed::Yes(new_column)); + return Ok(Transformed::yes(new_column)); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 387dce2cdc8b..43cb90e72f5f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -48,12 +48,13 @@ pub fn add_offset_to_expr( offset: usize, ) -> Arc { expr.transform_down(&|e| match e.as_any().downcast_ref::() { - Some(col) => Ok(Transformed::Yes(Arc::new(Column::new( + Some(col) => Ok(Transformed::yes(Arc::new(Column::new( col.name(), offset + col.index(), )))), - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .map(|t| t.data) .unwrap() // Note that we can safely unwrap here since our transform always returns // an `Ok` value. diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index 0f92b2c2f431..a96fbb6e484b 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -68,10 +68,11 @@ impl ProjectionMapping { let matching_input_field = input_schema.field(idx); let matching_input_column = Column::new(matching_input_field.name(), idx); - Ok(Transformed::Yes(Arc::new(matching_input_column))) + Ok(Transformed::yes(Arc::new(matching_input_column))) } - None => Ok(Transformed::No(e)), + None => Ok(Transformed::no(e)), }) + .map(|t| t.data) .map(|source_expr| (source_expr, target_expr)) }) .collect::>>() diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 2471d9249e16..cf05a97e21dd 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -778,6 +778,7 @@ impl EquivalenceProperties { pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new_default(expr.clone()) .transform_up(&|expr| Ok(update_ordering(expr, self))) + .map(|t| t.data) // Guaranteed to always return `Ok`. .unwrap() } @@ -816,9 +817,9 @@ fn update_ordering( // We have a Literal, which is the other possible leaf node type: node.data = node.expr.get_ordering(&[]); } else { - return Transformed::No(node); + return Transformed::no(node); } - Transformed::Yes(node) + Transformed::yes(node) } /// This function determines whether the provided expression is constant diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index b04c66b23728..59c6886d0c0e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -972,11 +972,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .map(|t| t.data) .unwrap(); let expr3 = expr @@ -993,11 +994,12 @@ mod tests { _ => None, }; Ok(if let Some(transformed) = transformed { - Transformed::Yes(transformed) + Transformed::yes(transformed) } else { - Transformed::No(e) + Transformed::no(e) }) }) + .map(|t| t.data) .unwrap(); assert!(expr.ne(&expr2)); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index a8d1e3638a17..253ed8da695b 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -32,6 +32,7 @@ use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; +use datafusion_common::tree_node::Transformed; use itertools::izip; /// Expression that can be evaluated against a RecordBatch @@ -185,7 +186,7 @@ pub type PhysicalExprRef = Arc; pub fn with_new_children_if_necessary( expr: Arc, children: Vec>, -) -> Result> { +) -> Result>> { let old_children = expr.children(); if children.len() != old_children.len() { internal_err!("PhysicalExpr: Wrong number of children") @@ -195,9 +196,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - expr.with_new_children(children) + Ok(Transformed::yes(expr.with_new_children(children)?)) } else { - Ok(expr) + Ok(Transformed::no(expr)) } } diff --git a/datafusion/physical-expr/src/tree_node.rs b/datafusion/physical-expr/src/tree_node.rs index 42dc6673af6a..8f21ffb82457 100644 --- a/datafusion/physical-expr/src/tree_node.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; use datafusion_common::Result; impl DynTreeNode for dyn PhysicalExpr { @@ -34,7 +34,7 @@ impl DynTreeNode for dyn PhysicalExpr { &self, arc_self: Arc, new_children: Vec>, - ) -> Result> { + ) -> Result>> { with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl ExprContext { pub fn update_expr_from_children(mut self) -> Result { let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - self.expr = with_new_children_if_necessary(self.expr, children_expr)?; + self.expr = with_new_children_if_necessary(self.expr, children_expr)?.data; Ok(self) } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 8d4f4cad4afa..694a18e147d3 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -172,7 +172,7 @@ impl<'a, T, F: Fn(&ExprTreeNode) -> Result> // Set the data field of the input expression node to the corresponding node index. node.data = Some(node_idx); // Return the mutated expression node. - Ok(Transformed::Yes(node)) + Ok(Transformed::yes(node)) } } @@ -193,7 +193,9 @@ where constructor, }; // Use the builder to transform the expression tree node into a DAG. - let root = init.transform_up_mut(&mut |node| builder.mutate(node))?; + let root = init + .transform_up_mut(&mut |node| builder.mutate(node))? + .data; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } @@ -230,13 +232,14 @@ pub fn reassign_predicate_columns( Err(_) if ignore_not_found => usize::MAX, Err(e) => return Err(e.into()), }; - return Ok(Transformed::Yes(Arc::new(Column::new( + return Ok(Transformed::yes(Arc::new(Column::new( column.name(), index, )))); } - Ok(Transformed::No(expr)) + Ok(Transformed::no(expr)) }) + .map(|t| t.data) } /// Reverses the ORDER BY expression, which is useful during equivalent window diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 41c8dbed1453..bdb8234be791 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -165,7 +165,7 @@ mod tests { let schema = test::aggr_test_schema(); let empty = Arc::new(EmptyExec::new(schema.clone())); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.into(); + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.data; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 9a4c98927683..3484ee45ba6a 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -284,14 +284,16 @@ pub fn convert_sort_expr_with_filter_schema( if all_columns_are_included { // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. - let converted_filter_expr = expr.transform_up(&|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { - Some(transformed) => Transformed::Yes(transformed), - None => Transformed::No(p), - } - }) - })?; + let converted_filter_expr = expr + .transform_up(&|p| { + convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { + match transformed { + Some(transformed) => Transformed::yes(transformed), + None => Transformed::no(p), + } + }) + })? + .data; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 70f315917351..073d5a035e0c 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -478,13 +478,17 @@ fn replace_on_columns_of_right_ordering( ) -> Result<()> { for (left_col, right_col) in on_columns { for item in right_ordering.iter_mut() { - let new_expr = item.expr.clone().transform_up(&|e| { - if e.eq(right_col) { - Ok(Transformed::Yes(left_col.clone())) - } else { - Ok(Transformed::No(e)) - } - })?; + let new_expr = item + .expr + .clone() + .transform_up(&|e| { + if e.eq(right_col) { + Ok(Transformed::yes(left_col.clone())) + } else { + Ok(Transformed::no(e)) + } + })? + .data; item.expr = new_expr; } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 0a9eab5c8633..0a147a29e1a8 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -489,9 +489,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::Yes(plan.with_new_children(children)?)) + Ok(Transformed::yes(plan.with_new_children(children)?)) } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 3ab3de62f37a..04482d7c1cc1 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -172,7 +172,7 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); let placeholder_2 = - with_new_children_if_necessary(placeholder.clone(), vec![])?.into(); + with_new_children_if_necessary(placeholder.clone(), vec![])?.data; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 614ab990ac49..1683159f3cee 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -317,16 +317,17 @@ fn assign_work_table( ) } else { work_table_refs += 1; - Ok(Transformed::Yes(Arc::new( + Ok(Transformed::yes(Arc::new( exec.with_work_table(work_table.clone()), ))) } } else if plan.as_any().is::() { not_impl_err!("Recursive queries cannot be nested") } else { - Ok(Transformed::No(plan)) + Ok(Transformed::no(plan)) } }) + .map(|t| t.data) } impl Stream for RecursiveQueryStream { diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index b8a5f95c5325..c4223cb73430 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -34,8 +34,8 @@ impl DynTreeNode for dyn ExecutionPlan { &self, arc_self: Arc, new_children: Vec>, - ) -> Result> { - with_new_children_if_necessary(arc_self, new_children).map(Transformed::into) + ) -> Result>> { + with_new_children_if_necessary(arc_self, new_children) } } @@ -63,7 +63,7 @@ impl PlanContext { pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - self.plan = with_new_children_if_necessary(self.plan, children_plans)?.into(); + self.plan = with_new_children_if_necessary(self.plan, children_plans)?.data; Ok(self) } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 0dc1258ebabe..3f6f3aa483ab 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -33,18 +33,20 @@ use std::collections::HashMap; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { - expr.clone().transform_up(&|nested_expr| { - match nested_expr { - Expr::Column(col) => { - let field = plan.schema().field_from_column(&col)?; - Ok(Transformed::Yes(Expr::Column(field.qualified_column()))) - } - _ => { - // keep recursing - Ok(Transformed::No(nested_expr)) + expr.clone() + .transform_up(&|nested_expr| { + match nested_expr { + Expr::Column(col) => { + let field = plan.schema().field_from_column(&col)?; + Ok(Transformed::yes(Expr::Column(field.qualified_column()))) + } + _ => { + // keep recursing + Ok(Transformed::no(nested_expr)) + } } - } - }) + }) + .map(|t| t.data) } /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. @@ -66,13 +68,15 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { - expr.clone().transform_down(&|nested_expr| { - if base_exprs.contains(&nested_expr) { - Ok(Transformed::Yes(expr_as_column_expr(&nested_expr, plan)?)) - } else { - Ok(Transformed::No(nested_expr)) - } - }) + expr.clone() + .transform_down(&|nested_expr| { + if base_exprs.contains(&nested_expr) { + Ok(Transformed::yes(expr_as_column_expr(&nested_expr, plan)?)) + } else { + Ok(Transformed::no(nested_expr)) + } + }) + .map(|t| t.data) } /// Determines if the set of `Expr`'s are a valid projection on the input @@ -170,16 +174,18 @@ pub(crate) fn resolve_aliases_to_exprs( expr: &Expr, aliases: &HashMap, ) -> Result { - expr.clone().transform_up(&|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { - if let Some(aliased_expr) = aliases.get(&c.name) { - Ok(Transformed::Yes(aliased_expr.clone())) - } else { - Ok(Transformed::No(Expr::Column(c))) + expr.clone() + .transform_up(&|nested_expr| match nested_expr { + Expr::Column(c) if c.relation.is_none() => { + if let Some(aliased_expr) = aliases.get(&c.name) { + Ok(Transformed::yes(aliased_expr.clone())) + } else { + Ok(Transformed::no(Expr::Column(c))) + } } - } - _ => Ok(Transformed::No(nested_expr)), - }) + _ => Ok(Transformed::no(nested_expr)), + }) + .map(|t| t.data) } /// given a slice of window expressions sharing the same sort key, find their common partition diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index b128d661f31a..ab2b0a2ce960 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -92,7 +92,7 @@ In our example, we'll use rewriting to update our `add_one` UDF, to be rewritten ### Rewriting with `transform` -To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::No` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::Yes` is used to wrap the new `Expr`. +To implement the inlining, we'll need to write a function that takes an `Expr` and returns a `Result`. If the expression is _not_ to be rewritten `Transformed::no` is used to wrap the original `Expr`. If the expression _is_ to be rewritten, `Transformed::yes` is used to wrap the new `Expr`. ```rust fn rewrite_add_one(expr: Expr) -> Result { @@ -102,9 +102,9 @@ fn rewrite_add_one(expr: Expr) -> Result { let input_arg = scalar_fun.args[0].clone(); let new_expression = input_arg + lit(1i64); - Transformed::Yes(new_expression) + Transformed::yes(new_expression) } - _ => Transformed::No(expr), + _ => Transformed::no(expr), }) }) } From 6f763dcc3d800873d5932919fa6a1510f1aec4b9 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 31 Jan 2024 20:59:24 +0100 Subject: [PATCH 05/40] minor fixes --- datafusion-examples/examples/rewrite_expr.rs | 3 ++ datafusion/expr/src/tree_node/expr.rs | 36 +++++++------------- 2 files changed, 15 insertions(+), 24 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 88b43ccdede7..cc27125c64b5 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -103,6 +103,7 @@ impl MyAnalyzerRule { _ => Transformed::no(plan), }) }) + .map(|t| t.data) } fn analyze_expr(expr: Expr) -> Result { @@ -118,6 +119,7 @@ impl MyAnalyzerRule { _ => Transformed::no(expr), }) }) + .map(|t| t.data) } } @@ -183,6 +185,7 @@ fn my_rewrite(expr: Expr) -> Result { _ => Transformed::no(expr), }) }) + .map(|t| t.data) } #[derive(Default)] diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 5e7dd1990923..5b2e9affc49d 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -214,34 +214,26 @@ impl TreeNode for Expr { case_insensitive, )) }), - Expr::Not(expr) => transform_box(expr, &mut f)?.map_data(|be| Expr::Not(be)), + Expr::Not(expr) => transform_box(expr, &mut f)?.map_data(Expr::Not), Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotNull(be)) - } - Expr::IsNull(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsNull(be)) - } - Expr::IsTrue(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsTrue(be)) - } - Expr::IsFalse(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsFalse(be)) + transform_box(expr, &mut f)?.map_data(Expr::IsNotNull) } + Expr::IsNull(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsNull), + Expr::IsTrue(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsTrue), + Expr::IsFalse(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsFalse), Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsUnknown(be)) + transform_box(expr, &mut f)?.map_data(Expr::IsUnknown) } Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotTrue(be)) + transform_box(expr, &mut f)?.map_data(Expr::IsNotTrue) } Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotFalse(be)) + transform_box(expr, &mut f)?.map_data(Expr::IsNotFalse) } Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::IsNotUnknown(be)) - } - Expr::Negative(expr) => { - transform_box(expr, &mut f)?.map_data(|be| Expr::Negative(be)) + transform_box(expr, &mut f)?.map_data(Expr::IsNotUnknown) } + Expr::Negative(expr) => transform_box(expr, &mut f)?.map_data(Expr::Negative), Expr::Between(Between { expr, negated, @@ -310,9 +302,7 @@ impl TreeNode for Expr { Ok(Expr::ScalarFunction(ScalarFunction::new_udf(fun, new_args))) } ScalarFunctionDefinition::Name(_) => { - return internal_err!( - "Function `Expr` with name should be resolved." - ); + internal_err!("Function `Expr` with name should be resolved.") } })? } @@ -379,9 +369,7 @@ impl TreeNode for Expr { ))) } AggregateFunctionDefinition::Name(_) => { - return internal_err!( - "Function `Expr` with name should be resolved." - ); + internal_err!("Function `Expr` with name should be resolved.") } })?, Expr::GroupingSet(grouping_set) => match grouping_set { From 84d91c6030852c9a0f275a20a8e233eae739aa42 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 1 Feb 2024 09:39:15 +0100 Subject: [PATCH 06/40] fix --- datafusion/common/src/tree_node.rs | 32 ++++++++++++------------ datafusion/expr/src/tree_node/expr.rs | 36 +++++++++++++-------------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index f1619257619b..3008c7c9b0e0 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -135,9 +135,9 @@ pub trait TreeNode: Sized { FD: FnMut(Self) -> Result>, FU: FnMut(Self) -> Result>, { - f_down(self)?.and_then_transform_children(|t| { - t.map_children(|node| node.transform(f_down, f_up))? - .and_then_transform_sibling(f_up) + f_down(self)?.and_then_transform_children(|n| { + n.map_children(|c| c.transform(f_down, f_up))? + .and_then_transform(f_up) }) } @@ -148,7 +148,7 @@ pub trait TreeNode: Sized { where F: Fn(Self) -> Result>, { - f(self)?.and_then_transform_children(|t| t.map_children(|n| n.transform_down(f))) + f(self)?.and_then_transform_children(|n| n.map_children(|c| c.transform_down(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -159,7 +159,7 @@ pub trait TreeNode: Sized { F: FnMut(Self) -> Result>, { f(self)? - .and_then_transform_children(|t| t.map_children(|n| n.transform_down_mut(f))) + .and_then_transform_children(|n| n.map_children(|c| c.transform_down_mut(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its @@ -169,8 +169,8 @@ pub trait TreeNode: Sized { where F: Fn(Self) -> Result>, { - self.map_children(|node| node.transform_up(f))? - .and_then_transform_sibling(f) + self.map_children(|c| c.transform_up(f))? + .and_then_transform(f) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its @@ -180,8 +180,8 @@ pub trait TreeNode: Sized { where F: FnMut(Self) -> Result>, { - self.map_children(|n| n.transform_up_mut(f))? - .and_then_transform_sibling(f) + self.map_children(|c| c.transform_up_mut(f))? + .and_then_transform(f) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -212,9 +212,9 @@ pub trait TreeNode: Sized { self, rewriter: &mut R, ) -> Result> { - rewriter.f_down(self)?.and_then_transform_children(|t| { - t.map_children(|n| n.rewrite(rewriter))? - .and_then_transform_sibling(|t| rewriter.f_up(t)) + rewriter.f_down(self)?.and_then_transform_children(|n| { + n.map_children(|c| c.rewrite(rewriter))? + .and_then_transform(|n| rewriter.f_up(n)) }) } @@ -349,7 +349,7 @@ impl Transformed { }) } - fn and_then_transform Result>>( + fn and_then Result>>( self, f: F, children: bool, @@ -375,18 +375,18 @@ impl Transformed { }) } - pub fn and_then_transform_sibling Result>>( + pub fn and_then_transform Result>>( self, f: F, ) -> Result> { - self.and_then_transform(f, false) + self.and_then(f, false) } pub fn and_then_transform_children Result>>( self, f: F, ) -> Result> { - self.and_then_transform(f, true) + self.and_then(f, true) } } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 5b2e9affc49d..0bd36caaa9d0 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -164,7 +164,7 @@ impl TreeNode for Expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => { transform_box(left, &mut f)? .map_data(|new_left| (new_left, right)) - .and_then_transform_sibling(|(new_left, right)| { + .and_then_transform(|(new_left, right)| { Ok(transform_box(right, &mut f)? .map_data(|new_right| (new_left, new_right))) })? @@ -180,7 +180,7 @@ impl TreeNode for Expr { case_insensitive, }) => transform_box(expr, &mut f)? .map_data(|new_expr| (new_expr, pattern)) - .and_then_transform_sibling(|(new_expr, pattern)| { + .and_then_transform(|(new_expr, pattern)| { Ok(transform_box(pattern, &mut f)? .map_data(|new_pattern| (new_expr, new_pattern))) })? @@ -201,7 +201,7 @@ impl TreeNode for Expr { case_insensitive, }) => transform_box(expr, &mut f)? .map_data(|new_expr| (new_expr, pattern)) - .and_then_transform_sibling(|(new_expr, pattern)| { + .and_then_transform(|(new_expr, pattern)| { Ok(transform_box(pattern, &mut f)? .map_data(|new_pattern| (new_expr, new_pattern))) })? @@ -241,11 +241,11 @@ impl TreeNode for Expr { high, }) => transform_box(expr, &mut f)? .map_data(|new_expr| (new_expr, low, high)) - .and_then_transform_sibling(|(new_expr, low, high)| { + .and_then_transform(|(new_expr, low, high)| { Ok(transform_box(low, &mut f)? .map_data(|new_low| (new_expr, new_low, high))) })? - .and_then_transform_sibling(|(new_expr, new_low, high)| { + .and_then_transform(|(new_expr, new_low, high)| { Ok(transform_box(high, &mut f)? .map_data(|new_high| (new_expr, new_low, new_high))) })? @@ -258,13 +258,13 @@ impl TreeNode for Expr { else_expr, }) => transform_option_box(expr, &mut f)? .map_data(|new_expr| (new_expr, when_then_expr, else_expr)) - .and_then_transform_sibling(|(new_expr, when_then_expr, else_expr)| { + .and_then_transform(|(new_expr, when_then_expr, else_expr)| { Ok(when_then_expr .into_iter() .map_till_continue_and_collect(|(when, then)| { transform_box(when, &mut f)? .map_data(|new_when| (new_when, then)) - .and_then_transform_sibling(|(new_when, then)| { + .and_then_transform(|(new_when, then)| { Ok(transform_box(then, &mut f)? .map_data(|new_then| (new_when, new_then))) }) @@ -273,13 +273,11 @@ impl TreeNode for Expr { (new_expr, new_when_then_expr, else_expr) })) })? - .and_then_transform_sibling( - |(new_expr, new_when_then_expr, else_expr)| { - Ok(transform_option_box(else_expr, &mut f)?.map_data( - |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), - )) - }, - )? + .and_then_transform(|(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.map_data( + |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), + )) + })? .map_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), @@ -314,12 +312,12 @@ impl TreeNode for Expr { window_frame, }) => transform_vec(args, &mut f)? .map_data(|new_args| (new_args, partition_by, order_by)) - .and_then_transform_sibling(|(new_args, partition_by, order_by)| { + .and_then_transform(|(new_args, partition_by, order_by)| { Ok(transform_vec(partition_by, &mut f)?.map_data( |new_partition_by| (new_args, new_partition_by, order_by), )) })? - .and_then_transform_sibling(|(new_args, new_partition_by, order_by)| { + .and_then_transform(|(new_args, new_partition_by, order_by)| { Ok(transform_vec(order_by, &mut f)?.map_data(|new_order_by| { (new_args, new_partition_by, new_order_by) })) @@ -341,11 +339,11 @@ impl TreeNode for Expr { order_by, }) => transform_vec(args, &mut f)? .map_data(|new_args| (new_args, filter, order_by)) - .and_then_transform_sibling(|(new_args, filter, order_by)| { + .and_then_transform(|(new_args, filter, order_by)| { Ok(transform_option_box(filter, &mut f)? .map_data(|new_filter| (new_args, new_filter, order_by))) })? - .and_then_transform_sibling(|(new_args, new_filter, order_by)| { + .and_then_transform(|(new_args, new_filter, order_by)| { Ok(transform_option_vec(order_by, &mut f)? .map_data(|new_order_by| (new_args, new_filter, new_order_by))) })? @@ -390,7 +388,7 @@ impl TreeNode for Expr { negated, }) => transform_box(expr, &mut f)? .map_data(|new_expr| (new_expr, list)) - .and_then_transform_sibling(|(new_expr, list)| { + .and_then_transform(|(new_expr, list)| { Ok(transform_vec(list, &mut f)? .map_data(|new_list| (new_expr, new_list))) })? From 3653aa69362753e97026cffdcedb202d9c151b07 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 1 Feb 2024 15:37:52 +0100 Subject: [PATCH 07/40] don't trust `t.transformed` coming from transformation closures, keep the old way of detecting if changes were made --- datafusion/common/src/tree_node.rs | 22 ++++++++++++------- .../enforce_distribution.rs | 2 +- .../src/physical_optimizer/enforce_sorting.rs | 4 ++-- .../replace_with_order_preserving_variants.rs | 4 ++-- datafusion/expr/src/tree_node/plan.rs | 11 ++++++---- datafusion/physical-expr/src/tree_node.rs | 9 ++++---- datafusion/physical-plan/src/tree_node.rs | 9 ++++---- 7 files changed, 36 insertions(+), 25 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 3008c7c9b0e0..7d3c70562e3b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -479,9 +479,13 @@ impl TreeNode for Arc { let children = self.arc_children(); if !children.is_empty() { let t = children.into_iter().map_till_continue_and_collect(f)?; - // TODO: once we trust `t.transformed` don't create new node if not necessary + // TODO: Currently `t.transformed` quality comes from if the transformation + // closures fill the field correctly. Once we trust `t.transformed` we can + // remove the additional `t2` check. + // Please note that we need to propagate up `t.tnr` though. let arc_self = Arc::clone(&self); - self.with_new_arc_children(arc_self, t.data) + let t2 = self.with_new_arc_children(arc_self, t.data)?; + Ok(Transformed::new(t2.data, t2.transformed, t.tnr)) } else { Ok(Transformed::no(self)) } @@ -499,7 +503,7 @@ pub trait ConcreteTreeNode: Sized { fn take_children(self) -> (Self, Vec); /// Reattaches updated child nodes to the node, returning the updated node. - fn with_new_children(self, children: Vec) -> Result; + fn with_new_children(self, children: Vec) -> Result>; } impl TreeNode for T { @@ -520,11 +524,13 @@ impl TreeNode for T { { let (new_self, children) = self.take_children(); if !children.is_empty() { - children - .into_iter() - .map_till_continue_and_collect(f)? - // TODO: once we trust `transformed` don't create new node if not necessary - .flat_map_data(|new_children| new_self.with_new_children(new_children)) + let t = children.into_iter().map_till_continue_and_collect(f)?; + // TODO: Currently `t.transformed` quality comes from if the transformation + // closures fill the field correctly. Once we trust `t.transformed` we can + // remove the additional `t2` check. + // Please note that we need to propagate up `t.tnr` though. + let t2 = new_self.with_new_children(t.data)?; + Ok(Transformed::new(t2.data, t2.transformed, t.tnr)) } else { Ok(Transformed::no(new_self)) } diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index ff033d168e77..f0cad14a0fa2 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1030,7 +1030,7 @@ fn replace_order_preserving_variants( } } - context.update_plan_from_children() + context.update_plan_from_children().map(|t| t.data) } /// This utility function adds a [`SortExec`] above an operator according to the diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 7a3b2c512111..c8e5286e1e3f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -113,7 +113,7 @@ fn update_sort_ctx_children( } node.data = data; - node.update_plan_from_children() + node.update_plan_from_children().map(|t| t.data) } /// This object is used within the [`EnforceSorting`] rule to track the closest @@ -477,7 +477,7 @@ fn remove_corresponding_coalesce_in_sub_plan( .collect::>()?; } - requirements.update_plan_from_children() + requirements.update_plan_from_children().map(|t| t.data) } /// Updates child to remove the unnecessary sort below it. diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 4629152cddd9..19add190043e 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -138,7 +138,7 @@ fn plan_with_order_preserving_variants( } } - sort_input.update_plan_from_children() + sort_input.update_plan_from_children().map(|t| t.data) } /// Calculates the updated plan by replacing operators that preserve ordering @@ -184,7 +184,7 @@ fn plan_with_order_breaking_variants( let coalesce = CoalescePartitionsExec::new(child); sort_input.plan = Arc::new(coalesce) as _; } else { - return sort_input.update_plan_from_children(); + return sort_input.update_plan_from_children().map(|t| t.data); } sort_input.children[0].data = false; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 917719907776..19dc9085602f 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -87,12 +87,15 @@ impl TreeNode for LogicalPlan { .iter() .map(|c| (*c).clone()) .map_till_continue_and_collect(f)?; - // TODO: once we trust `t.transformed` remove additional check - if old_children + // TODO: Currently `t.transformed` quality comes from if the transformation + // closures fill the field correctly. Once we trust `t.transformed` we can remove + // the additional `t2` check. + // Please note that we need to propagate up `t.tnr` though. + let t2 = old_children .into_iter() .zip(t.data.iter()) - .any(|(c1, c2)| c1 != c2) - { + .any(|(c1, c2)| c1 != c2); + if t2 { Ok(Transformed::new( self.with_new_exprs(self.expressions(), t.data)?, true, diff --git a/datafusion/physical-expr/src/tree_node.rs b/datafusion/physical-expr/src/tree_node.rs index 8f21ffb82457..68a5fc06e8ee 100644 --- a/datafusion/physical-expr/src/tree_node.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -61,10 +61,11 @@ impl ExprContext { } } - pub fn update_expr_from_children(mut self) -> Result { + pub fn update_expr_from_children(mut self) -> Result> { let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - self.expr = with_new_children_if_necessary(self.expr, children_expr)?.data; - Ok(self) + let t = with_new_children_if_necessary(self.expr, children_expr)?; + self.expr = t.data; + Ok(Transformed::new(self, t.transformed, t.tnr)) } } @@ -93,7 +94,7 @@ impl ConcreteTreeNode for ExprContext { (self, children) } - fn with_new_children(mut self, children: Vec) -> Result { + fn with_new_children(mut self, children: Vec) -> Result> { self.children = children; self.update_expr_from_children() } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index c4223cb73430..a3099b0ac934 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -61,10 +61,11 @@ impl PlanContext { } } - pub fn update_plan_from_children(mut self) -> Result { + pub fn update_plan_from_children(mut self) -> Result> { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - self.plan = with_new_children_if_necessary(self.plan, children_plans)?.data; - Ok(self) + let t = with_new_children_if_necessary(self.plan, children_plans)?; + self.plan = t.data; + Ok(Transformed::new(self, t.transformed, t.tnr)) } } @@ -94,7 +95,7 @@ impl ConcreteTreeNode for PlanContext { (self, children) } - fn with_new_children(mut self, children: Vec) -> Result { + fn with_new_children(mut self, children: Vec) -> Result> { self.children = children; self.update_plan_from_children() } From 5b9453132d4594f54b3f4aa784c8ba325bb80bba Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 2 Feb 2024 10:48:23 +0100 Subject: [PATCH 08/40] rephrase todo comment, always propagate up `t.transformed` from the transformation closure, fix projection pushdown closure --- datafusion/common/src/tree_node.rs | 28 ++++++++++++------- .../physical_optimizer/projection_pushdown.rs | 4 ++- datafusion/expr/src/tree_node/plan.rs | 16 +++++++---- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 7d3c70562e3b..9ca9524bb74b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -479,13 +479,17 @@ impl TreeNode for Arc { let children = self.arc_children(); if !children.is_empty() { let t = children.into_iter().map_till_continue_and_collect(f)?; - // TODO: Currently `t.transformed` quality comes from if the transformation - // closures fill the field correctly. Once we trust `t.transformed` we can - // remove the additional `t2` check. - // Please note that we need to propagate up `t.tnr` though. + // TODO: Currently `assert_eq!(t.transformed, t2.transformed)` fails as + // `t.transformed` quality comes from if the transformation closures fill the + // field correctly. + // Once we trust `t.transformed` we can remove the additional check in + // `with_new_arc_children()`. let arc_self = Arc::clone(&self); let t2 = self.with_new_arc_children(arc_self, t.data)?; - Ok(Transformed::new(t2.data, t2.transformed, t.tnr)) + + // Propagate up `t.transformed` and `t.tnr` along with the node containing + // transformed children. + Ok(Transformed::new(t2.data, t.transformed, t.tnr)) } else { Ok(Transformed::no(self)) } @@ -525,12 +529,16 @@ impl TreeNode for T { let (new_self, children) = self.take_children(); if !children.is_empty() { let t = children.into_iter().map_till_continue_and_collect(f)?; - // TODO: Currently `t.transformed` quality comes from if the transformation - // closures fill the field correctly. Once we trust `t.transformed` we can - // remove the additional `t2` check. - // Please note that we need to propagate up `t.tnr` though. + // TODO: Currently `assert_eq!(t.transformed, t2.transformed)` fails as + // `t.transformed` quality comes from if the transformation closures fill the + // field correctly. + // Once we trust `t.transformed` we can remove the additional check in + // `with_new_children()`. let t2 = new_self.with_new_children(t.data)?; - Ok(Transformed::new(t2.data, t2.transformed, t.tnr)) + + // Propagate up `t.transformed` and `t.tnr` along with the node containing + // transformed children. + Ok(Transformed::new(t2.data, t.transformed, t.tnr)) } else { Ok(Transformed::no(new_self)) } diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 3f7ca5ae6b41..07b362b628c0 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -111,7 +111,9 @@ pub fn remove_unnecessary_projections( let maybe_unified = try_unifying_projections(projection, child_projection)?; return if let Some(new_plan) = maybe_unified { // To unify 3 or more sequential projections: - remove_unnecessary_projections(new_plan) + Ok(Transformed::yes( + remove_unnecessary_projections(new_plan)?.data, + )) } else { Ok(Transformed::no(plan)) }; diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 19dc9085602f..5d05be5b0202 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -87,22 +87,26 @@ impl TreeNode for LogicalPlan { .iter() .map(|c| (*c).clone()) .map_till_continue_and_collect(f)?; - // TODO: Currently `t.transformed` quality comes from if the transformation - // closures fill the field correctly. Once we trust `t.transformed` we can remove - // the additional `t2` check. - // Please note that we need to propagate up `t.tnr` though. + // TODO: Currently `assert_eq!(t.transformed, t2)` fails as + // `t.transformed` quality comes from if the transformation closures fill the + // field correctly. + // Once we trust `t.transformed` we can remove the additional check in + // `t2`. let t2 = old_children .into_iter() .zip(t.data.iter()) .any(|(c1, c2)| c1 != c2); + + // Propagate up `t.transformed` and `t.tnr` along with the node containing + // transformed children. if t2 { Ok(Transformed::new( self.with_new_exprs(self.expressions(), t.data)?, - true, + t.transformed, t.tnr, )) } else { - Ok(Transformed::new(self, false, t.tnr)) + Ok(Transformed::new(self, t.transformed, t.tnr)) } } } From 5c2baef71a66708060f46490bdc5f79ce8e1b5a3 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Wed, 7 Feb 2024 20:20:21 +0100 Subject: [PATCH 09/40] Fix `TreeNodeRecursion` docs --- datafusion/common/src/tree_node.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 9ca9524bb74b..a1aa04ac1b3e 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -293,8 +293,14 @@ pub trait TreeNodeRewriter: Sized { pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, - /// Skip the current subtree. + + /// Do not recurse into children. + /// Has effect only if returned from top-down transform closures or + /// [`TreeNodeVisitor::pre_visit`] or [`TreeNodeRewriter::f_down`]. + /// If returned from bottom-up transform closures or [`TreeNodeVisitor::post_visit`] or + /// [`TreeNodeRewriter::f_up`] then works as [`TreeNodeRecursion::Continue`]. Skip, + /// Stop recursion. Stop, } From dcf0189f3de0b1754a1c934df82fb6ec393bb9aa Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 10 Feb 2024 08:48:59 +0100 Subject: [PATCH 10/40] extend Skip (Prune) functionality to Jump as it is defined in https://synnada.notion.site/synnada/TreeNode-Design-Proposal-bceac27d18504a2085145550e267c4c1 --- datafusion/common/src/tree_node.rs | 73 +++++++++---------- .../core/src/datasource/listing/helpers.rs | 2 +- .../physical_plan/parquet/row_filter.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/utils.rs | 2 +- .../optimizer/src/common_subexpr_eliminate.rs | 12 +-- datafusion/optimizer/src/decorrelate.rs | 6 +- datafusion/optimizer/src/push_down_filter.rs | 2 +- .../optimizer/src/scalar_subquery_to_join.rs | 2 +- 9 files changed, 50 insertions(+), 57 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index a1aa04ac1b3e..428cba5a8ab8 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::Result; /// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the -/// function continues. If it returns [`TreeNodeRecursion::Skip`], the function returns +/// function continues. If it returns [`TreeNodeRecursion::Jump`], the function returns /// with [`TreeNodeRecursion::Continue`] to jump next recursion step, bypassing further /// exploration of the current step. In case of [`TreeNodeRecursion::Stop`], the function /// return with [`TreeNodeRecursion::Stop`] and recursion halts. @@ -34,7 +34,7 @@ macro_rules! handle_tree_recursion { TreeNodeRecursion::Continue => {} // If the recursion should skip, do not apply to its children, let // the recursion continue: - TreeNodeRecursion::Skip => return Ok(TreeNodeRecursion::Continue), + TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Continue), // If the recursion should stop, do not apply to its children: TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } @@ -135,7 +135,7 @@ pub trait TreeNode: Sized { FD: FnMut(Self) -> Result>, FU: FnMut(Self) -> Result>, { - f_down(self)?.and_then_transform_children(|n| { + f_down(self)?.and_then_transform_on_continue(|n| { n.map_children(|c| c.transform(f_down, f_up))? .and_then_transform(f_up) }) @@ -148,7 +148,8 @@ pub trait TreeNode: Sized { where F: Fn(Self) -> Result>, { - f(self)?.and_then_transform_children(|n| n.map_children(|c| c.transform_down(f))) + f(self)? + .and_then_transform_on_continue(|n| n.map_children(|c| c.transform_down(f))) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -158,8 +159,9 @@ pub trait TreeNode: Sized { where F: FnMut(Self) -> Result>, { - f(self)? - .and_then_transform_children(|n| n.map_children(|c| c.transform_down_mut(f))) + f(self)?.and_then_transform_on_continue(|n| { + n.map_children(|c| c.transform_down_mut(f)) + }) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its @@ -212,7 +214,7 @@ pub trait TreeNode: Sized { self, rewriter: &mut R, ) -> Result> { - rewriter.f_down(self)?.and_then_transform_children(|n| { + rewriter.f_down(self)?.and_then_transform_on_continue(|n| { n.map_children(|c| c.rewrite(rewriter))? .and_then_transform(|n| rewriter.f_up(n)) }) @@ -252,7 +254,7 @@ pub trait TreeNode: Sized { /// siblings of that tree node are visited, nor is post_visit /// called on its parent tree node /// -/// If [`TreeNodeRecursion::Skip`] is returned on a call to pre_visit, no +/// If [`TreeNodeRecursion::Jump`] is returned on a call to pre_visit, no /// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. @@ -294,12 +296,12 @@ pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, - /// Do not recurse into children. - /// Has effect only if returned from top-down transform closures or - /// [`TreeNodeVisitor::pre_visit`] or [`TreeNodeRewriter::f_down`]. - /// If returned from bottom-up transform closures or [`TreeNodeVisitor::post_visit`] or - /// [`TreeNodeRewriter::f_up`] then works as [`TreeNodeRecursion::Continue`]. - Skip, + /// In top-down traversals skip recursing into children but continue with the next + /// node, which actually means pruning of the subtree. + /// In bottom-up traversals bypass calling bottom-up closures till the next leaf node. + /// In combined traversals bypass calling bottom-up closures till the first top-down + /// closure. + Jump, /// Stop recursion. Stop, @@ -358,14 +360,12 @@ impl Transformed { fn and_then Result>>( self, f: F, - children: bool, + return_continue_on_jump: bool, ) -> Result> { match self.tnr { TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Skip => { - // If the next transformation would happen on children return immediately - // on `Skip`. - if children { + TreeNodeRecursion::Jump => { + if return_continue_on_jump { return Ok(Transformed { tnr: TreeNodeRecursion::Continue, ..self @@ -388,7 +388,7 @@ impl Transformed { self.and_then(f, false) } - pub fn and_then_transform_children Result>>( + pub fn and_then_transform_on_continue Result>>( self, f: F, ) -> Result> { @@ -418,26 +418,19 @@ impl TransformedIterator for I { let mut new_transformed = false; let new_data = self .map(|i| { - if new_tnr == TreeNodeRecursion::Continue - || new_tnr == TreeNodeRecursion::Skip - { - let Transformed { - data, - transformed, - tnr, - } = f(i)?; - new_tnr = if tnr == TreeNodeRecursion::Skip { - // Iterator always considers the elements as siblings so `Skip` - // can be safely converted to `Continue`. - TreeNodeRecursion::Continue - } else { - tnr - }; - new_transformed |= transformed; - Ok(data) - } else { - Ok(i) - } + Ok(match new_tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + let Transformed { + data, + transformed, + tnr, + } = f(i)?; + new_tnr = tnr; + new_transformed |= transformed; + data + } + TreeNodeRecursion::Stop => i, + }) }) .collect::>>()?; Ok(Transformed { diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 508343671b90..78955104c72a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -57,7 +57,7 @@ pub fn expr_applicable_for_cols(col_names: &[String], expr: &Expr) -> bool { Expr::Column(Column { ref name, .. }) => { is_applicable &= col_names.contains(name); if is_applicable { - Ok(TreeNodeRecursion::Skip) + Ok(TreeNodeRecursion::Jump) } else { Ok(TreeNodeRecursion::Stop) } diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index bdd607095f44..06687dde6baf 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -223,13 +223,13 @@ impl<'a> TreeNodeRewriter for FilterCandidateBuilder<'a> { if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; - return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; - return Ok(Transformed::new(node, false, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump)); } } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e965d6ce541d..8940036bbd55 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -652,12 +652,12 @@ impl LogicalPlan { | Expr::ScalarSubquery(_) | Expr::InSubquery(_) => { // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)) + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) } Expr::Alias(_) => Ok(Transformed::new( expr.unalias(), true, - TreeNodeRecursion::Skip, + TreeNodeRecursion::Jump, )), _ => Ok(Transformed::no(expr)), } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 07095c6e2cc4..6ccce160f507 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -663,7 +663,7 @@ where exprs.push(expr.clone()) } // stop recursing down this expr once we find a match - return Ok(TreeNodeRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } Ok(TreeNodeRecursion::Continue) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index fafc6340f1a1..721092a813a5 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -667,7 +667,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { - return Ok(TreeNodeRecursion::Skip); + return Ok(TreeNodeRecursion::Jump); } self.visit_stack .push(VisitRecord::EnterMark(self.node_count)); @@ -750,12 +750,12 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { // the `id_array`, which records the expr's identifier used to rewrite expr. So if we // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. if expr.short_circuits() || is_volatile_expression(&expr)? { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { - return Ok(Transformed::new(expr, false, TreeNodeRecursion::Skip)); + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)); } let curr_id = &self.id_array[self.curr_index].1; @@ -774,7 +774,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { return Ok(Transformed::new( expr, false, - TreeNodeRecursion::Skip, + TreeNodeRecursion::Jump, )); } @@ -791,7 +791,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { return Ok(Transformed::new( expr, false, - TreeNodeRecursion::Skip, + TreeNodeRecursion::Jump, )); } @@ -810,7 +810,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { Ok(Transformed::new( col(id).alias(expr_name), true, - TreeNodeRecursion::Skip, + TreeNodeRecursion::Jump, )) } else { self.curr_index += 1; diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index b7119966c41c..1aad73d7b785 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -66,7 +66,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { if plan_hold_outer { // the unsupported case self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } else { Ok(Transformed::no(plan)) } @@ -77,7 +77,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { (false, true) => { // the unsupported case self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } _ => Ok(Transformed::no(plan)), } @@ -85,7 +85,7 @@ impl TreeNodeRewriter for PullUpCorrelatedExpr { _ if plan.expressions().iter().any(|expr| expr.contains_outer()) => { // the unsupported cases, the plan expressions contain out reference columns(like window expressions) self.can_pull_up = false; - Ok(Transformed::new(plan, false, TreeNodeRecursion::Skip)) + Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump)) } _ => Ok(Transformed::no(plan)), } diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 0e05793c5219..6e8de99d8105 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -222,7 +222,7 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::Column(_) | Expr::Literal(_) | Expr::Placeholder(_) - | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Skip), + | Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump), Expr::Exists { .. } | Expr::InSubquery(_) | Expr::ScalarSubquery(_) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 0ac053dacd29..caf3b9ba53a3 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -223,7 +223,7 @@ impl TreeNodeRewriter for ExtractScalarSubQuery { subqry_alias, )?), true, - TreeNodeRecursion::Skip, + TreeNodeRecursion::Jump, )) } _ => Ok(Transformed::no(expr)), From 6edd05f91654898df92d7620cefd0fc4bf63346c Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 10 Feb 2024 14:38:01 +0100 Subject: [PATCH 11/40] fix Jump and add tests --- datafusion/common/src/tree_node.rs | 254 ++++++++++++++++++++++++++--- 1 file changed, 233 insertions(+), 21 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 428cba5a8ab8..e3d5c4d2ccfe 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -135,10 +135,13 @@ pub trait TreeNode: Sized { FD: FnMut(Self) -> Result>, FU: FnMut(Self) -> Result>, { - f_down(self)?.and_then_transform_on_continue(|n| { - n.map_children(|c| c.transform(f_down, f_up))? - .and_then_transform(f_up) - }) + f_down(self)?.and_then_transform_on_continue( + |n| { + n.map_children(|c| c.transform(f_down, f_up))? + .and_then_transform_on_continue(f_up, TreeNodeRecursion::Jump) + }, + TreeNodeRecursion::Continue, + ) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -148,8 +151,10 @@ pub trait TreeNode: Sized { where F: Fn(Self) -> Result>, { - f(self)? - .and_then_transform_on_continue(|n| n.map_children(|c| c.transform_down(f))) + f(self)?.and_then_transform_on_continue( + |n| n.map_children(|c| c.transform_down(f)), + TreeNodeRecursion::Continue, + ) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its @@ -159,9 +164,10 @@ pub trait TreeNode: Sized { where F: FnMut(Self) -> Result>, { - f(self)?.and_then_transform_on_continue(|n| { - n.map_children(|c| c.transform_down_mut(f)) - }) + f(self)?.and_then_transform_on_continue( + |n| n.map_children(|c| c.transform_down_mut(f)), + TreeNodeRecursion::Continue, + ) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its @@ -172,7 +178,7 @@ pub trait TreeNode: Sized { F: Fn(Self) -> Result>, { self.map_children(|c| c.transform_up(f))? - .and_then_transform(f) + .and_then_transform_on_continue(f, TreeNodeRecursion::Jump) } /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its @@ -183,7 +189,7 @@ pub trait TreeNode: Sized { F: FnMut(Self) -> Result>, { self.map_children(|c| c.transform_up_mut(f))? - .and_then_transform(f) + .and_then_transform_on_continue(f, TreeNodeRecursion::Jump) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -214,10 +220,16 @@ pub trait TreeNode: Sized { self, rewriter: &mut R, ) -> Result> { - rewriter.f_down(self)?.and_then_transform_on_continue(|n| { - n.map_children(|c| c.rewrite(rewriter))? - .and_then_transform(|n| rewriter.f_up(n)) - }) + rewriter.f_down(self)?.and_then_transform_on_continue( + |n| { + n.map_children(|c| c.rewrite(rewriter))? + .and_then_transform_on_continue( + |n| rewriter.f_up(n), + TreeNodeRecursion::Jump, + ) + }, + TreeNodeRecursion::Continue, + ) } /// Apply the closure `F` to the node's children @@ -299,7 +311,7 @@ pub enum TreeNodeRecursion { /// In top-down traversals skip recursing into children but continue with the next /// node, which actually means pruning of the subtree. /// In bottom-up traversals bypass calling bottom-up closures till the next leaf node. - /// In combined traversals bypass calling bottom-up closures till the first top-down + /// In combined traversals bypass calling bottom-up closures till the next top-down /// closure. Jump, @@ -360,14 +372,14 @@ impl Transformed { fn and_then Result>>( self, f: F, - return_continue_on_jump: bool, + return_on_jump: Option, ) -> Result> { match self.tnr { TreeNodeRecursion::Continue => {} TreeNodeRecursion::Jump => { - if return_continue_on_jump { + if return_on_jump.is_some() { return Ok(Transformed { - tnr: TreeNodeRecursion::Continue, + tnr: return_on_jump.unwrap(), ..self }); } @@ -385,14 +397,15 @@ impl Transformed { self, f: F, ) -> Result> { - self.and_then(f, false) + self.and_then(f, None) } pub fn and_then_transform_on_continue Result>>( self, f: F, + return_on_jump: TreeNodeRecursion, ) -> Result> { - self.and_then(f, true) + self.and_then(f, Some(return_on_jump)) } } @@ -543,3 +556,202 @@ impl TreeNode for T { } } } + +#[cfg(test)] +mod tests { + use crate::tree_node::{ + Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }; + use crate::Result; + + struct TestTreeNode { + children: Vec>, + data: T, + } + + impl TestTreeNode { + fn new(children: Vec>, data: T) -> Self { + Self { children, data } + } + } + + impl TreeNode for TestTreeNode { + fn apply_children(&self, op: &mut F) -> Result + where + F: FnMut(&Self) -> Result, + { + for child in &self.children { + handle_tree_recursion!(op(&child)?); + } + Ok(TreeNodeRecursion::Continue) + } + + fn map_children(self, f: F) -> Result> + where + F: FnMut(Self) -> Result>, + { + Ok(self + .children + .into_iter() + .map_till_continue_and_collect(f)? + .map_data(|new_children| Self { + children: new_children, + ..self + })) + } + } + + fn new_test_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + #[test] + fn test_rewrite() -> Result<()> { + let tree = new_test_tree(); + + struct TestRewriter { + pub visits: Vec, + } + + impl TestRewriter { + fn new() -> Self { + Self { visits: vec![] } + } + } + + impl TreeNodeRewriter for TestRewriter { + type Node = TestTreeNode; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_down {}", node.data)); + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_up {}", node.data)); + Ok(Transformed::no(node)) + } + } + + let mut rewriter = TestRewriter::new(); + tree.rewrite(&mut rewriter)?; + assert_eq!( + rewriter.visits, + vec![ + "f_down j", "f_down i", "f_down f", "f_down e", "f_down c", "f_down b", + "f_up b", "f_down d", "f_down a", "f_up a", "f_up d", "f_up c", "f_up e", + "f_down g", "f_down h", "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + ] + ); + + Ok(()) + } + + #[test] + fn test_f_down_jump() -> Result<()> { + let tree = new_test_tree(); + + struct FDownJumpRewriter { + pub visits: Vec, + jump_on: String, + } + + impl FDownJumpRewriter { + fn new(jump_on: String) -> Self { + Self { + visits: vec![], + jump_on, + } + } + } + + impl TreeNodeRewriter for FDownJumpRewriter { + type Node = TestTreeNode; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_down {}", node.data)); + Ok(if node.data == self.jump_on { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_up {}", node.data)); + Ok(Transformed::no(node)) + } + } + + let mut rewriter = FDownJumpRewriter::new("e".to_string()); + tree.rewrite(&mut rewriter)?; + assert_eq!( + rewriter.visits, + vec![ + "f_down j", "f_down i", "f_down f", "f_down e", "f_down g", "f_down h", + "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + ] + ); + + Ok(()) + } + + #[test] + fn test_f_up_jump() -> Result<()> { + let tree = new_test_tree(); + + struct FUpJumpRewriter { + pub visits: Vec, + jump_on: String, + } + + impl FUpJumpRewriter { + fn new(jump_on: String) -> Self { + Self { + visits: vec![], + jump_on, + } + } + } + + impl TreeNodeRewriter for FUpJumpRewriter { + type Node = TestTreeNode; + + fn f_down(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_down {}", node.data)); + Ok(Transformed::no(node)) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_up {}", node.data)); + Ok(if node.data == self.jump_on { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + } + } + + let mut rewriter = FUpJumpRewriter::new("a".to_string()); + tree.rewrite(&mut rewriter)?; + assert_eq!( + rewriter.visits, + vec![ + "f_down j", "f_down i", "f_down f", "f_down e", "f_down c", "f_down b", + "f_up b", "f_down d", "f_down a", "f_up a", "f_down g", "f_down h", + "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + ] + ); + + Ok(()) + } +} From 623e5fa1d8c901902fa2182a8093a3424107ff2b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 10 Feb 2024 14:38:01 +0100 Subject: [PATCH 12/40] jump test fixes --- datafusion/common/src/tree_node.rs | 96 +++++++++++++++++++++--------- 1 file changed, 68 insertions(+), 28 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index e3d5c4d2ccfe..e14fe3ff2af6 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -601,17 +601,17 @@ mod tests { } } - fn new_test_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "a".to_string()); - let node_b = TestTreeNode::new(vec![], "b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); - let node_h = TestTreeNode::new(vec![], "h".to_string()); - let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); - let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); - let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); - TestTreeNode::new(vec![node_i], "j".to_string()) + fn new_test_tree<'a>() -> TestTreeNode<&'a str> { + let node_a = TestTreeNode::new(vec![], "a"); + let node_b = TestTreeNode::new(vec![], "b"); + let node_d = TestTreeNode::new(vec![node_a], "d"); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c"); + let node_e = TestTreeNode::new(vec![node_c], "e"); + let node_h = TestTreeNode::new(vec![], "h"); + let node_g = TestTreeNode::new(vec![node_h], "g"); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f"); + let node_i = TestTreeNode::new(vec![node_f], "i"); + TestTreeNode::new(vec![node_i], "j") } #[test] @@ -629,15 +629,15 @@ mod tests { } impl TreeNodeRewriter for TestRewriter { - type Node = TestTreeNode; + type Node = TestTreeNode<&'static str>; fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down {}", node.data)); + self.visits.push(format!("f_down({})", node.data)); Ok(Transformed::no(node)) } fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up {}", node.data)); + self.visits.push(format!("f_up({})", node.data)); Ok(Transformed::no(node)) } } @@ -647,9 +647,26 @@ mod tests { assert_eq!( rewriter.visits, vec![ - "f_down j", "f_down i", "f_down f", "f_down e", "f_down c", "f_down b", - "f_up b", "f_down d", "f_down a", "f_up a", "f_up d", "f_up c", "f_up e", - "f_down g", "f_down h", "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)" ] ); @@ -675,10 +692,10 @@ mod tests { } impl TreeNodeRewriter for FDownJumpRewriter { - type Node = TestTreeNode; + type Node = TestTreeNode<&'static str>; fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down {}", node.data)); + self.visits.push(format!("f_down({})", node.data)); Ok(if node.data == self.jump_on { Transformed::new(node, false, TreeNodeRecursion::Jump) } else { @@ -687,7 +704,7 @@ mod tests { } fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up {}", node.data)); + self.visits.push(format!("f_up({})", node.data)); Ok(Transformed::no(node)) } } @@ -697,8 +714,17 @@ mod tests { assert_eq!( rewriter.visits, vec![ - "f_down j", "f_down i", "f_down f", "f_down e", "f_down g", "f_down h", - "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)" ] ); @@ -724,15 +750,15 @@ mod tests { } impl TreeNodeRewriter for FUpJumpRewriter { - type Node = TestTreeNode; + type Node = TestTreeNode<&'static str>; fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down {}", node.data)); + self.visits.push(format!("f_down({})", node.data)); Ok(Transformed::no(node)) } fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up {}", node.data)); + self.visits.push(format!("f_up({})", node.data)); Ok(if node.data == self.jump_on { Transformed::new(node, false, TreeNodeRecursion::Jump) } else { @@ -746,9 +772,23 @@ mod tests { assert_eq!( rewriter.visits, vec![ - "f_down j", "f_down i", "f_down f", "f_down e", "f_down c", "f_down b", - "f_up b", "f_down d", "f_down a", "f_up a", "f_down g", "f_down h", - "f_up h", "f_up g", "f_up f", "f_up i", "f_up j" + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)" ] ); From 492777c8e569de80c9aaf7a91c4bab4f2aeb1ed6 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sat, 10 Feb 2024 15:04:30 +0100 Subject: [PATCH 13/40] fix clippy --- datafusion/common/src/tree_node.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index e14fe3ff2af6..ac5876bf74ce 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -377,11 +377,8 @@ impl Transformed { match self.tnr { TreeNodeRecursion::Continue => {} TreeNodeRecursion::Jump => { - if return_on_jump.is_some() { - return Ok(Transformed { - tnr: return_on_jump.unwrap(), - ..self - }); + if let Some(tnr) = return_on_jump { + return Ok(Transformed { tnr, ..self }); } } TreeNodeRecursion::Stop => return Ok(self), @@ -581,7 +578,7 @@ mod tests { F: FnMut(&Self) -> Result, { for child in &self.children { - handle_tree_recursion!(op(&child)?); + handle_tree_recursion!(op(child)?); } Ok(TreeNodeRecursion::Continue) } From 384c3e04d9fb320a969b2d439ac610ac3e3c1b1c Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 11 Feb 2024 12:25:07 +0100 Subject: [PATCH 14/40] unify "transform" traversals using macros, fix "visit" traversal jumps, add visit jump tests, ensure consistent naming `f` instead of `op`, `f_down` instead of `pre_visit` and `f_up` instead of `post_visit` --- datafusion/common/src/tree_node.rs | 486 ++++++++++++------ datafusion/core/src/execution/context/mod.rs | 4 +- datafusion/expr/src/logical_plan/display.rs | 12 +- datafusion/expr/src/logical_plan/plan.rs | 28 +- datafusion/expr/src/tree_node/expr.rs | 12 +- datafusion/expr/src/tree_node/plan.rs | 26 +- .../optimizer/src/common_subexpr_eliminate.rs | 6 +- 7 files changed, 389 insertions(+), 185 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index ac5876bf74ce..0c2813bb8489 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,25 +22,114 @@ use std::sync::Arc; use crate::Result; +/// This macro is used to determine continuation after a top-down closure is applied +/// during "visit" traversals. +/// /// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the -/// function continues. If it returns [`TreeNodeRecursion::Jump`], the function returns -/// with [`TreeNodeRecursion::Continue`] to jump next recursion step, bypassing further -/// exploration of the current step. In case of [`TreeNodeRecursion::Stop`], the function -/// return with [`TreeNodeRecursion::Stop`] and recursion halts. +/// function continues. +/// If it returns [`TreeNodeRecursion::Jump`], the function returns with (propagates up) +/// [`TreeNodeRecursion::Continue`] to jump next recursion step, bypassing further +/// exploration of the current step. +/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) +/// [`TreeNodeRecursion::Stop`] and recursion halts. #[macro_export] -macro_rules! handle_tree_recursion { +macro_rules! handle_visit_recursion_down { ($EXPR:expr) => { match $EXPR { TreeNodeRecursion::Continue => {} - // If the recursion should skip, do not apply to its children, let - // the recursion continue: TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Continue), - // If the recursion should stop, do not apply to its children: TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } +/// This macro is used to determine continuation between visiting siblings during "visit" +/// traversals. +/// +/// If the function returns [`TreeNodeRecursion::Continue`] or +/// [`TreeNodeRecursion::Jump`], the normal execution of the function continues. +/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) +/// [`TreeNodeRecursion::Stop`] and recursion halts. +macro_rules! handle_visit_recursion { + ($EXPR:expr) => { + match $EXPR { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + }; +} + +/// This macro is used to determine continuation before a bottom-up closure is applied +/// during "visit" traversals. +/// +/// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the +/// function continues. +/// If it returns [`TreeNodeRecursion::Jump`], the function returns with (propagates up) +/// [`TreeNodeRecursion::Jump`], bypassing further bottom-up closures until a top-down +/// closure is found. +/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) +/// [`TreeNodeRecursion::Stop`] and recursion halts. +#[macro_export] +macro_rules! handle_visit_recursion_up { + ($EXPR:expr) => { + match $EXPR { + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Jump), + TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), + } + }; +} + +/// This macro is used to determine continuation during top-down "transform" traversals. +/// +/// After the bottom-up closure returns with [`Transformed`] depending on the returned +/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// continuation and [`TreeNodeRecursion`] state propagation. +#[macro_export] +macro_rules! handle_transform_recursion_down { + ($F_DOWN:expr, $F_SELF:expr) => { + $F_DOWN?.and_then( + |n| n.map_children($F_SELF), + Some(TreeNodeRecursion::Continue), + ) + }; +} + +/// This macro is used to determine continuation during combined "transform" traversals. +/// +/// After the bottom-up closure returns with [`Transformed`] depending on the returned +/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// continuation and if [`TreeNodeRecursion`] state propagation is needed. +/// And then after recursing into children returns with [`Transformed`] depending on the +/// returned [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// continuation and [`TreeNodeRecursion`] state propagation. +#[macro_export] +macro_rules! handle_transform_recursion { + ($F_DOWN:expr, $SELF:expr, $F_UP:expr) => { + $F_DOWN?.and_then( + |n| { + n.map_children($SELF)? + .and_then($F_UP, Some(TreeNodeRecursion::Jump)) + }, + Some(TreeNodeRecursion::Continue), + ) + }; +} + +/// This macro is used to determine continuation during bottom-up traversal. +/// +/// After recursing into children returns with [`Transformed`] depending on the returned +/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// continuation and [`TreeNodeRecursion`] state propagation. +#[macro_export] +macro_rules! handle_transform_recursion_up { + ($NODE:expr, $F_SELF:expr, $F_UP:expr) => { + $NODE + .map_children($F_SELF)? + .and_then($F_UP, Some(TreeNodeRecursion::Jump)) + }; +} + /// Defines a visitable and rewriteable a tree node. This trait is /// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as /// well as expression trees ([`PhysicalExpr`], [`Expr`]) in @@ -52,18 +141,18 @@ macro_rules! handle_tree_recursion { /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Applies `op` to the node and its children. `op` is applied in a preoder way, - /// and it is controlled by [`TreeNodeRecursion`], which means result of the `op` + /// Applies `f` to the node and its children. `f` is applied in a preoder way, + /// and it is controlled by [`TreeNodeRecursion`], which means result of the `f` /// on the self node can cause an early return. /// - /// The `op` closure can be used to collect some info from the + /// The `f` closure can be used to collect some info from the /// tree node or do some checking for the tree node. fn apply Result>( &self, - op: &mut F, + f: &mut F, ) -> Result { - handle_tree_recursion!(op(self)?); - self.apply_children(&mut |node| node.apply(op)) + handle_visit_recursion_down!(f(self)?); + self.apply_children(&mut |n| n.apply(f)) } /// Visit the tree node using the given [TreeNodeVisitor] @@ -92,15 +181,15 @@ pub trait TreeNode: Sized { /// children of that node will be visited, nor is post_visit /// called on that node. Details see [`TreeNodeVisitor`] /// - /// If using the default [`TreeNodeVisitor::post_visit`] that does + /// If using the default [`TreeNodeVisitor::f_up`] that does /// nothing, [`Self::apply`] should be preferred. - fn visit>( + fn visit>( &self, visitor: &mut V, ) -> Result { - handle_tree_recursion!(visitor.pre_visit(self)?); - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) + handle_visit_recursion_down!(visitor.f_down(self)?); + handle_visit_recursion_up!(self.apply_children(&mut |n| n.visit(visitor))?); + visitor.f_up(self) } /// Transforms the tree using `f_down` while traversing the tree top-down @@ -135,61 +224,49 @@ pub trait TreeNode: Sized { FD: FnMut(Self) -> Result>, FU: FnMut(Self) -> Result>, { - f_down(self)?.and_then_transform_on_continue( - |n| { - n.map_children(|c| c.transform(f_down, f_up))? - .and_then_transform_on_continue(f_up, TreeNodeRecursion::Jump) - }, - TreeNodeRecursion::Continue, - ) + handle_transform_recursion!(f_down(self), |c| c.transform(f_down, f_up), f_up) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its /// children(Preorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. + /// When the `f` does not apply to a given node, it is left unchanged. fn transform_down(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - f(self)?.and_then_transform_on_continue( - |n| n.map_children(|c| c.transform_down(f)), - TreeNodeRecursion::Continue, - ) + handle_transform_recursion_down!(f(self), |n| n + .map_children(|c| c.transform_down(f))) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' to the node and all of its + /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its /// children(Preorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. + /// When the `f` does not apply to a given node, it is left unchanged. fn transform_down_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - f(self)?.and_then_transform_on_continue( - |n| n.map_children(|c| c.transform_down_mut(f)), - TreeNodeRecursion::Continue, - ) + handle_transform_recursion_down!(f(self), |n| n + .map_children(|c| c.transform_down_mut(f))) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// Convenience utils for writing optimizers rule: recursively apply the given 'f' first to all of its /// children and then itself(Postorder Traversal). - /// When the `op` does not apply to a given node, it is left unchanged. + /// When the `f` does not apply to a given node, it is left unchanged. fn transform_up(self, f: &F) -> Result> where F: Fn(Self) -> Result>, { - self.map_children(|c| c.transform_up(f))? - .and_then_transform_on_continue(f, TreeNodeRecursion::Jump) + handle_transform_recursion_up!(self, |c| c.transform_up(f), f) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'op' first to all of its + /// Convenience utils for writing optimizers rule: recursively apply the given 'f' first to all of its /// children and then itself(Postorder Traversal) using a mutable function, `F`. - /// When the `op` does not apply to a given node, it is left unchanged. + /// When the `f` does not apply to a given node, it is left unchanged. fn transform_up_mut(self, f: &mut F) -> Result> where F: FnMut(Self) -> Result>, { - self.map_children(|c| c.transform_up_mut(f))? - .and_then_transform_on_continue(f, TreeNodeRecursion::Jump) + handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f) } /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for @@ -220,20 +297,13 @@ pub trait TreeNode: Sized { self, rewriter: &mut R, ) -> Result> { - rewriter.f_down(self)?.and_then_transform_on_continue( - |n| { - n.map_children(|c| c.rewrite(rewriter))? - .and_then_transform_on_continue( - |n| rewriter.f_up(n), - TreeNodeRecursion::Jump, - ) - }, - TreeNodeRecursion::Continue, - ) + handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| { + rewriter.f_up(n) + }) } /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result; @@ -251,8 +321,8 @@ pub trait TreeNode: Sized { /// tree and makes it easier to add new types of tree node and /// algorithms. /// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::pre_visit`] -/// and [`TreeNodeVisitor::post_visit`] are invoked recursively +/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::f_down`] +/// and [`TreeNodeVisitor::f_up`] are invoked recursively /// on an node tree. /// /// If an [`Err`] result is returned, recursion is stopped @@ -270,14 +340,14 @@ pub trait TreeNode: Sized { /// children of that tree node are visited. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. - type N: TreeNode; + type Node: TreeNode; /// Invoked before any children of `node` are visited. - fn pre_visit(&mut self, node: &Self::N) -> Result; + fn f_down(&mut self, node: &Self::Node) -> Result; /// Invoked after all children of `node` are visited. Default /// implementation does nothing. - fn post_visit(&mut self, _node: &Self::N) -> Result { + fn f_up(&mut self, _node: &Self::Node) -> Result { Ok(TreeNodeRecursion::Continue) } } @@ -369,6 +439,11 @@ impl Transformed { }) } + /// This is an important function to decide about recursion continuation and + /// [`TreeNodeRecursion`] state propagation. Handling [`TreeNodeRecursion::Continue`] + /// and [`TreeNodeRecursion::Stop`] is always straightforward, but + /// [`TreeNodeRecursion::Jump`] can behave differently when we are traversing down or + /// up on a tree. fn and_then Result>>( self, f: F, @@ -396,14 +471,6 @@ impl Transformed { ) -> Result> { self.and_then(f, None) } - - pub fn and_then_transform_on_continue Result>>( - self, - f: F, - return_on_jump: TreeNodeRecursion, - ) -> Result> { - self.and_then(f, Some(return_on_jump)) - } } pub trait TransformedIterator: Iterator { @@ -471,14 +538,16 @@ pub trait DynTreeNode { /// [`DynTreeNode`] (such as [`Arc`]) impl TreeNode for Arc { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { + let mut tnr = TreeNodeRecursion::Continue; for child in self.arc_children() { - handle_tree_recursion!(op(&child)?) + tnr = f(&child)?; + handle_visit_recursion!(tnr) } - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn map_children(self, f: F) -> Result> @@ -521,14 +590,16 @@ pub trait ConcreteTreeNode: Sized { impl TreeNode for T { /// Apply the closure `F` to the node's children - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { + let mut tnr = TreeNodeRecursion::Continue; for child in self.children() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion!(tnr) } - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn map_children(self, f: F) -> Result> @@ -558,6 +629,7 @@ impl TreeNode for T { mod tests { use crate::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, }; use crate::Result; @@ -573,14 +645,16 @@ mod tests { } impl TreeNode for TestTreeNode { - fn apply_children(&self, op: &mut F) -> Result + fn apply_children(&self, f: &mut F) -> Result where F: FnMut(&Self) -> Result, { + let mut tnr = TreeNodeRecursion::Continue; for child in &self.children { - handle_tree_recursion!(op(child)?); + tnr = f(child)?; + handle_visit_recursion!(tnr); } - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn map_children(self, f: F) -> Result> @@ -611,6 +685,190 @@ mod tests { TestTreeNode::new(vec![node_i], "j") } + fn all_visits<'a>() -> Vec<&'a str> { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + } + + fn f_down_jump_visits<'a>() -> Vec<&'a str> { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + } + + fn f_up_jump_visits<'a>() -> Vec<&'a str> { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] + } + + #[test] + fn test_visit() -> Result<()> { + let tree = new_test_tree(); + + struct TestVisitor { + pub visits: Vec, + } + + impl TestVisitor { + fn new() -> Self { + Self { visits: vec![] } + } + } + + impl TreeNodeVisitor for TestVisitor { + type Node = TestTreeNode<&'static str>; + + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = TestVisitor::new(); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, all_visits()); + + Ok(()) + } + + #[test] + fn test_visit_f_down_jump() -> Result<()> { + let tree = new_test_tree(); + + struct FDownJumpVisitor { + pub visits: Vec, + jump_on: String, + } + + impl FDownJumpVisitor { + fn new(jump_on: String) -> Self { + Self { + visits: vec![], + jump_on, + } + } + } + + impl TreeNodeVisitor for FDownJumpVisitor { + type Node = TestTreeNode<&'static str>; + + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + Ok(if node.data == self.jump_on { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + Ok(TreeNodeRecursion::Continue) + } + } + + let mut visitor = FDownJumpVisitor::new("e".to_string()); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, f_down_jump_visits()); + + Ok(()) + } + + #[test] + fn test_visit_f_up_jump() -> Result<()> { + let tree = new_test_tree(); + + struct FUpJumpVisitor { + pub visits: Vec, + jump_on: String, + } + + impl FUpJumpVisitor { + fn new(jump_on: String) -> Self { + Self { + visits: vec![], + jump_on, + } + } + } + + impl TreeNodeVisitor for FUpJumpVisitor { + type Node = TestTreeNode<&'static str>; + + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + Ok(TreeNodeRecursion::Continue) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + Ok(if node.data == self.jump_on { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }) + } + } + + let mut visitor = FUpJumpVisitor::new("a".to_string()); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, f_up_jump_visits()); + + Ok(()) + } + #[test] fn test_rewrite() -> Result<()> { let tree = new_test_tree(); @@ -641,37 +899,13 @@ mod tests { let mut rewriter = TestRewriter::new(); tree.rewrite(&mut rewriter)?; - assert_eq!( - rewriter.visits, - vec![ - "f_down(j)", - "f_down(i)", - "f_down(f)", - "f_down(e)", - "f_down(c)", - "f_down(b)", - "f_up(b)", - "f_down(d)", - "f_down(a)", - "f_up(a)", - "f_up(d)", - "f_up(c)", - "f_up(e)", - "f_down(g)", - "f_down(h)", - "f_up(h)", - "f_up(g)", - "f_up(f)", - "f_up(i)", - "f_up(j)" - ] - ); + assert_eq!(rewriter.visits, all_visits()); Ok(()) } #[test] - fn test_f_down_jump() -> Result<()> { + fn test_rewrite_f_down_jump() -> Result<()> { let tree = new_test_tree(); struct FDownJumpRewriter { @@ -708,28 +942,13 @@ mod tests { let mut rewriter = FDownJumpRewriter::new("e".to_string()); tree.rewrite(&mut rewriter)?; - assert_eq!( - rewriter.visits, - vec![ - "f_down(j)", - "f_down(i)", - "f_down(f)", - "f_down(e)", - "f_down(g)", - "f_down(h)", - "f_up(h)", - "f_up(g)", - "f_up(f)", - "f_up(i)", - "f_up(j)" - ] - ); + assert_eq!(rewriter.visits, f_down_jump_visits()); Ok(()) } #[test] - fn test_f_up_jump() -> Result<()> { + fn test_rewrite_f_up_jump() -> Result<()> { let tree = new_test_tree(); struct FUpJumpRewriter { @@ -766,28 +985,7 @@ mod tests { let mut rewriter = FUpJumpRewriter::new("a".to_string()); tree.rewrite(&mut rewriter)?; - assert_eq!( - rewriter.visits, - vec![ - "f_down(j)", - "f_down(i)", - "f_down(f)", - "f_down(e)", - "f_down(c)", - "f_down(b)", - "f_up(b)", - "f_down(d)", - "f_down(a)", - "f_up(a)", - "f_down(g)", - "f_down(h)", - "f_up(h)", - "f_up(g)", - "f_up(f)", - "f_up(i)", - "f_up(j)" - ] - ); + assert_eq!(rewriter.visits, f_up_jump_visits()); Ok(()) } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 06aa0866c71a..29d814d13fe2 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2116,9 +2116,9 @@ impl<'a> BadPlanVisitor<'a> { } impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, node: &Self::N) -> Result { + fn f_down(&mut self, node: &Self::Node) -> Result { match node { LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => { plan_err!("DDL not supported: {}", ddl.name()) diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index ebef7791f8d8..389a33612d4c 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -49,9 +49,9 @@ impl<'a, 'b> IndentVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, ) -> datafusion_common::Result { @@ -72,7 +72,7 @@ impl<'a, 'b> TreeNodeVisitor for IndentVisitor<'a, 'b> { Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, ) -> datafusion_common::Result { @@ -171,9 +171,9 @@ impl<'a, 'b> GraphvizVisitor<'a, 'b> { } impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit( + fn f_down( &mut self, plan: &LogicalPlan, ) -> datafusion_common::Result { @@ -207,7 +207,7 @@ impl<'a, 'b> TreeNodeVisitor for GraphvizVisitor<'a, 'b> { Ok(TreeNodeRecursion::Continue) } - fn post_visit( + fn f_up( &mut self, _plan: &LogicalPlan, ) -> datafusion_common::Result { diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 8940036bbd55..c41479384191 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1148,7 +1148,7 @@ impl LogicalPlan { /// applies visitor to any subqueries in the plan pub(crate) fn visit_subqueries(&self, v: &mut V) -> Result<()> where - V: TreeNodeVisitor, + V: TreeNodeVisitor, { self.inspect_expressions(|expr| { // recursively look for subqueries @@ -2834,9 +2834,9 @@ digraph { } impl TreeNodeVisitor for OkVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "pre_visit Projection", LogicalPlan::Filter { .. } => "pre_visit Filter", @@ -2850,7 +2850,7 @@ digraph { Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { let s = match plan { LogicalPlan::Projection { .. } => "post_visit Projection", LogicalPlan::Filter { .. } => "post_visit Filter", @@ -2917,23 +2917,23 @@ digraph { } impl TreeNodeVisitor for StoppingVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_pre_in.dec() { return Ok(TreeNodeRecursion::Stop); } - self.inner.pre_visit(plan)?; + self.inner.f_down(plan)?; Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_false_from_post_in.dec() { return Ok(TreeNodeRecursion::Stop); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } @@ -2986,22 +2986,22 @@ digraph { } impl TreeNodeVisitor for ErrorVisitor { - type N = LogicalPlan; + type Node = LogicalPlan; - fn pre_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_down(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_pre_in.dec() { return not_impl_err!("Error in pre_visit"); } - self.inner.pre_visit(plan) + self.inner.f_down(plan) } - fn post_visit(&mut self, plan: &LogicalPlan) -> Result { + fn f_up(&mut self, plan: &LogicalPlan) -> Result { if self.return_error_from_post_in.dec() { return not_impl_err!("Error in post_visit"); } - self.inner.post_visit(plan) + self.inner.f_up(plan) } } diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 6877c5d40cba..ce7cca25f8b8 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -27,12 +27,14 @@ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{handle_tree_recursion, internal_err, DataFusionError, Result}; +use datafusion_common::{ + handle_visit_recursion_down, internal_err, DataFusionError, Result, +}; impl TreeNode for Expr { fn apply_children Result>( &self, - op: &mut F, + f: &mut F, ) -> Result { let children = match self { Expr::Alias(Alias{expr,..}) @@ -131,11 +133,13 @@ impl TreeNode for Expr { } }; + let mut tnr = TreeNodeRecursion::Continue; for child in children { - handle_tree_recursion!(op(child)?); + tnr = f(child)?; + handle_visit_recursion_down!(tnr); } - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn map_children(self, mut f: F) -> Result> diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index f5524d3b2eba..a62ce9549b1a 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -22,18 +22,18 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::{handle_tree_recursion, Result}; +use datafusion_common::{handle_visit_recursion_down, handle_visit_recursion_up, Result}; impl TreeNode for LogicalPlan { fn apply Result>( &self, - op: &mut F, + f: &mut F, ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children - handle_tree_recursion!(op(self)?); - self.apply_subqueries(op)?; - self.apply_children(&mut |node| node.apply(op)) + handle_visit_recursion_down!(f(self)?); + self.apply_subqueries(f)?; + self.apply_children(&mut |n| n.apply(f)) } /// To use, define a struct that implements the trait [`TreeNodeVisitor`] and then invoke @@ -56,26 +56,28 @@ impl TreeNode for LogicalPlan { /// visitor.post_visit(Filter) /// visitor.post_visit(Projection) /// ``` - fn visit>( + fn visit>( &self, visitor: &mut V, ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children - handle_tree_recursion!(visitor.pre_visit(self)?); + handle_visit_recursion_down!(visitor.f_down(self)?); self.visit_subqueries(visitor)?; - handle_tree_recursion!(self.apply_children(&mut |node| node.visit(visitor))?); - visitor.post_visit(self) + handle_visit_recursion_up!(self.apply_children(&mut |n| n.visit(visitor))?); + visitor.f_up(self) } fn apply_children Result>( &self, - op: &mut F, + f: &mut F, ) -> Result { + let mut tnr = TreeNodeRecursion::Continue; for child in self.inputs() { - handle_tree_recursion!(op(child)?) + tnr = f(child)?; + handle_visit_recursion_down!(tnr) } - Ok(TreeNodeRecursion::Continue) + Ok(tnr) } fn map_children(self, f: F) -> Result> diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 721092a813a5..df417ccc3f1f 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -661,9 +661,9 @@ impl ExprIdentifierVisitor<'_> { } impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { - type N = Expr; + type Node = Expr; - fn pre_visit(&mut self, expr: &Expr) -> Result { + fn f_down(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 // If the expr contain volatile expression or is a short-circuit expression, skip it. if expr.short_circuits() || is_volatile_expression(expr)? { @@ -677,7 +677,7 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { Ok(TreeNodeRecursion::Continue) } - fn post_visit(&mut self, expr: &Expr) -> Result { + fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; let (idx, sub_expr_desc) = self.pop_enter_mark(); From c224de7a229431201d36851ab3804ebd1c582ba0 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Sun, 11 Feb 2024 13:04:06 +0100 Subject: [PATCH 15/40] fix macro rewrite --- datafusion/common/src/tree_node.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0c2813bb8489..ab701761ac2c 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -234,8 +234,7 @@ pub trait TreeNode: Sized { where F: Fn(Self) -> Result>, { - handle_transform_recursion_down!(f(self), |n| n - .map_children(|c| c.transform_down(f))) + handle_transform_recursion_down!(f(self), |c| c.transform_down(f)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its @@ -245,8 +244,7 @@ pub trait TreeNode: Sized { where F: FnMut(Self) -> Result>, { - handle_transform_recursion_down!(f(self), |n| n - .map_children(|c| c.transform_down_mut(f))) + handle_transform_recursion_down!(f(self), |c| c.transform_down_mut(f)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'f' first to all of its From 87ad9953ac94ee2916ee9266f6a58344c22b5737 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 15 Feb 2024 10:41:37 +0100 Subject: [PATCH 16/40] minor fixes --- datafusion/common/src/tree_node.rs | 16 +++++----- datafusion/expr/src/logical_plan/plan.rs | 39 +++++++++++------------- 2 files changed, 25 insertions(+), 30 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index ab701761ac2c..fee32d5fd6f4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::Result; /// This macro is used to determine continuation after a top-down closure is applied -/// during "visit" traversals. +/// during visiting traversals. /// /// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the /// function continues. @@ -43,7 +43,7 @@ macro_rules! handle_visit_recursion_down { }; } -/// This macro is used to determine continuation between visiting siblings during "visit" +/// This macro is used to determine continuation between visiting siblings during visiting /// traversals. /// /// If the function returns [`TreeNodeRecursion::Continue`] or @@ -60,7 +60,7 @@ macro_rules! handle_visit_recursion { } /// This macro is used to determine continuation before a bottom-up closure is applied -/// during "visit" traversals. +/// during visiting traversals. /// /// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the /// function continues. @@ -80,7 +80,7 @@ macro_rules! handle_visit_recursion_up { }; } -/// This macro is used to determine continuation during top-down "transform" traversals. +/// This macro is used to determine continuation during top-down transforming traversals. /// /// After the bottom-up closure returns with [`Transformed`] depending on the returned /// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion @@ -95,7 +95,7 @@ macro_rules! handle_transform_recursion_down { }; } -/// This macro is used to determine continuation during combined "transform" traversals. +/// This macro is used to determine continuation during combined transforming traversals. /// /// After the bottom-up closure returns with [`Transformed`] depending on the returned /// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion @@ -105,10 +105,10 @@ macro_rules! handle_transform_recursion_down { /// continuation and [`TreeNodeRecursion`] state propagation. #[macro_export] macro_rules! handle_transform_recursion { - ($F_DOWN:expr, $SELF:expr, $F_UP:expr) => { + ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => { $F_DOWN?.and_then( |n| { - n.map_children($SELF)? + n.map_children($F_SELF)? .and_then($F_UP, Some(TreeNodeRecursion::Jump)) }, Some(TreeNodeRecursion::Continue), @@ -116,7 +116,7 @@ macro_rules! handle_transform_recursion { }; } -/// This macro is used to determine continuation during bottom-up traversal. +/// This macro is used to determine continuation during bottom-up transforming traversals. /// /// After recursing into children returns with [`Transformed`] depending on the returned /// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index c41479384191..ca21d45cc91c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -646,29 +646,24 @@ impl LogicalPlan { // Decimal128(Some(69999999999999),30,15) // AND lineitem.l_quantity < Decimal128(Some(2400),15,2) - fn unalias_down(expr: Expr) -> Result> { - match expr { - Expr::Exists { .. } - | Expr::ScalarSubquery(_) - | Expr::InSubquery(_) => { - // subqueries could contain aliases so we don't recurse into those - Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + let predicate = predicate + .transform_down(&mut |expr| { + match expr { + Expr::Exists { .. } + | Expr::ScalarSubquery(_) + | Expr::InSubquery(_) => { + // subqueries could contain aliases so we don't recurse into those + Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + Expr::Alias(_) => Ok(Transformed::new( + expr.unalias(), + true, + TreeNodeRecursion::Jump, + )), + _ => Ok(Transformed::no(expr)), } - Expr::Alias(_) => Ok(Transformed::new( - expr.unalias(), - true, - TreeNodeRecursion::Jump, - )), - _ => Ok(Transformed::no(expr)), - } - } - - fn dummy_up(expr: Expr) -> Result> { - Ok(Transformed::no(expr)) - } - - let predicate = - predicate.transform(&mut unalias_down, &mut dummy_up)?.data; + })? + .data; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) From 8c34a9c62057490954bd9a75955b68e88c0b2d15 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 16 Feb 2024 09:17:22 +0100 Subject: [PATCH 17/40] minor fix --- datafusion/expr/src/tree_node/plan.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index a62ce9549b1a..e167342de93e 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -102,13 +102,11 @@ impl TreeNode for LogicalPlan { // Propagate up `t.transformed` and `t.tnr` along with the node containing // transformed children. if t2 { - Ok(Transformed::new( - self.with_new_exprs(self.expressions(), t.data)?, - t.transformed, - t.tnr, - )) + t.flat_map_data(|new_children| { + self.with_new_exprs(self.expressions(), new_children) + }) } else { - Ok(Transformed::new(self, t.transformed, t.tnr)) + Ok(t.map_data(|_| self)) } } } From f2fa09ad81ef75cff7367223511c55f75b6fbda7 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 16 Feb 2024 10:09:38 +0100 Subject: [PATCH 18/40] refactor tests --- datafusion/common/src/tree_node.rs | 268 +++++++++++------------------ 1 file changed, 97 insertions(+), 171 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index fee32d5fd6f4..869cb59ae3e1 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -630,6 +630,7 @@ mod tests { TreeNodeVisitor, }; use crate::Result; + use std::fmt::Display; struct TestTreeNode { children: Vec>, @@ -708,7 +709,7 @@ mod tests { ] } - fn f_down_jump_visits<'a>() -> Vec<&'a str> { + fn f_down_jump_on_e_visits<'a>() -> Vec<&'a str> { vec![ "f_down(j)", "f_down(i)", @@ -724,7 +725,7 @@ mod tests { ] } - fn f_up_jump_visits<'a>() -> Vec<&'a str> { + fn f_up_jump_on_a_visits<'a>() -> Vec<&'a str> { vec![ "f_down(j)", "f_down(i)", @@ -746,35 +747,78 @@ mod tests { ] } - #[test] - fn test_visit() -> Result<()> { - let tree = new_test_tree(); + type TestVisitorF = Box) -> Result>; - struct TestVisitor { - pub visits: Vec, - } + struct TestVisitor { + visits: Vec, + fd: TestVisitorF, + fu: TestVisitorF, + } - impl TestVisitor { - fn new() -> Self { - Self { visits: vec![] } + impl TestVisitor { + fn new(fd: TestVisitorF, fu: TestVisitorF) -> Self { + Self { + visits: vec![], + fd, + fu, } } + } - impl TreeNodeVisitor for TestVisitor { - type Node = TestTreeNode<&'static str>; + impl TreeNodeVisitor for TestVisitor { + type Node = TestTreeNode; - fn f_down(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_down({})", node.data)); - Ok(TreeNodeRecursion::Continue) - } + fn f_down(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_down({})", node.data)); + (*self.fd)(node) + } + + fn f_up(&mut self, node: &Self::Node) -> Result { + self.visits.push(format!("f_up({})", node.data)); + (*self.fu)(node) + } + } - fn f_up(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_up({})", node.data)); - Ok(TreeNodeRecursion::Continue) + type TestRewriterF = + Box) -> Result>>>; + + struct TestRewriter { + visits: Vec, + fd: TestRewriterF, + fu: TestRewriterF, + } + + impl TestRewriter { + fn new(fd: TestRewriterF, fu: TestRewriterF) -> Self { + Self { + visits: vec![], + fd, + fu, } } + } + + impl TreeNodeRewriter for TestRewriter { + type Node = TestTreeNode; - let mut visitor = TestVisitor::new(); + fn f_down(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_down({})", node.data)); + (*self.fd)(node) + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + self.visits.push(format!("f_up({})", node.data)); + (*self.fu)(node) + } + } + + #[test] + fn test_visit() -> Result<()> { + let tree = new_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(|_| Ok(TreeNodeRecursion::Continue)), + Box::new(|_| Ok(TreeNodeRecursion::Continue)), + ); tree.visit(&mut visitor)?; assert_eq!(visitor.visits, all_visits()); @@ -784,42 +828,18 @@ mod tests { #[test] fn test_visit_f_down_jump() -> Result<()> { let tree = new_test_tree(); - - struct FDownJumpVisitor { - pub visits: Vec, - jump_on: String, - } - - impl FDownJumpVisitor { - fn new(jump_on: String) -> Self { - Self { - visits: vec![], - jump_on, - } - } - } - - impl TreeNodeVisitor for FDownJumpVisitor { - type Node = TestTreeNode<&'static str>; - - fn f_down(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_down({})", node.data)); - Ok(if node.data == self.jump_on { + let mut visitor = TestVisitor::new( + Box::new(|node| { + Ok(if node.data == "e" { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue }) - } - - fn f_up(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_up({})", node.data)); - Ok(TreeNodeRecursion::Continue) - } - } - - let mut visitor = FDownJumpVisitor::new("e".to_string()); + }), + Box::new(|_| Ok(TreeNodeRecursion::Continue)), + ); tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, f_down_jump_visits()); + assert_eq!(visitor.visits, f_down_jump_on_e_visits()); Ok(()) } @@ -827,42 +847,18 @@ mod tests { #[test] fn test_visit_f_up_jump() -> Result<()> { let tree = new_test_tree(); - - struct FUpJumpVisitor { - pub visits: Vec, - jump_on: String, - } - - impl FUpJumpVisitor { - fn new(jump_on: String) -> Self { - Self { - visits: vec![], - jump_on, - } - } - } - - impl TreeNodeVisitor for FUpJumpVisitor { - type Node = TestTreeNode<&'static str>; - - fn f_down(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_down({})", node.data)); - Ok(TreeNodeRecursion::Continue) - } - - fn f_up(&mut self, node: &Self::Node) -> Result { - self.visits.push(format!("f_up({})", node.data)); - Ok(if node.data == self.jump_on { + let mut visitor = TestVisitor::new( + Box::new(|_| Ok(TreeNodeRecursion::Continue)), + Box::new(|node| { + Ok(if node.data == "a" { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue }) - } - } - - let mut visitor = FUpJumpVisitor::new("a".to_string()); + }), + ); tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, f_up_jump_visits()); + assert_eq!(visitor.visits, f_up_jump_on_a_visits()); Ok(()) } @@ -870,32 +866,10 @@ mod tests { #[test] fn test_rewrite() -> Result<()> { let tree = new_test_tree(); - - struct TestRewriter { - pub visits: Vec, - } - - impl TestRewriter { - fn new() -> Self { - Self { visits: vec![] } - } - } - - impl TreeNodeRewriter for TestRewriter { - type Node = TestTreeNode<&'static str>; - - fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down({})", node.data)); - Ok(Transformed::no(node)) - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up({})", node.data)); - Ok(Transformed::no(node)) - } - } - - let mut rewriter = TestRewriter::new(); + let mut rewriter = TestRewriter::new( + Box::new(|node| Ok(Transformed::no(node))), + Box::new(|node| Ok(Transformed::no(node))), + ); tree.rewrite(&mut rewriter)?; assert_eq!(rewriter.visits, all_visits()); @@ -905,42 +879,18 @@ mod tests { #[test] fn test_rewrite_f_down_jump() -> Result<()> { let tree = new_test_tree(); - - struct FDownJumpRewriter { - pub visits: Vec, - jump_on: String, - } - - impl FDownJumpRewriter { - fn new(jump_on: String) -> Self { - Self { - visits: vec![], - jump_on, - } - } - } - - impl TreeNodeRewriter for FDownJumpRewriter { - type Node = TestTreeNode<&'static str>; - - fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down({})", node.data)); - Ok(if node.data == self.jump_on { + let mut rewriter = TestRewriter::new( + Box::new(|node| { + Ok(if node.data == "e" { Transformed::new(node, false, TreeNodeRecursion::Jump) } else { Transformed::no(node) }) - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up({})", node.data)); - Ok(Transformed::no(node)) - } - } - - let mut rewriter = FDownJumpRewriter::new("e".to_string()); + }), + Box::new(|node| Ok(Transformed::no(node))), + ); tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, f_down_jump_visits()); + assert_eq!(rewriter.visits, f_down_jump_on_e_visits()); Ok(()) } @@ -948,42 +898,18 @@ mod tests { #[test] fn test_rewrite_f_up_jump() -> Result<()> { let tree = new_test_tree(); - - struct FUpJumpRewriter { - pub visits: Vec, - jump_on: String, - } - - impl FUpJumpRewriter { - fn new(jump_on: String) -> Self { - Self { - visits: vec![], - jump_on, - } - } - } - - impl TreeNodeRewriter for FUpJumpRewriter { - type Node = TestTreeNode<&'static str>; - - fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down({})", node.data)); - Ok(Transformed::no(node)) - } - - fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up({})", node.data)); - Ok(if node.data == self.jump_on { + let mut rewriter = TestRewriter::new( + Box::new(|node| Ok(Transformed::no(node))), + Box::new(|node| { + Ok(if node.data == "a" { Transformed::new(node, false, TreeNodeRecursion::Jump) } else { Transformed::no(node) }) - } - } - - let mut rewriter = FUpJumpRewriter::new("a".to_string()); + }), + ); tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, f_up_jump_visits()); + assert_eq!(rewriter.visits, f_up_jump_on_a_visits()); Ok(()) } From a2c791ea4c697f8ef82da2b9400b0f5ff9b65d1a Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 16 Feb 2024 10:50:06 +0100 Subject: [PATCH 19/40] add transform tests --- datafusion/common/src/tree_node.rs | 85 ++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 869cb59ae3e1..c11219b6c578 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -631,6 +631,7 @@ mod tests { }; use crate::Result; use std::fmt::Display; + use std::sync::Mutex; struct TestTreeNode { children: Vec>, @@ -749,13 +750,13 @@ mod tests { type TestVisitorF = Box) -> Result>; - struct TestVisitor { + struct TestVisitor { visits: Vec, fd: TestVisitorF, fu: TestVisitorF, } - impl TestVisitor { + impl TestVisitor { fn new(fd: TestVisitorF, fu: TestVisitorF) -> Self { Self { visits: vec![], @@ -782,13 +783,13 @@ mod tests { type TestRewriterF = Box) -> Result>>>; - struct TestRewriter { + struct TestRewriter { visits: Vec, fd: TestRewriterF, fu: TestRewriterF, } - impl TestRewriter { + impl TestRewriter { fn new(fd: TestRewriterF, fu: TestRewriterF) -> Self { Self { visits: vec![], @@ -913,4 +914,80 @@ mod tests { Ok(()) } + + #[test] + fn test_transform() -> Result<()> { + let tree = new_test_tree(); + // TreeNode::transform() is not useful for mutable shared object between `f_down` + // and `f_up` closures, so we need a trick to test it. + let visits = Mutex::new(vec![]); + tree.transform( + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_down({})", node.data)); + Ok(Transformed::no(node)) + }, + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_up({})", node.data)); + Ok(Transformed::no(node)) + }, + )?; + assert_eq!(visits.into_inner().unwrap(), all_visits()); + Ok(()) + } + + #[test] + fn test_transform_f_down_jump() -> Result<()> { + let tree = new_test_tree(); + // TreeNode::transform() is not useful for mutable shared object between `f_down` + // and `f_up` closures, so we need a trick to test it. + let visits = Mutex::new(vec![]); + tree.transform( + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_down({})", node.data)); + Ok(if node.data == "e" { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + }, + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_up({})", node.data)); + Ok(Transformed::no(node)) + }, + )?; + assert_eq!(visits.into_inner().unwrap(), f_down_jump_on_e_visits()); + + Ok(()) + } + + #[test] + fn test_transform_f_up_jump() -> Result<()> { + let tree = new_test_tree(); + // TreeNode::transform() is not useful for mutable shared object between `f_down` + // and `f_up` closures, so we need a trick to test it. + let visits = Mutex::new(vec![]); + tree.transform( + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_down({})", node.data)); + Ok(Transformed::no(node)) + }, + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_up({})", node.data)); + Ok(if node.data == "a" { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + }, + )?; + assert_eq!(visits.into_inner().unwrap(), f_up_jump_on_a_visits()); + + Ok(()) + } } From b9a2afa79514b386b09aeaf29bf59f99756ce825 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 16 Feb 2024 11:51:29 +0100 Subject: [PATCH 20/40] add apply, transform_down and transform_up tests --- datafusion/common/src/tree_node.rs | 131 +++++++++++++++++++++++++++-- 1 file changed, 125 insertions(+), 6 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index c11219b6c578..eb7a5cc62238 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -710,6 +710,20 @@ mod tests { ] } + fn down_visits<'a>() -> Vec<&'a str> { + all_visits() + .into_iter() + .filter(|v| v.starts_with("f_down")) + .collect() + } + + fn up_visits<'a>() -> Vec<&'a str> { + all_visits() + .into_iter() + .filter(|v| v.starts_with("f_up")) + .collect() + } + fn f_down_jump_on_e_visits<'a>() -> Vec<&'a str> { vec![ "f_down(j)", @@ -726,6 +740,13 @@ mod tests { ] } + fn f_down_jump_on_e_down_visits<'a>() -> Vec<&'a str> { + f_down_jump_on_e_visits() + .into_iter() + .filter(|v| v.starts_with("f_down")) + .collect() + } + fn f_up_jump_on_a_visits<'a>() -> Vec<&'a str> { vec![ "f_down(j)", @@ -748,6 +769,13 @@ mod tests { ] } + fn f_up_jump_on_a_up_visits<'a>() -> Vec<&'a str> { + f_up_jump_on_a_visits() + .into_iter() + .filter(|v| v.starts_with("f_up")) + .collect() + } + type TestVisitorF = Box) -> Result>; struct TestVisitor { @@ -864,6 +892,36 @@ mod tests { Ok(()) } + #[test] + fn test_apply() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.apply(&mut |node| { + visits.push(format!("f_down({})", node.data)); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visits, down_visits()); + + Ok(()) + } + + #[test] + fn test_apply_f_down_jump() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.apply(&mut |node| { + visits.push(format!("f_down({})", node.data)); + Ok(if node.data == "e" { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }) + })?; + assert_eq!(visits, f_down_jump_on_e_down_visits()); + + Ok(()) + } + #[test] fn test_rewrite() -> Result<()> { let tree = new_test_tree(); @@ -918,8 +976,8 @@ mod tests { #[test] fn test_transform() -> Result<()> { let tree = new_test_tree(); - // TreeNode::transform() is not useful for mutable shared object between `f_down` - // and `f_up` closures, so we need a trick to test it. + // TreeNode::transform() is not useful when there is a mutable object shared + // between `f_down` and `f_up` closures, so we need a trick to test it. let visits = Mutex::new(vec![]); tree.transform( &mut |node| { @@ -934,14 +992,15 @@ mod tests { }, )?; assert_eq!(visits.into_inner().unwrap(), all_visits()); + Ok(()) } #[test] fn test_transform_f_down_jump() -> Result<()> { let tree = new_test_tree(); - // TreeNode::transform() is not useful for mutable shared object between `f_down` - // and `f_up` closures, so we need a trick to test it. + // TreeNode::transform() is not useful when there is a mutable object shared + // between `f_down` and `f_up` closures, so we need a trick to test it. let visits = Mutex::new(vec![]); tree.transform( &mut |node| { @@ -967,8 +1026,8 @@ mod tests { #[test] fn test_transform_f_up_jump() -> Result<()> { let tree = new_test_tree(); - // TreeNode::transform() is not useful for mutable shared object between `f_down` - // and `f_up` closures, so we need a trick to test it. + // TreeNode::transform() is not useful when there is a mutable object shared + // between `f_down` and `f_up` closures, so we need a trick to test it. let visits = Mutex::new(vec![]); tree.transform( &mut |node| { @@ -990,4 +1049,64 @@ mod tests { Ok(()) } + + #[test] + fn test_transform_down() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_down_mut(&mut |node| { + visits.push(format!("f_down({})", node.data)); + Ok(Transformed::no(node)) + })?; + assert_eq!(visits, down_visits()); + + Ok(()) + } + + #[test] + fn test_transform_down_f_down_jump() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_down_mut(&mut |node| { + visits.push(format!("f_down({})", node.data)); + Ok(if node.data == "e" { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + })?; + assert_eq!(visits, f_down_jump_on_e_down_visits()); + + Ok(()) + } + + #[test] + fn test_transform_up() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_up_mut(&mut |node| { + visits.push(format!("f_up({})", node.data)); + Ok(Transformed::no(node)) + })?; + assert_eq!(visits, up_visits()); + + Ok(()) + } + + #[test] + fn test_transform_up_f_up_jump() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_up_mut(&mut |node| { + visits.push(format!("f_up({})", node.data)); + Ok(if node.data == "a" { + Transformed::new(node, false, TreeNodeRecursion::Jump) + } else { + Transformed::no(node) + }) + })?; + assert_eq!(visits, f_up_jump_on_a_up_visits()); + + Ok(()) + } } From 83fe70934b544eb2c6a0835821bc915b91eda486 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 19 Feb 2024 10:36:21 +0100 Subject: [PATCH 21/40] refactor tests --- datafusion/common/src/tree_node.rs | 478 +++++++++++++---------------- 1 file changed, 211 insertions(+), 267 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index eb7a5cc62238..90b7016a16af 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -780,16 +780,16 @@ mod tests { struct TestVisitor { visits: Vec, - fd: TestVisitorF, - fu: TestVisitorF, + f_down: TestVisitorF, + f_up: TestVisitorF, } impl TestVisitor { - fn new(fd: TestVisitorF, fu: TestVisitorF) -> Self { + fn new(f_down: TestVisitorF, f_up: TestVisitorF) -> Self { Self { visits: vec![], - fd, - fu, + f_down, + f_up, } } } @@ -799,30 +799,98 @@ mod tests { fn f_down(&mut self, node: &Self::Node) -> Result { self.visits.push(format!("f_down({})", node.data)); - (*self.fd)(node) + (*self.f_down)(node) } fn f_up(&mut self, node: &Self::Node) -> Result { self.visits.push(format!("f_up({})", node.data)); - (*self.fu)(node) + (*self.f_up)(node) } } + fn visit_continue(_: &TestTreeNode) -> Result { + Ok(TreeNodeRecursion::Continue) + } + + fn visit_jump_on( + data: T, + ) -> impl FnMut(&TestTreeNode) -> Result { + move |node| { + Ok(if node.data == data { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }) + } + } + + macro_rules! visit_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP)); + tree.visit(&mut visitor)?; + assert_eq!(visitor.visits, $EXPECTED); + + Ok(()) + } + }; + } + + visit_test!(test_visit, visit_continue, visit_continue, all_visits()); + visit_test!( + test_visit_f_down_jump_on_e, + visit_jump_on("e"), + visit_continue, + f_down_jump_on_e_visits() + ); + visit_test!( + test_visit_f_up_jump_on_a, + visit_continue, + visit_jump_on("a"), + f_up_jump_on_a_visits() + ); + + macro_rules! test_apply { + ($NAME:ident, $F:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.apply(&mut |node| { + visits.push(format!("f_down({})", node.data)); + $F(node) + })?; + assert_eq!(visits, $EXPECTED); + + Ok(()) + } + }; + } + + test_apply!(test_apply, visit_continue, down_visits()); + test_apply!( + test_apply_f_down_jump_on_e, + visit_jump_on("e"), + f_down_jump_on_e_down_visits() + ); + type TestRewriterF = Box) -> Result>>>; struct TestRewriter { visits: Vec, - fd: TestRewriterF, - fu: TestRewriterF, + f_down: TestRewriterF, + f_up: TestRewriterF, } impl TestRewriter { - fn new(fd: TestRewriterF, fu: TestRewriterF) -> Self { + fn new(f_down: TestRewriterF, f_up: TestRewriterF) -> Self { Self { visits: vec![], - fd, - fu, + f_down, + f_up, } } } @@ -832,281 +900,157 @@ mod tests { fn f_down(&mut self, node: Self::Node) -> Result> { self.visits.push(format!("f_down({})", node.data)); - (*self.fd)(node) + (*self.f_down)(node) } fn f_up(&mut self, node: Self::Node) -> Result> { self.visits.push(format!("f_up({})", node.data)); - (*self.fu)(node) + (*self.f_up)(node) } } - #[test] - fn test_visit() -> Result<()> { - let tree = new_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(|_| Ok(TreeNodeRecursion::Continue)), - Box::new(|_| Ok(TreeNodeRecursion::Continue)), - ); - tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, all_visits()); - - Ok(()) - } - - #[test] - fn test_visit_f_down_jump() -> Result<()> { - let tree = new_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(|node| { - Ok(if node.data == "e" { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }) - }), - Box::new(|_| Ok(TreeNodeRecursion::Continue)), - ); - tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, f_down_jump_on_e_visits()); - - Ok(()) - } - - #[test] - fn test_visit_f_up_jump() -> Result<()> { - let tree = new_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(|_| Ok(TreeNodeRecursion::Continue)), - Box::new(|node| { - Ok(if node.data == "a" { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue - }) - }), - ); - tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, f_up_jump_on_a_visits()); - - Ok(()) - } - - #[test] - fn test_apply() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.apply(&mut |node| { - visits.push(format!("f_down({})", node.data)); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(visits, down_visits()); - - Ok(()) + fn transform_continue( + node: TestTreeNode, + ) -> Result>> { + Ok(Transformed::no(node)) } - #[test] - fn test_apply_f_down_jump() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.apply(&mut |node| { - visits.push(format!("f_down({})", node.data)); - Ok(if node.data == "e" { - TreeNodeRecursion::Jump + fn transform_jump_on( + data: T, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(if node.data == data { + Transformed::new(node, false, TreeNodeRecursion::Jump) } else { - TreeNodeRecursion::Continue + Transformed::no(node) }) - })?; - assert_eq!(visits, f_down_jump_on_e_down_visits()); - - Ok(()) - } - - #[test] - fn test_rewrite() -> Result<()> { - let tree = new_test_tree(); - let mut rewriter = TestRewriter::new( - Box::new(|node| Ok(Transformed::no(node))), - Box::new(|node| Ok(Transformed::no(node))), - ); - tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, all_visits()); - - Ok(()) - } - - #[test] - fn test_rewrite_f_down_jump() -> Result<()> { - let tree = new_test_tree(); - let mut rewriter = TestRewriter::new( - Box::new(|node| { - Ok(if node.data == "e" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - }), - Box::new(|node| Ok(Transformed::no(node))), - ); - tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, f_down_jump_on_e_visits()); - - Ok(()) - } - - #[test] - fn test_rewrite_f_up_jump() -> Result<()> { - let tree = new_test_tree(); - let mut rewriter = TestRewriter::new( - Box::new(|node| Ok(Transformed::no(node))), - Box::new(|node| { - Ok(if node.data == "a" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - }), - ); - tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, f_up_jump_on_a_visits()); - - Ok(()) - } - - #[test] - fn test_transform() -> Result<()> { - let tree = new_test_tree(); - // TreeNode::transform() is not useful when there is a mutable object shared - // between `f_down` and `f_up` closures, so we need a trick to test it. - let visits = Mutex::new(vec![]); - tree.transform( - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_down({})", node.data)); - Ok(Transformed::no(node)) - }, - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_up({})", node.data)); - Ok(Transformed::no(node)) - }, - )?; - assert_eq!(visits.into_inner().unwrap(), all_visits()); - - Ok(()) - } - - #[test] - fn test_transform_f_down_jump() -> Result<()> { - let tree = new_test_tree(); - // TreeNode::transform() is not useful when there is a mutable object shared - // between `f_down` and `f_up` closures, so we need a trick to test it. - let visits = Mutex::new(vec![]); - tree.transform( - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_down({})", node.data)); - Ok(if node.data == "e" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - }, - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_up({})", node.data)); - Ok(Transformed::no(node)) - }, - )?; - assert_eq!(visits.into_inner().unwrap(), f_down_jump_on_e_visits()); - - Ok(()) + } } - #[test] - fn test_transform_f_up_jump() -> Result<()> { - let tree = new_test_tree(); - // TreeNode::transform() is not useful when there is a mutable object shared - // between `f_down` and `f_up` closures, so we need a trick to test it. - let visits = Mutex::new(vec![]); - tree.transform( - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_down({})", node.data)); - Ok(Transformed::no(node)) - }, - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_up({})", node.data)); - Ok(if node.data == "a" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - }, - )?; - assert_eq!(visits.into_inner().unwrap(), f_up_jump_on_a_visits()); + macro_rules! rewrite_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); + tree.rewrite(&mut rewriter)?; + assert_eq!(rewriter.visits, $EXPECTED); - Ok(()) + Ok(()) + } + }; } - #[test] - fn test_transform_down() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_down_mut(&mut |node| { - visits.push(format!("f_down({})", node.data)); - Ok(Transformed::no(node)) - })?; - assert_eq!(visits, down_visits()); - - Ok(()) + rewrite_test!( + test_rewrite, + transform_continue, + transform_continue, + all_visits() + ); + rewrite_test!( + test_rewrite_f_down_jump_on_e, + transform_jump_on("e"), + transform_continue, + f_down_jump_on_e_visits() + ); + rewrite_test!( + test_rewrite_f_up_jump_on_a, + transform_continue, + transform_jump_on("a"), + f_up_jump_on_a_visits() + ); + + macro_rules! transform_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + // TreeNode::transform() is not useful when there is a mutable object shared + // between `f_down` and `f_up` closures, so we need a trick to test it. + let visits = Mutex::new(vec![]); + tree.transform( + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_down({})", node.data)); + $F_DOWN(node) + }, + &mut |node| { + let mut mut_visits = visits.lock().unwrap(); + mut_visits.push(format!("f_up({})", node.data)); + $F_UP(node) + }, + )?; + assert_eq!(visits.into_inner().unwrap(), $EXPECTED); + + Ok(()) + } + }; } - #[test] - fn test_transform_down_f_down_jump() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_down_mut(&mut |node| { - visits.push(format!("f_down({})", node.data)); - Ok(if node.data == "e" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - })?; - assert_eq!(visits, f_down_jump_on_e_down_visits()); - - Ok(()) + transform_test!( + test_transform, + transform_continue, + transform_continue, + all_visits() + ); + transform_test!( + test_transform_f_down_jump_on_e, + transform_jump_on("e"), + transform_continue, + f_down_jump_on_e_visits() + ); + transform_test!( + test_transform_f_up_jump_on_a, + transform_continue, + transform_jump_on("a"), + f_up_jump_on_a_visits() + ); + + macro_rules! transform_down_test { + ($NAME:ident, $F:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_down_mut(&mut |node| { + visits.push(format!("f_down({})", node.data)); + $F(node) + })?; + assert_eq!(visits, $EXPECTED); + + Ok(()) + } + }; } - #[test] - fn test_transform_up() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_up_mut(&mut |node| { - visits.push(format!("f_up({})", node.data)); - Ok(Transformed::no(node)) - })?; - assert_eq!(visits, up_visits()); - - Ok(()) + transform_down_test!(test_transform_down, transform_continue, down_visits()); + transform_down_test!( + test_transform_down_f_down_jump_on_e, + transform_jump_on("e"), + f_down_jump_on_e_down_visits() + ); + + macro_rules! transform_up_test { + ($NAME:ident, $F:expr, $EXPECTED:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = new_test_tree(); + let mut visits = vec![]; + tree.transform_up_mut(&mut |node| { + visits.push(format!("f_up({})", node.data)); + $F(node) + })?; + assert_eq!(visits, $EXPECTED); + + Ok(()) + } + }; } - #[test] - fn test_transform_up_f_up_jump() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_up_mut(&mut |node| { - visits.push(format!("f_up({})", node.data)); - Ok(if node.data == "a" { - Transformed::new(node, false, TreeNodeRecursion::Jump) - } else { - Transformed::no(node) - }) - })?; - assert_eq!(visits, f_up_jump_on_a_up_visits()); - - Ok(()) - } + transform_up_test!(test_transform_up, transform_continue, up_visits()); + transform_up_test!( + test_transform_up_f_up_jump_on_a, + transform_jump_on("a"), + f_up_jump_on_a_up_visits() + ); } From e5dbc7ebf2069d986d7be1784c216812bc15e529 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 19 Feb 2024 11:28:04 +0100 Subject: [PATCH 22/40] test jump on both a and e nodes in both top-down and bottom-up traversals --- datafusion/common/src/tree_node.rs | 146 ++++++++++++++++++++++++----- 1 file changed, 120 insertions(+), 26 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 90b7016a16af..68b29aa85a0a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -710,18 +710,28 @@ mod tests { ] } - fn down_visits<'a>() -> Vec<&'a str> { - all_visits() - .into_iter() - .filter(|v| v.starts_with("f_down")) - .collect() - } - - fn up_visits<'a>() -> Vec<&'a str> { - all_visits() - .into_iter() - .filter(|v| v.starts_with("f_up")) - .collect() + fn f_down_jump_on_a_visits<'a>() -> Vec<&'a str> { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] } fn f_down_jump_on_e_visits<'a>() -> Vec<&'a str> { @@ -740,14 +750,29 @@ mod tests { ] } - fn f_down_jump_on_e_down_visits<'a>() -> Vec<&'a str> { - f_down_jump_on_e_visits() - .into_iter() - .filter(|v| v.starts_with("f_down")) - .collect() + fn f_up_jump_on_a_visits<'a>() -> Vec<&'a str> { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_down(g)", + "f_down(h)", + "f_up(h)", + "f_up(g)", + "f_up(f)", + "f_up(i)", + "f_up(j)", + ] } - fn f_up_jump_on_a_visits<'a>() -> Vec<&'a str> { + fn f_up_jump_on_e_visits<'a>() -> Vec<&'a str> { vec![ "f_down(j)", "f_down(i)", @@ -759,6 +784,9 @@ mod tests { "f_down(d)", "f_down(a)", "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", "f_down(g)", "f_down(h)", "f_up(h)", @@ -769,8 +797,15 @@ mod tests { ] } - fn f_up_jump_on_a_up_visits<'a>() -> Vec<&'a str> { - f_up_jump_on_a_visits() + fn down_visits(visits: Vec<&str>) -> Vec<&str> { + visits + .into_iter() + .filter(|v| v.starts_with("f_down")) + .collect() + } + + fn up_visits(visits: Vec<&str>) -> Vec<&str> { + visits .into_iter() .filter(|v| v.starts_with("f_up")) .collect() @@ -839,6 +874,12 @@ mod tests { } visit_test!(test_visit, visit_continue, visit_continue, all_visits()); + visit_test!( + test_visit_f_down_jump_on_a, + visit_jump_on("a"), + visit_continue, + f_down_jump_on_a_visits() + ); visit_test!( test_visit_f_down_jump_on_e, visit_jump_on("e"), @@ -851,6 +892,12 @@ mod tests { visit_jump_on("a"), f_up_jump_on_a_visits() ); + visit_test!( + test_visit_f_up_jump_on_e, + visit_continue, + visit_jump_on("e"), + f_up_jump_on_e_visits() + ); macro_rules! test_apply { ($NAME:ident, $F:expr, $EXPECTED:expr) => { @@ -869,11 +916,16 @@ mod tests { }; } - test_apply!(test_apply, visit_continue, down_visits()); + test_apply!(test_apply, visit_continue, down_visits(all_visits())); + test_apply!( + test_apply_f_down_jump_on_a, + visit_jump_on("a"), + down_visits(f_down_jump_on_a_visits()) + ); test_apply!( test_apply_f_down_jump_on_e, visit_jump_on("e"), - f_down_jump_on_e_down_visits() + down_visits(f_down_jump_on_e_visits()) ); type TestRewriterF = @@ -947,6 +999,12 @@ mod tests { transform_continue, all_visits() ); + rewrite_test!( + test_rewrite_f_down_jump_on_a, + transform_jump_on("a"), + transform_continue, + f_down_jump_on_a_visits() + ); rewrite_test!( test_rewrite_f_down_jump_on_e, transform_jump_on("e"), @@ -959,6 +1017,12 @@ mod tests { transform_jump_on("a"), f_up_jump_on_a_visits() ); + rewrite_test!( + test_rewrite_f_up_jump_on_e, + transform_continue, + transform_jump_on("e"), + f_up_jump_on_e_visits() + ); macro_rules! transform_test { ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { @@ -993,6 +1057,12 @@ mod tests { transform_continue, all_visits() ); + transform_test!( + test_transform_f_down_jump_on_a, + transform_jump_on("a"), + transform_continue, + f_down_jump_on_a_visits() + ); transform_test!( test_transform_f_down_jump_on_e, transform_jump_on("e"), @@ -1005,6 +1075,12 @@ mod tests { transform_jump_on("a"), f_up_jump_on_a_visits() ); + transform_test!( + test_transform_f_up_jump_on_e, + transform_continue, + transform_jump_on("e"), + f_up_jump_on_e_visits() + ); macro_rules! transform_down_test { ($NAME:ident, $F:expr, $EXPECTED:expr) => { @@ -1023,11 +1099,20 @@ mod tests { }; } - transform_down_test!(test_transform_down, transform_continue, down_visits()); + transform_down_test!( + test_transform_down, + transform_continue, + down_visits(all_visits()) + ); + transform_down_test!( + test_transform_down_f_down_jump_on_a, + transform_jump_on("a"), + down_visits(f_down_jump_on_a_visits()) + ); transform_down_test!( test_transform_down_f_down_jump_on_e, transform_jump_on("e"), - f_down_jump_on_e_down_visits() + down_visits(f_down_jump_on_e_visits()) ); macro_rules! transform_up_test { @@ -1047,10 +1132,19 @@ mod tests { }; } - transform_up_test!(test_transform_up, transform_continue, up_visits()); + transform_up_test!( + test_transform_up, + transform_continue, + up_visits(all_visits()) + ); transform_up_test!( test_transform_up_f_up_jump_on_a, transform_jump_on("a"), - f_up_jump_on_a_up_visits() + up_visits(f_up_jump_on_a_visits()) + ); + transform_up_test!( + test_transform_up_f_up_jump_on_e, + transform_jump_on("e"), + up_visits(f_up_jump_on_e_visits()) ); } From cbdc6495056b425557230cb04ee84b09723af6a5 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 19 Feb 2024 16:36:35 +0100 Subject: [PATCH 23/40] better transform/rewrite tests --- datafusion/common/src/tree_node.rs | 392 +++++++++++++++++++---------- 1 file changed, 259 insertions(+), 133 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 68b29aa85a0a..73a99a2fa2f5 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -387,6 +387,7 @@ pub enum TreeNodeRecursion { Stop, } +#[derive(PartialEq, Debug)] pub struct Transformed { pub data: T, pub transformed: bool, @@ -631,8 +632,8 @@ mod tests { }; use crate::Result; use std::fmt::Display; - use std::sync::Mutex; + #[derive(PartialEq, Debug)] struct TestTreeNode { children: Vec>, data: T, @@ -672,20 +673,20 @@ mod tests { } } - fn new_test_tree<'a>() -> TestTreeNode<&'a str> { - let node_a = TestTreeNode::new(vec![], "a"); - let node_b = TestTreeNode::new(vec![], "b"); - let node_d = TestTreeNode::new(vec![node_a], "d"); - let node_c = TestTreeNode::new(vec![node_b, node_d], "c"); - let node_e = TestTreeNode::new(vec![node_c], "e"); - let node_h = TestTreeNode::new(vec![], "h"); - let node_g = TestTreeNode::new(vec![node_h], "g"); - let node_f = TestTreeNode::new(vec![node_e, node_g], "f"); - let node_i = TestTreeNode::new(vec![node_f], "i"); - TestTreeNode::new(vec![node_i], "j") + fn test_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) } - fn all_visits<'a>() -> Vec<&'a str> { + fn all_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", @@ -708,9 +709,53 @@ mod tests { "f_up(i)", "f_up(j)", ] + .into_iter() + .map(|s| s.to_string()) + .collect() } - fn f_down_jump_on_a_visits<'a>() -> Vec<&'a str> { + fn transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + fn f_down_jump_on_a_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", @@ -732,9 +777,40 @@ mod tests { "f_up(i)", "f_up(j)", ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) } - fn f_down_jump_on_e_visits<'a>() -> Vec<&'a str> { + fn f_down_jump_on_e_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", @@ -748,9 +824,39 @@ mod tests { "f_up(i)", "f_up(j)", ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_down(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_down(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) } - fn f_up_jump_on_a_visits<'a>() -> Vec<&'a str> { + fn f_up_jump_on_a_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", @@ -770,9 +876,39 @@ mod tests { "f_up(i)", "f_up(j)", ] + .into_iter() + .map(|s| s.to_string()) + .collect() } - fn f_up_jump_on_e_visits<'a>() -> Vec<&'a str> { + fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); + let node_f = + TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); + TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) + } + + fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + } + + fn f_up_jump_on_e_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", @@ -795,19 +931,32 @@ mod tests { "f_up(i)", "f_up(j)", ] + .into_iter() + .map(|s| s.to_string()) + .collect() } - fn down_visits(visits: Vec<&str>) -> Vec<&str> { - visits - .into_iter() - .filter(|v| v.starts_with("f_down")) - .collect() + fn f_up_jump_on_e_transformed_tree() -> TestTreeNode { + transformed_tree() + } + + fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) } - fn up_visits(visits: Vec<&str>) -> Vec<&str> { + fn down_visits(visits: Vec) -> Vec { visits .into_iter() - .filter(|v| v.starts_with("f_up")) + .filter(|v| v.starts_with("f_down")) .collect() } @@ -860,13 +1009,13 @@ mod tests { } macro_rules! visit_test { - ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_VISITS:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); + let tree = test_tree(); let mut visitor = TestVisitor::new(Box::new($F_DOWN), Box::new($F_UP)); tree.visit(&mut visitor)?; - assert_eq!(visitor.visits, $EXPECTED); + assert_eq!(visitor.visits, $EXPECTED_VISITS); Ok(()) } @@ -876,40 +1025,40 @@ mod tests { visit_test!(test_visit, visit_continue, visit_continue, all_visits()); visit_test!( test_visit_f_down_jump_on_a, - visit_jump_on("a"), + visit_jump_on("a".to_string()), visit_continue, f_down_jump_on_a_visits() ); visit_test!( test_visit_f_down_jump_on_e, - visit_jump_on("e"), + visit_jump_on("e".to_string()), visit_continue, f_down_jump_on_e_visits() ); visit_test!( test_visit_f_up_jump_on_a, visit_continue, - visit_jump_on("a"), + visit_jump_on("a".to_string()), f_up_jump_on_a_visits() ); visit_test!( test_visit_f_up_jump_on_e, visit_continue, - visit_jump_on("e"), + visit_jump_on("e".to_string()), f_up_jump_on_e_visits() ); macro_rules! test_apply { - ($NAME:ident, $F:expr, $EXPECTED:expr) => { + ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); + let tree = test_tree(); let mut visits = vec![]; tree.apply(&mut |node| { visits.push(format!("f_down({})", node.data)); $F(node) })?; - assert_eq!(visits, $EXPECTED); + assert_eq!(visits, $EXPECTED_VISITS); Ok(()) } @@ -919,12 +1068,12 @@ mod tests { test_apply!(test_apply, visit_continue, down_visits(all_visits())); test_apply!( test_apply_f_down_jump_on_a, - visit_jump_on("a"), + visit_jump_on("a".to_string()), down_visits(f_down_jump_on_a_visits()) ); test_apply!( test_apply_f_down_jump_on_e, - visit_jump_on("e"), + visit_jump_on("e".to_string()), down_visits(f_down_jump_on_e_visits()) ); @@ -932,18 +1081,13 @@ mod tests { Box) -> Result>>>; struct TestRewriter { - visits: Vec, f_down: TestRewriterF, f_up: TestRewriterF, } impl TestRewriter { fn new(f_down: TestRewriterF, f_up: TestRewriterF) -> Self { - Self { - visits: vec![], - f_down, - f_up, - } + Self { f_down, f_up } } } @@ -951,42 +1095,49 @@ mod tests { type Node = TestTreeNode; fn f_down(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_down({})", node.data)); (*self.f_down)(node) } fn f_up(&mut self, node: Self::Node) -> Result> { - self.visits.push(format!("f_up({})", node.data)); (*self.f_up)(node) } } - fn transform_continue( - node: TestTreeNode, - ) -> Result>> { - Ok(Transformed::no(node)) + fn transform_yes( + f: String, + ) -> impl FnMut(TestTreeNode) -> Result>> + { + move |node| { + Ok(Transformed::yes(TestTreeNode::new( + node.children, + format!("{}({})", f, node.data), + ))) + } } - fn transform_jump_on( - data: T, - ) -> impl FnMut(TestTreeNode) -> Result>> { + fn transform_jump_on( + f: String, + data: String, + ) -> impl FnMut(TestTreeNode) -> Result>> + { move |node| { + let new_node = + TestTreeNode::new(node.children, format!("{}({})", f, node.data)); Ok(if node.data == data { - Transformed::new(node, false, TreeNodeRecursion::Jump) + Transformed::new(new_node, true, TreeNodeRecursion::Jump) } else { - Transformed::no(node) + Transformed::yes(new_node) }) } } macro_rules! rewrite_test { - ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); + let tree = test_tree(); let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); - tree.rewrite(&mut rewriter)?; - assert_eq!(rewriter.visits, $EXPECTED); + assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE); Ok(()) } @@ -995,56 +1146,41 @@ mod tests { rewrite_test!( test_rewrite, - transform_continue, - transform_continue, - all_visits() + transform_yes("f_down".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_a, - transform_jump_on("a"), - transform_continue, - f_down_jump_on_a_visits() + transform_jump_on("f_down".to_string(), "a".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(f_down_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_e, - transform_jump_on("e"), - transform_continue, - f_down_jump_on_e_visits() + transform_jump_on("f_down".to_string(), "e".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(f_down_jump_on_e_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_a, - transform_continue, - transform_jump_on("a"), - f_up_jump_on_a_visits() + transform_yes("f_down".to_string()), + transform_jump_on("f_up".to_string(), "f_down(a)".to_string()), + Transformed::yes(f_up_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_e, - transform_continue, - transform_jump_on("e"), - f_up_jump_on_e_visits() + transform_yes("f_down".to_string()), + transform_jump_on("f_up".to_string(), "f_down(e)".to_string()), + Transformed::yes(f_up_jump_on_e_transformed_tree()) ); macro_rules! transform_test { - ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED:expr) => { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); - // TreeNode::transform() is not useful when there is a mutable object shared - // between `f_down` and `f_up` closures, so we need a trick to test it. - let visits = Mutex::new(vec![]); - tree.transform( - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_down({})", node.data)); - $F_DOWN(node) - }, - &mut |node| { - let mut mut_visits = visits.lock().unwrap(); - mut_visits.push(format!("f_up({})", node.data)); - $F_UP(node) - }, - )?; - assert_eq!(visits.into_inner().unwrap(), $EXPECTED); + let tree = test_tree(); + assert_eq!(tree.transform(&mut $F_DOWN, &mut $F_UP,)?, $EXPECTED_TREE); Ok(()) } @@ -1053,46 +1189,41 @@ mod tests { transform_test!( test_transform, - transform_continue, - transform_continue, - all_visits() + transform_yes("f_down".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_a, - transform_jump_on("a"), - transform_continue, - f_down_jump_on_a_visits() + transform_jump_on("f_down".to_string(), "a".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(f_down_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_e, - transform_jump_on("e"), - transform_continue, - f_down_jump_on_e_visits() + transform_jump_on("f_down".to_string(), "e".to_string()), + transform_yes("f_up".to_string()), + Transformed::yes(f_down_jump_on_e_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_a, - transform_continue, - transform_jump_on("a"), - f_up_jump_on_a_visits() + transform_yes("f_down".to_string()), + transform_jump_on("f_up".to_string(), "f_down(a)".to_string()), + Transformed::yes(f_up_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_e, - transform_continue, - transform_jump_on("e"), - f_up_jump_on_e_visits() + transform_yes("f_down".to_string()), + transform_jump_on("f_up".to_string(), "f_down(e)".to_string()), + Transformed::yes(f_up_jump_on_e_transformed_tree()) ); macro_rules! transform_down_test { - ($NAME:ident, $F:expr, $EXPECTED:expr) => { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_down_mut(&mut |node| { - visits.push(format!("f_down({})", node.data)); - $F(node) - })?; - assert_eq!(visits, $EXPECTED); + let tree = test_tree(); + assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE); Ok(()) } @@ -1101,31 +1232,26 @@ mod tests { transform_down_test!( test_transform_down, - transform_continue, - down_visits(all_visits()) + transform_yes("f_down".to_string()), + Transformed::yes(transformed_down_tree()) ); transform_down_test!( test_transform_down_f_down_jump_on_a, - transform_jump_on("a"), - down_visits(f_down_jump_on_a_visits()) + transform_jump_on("f_down".to_string(), "a".to_string()), + Transformed::yes(f_down_jump_on_a_transformed_down_tree()) ); transform_down_test!( test_transform_down_f_down_jump_on_e, - transform_jump_on("e"), - down_visits(f_down_jump_on_e_visits()) + transform_jump_on("f_down".to_string(), "e".to_string()), + Transformed::yes(f_down_jump_on_e_transformed_down_tree()) ); macro_rules! transform_up_test { - ($NAME:ident, $F:expr, $EXPECTED:expr) => { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { #[test] fn $NAME() -> Result<()> { - let tree = new_test_tree(); - let mut visits = vec![]; - tree.transform_up_mut(&mut |node| { - visits.push(format!("f_up({})", node.data)); - $F(node) - })?; - assert_eq!(visits, $EXPECTED); + let tree = test_tree(); + assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE); Ok(()) } @@ -1134,17 +1260,17 @@ mod tests { transform_up_test!( test_transform_up, - transform_continue, - up_visits(all_visits()) + transform_yes("f_up".to_string()), + Transformed::yes(transformed_up_tree()) ); transform_up_test!( test_transform_up_f_up_jump_on_a, - transform_jump_on("a"), - up_visits(f_up_jump_on_a_visits()) + transform_jump_on("f_up".to_string(), "a".to_string()), + Transformed::yes(f_up_jump_on_a_transformed_up_tree()) ); transform_up_test!( test_transform_up_f_up_jump_on_e, - transform_jump_on("e"), - up_visits(f_up_jump_on_e_visits()) + transform_jump_on("f_up".to_string(), "e".to_string()), + Transformed::yes(f_up_jump_on_e_transformed_up_tree()) ); } From 545d14ad01f804a8b997485095f82e384863f89e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 19 Feb 2024 17:42:46 +0100 Subject: [PATCH 24/40] minor fix --- datafusion/common/src/tree_node.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 73a99a2fa2f5..0138ee116e06 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -941,16 +941,7 @@ mod tests { } fn f_up_jump_on_e_transformed_up_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(h)".to_string()); - let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); - let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); - let node_i = TestTreeNode::new(vec![node_f], "f_up(i)".to_string()); - TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) + transformed_up_tree() } fn down_visits(visits: Vec) -> Vec { From af989a0f7f31c60ca37cd08e84a17004ddf9f066 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 19 Feb 2024 20:09:33 +0100 Subject: [PATCH 25/40] simplify tests --- datafusion/common/src/tree_node.rs | 84 +++++++++++++++--------------- 1 file changed, 43 insertions(+), 41 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0138ee116e06..033679e7daa1 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -987,11 +987,12 @@ mod tests { Ok(TreeNodeRecursion::Continue) } - fn visit_jump_on( - data: T, + fn visit_jump_on>( + data: U, ) -> impl FnMut(&TestTreeNode) -> Result { + let d = data.into(); move |node| { - Ok(if node.data == data { + Ok(if node.data == d { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue @@ -1016,26 +1017,26 @@ mod tests { visit_test!(test_visit, visit_continue, visit_continue, all_visits()); visit_test!( test_visit_f_down_jump_on_a, - visit_jump_on("a".to_string()), + visit_jump_on("a"), visit_continue, f_down_jump_on_a_visits() ); visit_test!( test_visit_f_down_jump_on_e, - visit_jump_on("e".to_string()), + visit_jump_on("e"), visit_continue, f_down_jump_on_e_visits() ); visit_test!( test_visit_f_up_jump_on_a, visit_continue, - visit_jump_on("a".to_string()), + visit_jump_on("a"), f_up_jump_on_a_visits() ); visit_test!( test_visit_f_up_jump_on_e, visit_continue, - visit_jump_on("e".to_string()), + visit_jump_on("e"), f_up_jump_on_e_visits() ); @@ -1059,12 +1060,12 @@ mod tests { test_apply!(test_apply, visit_continue, down_visits(all_visits())); test_apply!( test_apply_f_down_jump_on_a, - visit_jump_on("a".to_string()), + visit_jump_on("a"), down_visits(f_down_jump_on_a_visits()) ); test_apply!( test_apply_f_down_jump_on_e, - visit_jump_on("e".to_string()), + visit_jump_on("e"), down_visits(f_down_jump_on_e_visits()) ); @@ -1094,8 +1095,8 @@ mod tests { } } - fn transform_yes( - f: String, + fn transform_yes( + f: P, ) -> impl FnMut(TestTreeNode) -> Result>> { move |node| { @@ -1106,15 +1107,16 @@ mod tests { } } - fn transform_jump_on( - f: String, - data: String, + fn transform_jump_on>( + f: P, + data: U, ) -> impl FnMut(TestTreeNode) -> Result>> { + let d = data.into(); move |node| { let new_node = TestTreeNode::new(node.children, format!("{}({})", f, node.data)); - Ok(if node.data == data { + Ok(if node.data == d { Transformed::new(new_node, true, TreeNodeRecursion::Jump) } else { Transformed::yes(new_node) @@ -1137,32 +1139,32 @@ mod tests { rewrite_test!( test_rewrite, - transform_yes("f_down".to_string()), - transform_yes("f_up".to_string()), + transform_yes("f_down"), + transform_yes("f_up"), Transformed::yes(transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_a, - transform_jump_on("f_down".to_string(), "a".to_string()), - transform_yes("f_up".to_string()), + transform_jump_on("f_down", "a"), + transform_yes("f_up"), Transformed::yes(f_down_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_e, - transform_jump_on("f_down".to_string(), "e".to_string()), - transform_yes("f_up".to_string()), + transform_jump_on("f_down", "e"), + transform_yes("f_up"), Transformed::yes(f_down_jump_on_e_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_a, - transform_yes("f_down".to_string()), - transform_jump_on("f_up".to_string(), "f_down(a)".to_string()), + transform_yes("f_down"), + transform_jump_on("f_up", "f_down(a)"), Transformed::yes(f_up_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_e, - transform_yes("f_down".to_string()), - transform_jump_on("f_up".to_string(), "f_down(e)".to_string()), + transform_yes("f_down"), + transform_jump_on("f_up", "f_down(e)"), Transformed::yes(f_up_jump_on_e_transformed_tree()) ); @@ -1180,32 +1182,32 @@ mod tests { transform_test!( test_transform, - transform_yes("f_down".to_string()), - transform_yes("f_up".to_string()), + transform_yes("f_down"), + transform_yes("f_up"), Transformed::yes(transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_a, - transform_jump_on("f_down".to_string(), "a".to_string()), - transform_yes("f_up".to_string()), + transform_jump_on("f_down", "a"), + transform_yes("f_up"), Transformed::yes(f_down_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_e, - transform_jump_on("f_down".to_string(), "e".to_string()), - transform_yes("f_up".to_string()), + transform_jump_on("f_down", "e"), + transform_yes("f_up"), Transformed::yes(f_down_jump_on_e_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_a, - transform_yes("f_down".to_string()), - transform_jump_on("f_up".to_string(), "f_down(a)".to_string()), + transform_yes("f_down"), + transform_jump_on("f_up", "f_down(a)"), Transformed::yes(f_up_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_e, - transform_yes("f_down".to_string()), - transform_jump_on("f_up".to_string(), "f_down(e)".to_string()), + transform_yes("f_down"), + transform_jump_on("f_up", "f_down(e)"), Transformed::yes(f_up_jump_on_e_transformed_tree()) ); @@ -1223,17 +1225,17 @@ mod tests { transform_down_test!( test_transform_down, - transform_yes("f_down".to_string()), + transform_yes("f_down"), Transformed::yes(transformed_down_tree()) ); transform_down_test!( test_transform_down_f_down_jump_on_a, - transform_jump_on("f_down".to_string(), "a".to_string()), + transform_jump_on("f_down", "a"), Transformed::yes(f_down_jump_on_a_transformed_down_tree()) ); transform_down_test!( test_transform_down_f_down_jump_on_e, - transform_jump_on("f_down".to_string(), "e".to_string()), + transform_jump_on("f_down", "e"), Transformed::yes(f_down_jump_on_e_transformed_down_tree()) ); @@ -1251,17 +1253,17 @@ mod tests { transform_up_test!( test_transform_up, - transform_yes("f_up".to_string()), + transform_yes("f_up"), Transformed::yes(transformed_up_tree()) ); transform_up_test!( test_transform_up_f_up_jump_on_a, - transform_jump_on("f_up".to_string(), "a".to_string()), + transform_jump_on("f_up", "a"), Transformed::yes(f_up_jump_on_a_transformed_up_tree()) ); transform_up_test!( test_transform_up_f_up_jump_on_e, - transform_jump_on("f_up".to_string(), "e".to_string()), + transform_jump_on("f_up", "e"), Transformed::yes(f_up_jump_on_e_transformed_up_tree()) ); } From 79db28d2d6300e6d8a880accea384f8de38ef924 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Tue, 20 Feb 2024 20:26:17 +0100 Subject: [PATCH 26/40] add stop tests, reorganize tests --- datafusion/common/src/tree_node.rs | 565 +++++++++++++++++++++++------ 1 file changed, 462 insertions(+), 103 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 033679e7daa1..cb3e91bf2686 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -371,7 +371,7 @@ pub trait TreeNodeRewriter: Sized { } /// Controls how [`TreeNode`] recursions should proceed. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone, Copy)] pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, @@ -673,6 +673,19 @@ mod tests { } } + // J + // | + // I + // | + // F + // / \ + // E G + // | | + // C H + // / \ + // B D + // | + // A fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new(vec![], "a".to_string()); let node_b = TestTreeNode::new(vec![], "b".to_string()); @@ -686,6 +699,9 @@ mod tests { TestTreeNode::new(vec![node_i], "j".to_string()) } + // Continue on all nodes + + // Expected visits in a combined traversal fn all_visits() -> Vec { vec![ "f_down(j)", @@ -714,6 +730,7 @@ mod tests { .collect() } + // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); @@ -729,6 +746,7 @@ mod tests { TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) } + // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); @@ -742,6 +760,7 @@ mod tests { TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) } + // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); @@ -755,6 +774,8 @@ mod tests { TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) } + // f_down Jump on A node + fn f_down_jump_on_a_visits() -> Vec { vec![ "f_down(j)", @@ -810,6 +831,8 @@ mod tests { TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) } + // f_down Jump on E node + fn f_down_jump_on_e_visits() -> Vec { vec![ "f_down(j)", @@ -856,6 +879,8 @@ mod tests { TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) } + // f_up Jump on A node + fn f_up_jump_on_a_visits() -> Vec { vec![ "f_down(j)", @@ -908,6 +933,8 @@ mod tests { TestTreeNode::new(vec![node_i], "f_up(j)".to_string()) } + // f_up Jump on E node + fn f_up_jump_on_e_visits() -> Vec { vec![ "f_down(j)", @@ -944,6 +971,182 @@ mod tests { transformed_up_tree() } + // f_down Stop on A node + + fn f_down_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_down Stop on E node + + fn f_down_stop_on_e_visits() -> Vec { + vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "a".to_string()); + let node_b = TestTreeNode::new(vec![], "b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + // f_up Stop on A node + + fn f_up_stop_on_a_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + + // f_up Stop on E node + + fn f_up_stop_on_e_visits() -> Vec { + vec![ + "f_down(j)", + "f_down(i)", + "f_down(f)", + "f_down(e)", + "f_down(c)", + "f_down(b)", + "f_up(b)", + "f_down(d)", + "f_down(a)", + "f_up(a)", + "f_up(d)", + "f_up(c)", + "f_up(e)", + ] + .into_iter() + .map(|s| s.to_string()) + .collect() + } + + fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(f_down(a))".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_c = + TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f_down(f)".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "f_down(i)".to_string()); + TestTreeNode::new(vec![node_i], "f_down(j)".to_string()) + } + + fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { + let node_a = TestTreeNode::new(vec![], "f_up(a)".to_string()); + let node_b = TestTreeNode::new(vec![], "f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_h = TestTreeNode::new(vec![], "h".to_string()); + let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); + let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) + } + fn down_visits(visits: Vec) -> Vec { visits .into_iter() @@ -987,13 +1190,14 @@ mod tests { Ok(TreeNodeRecursion::Continue) } - fn visit_jump_on>( - data: U, + fn visit_event_on>( + data: D, + event: TreeNodeRecursion, ) -> impl FnMut(&TestTreeNode) -> Result { let d = data.into(); move |node| { Ok(if node.data == d { - TreeNodeRecursion::Jump + event } else { TreeNodeRecursion::Continue }) @@ -1014,32 +1218,6 @@ mod tests { }; } - visit_test!(test_visit, visit_continue, visit_continue, all_visits()); - visit_test!( - test_visit_f_down_jump_on_a, - visit_jump_on("a"), - visit_continue, - f_down_jump_on_a_visits() - ); - visit_test!( - test_visit_f_down_jump_on_e, - visit_jump_on("e"), - visit_continue, - f_down_jump_on_e_visits() - ); - visit_test!( - test_visit_f_up_jump_on_a, - visit_continue, - visit_jump_on("a"), - f_up_jump_on_a_visits() - ); - visit_test!( - test_visit_f_up_jump_on_e, - visit_continue, - visit_jump_on("e"), - f_up_jump_on_e_visits() - ); - macro_rules! test_apply { ($NAME:ident, $F:expr, $EXPECTED_VISITS:expr) => { #[test] @@ -1057,18 +1235,6 @@ mod tests { }; } - test_apply!(test_apply, visit_continue, down_visits(all_visits())); - test_apply!( - test_apply_f_down_jump_on_a, - visit_jump_on("a"), - down_visits(f_down_jump_on_a_visits()) - ); - test_apply!( - test_apply_f_down_jump_on_e, - visit_jump_on("e"), - down_visits(f_down_jump_on_e_visits()) - ); - type TestRewriterF = Box) -> Result>>>; @@ -1095,29 +1261,34 @@ mod tests { } } - fn transform_yes( - f: P, - ) -> impl FnMut(TestTreeNode) -> Result>> - { + fn transform_yes>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { move |node| { Ok(Transformed::yes(TestTreeNode::new( node.children, - format!("{}({})", f, node.data), + format!("{}({})", transformation_name, node.data).into(), ))) } } - fn transform_jump_on>( - f: P, - data: U, - ) -> impl FnMut(TestTreeNode) -> Result>> - { + fn transform_and_event_on< + N: Display, + T: PartialEq + Display + From, + D: Into, + >( + transformation_name: N, + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(TestTreeNode) -> Result>> { let d = data.into(); move |node| { - let new_node = - TestTreeNode::new(node.children, format!("{}({})", f, node.data)); + let new_node = TestTreeNode::new( + node.children, + format!("{}({})", transformation_name, node.data).into(), + ); Ok(if node.data == d { - Transformed::new(new_node, true, TreeNodeRecursion::Jump) + Transformed::new(new_node, true, event) } else { Transformed::yes(new_node) }) @@ -1137,6 +1308,114 @@ mod tests { }; } + macro_rules! transform_test { + ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform(&mut $F_DOWN, &mut $F_UP,)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_down_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + macro_rules! transform_up_test { + ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { + #[test] + fn $NAME() -> Result<()> { + let tree = test_tree(); + assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE); + + Ok(()) + } + }; + } + + visit_test!(test_visit, visit_continue, visit_continue, all_visits()); + visit_test!( + test_visit_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_a_visits() + ); + visit_test!( + test_visit_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + visit_continue, + f_down_jump_on_e_visits() + ); + visit_test!( + test_visit_f_up_jump_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Jump), + f_up_jump_on_a_visits() + ); + visit_test!( + test_visit_f_up_jump_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Jump), + f_up_jump_on_e_visits() + ); + visit_test!( + test_visit_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_a_visits() + ); + visit_test!( + test_visit_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + visit_continue, + f_down_stop_on_e_visits() + ); + visit_test!( + test_visit_f_up_stop_on_a, + visit_continue, + visit_event_on("a", TreeNodeRecursion::Stop), + f_up_stop_on_a_visits() + ); + visit_test!( + test_visit_f_up_stop_on_e, + visit_continue, + visit_event_on("e", TreeNodeRecursion::Stop), + f_up_stop_on_e_visits() + ); + + test_apply!(test_apply, visit_continue, down_visits(all_visits())); + test_apply!( + test_apply_f_down_jump_on_a, + visit_event_on("a", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_a_visits()) + ); + test_apply!( + test_apply_f_down_jump_on_e, + visit_event_on("e", TreeNodeRecursion::Jump), + down_visits(f_down_jump_on_e_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_a, + visit_event_on("a", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_a_visits()) + ); + test_apply!( + test_apply_f_down_stop_on_e, + visit_event_on("e", TreeNodeRecursion::Stop), + down_visits(f_down_stop_on_e_visits()) + ); + rewrite_test!( test_rewrite, transform_yes("f_down"), @@ -1145,40 +1424,68 @@ mod tests { ); rewrite_test!( test_rewrite_f_down_jump_on_a, - transform_jump_on("f_down", "a"), + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), transform_yes("f_up"), Transformed::yes(f_down_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_e, - transform_jump_on("f_down", "e"), + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), transform_yes("f_up"), Transformed::yes(f_down_jump_on_e_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_a, transform_yes("f_down"), - transform_jump_on("f_up", "f_down(a)"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_a_transformed_tree()) ); rewrite_test!( test_rewrite_f_up_jump_on_e, transform_yes("f_down"), - transform_jump_on("f_up", "f_down(e)"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_e_transformed_tree()) ); - - macro_rules! transform_test { - ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { - #[test] - fn $NAME() -> Result<()> { - let tree = test_tree(); - assert_eq!(tree.transform(&mut $F_DOWN, &mut $F_UP,)?, $EXPECTED_TREE); - - Ok(()) - } - }; - } + rewrite_test!( + test_rewrite_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + rewrite_test!( + test_rewrite_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); transform_test!( test_transform, @@ -1188,40 +1495,68 @@ mod tests { ); transform_test!( test_transform_f_down_jump_on_a, - transform_jump_on("f_down", "a"), + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), transform_yes("f_up"), Transformed::yes(f_down_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_e, - transform_jump_on("f_down", "e"), + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), transform_yes("f_up"), Transformed::yes(f_down_jump_on_e_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_a, transform_yes("f_down"), - transform_jump_on("f_up", "f_down(a)"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_a_transformed_tree()) ); transform_test!( test_transform_f_up_jump_on_e, transform_yes("f_down"), - transform_jump_on("f_up", "f_down(e)"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_e_transformed_tree()) ); - - macro_rules! transform_down_test { - ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { - #[test] - fn $NAME() -> Result<()> { - let tree = test_tree(); - assert_eq!(tree.transform_down_mut(&mut $F)?, $EXPECTED_TREE); - - Ok(()) - } - }; - } + transform_test!( + test_transform_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + transform_yes("f_up"), + Transformed::new( + f_down_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_a, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(a)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_test!( + test_transform_f_up_stop_on_e, + transform_yes("f_down"), + transform_and_event_on("f_up", "f_down(e)", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_tree(), + true, + TreeNodeRecursion::Stop + ) + ); transform_down_test!( test_transform_down, @@ -1230,26 +1565,32 @@ mod tests { ); transform_down_test!( test_transform_down_f_down_jump_on_a, - transform_jump_on("f_down", "a"), + transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), Transformed::yes(f_down_jump_on_a_transformed_down_tree()) ); transform_down_test!( test_transform_down_f_down_jump_on_e, - transform_jump_on("f_down", "e"), + transform_and_event_on("f_down", "e", TreeNodeRecursion::Jump), Transformed::yes(f_down_jump_on_e_transformed_down_tree()) ); - - macro_rules! transform_up_test { - ($NAME:ident, $F:expr, $EXPECTED_TREE:expr) => { - #[test] - fn $NAME() -> Result<()> { - let tree = test_tree(); - assert_eq!(tree.transform_up_mut(&mut $F)?, $EXPECTED_TREE); - - Ok(()) - } - }; - } + transform_down_test!( + test_transform_down_f_down_stop_on_a, + transform_and_event_on("f_down", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_a_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_down_test!( + test_transform_down_f_down_stop_on_e, + transform_and_event_on("f_down", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_down_stop_on_e_transformed_down_tree(), + true, + TreeNodeRecursion::Stop + ) + ); transform_up_test!( test_transform_up, @@ -1258,12 +1599,30 @@ mod tests { ); transform_up_test!( test_transform_up_f_up_jump_on_a, - transform_jump_on("f_up", "a"), + transform_and_event_on("f_up", "a", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_a_transformed_up_tree()) ); transform_up_test!( test_transform_up_f_up_jump_on_e, - transform_jump_on("f_up", "e"), + transform_and_event_on("f_up", "e", TreeNodeRecursion::Jump), Transformed::yes(f_up_jump_on_e_transformed_up_tree()) ); + transform_up_test!( + test_transform_up_f_up_stop_on_a, + transform_and_event_on("f_up", "a", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_a_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); + transform_up_test!( + test_transform_up_f_up_stop_on_e, + transform_and_event_on("f_up", "e", TreeNodeRecursion::Stop), + Transformed::new( + f_up_stop_on_e_transformed_up_tree(), + true, + TreeNodeRecursion::Stop + ) + ); } From e1a499ab51aea3623d1c2754f16fb70e6176d119 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Mon, 26 Feb 2024 14:24:18 +0100 Subject: [PATCH 27/40] fix previous merges and remove leftover file --- .../or_in_list_simplifier.rs | 100 ------------------ 1 file changed, 100 deletions(-) delete mode 100644 datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs diff --git a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs deleted file mode 100644 index ff50b337e158..000000000000 --- a/datafusion/optimizer/src/simplify_expressions/or_in_list_simplifier.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! This module implements a rule that simplifies OR expressions into IN list expressions - -use std::borrow::Cow; -use std::collections::HashSet; - -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; -use datafusion_common::Result; -use datafusion_expr::expr::InList; -use datafusion_expr::{BinaryExpr, Expr, Operator}; - -/// Combine multiple OR expressions into a single IN list expression if possible -/// -/// i.e. `a = 1 OR a = 2 OR a = 3` -> `a IN (1, 2, 3)` -pub(super) struct OrInListSimplifier {} - -impl OrInListSimplifier { - pub(super) fn new() -> Self { - Self {} - } -} - -impl TreeNodeRewriter for OrInListSimplifier { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = &expr { - if *op == Operator::Or { - let left = as_inlist(left); - let right = as_inlist(right); - if let (Some(lhs), Some(rhs)) = (left, right) { - if lhs.expr.try_into_col().is_ok() - && rhs.expr.try_into_col().is_ok() - && lhs.expr == rhs.expr - && !lhs.negated - && !rhs.negated - { - let lhs = lhs.into_owned(); - let rhs = rhs.into_owned(); - let mut seen: HashSet = HashSet::new(); - let list = lhs - .list - .into_iter() - .chain(rhs.list) - .filter(|e| seen.insert(e.to_owned())) - .collect::>(); - - let merged_inlist = InList { - expr: lhs.expr, - list, - negated: false, - }; - return Ok(Transformed::yes(Expr::InList(merged_inlist))); - } - } - } - } - - Ok(Transformed::no(expr)) - } -} - -/// Try to convert an expression to an in-list expression -fn as_inlist(expr: &Expr) -> Option> { - match expr { - Expr::InList(inlist) => Some(Cow::Borrowed(inlist)), - Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => { - match (left.as_ref(), right.as_ref()) { - (Expr::Column(_), Expr::Literal(_)) => Some(Cow::Owned(InList { - expr: left.clone(), - list: vec![*right.clone()], - negated: false, - })), - (Expr::Literal(_), Expr::Column(_)) => Some(Cow::Owned(InList { - expr: right.clone(), - list: vec![*left.clone()], - negated: false, - })), - _ => None, - } - } - _ => None, - } -} From 271a0fdf3e20c4a4507e4c4caa75c1bab6e55a16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Berkay=20=C5=9Eahin?= <124376117+berkaysynnada@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:52:28 +0300 Subject: [PATCH 28/40] Review TreeNode Refactor (#1) * Minor changes * Jump doesn't ignore f_up * update test * Update rewriter * LogicalPlan visit update and propagate from children flags * Update tree_node.rs * Update map_children's --------- Co-authored-by: Mustafa Akur --- datafusion/common/src/tree_node.rs | 360 ++++++++++++------ datafusion/expr/src/tree_node/expr.rs | 154 ++++---- datafusion/expr/src/tree_node/plan.rs | 42 +- .../src/analyzer/count_wildcard_rule.rs | 6 +- .../optimizer/src/common_subexpr_eliminate.rs | 14 +- 5 files changed, 347 insertions(+), 229 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index cb3e91bf2686..0f5497e4d543 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -83,12 +83,12 @@ macro_rules! handle_visit_recursion_up { /// This macro is used to determine continuation during top-down transforming traversals. /// /// After the bottom-up closure returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion /// continuation and [`TreeNodeRecursion`] state propagation. #[macro_export] macro_rules! handle_transform_recursion_down { ($F_DOWN:expr, $F_SELF:expr) => { - $F_DOWN?.and_then( + $F_DOWN?.try_transform_node_with( |n| n.map_children($F_SELF), Some(TreeNodeRecursion::Continue), ) @@ -98,18 +98,18 @@ macro_rules! handle_transform_recursion_down { /// This macro is used to determine continuation during combined transforming traversals. /// /// After the bottom-up closure returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion /// continuation and if [`TreeNodeRecursion`] state propagation is needed. /// And then after recursing into children returns with [`Transformed`] depending on the -/// returned [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// returned [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion /// continuation and [`TreeNodeRecursion`] state propagation. #[macro_export] macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => { - $F_DOWN?.and_then( + $F_DOWN?.try_transform_node_with( |n| { n.map_children($F_SELF)? - .and_then($F_UP, Some(TreeNodeRecursion::Jump)) + .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump)) }, Some(TreeNodeRecursion::Continue), ) @@ -119,14 +119,14 @@ macro_rules! handle_transform_recursion { /// This macro is used to determine continuation during bottom-up transforming traversals. /// /// After recursing into children returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::and_then()`] decides about recursion +/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion /// continuation and [`TreeNodeRecursion`] state propagation. #[macro_export] macro_rules! handle_transform_recursion_up { ($NODE:expr, $F_SELF:expr, $F_UP:expr) => { $NODE .map_children($F_SELF)? - .and_then($F_UP, Some(TreeNodeRecursion::Jump)) + .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump)) }; } @@ -141,20 +141,6 @@ macro_rules! handle_transform_recursion_up { /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Applies `f` to the node and its children. `f` is applied in a preoder way, - /// and it is controlled by [`TreeNodeRecursion`], which means result of the `f` - /// on the self node can cause an early return. - /// - /// The `f` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. - fn apply Result>( - &self, - f: &mut F, - ) -> Result { - handle_visit_recursion_down!(f(self)?); - self.apply_children(&mut |n| n.apply(f)) - } - /// Visit the tree node using the given [TreeNodeVisitor] /// It performs a depth first walk of an node and its children. /// @@ -187,13 +173,20 @@ pub trait TreeNode: Sized { &self, visitor: &mut V, ) -> Result { - handle_visit_recursion_down!(visitor.f_down(self)?); - handle_visit_recursion_up!(self.apply_children(&mut |n| n.visit(visitor))?); - visitor.f_up(self) + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + handle_visit_recursion_up!( + self.apply_children(&mut |n| n.visit(visitor))? + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => visitor.f_up(self), + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } - /// Transforms the tree using `f_down` while traversing the tree top-down - /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order). + /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for + /// recursively transforming [`TreeNode`]s. /// /// E.g. for an tree such as: /// ```text @@ -204,27 +197,64 @@ pub trait TreeNode: Sized { /// /// The nodes are visited using the following order: /// ```text - /// f_down(ParentNode) - /// f_down(ChildNode1) - /// f_up(ChildNode1) - /// f_down(ChildNode2) - /// f_up(ChildNode2) - /// f_up(ParentNode) + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// - /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. - fn transform( + /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], + /// recursion is stopped immediately. + fn rewrite>( self, - f_down: &mut FD, - f_up: &mut FU, - ) -> Result> - where - FD: FnMut(Self) -> Result>, - FU: FnMut(Self) -> Result>, - { - handle_transform_recursion!(f_down(self), |c| c.transform(f_down, f_up), f_up) + rewriter: &mut R, + ) -> Result> { + let pre_visited = rewriter.f_down(self)?; + match pre_visited.tnr { + TreeNodeRecursion::Continue => { + let with_updated_children = pre_visited + .data + .map_children(|c| c.rewrite(rewriter))? + .try_transform_node_with( + |n| rewriter.f_up(n), + Some(TreeNodeRecursion::Jump), + )?; + Ok(Transformed { + transformed: with_updated_children.transformed + || pre_visited.transformed, + ..with_updated_children + }) + } + TreeNodeRecursion::Jump => { + let pre_visited_transformed = pre_visited.transformed; + let post_visited = rewriter.f_up(pre_visited.data)?; + + Ok(Transformed { + tnr: TreeNodeRecursion::Continue, + transformed: post_visited.transformed || pre_visited_transformed, + data: post_visited.data, + }) + } + TreeNodeRecursion::Stop => Ok(pre_visited), + } + } + + /// Applies `f` to the node and its children. `f` is applied in a preoder way, + /// and it is controlled by [`TreeNodeRecursion`], which means result of the `f` + /// on the self node can cause an early return. + /// + /// The `f` closure can be used to collect some info from the + /// tree node or do some checking for the tree node. + fn apply Result>( + &self, + f: &mut F, + ) -> Result { + handle_visit_recursion_down!(f(self)?); + self.apply_children(&mut |n| n.apply(f)) } /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its @@ -267,8 +297,13 @@ pub trait TreeNode: Sized { handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f) } - /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for - /// recursively transforming [`TreeNode`]s. + /// Transforms the tree using `f_down` while traversing the tree top-down + /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order). + /// + /// Use this method if you want to start the `f_up` process right where `f_down` jumps. + /// This can make the whole process faster by reducing the number of `f_up` steps. + /// If you don't need this, it's just like using `transform_down_mut` followed by + /// `transform_up_mut` on the same tree. /// /// E.g. for an tree such as: /// ```text @@ -279,25 +314,91 @@ pub trait TreeNode: Sized { /// /// The nodes are visited using the following order: /// ```text - /// TreeNodeRewriter::f_down(ParentNode) - /// TreeNodeRewriter::f_down(ChildNode1) - /// TreeNodeRewriter::f_up(ChildNode1) - /// TreeNodeRewriter::f_down(ChildNode2) - /// TreeNodeRewriter::f_up(ChildNode2) - /// TreeNodeRewriter::f_up(ParentNode) + /// f_down(ParentNode) + /// f_down(ChildNode1) + /// f_up(ChildNode1) + /// f_down(ChildNode2) + /// f_up(ChildNode2) + /// f_up(ParentNode) /// ``` /// /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// - /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], - /// recursion is stopped immediately. - fn rewrite>( + /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. + /// + /// Example: + /// ```text + /// | +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// TreeNodeRecursion::Continue | | I | + /// | +---+ + /// | | + /// | +---+ + /// \|/ | F | + /// ' +---+ + /// / \ ___________________ + /// When `f_down` is +---+ \ ---+ + /// applied on node "E", | E | | G | + /// it returns with "jump". +---+ +---+ + /// | | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// + /// Instead of starting from leaf nodes, `f_up` starts from the node "E". + /// +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// | | I | + /// | +---+ + /// | | + /// / +---+ + /// / | F | + /// / +---+ + /// / / \ ______________________ + /// | +---+ . \ ---+ + /// | | E | /|\ After `f_down` jumps | G | + /// | +---+ | on node E, `f_up` +---+ + /// \------| ---/ if applied on node E. | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// ``` + fn transform_down_up( self, - rewriter: &mut R, - ) -> Result> { - handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| { - rewriter.f_up(n) - }) + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> + where + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + { + handle_transform_recursion!( + f_down(self), + |c| c.transform_down_up(f_down, f_up), + f_up + ) } /// Apply the closure `F` to the node's children @@ -341,7 +442,10 @@ pub trait TreeNodeVisitor: Sized { type Node: TreeNode; /// Invoked before any children of `node` are visited. - fn f_down(&mut self, node: &Self::Node) -> Result; + /// Default implementation returns the node unmodified and continues recursion. + fn f_down(&mut self, _node: &Self::Node) -> Result { + Ok(TreeNodeRecursion::Continue) + } /// Invoked after all children of `node` are visited. Default /// implementation does nothing. @@ -355,15 +459,13 @@ pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type Node: TreeNode; - /// Invoked while traversing down the tree before any children are rewritten / - /// visited. + /// Invoked while traversing down the tree before any children are rewritten. /// Default implementation returns the node unmodified and continues recursion. fn f_down(&mut self, node: Self::Node) -> Result> { Ok(Transformed::no(node)) } - /// Invoked while traversing up the tree after all children have been rewritten / - /// visited. + /// Invoked while traversing up the tree after all children have been rewritten. /// Default implementation returns the node unmodified. fn f_up(&mut self, node: Self::Node) -> Result> { Ok(Transformed::no(node)) @@ -376,11 +478,15 @@ pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, - /// In top-down traversals skip recursing into children but continue with the next + /// In top-down traversals, skip recursing into children but continue with the next /// node, which actually means pruning of the subtree. - /// In bottom-up traversals bypass calling bottom-up closures till the next leaf node. - /// In combined traversals bypass calling bottom-up closures till the next top-down - /// closure. + /// + /// In bottom-up traversals, bypass calling bottom-up closures till the next leaf node. + /// + /// In combined traversals, if it is "f_down" (pre-order) phase, execution "jumps" to + /// next "f_up" (post_order) phase by shortcutting its children. If it is "f_up" (pre-order) + /// phase, execution "jumps" to next "f_down" (pre_order) phase by shortcutting its parent + /// nodes until the first parent node having unvisited children path. Jump, /// Stop recursion. @@ -403,6 +509,7 @@ impl Transformed { } } + /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. pub fn yes(data: T) -> Self { Self { data, @@ -411,6 +518,7 @@ impl Transformed { } } + /// Wrapper for non-transformed data with [`TreeNodeRecursion::Continue`] statement. pub fn no(data: T) -> Self { Self { data, @@ -419,7 +527,8 @@ impl Transformed { } } - pub fn map_data U>(self, f: F) -> Transformed { + /// Applies the given `f` to the data of [`Transformed`] object. + pub fn update_data U>(self, f: F) -> Transformed { Transformed { data: f(self.data), transformed: self.transformed, @@ -427,31 +536,30 @@ impl Transformed { } } - pub fn flat_map_data Result>( - self, - f: F, - ) -> Result> { - Ok(Transformed { - data: f(self.data)?, + /// Maps the data of [`Transformed`] object to the result of the given `f`. + pub fn map_data Result>(self, f: F) -> Result> { + f(self.data).map(|data| Transformed { + data, transformed: self.transformed, tnr: self.tnr, }) } - /// This is an important function to decide about recursion continuation and - /// [`TreeNodeRecursion`] state propagation. Handling [`TreeNodeRecursion::Continue`] - /// and [`TreeNodeRecursion::Stop`] is always straightforward, but - /// [`TreeNodeRecursion::Jump`] can behave differently when we are traversing down or - /// up on a tree. - fn and_then Result>>( + /// According to the TreeNodeRecursion condition on the node, the function decides + /// applying the given `f` to the node's data. Handling [`TreeNodeRecursion::Continue`] + /// and [`TreeNodeRecursion::Stop`] is straightforward, but [`TreeNodeRecursion::Jump`] + /// can behave differently when we are traversing down or up on a tree. If `return_if_jump` + /// is `Some`, `jump` condition on the node would stop the recursion with the given + /// [`TreeNodeRecursion`] statement. + fn try_transform_node_with Result>>( self, f: F, - return_on_jump: Option, + return_if_jump: Option, ) -> Result> { match self.tnr { TreeNodeRecursion::Continue => {} TreeNodeRecursion::Jump => { - if let Some(tnr) = return_on_jump { + if let Some(tnr) = return_if_jump { return Ok(Transformed { tnr, ..self }); } } @@ -464,26 +572,26 @@ impl Transformed { }) } - pub fn and_then_transform Result>>( + /// More simple version of [`Self::try_transform_node_with`]. If [`TreeNodeRecursion`] + /// of the node is [`TreeNodeRecursion::Continue`] or [`TreeNodeRecursion::Jump`], + /// transformation is applied to the node. Otherwise, it remains as it is. + pub fn try_transform_node Result>>( self, f: F, ) -> Result> { - self.and_then(f, None) + self.try_transform_node_with(f, None) } } pub trait TransformedIterator: Iterator { - fn map_till_continue_and_collect( - self, - f: F, - ) -> Result>> + fn map_until_stop_and_collect(self, f: F) -> Result>> where F: FnMut(Self::Item) -> Result>, Self: Sized; } impl TransformedIterator for I { - fn map_till_continue_and_collect( + fn map_until_stop_and_collect( self, mut f: F, ) -> Result>> @@ -555,18 +663,18 @@ impl TreeNode for Arc { { let children = self.arc_children(); if !children.is_empty() { - let t = children.into_iter().map_till_continue_and_collect(f)?; - // TODO: Currently `assert_eq!(t.transformed, t2.transformed)` fails as - // `t.transformed` quality comes from if the transformation closures fill the - // field correctly. - // Once we trust `t.transformed` we can remove the additional check in - // `with_new_arc_children()`. - let arc_self = Arc::clone(&self); - let t2 = self.with_new_arc_children(arc_self, t.data)?; - - // Propagate up `t.transformed` and `t.tnr` along with the node containing - // transformed children. - Ok(Transformed::new(t2.data, t.transformed, t.tnr)) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + let arc_self = Arc::clone(&self); + new_children.map_data(|children| { + self.with_new_arc_children(arc_self, children) + .map(|new| new.data) + }) + } else { + Ok(Transformed::no(self)) + } } else { Ok(Transformed::no(self)) } @@ -607,17 +715,18 @@ impl TreeNode for T { { let (new_self, children) = self.take_children(); if !children.is_empty() { - let t = children.into_iter().map_till_continue_and_collect(f)?; - // TODO: Currently `assert_eq!(t.transformed, t2.transformed)` fails as - // `t.transformed` quality comes from if the transformation closures fill the - // field correctly. - // Once we trust `t.transformed` we can remove the additional check in - // `with_new_children()`. - let t2 = new_self.with_new_children(t.data)?; - - // Propagate up `t.transformed` and `t.tnr` along with the node containing - // transformed children. - Ok(Transformed::new(t2.data, t.transformed, t.tnr)) + let new_children = children.into_iter().map_until_stop_and_collect(f)?; + if new_children.transformed { + // Propagate up `t.transformed` and `t.tnr` along with + // the node containing transformed children. + new_children.map_data(|children| { + new_self.with_new_children(children).map(|new| new.data) + }) + } else { + Ok(Transformed::no( + new_self.with_new_children(new_children.data)?.data, + )) + } } else { Ok(Transformed::no(new_self)) } @@ -665,8 +774,8 @@ mod tests { Ok(self .children .into_iter() - .map_till_continue_and_collect(f)? - .map_data(|new_children| Self { + .map_until_stop_and_collect(f)? + .update_data(|new_children| Self { children: new_children, ..self })) @@ -700,7 +809,6 @@ mod tests { } // Continue on all nodes - // Expected visits in a combined traversal fn all_visits() -> Vec { vec![ @@ -775,7 +883,6 @@ mod tests { } // f_down Jump on A node - fn f_down_jump_on_a_visits() -> Vec { vec![ "f_down(j)", @@ -787,6 +894,7 @@ mod tests { "f_up(b)", "f_down(d)", "f_down(a)", + "f_up(a)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -832,13 +940,13 @@ mod tests { } // f_down Jump on E node - fn f_down_jump_on_e_visits() -> Vec { vec![ "f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)", + "f_up(e)", "f_down(g)", "f_down(h)", "f_up(h)", @@ -857,7 +965,7 @@ mod tests { let node_b = TestTreeNode::new(vec![], "b".to_string()); let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = @@ -880,7 +988,6 @@ mod tests { } // f_up Jump on A node - fn f_up_jump_on_a_visits() -> Vec { vec![ "f_down(j)", @@ -934,7 +1041,6 @@ mod tests { } // f_up Jump on E node - fn f_up_jump_on_e_visits() -> Vec { vec![ "f_down(j)", @@ -1017,7 +1123,6 @@ mod tests { } // f_down Stop on E node - fn f_down_stop_on_e_visits() -> Vec { vec!["f_down(j)", "f_down(i)", "f_down(f)", "f_down(e)"] .into_iter() @@ -1052,7 +1157,6 @@ mod tests { } // f_up Stop on A node - fn f_up_stop_on_a_visits() -> Vec { vec![ "f_down(j)", @@ -1098,7 +1202,6 @@ mod tests { } // f_up Stop on E node - fn f_up_stop_on_e_visits() -> Vec { vec![ "f_down(j)", @@ -1313,7 +1416,10 @@ mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform(&mut $F_DOWN, &mut $F_UP,)?, $EXPECTED_TREE); + assert_eq!( + tree.transform_down_up(&mut $F_DOWN, &mut $F_UP,)?, + $EXPECTED_TREE + ); Ok(()) } @@ -1426,7 +1532,7 @@ mod tests { test_rewrite_f_down_jump_on_a, transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), transform_yes("f_up"), - Transformed::yes(f_down_jump_on_a_transformed_tree()) + Transformed::yes(transformed_tree()) ); rewrite_test!( test_rewrite_f_down_jump_on_e, @@ -1497,7 +1603,7 @@ mod tests { test_transform_f_down_jump_on_a, transform_and_event_on("f_down", "a", TreeNodeRecursion::Jump), transform_yes("f_up"), - Transformed::yes(f_down_jump_on_a_transformed_tree()) + Transformed::yes(transformed_tree()) ); transform_test!( test_transform_f_down_jump_on_e, diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 8f043030d562..a955bac2bf37 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -160,21 +160,22 @@ impl TreeNode for Expr { expr, relation, name, - }) => f(*expr)?.map_data(|e| Expr::Alias(Alias::new(e, relation, name))), + }) => f(*expr)?.update_data(|e| Expr::Alias(Alias::new(e, relation, name))), Expr::InSubquery(InSubquery { expr, subquery, negated, - }) => transform_box(expr, &mut f)? - .map_data(|be| Expr::InSubquery(InSubquery::new(be, subquery, negated))), + }) => transform_box(expr, &mut f)?.update_data(|be| { + Expr::InSubquery(InSubquery::new(be, subquery, negated)) + }), Expr::BinaryExpr(BinaryExpr { left, op, right }) => { transform_box(left, &mut f)? - .map_data(|new_left| (new_left, right)) - .and_then_transform(|(new_left, right)| { + .update_data(|new_left| (new_left, right)) + .try_transform_node(|(new_left, right)| { Ok(transform_box(right, &mut f)? - .map_data(|new_right| (new_left, new_right))) + .update_data(|new_right| (new_left, new_right))) })? - .map_data(|(new_left, new_right)| { + .update_data(|(new_left, new_right)| { Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right)) }) } @@ -185,12 +186,12 @@ impl TreeNode for Expr { escape_char, case_insensitive, }) => transform_box(expr, &mut f)? - .map_data(|new_expr| (new_expr, pattern)) - .and_then_transform(|(new_expr, pattern)| { + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { Ok(transform_box(pattern, &mut f)? - .map_data(|new_pattern| (new_expr, new_pattern))) + .update_data(|new_pattern| (new_expr, new_pattern))) })? - .map_data(|(new_expr, new_pattern)| { + .update_data(|(new_expr, new_pattern)| { Expr::Like(Like::new( negated, new_expr, @@ -206,12 +207,12 @@ impl TreeNode for Expr { escape_char, case_insensitive, }) => transform_box(expr, &mut f)? - .map_data(|new_expr| (new_expr, pattern)) - .and_then_transform(|(new_expr, pattern)| { + .update_data(|new_expr| (new_expr, pattern)) + .try_transform_node(|(new_expr, pattern)| { Ok(transform_box(pattern, &mut f)? - .map_data(|new_pattern| (new_expr, new_pattern))) + .update_data(|new_pattern| (new_expr, new_pattern))) })? - .map_data(|(new_expr, new_pattern)| { + .update_data(|(new_expr, new_pattern)| { Expr::SimilarTo(Like::new( negated, new_expr, @@ -220,42 +221,46 @@ impl TreeNode for Expr { case_insensitive, )) }), - Expr::Not(expr) => transform_box(expr, &mut f)?.map_data(Expr::Not), + Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not), Expr::IsNotNull(expr) => { - transform_box(expr, &mut f)?.map_data(Expr::IsNotNull) + transform_box(expr, &mut f)?.update_data(Expr::IsNotNull) + } + Expr::IsNull(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsNull), + Expr::IsTrue(expr) => transform_box(expr, &mut f)?.update_data(Expr::IsTrue), + Expr::IsFalse(expr) => { + transform_box(expr, &mut f)?.update_data(Expr::IsFalse) } - Expr::IsNull(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsNull), - Expr::IsTrue(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsTrue), - Expr::IsFalse(expr) => transform_box(expr, &mut f)?.map_data(Expr::IsFalse), Expr::IsUnknown(expr) => { - transform_box(expr, &mut f)?.map_data(Expr::IsUnknown) + transform_box(expr, &mut f)?.update_data(Expr::IsUnknown) } Expr::IsNotTrue(expr) => { - transform_box(expr, &mut f)?.map_data(Expr::IsNotTrue) + transform_box(expr, &mut f)?.update_data(Expr::IsNotTrue) } Expr::IsNotFalse(expr) => { - transform_box(expr, &mut f)?.map_data(Expr::IsNotFalse) + transform_box(expr, &mut f)?.update_data(Expr::IsNotFalse) } Expr::IsNotUnknown(expr) => { - transform_box(expr, &mut f)?.map_data(Expr::IsNotUnknown) + transform_box(expr, &mut f)?.update_data(Expr::IsNotUnknown) + } + Expr::Negative(expr) => { + transform_box(expr, &mut f)?.update_data(Expr::Negative) } - Expr::Negative(expr) => transform_box(expr, &mut f)?.map_data(Expr::Negative), Expr::Between(Between { expr, negated, low, high, }) => transform_box(expr, &mut f)? - .map_data(|new_expr| (new_expr, low, high)) - .and_then_transform(|(new_expr, low, high)| { + .update_data(|new_expr| (new_expr, low, high)) + .try_transform_node(|(new_expr, low, high)| { Ok(transform_box(low, &mut f)? - .map_data(|new_low| (new_expr, new_low, high))) + .update_data(|new_low| (new_expr, new_low, high))) })? - .and_then_transform(|(new_expr, new_low, high)| { + .try_transform_node(|(new_expr, new_low, high)| { Ok(transform_box(high, &mut f)? - .map_data(|new_high| (new_expr, new_low, new_high))) + .update_data(|new_high| (new_expr, new_low, new_high))) })? - .map_data(|(new_expr, new_low, new_high)| { + .update_data(|(new_expr, new_low, new_high)| { Expr::Between(Between::new(new_expr, negated, new_low, new_high)) }), Expr::Case(Case { @@ -263,42 +268,42 @@ impl TreeNode for Expr { when_then_expr, else_expr, }) => transform_option_box(expr, &mut f)? - .map_data(|new_expr| (new_expr, when_then_expr, else_expr)) - .and_then_transform(|(new_expr, when_then_expr, else_expr)| { + .update_data(|new_expr| (new_expr, when_then_expr, else_expr)) + .try_transform_node(|(new_expr, when_then_expr, else_expr)| { Ok(when_then_expr .into_iter() - .map_till_continue_and_collect(|(when, then)| { + .map_until_stop_and_collect(|(when, then)| { transform_box(when, &mut f)? - .map_data(|new_when| (new_when, then)) - .and_then_transform(|(new_when, then)| { + .update_data(|new_when| (new_when, then)) + .try_transform_node(|(new_when, then)| { Ok(transform_box(then, &mut f)? - .map_data(|new_then| (new_when, new_then))) + .update_data(|new_then| (new_when, new_then))) }) })? - .map_data(|new_when_then_expr| { + .update_data(|new_when_then_expr| { (new_expr, new_when_then_expr, else_expr) })) })? - .and_then_transform(|(new_expr, new_when_then_expr, else_expr)| { - Ok(transform_option_box(else_expr, &mut f)?.map_data( + .try_transform_node(|(new_expr, new_when_then_expr, else_expr)| { + Ok(transform_option_box(else_expr, &mut f)?.update_data( |new_else_expr| (new_expr, new_when_then_expr, new_else_expr), )) })? - .map_data(|(new_expr, new_when_then_expr, new_else_expr)| { + .update_data(|(new_expr, new_when_then_expr, new_else_expr)| { Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr)) }), Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)? - .map_data(|be| Expr::Cast(Cast::new(be, data_type))), + .update_data(|be| Expr::Cast(Cast::new(be, data_type))), Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)? - .map_data(|be| Expr::TryCast(TryCast::new(be, data_type))), + .update_data(|be| Expr::TryCast(TryCast::new(be, data_type))), Expr::Sort(Sort { expr, asc, nulls_first, }) => transform_box(expr, &mut f)? - .map_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), + .update_data(|be| Expr::Sort(Sort::new(be, asc, nulls_first))), Expr::ScalarFunction(ScalarFunction { func_def, args }) => { - transform_vec(args, &mut f)?.flat_map_data(|new_args| match func_def { + transform_vec(args, &mut f)?.map_data(|new_args| match func_def { ScalarFunctionDefinition::BuiltIn(fun) => { Ok(Expr::ScalarFunction(ScalarFunction::new(fun, new_args))) } @@ -318,18 +323,20 @@ impl TreeNode for Expr { window_frame, null_treatment, }) => transform_vec(args, &mut f)? - .map_data(|new_args| (new_args, partition_by, order_by)) - .and_then_transform(|(new_args, partition_by, order_by)| { - Ok(transform_vec(partition_by, &mut f)?.map_data( + .update_data(|new_args| (new_args, partition_by, order_by)) + .try_transform_node(|(new_args, partition_by, order_by)| { + Ok(transform_vec(partition_by, &mut f)?.update_data( |new_partition_by| (new_args, new_partition_by, order_by), )) })? - .and_then_transform(|(new_args, new_partition_by, order_by)| { - Ok(transform_vec(order_by, &mut f)?.map_data(|new_order_by| { - (new_args, new_partition_by, new_order_by) - })) + .try_transform_node(|(new_args, new_partition_by, order_by)| { + Ok( + transform_vec(order_by, &mut f)?.update_data(|new_order_by| { + (new_args, new_partition_by, new_order_by) + }), + ) })? - .map_data(|(new_args, new_partition_by, new_order_by)| { + .update_data(|(new_args, new_partition_by, new_order_by)| { Expr::WindowFunction(WindowFunction::new( fun, new_args, @@ -346,16 +353,16 @@ impl TreeNode for Expr { filter, order_by, }) => transform_vec(args, &mut f)? - .map_data(|new_args| (new_args, filter, order_by)) - .and_then_transform(|(new_args, filter, order_by)| { + .update_data(|new_args| (new_args, filter, order_by)) + .try_transform_node(|(new_args, filter, order_by)| { Ok(transform_option_box(filter, &mut f)? - .map_data(|new_filter| (new_args, new_filter, order_by))) + .update_data(|new_filter| (new_args, new_filter, order_by))) })? - .and_then_transform(|(new_args, new_filter, order_by)| { + .try_transform_node(|(new_args, new_filter, order_by)| { Ok(transform_option_vec(order_by, &mut f)? - .map_data(|new_order_by| (new_args, new_filter, new_order_by))) + .update_data(|new_order_by| (new_args, new_filter, new_order_by))) })? - .flat_map_data(|(new_args, new_filter, new_order_by)| match func_def { + .map_data(|(new_args, new_filter, new_order_by)| match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { Ok(Expr::AggregateFunction(AggregateFunction::new( fun, @@ -380,13 +387,13 @@ impl TreeNode for Expr { })?, Expr::GroupingSet(grouping_set) => match grouping_set { GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)? - .map_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), + .update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))), GroupingSet::Cube(exprs) => transform_vec(exprs, &mut f)? - .map_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), + .update_data(|ve| Expr::GroupingSet(GroupingSet::Cube(ve))), GroupingSet::GroupingSets(lists_of_exprs) => lists_of_exprs .into_iter() - .map_till_continue_and_collect(|exprs| transform_vec(exprs, &mut f))? - .map_data(|new_lists_of_exprs| { + .map_until_stop_and_collect(|exprs| transform_vec(exprs, &mut f))? + .update_data(|new_lists_of_exprs| { Expr::GroupingSet(GroupingSet::GroupingSets(new_lists_of_exprs)) }), }, @@ -395,17 +402,18 @@ impl TreeNode for Expr { list, negated, }) => transform_box(expr, &mut f)? - .map_data(|new_expr| (new_expr, list)) - .and_then_transform(|(new_expr, list)| { + .update_data(|new_expr| (new_expr, list)) + .try_transform_node(|(new_expr, list)| { Ok(transform_vec(list, &mut f)? - .map_data(|new_list| (new_expr, new_list))) + .update_data(|new_list| (new_expr, new_list))) })? - .map_data(|(new_expr, new_list)| { + .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), Expr::GetIndexedField(GetIndexedField { expr, field }) => { - transform_box(expr, &mut f)? - .map_data(|be| Expr::GetIndexedField(GetIndexedField::new(be, field))) + transform_box(expr, &mut f)?.update_data(|be| { + Expr::GetIndexedField(GetIndexedField::new(be, field)) + }) } }) } @@ -415,7 +423,7 @@ fn transform_box(be: Box, f: &mut F) -> Result>> where F: FnMut(Expr) -> Result>, { - Ok(f(*be)?.map_data(Box::new)) + Ok(f(*be)?.update_data(Box::new)) } fn transform_option_box( @@ -426,7 +434,7 @@ where F: FnMut(Expr) -> Result>, { obe.map_or(Ok(Transformed::no(None)), |be| { - Ok(transform_box(be, f)?.map_data(Some)) + Ok(transform_box(be, f)?.update_data(Some)) }) } @@ -439,7 +447,7 @@ where F: FnMut(Expr) -> Result>, { ove.map_or(Ok(Transformed::no(None)), |ve| { - Ok(transform_vec(ve, f)?.map_data(Some)) + Ok(transform_vec(ve, f)?.update_data(Some)) }) } @@ -448,5 +456,5 @@ fn transform_vec(ve: Vec, f: &mut F) -> Result>> where F: FnMut(Expr) -> Result>, { - ve.into_iter().map_till_continue_and_collect(f) + ve.into_iter().map_until_stop_and_collect(f) } diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index e167342de93e..6b2b9d055c81 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -62,10 +62,20 @@ impl TreeNode for LogicalPlan { ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::visit_subqueries`] before visiting its children - handle_visit_recursion_down!(visitor.f_down(self)?); - self.visit_subqueries(visitor)?; - handle_visit_recursion_up!(self.apply_children(&mut |n| n.visit(visitor))?); - visitor.f_up(self) + match visitor.f_down(self)? { + TreeNodeRecursion::Continue => { + self.visit_subqueries(visitor)?; + handle_visit_recursion_up!( + self.apply_children(&mut |n| n.visit(visitor))? + ); + visitor.f_up(self) + } + TreeNodeRecursion::Jump => { + self.visit_subqueries(visitor)?; + visitor.f_up(self) + } + TreeNodeRecursion::Stop => Ok(TreeNodeRecursion::Stop), + } } fn apply_children Result>( @@ -85,28 +95,18 @@ impl TreeNode for LogicalPlan { F: FnMut(Self) -> Result>, { let old_children = self.inputs(); - let t = old_children + let new_children = old_children .iter() .map(|&c| c.clone()) - .map_till_continue_and_collect(f)?; - // TODO: Currently `assert_eq!(t.transformed, t2)` fails as - // `t.transformed` quality comes from if the transformation closures fill the - // field correctly. - // Once we trust `t.transformed` we can remove the additional check in - // `t2`. - let t2 = old_children - .into_iter() - .zip(t.data.iter()) - .any(|(c1, c2)| c1 != c2); - - // Propagate up `t.transformed` and `t.tnr` along with the node containing - // transformed children. - if t2 { - t.flat_map_data(|new_children| { + .map_until_stop_and_collect(f)?; + // Propagate up `new_children.transformed` and `new_children.tnr` + // along with the node containing transformed children. + if new_children.transformed { + new_children.map_data(|new_children| { self.with_new_exprs(self.expressions(), new_children) }) } else { - Ok(t.map_data(|_| self)) + Ok(new_children.update_data(|_| self)) } } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index b066f35b828c..15eb1035f240 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -174,7 +174,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { .as_ref() .clone() .transform_down(&analyze_internal)? - .map_data(|new_plan| { + .update_data(|new_plan| { ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -189,7 +189,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { .as_ref() .clone() .transform_down(&analyze_internal)? - .map_data(|new_plan| { + .update_data(|new_plan| { Expr::InSubquery(InSubquery::new( expr, Subquery { @@ -204,7 +204,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { .as_ref() .clone() .transform_down(&analyze_internal)? - .map_data(|new_plan| { + .update_data(|new_plan| { Expr::Exists(expr::Exists { subquery: Subquery { subquery: Arc::new(new_plan), diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index df417ccc3f1f..b21a56d2aee9 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -642,21 +642,20 @@ impl ExprIdentifierVisitor<'_> { /// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` /// before it. - fn pop_enter_mark(&mut self) -> (usize, Identifier) { + fn pop_enter_mark(&mut self) -> Option<(usize, Identifier)> { let mut desc = String::new(); while let Some(item) = self.visit_stack.pop() { match item { VisitRecord::EnterMark(idx) => { - return (idx, desc); + return Some((idx, desc)); } VisitRecord::ExprItem(s) => { desc.push_str(&s); } } } - - unreachable!("Enter mark should paired with node number"); + None } } @@ -680,7 +679,12 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let (idx, sub_expr_desc) = self.pop_enter_mark(); + let (idx, sub_expr_desc) = + if let Some((idx, sub_expr_desc)) = self.pop_enter_mark() { + (idx, sub_expr_desc) + } else { + return Ok(TreeNodeRecursion::Continue); + }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; From 6f8f9404347da408531710bd618449dbfeec933e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 14:12:24 +0100 Subject: [PATCH 29/40] fix --- datafusion/common/src/tree_node.rs | 80 +++++++++++------------------- 1 file changed, 28 insertions(+), 52 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0f5497e4d543..3851658db76d 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -105,15 +105,31 @@ macro_rules! handle_transform_recursion_down { /// continuation and [`TreeNodeRecursion`] state propagation. #[macro_export] macro_rules! handle_transform_recursion { - ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => { - $F_DOWN?.try_transform_node_with( - |n| { - n.map_children($F_SELF)? - .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump)) - }, - Some(TreeNodeRecursion::Continue), - ) - }; + ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ + let pre_visited = $F_DOWN?; + match pre_visited.tnr { + TreeNodeRecursion::Continue => { + let with_updated_children = pre_visited + .data + .map_children($F_SELF)? + .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump))?; + Ok(Transformed { + transformed: with_updated_children.transformed + || pre_visited.transformed, + ..with_updated_children + }) + } + TreeNodeRecursion::Jump => { + let post_visited = $F_UP(pre_visited.data)?; + Ok(Transformed::new( + post_visited.data, + post_visited.transformed || pre_visited.transformed, + TreeNodeRecursion::Continue, + )) + } + TreeNodeRecursion::Stop => Ok(pre_visited), + } + }}; } /// This macro is used to determine continuation during bottom-up transforming traversals. @@ -213,34 +229,9 @@ pub trait TreeNode: Sized { self, rewriter: &mut R, ) -> Result> { - let pre_visited = rewriter.f_down(self)?; - match pre_visited.tnr { - TreeNodeRecursion::Continue => { - let with_updated_children = pre_visited - .data - .map_children(|c| c.rewrite(rewriter))? - .try_transform_node_with( - |n| rewriter.f_up(n), - Some(TreeNodeRecursion::Jump), - )?; - Ok(Transformed { - transformed: with_updated_children.transformed - || pre_visited.transformed, - ..with_updated_children - }) - } - TreeNodeRecursion::Jump => { - let pre_visited_transformed = pre_visited.transformed; - let post_visited = rewriter.f_up(pre_visited.data)?; - - Ok(Transformed { - tnr: TreeNodeRecursion::Continue, - transformed: post_visited.transformed || pre_visited_transformed, - data: post_visited.data, - }) - } - TreeNodeRecursion::Stop => Ok(pre_visited), - } + handle_transform_recursion!(rewriter.f_down(self), |c| c.rewrite(rewriter), |n| { + rewriter.f_up(n) + }) } /// Applies `f` to the node and its children. `f` is applied in a preoder way, @@ -911,21 +902,6 @@ mod tests { .collect() } - fn f_down_jump_on_a_transformed_tree() -> TestTreeNode { - let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![], "f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); - let node_c = - TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); - let node_h = TestTreeNode::new(vec![], "f_up(f_down(h))".to_string()); - let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); - let node_f = - TestTreeNode::new(vec![node_e, node_g], "f_up(f_down(f))".to_string()); - let node_i = TestTreeNode::new(vec![node_f], "f_up(f_down(i))".to_string()); - TestTreeNode::new(vec![node_i], "f_up(f_down(j))".to_string()) - } - fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new(vec![], "f_down(a)".to_string()); let node_b = TestTreeNode::new(vec![], "f_down(b)".to_string()); From 61272eab6d263b0ea7f9f78af6ab516e4e63a2e4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 14:50:13 +0100 Subject: [PATCH 30/40] minor fixes --- datafusion/common/src/tree_node.rs | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 3851658db76d..825f42e32a6e 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -520,20 +520,12 @@ impl Transformed { /// Applies the given `f` to the data of [`Transformed`] object. pub fn update_data U>(self, f: F) -> Transformed { - Transformed { - data: f(self.data), - transformed: self.transformed, - tnr: self.tnr, - } + Transformed::new(f(self.data), self.transformed, self.tnr) } /// Maps the data of [`Transformed`] object to the result of the given `f`. pub fn map_data Result>(self, f: F) -> Result> { - f(self.data).map(|data| Transformed { - data, - transformed: self.transformed, - tnr: self.tnr, - }) + f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } /// According to the TreeNodeRecursion condition on the node, the function decides @@ -608,11 +600,7 @@ impl TransformedIterator for I { }) }) .collect::>>()?; - Ok(Transformed { - data: new_data, - transformed: new_transformed, - tnr: new_tnr, - }) + Ok(Transformed::new(new_data, new_transformed, new_tnr)) } } From 54a1b469773f0cba48f0cdcce33cec3ee7c71ea1 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 14:53:42 +0100 Subject: [PATCH 31/40] fix f_up call when f_down returns jump --- datafusion/common/src/tree_node.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 825f42e32a6e..8cfcea8a41d6 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -121,11 +121,10 @@ macro_rules! handle_transform_recursion { } TreeNodeRecursion::Jump => { let post_visited = $F_UP(pre_visited.data)?; - Ok(Transformed::new( - post_visited.data, - post_visited.transformed || pre_visited.transformed, - TreeNodeRecursion::Continue, - )) + Ok(Transformed { + transformed: post_visited.transformed || pre_visited.transformed, + ..post_visited + }) } TreeNodeRecursion::Stop => Ok(pre_visited), } From f7800073fe5b738e0ca8cb3749d7c98528d19ee4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 17:27:51 +0100 Subject: [PATCH 32/40] simplify code --- datafusion/common/src/tree_node.rs | 88 +++++++++++++++--------------- 1 file changed, 45 insertions(+), 43 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 8cfcea8a41d6..70e2b7afa55b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -90,7 +90,7 @@ macro_rules! handle_transform_recursion_down { ($F_DOWN:expr, $F_SELF:expr) => { $F_DOWN?.try_transform_node_with( |n| n.map_children($F_SELF), - Some(TreeNodeRecursion::Continue), + TreeNodeRecursion::Continue, ) }; } @@ -108,26 +108,18 @@ macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ let pre_visited = $F_DOWN?; match pre_visited.tnr { - TreeNodeRecursion::Continue => { - let with_updated_children = pre_visited - .data - .map_children($F_SELF)? - .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump))?; - Ok(Transformed { - transformed: with_updated_children.transformed - || pre_visited.transformed, - ..with_updated_children - }) + TreeNodeRecursion::Continue => pre_visited + .data + .map_children($F_SELF)? + .try_transform_node_with($F_UP, TreeNodeRecursion::Jump), + TreeNodeRecursion::Jump => + { + #[allow(clippy::redundant_closure_call)] + $F_UP(pre_visited.data) } - TreeNodeRecursion::Jump => { - let post_visited = $F_UP(pre_visited.data)?; - Ok(Transformed { - transformed: post_visited.transformed || pre_visited.transformed, - ..post_visited - }) - } - TreeNodeRecursion::Stop => Ok(pre_visited), + TreeNodeRecursion::Stop => return Ok(pre_visited), } + .map(|post_visited| post_visited.update_transformed(pre_visited.transformed)) }}; } @@ -141,7 +133,7 @@ macro_rules! handle_transform_recursion_up { ($NODE:expr, $F_SELF:expr, $F_UP:expr) => { $NODE .map_children($F_SELF)? - .try_transform_node_with($F_UP, Some(TreeNodeRecursion::Jump)) + .try_transform_node_with($F_UP, TreeNodeRecursion::Jump) }; } @@ -522,46 +514,56 @@ impl Transformed { Transformed::new(f(self.data), self.transformed, self.tnr) } + /// Updates the transformed state based on the current and the new state. + pub fn update_transformed(self, transformed: bool) -> Self { + Self { + transformed: self.transformed || transformed, + ..self + } + } + + /// Sets a new [`TreeNodeRecursion`]. + pub fn update_tnr(self, tnr: TreeNodeRecursion) -> Self { + Self { tnr, ..self } + } + /// Maps the data of [`Transformed`] object to the result of the given `f`. pub fn map_data Result>(self, f: F) -> Result> { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } - /// According to the TreeNodeRecursion condition on the node, the function decides - /// applying the given `f` to the node's data. Handling [`TreeNodeRecursion::Continue`] - /// and [`TreeNodeRecursion::Stop`] is straightforward, but [`TreeNodeRecursion::Jump`] - /// can behave differently when we are traversing down or up on a tree. If `return_if_jump` - /// is `Some`, `jump` condition on the node would stop the recursion with the given - /// [`TreeNodeRecursion`] statement. + /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] is + /// straightforward, but [`TreeNodeRecursion::Jump`] can behave differently when we + /// are traversing down or up on a tree. + /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Jump`] recursion is + /// stopped with the given `return_if_jump` [`TreeNodeRecursion`] statement. fn try_transform_node_with Result>>( self, f: F, - return_if_jump: Option, + return_if_jump: TreeNodeRecursion, ) -> Result> { match self.tnr { - TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Jump => { - if let Some(tnr) = return_if_jump { - return Ok(Transformed { tnr, ..self }); - } + TreeNodeRecursion::Continue => { + f(self.data).map(|t| t.update_transformed(self.transformed)) } - TreeNodeRecursion::Stop => return Ok(self), - }; - let t = f(self.data)?; - Ok(Transformed { - transformed: t.transformed || self.transformed, - ..t - }) + TreeNodeRecursion::Jump => Ok(self.update_tnr(return_if_jump)), + TreeNodeRecursion::Stop => Ok(self), + } } - /// More simple version of [`Self::try_transform_node_with`]. If [`TreeNodeRecursion`] - /// of the node is [`TreeNodeRecursion::Continue`] or [`TreeNodeRecursion::Jump`], - /// transformation is applied to the node. Otherwise, it remains as it is. + /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Continue`] or + /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. Otherwise, it + /// remains as it is. pub fn try_transform_node Result>>( self, f: F, ) -> Result> { - self.try_transform_node_with(f, None) + match self.tnr { + TreeNodeRecursion::Continue => {} + TreeNodeRecursion::Jump => {} + TreeNodeRecursion::Stop => return Ok(self), + }; + f(self.data).map(|t| t.update_transformed(self.transformed)) } } From e75f2e350312fed227e2193274312eaa2d2b68e4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 17:28:15 +0100 Subject: [PATCH 33/40] minor fix --- datafusion/common/src/tree_node.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 70e2b7afa55b..5f27aa22f036 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -570,8 +570,7 @@ impl Transformed { pub trait TransformedIterator: Iterator { fn map_until_stop_and_collect(self, f: F) -> Result>> where - F: FnMut(Self::Item) -> Result>, - Self: Sized; + F: FnMut(Self::Item) -> Result>; } impl TransformedIterator for I { From 60e76ac83a829224a9d9b3d0f3a346542241ad2e Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 18:25:24 +0100 Subject: [PATCH 34/40] revert unnecessary changes --- datafusion/common/src/tree_node.rs | 20 ++++++++----------- .../enforce_distribution.rs | 2 +- .../src/physical_optimizer/enforce_sorting.rs | 4 ++-- .../replace_with_order_preserving_variants.rs | 4 ++-- datafusion/physical-expr/src/physical_expr.rs | 7 +++---- datafusion/physical-expr/src/tree_node.rs | 12 +++++------ datafusion/physical-plan/src/empty.rs | 2 +- datafusion/physical-plan/src/lib.rs | 7 +++---- .../physical-plan/src/placeholder_row.rs | 3 +-- datafusion/physical-plan/src/tree_node.rs | 12 +++++------ 10 files changed, 33 insertions(+), 40 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5f27aa22f036..3c93c6f0ec58 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -617,7 +617,7 @@ pub trait DynTreeNode { &self, arc_self: Arc, new_children: Vec>, - ) -> Result>>; + ) -> Result>; } /// Blanket implementation for Arc for any tye that implements @@ -647,10 +647,8 @@ impl TreeNode for Arc { // along with the node containing transformed children. if new_children.transformed { let arc_self = Arc::clone(&self); - new_children.map_data(|children| { - self.with_new_arc_children(arc_self, children) - .map(|new| new.data) - }) + new_children + .map_data(|children| self.with_new_arc_children(arc_self, children)) } else { Ok(Transformed::no(self)) } @@ -671,7 +669,7 @@ pub trait ConcreteTreeNode: Sized { fn take_children(self) -> (Self, Vec); /// Reattaches updated child nodes to the node, returning the updated node. - fn with_new_children(self, children: Vec) -> Result>; + fn with_new_children(self, children: Vec) -> Result; } impl TreeNode for T { @@ -698,13 +696,11 @@ impl TreeNode for T { if new_children.transformed { // Propagate up `t.transformed` and `t.tnr` along with // the node containing transformed children. - new_children.map_data(|children| { - new_self.with_new_children(children).map(|new| new.data) - }) + new_children.map_data(|children| new_self.with_new_children(children)) } else { - Ok(Transformed::no( - new_self.with_new_children(new_children.data)?.data, - )) + new_self + .with_new_children(new_children.data) + .map(Transformed::no) } } else { Ok(Transformed::no(new_self)) diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index d16dcc7eebdb..2ac817db170c 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -1022,7 +1022,7 @@ fn replace_order_preserving_variants( } } - context.update_plan_from_children().map(|t| t.data) + context.update_plan_from_children() } /// This utility function adds a [`SortExec`] above an operator according to the diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index 185c74298588..bdc2659dbe3d 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -115,7 +115,7 @@ fn update_sort_ctx_children( } node.data = data; - node.update_plan_from_children().map(|t| t.data) + node.update_plan_from_children() } /// This object is used within the [`EnforceSorting`] rule to track the closest @@ -519,7 +519,7 @@ fn remove_corresponding_coalesce_in_sub_plan( .collect::>()?; } - requirements.update_plan_from_children().map(|t| t.data) + requirements.update_plan_from_children() } /// Updates child to remove the unnecessary sort below it. diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 04031e814911..6320c69df7e5 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -138,7 +138,7 @@ fn plan_with_order_preserving_variants( } } - sort_input.update_plan_from_children().map(|t| t.data) + sort_input.update_plan_from_children() } /// Calculates the updated plan by replacing operators that preserve ordering @@ -184,7 +184,7 @@ fn plan_with_order_breaking_variants( let coalesce = CoalescePartitionsExec::new(child); sort_input.plan = Arc::new(coalesce) as _; } else { - return sort_input.update_plan_from_children().map(|t| t.data); + return sort_input.update_plan_from_children(); } sort_input.children[0].data = false; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c1b36b3e2a5d..39b8de81af56 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -32,7 +32,6 @@ use datafusion_common::{internal_err, not_impl_err, Result}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::ColumnarValue; -use datafusion_common::tree_node::Transformed; use itertools::izip; /// `PhysicalExpr` evaluate DataFusion expressions such as `A + 1`, or `CAST(c1 @@ -243,7 +242,7 @@ pub type PhysicalExprRef = Arc; pub fn with_new_children_if_necessary( expr: Arc, children: Vec>, -) -> Result>> { +) -> Result> { let old_children = expr.children(); if children.len() != old_children.len() { internal_err!("PhysicalExpr: Wrong number of children") @@ -253,9 +252,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::yes(expr.with_new_children(children)?)) + Ok(expr.with_new_children(children)?) } else { - Ok(Transformed::no(expr)) + Ok(expr) } } diff --git a/datafusion/physical-expr/src/tree_node.rs b/datafusion/physical-expr/src/tree_node.rs index 68a5fc06e8ee..0e2aa7d63679 100644 --- a/datafusion/physical-expr/src/tree_node.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::physical_expr::{with_new_children_if_necessary, PhysicalExpr}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn PhysicalExpr { @@ -34,7 +34,7 @@ impl DynTreeNode for dyn PhysicalExpr { &self, arc_self: Arc, new_children: Vec>, - ) -> Result>> { + ) -> Result> { with_new_children_if_necessary(arc_self, new_children) } } @@ -61,11 +61,11 @@ impl ExprContext { } } - pub fn update_expr_from_children(mut self) -> Result> { + pub fn update_expr_from_children(mut self) -> Result { let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); let t = with_new_children_if_necessary(self.expr, children_expr)?; - self.expr = t.data; - Ok(Transformed::new(self, t.transformed, t.tnr)) + self.expr = t; + Ok(self) } } @@ -94,7 +94,7 @@ impl ConcreteTreeNode for ExprContext { (self, children) } - fn with_new_children(mut self, children: Vec) -> Result> { + fn with_new_children(mut self, children: Vec) -> Result { self.children = children; self.update_expr_from_children() } diff --git a/datafusion/physical-plan/src/empty.rs b/datafusion/physical-plan/src/empty.rs index 699e8dfc7776..4ff79cdaae70 100644 --- a/datafusion/physical-plan/src/empty.rs +++ b/datafusion/physical-plan/src/empty.rs @@ -182,7 +182,7 @@ mod tests { let schema = test::aggr_test_schema(); let empty = Arc::new(EmptyExec::new(schema.clone())); - let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?.data; + let empty2 = with_new_children_if_necessary(empty.clone(), vec![])?; assert_eq!(empty.schema(), empty2.schema()); let too_many_kids = vec![empty2]; diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index cba90869aacd..b588d31794ff 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -30,7 +30,6 @@ use crate::sorts::sort_preserving_merge::SortPreservingMergeExec; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::Transformed; use datafusion_common::utils::DataPtr; use datafusion_common::Result; use datafusion_execution::TaskContext; @@ -629,7 +628,7 @@ pub fn need_data_exchange(plan: Arc) -> bool { pub fn with_new_children_if_necessary( plan: Arc, children: Vec>, -) -> Result>> { +) -> Result> { let old_children = plan.children(); if children.len() != old_children.len() { internal_err!("Wrong number of children") @@ -639,9 +638,9 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(Transformed::yes(plan.with_new_children(children)?)) + Ok(plan.with_new_children(children)?) } else { - Ok(Transformed::no(plan)) + Ok(plan) } } diff --git a/datafusion/physical-plan/src/placeholder_row.rs b/datafusion/physical-plan/src/placeholder_row.rs index 76e1afa614a9..3880cf3d77af 100644 --- a/datafusion/physical-plan/src/placeholder_row.rs +++ b/datafusion/physical-plan/src/placeholder_row.rs @@ -184,8 +184,7 @@ mod tests { let placeholder = Arc::new(PlaceholderRowExec::new(schema)); - let placeholder_2 = - with_new_children_if_necessary(placeholder.clone(), vec![])?.data; + let placeholder_2 = with_new_children_if_necessary(placeholder.clone(), vec![])?; assert_eq!(placeholder.schema(), placeholder_2.schema()); let too_many_kids = vec![placeholder_2]; diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index a3099b0ac934..6fd88160468c 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use crate::{displayable, with_new_children_if_necessary, ExecutionPlan}; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode, Transformed}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; use datafusion_common::Result; impl DynTreeNode for dyn ExecutionPlan { @@ -34,7 +34,7 @@ impl DynTreeNode for dyn ExecutionPlan { &self, arc_self: Arc, new_children: Vec>, - ) -> Result>> { + ) -> Result> { with_new_children_if_necessary(arc_self, new_children) } } @@ -61,11 +61,11 @@ impl PlanContext { } } - pub fn update_plan_from_children(mut self) -> Result> { + pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); let t = with_new_children_if_necessary(self.plan, children_plans)?; - self.plan = t.data; - Ok(Transformed::new(self, t.transformed, t.tnr)) + self.plan = t; + Ok(self) } } @@ -95,7 +95,7 @@ impl ConcreteTreeNode for PlanContext { (self, children) } - fn with_new_children(mut self, children: Vec) -> Result> { + fn with_new_children(mut self, children: Vec) -> Result { self.children = children; self.update_plan_from_children() } From 3af92ce57969a5b3fe5438ba798c959fd9af41a4 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Thu, 29 Feb 2024 18:34:25 +0100 Subject: [PATCH 35/40] fix `DynTreeNode` and `ConcreteTreeNode` `transformed` and `tnr` propagation --- datafusion/common/src/tree_node.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 3c93c6f0ec58..4e1d5fca68ba 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -647,10 +647,11 @@ impl TreeNode for Arc { // along with the node containing transformed children. if new_children.transformed { let arc_self = Arc::clone(&self); - new_children - .map_data(|children| self.with_new_arc_children(arc_self, children)) + new_children.map_data(|new_children| { + self.with_new_arc_children(arc_self, new_children) + }) } else { - Ok(Transformed::no(self)) + Ok(Transformed::new(self, false, new_children.tnr)) } } else { Ok(Transformed::no(self)) @@ -693,15 +694,9 @@ impl TreeNode for T { let (new_self, children) = self.take_children(); if !children.is_empty() { let new_children = children.into_iter().map_until_stop_and_collect(f)?; - if new_children.transformed { - // Propagate up `t.transformed` and `t.tnr` along with - // the node containing transformed children. - new_children.map_data(|children| new_self.with_new_children(children)) - } else { - new_self - .with_new_children(new_children.data) - .map(Transformed::no) - } + // Propagate up `new_children.transformed` and `new_children.tnr` along with + // the node containing transformed children. + new_children.map_data(|new_children| new_self.with_new_children(new_children)) } else { Ok(Transformed::no(new_self)) } From 6b4b6dd7e2b376320a56734f4a19f1944ebf2693 Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 1 Mar 2024 10:15:10 +0100 Subject: [PATCH 36/40] introduce TransformedResult helper --- datafusion-examples/examples/rewrite_expr.rs | 8 +++---- datafusion/common/src/tree_node.rs | 22 +++++++++++++++++++ .../aggregate_statistics.rs | 6 ++--- .../physical_optimizer/coalesce_batches.rs | 4 ++-- .../combine_partial_final_agg.rs | 6 ++--- .../enforce_distribution.rs | 5 +++-- .../src/physical_optimizer/enforce_sorting.rs | 12 +++++----- .../src/physical_optimizer/join_selection.rs | 4 ++-- .../limited_distinct_aggregation.rs | 8 ++----- .../physical_optimizer/output_requirements.rs | 4 ++-- .../physical_optimizer/pipeline_checker.rs | 4 ++-- .../physical_optimizer/projection_pushdown.rs | 11 +++++----- .../core/src/physical_optimizer/pruning.rs | 3 ++- .../replace_with_order_preserving_variants.rs | 4 ++-- .../core/src/physical_optimizer/test_utils.rs | 4 ++-- .../physical_optimizer/topk_aggregation.rs | 8 ++----- datafusion/expr/src/expr.rs | 4 ++-- datafusion/expr/src/expr_rewriter/mod.rs | 20 +++++++++-------- datafusion/expr/src/expr_rewriter/order_by.rs | 4 ++-- datafusion/expr/src/logical_plan/plan.rs | 6 ++--- .../src/analyzer/count_wildcard_rule.rs | 6 +++-- .../src/analyzer/inline_table_scan.rs | 4 ++-- .../optimizer/src/common_subexpr_eliminate.rs | 5 +++-- datafusion/optimizer/src/push_down_filter.rs | 6 +++-- .../simplify_expressions/expr_simplifier.rs | 8 +++---- .../src/simplify_expressions/guarantees.rs | 13 ++++++----- .../src/unwrap_cast_in_comparison.rs | 4 ++-- .../physical-expr/src/equivalence/class.rs | 6 ++--- .../physical-expr/src/equivalence/mod.rs | 4 ++-- .../src/equivalence/projection.rs | 4 ++-- .../src/equivalence/properties.rs | 4 ++-- .../physical-expr/src/expressions/case.rs | 6 ++--- datafusion/physical-expr/src/utils/mod.rs | 6 +++-- .../physical-plan/src/recursive_query.rs | 4 ++-- datafusion/sql/src/utils.rs | 8 +++---- 35 files changed, 131 insertions(+), 104 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index a2f6e44f115f..2bb432d65be4 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -17,7 +17,7 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF, @@ -103,7 +103,7 @@ impl MyAnalyzerRule { _ => Transformed::no(plan), }) }) - .map(|t| t.data) + .data() } fn analyze_expr(expr: Expr) -> Result { @@ -119,7 +119,7 @@ impl MyAnalyzerRule { _ => Transformed::no(expr), }) }) - .map(|t| t.data) + .data() } } @@ -185,7 +185,7 @@ fn my_rewrite(expr: Expr) -> Result { _ => Transformed::no(expr), }) }) - .map(|t| t.data) + .data() } #[derive(Default)] diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 4e1d5fca68ba..1d10391508f4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -604,6 +604,28 @@ impl TransformedIterator for I { } } +pub trait TransformedResult { + fn data(self) -> Result; + + fn transformed(self) -> Result; + + fn tnr(self) -> Result; +} + +impl TransformedResult for Result> { + fn data(self) -> Result { + self.map(|t| t.data) + } + + fn transformed(self) -> Result { + self.map(|t| t.transformed) + } + + fn tnr(self) -> Result { + self.map(|t| t.tnr) + } +} + /// Helper trait for implementing [`TreeNode`] that have children stored as Arc's /// /// If some trait object, such as `dyn T`, implements this trait, diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 5f872831ef93..df54222270ce 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -27,7 +27,7 @@ use crate::physical_plan::{expressions, AggregateExpr, ExecutionPlan, Statistics use crate::scalar::ScalarValue; use datafusion_common::stats::Precision; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -88,11 +88,11 @@ impl PhysicalOptimizerRule for AggregateStatistics { plan.map_children(|child| { self.optimize(child, _config).map(Transformed::yes) }) - .map(|t| t.data) + .data() } } else { plan.map_children(|child| self.optimize(child, _config).map(Transformed::yes)) - .map(|t| t.data) + .data() } } diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index e3565e451669..01213ed8df1a 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -27,7 +27,7 @@ use crate::{ repartition::RepartitionExec, Partitioning, }, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use std::sync::Arc; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that @@ -79,7 +79,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { Ok(Transformed::no(plan)) } }) - .map(|t| t.data) + .data() } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index ccc9a2909cca..db09d504c468 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -26,7 +26,7 @@ use crate::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGro use crate::physical_plan::ExecutionPlan; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr}; @@ -114,7 +114,7 @@ impl PhysicalOptimizerRule for CombinePartialFinalAggregate { Transformed::no(plan) }) }) - .map(|t| t.data) + .data() } fn name(&self) -> &str { @@ -191,7 +191,7 @@ fn discard_column_index(group_expr: Arc) -> Arc &str { @@ -687,7 +687,7 @@ mod tests { { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_up(&ensure_sorting).map(|t| t.data) + .transform_up(&ensure_sorting).data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -695,7 +695,7 @@ mod tests { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions - .transform_up(¶llelize_sorts).map(|t| t.data) + .transform_up(¶llelize_sorts).data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -712,14 +712,14 @@ mod tests { true, state.config_options(), ) - }).map(|t| t.data) + }).data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); sort_pushdown - .transform_down(&pushdown_sorts).map(|t| t.data) + .transform_down(&pushdown_sorts).data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index 192eb49a6d09..cd710ce46990 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -37,7 +37,7 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use arrow_schema::Schema; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, JoinSide, JoinType}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::sort_properties::SortProperties; @@ -261,7 +261,7 @@ impl PhysicalOptimizerRule for JoinSelection { collect_threshold_num_rows, ) }) - .map(|t| t.data) + .data() } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index faf42353dba0..8b14bf067d3c 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -26,7 +26,7 @@ use crate::physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use itertools::Itertools; @@ -138,11 +138,7 @@ impl LimitedDistinctAggregation { rewrite_applicable = false; Ok(Transformed::no(plan)) }; - let child = child - .clone() - .transform_down_mut(&mut closure) - .map(|t| t.data) - .ok()?; + let child = child.clone().transform_down_mut(&mut closure).data().ok()?; if is_global_limit { return Some(Arc::new(GlobalLimitExec::new( child, diff --git a/datafusion/core/src/physical_optimizer/output_requirements.rs b/datafusion/core/src/physical_optimizer/output_requirements.rs index ded0ac45ee87..bd71b3e8ed80 100644 --- a/datafusion/core/src/physical_optimizer/output_requirements.rs +++ b/datafusion/core/src/physical_optimizer/output_requirements.rs @@ -29,7 +29,7 @@ use crate::physical_plan::sorts::sort::SortExec; use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_physical_expr::{Distribution, LexRequirement, PhysicalSortRequirement}; use datafusion_physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; @@ -203,7 +203,7 @@ impl PhysicalOptimizerRule for OutputRequirements { Ok(Transformed::no(plan)) } }) - .map(|t| t.data), + .data(), } } diff --git a/datafusion/core/src/physical_optimizer/pipeline_checker.rs b/datafusion/core/src/physical_optimizer/pipeline_checker.rs index 35166b0dfdf3..1dc8bc5042bf 100644 --- a/datafusion/core/src/physical_optimizer/pipeline_checker.rs +++ b/datafusion/core/src/physical_optimizer/pipeline_checker.rs @@ -28,7 +28,7 @@ use crate::physical_plan::{ExecutionPlan, ExecutionPlanProperties}; use datafusion_common::config::OptimizerOptions; use datafusion_common::plan_err; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::intervals::utils::{check_support, is_datatype_supported}; use datafusion_physical_plan::joins::SymmetricHashJoinExec; @@ -51,7 +51,7 @@ impl PhysicalOptimizerRule for PipelineChecker { config: &ConfigOptions, ) -> Result> { plan.transform_up(&|p| check_finiteness_requirements(p, &config.optimizer)) - .map(|t| t.data) + .data() } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 1c9a8023c6e0..78efdc9e0ce6 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -43,7 +43,9 @@ use crate::physical_plan::{Distribution, ExecutionPlan, ExecutionPlanProperties} use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{DataFusionError, JoinSide}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::{ @@ -73,8 +75,7 @@ impl PhysicalOptimizerRule for ProjectionPushdown { plan: Arc, _config: &ConfigOptions, ) -> Result> { - plan.transform_down(&remove_unnecessary_projections) - .map(|t| t.data) + plan.transform_down(&remove_unnecessary_projections).data() } fn name(&self) -> &str { @@ -931,7 +932,7 @@ fn update_expr( ) } }) - .map(|t| t.data); + .data(); new_expr.map(|e| (state == RewriteState::RewrittenValid).then_some(e)) } @@ -1062,7 +1063,7 @@ fn new_columns_for_join_on( Ok(Transformed::no(expr)) } }) - .map(|t| t.data) + .data() .ok() }) .collect::>(); diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 2eeb9c970ca5..18df763908fa 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -36,6 +36,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_array::cast::AsArray; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_err, plan_err, tree_node::{Transformed, TreeNode}, @@ -1040,7 +1041,7 @@ fn rewrite_column_expr( Ok(Transformed::no(expr)) }) - .map(|t| t.data) + .data() } fn reverse_operator(op: Operator) -> Result { diff --git a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs index 6320c69df7e5..e8b6a78b929e 100644 --- a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs +++ b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs @@ -293,7 +293,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use datafusion_common::tree_node::TreeNode; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{Result, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{JoinType, Operator}; @@ -395,7 +395,7 @@ mod tests { // Run the rule top-down let config = SessionConfig::new().with_prefer_existing_sort($PREFER_EXISTING_SORT); let plan_with_pipeline_fixer = OrderPreservationContext::new_default(physical_plan); - let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).map(|t| t.data).and_then(check_integrity)?; + let parallel = plan_with_pipeline_fixer.transform_up(&|plan_with_pipeline_fixer| replace_with_order_preserving_variants(plan_with_pipeline_fixer, false, false, config.options())).data().and_then(check_integrity)?; let optimized_physical_plan = parallel.plan; // Get string representation of the plan diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 0e19f55b2bcd..d944cedb0f96 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -40,7 +40,7 @@ use crate::physical_plan::{ExecutionPlan, InputOrderMode, Partitioning}; use crate::prelude::{CsvReadOptions, SessionContext}; use arrow_schema::{Schema, SchemaRef, SortOptions}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{JoinType, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_expr::{AggregateFunction, WindowFrame, WindowFunctionDefinition}; @@ -390,5 +390,5 @@ pub fn check_integrity(context: PlanContext) -> Result Result { } }) }) - .map(|t| t.data) + .data() } /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions @@ -68,7 +70,7 @@ pub fn normalize_col_with_schemas( } }) }) - .map(|t| t.data) + .data() } /// See [`Column::normalize_with_schemas_and_ambiguity_check`] for usage @@ -98,7 +100,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( } }) }) - .map(|t| t.data) + .data() } /// Recursively normalize all [`Column`] expressions in a list of expression trees @@ -127,7 +129,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul } }) }) - .map(|t| t.data) + .data() } /// Recursively 'unnormalize' (remove all qualifiers) from an @@ -149,7 +151,7 @@ pub fn unnormalize_col(expr: Expr) -> Expr { } }) }) - .map(|t| t.data) + .data() .expect("Unnormalize is infallable") } @@ -188,7 +190,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { } }) }) - .map(|t| t.data) + .data() .expect("strip_outer_reference is infallable") } @@ -324,7 +326,7 @@ mod test { let rewritten = col("state") .eq(lit("foo")) .transform_up(&transformer) - .map(|t| t.data) + .data() .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); @@ -332,7 +334,7 @@ mod test { let rewritten = col("state") .eq(lit("baz")) .transform_up(&transformer) - .map(|t| t.data) + .data() .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); } diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 1cc35a1a4b94..4345a386eda3 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -20,7 +20,7 @@ use crate::expr::{Alias, Sort}; use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; /// Rewrite sort on aggregate expressions to sort on the column of aggregate output @@ -133,7 +133,7 @@ fn rewrite_in_terms_of_projection( Ok(Transformed::no(expr)) }) - .map(|t| t.data) + .data() } /// Does the underlying expr match e? diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 8891ea77f60d..db5ef8931612 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -46,7 +46,7 @@ use crate::{ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; use datafusion_common::{ aggregate_functional_dependencies, internal_err, plan_err, Column, Constraints, @@ -1252,7 +1252,7 @@ impl LogicalPlan { _ => Ok(Transformed::no(expr)), } }) - .map(|t| t.data) + .data() } } @@ -3314,7 +3314,7 @@ digraph { } x => Ok(Transformed::no(x)), }) - .map(|t| t.data) + .data() .unwrap(); let expected = "Explain\ diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 15eb1035f240..99e32c0bac74 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -17,7 +17,9 @@ use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; use datafusion_common::Result; use datafusion_expr::expr::{AggregateFunction, AggregateFunctionDefinition, InSubquery}; use datafusion_expr::expr_rewriter::rewrite_preserving_name; @@ -43,7 +45,7 @@ impl CountWildcardRule { impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_down(&analyze_internal).map(|t| t.data) + plan.transform_down(&analyze_internal).data() } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index 36f0c3318371..ef297a19c69d 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; use datafusion_expr::expr::Exists; use datafusion_expr::expr::InSubquery; @@ -42,7 +42,7 @@ impl InlineTableScan { impl AnalyzerRule for InlineTableScan { fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - plan.transform_up(&analyze_internal).map(|t| t.data) + plan.transform_up(&analyze_internal).data() } fn name(&self) -> &str { diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index b21a56d2aee9..323556ad7158 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -25,7 +25,8 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + TreeNodeVisitor, }; use datafusion_common::{ internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, @@ -839,7 +840,7 @@ fn replace_common_expr( max_series_number: 0, curr_index: 0, }) - .map(|t| t.data) + .data() } #[cfg(test)] diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index f5fd71ca1551..a63133c5166f 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -22,7 +22,9 @@ use crate::optimizer::ApplyOrder; use crate::utils::is_volatile_expression; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DFSchemaRef, JoinConstraint, Result, @@ -999,7 +1001,7 @@ pub fn replace_cols_by_name( Transformed::no(expr) }) }) - .map(|t| t.data) + .data() } /// check whether the expression uses the columns in `check_map`. diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index e0236450c837..dd618d960747 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -31,7 +31,7 @@ use arrow::{ datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::Transformed; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, tree_node::{TreeNode, TreeNodeRewriter}, @@ -161,7 +161,7 @@ impl ExprSimplifier { .rewrite(&mut const_evaluator)? .data .rewrite(&mut simplifier) - .map(|t| t.data) + .data() } /// Apply type coercion to an [`Expr`] so that it can be @@ -177,7 +177,7 @@ impl ExprSimplifier { pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).map(|t| t.data) + expr.rewrite(&mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -1477,7 +1477,7 @@ mod tests { let evaluated_expr = input_expr .clone() .rewrite(&mut const_evaluator) - .map(|t| t.data) + .data() .expect("successfully evaluated"); assert_eq!( diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 8b243f82c714..9f8553cb0cc2 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -204,6 +204,7 @@ mod tests { use super::*; use arrow::datatypes::DataType; + use datafusion_common::tree_node::TransformedResult; use datafusion_common::{tree_node::TreeNode, ScalarValue}; use datafusion_expr::{col, lit, Operator}; @@ -224,12 +225,12 @@ mod tests { // x IS NULL => guaranteed false let expr = col("x").is_null(); - let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(false)); // x IS NOT NULL => guaranteed true let expr = col("x").is_not_null(); - let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, lit(true)); } @@ -239,7 +240,7 @@ mod tests { T: Clone, { for (expr, expected_value) in cases { - let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); let expected = lit(ScalarValue::from(expected_value.clone())); assert_eq!( output, expected, @@ -251,7 +252,7 @@ mod tests { fn validate_unchanged_cases(rewriter: &mut GuaranteeRewriter, cases: &[Expr]) { for expr in cases { - let output = expr.clone().rewrite(rewriter).map(|t| t.data).unwrap(); + let output = expr.clone().rewrite(rewriter).data().unwrap(); assert_eq!( &output, expr, "{} was simplified to {}, but expected it to be unchanged", @@ -481,7 +482,7 @@ mod tests { let guarantees = vec![(col("x"), NullableInterval::from(scalar.clone()))]; let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); - let output = col("x").rewrite(&mut rewriter).map(|t| t.data).unwrap(); + let output = col("x").rewrite(&mut rewriter).data().unwrap(); assert_eq!(output, Expr::Literal(scalar.clone())); } } @@ -525,7 +526,7 @@ mod tests { .collect(), *negated, ); - let output = expr.clone().rewrite(&mut rewriter).map(|t| t.data).unwrap(); + let output = expr.clone().rewrite(&mut rewriter).data().unwrap(); let expected_list = expected_list .iter() .map(|v| lit(ScalarValue::Int32(Some(*v)))) diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 4c68ac979fb2..9cc34c9b1611 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -476,7 +476,7 @@ mod tests { use crate::unwrap_cast_in_comparison::UnwrapCastExprRewriter; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::tree_node::TreeNode; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue}; use datafusion_expr::{cast, col, in_list, lit, try_cast, Expr}; use std::collections::HashMap; @@ -730,7 +730,7 @@ mod tests { let mut expr_rewriter = UnwrapCastExprRewriter { schema: schema.clone(), }; - expr.rewrite(&mut expr_rewriter).map(|t| t.data).unwrap() + expr.rewrite(&mut expr_rewriter).data().unwrap() } fn expr_test_schema() -> DFSchemaRef { diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 87e71e3458cd..0ebc2f52a28a 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -22,7 +22,7 @@ use crate::{ LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, PhysicalSortRequirement, }; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{tree_node::Transformed, JoinType}; use std::sync::Arc; @@ -268,7 +268,7 @@ impl EquivalenceGroup { } Ok(Transformed::no(expr)) }) - .map(|t| t.data) + .data() .unwrap_or(expr) } @@ -464,7 +464,7 @@ impl EquivalenceGroup { Ok(Transformed::no(expr)) }) - .map(|t| t.data) + .data() .unwrap(); result.add_equal_conditions(&new_lhs, &new_rhs); } diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 896f97dc26a0..6b928ea24c6b 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -22,7 +22,7 @@ mod properties; use crate::expressions::Column; use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; pub use class::{EquivalenceClass, EquivalenceGroup}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{join_equivalence_properties, EquivalenceProperties}; @@ -54,7 +54,7 @@ pub fn add_offset_to_expr( )))), None => Ok(Transformed::no(e)), }) - .map(|t| t.data) + .data() .unwrap() // Note that we can safely unwrap here since our transform always returns // an `Ok` value. diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index a96fbb6e484b..ad1f754a96d1 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -21,7 +21,7 @@ use crate::expressions::Column; use crate::PhysicalExpr; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; /// Stores the mapping between source expressions and target expressions for a @@ -72,7 +72,7 @@ impl ProjectionMapping { } None => Ok(Transformed::no(e)), }) - .map(|t| t.data) + .data() .map(|source_expr| (source_expr, target_expr)) }) .collect::>>() diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index d0da9d220e71..88550813fe23 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -36,7 +36,7 @@ use crate::{ }; use arrow_schema::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -848,7 +848,7 @@ impl EquivalenceProperties { pub fn get_expr_ordering(&self, expr: Arc) -> ExprOrdering { ExprOrdering::new_default(expr.clone()) .transform_up(&|expr| Ok(update_ordering(expr, self))) - .map(|t| t.data) + .data() // Guaranteed to always return `Ok`. .unwrap() } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 59c6886d0c0e..22609f5afb26 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -423,7 +423,7 @@ mod tests { use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; - use datafusion_common::tree_node::{Transformed, TreeNode}; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::ScalarValue; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; @@ -977,7 +977,7 @@ mod tests { Transformed::no(e) }) }) - .map(|t| t.data) + .data() .unwrap(); let expr3 = expr @@ -999,7 +999,7 @@ mod tests { Transformed::no(e) }) }) - .map(|t| t.data) + .data() .unwrap(); assert!(expr.ne(&expr2)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 694a18e147d3..0b0dca6bb4b6 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -29,7 +29,9 @@ use crate::{PhysicalExpr, PhysicalSortExpr}; use arrow::array::{make_array, Array, ArrayRef, BooleanArray, MutableArrayData}; use arrow::compute::{and_kleene, is_not_null, SlicesIterator}; use arrow::datatypes::SchemaRef; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, +}; use datafusion_common::Result; use datafusion_expr::Operator; @@ -239,7 +241,7 @@ pub fn reassign_predicate_columns( } Ok(Transformed::no(expr)) }) - .map(|t| t.data) + .data() } /// Reverses the ORDER BY expression, which is useful during equivalent window diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 5fec7079d762..2e4b97bc224b 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -30,7 +30,7 @@ use crate::{DisplayAs, DisplayFormatType, ExecutionMode, ExecutionPlan}; use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; @@ -327,7 +327,7 @@ fn assign_work_table( Ok(Transformed::no(plan)) } }) - .map(|t| t.data) + .data() } impl Stream for RecursiveQueryStream { diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 3f6f3aa483ab..d6f53a73dcb1 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -20,7 +20,7 @@ use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use sqlparser::ast::Ident; use datafusion_common::{exec_err, internal_err, plan_err}; @@ -46,7 +46,7 @@ pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { } } }) - .map(|t| t.data) + .data() } /// Rebuilds an `Expr` as a projection on top of a collection of `Expr`'s. @@ -76,7 +76,7 @@ pub(crate) fn rebase_expr( Ok(Transformed::no(nested_expr)) } }) - .map(|t| t.data) + .data() } /// Determines if the set of `Expr`'s are a valid projection on the input @@ -185,7 +185,7 @@ pub(crate) fn resolve_aliases_to_exprs( } _ => Ok(Transformed::no(nested_expr)), }) - .map(|t| t.data) + .data() } /// given a slice of window expressions sharing the same sort key, find their common partition From 08e3c7b8d920ef2a541679db45b5ce8d93a2f51d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 1 Mar 2024 10:42:34 +0100 Subject: [PATCH 37/40] fix docs --- datafusion/common/src/tree_node.rs | 36 +++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1d10391508f4..f90b657c5b08 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -160,19 +160,18 @@ pub trait TreeNode: Sized { /// /// The nodes are visited using the following order /// ```text - /// pre_visit(ParentNode) - /// pre_visit(ChildNode1) - /// post_visit(ChildNode1) - /// pre_visit(ChildNode2) - /// post_visit(ChildNode2) - /// post_visit(ParentNode) + /// TreeNodeVisitor::f_down(ParentNode) + /// TreeNodeVisitor::f_down(ChildNode1) + /// TreeNodeVisitor::f_up(ChildNode1) + /// TreeNodeVisitor::f_down(ChildNode2) + /// TreeNodeVisitor::f_up(ChildNode2) + /// TreeNodeVisitor::f_up(ParentNode) /// ``` /// - /// If an Err result is returned, recursion is stopped immediately + /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. /// - /// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no - /// children of that node will be visited, nor is post_visit - /// called on that node. Details see [`TreeNodeVisitor`] + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// recursion is stopped immediately. /// /// If using the default [`TreeNodeVisitor::f_up`] that does /// nothing, [`Self::apply`] should be preferred. @@ -475,6 +474,21 @@ pub enum TreeNodeRecursion { Stop, } +/// This struct is used with [`TreeNode::rewrite`], [`TreeNode::transform_down`], +/// [`TreeNode::transform_down_mut`], [`TreeNode::transform_up`], +/// [`TreeNode::transform_up_mut`] and [`TreeNode::transform_down_up`] methods to control +/// transformations and return the transformed result. +/// +/// API users can provide transformation closures and [`TreeNodeRewriter`] +/// implementations to control transformation by specifying: +/// - the possibly transformed node, +/// - if any change was made to the node and +/// - how to proceed with the recursion. +/// +/// The APIs return this struct with the: +/// - final possibly transformed tree, +/// - if any change was made to any node and +/// - how the recursion ended. #[derive(PartialEq, Debug)] pub struct Transformed { pub data: T, @@ -567,6 +581,7 @@ impl Transformed { } } +/// Transformation helper to process tree nodes that are siblings. pub trait TransformedIterator: Iterator { fn map_until_stop_and_collect(self, f: F) -> Result>> where @@ -604,6 +619,7 @@ impl TransformedIterator for I { } } +/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily. pub trait TransformedResult { fn data(self) -> Result; From 9f4b28aeedf417adf74ffd862f3c13e5db60727b Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 1 Mar 2024 18:10:23 +0100 Subject: [PATCH 38/40] restore transform as alias to trassform_up --- datafusion/common/src/tree_node.rs | 10 ++++++++++ .../combine_partial_final_agg.rs | 2 +- .../physical_optimizer/projection_pushdown.rs | 2 +- .../core/src/physical_optimizer/pruning.rs | 2 +- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/expr_rewriter/mod.rs | 16 ++++++++-------- datafusion/expr/src/expr_rewriter/order_by.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 4 ++-- .../optimizer/src/analyzer/inline_table_scan.rs | 2 +- .../physical-expr/src/equivalence/class.rs | 4 ++-- datafusion/physical-expr/src/expressions/case.rs | 2 +- datafusion/physical-plan/src/joins/utils.rs | 2 +- .../library-user-guide/working-with-exprs.md | 2 +- 13 files changed, 31 insertions(+), 21 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index f90b657c5b08..3de0c5bf945c 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -238,6 +238,16 @@ pub trait TreeNode: Sized { self.apply_children(&mut |n| n.apply(f)) } + /// Convenience utils for writing optimizers rule: recursively apply the given `f` to the node tree. + /// When `f` does not apply to a given node, it is left unchanged. + /// The default tree traversal direction is transform_up(Postorder Traversal). + fn transform(self, f: &F) -> Result> + where + F: Fn(Self) -> Result>, + { + self.transform_up(f) + } + /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its /// children(Preorder Traversal). /// When the `f` does not apply to a given node, it is left unchanged. diff --git a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs index db09d504c468..c45e14100e82 100644 --- a/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs +++ b/datafusion/core/src/physical_optimizer/combine_partial_final_agg.rs @@ -179,7 +179,7 @@ fn normalize_group_exprs(group_exprs: GroupExprsRef) -> GroupExprs { fn discard_column_index(group_expr: Arc) -> Arc { group_expr .clone() - .transform_up(&|expr| { + .transform(&|expr| { let normalized_form: Option> = match expr.as_any().downcast_ref::() { Some(column) => Some(Arc::new(Column::new(column.name(), 0))), diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 78efdc9e0ce6..5fe0d46b8043 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1038,7 +1038,7 @@ fn new_columns_for_join_on( // Rewrite all columns in `on` (*on) .clone() - .transform_up(&|expr| { + .transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { // Find the column in the projection expressions let new_column = projection_exprs diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 18df763908fa..b3dd8294d507 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1032,7 +1032,7 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform_up(&|expr| { + e.transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { if column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 89a0a8cdb952..06b276fb41fd 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1259,7 +1259,7 @@ impl Expr { /// For example, gicen an expression like ` = $0` will infer `$0` to /// have type `int32`. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result { - self.transform_up(&|mut expr| { + self.transform(&|mut expr| { // Default to assuming the arguments are the same type if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = &mut expr { rewrite_placeholder(left.as_mut(), right.as_ref(), schema)?; diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 62b0e7ad6b1d..8d7a314a89fe 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -35,7 +35,7 @@ pub use order_by::rewrite_sort_cols_by_aggs; /// Recursively call [`Column::normalize_with_schemas`] on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = LogicalPlanBuilder::normalize(plan, c)?; @@ -60,7 +60,7 @@ pub fn normalize_col_with_schemas( schemas: &[&Arc], using_columns: &[HashSet], ) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = c.normalize_with_schemas(schemas, using_columns)?; @@ -89,7 +89,7 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { exprs: vec![e] })); } - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = @@ -117,7 +117,7 @@ pub fn normalize_cols( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::Column(c) = &expr { match replace_map.get(c) { @@ -138,7 +138,7 @@ pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Resul /// For example, if there were expressions like `foo.bar` this would /// rewrite it to just `bar`. pub fn unnormalize_col(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::Column(c) = expr { let col = Column { @@ -181,7 +181,7 @@ pub fn unnormalize_cols(exprs: impl IntoIterator) -> Vec { /// Recursively remove all the ['OuterReferenceColumn'] and return the inside Column /// in the expression tree. pub fn strip_outer_reference(expr: Expr) -> Expr { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { Transformed::yes(Expr::Column(col)) @@ -325,7 +325,7 @@ mod test { // rewrites "foo" --> "bar" let rewritten = col("state") .eq(lit("foo")) - .transform_up(&transformer) + .transform(&transformer) .data() .unwrap(); assert_eq!(rewritten, col("state").eq(lit("bar"))); @@ -333,7 +333,7 @@ mod test { // doesn't rewrite let rewritten = col("state") .eq(lit("baz")) - .transform_up(&transformer) + .transform(&transformer) .data() .unwrap(); assert_eq!(rewritten, col("state").eq(lit("baz"))); diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 4345a386eda3..06d1dc061168 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -83,7 +83,7 @@ fn rewrite_in_terms_of_projection( ) -> Result { // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" - expr.transform_up(&|expr| { + expr.transform(&|expr| { // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let col = Expr::Column( diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 59a0fe59399c..825d3f037023 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1234,7 +1234,7 @@ impl LogicalPlan { expr: Expr, param_values: &ParamValues, ) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { match &expr { Expr::Placeholder(Placeholder { id, .. }) => { let value = param_values.get_placeholders_with_values(id)?; @@ -3303,7 +3303,7 @@ digraph { // after transformation, because plan is not the same anymore, // the parent plan is built again with call to LogicalPlan::with_new_inputs -> with_new_exprs let plan = plan - .transform_up(&|plan| match plan { + .transform(&|plan| match plan { LogicalPlan::TableScan(table) => { let filter = Filter::try_new( external_filter.clone(), diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index ef297a19c69d..ada7dca45759 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -74,7 +74,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Transformed::yes(plan) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform_up(&rewrite_subquery)?.data; + let new_expr = filter.predicate.transform(&rewrite_subquery)?.data; Transformed::yes(LogicalPlan::Filter(Filter::try_new( new_expr, filter.input, diff --git a/datafusion/physical-expr/src/equivalence/class.rs b/datafusion/physical-expr/src/equivalence/class.rs index 0ebc2f52a28a..6d34752d500e 100644 --- a/datafusion/physical-expr/src/equivalence/class.rs +++ b/datafusion/physical-expr/src/equivalence/class.rs @@ -260,7 +260,7 @@ impl EquivalenceGroup { /// class it matches with (if any). pub fn normalize_expr(&self, expr: Arc) -> Arc { expr.clone() - .transform_up(&|expr| { + .transform(&|expr| { for cls in self.iter() { if cls.contains(&expr) { return Ok(Transformed::yes(cls.canonical_expr().unwrap())); @@ -450,7 +450,7 @@ impl EquivalenceGroup { // Rewrite rhs to point to the right side of the join: let new_rhs = rhs .clone() - .transform_up(&|expr| { + .transform(&|expr| { if let Some(column) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 22609f5afb26..e6ce8316c27e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -960,7 +960,7 @@ mod tests { let expr2 = expr .clone() - .transform_up(&|e| { + .transform(&|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 9d8d5f301920..083c2f03be7b 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -478,7 +478,7 @@ fn replace_on_columns_of_right_ordering( let new_expr = item .expr .clone() - .transform_up(&|e| { + .transform(&|e| { if e.eq(right_col) { Ok(Transformed::yes(left_col.clone())) } else { diff --git a/docs/source/library-user-guide/working-with-exprs.md b/docs/source/library-user-guide/working-with-exprs.md index 79e12eb4c24c..a839420aa5b2 100644 --- a/docs/source/library-user-guide/working-with-exprs.md +++ b/docs/source/library-user-guide/working-with-exprs.md @@ -96,7 +96,7 @@ To implement the inlining, we'll need to write a function that takes an `Expr` a ```rust fn rewrite_add_one(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { Ok(match expr { Expr::ScalarUDF(scalar_fun) if scalar_fun.fun.name == "add_one" => { let input_arg = scalar_fun.args[0].clone(); From af6ab4a7182031d7049b82b6bfcd3a40c59ac26d Mon Sep 17 00:00:00 2001 From: Peter Toth Date: Fri, 1 Mar 2024 18:14:41 +0100 Subject: [PATCH 39/40] restore transform as alias to trassform_up 2 --- datafusion-examples/examples/rewrite_expr.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 2bb432d65be4..cc1396f770e4 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -91,7 +91,7 @@ impl AnalyzerRule for MyAnalyzerRule { impl MyAnalyzerRule { fn analyze_plan(plan: LogicalPlan) -> Result { - plan.transform_up(&|plan| { + plan.transform(&|plan| { Ok(match plan { LogicalPlan::Filter(filter) => { let predicate = Self::analyze_expr(filter.predicate.clone())?; @@ -107,7 +107,7 @@ impl MyAnalyzerRule { } fn analyze_expr(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Literal(ScalarValue::Int64(i)) => { @@ -163,7 +163,7 @@ impl OptimizerRule for MyOptimizerRule { /// use rewrite_expr to modify the expression tree. fn my_rewrite(expr: Expr) -> Result { - expr.transform_up(&|expr| { + expr.transform(&|expr| { // closure is invoked for all sub expressions Ok(match expr { Expr::Between(Between { From c944b812c24b5032499296ee4244efc1cf2cf4f2 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 4 Mar 2024 22:05:28 +0300 Subject: [PATCH 40/40] Simplifications and comment improvements (#2) --- datafusion/common/src/tree_node.rs | 682 ++++++++---------- .../core/src/datasource/listing/helpers.rs | 24 +- .../physical_plan/parquet/row_filter.rs | 19 +- .../physical_optimizer/coalesce_batches.rs | 6 +- .../enforce_distribution.rs | 23 +- .../src/physical_optimizer/enforce_sorting.rs | 20 +- .../src/physical_optimizer/join_selection.rs | 8 +- .../limited_distinct_aggregation.rs | 11 +- .../physical_optimizer/projection_pushdown.rs | 6 +- .../core/src/physical_optimizer/pruning.rs | 9 +- .../physical_optimizer/topk_aggregation.rs | 11 +- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 15 +- datafusion/expr/src/expr.rs | 27 +- datafusion/expr/src/expr_rewriter/mod.rs | 15 +- datafusion/expr/src/expr_rewriter/order_by.rs | 1 + datafusion/expr/src/logical_plan/display.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 4 +- datafusion/expr/src/tree_node/expr.rs | 4 +- datafusion/expr/src/tree_node/plan.rs | 15 +- .../src/analyzer/count_wildcard_rule.rs | 4 +- .../src/analyzer/inline_table_scan.rs | 37 +- .../optimizer/src/analyzer/rewrite_expr.rs | 69 +- datafusion/optimizer/src/analyzer/subquery.rs | 11 +- .../optimizer/src/analyzer/type_coercion.rs | 44 +- .../optimizer/src/common_subexpr_eliminate.rs | 14 +- datafusion/optimizer/src/decorrelate.rs | 25 +- .../src/decorrelate_predicate_subquery.rs | 17 +- datafusion/optimizer/src/plan_signature.rs | 2 +- .../optimizer/src/scalar_subquery_to_join.rs | 33 +- .../simplify_expressions/expr_simplifier.rs | 62 +- .../src/simplify_expressions/guarantees.rs | 8 +- .../simplify_expressions/inlist_simplifier.rs | 6 +- .../src/unwrap_cast_in_comparison.rs | 13 +- datafusion/optimizer/src/utils.rs | 8 +- .../physical-expr/src/equivalence/class.rs | 19 +- .../physical-expr/src/equivalence/mod.rs | 17 +- .../src/equivalence/projection.rs | 5 +- .../src/equivalence/properties.rs | 15 +- .../physical-expr/src/expressions/case.rs | 16 +- datafusion/physical-expr/src/tree_node.rs | 3 +- datafusion/physical-expr/src/utils/mod.rs | 8 +- .../src/joins/stream_join_utils.rs | 6 +- datafusion/physical-plan/src/joins/utils.rs | 6 +- datafusion/physical-plan/src/lib.rs | 2 +- datafusion/physical-plan/src/tree_node.rs | 3 +- datafusion/sql/src/utils.rs | 14 +- .../sqllogictest/test_files/group_by.slt | 2 +- 47 files changed, 663 insertions(+), 710 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 3de0c5bf945c..2d653a27c47b 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -22,88 +22,52 @@ use std::sync::Arc; use crate::Result; -/// This macro is used to determine continuation after a top-down closure is applied -/// during visiting traversals. +/// This macro is used to control continuation behaviors during tree traversals +/// based on the specified direction. Depending on `$DIRECTION` and the value of +/// the given expression (`$EXPR`), which should be a variant of [`TreeNodeRecursion`], +/// the macro results in the following behavior: /// -/// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the -/// function continues. -/// If it returns [`TreeNodeRecursion::Jump`], the function returns with (propagates up) -/// [`TreeNodeRecursion::Continue`] to jump next recursion step, bypassing further -/// exploration of the current step. -/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) -/// [`TreeNodeRecursion::Stop`] and recursion halts. +/// - If the expression returns [`TreeNodeRecursion::Continue`], normal execution +/// continues. +/// - If it returns [`TreeNodeRecursion::Stop`], recursion halts and propagates +/// [`TreeNodeRecursion::Stop`]. +/// - If it returns [`TreeNodeRecursion::Jump`], the continuation behavior depends +/// on the traversal direction: +/// - For `UP` direction, the function returns with [`TreeNodeRecursion::Jump`], +/// bypassing further bottom-up closures until the next top-down closure. +/// - For `DOWN` direction, the function returns with [`TreeNodeRecursion::Continue`], +/// skipping further exploration. +/// - If no direction is specified, `Jump` is treated like `Continue`. #[macro_export] -macro_rules! handle_visit_recursion_down { - ($EXPR:expr) => { - match $EXPR { - TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Continue), - TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), - } - }; -} - -/// This macro is used to determine continuation between visiting siblings during visiting -/// traversals. -/// -/// If the function returns [`TreeNodeRecursion::Continue`] or -/// [`TreeNodeRecursion::Jump`], the normal execution of the function continues. -/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) -/// [`TreeNodeRecursion::Stop`] and recursion halts. macro_rules! handle_visit_recursion { - ($EXPR:expr) => { - match $EXPR { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {} - TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), - } + // Internal helper macro for handling the `Jump` case based on the direction: + (@handle_jump UP) => { + return Ok(TreeNodeRecursion::Jump) + }; + (@handle_jump DOWN) => { + return Ok(TreeNodeRecursion::Continue) + }; + (@handle_jump) => { + {} // Treat `Jump` like `Continue`, do nothing and continue execution. }; -} -/// This macro is used to determine continuation before a bottom-up closure is applied -/// during visiting traversals. -/// -/// If the function returns [`TreeNodeRecursion::Continue`], the normal execution of the -/// function continues. -/// If it returns [`TreeNodeRecursion::Jump`], the function returns with (propagates up) -/// [`TreeNodeRecursion::Jump`], bypassing further bottom-up closures until a top-down -/// closure is found. -/// In case of [`TreeNodeRecursion::Stop`], the function return with (propagates up) -/// [`TreeNodeRecursion::Stop`] and recursion halts. -#[macro_export] -macro_rules! handle_visit_recursion_up { - ($EXPR:expr) => { + // Main macro logic with variables to handle directionality. + ($EXPR:expr $(, $DIRECTION:ident)?) => { match $EXPR { TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Jump => return Ok(TreeNodeRecursion::Jump), + TreeNodeRecursion::Jump => handle_visit_recursion!(@handle_jump $($DIRECTION)?), TreeNodeRecursion::Stop => return Ok(TreeNodeRecursion::Stop), } }; } -/// This macro is used to determine continuation during top-down transforming traversals. -/// -/// After the bottom-up closure returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion -/// continuation and [`TreeNodeRecursion`] state propagation. -#[macro_export] -macro_rules! handle_transform_recursion_down { - ($F_DOWN:expr, $F_SELF:expr) => { - $F_DOWN?.try_transform_node_with( - |n| n.map_children($F_SELF), - TreeNodeRecursion::Continue, - ) - }; -} - -/// This macro is used to determine continuation during combined transforming traversals. +/// This macro is used to determine continuation during combined transforming +/// traversals. /// -/// After the bottom-up closure returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion -/// continuation and if [`TreeNodeRecursion`] state propagation is needed. -/// And then after recursing into children returns with [`Transformed`] depending on the -/// returned [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion -/// continuation and [`TreeNodeRecursion`] state propagation. -#[macro_export] +/// Depending on the [`TreeNodeRecursion`] the bottom-up closure returns, +/// [`Transformed::try_transform_node_with()`] decides recursion continuation +/// and if state propagation is necessary. Then, the same procedure recursively +/// applies to the children of the node in question. macro_rules! handle_transform_recursion { ($F_DOWN:expr, $F_SELF:expr, $F_UP:expr) => {{ let pre_visited = $F_DOWN?; @@ -112,35 +76,20 @@ macro_rules! handle_transform_recursion { .data .map_children($F_SELF)? .try_transform_node_with($F_UP, TreeNodeRecursion::Jump), - TreeNodeRecursion::Jump => - { - #[allow(clippy::redundant_closure_call)] - $F_UP(pre_visited.data) - } + #[allow(clippy::redundant_closure_call)] + TreeNodeRecursion::Jump => $F_UP(pre_visited.data), TreeNodeRecursion::Stop => return Ok(pre_visited), } - .map(|post_visited| post_visited.update_transformed(pre_visited.transformed)) + .map(|mut post_visited| { + post_visited.transformed |= pre_visited.transformed; + post_visited + }) }}; } -/// This macro is used to determine continuation during bottom-up transforming traversals. -/// -/// After recursing into children returns with [`Transformed`] depending on the returned -/// [`TreeNodeRecursion`], [`Transformed::try_transform_node_with()`] decides about recursion -/// continuation and [`TreeNodeRecursion`] state propagation. -#[macro_export] -macro_rules! handle_transform_recursion_up { - ($NODE:expr, $F_SELF:expr, $F_UP:expr) => { - $NODE - .map_children($F_SELF)? - .try_transform_node_with($F_UP, TreeNodeRecursion::Jump) - }; -} - -/// Defines a visitable and rewriteable a tree node. This trait is -/// implemented for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as -/// well as expression trees ([`PhysicalExpr`], [`Expr`]) in -/// DataFusion +/// Defines a visitable and rewriteable tree node. This trait is implemented +/// for plans ([`ExecutionPlan`] and [`LogicalPlan`]) as well as expression +/// trees ([`PhysicalExpr`], [`Expr`]) in DataFusion. /// /// /// [`ExecutionPlan`]: https://docs.rs/datafusion/latest/datafusion/physical_plan/trait.ExecutionPlan.html @@ -148,17 +97,17 @@ macro_rules! handle_transform_recursion_up { /// [`LogicalPlan`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/logical_plan/enum.LogicalPlan.html /// [`Expr`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/expr/enum.Expr.html pub trait TreeNode: Sized { - /// Visit the tree node using the given [TreeNodeVisitor] - /// It performs a depth first walk of an node and its children. + /// Visit the tree node using the given [`TreeNodeVisitor`], performing a + /// depth-first walk of the node and its children. /// - /// For an node tree such as + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order + /// Here, the nodes would be visited using the following order: /// ```text /// TreeNodeVisitor::f_down(ParentNode) /// TreeNodeVisitor::f_down(ChildNode1) @@ -168,21 +117,22 @@ pub trait TreeNode: Sized { /// TreeNodeVisitor::f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], - /// recursion is stopped immediately. + /// the recursion stops immediately. /// - /// If using the default [`TreeNodeVisitor::f_up`] that does - /// nothing, [`Self::apply`] should be preferred. + /// If using the default [`TreeNodeVisitor::f_up`] that does nothing, consider using + /// [`Self::apply`]. fn visit>( &self, visitor: &mut V, ) -> Result { match visitor.f_down(self)? { TreeNodeRecursion::Continue => { - handle_visit_recursion_up!( - self.apply_children(&mut |n| n.visit(visitor))? + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP ); visitor.f_up(self) } @@ -194,14 +144,14 @@ pub trait TreeNode: Sized { /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for /// recursively transforming [`TreeNode`]s. /// - /// E.g. for an tree such as: + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 /// right: ChildNode2 /// ``` /// - /// The nodes are visited using the following order: + /// Here, the nodes would be visited using the following order: /// ```text /// TreeNodeRewriter::f_down(ParentNode) /// TreeNodeRewriter::f_down(ChildNode1) @@ -211,10 +161,10 @@ pub trait TreeNode: Sized { /// TreeNodeRewriter::f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If [`TreeNodeRewriter::f_down()`] or [`TreeNodeRewriter::f_up()`] returns [`Err`], - /// recursion is stopped immediately. + /// If [`TreeNodeVisitor::f_down()`] or [`TreeNodeVisitor::f_up()`] returns [`Err`], + /// the recursion stops immediately. fn rewrite>( self, rewriter: &mut R, @@ -224,79 +174,90 @@ pub trait TreeNode: Sized { }) } - /// Applies `f` to the node and its children. `f` is applied in a preoder way, - /// and it is controlled by [`TreeNodeRecursion`], which means result of the `f` - /// on the self node can cause an early return. + /// Applies `f` to the node and its children. `f` is applied in a pre-order + /// way, and it is controlled by [`TreeNodeRecursion`], which means result + /// of the `f` on a node can cause an early return. /// - /// The `f` closure can be used to collect some info from the - /// tree node or do some checking for the tree node. + /// The `f` closure can be used to collect some information from tree nodes + /// or run a check on the tree. fn apply Result>( &self, f: &mut F, ) -> Result { - handle_visit_recursion_down!(f(self)?); + handle_visit_recursion!(f(self)?, DOWN); self.apply_children(&mut |n| n.apply(f)) } - /// Convenience utils for writing optimizers rule: recursively apply the given `f` to the node tree. - /// When `f` does not apply to a given node, it is left unchanged. - /// The default tree traversal direction is transform_up(Postorder Traversal). - fn transform(self, f: &F) -> Result> - where - F: Fn(Self) -> Result>, - { + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to the tree in a bottom-up (post-order) fashion. When + /// `f` does not apply to a given node, it is left unchanged. + fn transform Result>>( + self, + f: &F, + ) -> Result> { self.transform_up(f) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its - /// children(Preorder Traversal). - /// When the `f` does not apply to a given node, it is left unchanged. - fn transform_down(self, f: &F) -> Result> - where - F: Fn(Self) -> Result>, - { - handle_transform_recursion_down!(f(self), |c| c.transform_down(f)) + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to a node and then to its children (pre-order traversal). + /// When `f` does not apply to a given node, it is left unchanged. + fn transform_down Result>>( + self, + f: &F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down(f)), + TreeNodeRecursion::Continue, + ) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'f' to the node and all of its - /// children(Preorder Traversal) using a mutable function, `F`. - /// When the `f` does not apply to a given node, it is left unchanged. - fn transform_down_mut(self, f: &mut F) -> Result> - where - F: FnMut(Self) -> Result>, - { - handle_transform_recursion_down!(f(self), |c| c.transform_down_mut(f)) + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to a node and then to its children (pre-order + /// traversal). When `f` does not apply to a given node, it is left unchanged. + fn transform_down_mut Result>>( + self, + f: &mut F, + ) -> Result> { + f(self)?.try_transform_node_with( + |n| n.map_children(|c| c.transform_down_mut(f)), + TreeNodeRecursion::Continue, + ) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'f' first to all of its - /// children and then itself(Postorder Traversal). - /// When the `f` does not apply to a given node, it is left unchanged. - fn transform_up(self, f: &F) -> Result> - where - F: Fn(Self) -> Result>, - { - handle_transform_recursion_up!(self, |c| c.transform_up(f), f) + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given function `f` to all children of a node, and then to the node itself + /// (post-order traversal). When `f` does not apply to a given node, it is + /// left unchanged. + fn transform_up Result>>( + self, + f: &F, + ) -> Result> { + self.map_children(|c| c.transform_up(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) } - /// Convenience utils for writing optimizers rule: recursively apply the given 'f' first to all of its - /// children and then itself(Postorder Traversal) using a mutable function, `F`. - /// When the `f` does not apply to a given node, it is left unchanged. - fn transform_up_mut(self, f: &mut F) -> Result> - where - F: FnMut(Self) -> Result>, - { - handle_transform_recursion_up!(self, |c| c.transform_up_mut(f), f) + /// Convenience utility for writing optimizer rules: Recursively apply the + /// given mutable function `f` to all children of a node, and then to the + /// node itself (post-order traversal). When `f` does not apply to a given + /// node, it is left unchanged. + fn transform_up_mut Result>>( + self, + f: &mut F, + ) -> Result> { + self.map_children(|c| c.transform_up_mut(f))? + .try_transform_node_with(f, TreeNodeRecursion::Jump) } /// Transforms the tree using `f_down` while traversing the tree top-down - /// (pre-preorder) and using `f_up` while traversing the tree bottom-up (post-order). + /// (pre-order), and using `f_up` while traversing the tree bottom-up + /// (post-order). /// /// Use this method if you want to start the `f_up` process right where `f_down` jumps. /// This can make the whole process faster by reducing the number of `f_up` steps. /// If you don't need this, it's just like using `transform_down_mut` followed by /// `transform_up_mut` on the same tree. /// - /// E.g. for an tree such as: + /// Consider the following tree structure: /// ```text /// ParentNode /// left: ChildNode1 @@ -313,78 +274,77 @@ pub trait TreeNode: Sized { /// f_up(ParentNode) /// ``` /// - /// See [`TreeNodeRecursion`] for more details on how the traversal can be controlled. + /// See [`TreeNodeRecursion`] for more details on controlling the traversal. /// - /// If `f_down` or `f_up` returns [`Err`], recursion is stopped immediately. + /// If `f_down` or `f_up` returns [`Err`], the recursion stops immediately. /// /// Example: /// ```text - /// | +---+ - /// | | J | - /// | +---+ - /// | | - /// | +---+ - /// TreeNodeRecursion::Continue | | I | - /// | +---+ - /// | | - /// | +---+ - /// \|/ | F | - /// ' +---+ - /// / \ ___________________ - /// When `f_down` is +---+ \ ---+ - /// applied on node "E", | E | | G | - /// it returns with "jump". +---+ +---+ - /// | | - /// +---+ +---+ - /// | C | | H | - /// +---+ +---+ - /// / \ - /// +---+ +---+ - /// | B | | D | - /// +---+ +---+ - /// | - /// +---+ - /// | A | - /// +---+ + /// | +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// TreeNodeRecursion::Continue | | I | + /// | +---+ + /// | | + /// | +---+ + /// \|/ | F | + /// ' +---+ + /// / \ ___________________ + /// When `f_down` is +---+ \ ---+ + /// applied on node "E", | E | | G | + /// it returns with "Jump". +---+ +---+ + /// | | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ /// - /// Instead of starting from leaf nodes, `f_up` starts from the node "E". - /// +---+ - /// | | J | - /// | +---+ - /// | | - /// | +---+ - /// | | I | - /// | +---+ - /// | | - /// / +---+ - /// / | F | - /// / +---+ - /// / / \ ______________________ - /// | +---+ . \ ---+ - /// | | E | /|\ After `f_down` jumps | G | - /// | +---+ | on node E, `f_up` +---+ - /// \------| ---/ if applied on node E. | - /// +---+ +---+ - /// | C | | H | - /// +---+ +---+ - /// / \ - /// +---+ +---+ - /// | B | | D | - /// +---+ +---+ - /// | - /// +---+ - /// | A | - /// +---+ - /// ``` - fn transform_down_up( + /// Instead of starting from leaf nodes, `f_up` starts from the node "E". + /// +---+ + /// | | J | + /// | +---+ + /// | | + /// | +---+ + /// | | I | + /// | +---+ + /// | | + /// / +---+ + /// / | F | + /// / +---+ + /// / / \ ______________________ + /// | +---+ . \ ---+ + /// | | E | /|\ After `f_down` jumps | G | + /// | +---+ | on node E, `f_up` +---+ + /// \------| ---/ if applied on node E. | + /// +---+ +---+ + /// | C | | H | + /// +---+ +---+ + /// / \ + /// +---+ +---+ + /// | B | | D | + /// +---+ +---+ + /// | + /// +---+ + /// | A | + /// +---+ + /// ``` + fn transform_down_up< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( self, f_down: &mut FD, f_up: &mut FU, - ) -> Result> - where - FD: FnMut(Self) -> Result>, - FU: FnMut(Self) -> Result>, - { + ) -> Result> { handle_transform_recursion!( f_down(self), |c| c.transform_down_up(f_down, f_up), @@ -392,72 +352,60 @@ pub trait TreeNode: Sized { ) } - /// Apply the closure `F` to the node's children - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result; + /// Apply the closure `F` to the node's children. + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result; - /// Apply transform `F` to the node's children, the transform `F` might have a direction(Preorder or Postorder) - fn map_children(self, f: F) -> Result> - where - F: FnMut(Self) -> Result>; + /// Apply transform `F` to the node's children. Note that the transform `F` + /// might have a direction (pre-order or post-order). + fn map_children Result>>( + self, + f: F, + ) -> Result>; } -/// Implements the [visitor -/// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively walking [`TreeNode`]s. -/// -/// [`TreeNodeVisitor`] allows keeping the algorithms -/// separate from the code to traverse the structure of the `TreeNode` -/// tree and makes it easier to add new types of tree node and -/// algorithms. -/// -/// When passed to[`TreeNode::visit`], [`TreeNodeVisitor::f_down`] -/// and [`TreeNodeVisitor::f_up`] are invoked recursively -/// on an node tree. +/// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) +/// for recursively walking [`TreeNode`]s. /// -/// If an [`Err`] result is returned, recursion is stopped -/// immediately. +/// A [`TreeNodeVisitor`] allows one to express algorithms separately from the +/// code traversing the structure of the `TreeNode` tree, making it easier to +/// add new types of tree nodes and algorithms. /// -/// If [`TreeNodeRecursion::Stop`] is returned on a call to pre_visit, no -/// children of that tree node are visited, nor is post_visit -/// called on that tree node -/// -/// If [`TreeNodeRecursion::Stop`] is returned on a call to post_visit, no -/// siblings of that tree node are visited, nor is post_visit -/// called on its parent tree node -/// -/// If [`TreeNodeRecursion::Jump`] is returned on a call to pre_visit, no -/// children of that tree node are visited. +/// When passed to [`TreeNode::visit`], [`TreeNodeVisitor::f_down`] and +/// [`TreeNodeVisitor::f_up`] are invoked recursively on the tree. +/// See [`TreeNodeRecursion`] for more details on controlling the traversal. pub trait TreeNodeVisitor: Sized { /// The node type which is visitable. type Node: TreeNode; /// Invoked before any children of `node` are visited. - /// Default implementation returns the node unmodified and continues recursion. + /// Default implementation simply continues the recursion. fn f_down(&mut self, _node: &Self::Node) -> Result { Ok(TreeNodeRecursion::Continue) } - /// Invoked after all children of `node` are visited. Default - /// implementation does nothing. + /// Invoked after all children of `node` are visited. + /// Default implementation simply continues the recursion. fn f_up(&mut self, _node: &Self::Node) -> Result { Ok(TreeNodeRecursion::Continue) } } -/// Trait for potentially recursively transform a [`TreeNode`] node tree. +/// Trait for potentially recursively transforming a tree of [`TreeNode`]s. pub trait TreeNodeRewriter: Sized { /// The node type which is rewritable. type Node: TreeNode; /// Invoked while traversing down the tree before any children are rewritten. - /// Default implementation returns the node unmodified and continues recursion. + /// Default implementation returns the node as is and continues recursion. fn f_down(&mut self, node: Self::Node) -> Result> { Ok(Transformed::no(node)) } /// Invoked while traversing up the tree after all children have been rewritten. - /// Default implementation returns the node unmodified. + /// Default implementation returns the node as is and continues recursion. fn f_up(&mut self, node: Self::Node) -> Result> { Ok(Transformed::no(node)) } @@ -468,37 +416,42 @@ pub trait TreeNodeRewriter: Sized { pub enum TreeNodeRecursion { /// Continue recursion with the next node. Continue, - - /// In top-down traversals, skip recursing into children but continue with the next - /// node, which actually means pruning of the subtree. + /// In top-down traversals, skip recursing into children but continue with + /// the next node, which actually means pruning of the subtree. /// - /// In bottom-up traversals, bypass calling bottom-up closures till the next leaf node. + /// In bottom-up traversals, bypass calling bottom-up closures till the next + /// leaf node. /// - /// In combined traversals, if it is "f_down" (pre-order) phase, execution "jumps" to - /// next "f_up" (post_order) phase by shortcutting its children. If it is "f_up" (pre-order) - /// phase, execution "jumps" to next "f_down" (pre_order) phase by shortcutting its parent - /// nodes until the first parent node having unvisited children path. + /// In combined traversals, if it is the `f_down` (pre-order) phase, execution + /// "jumps" to the next `f_up` (post-order) phase by shortcutting its children. + /// If it is the `f_up` (post-order) phase, execution "jumps" to the next `f_down` + /// (pre-order) phase by shortcutting its parent nodes until the first parent node + /// having unvisited children path. Jump, - /// Stop recursion. Stop, } -/// This struct is used with [`TreeNode::rewrite`], [`TreeNode::transform_down`], -/// [`TreeNode::transform_down_mut`], [`TreeNode::transform_up`], -/// [`TreeNode::transform_up_mut`] and [`TreeNode::transform_down_up`] methods to control -/// transformations and return the transformed result. +/// This struct is used by tree transformation APIs such as +/// - [`TreeNode::rewrite`], +/// - [`TreeNode::transform_down`], +/// - [`TreeNode::transform_down_mut`], +/// - [`TreeNode::transform_up`], +/// - [`TreeNode::transform_up_mut`], +/// - [`TreeNode::transform_down_up`] /// -/// API users can provide transformation closures and [`TreeNodeRewriter`] -/// implementations to control transformation by specifying: -/// - the possibly transformed node, -/// - if any change was made to the node and -/// - how to proceed with the recursion. +/// to control the transformation and return the transformed result. /// -/// The APIs return this struct with the: -/// - final possibly transformed tree, -/// - if any change was made to any node and -/// - how the recursion ended. +/// Specifically, API users can provide transformation closures or [`TreeNodeRewriter`] +/// implementations to control the transformation by returning: +/// - The resulting (possibly transformed) node, +/// - A flag indicating whether any change was made to the node, and +/// - A flag specifying how to proceed with the recursion. +/// +/// At the end of the transformation, the return value will contain: +/// - The final (possibly transformed) tree, +/// - A flag indicating whether any change was made to the tree, and +/// - A flag specifying how the recursion ended. #[derive(PartialEq, Debug)] pub struct Transformed { pub data: T, @@ -507,6 +460,7 @@ pub struct Transformed { } impl Transformed { + /// Create a new `Transformed` object with the given information. pub fn new(data: T, transformed: bool, tnr: TreeNodeRecursion) -> Self { Self { data, @@ -517,115 +471,99 @@ impl Transformed { /// Wrapper for transformed data with [`TreeNodeRecursion::Continue`] statement. pub fn yes(data: T) -> Self { - Self { - data, - transformed: true, - tnr: TreeNodeRecursion::Continue, - } + Self::new(data, true, TreeNodeRecursion::Continue) } - /// Wrapper for non-transformed data with [`TreeNodeRecursion::Continue`] statement. + /// Wrapper for unchanged data with [`TreeNodeRecursion::Continue`] statement. pub fn no(data: T) -> Self { - Self { - data, - transformed: false, - tnr: TreeNodeRecursion::Continue, - } + Self::new(data, false, TreeNodeRecursion::Continue) } - /// Applies the given `f` to the data of [`Transformed`] object. + /// Applies the given `f` to the data of this [`Transformed`] object. pub fn update_data U>(self, f: F) -> Transformed { Transformed::new(f(self.data), self.transformed, self.tnr) } - /// Updates the transformed state based on the current and the new state. - pub fn update_transformed(self, transformed: bool) -> Self { - Self { - transformed: self.transformed || transformed, - ..self - } - } - - /// Sets a new [`TreeNodeRecursion`]. - pub fn update_tnr(self, tnr: TreeNodeRecursion) -> Self { - Self { tnr, ..self } - } - /// Maps the data of [`Transformed`] object to the result of the given `f`. pub fn map_data Result>(self, f: F) -> Result> { f(self.data).map(|data| Transformed::new(data, self.transformed, self.tnr)) } - /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] is - /// straightforward, but [`TreeNodeRecursion::Jump`] can behave differently when we - /// are traversing down or up on a tree. - /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Jump`] recursion is - /// stopped with the given `return_if_jump` [`TreeNodeRecursion`] statement. + /// Handling [`TreeNodeRecursion::Continue`] and [`TreeNodeRecursion::Stop`] + /// is straightforward, but [`TreeNodeRecursion::Jump`] can behave differently + /// when we are traversing down or up on a tree. If [`TreeNodeRecursion`] of + /// the node is [`TreeNodeRecursion::Jump`], recursion stops with the given + /// `return_if_jump` value. fn try_transform_node_with Result>>( - self, + mut self, f: F, return_if_jump: TreeNodeRecursion, ) -> Result> { match self.tnr { TreeNodeRecursion::Continue => { - f(self.data).map(|t| t.update_transformed(self.transformed)) + return f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }); } - TreeNodeRecursion::Jump => Ok(self.update_tnr(return_if_jump)), - TreeNodeRecursion::Stop => Ok(self), + TreeNodeRecursion::Jump => { + self.tnr = return_if_jump; + } + TreeNodeRecursion::Stop => {} } + Ok(self) } /// If [`TreeNodeRecursion`] of the node is [`TreeNodeRecursion::Continue`] or - /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. Otherwise, it - /// remains as it is. + /// [`TreeNodeRecursion::Jump`], transformation is applied to the node. + /// Otherwise, it remains as it is. pub fn try_transform_node Result>>( self, f: F, ) -> Result> { - match self.tnr { - TreeNodeRecursion::Continue => {} - TreeNodeRecursion::Jump => {} - TreeNodeRecursion::Stop => return Ok(self), - }; - f(self.data).map(|t| t.update_transformed(self.transformed)) + if self.tnr == TreeNodeRecursion::Stop { + Ok(self) + } else { + f(self.data).map(|mut t| { + t.transformed |= self.transformed; + t + }) + } } } /// Transformation helper to process tree nodes that are siblings. pub trait TransformedIterator: Iterator { - fn map_until_stop_and_collect(self, f: F) -> Result>> - where - F: FnMut(Self::Item) -> Result>; + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( + self, + f: F, + ) -> Result>>; } impl TransformedIterator for I { - fn map_until_stop_and_collect( + fn map_until_stop_and_collect< + F: FnMut(Self::Item) -> Result>, + >( self, mut f: F, - ) -> Result>> - where - F: FnMut(Self::Item) -> Result>, - { - let mut new_tnr = TreeNodeRecursion::Continue; - let mut new_transformed = false; - let new_data = self - .map(|i| { - Ok(match new_tnr { - TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { - let Transformed { - data, - transformed, - tnr, - } = f(i)?; - new_tnr = tnr; - new_transformed |= transformed; - data - } - TreeNodeRecursion::Stop => i, - }) + ) -> Result>> { + let mut tnr = TreeNodeRecursion::Continue; + let mut transformed = false; + let data = self + .map(|item| match tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + f(item).map(|result| { + tnr = result.tnr; + transformed |= result.transformed; + result.data + }) + } + TreeNodeRecursion::Stop => Ok(item), }) .collect::>>()?; - Ok(Transformed::new(new_data, new_transformed, new_tnr)) + Ok(Transformed::new(data, transformed, tnr)) } } @@ -652,15 +590,14 @@ impl TransformedResult for Result> { } } -/// Helper trait for implementing [`TreeNode`] that have children stored as Arc's -/// -/// If some trait object, such as `dyn T`, implements this trait, -/// its related `Arc` will automatically implement [`TreeNode`] +/// Helper trait for implementing [`TreeNode`] that have children stored as +/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// its related `Arc` will automatically implement [`TreeNode`]. pub trait DynTreeNode { - /// Returns all children of the specified TreeNode + /// Returns all children of the specified `TreeNode`. fn arc_children(&self) -> Vec>; - /// construct a new self with the specified children + /// Constructs a new node with the specified children. fn with_new_arc_children( &self, arc_self: Arc, @@ -668,14 +605,13 @@ pub trait DynTreeNode { ) -> Result>; } -/// Blanket implementation for Arc for any tye that implements -/// [`DynTreeNode`] (such as [`Arc`]) +/// Blanket implementation for any `Arc` where `T` implements [`DynTreeNode`] +/// (such as [`Arc`]). impl TreeNode for Arc { - /// Apply the closure `F` to the node's children - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { let mut tnr = TreeNodeRecursion::Continue; for child in self.arc_children() { tnr = f(&child)?; @@ -684,10 +620,10 @@ impl TreeNode for Arc { Ok(tnr) } - fn map_children(self, f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let children = self.arc_children(); if !children.is_empty() { let new_children = children.into_iter().map_until_stop_and_collect(f)?; @@ -722,11 +658,10 @@ pub trait ConcreteTreeNode: Sized { } impl TreeNode for T { - /// Apply the closure `F` to the node's children - fn apply_children(&self, f: &mut F) -> Result - where - F: FnMut(&Self) -> Result, - { + fn apply_children Result>( + &self, + f: &mut F, + ) -> Result { let mut tnr = TreeNodeRecursion::Continue; for child in self.children() { tnr = f(child)?; @@ -735,10 +670,10 @@ impl TreeNode for T { Ok(tnr) } - fn map_children(self, f: F) -> Result> - where - F: FnMut(Self) -> Result>, - { + fn map_children Result>>( + self, + f: F, + ) -> Result> { let (new_self, children) = self.take_children(); if !children.is_empty() { let new_children = children.into_iter().map_until_stop_and_collect(f)?; @@ -753,12 +688,13 @@ impl TreeNode for T { #[cfg(test)] mod tests { + use std::fmt::Display; + use crate::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeRewriter, TreeNodeVisitor, }; use crate::Result; - use std::fmt::Display; #[derive(PartialEq, Debug)] struct TestTreeNode { diff --git a/datafusion/core/src/datasource/listing/helpers.rs b/datafusion/core/src/datasource/listing/helpers.rs index 78955104c72a..eef25792d00a 100644 --- a/datafusion/core/src/datasource/listing/helpers.rs +++ b/datafusion/core/src/datasource/listing/helpers.rs @@ -19,29 +19,27 @@ use std::sync::Arc; -use arrow::compute::{and, cast, prep_null_mask_filter}; +use super::PartitionedFile; +use crate::datasource::listing::ListingTableUrl; +use crate::execution::context::SessionState; +use crate::{error::Result, scalar::ScalarValue}; + use arrow::{ - array::{ArrayRef, StringBuilder}, + array::{Array, ArrayRef, AsArray, StringBuilder}, + compute::{and, cast, prep_null_mask_filter}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use arrow_array::cast::AsArray; -use arrow_array::Array; use arrow_schema::Fields; -use futures::stream::FuturesUnordered; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; -use log::{debug, trace}; - -use crate::{error::Result, scalar::ScalarValue}; - -use super::PartitionedFile; -use crate::datasource::listing::ListingTableUrl; -use crate::execution::context::SessionState; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{internal_err, Column, DFField, DFSchema, DataFusionError}; use datafusion_expr::{Expr, ScalarFunctionDefinition, Volatility}; use datafusion_physical_expr::create_physical_expr; use datafusion_physical_expr::execution_props::ExecutionProps; + +use futures::stream::{BoxStream, FuturesUnordered}; +use futures::{StreamExt, TryStreamExt}; +use log::{debug, trace}; use object_store::path::Path; use object_store::{ObjectMeta, ObjectStore}; diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs index 06687dde6baf..c0e37a7150d9 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_filter.rs @@ -15,28 +15,28 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::sync::Arc; + +use super::ParquetFileMetrics; +use crate::physical_plan::metrics; + use arrow::array::BooleanArray; use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{arrow_err, DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{Column, Literal}; use datafusion_physical_expr::utils::reassign_predicate_columns; -use std::collections::BTreeSet; - use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; + use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; -use std::sync::Arc; - -use crate::physical_plan::metrics; - -use super::ParquetFileMetrics; /// This module contains utilities for enabling the pushdown of DataFusion filter predicates (which /// can be any DataFusion `Expr` that evaluates to a `BooleanArray`) to the parquet decoder level in `arrow-rs`. @@ -190,8 +190,7 @@ impl<'a> FilterCandidateBuilder<'a> { mut self, metadata: &ParquetMetaData, ) -> Result> { - let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?.data; + let expr = self.expr.clone().rewrite(&mut self).data()?; if self.non_primitive_columns || self.projected_columns { Ok(None) diff --git a/datafusion/core/src/physical_optimizer/coalesce_batches.rs b/datafusion/core/src/physical_optimizer/coalesce_batches.rs index 01213ed8df1a..7c0082037da0 100644 --- a/datafusion/core/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/core/src/physical_optimizer/coalesce_batches.rs @@ -18,8 +18,10 @@ //! CoalesceBatches optimizer that groups batches together rows //! in bigger batches to avoid overhead with small batches -use crate::config::ConfigOptions; +use std::sync::Arc; + use crate::{ + config::ConfigOptions, error::Result, physical_optimizer::PhysicalOptimizerRule, physical_plan::{ @@ -27,8 +29,8 @@ use crate::{ repartition::RepartitionExec, Partitioning, }, }; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use std::sync::Arc; /// Optimizer rule that introduces CoalesceBatchesExec to avoid overhead with small batches that /// are produced by highly selective filters diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs b/datafusion/core/src/physical_optimizer/enforce_distribution.rs index 8b6610e871f7..822cd0541ae2 100644 --- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs +++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs @@ -45,7 +45,7 @@ use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning}; use arrow::compute::SortOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_expr::logical_plan::JoinType; use datafusion_physical_expr::expressions::{Column, NoOp}; use datafusion_physical_expr::utils::map_columns_before_projection; @@ -198,15 +198,15 @@ impl PhysicalOptimizerRule for EnforceDistribution { // Run a top-down process to adjust input key ordering recursively let plan_requirements = PlanWithKeyRequirements::new_default(plan); let adjusted = plan_requirements - .transform_down(&adjust_input_keys_ordering)? - .data; + .transform_down(&adjust_input_keys_ordering) + .data()?; adjusted.plan } else { // Run a bottom-up process plan.transform_up(&|plan| { Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) - })? - .data + }) + .data()? }; let distribution_context = DistributionContext::new_default(adjusted); @@ -214,8 +214,8 @@ impl PhysicalOptimizerRule for EnforceDistribution { let distribution_context = distribution_context .transform_up(&|distribution_context| { ensure_distribution(distribution_context, config) - })? - .data; + }) + .data()?; Ok(distribution_context.plan) } @@ -1788,7 +1788,8 @@ pub(crate) mod tests { let plan_requirements = PlanWithKeyRequirements::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_down(&adjust_input_keys_ordering).data() + .transform_down(&adjust_input_keys_ordering) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. adjusted.plan @@ -1796,14 +1797,16 @@ pub(crate) mod tests { // Run reorder_join_keys_to_inputs rule $PLAN.clone().transform_up(&|plan| { Ok(Transformed::yes(reorder_join_keys_to_inputs(plan)?)) - })?.data + }) + .data()? }; // Then run ensure_distribution rule DistributionContext::new_default(adjusted) .transform_up(&|distribution_context| { ensure_distribution(distribution_context, &config) - }).data() + }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index d31bc9c452b9..79dd5758cc2f 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -165,8 +165,8 @@ impl PhysicalOptimizerRule for EnforceSorting { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions - .transform_up(¶llelize_sorts)? - .data; + .transform_up(¶llelize_sorts) + .data()?; parallel.plan } else { adjusted.plan @@ -181,8 +181,8 @@ impl PhysicalOptimizerRule for EnforceSorting { true, config, ) - })? - .data; + }) + .data()?; // Execute a top-down traversal to exploit sort push-down opportunities // missed by the bottom-up traversal: @@ -687,7 +687,8 @@ mod tests { { let plan_requirements = PlanWithCorrespondingSort::new_default($PLAN.clone()); let adjusted = plan_requirements - .transform_up(&ensure_sorting).data() + .transform_up(&ensure_sorting) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. @@ -695,7 +696,8 @@ mod tests { let plan_with_coalesce_partitions = PlanWithCorrespondingCoalescePartitions::new_default(adjusted.plan); let parallel = plan_with_coalesce_partitions - .transform_up(¶llelize_sorts).data() + .transform_up(¶llelize_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. parallel.plan @@ -712,14 +714,16 @@ mod tests { true, state.config_options(), ) - }).data() + }) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. let mut sort_pushdown = SortPushDown::new_default(updated_plan.plan); assign_initial_requirements(&mut sort_pushdown); sort_pushdown - .transform_down(&pushdown_sorts).data() + .transform_down(&pushdown_sorts) + .data() .and_then(check_integrity)?; // TODO: End state payloads will be checked here. } diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs b/datafusion/core/src/physical_optimizer/join_selection.rs index cd710ce46990..47ca2028fd86 100644 --- a/datafusion/core/src/physical_optimizer/join_selection.rs +++ b/datafusion/core/src/physical_optimizer/join_selection.rs @@ -237,8 +237,8 @@ impl PhysicalOptimizerRule for JoinSelection { Box::new(hash_join_swap_subrule), ]; let new_plan = plan - .transform_up(&|p| apply_subrules(p, &subrules, config))? - .data; + .transform_up(&|p| apply_subrules(p, &subrules, config)) + .data()?; // Next, we apply another subrule that tries to optimize joins using any // statistics their inputs might have. // - For a hash join with partition mode [`PartitionMode::Auto`], we will @@ -813,8 +813,8 @@ mod tests_statistical { Box::new(hash_join_swap_subrule), ]; let new_plan = plan - .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new()))? - .data; + .transform_up(&|p| apply_subrules(p, &subrules, &ConfigOptions::new())) + .data()?; // TODO: End state payloads will be checked here. let config = ConfigOptions::new().optimizer; let collect_left_threshold = config.hash_join_single_partition_threshold; diff --git a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs index 8b14bf067d3c..9509d4e4c828 100644 --- a/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/limited_distinct_aggregation.rs @@ -162,7 +162,7 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_distinct_aggregation_soft_limit { + if config.optimizer.enable_distinct_aggregation_soft_limit { plan.transform_down(&|plan| { Ok( if let Some(plan) = @@ -173,12 +173,11 @@ impl PhysicalOptimizerRule for LimitedDistinctAggregation { Transformed::no(plan) }, ) - })? - .data + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index 5fe0d46b8043..17d30a2b4ec1 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -112,9 +112,9 @@ pub fn remove_unnecessary_projections( let maybe_unified = try_unifying_projections(projection, child_projection)?; return if let Some(new_plan) = maybe_unified { // To unify 3 or more sequential projections: - Ok(Transformed::yes( - remove_unnecessary_projections(new_plan)?.data, - )) + remove_unnecessary_projections(new_plan) + .data() + .map(Transformed::yes) } else { Ok(Transformed::no(plan)) }; diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index b3dd8294d507..05d2d852e057 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -29,21 +29,22 @@ use crate::{ logical_expr::Operator, physical_plan::{ColumnarValue, PhysicalExpr}, }; -use arrow::record_batch::RecordBatchOptions; + use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, datatypes::{DataType, Field, Schema, SchemaRef}, - record_batch::RecordBatch, + record_batch::{RecordBatch, RecordBatchOptions}, }; use arrow_array::cast::AsArray; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ - internal_err, plan_err, + internal_err, plan_datafusion_err, plan_err, tree_node::{Transformed, TreeNode}, + ScalarValue, }; -use datafusion_common::{plan_datafusion_err, ScalarValue}; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; + use log::trace; /// A source of runtime statistical information to [`PruningPredicate`]s. diff --git a/datafusion/core/src/physical_optimizer/topk_aggregation.rs b/datafusion/core/src/physical_optimizer/topk_aggregation.rs index e85243b30ba1..c47e5e25d143 100644 --- a/datafusion/core/src/physical_optimizer/topk_aggregation.rs +++ b/datafusion/core/src/physical_optimizer/topk_aggregation.rs @@ -140,7 +140,7 @@ impl PhysicalOptimizerRule for TopKAggregation { plan: Arc, config: &ConfigOptions, ) -> Result> { - let plan = if config.optimizer.enable_topk_aggregation { + if config.optimizer.enable_topk_aggregation { plan.transform_down(&|plan| { Ok( if let Some(plan) = TopKAggregation::transform_sort(plan.clone()) { @@ -149,12 +149,11 @@ impl PhysicalOptimizerRule for TopKAggregation { Transformed::no(plan) }, ) - })? - .data + }) + .data() } else { - plan - }; - Ok(plan) + Ok(plan) + } } fn name(&self) -> &str { diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index c24db11763ea..59905d859dc8 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -17,19 +17,12 @@ use std::sync::Arc; -use arrow::array::{ArrayRef, Int64Array}; +use arrow::array::{Array, ArrayRef, AsArray, Int64Array}; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use arrow_array::cast::AsArray; use arrow_array::types::Int64Type; -use arrow_array::Array; -use hashbrown::HashMap; -use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; -use tokio::task::JoinSet; - use datafusion::common::Result; use datafusion::datasource::MemTable; use datafusion::physical_plan::aggregates::{ @@ -44,6 +37,11 @@ use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; use datafusion_physical_plan::InputOrderMode; use test_utils::{add_empty_batches, StringBatchGenerator}; +use hashbrown::HashMap; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use tokio::task::JoinSet; + /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -316,6 +314,7 @@ async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { impl TreeNodeVisitor for Visitor { type Node = Arc; + fn f_down(&mut self, node: &Self::Node) -> Result { if let Some(exec) = node.as_any().downcast_ref::() { if self.expected_sort { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 06b276fb41fd..68b123ab1f28 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -17,28 +17,27 @@ //! Logical Expressions: [`Expr`] +use std::collections::HashSet; +use std::fmt::{self, Display, Formatter, Write}; +use std::hash::Hash; +use std::str::FromStr; +use std::sync::Arc; + use crate::expr_fn::binary_expr; use crate::logical_plan::Subquery; use crate::utils::{expr_to_columns, find_out_reference_exprs}; use crate::window_frame; +use crate::{ + aggregate_function, built_in_function, built_in_window_function, udaf, + BuiltinScalarFunction, ExprSchemable, Operator, Signature, +}; -use crate::Operator; -use crate::{aggregate_function, ExprSchemable}; -use crate::{built_in_function, BuiltinScalarFunction}; -use crate::{built_in_window_function, udaf}; use arrow::datatypes::DataType; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::{internal_err, DFSchema, OwnedTableReference}; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::{ + internal_err, plan_err, Column, DFSchema, OwnedTableReference, Result, ScalarValue, +}; use sqlparser::ast::NullTreatment; -use std::collections::HashSet; -use std::fmt; -use std::fmt::{Display, Formatter, Write}; -use std::hash::Hash; -use std::str::FromStr; -use std::sync::Arc; - -use crate::Signature; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 8d7a314a89fe..cd9a8344dec4 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -17,17 +17,18 @@ //! Expression rewriter +use std::collections::HashMap; +use std::collections::HashSet; +use std::sync::Arc; + use crate::expr::{Alias, Unnest}; use crate::logical_plan::Projection; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; + use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRewriter, }; -use datafusion_common::Result; -use datafusion_common::{Column, DFSchema}; -use std::collections::HashMap; -use std::collections::HashSet; -use std::sync::Arc; +use datafusion_common::{Column, DFSchema, Result}; mod order_by; pub use order_by::rewrite_sort_cols_by_aggs; @@ -277,13 +278,15 @@ where #[cfg(test)] mod test { + use std::ops::Add; + use super::*; use crate::expr::Sort; use crate::{col, lit, Cast}; + use arrow::datatypes::DataType; use datafusion_common::tree_node::{TreeNode, TreeNodeRewriter}; use datafusion_common::{DFField, DFSchema, ScalarValue}; - use std::ops::Add; #[derive(Default)] struct RecordingRewriter { diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 06d1dc061168..b1bc11a83f90 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -20,6 +20,7 @@ use crate::expr::{Alias, Sort}; use crate::expr_rewriter::normalize_col; use crate::{Cast, Expr, ExprSchemable, LogicalPlan, TryCast}; + use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{Column, Result}; diff --git a/datafusion/expr/src/logical_plan/display.rs b/datafusion/expr/src/logical_plan/display.rs index 389a33612d4c..e0cb44626e24 100644 --- a/datafusion/expr/src/logical_plan/display.rs +++ b/datafusion/expr/src/logical_plan/display.rs @@ -16,12 +16,14 @@ // under the License. //! This module provides logic for displaying LogicalPlans in various styles +use std::fmt; + use crate::LogicalPlan; + use arrow::datatypes::Schema; use datafusion_common::display::GraphvizBuilder; use datafusion_common::tree_node::{TreeNodeRecursion, TreeNodeVisitor}; use datafusion_common::DataFusionError; -use std::fmt; /// Formats plans with a single line per node. For example: /// diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 825d3f037023..ca021c4bfc28 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -664,8 +664,8 @@ impl LogicalPlan { )), _ => Ok(Transformed::no(expr)), } - })? - .data; + }) + .data()?; Filter::try_new(predicate, Arc::new(inputs.swap_remove(0))) .map(LogicalPlan::Filter) diff --git a/datafusion/expr/src/tree_node/expr.rs b/datafusion/expr/src/tree_node/expr.rs index 41192a6f29b7..67d48f986f13 100644 --- a/datafusion/expr/src/tree_node/expr.rs +++ b/datafusion/expr/src/tree_node/expr.rs @@ -27,7 +27,7 @@ use crate::{Expr, GetFieldAccess}; use datafusion_common::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, }; -use datafusion_common::{handle_visit_recursion_down, internal_err, Result}; +use datafusion_common::{handle_visit_recursion, internal_err, Result}; impl TreeNode for Expr { fn apply_children Result>( @@ -134,7 +134,7 @@ impl TreeNode for Expr { let mut tnr = TreeNodeRecursion::Continue; for child in children { tnr = f(child)?; - handle_visit_recursion_down!(tnr); + handle_visit_recursion!(tnr, DOWN); } Ok(tnr) diff --git a/datafusion/expr/src/tree_node/plan.rs b/datafusion/expr/src/tree_node/plan.rs index 6b2b9d055c81..02d5d1851289 100644 --- a/datafusion/expr/src/tree_node/plan.rs +++ b/datafusion/expr/src/tree_node/plan.rs @@ -22,7 +22,7 @@ use crate::LogicalPlan; use datafusion_common::tree_node::{ Transformed, TransformedIterator, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; -use datafusion_common::{handle_visit_recursion_down, handle_visit_recursion_up, Result}; +use datafusion_common::{handle_visit_recursion, Result}; impl TreeNode for LogicalPlan { fn apply Result>( @@ -31,7 +31,7 @@ impl TreeNode for LogicalPlan { ) -> Result { // Compared to the default implementation, we need to invoke // [`Self::apply_subqueries`] before visiting its children - handle_visit_recursion_down!(f(self)?); + handle_visit_recursion!(f(self)?, DOWN); self.apply_subqueries(f)?; self.apply_children(&mut |n| n.apply(f)) } @@ -65,8 +65,9 @@ impl TreeNode for LogicalPlan { match visitor.f_down(self)? { TreeNodeRecursion::Continue => { self.visit_subqueries(visitor)?; - handle_visit_recursion_up!( - self.apply_children(&mut |n| n.visit(visitor))? + handle_visit_recursion!( + self.apply_children(&mut |n| n.visit(visitor))?, + UP ); visitor.f_up(self) } @@ -85,7 +86,7 @@ impl TreeNode for LogicalPlan { let mut tnr = TreeNodeRecursion::Continue; for child in self.inputs() { tnr = f(child)?; - handle_visit_recursion_down!(tnr) + handle_visit_recursion!(tnr, DOWN) } Ok(tnr) } @@ -94,8 +95,8 @@ impl TreeNode for LogicalPlan { where F: FnMut(Self) -> Result>, { - let old_children = self.inputs(); - let new_children = old_children + let new_children = self + .inputs() .iter() .map(|&c| c.clone()) .map_until_stop_and_collect(f)?; diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 99e32c0bac74..93b24d71c496 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRewriter, @@ -29,7 +32,6 @@ use datafusion_expr::{ aggregate_function, expr, lit, Aggregate, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Projection, Sort, Subquery, }; -use std::sync::Arc; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// diff --git a/datafusion/optimizer/src/analyzer/inline_table_scan.rs b/datafusion/optimizer/src/analyzer/inline_table_scan.rs index ada7dca45759..b21ec851dfcd 100644 --- a/datafusion/optimizer/src/analyzer/inline_table_scan.rs +++ b/datafusion/optimizer/src/analyzer/inline_table_scan.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::analyzer::AnalyzerRule; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::Exists; -use datafusion_expr::expr::InSubquery; +use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::{ logical_plan::LogicalPlan, Expr, Filter, LogicalPlanBuilder, TableScan, }; @@ -51,7 +51,7 @@ impl AnalyzerRule for InlineTableScan { } fn analyze_internal(plan: LogicalPlan) -> Result> { - Ok(match plan { + match plan { // Match only on scans without filter / projection / fetch // Views and DataFrames won't have those added // during the early stage of planning @@ -64,31 +64,29 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { }) if filters.is_empty() && source.get_logical_plan().is_some() => { let sub_plan = source.get_logical_plan().unwrap(); let projection_exprs = generate_projection_expr(&projection, sub_plan)?; - let plan = LogicalPlanBuilder::from(sub_plan.clone()) + LogicalPlanBuilder::from(sub_plan.clone()) .project(projection_exprs)? // Ensures that the reference to the inlined table remains the // same, meaning we don't have to change any of the parent nodes // that reference this table. .alias(table_name)? - .build()?; - Transformed::yes(plan) + .build() + .map(Transformed::yes) } LogicalPlan::Filter(filter) => { - let new_expr = filter.predicate.transform(&rewrite_subquery)?.data; - Transformed::yes(LogicalPlan::Filter(Filter::try_new( - new_expr, - filter.input, - )?)) + let new_expr = filter.predicate.transform(&rewrite_subquery).data()?; + Filter::try_new(new_expr, filter.input) + .map(|e| Transformed::yes(LogicalPlan::Filter(e))) } - _ => Transformed::no(plan), - }) + _ => Ok(Transformed::no(plan)), + } } fn rewrite_subquery(expr: Expr) -> Result> { match expr { Expr::Exists(Exists { subquery, negated }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?.data; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::yes(Expr::Exists(Exists { subquery, negated }))) } @@ -98,7 +96,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { negated, }) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?.data; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( expr, subquery, negated, @@ -106,7 +104,7 @@ fn rewrite_subquery(expr: Expr) -> Result> { } Expr::ScalarSubquery(subquery) => { let plan = subquery.subquery.as_ref().clone(); - let new_plan = plan.transform_up(&analyze_internal)?.data; + let new_plan = plan.transform_up(&analyze_internal).data()?; let subquery = subquery.with_plan(Arc::new(new_plan)); Ok(Transformed::yes(Expr::ScalarSubquery(subquery))) } @@ -135,13 +133,12 @@ fn generate_projection_expr( mod tests { use std::{sync::Arc, vec}; - use arrow::datatypes::{DataType, Field, Schema}; - - use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; - use crate::analyzer::inline_table_scan::InlineTableScan; use crate::test::assert_analyzed_plan_eq; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; + pub struct RawTableSource {} impl TableSource for RawTableSource { diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 7bf852f8891c..41ebcd8e501a 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -19,21 +19,19 @@ use std::sync::Arc; +use super::AnalyzerRule; + use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::utils::list_ndims; -use datafusion_common::DFSchema; -use datafusion_common::DFSchemaRef; -use datafusion_common::Result; +use datafusion_common::{DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::BuiltinScalarFunction; -use datafusion_expr::Operator; -use datafusion_expr::ScalarFunctionDefinition; -use datafusion_expr::{BinaryExpr, Expr, LogicalPlan}; - -use super::AnalyzerRule; +use datafusion_expr::{ + BinaryExpr, BuiltinScalarFunction, Expr, LogicalPlan, Operator, + ScalarFunctionDefinition, +}; #[derive(Default)] pub struct OperatorToFunction {} @@ -97,38 +95,31 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { type Node = Expr; fn f_up(&mut self, expr: Expr) -> Result> { - Ok(match expr { - Expr::BinaryExpr(BinaryExpr { - ref left, + if let Expr::BinaryExpr(BinaryExpr { + ref left, + op, + ref right, + }) = expr + { + if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( + left.as_ref(), op, - ref right, - }) => { - if let Some(fun) = rewrite_array_concat_operator_to_func_for_column( - left.as_ref(), - op, - right.as_ref(), - self.schema.as_ref(), - )? - .or_else(|| { - rewrite_array_concat_operator_to_func( - left.as_ref(), - op, - right.as_ref(), - ) - }) { - // Convert &Box -> Expr - let left = (**left).clone(); - let right = (**right).clone(); - return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args: vec![left, right], - }))); - } - - Transformed::no(expr) + right.as_ref(), + self.schema.as_ref(), + )? + .or_else(|| { + rewrite_array_concat_operator_to_func(left.as_ref(), op, right.as_ref()) + }) { + // Convert &Box -> Expr + let left = (**left).clone(); + let right = (**right).clone(); + return Ok(Transformed::yes(Expr::ScalarFunction(ScalarFunction { + func_def: ScalarFunctionDefinition::BuiltIn(fun), + args: vec![left, right], + }))); } - _ => Transformed::no(expr), - }) + } + Ok(Transformed::no(expr)) } } diff --git a/datafusion/optimizer/src/analyzer/subquery.rs b/datafusion/optimizer/src/analyzer/subquery.rs index 3aab62438fe5..b7f513727d39 100644 --- a/datafusion/optimizer/src/analyzer/subquery.rs +++ b/datafusion/optimizer/src/analyzer/subquery.rs @@ -15,8 +15,11 @@ // specific language governing permissions and limitations // under the License. +use std::ops::Deref; + use crate::analyzer::check_plan; use crate::utils::collect_subquery_cols; + use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr_rewriter::strip_outer_reference; @@ -25,7 +28,6 @@ use datafusion_expr::{ Aggregate, BinaryExpr, Cast, Expr, Filter, Join, JoinType, LogicalPlan, Operator, Window, }; -use std::ops::Deref; /// Do necessary check on subquery expressions and fail the invalid plan /// 1) Check whether the outer plan is in the allowed outer plans list to use subquery expressions, @@ -287,10 +289,9 @@ fn get_correlated_expressions(inner_plan: &LogicalPlan) -> Result> { .into_iter() .partition(|e| e.contains_outer()); - correlated - .into_iter() - .for_each(|expr| exprs.push(strip_outer_reference(expr.clone()))); - return Ok(TreeNodeRecursion::Continue); + for expr in correlated { + exprs.push(strip_outer_reference(expr.clone())); + } } Ok(TreeNodeRecursion::Continue) })?; diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 8c208fab9aa6..08f49ed15b09 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -19,8 +19,9 @@ use std::sync::Arc; -use arrow::datatypes::{DataType, IntervalUnit}; +use crate::analyzer::AnalyzerRule; +use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; use datafusion_common::{ @@ -50,8 +51,6 @@ use datafusion_expr::{ WindowFrameBound, WindowFrameUnits, }; -use crate::analyzer::AnalyzerRule; - #[derive(Default)] pub struct TypeCoercion {} @@ -753,31 +752,26 @@ mod test { use std::any::Any; use std::sync::{Arc, OnceLock}; - use arrow::array::{FixedSizeListArray, Int32Array}; - use arrow::datatypes::{DataType, TimeUnit}; + use crate::analyzer::type_coercion::{ + cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, + }; + use crate::test::assert_analyzed_plan_eq; - use arrow::datatypes::Field; - use datafusion_common::tree_node::TreeNode; + use arrow::array::{FixedSizeListArray, Int32Array}; + use arrow::datatypes::{DataType, Field, TimeUnit}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{DFField, DFSchema, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; + use datafusion_expr::logical_plan::{EmptyRelation, Projection}; use datafusion_expr::{ - cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, - AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, - SimpleAggregateUDF, Subquery, - }; - use datafusion_expr::{ - lit, - logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ScalarUDF, Signature, Volatility, + cast, col, concat, concat_ws, create_udaf, is_true, lit, + AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, + BuiltinScalarFunction, Case, ColumnarValue, Expr, ExprSchemable, Filter, + LogicalPlan, Operator, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + Subquery, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; - use crate::analyzer::type_coercion::{ - cast_expr, coerce_case_expression, TypeCoercion, TypeCoercionRewriter, - }; - use crate::test::assert_analyzed_plan_eq; - fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, @@ -1278,7 +1272,7 @@ mod test { std::collections::HashMap::new(), )?); let mut rewriter = TypeCoercionRewriter { schema }; - let result = expr.rewrite(&mut rewriter)?.data; + let result = expr.rewrite(&mut rewriter).data()?; let schema = Arc::new(DFSchema::new_with_metadata( vec![DFField::new_unqualified( @@ -1313,7 +1307,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?.data; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -1324,7 +1318,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?.data; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -1335,7 +1329,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter)?.data; + let result = expr.rewrite(&mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 323556ad7158..30c184a28e33 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,8 @@ use datafusion_common::tree_node::{ TreeNodeVisitor, }; use datafusion_common::{ - internal_err, Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + internal_datafusion_err, internal_err, Column, DFField, DFSchema, DFSchemaRef, + DataFusionError, Result, }; use datafusion_expr::expr::Alias; use datafusion_expr::logical_plan::{ @@ -680,12 +681,9 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn f_up(&mut self, expr: &Expr) -> Result { self.series_number += 1; - let (idx, sub_expr_desc) = - if let Some((idx, sub_expr_desc)) = self.pop_enter_mark() { - (idx, sub_expr_desc) - } else { - return Ok(TreeNodeRecursion::Continue); - }; + let Some((idx, sub_expr_desc)) = self.pop_enter_mark() else { + return Ok(TreeNodeRecursion::Continue); + }; // skip exprs should not be recognize. if self.expr_mask.ignores(expr) { self.id_array[idx].0 = self.series_number; @@ -787,7 +785,7 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> { self.curr_index += 1; // Skip sub-node of a replaced tree, or without identifier, or is not repeated expr. let expr_set_item = self.expr_set.get(id).ok_or_else(|| { - DataFusionError::Internal("expr_set invalid state".to_string()) + internal_datafusion_err!("expr_set invalid state") })?; if *series_number < self.max_series_number || id.is_empty() diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 09dccc1fc703..fd548ba4948e 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -15,19 +15,20 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::ops::Deref; + use crate::simplify_expressions::{ExprSimplifier, SimplifyContext}; use crate::utils::collect_subquery_cols; + use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Result}; -use datafusion_common::{Column, DFSchemaRef, ScalarValue}; +use datafusion_common::{plan_err, Column, DFSchemaRef, Result, ScalarValue}; use datafusion_expr::expr::{AggregateFunctionDefinition, Alias}; use datafusion_expr::utils::{conjunction, find_join_exprs, split_conjunction}; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; use datafusion_physical_expr::execution_props::ExecutionProps; -use std::collections::{BTreeSet, HashMap}; -use std::ops::Deref; /// This struct rewrite the sub query plan by pull up the correlated expressions(contains outer reference columns) from the inner subquery's 'Filter'. /// It adds the inner reference columns to the 'Projection' or 'Aggregate' of the subquery if they are missing, so that they can be evaluated by the parent operator as the join condition. @@ -396,8 +397,8 @@ fn agg_exprs_evaluation_result_on_empty_batch( _ => Transformed::no(expr), }; Ok(new_expr) - })? - .data; + }) + .data()?; let result_expr = result_expr.unalias(); let props = ExecutionProps::new(); @@ -432,8 +433,9 @@ fn proj_exprs_evaluation_result_on_empty_batch( } else { Ok(Transformed::no(expr)) } - })? - .data; + }) + .data()?; + if result_expr.ne(expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema.clone()); @@ -468,8 +470,9 @@ fn filter_exprs_evaluation_result_on_empty_batch( } else { Ok(Transformed::no(expr)) } - })? - .data; + }) + .data()?; + let pull_up_expr = if result_expr.ne(filter_expr) { let props = ExecutionProps::new(); let info = SimplifyContext::new(&props).with_schema(schema); diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs index 013ab83aaa95..b94cf37c5c12 100644 --- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs +++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs @@ -15,12 +15,17 @@ // specific language governing permissions and limitations // under the License. +use std::collections::BTreeSet; +use std::ops::Deref; +use std::sync::Arc; + use crate::decorrelate::PullUpCorrelatedExpr; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; -use datafusion_common::tree_node::TreeNode; +use datafusion_common::tree_node::{TransformedResult, TreeNode}; use datafusion_common::{plan_err, Result}; use datafusion_expr::expr::{Exists, InSubquery}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; @@ -30,10 +35,8 @@ use datafusion_expr::{ exists, in_subquery, not_exists, not_in_subquery, BinaryExpr, Expr, Filter, LogicalPlan, LogicalPlanBuilder, Operator, }; + use log::debug; -use std::collections::BTreeSet; -use std::ops::Deref; -use std::sync::Arc; /// Optimizer rule for rewriting predicate(IN/EXISTS) subquery to left semi/anti joins #[derive(Default)] @@ -228,7 +231,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery.clone().rewrite(&mut pull_up)?.data; + let new_plan = subquery.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -321,8 +324,11 @@ impl SubqueryInfo { #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; use datafusion_expr::{ @@ -330,7 +336,6 @@ mod tests { logical_plan::LogicalPlanBuilder, not_exists, not_in_subquery, or, out_ref_col, Operator, }; - use std::ops::Add; fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( diff --git a/datafusion/optimizer/src/plan_signature.rs b/datafusion/optimizer/src/plan_signature.rs index 8b8814192d38..4143d52a053e 100644 --- a/datafusion/optimizer/src/plan_signature.rs +++ b/datafusion/optimizer/src/plan_signature.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use std::{ collections::hash_map::DefaultHasher, hash::{Hash, Hasher}, num::NonZeroUsize, }; +use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; use datafusion_expr::LogicalPlan; /// Non-unique identifier of a [`LogicalPlan`]. diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 4501733b00a6..8acc36e479ca 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,21 +15,23 @@ // specific language governing permissions and limitations // under the License. +use std::collections::{BTreeSet, HashMap}; +use std::sync::Arc; + use crate::decorrelate::{PullUpCorrelatedExpr, UN_MATCHED_ROW_INDICATOR}; use crate::optimizer::ApplyOrder; use crate::utils::replace_qualified_name; use crate::{OptimizerConfig, OptimizerRule}; + use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; use datafusion_common::{plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; use datafusion_expr::{expr, EmptyRelation, Expr, LogicalPlan, LogicalPlanBuilder}; -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; /// Optimizer rule for rewriting subquery filters to joins #[derive(Default)] @@ -56,8 +58,11 @@ impl ScalarSubqueryToJoin { sub_query_info: vec![], alias_gen, }; - let new_expr = predicate.clone().rewrite(&mut extract)?.data; - Ok((extract.sub_query_info, new_expr)) + predicate + .clone() + .rewrite(&mut extract) + .data() + .map(|new_expr| (extract.sub_query_info, new_expr)) } } @@ -100,8 +105,8 @@ impl OptimizerRule for ScalarSubqueryToJoin { } else { Ok(Transformed::no(expr)) } - })? - .data; + }) + .data()?; } cur_input = optimized_subquery; } else { @@ -157,8 +162,8 @@ impl OptimizerRule for ScalarSubqueryToJoin { } else { Ok(Transformed::no(expr)) } - })? - .data; + }) + .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } } @@ -283,7 +288,7 @@ fn build_join( collected_count_expr_map: Default::default(), pull_up_having_expr: None, }; - let new_plan = subquery_plan.clone().rewrite(&mut pull_up)?.data; + let new_plan = subquery_plan.clone().rewrite(&mut pull_up).data()?; if !pull_up.can_pull_up { return Ok(None); } @@ -372,15 +377,17 @@ fn build_join( #[cfg(test)] mod tests { + use std::ops::Add; + use super::*; use crate::test::*; + use arrow::datatypes::DataType; use datafusion_common::Result; + use datafusion_expr::logical_plan::LogicalPlanBuilder; use datafusion_expr::{ - col, lit, logical_plan::LogicalPlanBuilder, max, min, out_ref_col, - scalar_subquery, sum, Between, + col, lit, max, min, out_ref_col, scalar_subquery, sum, Between, }; - use std::ops::Add; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 9b0224cafd2a..6b5dd1b4681e 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -19,16 +19,21 @@ use std::ops::Not; +use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; +use super::utils::*; +use crate::analyzer::type_coercion::TypeCoercionRewriter; +use crate::simplify_expressions::guarantees::GuaranteeRewriter; +use crate::simplify_expressions::regex::simplify_regex_expr; +use crate::simplify_expressions::SimplifyInfo; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; - -use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, - tree_node::{TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{ internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -40,14 +45,6 @@ use datafusion_expr::{ use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps}; -use crate::analyzer::type_coercion::TypeCoercionRewriter; -use crate::simplify_expressions::guarantees::GuaranteeRewriter; -use crate::simplify_expressions::regex::simplify_regex_expr; -use crate::simplify_expressions::SimplifyInfo; - -use super::inlist_simplifier::{InListSimplifier, ShortenInListSimplifier}; -use super::utils::*; - /// This structure handles API for expression simplification pub struct ExprSimplifier { info: S, @@ -132,36 +129,34 @@ impl ExprSimplifier { /// let expr = simplifier.simplify(expr).unwrap(); /// assert_eq!(expr, b_lt_2); /// ``` - pub fn simplify(&self, expr: Expr) -> Result { + pub fn simplify(&self, mut expr: Expr) -> Result { let mut simplifier = Simplifier::new(&self.info); let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?; let mut shorten_in_list_simplifier = ShortenInListSimplifier::new(); let mut inlist_simplifier = InListSimplifier::new(); let mut guarantee_rewriter = GuaranteeRewriter::new(&self.guarantees); - let expr = if self.canonicalize { - expr.rewrite(&mut Canonicalizer::new())?.data - } else { - expr - }; + if self.canonicalize { + expr = expr.rewrite(&mut Canonicalizer::new()).data()? + } // TODO iterate until no changes are made during rewrite // (evaluating constants can enable new simplifications and // simplifications can enable new constant evaluation) // https://github.com/apache/arrow-datafusion/issues/1160 - expr.rewrite(&mut const_evaluator)? - .data - .rewrite(&mut simplifier)? - .data - .rewrite(&mut inlist_simplifier)? - .data - .rewrite(&mut shorten_in_list_simplifier)? - .data - .rewrite(&mut guarantee_rewriter)? - .data + expr.rewrite(&mut const_evaluator) + .data()? + .rewrite(&mut simplifier) + .data()? + .rewrite(&mut inlist_simplifier) + .data()? + .rewrite(&mut shorten_in_list_simplifier) + .data()? + .rewrite(&mut guarantee_rewriter) + .data()? // run both passes twice to try an minimize simplifications that we missed - .rewrite(&mut const_evaluator)? - .data + .rewrite(&mut const_evaluator) + .data()? .rewrite(&mut simplifier) .data() } @@ -1372,16 +1367,15 @@ mod tests { sync::Arc, }; + use super::*; + use crate::simplify_expressions::SimplifyContext; + use crate::test::test_table_scan_with_name; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{assert_contains, DFField, ToDFSchema}; use datafusion_expr::{interval_arithmetic::Interval, *}; use datafusion_physical_expr::execution_props::ExecutionProps; - use crate::simplify_expressions::SimplifyContext; - use crate::test::test_table_scan_with_name; - - use super::*; - // ------------------------------ // --- ExprSimplifier tests ----- // ------------------------------ diff --git a/datafusion/optimizer/src/simplify_expressions/guarantees.rs b/datafusion/optimizer/src/simplify_expressions/guarantees.rs index 9f8553cb0cc2..6eb583257dcb 100644 --- a/datafusion/optimizer/src/simplify_expressions/guarantees.rs +++ b/datafusion/optimizer/src/simplify_expressions/guarantees.rs @@ -21,8 +21,8 @@ use std::{borrow::Cow, collections::HashMap}; -use datafusion_common::tree_node::Transformed; -use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; +use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::interval_arithmetic::{Interval, NullableInterval}; use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; @@ -204,8 +204,8 @@ mod tests { use super::*; use arrow::datatypes::DataType; - use datafusion_common::tree_node::TransformedResult; - use datafusion_common::{tree_node::TreeNode, ScalarValue}; + use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::ScalarValue; use datafusion_expr::{col, lit, Operator}; #[test] diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 8cbb321c2755..fa1d7cfc1239 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -17,6 +17,9 @@ //! This module implements a rule that simplifies the values for `InList`s +use super::utils::{is_null, lit_bool_null}; +use super::THRESHOLD_INLINE_INLIST; + use std::borrow::Cow; use std::collections::HashSet; @@ -25,9 +28,6 @@ use datafusion_common::{Result, ScalarValue}; use datafusion_expr::expr::{InList, InSubquery}; use datafusion_expr::{lit, BinaryExpr, Expr, Operator}; -use super::utils::{is_null, lit_bool_null}; -use super::THRESHOLD_INLINE_INLIST; - pub(super) struct ShortenInListSimplifier {} impl ShortenInListSimplifier { diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 9cc34c9b1611..196a35ee9ae8 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,8 +18,13 @@ //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. + +use std::cmp::Ordering; +use std::sync::Arc; + use crate::optimizer::ApplyOrder; use crate::{OptimizerConfig, OptimizerRule}; + use arrow::datatypes::{ DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; @@ -32,8 +37,6 @@ use datafusion_expr::utils::merge_schema; use datafusion_expr::{ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator, }; -use std::cmp::Ordering; -use std::sync::Arc; /// [`UnwrapCastInComparison`] attempts to remove casts from /// comparisons to literals ([`ScalarValue`]s) by applying the casts @@ -472,15 +475,17 @@ fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option`]s that are known /// to have the same value for all tuples in a relation. These are generated by @@ -479,15 +481,14 @@ impl EquivalenceGroup { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::create_test_params; use crate::equivalence::{EquivalenceClass, EquivalenceGroup}; - use crate::expressions::lit; - use crate::expressions::Column; - use crate::expressions::Literal; - use datafusion_common::Result; - use datafusion_common::ScalarValue; - use std::sync::Arc; + use crate::expressions::{lit, Column, Literal}; + + use datafusion_common::{Result, ScalarValue}; #[test] fn test_bridge_groups() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/mod.rs b/datafusion/physical-expr/src/equivalence/mod.rs index 6b928ea24c6b..46909f23616f 100644 --- a/datafusion/physical-expr/src/equivalence/mod.rs +++ b/datafusion/physical-expr/src/equivalence/mod.rs @@ -15,18 +15,22 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + +use crate::expressions::Column; +use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; + mod class; mod ordering; mod projection; mod properties; -use crate::expressions::Column; -use crate::{LexRequirement, PhysicalExpr, PhysicalSortRequirement}; + pub use class::{EquivalenceClass, EquivalenceGroup}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; pub use ordering::OrderingEquivalenceClass; pub use projection::ProjectionMapping; pub use properties::{join_equivalence_properties, EquivalenceProperties}; -use std::sync::Arc; /// This function constructs a duplicate-free `LexOrderingReq` by filtering out /// duplicate entries that have same physical expression inside. For example, @@ -62,19 +66,22 @@ pub fn add_offset_to_expr( #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::expressions::{col, Column}; use crate::PhysicalSortExpr; + use arrow::compute::{lexsort_to_indices, SortColumn}; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Float64Array, RecordBatch, UInt32Array}; use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::{plan_datafusion_err, Result}; + use itertools::izip; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::{Rng, SeedableRng}; - use std::sync::Arc; pub fn output_schema( mapping: &ProjectionMapping, diff --git a/datafusion/physical-expr/src/equivalence/projection.rs b/datafusion/physical-expr/src/equivalence/projection.rs index ad1f754a96d1..c60742c724e4 100644 --- a/datafusion/physical-expr/src/equivalence/projection.rs +++ b/datafusion/physical-expr/src/equivalence/projection.rs @@ -109,6 +109,8 @@ impl ProjectionMapping { #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::equivalence::tests::{ apply_projection, convert_to_orderings, convert_to_orderings_owned, @@ -120,12 +122,13 @@ mod tests { use crate::expressions::{col, BinaryExpr, Literal}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{SortOptions, TimeUnit}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; - use std::sync::Arc; #[test] fn project_orderings() -> Result<()> { diff --git a/datafusion/physical-expr/src/equivalence/properties.rs b/datafusion/physical-expr/src/equivalence/properties.rs index 88550813fe23..f234a1fa08cd 100644 --- a/datafusion/physical-expr/src/equivalence/properties.rs +++ b/datafusion/physical-expr/src/equivalence/properties.rs @@ -15,11 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::CastExpr; -use arrow_schema::SchemaRef; -use datafusion_common::{JoinSide, JoinType, Result}; -use indexmap::{IndexMap, IndexSet}; -use itertools::Itertools; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -27,7 +22,7 @@ use super::ordering::collapse_lex_ordering; use crate::equivalence::{ collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::Literal; +use crate::expressions::{CastExpr, Literal}; use crate::sort_properties::{ExprOrdering, SortProperties}; use crate::{ physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement, @@ -35,8 +30,12 @@ use crate::{ PhysicalSortRequirement, }; -use arrow_schema::SortOptions; +use arrow_schema::{SchemaRef, SortOptions}; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::{JoinSide, JoinType, Result}; + +use indexmap::{IndexMap, IndexSet}; +use itertools::Itertools; /// A `EquivalenceProperties` object stores useful information related to a schema. /// Currently, it keeps track of: @@ -1298,10 +1297,12 @@ mod tests { use crate::expressions::{col, BinaryExpr, Column}; use crate::functions::create_physical_expr; use crate::PhysicalSortExpr; + use arrow::datatypes::{DataType, Field, Schema}; use arrow_schema::{Fields, SortOptions, TimeUnit}; use datafusion_common::Result; use datafusion_expr::{BuiltinScalarFunction, Operator}; + use itertools::Itertools; #[test] diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index e6ce8316c27e..609349509b86 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,18 +19,18 @@ use std::borrow::Cow; use std::hash::{Hash, Hasher}; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::expressions::NoOp; +use crate::expressions::{try_cast, NoOp}; use crate::physical_expr::down_cast_any_ref; use crate::PhysicalExpr; + use arrow::array::*; use arrow::compute::kernels::cmp::eq; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_null, not, nullif, or, prep_null_mask_filter}; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::{cast::as_boolean_array, internal_err, DataFusionError, Result}; -use datafusion_common::{exec_err, ScalarValue}; +use datafusion_common::cast::as_boolean_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::ColumnarValue; use itertools::Itertools; @@ -414,17 +414,15 @@ pub fn case( #[cfg(test)] mod tests { use super::*; - use crate::expressions::col; - use crate::expressions::lit; - use crate::expressions::{binary, cast}; + use crate::expressions::{binary, cast, col, lit}; + use arrow::array::StringArray; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; use arrow::datatypes::*; use datafusion_common::cast::{as_float64_array, as_int32_array}; - use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; - use datafusion_common::ScalarValue; + use datafusion_common::{plan_err, ScalarValue}; use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr::Operator; diff --git a/datafusion/physical-expr/src/tree_node.rs b/datafusion/physical-expr/src/tree_node.rs index 0e2aa7d63679..42dc6673af6a 100644 --- a/datafusion/physical-expr/src/tree_node.rs +++ b/datafusion/physical-expr/src/tree_node.rs @@ -63,8 +63,7 @@ impl ExprContext { pub fn update_expr_from_children(mut self) -> Result { let children_expr = self.children.iter().map(|c| c.expr.clone()).collect(); - let t = with_new_children_if_necessary(self.expr, children_expr)?; - self.expr = t; + self.expr = with_new_children_if_necessary(self.expr, children_expr)?; Ok(self) } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 0b0dca6bb4b6..b8e99403d695 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -130,10 +130,10 @@ pub fn get_indices_of_exprs_strict>>( pub type ExprTreeNode = ExprContext>; -/// This struct is used to convert a [PhysicalExpr] tree into a DAEG (i.e. an expression +/// This struct is used to convert a [`PhysicalExpr`] tree into a DAEG (i.e. an expression /// DAG) by collecting identical expressions in one node. Caller specifies the node type /// in the DAEG via the `constructor` argument, which constructs nodes in the DAEG from -/// the [ExprTreeNode] ancillary object. +/// the [`ExprTreeNode`] ancillary object. struct PhysicalExprDAEGBuilder<'a, T, F: Fn(&ExprTreeNode) -> Result> { // The resulting DAEG (expression DAG). graph: StableGraph, @@ -196,8 +196,8 @@ where }; // Use the builder to transform the expression tree node into a DAG. let root = init - .transform_up_mut(&mut |node| builder.mutate(node))? - .data; + .transform_up_mut(&mut |node| builder.mutate(node)) + .data()?; // Return a tuple containing the root node index and the DAG. Ok((root.data.unwrap(), builder.graph)) } diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 3484ee45ba6a..9824c723d9d1 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -31,7 +31,7 @@ use arrow::compute::concat_batches; use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, RecordBatch}; use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder}; use arrow_schema::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ arrow_datafusion_err, plan_datafusion_err, DataFusionError, JoinSide, Result, ScalarValue, @@ -292,8 +292,8 @@ pub fn convert_sort_expr_with_filter_schema( None => Transformed::no(p), } }) - })? - .data; + }) + .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact // match is found, use this sorted expression in graph traversals. if check_filter_expr_contains_sort_information( diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index 083c2f03be7b..1cb2b100e2d6 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -39,6 +39,7 @@ use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use datafusion_common::cast::as_boolean_array; use datafusion_common::stats::Precision; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{ plan_err, DataFusionError, JoinSide, JoinType, Result, SharedResult, }; @@ -50,7 +51,6 @@ use datafusion_physical_expr::{ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; use futures::future::{BoxFuture, Shared}; use futures::{ready, FutureExt}; use hashbrown::raw::RawTable; @@ -484,8 +484,8 @@ fn replace_on_columns_of_right_ordering( } else { Ok(Transformed::no(e)) } - })? - .data; + }) + .data()?; item.expr = new_expr; } } diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 9d4edd6c0de4..a9a7070023ab 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -661,7 +661,7 @@ pub fn with_new_children_if_necessary( .zip(old_children.iter()) .any(|(c1, c2)| !Arc::data_ptr_eq(c1, c2)) { - Ok(plan.with_new_children(children)?) + plan.with_new_children(children) } else { Ok(plan) } diff --git a/datafusion/physical-plan/src/tree_node.rs b/datafusion/physical-plan/src/tree_node.rs index 6fd88160468c..52a52f81bdaf 100644 --- a/datafusion/physical-plan/src/tree_node.rs +++ b/datafusion/physical-plan/src/tree_node.rs @@ -63,8 +63,7 @@ impl PlanContext { pub fn update_plan_from_children(mut self) -> Result { let children_plans = self.children.iter().map(|c| c.plan.clone()).collect(); - let t = with_new_children_if_necessary(self.plan, children_plans)?; - self.plan = t; + self.plan = with_new_children_if_necessary(self.plan, children_plans)?; Ok(self) } } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index d6f53a73dcb1..abb896ab113e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -17,19 +17,19 @@ //! SQL Utility Functions +use std::collections::HashMap; + use arrow_schema::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use sqlparser::ast::Ident; - -use datafusion_common::{exec_err, internal_err, plan_err}; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{ + exec_err, internal_err, plan_err, DataFusionError, Result, ScalarValue, +}; use datafusion_expr::expr::{Alias, GroupingSet, WindowFunction}; -use datafusion_expr::expr_vec_fmt; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; -use datafusion_expr::{Expr, LogicalPlan}; -use std::collections::HashMap; +use datafusion_expr::{expr_vec_fmt, Expr, LogicalPlan}; +use sqlparser::ast::Ident; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { diff --git a/datafusion/sqllogictest/test_files/group_by.slt b/datafusion/sqllogictest/test_files/group_by.slt index 79877fa421e3..906926a5a9ab 100644 --- a/datafusion/sqllogictest/test_files/group_by.slt +++ b/datafusion/sqllogictest/test_files/group_by.slt @@ -3214,7 +3214,7 @@ JOIN sales_global AS e ON s.currency = e.currency AND s.ts >= e.ts GROUP BY s.sn, s.zip_code, s.country, s.ts, s.currency -ORDER BY s.sn +ORDER BY s.sn, s.zip_code ---- 0 GRC 0 2022-01-01T06:00:00 EUR 30 0 GRC 4 2022-01-03T10:00:00 EUR 80