Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LogicalPlan::recompute_schema for handling rewrite passes #10410

Closed
wants to merge 5 commits into from
Closed
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 311 additions & 17 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ use datafusion_common::{

// backwards compatibility
use crate::display::PgJsonVisitor;
use crate::logical_plan::tree_node::unwrap_arc;
pub use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan};
pub use datafusion_common::{JoinConstraint, JoinType};

Expand Down Expand Up @@ -467,6 +468,200 @@ impl LogicalPlan {
self.with_new_exprs(self.expressions(), inputs.to_vec())
}

/// Recomputes schema and type information for this LogicalPlan if needed.
///
/// Some `LogicalPlan`s may need to recompute their schema if the number or
/// type of expressions have been changed (for example due to type
/// coercion). For example [`LogicalPlan::Projection`]s schema depends on
/// its expressions.
///
/// Some `LogicalPlan`s schema is unaffected by any changes to their
/// expressions. For example [`LogicalPlan::Filter`] schema is always the
/// same as its input schema.
///
/// # Return value
/// Returns an error if there is some issue recomputing the schema.
///
/// # Notes
///
/// * Does not recursively recompute schema for input (child) plans.
pub fn recompute_schema(self) -> Result<Self> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All this logic is the same as new_with_expr, even some questionable code like for Filter and Union

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a question which may be stupid, what's the difference between this and getting the inputs and exprs of the LogicalPlan and pass them as params to new_with_expr? 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key difference is that to use new_with_expr it requires providing a copy (Vec<Expr> and Vec<LogicalPlan>) where as this function can be used after modifying the Exprs (or children) via methods such as map_children and map_expressions

Avoiding those copies is a key part of improve planning performance in PRs like #10356 (the change is basically to stop copying)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed bd7b62f to try and clarify

match self {
// Since expr may be different than the previous expr, schema of the projection
// may change. We need to use try_new method instead of try_new_with_schema method.
LogicalPlan::Projection(Projection {
expr,
input,
schema: _,
}) => Projection::try_new(expr, input).map(LogicalPlan::Projection),
LogicalPlan::Dml(_) => Ok(self),
LogicalPlan::Copy(_) => Ok(self),
LogicalPlan::Values(Values { schema, values }) => {
// todo it isn't clear why the schema is not recomputed here
Ok(LogicalPlan::Values(Values { schema, values }))
}
LogicalPlan::Filter(Filter { predicate, input }) => {
// todo: should this logic be moved to Filter::try_new?

// filter predicates should not contain aliased expressions so we remove any aliases
// before this logic was added we would have aliases within filters such as for
// benchmark q6:
//
// lineitem.l_shipdate >= Date32(\"8766\")
// AND lineitem.l_shipdate < Date32(\"9131\")
// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount >=
// Decimal128(Some(49999999999999),30,15)
// AND CAST(lineitem.l_discount AS Decimal128(30, 15)) AS lineitem.l_discount <=
// Decimal128(Some(69999999999999),30,15)
// AND lineitem.l_quantity < Decimal128(Some(2400),15,2)

let predicate = predicate
.transform_down(|expr| {
match expr {
Expr::Exists { .. }
| Expr::ScalarSubquery(_)
| Expr::InSubquery(_) => {
// subqueries could contain aliases so we don't recurse into those
Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump))
}
Expr::Alias(_) => Ok(Transformed::new(
expr.unalias(),
true,
TreeNodeRecursion::Jump,
)),
_ => Ok(Transformed::no(expr)),
}
})
.data()?;

