Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized push down filter #10291 #10366

Merged
merged 1 commit into from
May 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 81 additions & 58 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
use std::collections::{HashMap, HashSet};
use std::sync::Arc;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use itertools::Itertools;

use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
Expand All @@ -29,6 +28,7 @@ use datafusion_common::{
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
};
Expand All @@ -38,7 +38,8 @@ use datafusion_expr::{
ScalarFunctionDefinition, TableProviderFilterPushDown,
};

use itertools::Itertools;
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};

/// Optimizer rule for pushing (moving) filter expressions down in a plan so
/// they are applied as early as possible.
Expand Down Expand Up @@ -407,7 +408,7 @@ fn push_down_all_join(
right: &LogicalPlan,
on_filter: Vec<Expr>,
is_inner_join: bool,
) -> Result<LogicalPlan> {
) -> Result<Transformed<LogicalPlan>> {
let on_filter_empty = on_filter.is_empty();
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
Expand Down Expand Up @@ -505,41 +506,43 @@ fn push_down_all_join(
// wrap the join on the filter whose predicates must be kept
match conjunction(keep_predicates) {
Some(predicate) => {
Filter::try_new(predicate, Arc::new(plan)).map(LogicalPlan::Filter)
let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
}
None => Ok(plan),
None => Ok(Transformed::no(plan)),
}
}

fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
) -> Result<Option<LogicalPlan>> {
let predicates = match parent_predicate {
Some(parent_predicate) => split_conjunction_owned(parent_predicate.clone()),
None => vec![],
};
) -> Result<Transformed<LogicalPlan>> {
// Split the parent predicate into individual conjunctive parts.
let predicates = parent_predicate
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));

// Convert JOIN ON predicate to Predicates
// Extract conjunctions from the JOIN's ON filter, if present.
let on_filters = join
.filter
.as_ref()
.map(|e| split_conjunction_owned(e.clone()))
.unwrap_or_default();
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));

