diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 01e3d27c580f..caf94a2c88bd 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -26,7 +26,7 @@ use datafusion_expr::utils::conjunction; use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; -/// The FilterNullJoinKeys rule will identify joins with equi-join conditions +/// The FilterNullJoinKeys rule will identify joins with equi-join conditions /// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values /// can never match. #[derive(Default)] @@ -50,7 +50,9 @@ impl OptimizerRule for FilterNullJoinKeys { return Ok(Transformed::no(plan)); } match plan { - LogicalPlan::Join(mut join) if !join.on.is_empty() => { + LogicalPlan::Join(mut join) + if !join.on.is_empty() && !join.null_equals_null => + { let (left_preserved, right_preserved) = on_lr_is_preserved(join.join_type); diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index 5292b66197f6..da5e92eafd11 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -177,15 +177,12 @@ fn intersect() -> Result<()> { let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ - \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ - \n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ - \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ - \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ - \n TableScan: test projection=[col_int32, col_utf8]\ - \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ - \n TableScan: test projection=[col_int32, col_utf8]\ - \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ - \n TableScan: test projection=[col_int32, col_utf8]"; + \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ + \n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ + \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan}")); Ok(()) } @@ -281,11 +278,9 @@ fn test_same_name_but_not_ambiguous() { let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ \n SubqueryAlias: t1\ - \n Filter: test.col_int32 IS NOT NULL\ - \n TableScan: test projection=[col_int32]\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ - \n Filter: test.col_int32 IS NOT NULL\ - \n TableScan: test projection=[col_int32]"; + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan}")); }