Filter::try_new(predicate, input).map(LogicalPlan::Filter)
}
LogicalPlan::Repartition(_) => Ok(self),
LogicalPlan::Window(Window {
input,
window_expr,
schema: _,
}) => Window::try_new(window_expr, input).map(LogicalPlan::Window),
LogicalPlan::Aggregate(Aggregate {
input,
group_expr,
aggr_expr,
schema: _,
}) => Aggregate::try_new(input, group_expr, aggr_expr)
.map(LogicalPlan::Aggregate),
LogicalPlan::Sort(_) => Ok(self),
LogicalPlan::Join(Join {
left,
right,
filter,
join_type,
join_constraint,
on,
schema: _,
null_equals_null,
}) => {
let schema =
build_join_schema(left.schema(), right.schema(), &join_type)?;

let new_on: Vec<_> = on
.into_iter()
.map(|equi_expr| {
// SimplifyExpression rule may add alias to the equi_expr.
(equi_expr.0.unalias(), equi_expr.1.unalias())
})
.collect();

Ok(LogicalPlan::Join(Join {
left,
right,
join_type,
join_constraint,
on: new_on,
filter,
schema: DFSchemaRef::new(schema),
null_equals_null,
}))
}
LogicalPlan::CrossJoin(CrossJoin {
left,
right,
schema: _,
}) => {
let join_schema =
build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;

Ok(LogicalPlan::CrossJoin(CrossJoin {
left,
right,
schema: join_schema.into(),
}))
}
LogicalPlan::Subquery(_) => Ok(self),
LogicalPlan::SubqueryAlias(SubqueryAlias {
input,
alias,
schema: _,
}) => SubqueryAlias::try_new(input, alias).map(LogicalPlan::SubqueryAlias),
LogicalPlan::Limit(_) => Ok(self),
LogicalPlan::Ddl(_) => Ok(self),
LogicalPlan::Extension(Extension { node }) => {
// todo make an API that does not require cloning
// This requires a copy of the extension nodes expressions and inputs
let expr = node.expressions();
let inputs: Vec<_> = node.inputs().into_iter().cloned().collect();
Ok(LogicalPlan::Extension(Extension {
node: node.from_template(&expr, &inputs),
}))
}
LogicalPlan::Union(Union { inputs, schema }) => {
let input_schema = inputs[0].schema();
// If inputs are not pruned do not change schema
// TODO this seems wrong (shouldn't we always use the schema of the input?)
let schema = if schema.fields().len() == input_schema.fields().len() {
schema.clone()
} else {
input_schema.clone()
};
Ok(LogicalPlan::Union(Union { inputs, schema }))
}
LogicalPlan::Distinct(distinct) => {
let distinct = match distinct {
Distinct::All(input) => Distinct::All(input),
Distinct::On(DistinctOn {
on_expr,
select_expr,
sort_expr,
input,
schema: _,
}) => Distinct::On(DistinctOn::try_new(
on_expr,
select_expr,
sort_expr,
input,
)?),
};
Ok(LogicalPlan::Distinct(distinct))
}
LogicalPlan::RecursiveQuery(_) => Ok(self),
LogicalPlan::Analyze(_) => Ok(self),
LogicalPlan::Explain(_) => Ok(self),
LogicalPlan::Prepare(_) => Ok(self),
LogicalPlan::TableScan(_) => Ok(self),
LogicalPlan::EmptyRelation(_) => Ok(self),
LogicalPlan::Statement(_) => Ok(self),
LogicalPlan::DescribeTable(_) => Ok(self),
LogicalPlan::Unnest(Unnest {
input,
columns,
schema: _,
options,
}) => {
// Update schema with unnested column type.
unnest_with_options(unwrap_arc(input), columns, options)
}
}
}

/// Returns a new `LogicalPlan` based on `self` with inputs and
/// expressions replaced.
///
Expand Down Expand Up @@ -2526,30 +2721,48 @@ pub struct Unnest {

#[cfg(test)]
mod tests {

use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, count, exists, in_subquery, lit, placeholder, GroupingSet};
use crate::{
col, count, exists, in_subquery, lit, max, placeholder, sum, GroupingSet,
};
use std::sync::OnceLock;

use datafusion_common::tree_node::TreeNodeVisitor;
use datafusion_common::{not_impl_err, Constraint, ScalarValue};

fn employee_schema() -> Schema {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
static EMPLOYEE_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn employee_schema() -> &'static Schema {
EMPLOYEE_SCHEMA.get_or_init(|| {
Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("first_name", DataType::Utf8, false),
Field::new("last_name", DataType::Utf8, false),
Field::new("state", DataType::Utf8, false),
Field::new("salary", DataType::Int32, false),
])
})
}