let mut is_inner_join = false;
let infer_predicates = if join.join_type == JoinType::Inner {
is_inner_join = true;

// Only allow both side key is column.
let join_col_keys = join
.on
.iter()
.flat_map(|(l, r)| match (l.try_into_col(), r.try_into_col()) {
(Ok(l_col), Ok(r_col)) => Some((l_col, r_col)),
_ => None,
.filter_map(|(l, r)| {
let left_col = l.try_into_col().ok()?;
let right_col = r.try_into_col().ok()?;
Some((left_col, right_col))
})
.collect::<Vec<_>>();

// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
// For inner joins, duplicate filters for joined columns so filters can be pushed down
// to both sides. Take the following query as an example:
Expand All @@ -559,6 +562,7 @@ fn push_down_join(
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();

let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
Expand Down Expand Up @@ -596,20 +600,32 @@ fn push_down_join(
};

if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
return Ok(None);
return Ok(Transformed::no(plan.clone()));
}
Ok(Some(push_down_all_join(

match push_down_all_join(
predicates,
infer_predicates,
plan,
&join.left,
&join.right,
on_filters,
is_inner_join,
)?))
) {
Ok(plan) => Ok(Transformed::yes(plan.data)),
Err(e) => Err(e),
}
}

impl OptimizerRule for PushDownFilter {
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called PushDownFilter::rewrite")
}

fn name(&self) -> &str {
"push_down_filter"
}
Expand All @@ -618,21 +634,24 @@ impl OptimizerRule for PushDownFilter {
Some(ApplyOrder::TopDown)
}

fn try_optimize(
fn supports_rewrite(&self) -> bool {
true
}

fn rewrite(
&self,
plan: &LogicalPlan,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
) -> Result<Transformed<LogicalPlan>> {
let filter = match plan {
LogicalPlan::Filter(filter) => filter,
// we also need to pushdown filter in Join.
LogicalPlan::Join(join) => return push_down_join(plan, join, None),
_ => return Ok(None),
LogicalPlan::Filter(ref filter) => filter,
LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None),
_ => return Ok(Transformed::no(plan)),
};

let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
LogicalPlan::Filter(child_filter) => {
LogicalPlan::Filter(ref child_filter) => {
let parents_predicates = split_conjunction(&filter.predicate);
let set: HashSet<&&Expr> = parents_predicates.iter().collect();

Expand All @@ -652,20 +671,18 @@ impl OptimizerRule for PushDownFilter {
new_predicate,
child_filter.input.clone(),
)?);
self.try_optimize(&new_filter, _config)?
.unwrap_or(new_filter)
self.rewrite(new_filter, _config)?.data
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
// commutable
let new_filter = plan.with_new_exprs(
plan.expressions(),
vec![child_plan.inputs()[0].clone()],
)?;
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
LogicalPlan::SubqueryAlias(ref subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
Expand All @@ -685,7 +702,7 @@ impl OptimizerRule for PushDownFilter {
)?);
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::Projection(projection) => {
LogicalPlan::Projection(ref projection) => {
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
// collect projection.
Expand Down Expand Up @@ -742,10 +759,10 @@ impl OptimizerRule for PushDownFilter {
}
}
}
None => return Ok(None),
None => return Ok(Transformed::no(plan)),
}
}
LogicalPlan::Union(union) => {
LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
Expand All @@ -770,7 +787,7 @@ impl OptimizerRule for PushDownFilter {
schema: plan.schema().clone(),
})
}
LogicalPlan::Aggregate(agg) => {
LogicalPlan::Aggregate(ref agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
Expand Down Expand Up @@ -821,13 +838,15 @@ impl OptimizerRule for PushDownFilter {
None => new_agg,
}
}
LogicalPlan::Join(join) => {
match push_down_join(&filter.input, join, Some(&filter.predicate))? {
Some(optimized_plan) => optimized_plan,
None => return Ok(None),
}
LogicalPlan::Join(ref join) => {
push_down_join(
&unwrap_arc(filter.clone().input),
join,
Some(&filter.predicate),
)?
.data
}
LogicalPlan::CrossJoin(cross_join) => {
LogicalPlan::CrossJoin(ref cross_join) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let join = convert_cross_join_to_inner_join(cross_join.clone())?;
let join_plan = LogicalPlan::Join(join);
Expand All @@ -843,9 +862,9 @@ impl OptimizerRule for PushDownFilter {
vec![],
true,
)?;
convert_to_cross_join_if_beneficial(plan)?
convert_to_cross_join_if_beneficial(plan.data)?
}
LogicalPlan::TableScan(scan) => {
LogicalPlan::TableScan(ref scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
Expand Down Expand Up @@ -892,7 +911,7 @@ impl OptimizerRule for PushDownFilter {
None => new_scan,
}
}
LogicalPlan::Extension(extension_plan) => {
LogicalPlan::Extension(ref extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();

Expand Down Expand Up @@ -935,9 +954,10 @@ impl OptimizerRule for PushDownFilter {
None => new_extension,
}
}
_ => return Ok(None),
_ => return Ok(Transformed::no(plan)),
};
Ok(Some(new_plan))

Ok(Transformed::yes(new_plan))
}
}

Expand Down Expand Up @@ -1024,16 +1044,12 @@ fn contain(e: &Expr, check_map: &HashMap<String, Expr>) -> bool {

#[cfg(test)]
mod tests {
use super::*;
use std::any::Any;
use std::fmt::{Debug, Formatter};

use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;

use datafusion_common::ScalarValue;
use datafusion_expr::expr::ScalarFunction;
use datafusion_expr::logical_plan::table_scan;
Expand All @@ -1043,7 +1059,13 @@ mod tests {
Volatility,
};

use async_trait::async_trait;
use crate::optimizer::Optimizer;
use crate::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use crate::test::*;
use crate::OptimizerContext;

use super::*;

fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {}

fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -2298,9 +2320,9 @@ mod tests {
table_scan_with_pushdown_provider(TableProviderFilterPushDown::Inexact)?;

let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.unwrap();
.data;

let expected = "\
Filter: a = Int64(1)\
Expand Down Expand Up @@ -2667,8 +2689,9 @@ Projection: a, b
// Originally global state which can help to avoid duplicate Filters been generated and pushed down.
// Now the global state is removed. Need to double confirm that avoid duplicate Filters.
let optimized_plan = PushDownFilter::new()
.try_optimize(&plan, &OptimizerContext::new())?
.expect("failed to optimize plan");
.rewrite(plan, &OptimizerContext::new())
.expect("failed to optimize plan")
.data;
assert_optimized_plan_eq(optimized_plan, expected)
}

Expand Down