Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support IntegralDivide function #1428

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ struct JoinParameters {
pub join_type: DFJoinType,
}

#[derive(Default)]
struct BinaryExprOptions {
pub is_integral_div: bool,
}

pub const TEST_EXEC_CONTEXT_ID: i64 = -1;

/// The query planner for converting Spark query plans to DataFusion query plans.
Expand Down Expand Up @@ -211,6 +216,16 @@ impl PhysicalPlanner {
DataFusionOperator::Divide,
input_schema,
),
ExprStruct::IntegralDivide(expr) => self.create_binary_expr_with_options(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
expr.return_type.as_ref(),
DataFusionOperator::Divide,
input_schema,
BinaryExprOptions {
is_integral_div: true,
},
),
ExprStruct::Remainder(expr) => self.create_binary_expr(
expr.left.as_ref().unwrap(),
expr.right.as_ref().unwrap(),
Expand Down Expand Up @@ -873,6 +888,25 @@ impl PhysicalPlanner {
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
self.create_binary_expr_with_options(
left,
right,
return_type,
op,
input_schema,
BinaryExprOptions::default(),
)
}

fn create_binary_expr_with_options(
&self,
left: &Expr,
right: &Expr,
return_type: Option<&spark_expression::DataType>,
op: DataFusionOperator,
input_schema: SchemaRef,
options: BinaryExprOptions,
) -> Result<Arc<dyn PhysicalExpr>, ExecutionError> {
let left = self.create_expr(left, Arc::clone(&input_schema))?;
let right = self.create_expr(right, Arc::clone(&input_schema))?;
Expand Down Expand Up @@ -922,13 +956,21 @@ impl PhysicalPlanner {
Ok(DataType::Decimal128(_p2, _s2)),
) => {
let data_type = return_type.map(to_arrow_datatype).unwrap();
let func_name = if options.is_integral_div {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized, we maybe able to reuse the previous case-match instead of here.
We needed to treat decimal_div differently because we had to deal with rounding. However we just need to round down for IntegralDivide?
I.e. instead of 77digits for scale, we only need 76digits that fits into Decimal256

Will need a similar calculation to

                || (op == DataFusionOperator::Modulo
                    && max(s1, s2) as u8 + max(p1 - s1 as u8, p2 - s2 as u8)
                        > DECIMAL128_MAX_PRECISION) 

In this way, we do not need the decimal_div change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your suggestion, it makes sense to me. But I am getting an overflow error and I will continue debugging this tomorrow.

image

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying, if this is too much trouble, we can file an issue ticket and can be worked on separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a bug in arrow, I have reported an issue: apache/arrow-rs#7216

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying, if this is too much trouble, we can file an issue ticket and can be worked on separately.

Could you please continue review this pr and let us keep changes to decimal_div ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, perhaps it is good to mention the ticket in a comment
Otherwise, my only comment is

Do we need to update https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
?

Additionally https://github.com/apache/datafusion-comet/blob/main/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala#L1793 Decimal random number tests is another good test to extend

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed it before.

Do we need to update https://github.com/apache/datafusion-comet/blob/main/spark/src/main/scala/org/apache/spark/sql/comet/DecimalPrecision.scala
?

I guess it may not be necessary, IntegralDivide always returns a long type.

Additionally https://github.com/apache/datafusion-comet/blob/main/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala#L1793 Decimal random number tests is another good test to extend

I added div operator to this test case.

// Decimal256 division in Arrow may overflow, so we still need this variant of decimal_div.
// Otherwise, we may be able to reuse the previous case-match instead of here,
// see more: https://github.com/apache/datafusion-comet/pull/1428#discussion_r1972648463
"decimal_integral_div"
} else {
"decimal_div"
};
let fun_expr = create_comet_physical_fun(
"decimal_div",
func_name,
data_type.clone(),
&self.session_ctx.state(),
)?;
Ok(Arc::new(ScalarFunctionExpr::new(
"decimal_div",
func_name,
fun_expr,
vec![left, right],
data_type,
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ message Expr {
BinaryExpr array_intersect = 62;
ArrayJoin array_join = 63;
BinaryExpr arrays_overlap = 64;
MathExpr integral_divide = 65;
}
}

Expand Down
15 changes: 13 additions & 2 deletions native/spark-expr/benches/decimal_div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use arrow::compute::cast;
use arrow_array::builder::Decimal128Builder;
use arrow_schema::DataType;
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use datafusion_comet_spark_expr::spark_decimal_div;
use datafusion_comet_spark_expr::{spark_decimal_div, spark_decimal_integral_div};
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

Expand All @@ -40,14 +40,25 @@ fn criterion_benchmark(c: &mut Criterion) {
let c2 = cast(c2.as_ref(), &c2_type).unwrap();

let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)];
c.bench_function("decimal_div", |b| {

let mut group = c.benchmark_group("decimal div");
group.bench_function("decimal_div", |b| {
b.iter(|| {
black_box(spark_decimal_div(
black_box(&args),
black_box(&DataType::Decimal128(10, 4)),
))
})
});

group.bench_function("decimal_integral_div", |b| {
b.iter(|| {
black_box(spark_decimal_integral_div(
black_box(&args),
black_box(&DataType::Decimal128(10, 4)),
))
})
});
}

criterion_group!(benches, criterion_benchmark);
Expand Down
13 changes: 10 additions & 3 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

use crate::hash_funcs::*;
use crate::{
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, spark_hex,
spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, spark_unhex,
spark_unscaled_value, SparkChrFunc,
spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_decimal_integral_div,
spark_floor, spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round,
spark_unhex, spark_unscaled_value, SparkChrFunc,
};
use arrow_schema::DataType;
use datafusion_common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -90,6 +90,13 @@ pub fn create_comet_physical_fun(
"decimal_div" => {
make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type)
}
"decimal_integral_div" => {
make_comet_scalar_udf!(
"decimal_integral_div",
spark_decimal_integral_div,
data_type
)
}
"murmur3_hash" => {
let func = Arc::new(spark_murmur3_hash);
make_comet_scalar_udf!("murmur3_hash", func, without data_type)
Expand Down
6 changes: 3 additions & 3 deletions native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ pub use error::{SparkError, SparkResult};
pub use hash_funcs::*;
pub use json_funcs::ToJson;
pub use math_funcs::{
create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, spark_make_decimal,
spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, NegativeExpr,
NormalizeNaNAndZero,
create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
spark_hex, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow,
NegativeExpr, NormalizeNaNAndZero,
};
pub use string_funcs::*;

Expand Down
29 changes: 26 additions & 3 deletions native/spark-expr/src/math_funcs/div.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,29 @@ use datafusion_common::DataFusionError;
use num::{BigInt, Signed, ToPrimitive};
use std::sync::Arc;

pub fn spark_decimal_div(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
spark_decimal_div_internal(args, data_type, false)
}

pub fn spark_decimal_integral_div(
args: &[ColumnarValue],
data_type: &DataType,
) -> Result<ColumnarValue, DataFusionError> {
spark_decimal_div_internal(args, data_type, true)
}

// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3).
// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to
// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since
// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale >
// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt.
pub fn spark_decimal_div(
fn spark_decimal_div_internal(
args: &[ColumnarValue],
data_type: &DataType,
is_integral_div: bool,
) -> Result<ColumnarValue, DataFusionError> {
let left = &args[0];
let right = &args[1];
Expand Down Expand Up @@ -69,7 +84,9 @@ pub fn spark_decimal_div(
let l = BigInt::from(l) * &l_mul;
let r = BigInt::from(r) * &r_mul;
let div = if r.eq(&zero) { zero.clone() } else { &l / &r };
let res = if div.is_negative() {
let res = if is_integral_div {
div
} else if div.is_negative() {
div - &five
} else {
div + &five
Expand All @@ -83,7 +100,13 @@ pub fn spark_decimal_div(
let l = l * l_mul;
let r = r * r_mul;
let div = if r == 0 { 0 } else { l / r };
let res = if div.is_negative() { div - 5 } else { div + 5 } / 10;
let res = if is_integral_div {
div
} else if div.is_negative() {
div - 5
} else {
div + 5
} / 10;
res.to_i128().unwrap_or(i128::MAX)
})?
};
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/math_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod utils;

pub use ceil::spark_ceil;
pub use div::spark_decimal_div;
pub use div::spark_decimal_integral_div;
pub use floor::spark_floor;
pub use hex::spark_hex;
pub use internal::*;
Expand Down
39 changes: 39 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.serde

import scala.collection.JavaConverters._
import scala.math.min

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
Expand Down Expand Up @@ -631,6 +632,44 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
}
None

case div @ IntegralDivide(left, right, _)
if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) =>
val rightExpr = nullIfWhenPrimitive(right)

val dataType = (left.dataType, right.dataType) match {
case (l: DecimalType, r: DecimalType) =>
// copy from IntegralDivide.resultDecimalType
val intDig = l.precision - l.scale + r.scale
DecimalType(min(if (intDig == 0) 1 else intDig, DecimalType.MAX_PRECISION), 0)
case _ => left.dataType
}

val divideExpr = createMathExpression(
expr,
left,
rightExpr,
inputs,
binding,
dataType,
getFailOnError(div),
(builder, mathExpr) => builder.setIntegralDivide(mathExpr))

if (divideExpr.isDefined) {
// cast result to long
castToProto(expr, None, LongType, divideExpr.get, CometEvalMode.LEGACY)
} else {
None
}

case div @ IntegralDivide(left, _, _) =>
if (!supportedDataType(left.dataType)) {
withInfo(div, s"Unsupported datatype ${left.dataType}")
}
if (decimalBeforeSpark34(left.dataType)) {
withInfo(div, "Decimal support requires Spark 3.4 or later")
}
None

case rem @ Remainder(left, right, _)
if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) =>
val rightExpr = nullIfWhenPrimitive(right)
Expand Down
55 changes: 54 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1811,7 +1811,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
"spark.sql.decimalOperations.allowPrecisionLoss" -> allowPrecisionLoss.toString) {
val a = makeNum(p1, s1)
val b = makeNum(p2, s2)
var ops = Seq("+", "-", "*", "/", "%")
val ops = Seq("+", "-", "*", "/", "%", "div")
for (op <- ops) {
checkSparkAnswerAndOperator(s"select a, b, a $op b from $table")
checkSparkAnswerAndOperator(s"select $a, b, $a $op b from $table")
Expand Down Expand Up @@ -2641,4 +2641,57 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("test integral divide") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path1 = new Path(dir.toURI.toString, "test1.parquet")
val path2 = new Path(dir.toURI.toString, "test2.parquet")
makeParquetFileAllTypes(
path1,
dictionaryEnabled = dictionaryEnabled,
0,
0,
randomSize = 10000)
makeParquetFileAllTypes(
path2,
dictionaryEnabled = dictionaryEnabled,
0,
0,
randomSize = 10000)
withParquetTable(path1.toString, "tbl1") {
withParquetTable(path2.toString, "tbl2") {
// disable broadcast, as comet on spark 3.3 does not support broadcast exchange
withSQLConf(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
checkSparkAnswerAndOperator("""
|select
| t1._2 div t2._2, div(t1._2, t2._2),
| t1._3 div t2._3, div(t1._3, t2._3),
| t1._4 div t2._4, div(t1._4, t2._4),
| t1._5 div t2._5, div(t1._5, t2._5),
| t1._9 div t2._9, div(t1._9, t2._9),
| t1._10 div t2._10, div(t1._10, t2._10),
| t1._11 div t2._11, div(t1._11, t2._11)
| from tbl1 t1 join tbl2 t2 on t1._id = t2._id
| order by t1._id""".stripMargin)

if (isSpark34Plus) {
// decimal support requires Spark 3.4 or later
checkSparkAnswerAndOperator("""
|select
| t1._12 div t2._12, div(t1._12, t2._12),
| t1._15 div t2._15, div(t1._15, t2._15),
| t1._16 div t2._16, div(t1._16, t2._16),
| t1._17 div t2._17, div(t1._17, t2._17)
| from tbl1 t1 join tbl2 t2 on t1._id = t2._id
| order by t1._id""".stripMargin)
}
}
}
}
}
}
}

}
Loading
Loading