static ID_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn id_schema() -> &'static Schema {
ID_SCHEMA
.get_or_init(|| Schema::new(vec![Field::new("id", DataType::Int32, false)]))
}

static FIRST_NAME_SCHEMA: OnceLock<Schema> = OnceLock::new();
fn first_name_schema() -> &'static Schema {
FIRST_NAME_SCHEMA.get_or_init(|| {
Schema::new(vec![Field::new("first_name", DataType::Utf8, false)])
})
}

fn display_plan() -> Result<LogicalPlan> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))?
.build()?;

table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))?
.filter(in_subquery(col("state"), Arc::new(plan1)))?
.project(vec![col("id")])?
.build()
Expand Down Expand Up @@ -2585,14 +2798,13 @@ mod tests {

#[test]
fn test_display_subquery_alias() -> Result<()> {
let plan1 = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3]))?
let plan1 = table_scan(Some("employee_csv"), employee_schema(), Some(vec![3]))?
.build()?;
let plan1 = Arc::new(plan1);

let plan =
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();
let plan = table_scan(Some("employee_csv"), employee_schema(), Some(vec![0, 3]))?
.project(vec![col("id"), exists(plan1).alias("exists")])?
.build();

let expected = "Projection: employee_csv.id, EXISTS (<subquery>) AS exists\
\n Subquery:\
Expand Down Expand Up @@ -3138,4 +3350,86 @@ digraph {
let actual = format!("{}", plan.display_indent());
assert_eq!(expected.to_string(), actual)
}

#[test]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are some basic tests that show what it does and how to use it.

fn recompute_schema_projection() -> Result<()> {
// SELECT id FROM employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("id")])?
.build()?;
assert_eq!(plan.schema().as_arrow(), id_schema());

// rewrite to SELECT first_name FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(col("first_name"))))?
.data;

// before recompute_schema, the schema is still the same
assert_eq!(plan.schema().as_arrow(), id_schema());
let plan = plan.recompute_schema()?;
assert_eq!(plan.schema().as_arrow(), first_name_schema());

Ok(())
}

#[test]
fn recompute_schema_window() -> Result<()> {
// SELECT id, SUM(salary) OVER () FROM employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("id"), col("salary")])?
.window(vec![sum(col("salary"))])?
.build()?;

// rewrite to SELECT id, MAX(salary) OVER () FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))?
.data;

// before recompute_schema, the schema should be SUM
let expected_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("salary", DataType::Int32, false),
Field::new("SUM(employee_csv.salary)", DataType::Int64, true),
]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);

// after recompute_schema, the schema should be MAX
let plan = plan.recompute_schema()?;
let expected_schema = Schema::new(vec![
Field::new("id", DataType::Int32, false),
Field::new("salary", DataType::Int32, false),
Field::new("MAX(salary)", DataType::Int32, true),
]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);
Ok(())
}

#[test]
fn recompute_schema_aggregate() -> Result<()> {
// SELECT sum(salary) from employee_csv
let plan = table_scan(Some("employee_csv"), employee_schema(), None)?
.project(vec![col("salary")])?
.aggregate(vec![] as Vec<Expr>, vec![sum(col("salary"))])?
.build()?;

// rewrite to MAX(salary) FROM employee_csv
let plan = plan
.map_expressions(|_| Ok(Transformed::yes(max(col("salary")))))?
.data;

// before recompute_schema, the schema should be SUM
let expected_schema = Schema::new(vec![Field::new(
"SUM(employee_csv.salary)",
DataType::Int64,
true,
)]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);

// after recompute_schema, the schema should be MAX
let plan = plan.recompute_schema()?;
let expected_schema =
Schema::new(vec![Field::new("MAX(salary)", DataType::Int32, true)]);
assert_eq!(plan.schema().as_arrow(), &expected_schema);
Ok(())
}
}