Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overflow in negate operator #11084

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 160 additions & 17 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
as_decimal128_array, as_decimal256_array, as_dictionary_array,
as_fixed_size_binary_array, as_fixed_size_list_array,
Expand Down Expand Up @@ -1168,6 +1169,13 @@ impl ScalarValue {

/// Calculate arithmetic negation for a scalar value
pub fn arithmetic_negate(&self) -> Result<Self> {
fn neg_checked_with_ctx<T: ArrowNativeTypeOp>(
v: T,
ctx: impl Fn() -> String,
) -> Result<T> {
v.neg_checked()
.map_err(|e| arrow_datafusion_err!(e).context(ctx()))
}
match self {
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
Expand All @@ -1177,40 +1185,91 @@ impl ScalarValue {
| ScalarValue::Float64(None) => Ok(self.clone()),
ScalarValue::Float64(Some(v)) => Ok(ScalarValue::Float64(Some(-v))),
ScalarValue::Float32(Some(v)) => Ok(ScalarValue::Float32(Some(-v))),
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(-v))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(-v))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(-v))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(-v))),
ScalarValue::IntervalYearMonth(Some(v)) => {
Ok(ScalarValue::IntervalYearMonth(Some(-v)))
}
ScalarValue::Int8(Some(v)) => Ok(ScalarValue::Int8(Some(v.neg_checked()?))),
ScalarValue::Int16(Some(v)) => Ok(ScalarValue::Int16(Some(v.neg_checked()?))),
ScalarValue::Int32(Some(v)) => Ok(ScalarValue::Int32(Some(v.neg_checked()?))),
ScalarValue::Int64(Some(v)) => Ok(ScalarValue::Int64(Some(v.neg_checked()?))),
ScalarValue::IntervalYearMonth(Some(v)) => Ok(
ScalarValue::IntervalYearMonth(Some(neg_checked_with_ctx(*v, || {
format!("In negation of IntervalYearMonth({v})")
})?)),
),
ScalarValue::IntervalDayTime(Some(v)) => {
let (days, ms) = IntervalDayTimeType::to_parts(*v);
let val = IntervalDayTimeType::make_value(-days, -ms);
let val = IntervalDayTimeType::make_value(
neg_checked_with_ctx(days, || {
format!("In negation of days {days} in IntervalDayTime")
})?,
neg_checked_with_ctx(ms, || {
format!("In negation of milliseconds {ms} in IntervalDayTime")
})?,
);
Ok(ScalarValue::IntervalDayTime(Some(val)))
}
ScalarValue::IntervalMonthDayNano(Some(v)) => {
let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(*v);
let val = IntervalMonthDayNanoType::make_value(-months, -days, -nanos);
let val = IntervalMonthDayNanoType::make_value(
neg_checked_with_ctx(months, || {
format!("In negation of months {months} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(days, || {
format!("In negation of days {days} of IntervalMonthDayNano")
})?,
neg_checked_with_ctx(nanos, || {
format!("In negation of nanos {nanos} of IntervalMonthDayNano")
})?,
);
Ok(ScalarValue::IntervalMonthDayNano(Some(val)))
}
ScalarValue::Decimal128(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal128(Some(-v), *precision, *scale))
Ok(ScalarValue::Decimal128(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal128({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => {
Ok(ScalarValue::Decimal256(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of Decimal256({v}, {precision}, {scale})")
})?),
*precision,
*scale,
))
}
ScalarValue::Decimal256(Some(v), precision, scale) => Ok(
ScalarValue::Decimal256(Some(v.neg_wrapping()), *precision, *scale),
),
ScalarValue::TimestampSecond(Some(v), tz) => {
Ok(ScalarValue::TimestampSecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampSecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampNanosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampNanosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampNanosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampNanoSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMicrosecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMicrosecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMicrosecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMicroSecond({v})")
})?),
tz.clone(),
))
}
ScalarValue::TimestampMillisecond(Some(v), tz) => {
Ok(ScalarValue::TimestampMillisecond(Some(-v), tz.clone()))
Ok(ScalarValue::TimestampMillisecond(
Some(neg_checked_with_ctx(*v, || {
format!("In negation of TimestampMilliSecond({v})")
})?),
tz.clone(),
))
}
value => _internal_err!(
"Can not run arithmetic negative on scalar value {value:?}"
Expand Down Expand Up @@ -3501,6 +3560,7 @@ mod tests {
use crate::assert_batches_eq;
use arrow::buffer::OffsetBuffer;
use arrow::compute::{is_null, kernels};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_columns;
use arrow_buffer::Buffer;
use arrow_schema::Fields;
Expand Down Expand Up @@ -5494,6 +5554,89 @@ mod tests {
Ok(())
}

#[test]
#[allow(arithmetic_overflow)] // we want to test them
fn test_scalar_negative_overflows() -> Result<()> {
macro_rules! test_overflow_on_value {
($($val:expr),* $(,)?) => {$(
{
let value: ScalarValue = $val;
let err = value.arithmetic_negate().expect_err("Should receive overflow error on negating {value:?}");
let root_err = err.find_root();
match root_err{
DataFusionError::ArrowError(
ArrowError::ComputeError(_),
_,
) => {}
_ => return Err(err),
};
}
)*};
}
test_overflow_on_value!(
// the integers
i8::MIN.into(),
i16::MIN.into(),
i32::MIN.into(),
i64::MIN.into(),
// for decimals, only value needs to be tested
ScalarValue::try_new_decimal128(i128::MIN, 10, 5)?,
ScalarValue::Decimal256(Some(i256::MIN), 20, 5),
// interval, check all possible values
ScalarValue::IntervalYearMonth(Some(i32::MIN)),
ScalarValue::new_interval_dt(i32::MIN, 999),
ScalarValue::new_interval_dt(1, i32::MIN),
ScalarValue::new_interval_mdn(i32::MIN, 15, 123_456),
ScalarValue::new_interval_mdn(12, i32::MIN, 123_456),
ScalarValue::new_interval_mdn(12, 15, i64::MIN),
// tz doesn't matter when negating
ScalarValue::TimestampSecond(Some(i64::MIN), None),
ScalarValue::TimestampMillisecond(Some(i64::MIN), None),
ScalarValue::TimestampMicrosecond(Some(i64::MIN), None),
ScalarValue::TimestampNanosecond(Some(i64::MIN), None),
);

let float_cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
(f32::MIN.into(), f32::MAX.into()),
(f32::MAX.into(), f32::MIN.into()),
(f64::MIN.into(), f64::MAX.into()),
(f64::MAX.into(), f64::MIN.into()),
];
// skip float 16 because they aren't supported
for (test, expected) in float_cases.into_iter().skip(2) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we just shouldn't list float 16 at all 🤔 I think it might be more confusing to see them listed then skipped than simply not listed (or listed in a separate test that simply fails)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I did this when filing #11083. Maybe it's cleaner to add a test that should fail on Float16 and when Float16 is supported, that test fails and the person making those changes knows to change the test.

assert_eq!(test.arithmetic_negate()?, expected);
}
Ok(())
}

#[test]
#[should_panic(expected = "Can not run arithmetic negative on scalar value Float16")]
fn f16_test_overflow() {
// TODO: if negate supports f16, add these cases to `test_scalar_negative_overflows` test case
let cases = [
(
ScalarValue::Float16(Some(f16::MIN)),
ScalarValue::Float16(Some(f16::MAX)),
),
(
ScalarValue::Float16(Some(f16::MAX)),
ScalarValue::Float16(Some(f16::MIN)),
),
];

for (test, expected) in cases {
assert_eq!(test.arithmetic_negate().unwrap(), expected);
}
}

macro_rules! expect_operation_error {
($TEST_NAME:ident, $FUNCTION:ident, $EXPECTED_ERROR:expr) => {
#[test]
Expand Down