From 3100700a4d4c0f2623319afa0a257311e7de19a5 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 20 Feb 2025 14:03:50 +0800 Subject: [PATCH 1/9] feat: Support IntegralDivide function --- .../org/apache/comet/serde/QueryPlanSerde.scala | 14 ++++++++++++++ .../org/apache/comet/CometExpressionSuite.scala | 13 +++++++++++++ 2 files changed, 27 insertions(+) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index aa1aba11d..cb5e353e0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -631,6 +631,20 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } None + case IntegralDivide(left, right, evalMode) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + // convert IntegralDivide(...) to Cast(Divide(...), LongType) + exprToProtoInternal(Cast(Divide(left, right, evalMode), LongType), inputs, binding) + + case div @ IntegralDivide(left, _, _) => + if (!supportedDataType(left.dataType)) { + withInfo(div, s"Unsupported datatype ${left.dataType}") + } + if (decimalBeforeSpark34(left.dataType)) { + withInfo(div, "Decimal support requires Spark 3.4 or later") + } + None + case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => val rightExpr = nullIfWhenPrimitive(right) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index e5edfd56b..48f5d0df8 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2641,4 +2641,17 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("test integral divide") { + withTable("t1", "t2") { + sql(s"create table t1(c1 long, c2 int) using parquet") + // TODO: COMET-1412: Support warping div on overflow for Long.MinValue / -1 + sql(s"insert into t1 values(10, 0), (52, 10)") + checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t1 order by c1") + + sql(s"create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") + sql(s"insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") + checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t2 order by c1") + } + } + } From 75f125a36b6ee0ff7bd76316d1a1258ecf407045 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 20 Feb 2025 14:35:35 +0800 Subject: [PATCH 2/9] format --- .../scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- .../scala/org/apache/comet/CometExpressionSuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cb5e353e0..73caa4c96 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -632,7 +632,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case IntegralDivide(left, right, evalMode) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => // convert IntegralDivide(...) to Cast(Divide(...), LongType) exprToProtoInternal(Cast(Divide(left, right, evalMode), LongType), inputs, binding) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 48f5d0df8..ab2432e9e 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2643,13 +2643,13 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("test integral divide") { withTable("t1", "t2") { - sql(s"create table t1(c1 long, c2 int) using parquet") + sql("create table t1(c1 long, c2 int) using parquet") // TODO: COMET-1412: Support warping div on overflow for Long.MinValue / -1 - sql(s"insert into t1 values(10, 0), (52, 10)") + sql("insert into t1 values(10, 0), (52, 10)") checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t1 order by c1") - sql(s"create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") - sql(s"insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") + sql("create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") + sql("insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t2 order by c1") } } From fef19221651418ce8be13c3b6ca5b5e3a500ed6d Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 20 Feb 2025 16:05:20 +0800 Subject: [PATCH 3/9] fix --- .../scala/org/apache/comet/CometExpressionSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ab2432e9e..cac8b55de 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2648,9 +2648,12 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { sql("insert into t1 values(10, 0), (52, 10)") checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t1 order by c1") - sql("create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") - sql("insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") - checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t2 order by c1") + if (isSpark34Plus) { + // Decimal support requires Spark 3.4 or later + sql("create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") + sql("insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") + checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t2 order by c1") + } } } From 0c1b89945ec9094eef2220c3858082295d565ccb Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Wed, 26 Feb 2025 15:29:22 +0800 Subject: [PATCH 4/9] decimal integral div --- native/core/src/execution/planner.rs | 50 ++++++++++++++++- native/proto/src/proto/expr.proto | 1 + native/spark-expr/benches/decimal_div.rs | 15 +++++- native/spark-expr/src/comet_scalar_funcs.rs | 13 +++-- native/spark-expr/src/lib.rs | 6 +-- native/spark-expr/src/math_funcs/div.rs | 37 +++++++++++-- native/spark-expr/src/math_funcs/mod.rs | 1 + .../apache/comet/serde/QueryPlanSerde.scala | 28 ++++++++-- .../apache/comet/CometExpressionSuite.scala | 54 +++++++++++++++---- .../org/apache/spark/sql/CometTestBase.scala | 10 ++++ 10 files changed, 186 insertions(+), 29 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index f42a9ed19..5c6a9b8c1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -136,6 +136,18 @@ struct JoinParameters { pub join_type: DFJoinType, } +struct BinaryExprOptions { + pub is_integral_div: bool, +} + +impl Default for BinaryExprOptions { + fn default() -> Self { + Self { + is_integral_div: false, + } + } +} + pub const TEST_EXEC_CONTEXT_ID: i64 = -1; /// The query planner for converting Spark query plans to DataFusion query plans. @@ -211,6 +223,16 @@ impl PhysicalPlanner { DataFusionOperator::Divide, input_schema, ), + ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options( + expr.left.as_ref().unwrap(), + expr.right.as_ref().unwrap(), + expr.return_type.as_ref(), + DataFusionOperator::Divide, + input_schema, + BinaryExprOptions { + is_integral_div: true, + }, + ), ExprStruct::Remainder(expr) => self.create_binary_expr( expr.left.as_ref().unwrap(), expr.right.as_ref().unwrap(), @@ -873,6 +895,25 @@ impl PhysicalPlanner { return_type: Option<&spark_expression::DataType>, op: DataFusionOperator, input_schema: SchemaRef, + ) -> Result, ExecutionError> { + self.create_binary_expr_with_options( + left, + right, + return_type, + op, + input_schema, + BinaryExprOptions::default(), + ) + } + + fn create_binary_expr_with_options( + &self, + left: &Expr, + right: &Expr, + return_type: Option<&spark_expression::DataType>, + op: DataFusionOperator, + input_schema: SchemaRef, + options: BinaryExprOptions, ) -> Result, ExecutionError> { let left = self.create_expr(left, Arc::clone(&input_schema))?; let right = self.create_expr(right, Arc::clone(&input_schema))?; @@ -922,13 +963,18 @@ impl PhysicalPlanner { Ok(DataType::Decimal128(_p2, _s2)), ) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); + let func_name = if options.is_integral_div { + "decimal_integral_div" + } else { + "decimal_div" + }; let fun_expr = create_comet_physical_fun( - "decimal_div", + func_name, data_type.clone(), &self.session_ctx.state(), )?; Ok(Arc::new(ScalarFunctionExpr::new( - "decimal_div", + func_name, fun_expr, vec![left, right], data_type, diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index fd928fd8a..71ad8cf3f 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -89,6 +89,7 @@ message Expr { BinaryExpr array_intersect = 62; ArrayJoin array_join = 63; BinaryExpr arrays_overlap = 64; + MathExpr integral_divide = 65; } } diff --git a/native/spark-expr/benches/decimal_div.rs b/native/spark-expr/benches/decimal_div.rs index ad527fecb..1d25d815a 100644 --- a/native/spark-expr/benches/decimal_div.rs +++ b/native/spark-expr/benches/decimal_div.rs @@ -19,7 +19,7 @@ use arrow::compute::cast; use arrow_array::builder::Decimal128Builder; use arrow_schema::DataType; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use datafusion_comet_spark_expr::spark_decimal_div; +use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div}; use datafusion_expr::ColumnarValue; use std::sync::Arc; @@ -40,7 +40,9 @@ fn criterion_benchmark(c: &mut Criterion) { let c2 = cast(c2.as_ref(), &c2_type).unwrap(); let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)]; - c.bench_function("decimal_div", |b| { + + let mut group = c.benchmark_group("decimal div"); + group.bench_function("decimal_div", |b| { b.iter(|| { black_box(spark_decimal_div( black_box(&args), @@ -48,6 +50,15 @@ fn criterion_benchmark(c: &mut Criterion) { )) }) }); + + group.bench_function("decimal_integral_div", |b| { + b.iter(|| { + black_box(spark_decimal_integral_div( + black_box(&args), + black_box(&DataType::Decimal128(10, 4)), + )) + }) + }); } criterion_group!(benches, criterion_benchmark); diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 6070e81d2..227b6f72e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -17,9 +17,9 @@ use crate::hash_funcs::*; use crate::{ - spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex, - spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex, - spark_unscaled_value, SparkChrFunc, + spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div, + spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, + spark_unhex, spark_unscaled_value, SparkChrFunc, }; use arrow_schema::DataType; use datafusion_common::{DataFusionError, Result as DataFusionResult}; @@ -90,6 +90,13 @@ pub fn create_comet_physical_fun( "decimal_div" => { make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) } + "decimal_integral_div" => { + make_comet_scalar_udf!( + "decimal_integral_div", + spark_decimal_integral_div, + data_type + ) + } "murmur3_hash" => { let func = Arc::new(spark_murmur3_hash); make_comet_scalar_udf!("murmur3_hash", func, without data_type) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 9bf6bb24f..ae8e639b3 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -67,9 +67,9 @@ pub use error::{SparkError, SparkResult}; pub use hash_funcs::*; pub use json_funcs::ToJson; pub use math_funcs::{ - create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal, - spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr, - NormalizeNaNAndZero, + create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_hex, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, + NegativeExpr, NormalizeNaNAndZero, }; pub use string_funcs::*; diff --git a/native/spark-expr/src/math_funcs/div.rs b/native/spark-expr/src/math_funcs/div.rs index 72c23b9e9..0a28197c1 100644 --- a/native/spark-expr/src/math_funcs/div.rs +++ b/native/spark-expr/src/math_funcs/div.rs @@ -27,14 +27,29 @@ use datafusion_common::DataFusionError; use num::{BigInt, Signed, ToPrimitive}; use std::sync::Arc; +pub fn spark_decimal_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + spark_decimal_div_internal(args, data_type, false) +} + +pub fn spark_decimal_integral_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + spark_decimal_div_internal(args, data_type, true) +} + // Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). // Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to // get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since // both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale > // Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt. -pub fn spark_decimal_div( +fn spark_decimal_div_internal( args: &[ColumnarValue], data_type: &DataType, + is_integral_div: bool, ) -> Result { let left = &args[0]; let right = &args[1]; @@ -69,10 +84,14 @@ pub fn spark_decimal_div( let l = BigInt::from(l) * &l_mul; let r = BigInt::from(r) * &r_mul; let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; - let res = if div.is_negative() { - div - &five + let res = if is_integral_div { + div } else { - div + &five + if div.is_negative() { + div - &five + } else { + div + &five + } } / &ten; res.to_i128().unwrap_or(i128::MAX) })? @@ -83,7 +102,15 @@ pub fn spark_decimal_div( let l = l * l_mul; let r = r * r_mul; let div = if r == 0 { 0 } else { l / r }; - let res = if div.is_negative() { div - 5 } else { div + 5 } / 10; + let res = if is_integral_div { + div + } else { + if div.is_negative() { + div - 5 + } else { + div + 5 + } + } / 10; res.to_i128().unwrap_or(i128::MAX) })? }; diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index c559ae15c..03eb9a76c 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -27,6 +27,7 @@ mod utils; pub use ceil::spark_ceil; pub use div::spark_decimal_div; +pub use div::spark_decimal_integral_div; pub use floor::spark_floor; pub use hex::spark_hex; pub use internal::*; diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 73caa4c96..7d656567d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -631,10 +631,32 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } None - case IntegralDivide(left, right, evalMode) + case div @ IntegralDivide(left, right, _) if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - // convert IntegralDivide(...) to Cast(Divide(...), LongType) - exprToProtoInternal(Cast(Divide(left, right, evalMode), LongType), inputs, binding) + val rightExpr = nullIfWhenPrimitive(right) + + val dataType = (left.dataType, right.dataType) match { + case (l: DecimalType, r: DecimalType) => + div.resultDecimalType(l.precision, l.scale, r.precision, r.scale) + case _ => left.dataType + } + + val divideExpr = createMathExpression( + expr, + left, + rightExpr, + inputs, + binding, + dataType, + getFailOnError(div), + (builder, mathExpr) => builder.setIntegralDivide(mathExpr)) + + if (divideExpr.isDefined) { + // cast result to long + castToProto(expr, None, LongType, divideExpr.get, CometEvalMode.LEGACY) + } else { + None + } case div @ IntegralDivide(left, _, _) => if (!supportedDataType(left.dataType)) { diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index cac8b55de..74cb6d201 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2642,17 +2642,49 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("test integral divide") { - withTable("t1", "t2") { - sql("create table t1(c1 long, c2 int) using parquet") - // TODO: COMET-1412: Support warping div on overflow for Long.MinValue / -1 - sql("insert into t1 values(10, 0), (52, 10)") - checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t1 order by c1") - - if (isSpark34Plus) { - // Decimal support requires Spark 3.4 or later - sql("create table t2(c1 decimal(10, 2), c2 decimal(10, 2)) using parquet") - sql("insert into t2 values(15.09, 5.0), (13.2, 2), (18.66, 0)") - checkSparkAnswerAndOperator("select c1 div c2, div(c1, c2) from t2 order by c1") + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path1 = new Path(dir.toURI.toString, "test1.parquet") + val path2 = new Path(dir.toURI.toString, "test2.parquet") + makeParquetFileAllTypes( + path1, + dictionaryEnabled = dictionaryEnabled, + 0, + 0, + randomSize = 10000) + makeParquetFileAllTypes( + path2, + dictionaryEnabled = dictionaryEnabled, + 0, + 0, + randomSize = 10000) + withParquetTable(path1.toString, "tbl1") { + withParquetTable(path2.toString, "tbl2") { + checkSparkAnswerAndOperator(""" + |select + | t1._2 div t2._2, div(t1._2, t2._2), + | t1._3 div t2._3, div(t1._3, t2._3), + | t1._4 div t2._4, div(t1._4, t2._4), + | t1._5 div t2._5, div(t1._5, t2._5), + | t1._9 div t2._9, div(t1._9, t2._9), + | t1._10 div t2._10, div(t1._10, t2._10), + | t1._11 div t2._11, div(t1._11, t2._11), + | t1._12 div t2._12, div(t1._12, t2._12) + | from tbl1 t1 join tbl2 t2 on t1._id = t2._id + | order by t1._id""".stripMargin) + + if (isSpark34Plus) { + // Decimal support requires Spark 3.4 or later + checkSparkAnswerAndOperator(""" + |select + | t1._15, t2._15, t1._15 div t2._15, div(t1._15, t2._15), + | t1._16, t2._16, t1._16 div t2._16, div(t1._16, t2._16), + | t1._17, t2._17, t1._17 div t2._17, div(t1._17, t2._17) + | from tbl1 t1 join tbl2 t2 on t1._id = t2._id + | order by t1._id""".stripMargin) + } + } + } } } } diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala index cd5ac7f86..10c5ca210 100644 --- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql +import java.util.concurrent.atomic.AtomicInteger + import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -461,6 +463,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional INT32 _id; |} """.stripMargin } else { @@ -486,6 +489,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional INT32 _id; |} """.stripMargin } @@ -514,6 +518,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional INT32 _id; |} """.stripMargin } else { @@ -539,6 +544,7 @@ abstract class CometTestBase | optional INT64 _18(TIMESTAMP(MILLIS,true)); | optional INT64 _19(TIMESTAMP(MICROS,true)); | optional INT32 _20(DATE); + | optional INT32 _id; |} """.stripMargin } @@ -564,6 +570,8 @@ abstract class CometTestBase pageSize = pageSize, dictionaryPageSize = pageSize) + val idGenerator = new AtomicInteger(0) + val rand = scala.util.Random val data = (begin until end).map { i => if (rand.nextBoolean()) { @@ -596,6 +604,7 @@ abstract class CometTestBase record.add(17, i.toLong) record.add(18, i.toLong) record.add(19, i) + record.add(20, idGenerator.getAndIncrement()) case _ => } writer.write(record) @@ -623,6 +632,7 @@ abstract class CometTestBase record.add(17, i) record.add(18, i) record.add(19, i.toInt) + record.add(20, idGenerator.getAndIncrement()) writer.write(record) } From 11f36eb69cb952d5ec6414d9dff3240f2240b9ce Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Wed, 26 Feb 2025 15:43:05 +0800 Subject: [PATCH 5/9] fix code style --- native/core/src/execution/planner.rs | 9 +-------- native/spark-expr/src/math_funcs/div.rs | 16 ++++++---------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5c6a9b8c1..261df954a 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -136,18 +136,11 @@ struct JoinParameters { pub join_type: DFJoinType, } +#[derive(Default)] struct BinaryExprOptions { pub is_integral_div: bool, } -impl Default for BinaryExprOptions { - fn default() -> Self { - Self { - is_integral_div: false, - } - } -} - pub const TEST_EXEC_CONTEXT_ID: i64 = -1; /// The query planner for converting Spark query plans to DataFusion query plans. diff --git a/native/spark-expr/src/math_funcs/div.rs b/native/spark-expr/src/math_funcs/div.rs index 0a28197c1..8abfc431c 100644 --- a/native/spark-expr/src/math_funcs/div.rs +++ b/native/spark-expr/src/math_funcs/div.rs @@ -86,12 +86,10 @@ fn spark_decimal_div_internal( let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; let res = if is_integral_div { div + } else if div.is_negative() { + div - &five } else { - if div.is_negative() { - div - &five - } else { - div + &five - } + div + &five } / &ten; res.to_i128().unwrap_or(i128::MAX) })? @@ -104,12 +102,10 @@ fn spark_decimal_div_internal( let div = if r == 0 { 0 } else { l / r }; let res = if is_integral_div { div + } else if div.is_negative() { + div - 5 } else { - if div.is_negative() { - div - 5 - } else { - div + 5 - } + div + 5 } / 10; res.to_i128().unwrap_or(i128::MAX) })? From 66efcfe85144c65af0a800bd6c0c679bc9bf5106 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Wed, 26 Feb 2025 16:10:03 +0800 Subject: [PATCH 6/9] compatible with spark 3.3 --- .../main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7d656567d..381a75b65 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -20,6 +20,7 @@ package org.apache.comet.serde import scala.collection.JavaConverters._ +import scala.math.min import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ @@ -637,7 +638,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val dataType = (left.dataType, right.dataType) match { case (l: DecimalType, r: DecimalType) => - div.resultDecimalType(l.precision, l.scale, r.precision, r.scale) + // copy from IntegralDivide.resultDecimalType + val intDig = l.precision - l.scale + r.scale + DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0) case _ => left.dataType } From fd158bf493150881e6a8d11202a6cff8d217d718 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Wed, 26 Feb 2025 18:23:12 +0800 Subject: [PATCH 7/9] test compatible with spark 3.3 --- .../apache/comet/CometExpressionSuite.scala | 41 +++++++++++-------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 74cb6d201..90f1cf375 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -2660,28 +2660,33 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { randomSize = 10000) withParquetTable(path1.toString, "tbl1") { withParquetTable(path2.toString, "tbl2") { - checkSparkAnswerAndOperator(""" - |select - | t1._2 div t2._2, div(t1._2, t2._2), - | t1._3 div t2._3, div(t1._3, t2._3), - | t1._4 div t2._4, div(t1._4, t2._4), - | t1._5 div t2._5, div(t1._5, t2._5), - | t1._9 div t2._9, div(t1._9, t2._9), - | t1._10 div t2._10, div(t1._10, t2._10), - | t1._11 div t2._11, div(t1._11, t2._11), - | t1._12 div t2._12, div(t1._12, t2._12) - | from tbl1 t1 join tbl2 t2 on t1._id = t2._id - | order by t1._id""".stripMargin) - - if (isSpark34Plus) { - // Decimal support requires Spark 3.4 or later + // disable broadcast, as comet on spark 3.3 does not support broadcast exchange + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { checkSparkAnswerAndOperator(""" |select - | t1._15, t2._15, t1._15 div t2._15, div(t1._15, t2._15), - | t1._16, t2._16, t1._16 div t2._16, div(t1._16, t2._16), - | t1._17, t2._17, t1._17 div t2._17, div(t1._17, t2._17) + | t1._2 div t2._2, div(t1._2, t2._2), + | t1._3 div t2._3, div(t1._3, t2._3), + | t1._4 div t2._4, div(t1._4, t2._4), + | t1._5 div t2._5, div(t1._5, t2._5), + | t1._9 div t2._9, div(t1._9, t2._9), + | t1._10 div t2._10, div(t1._10, t2._10), + | t1._11 div t2._11, div(t1._11, t2._11) | from tbl1 t1 join tbl2 t2 on t1._id = t2._id | order by t1._id""".stripMargin) + + if (isSpark34Plus) { + // decimal support requires Spark 3.4 or later + checkSparkAnswerAndOperator(""" + |select + | t1._12 div t2._12, div(t1._12, t2._12), + | t1._15 div t2._15, div(t1._15, t2._15), + | t1._16 div t2._16, div(t1._16, t2._16), + | t1._17 div t2._17, div(t1._17, t2._17) + | from tbl1 t1 join tbl2 t2 on t1._id = t2._id + | order by t1._id""".stripMargin) + } } } } From fb6e8b5902b1ff39a746c088637ecb65349ff1e2 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Fri, 28 Feb 2025 15:57:53 +0800 Subject: [PATCH 8/9] add div operator to `Decimal random number tests` --- .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 90f1cf375..ba13b439c 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1811,7 +1811,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { "spark.sql.decimalOperations.allowPrecisionLoss" -> allowPrecisionLoss.toString) { val a = makeNum(p1, s1) val b = makeNum(p2, s2) - var ops = Seq("+", "-", "*", "/", "%") + val ops = Seq("+", "-", "*", "/", "%", "div") for (op <- ops) { checkSparkAnswerAndOperator(s"select a, b, a $op b from $table") checkSparkAnswerAndOperator(s"select $a, b, $a $op b from $table") From 2f3d3a75ad22061dc37237466cac3b302df46034 Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Fri, 28 Feb 2025 16:17:48 +0800 Subject: [PATCH 9/9] add comment --- native/core/src/execution/planner.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 261df954a..34074f2ca 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -957,6 +957,9 @@ impl PhysicalPlanner { ) => { let data_type = return_type.map(to_arrow_datatype).unwrap(); let func_name = if options.is_integral_div { + // Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div. + // Otherwise, we may be able to reuse the previous case-match instead of here, + // see more: https://github.com/apache/datafusion-comet/pull/1428#discussion_r1972648463 "decimal_integral_div" } else { "decimal_div"