Skip to content

Commit

Permalink
fix return type conflict when calling builtin math fuctions
Browse files Browse the repository at this point in the history
  • Loading branch information
lvheyang committed Jul 14, 2021
1 parent 7d24567 commit 5b3d6df
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 10 deletions.
70 changes: 68 additions & 2 deletions datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,8 +916,8 @@ mod tests {
physical_plan::expressions::AvgAccumulator,
};
use arrow::array::{
Array, ArrayRef, BinaryArray, DictionaryArray, Float64Array, Int32Array,
Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
Array, ArrayRef, BinaryArray, DictionaryArray, Float32Array, Float64Array,
Int32Array, Int64Array, LargeBinaryArray, LargeStringArray, StringArray,
TimestampNanosecondArray,
};
use arrow::compute::add;
Expand Down Expand Up @@ -2364,6 +2364,72 @@ mod tests {
assert_batches_sorted_eq!(expected, &results);
}

#[tokio::test]
async fn case_builtin_math_expression() {
let mut ctx = ExecutionContext::new();

let schema = Arc::new(Schema::new(vec![
Field::new("f64", DataType::Float64, false),
Field::new("f32", DataType::Float32, false),
Field::new("i32", DataType::Int32, false),
]));

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float64Array::from(vec![1.0])),
Arc::new(Float32Array::from(vec![1.0])),
Arc::new(Int32Array::from(vec![1])),
],
)
.unwrap();

let provider = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();

let expected = vec![
"+-----------+",
"| sqrt(f64) |",
"+-----------+",
"| 1 |",
"+-----------+",
];

let results = plan_and_collect(&mut ctx, "SELECT sqrt(f64) FROM t")
.await
.unwrap();

assert_batches_sorted_eq!(expected, &results);

let expected = vec![
"+-----------+",
"| sqrt(f32) |",
"+-----------+",
"| 1 |",
"+-----------+",
];

let results = plan_and_collect(&mut ctx, "SELECT sqrt(f32) FROM t")
.await
.unwrap();
assert_batches_sorted_eq!(expected, &results);

let expected = vec![
"+-----------+",
"| sqrt(i32) |",
"+-----------+",
"| 1 |",
"+-----------+",
];

let results = plan_and_collect(&mut ctx, "SELECT sqrt(i32) FROM t")
.await
.unwrap();

assert_batches_sorted_eq!(expected, &results);
}

#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let mut ctx = ExecutionContext::new();
Expand Down
25 changes: 18 additions & 7 deletions datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,18 @@ pub fn return_type(
| BuiltinScalarFunction::Sin
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Trunc => Ok(DataType::Float64),
| BuiltinScalarFunction::Trunc => {
if arg_types.is_empty() {
return Err(DataFusionError::Internal(format!(
"builtin scalar function {} does not support empty arguments",
fun
)));
}
match arg_types[0] {
DataType::Float32 => Ok(DataType::Float32),
_ => Ok(DataType::Float64),
}
}
}
}

Expand Down Expand Up @@ -1427,8 +1438,8 @@ mod tests {
};
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float64Array,
Int32Array, StringArray, UInt32Array, UInt64Array,
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, Float32Array,
Float64Array, Int32Array, StringArray, UInt32Array, UInt64Array,
},
datatypes::Field,
record_batch::RecordBatch,
Expand Down Expand Up @@ -1857,10 +1868,10 @@ mod tests {
test_function!(
Exp,
&[lit(ScalarValue::Float32(Some(1.0)))],
Ok(Some((1.0_f32).exp() as f64)),
f64,
Float64,
Float64Array
Ok(Some((1.0_f32).exp())),
f32,
Float32,
Float32Array
);
test_function!(
InitCap,
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/math_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ macro_rules! unary_primitive_array_op {
},
ColumnarValue::Scalar(a) => match a {
ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(
ScalarValue::Float64(a.map(|x| x.$FUNC() as f64)),
ScalarValue::Float32(a.map(|x| x.$FUNC())),
)),
ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(
ScalarValue::Float64(a.map(|x| x.$FUNC())),
Expand Down

0 comments on commit 5b3d6df

Please sign in to comment.