diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 52e4a000355d..b5637f785fb2 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1340,6 +1340,7 @@ dependencies = [ "datafusion-functions-aggregate-common", "datafusion-physical-expr", "datafusion-physical-expr-common", + "half", "log", "paste", "sqlparser", diff --git a/datafusion/functions-aggregate/Cargo.toml b/datafusion/functions-aggregate/Cargo.toml index 636b2e42d236..d78f68a2604e 100644 --- a/datafusion/functions-aggregate/Cargo.toml +++ b/datafusion/functions-aggregate/Cargo.toml @@ -47,6 +47,7 @@ datafusion-expr = { workspace = true } datafusion-functions-aggregate-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } +half = { workspace = true } log = { workspace = true } paste = "1.0.14" sqlparser = { workspace = true } diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 4dcd5ac0e951..961e8639604c 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -34,18 +34,19 @@ use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, - Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, - IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, - StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, - Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, - TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, + Decimal128Array, Decimal256Array, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, IntervalDayTimeArray, + IntervalMonthDayNanoArray, IntervalYearMonthArray, LargeBinaryArray, + LargeStringArray, StringArray, StringViewArray, Time32MillisecondArray, + Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{ - DataType, Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, - Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, + DataType, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, + Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, + UInt8Type, }; use arrow_schema::IntervalUnit; use datafusion_common::{ @@ -66,6 +67,7 @@ use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use half::f16; use std::ops::Deref; fn get_min_max_result_type(input_types: &[DataType]) -> Result> { @@ -181,6 +183,7 @@ impl AggregateUDFImpl for Max { | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal128(_, _) @@ -209,6 +212,9 @@ impl AggregateUDFImpl for Max { UInt16 => instantiate_max_accumulator!(data_type, u16, UInt16Type), UInt32 => instantiate_max_accumulator!(data_type, u32, UInt32Type), UInt64 => instantiate_max_accumulator!(data_type, u64, UInt64Type), + Float16 => { + instantiate_max_accumulator!(data_type, f16, Float16Type) + } Float32 => { instantiate_max_accumulator!(data_type, f32, Float32Type) } @@ -339,6 +345,9 @@ macro_rules! min_max_batch { DataType::Float32 => { typed_min_max_batch!($VALUES, Float32Array, Float32, $OP) } + DataType::Float16 => { + typed_min_max_batch!($VALUES, Float16Array, Float16, $OP) + } DataType::Int64 => typed_min_max_batch!($VALUES, Int64Array, Int64, $OP), DataType::Int32 => typed_min_max_batch!($VALUES, Int32Array, Int32, $OP), DataType::Int16 => typed_min_max_batch!($VALUES, Int16Array, Int16, $OP), @@ -623,6 +632,9 @@ macro_rules! min_max { (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { typed_min_max_float!(lhs, rhs, Float32, $OP) } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + typed_min_max_float!(lhs, rhs, Float16, $OP) + } (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { typed_min_max!(lhs, rhs, UInt64, $OP) } @@ -950,6 +962,7 @@ impl AggregateUDFImpl for Min { | UInt16 | UInt32 | UInt64 + | Float16 | Float32 | Float64 | Decimal128(_, _) @@ -978,6 +991,9 @@ impl AggregateUDFImpl for Min { UInt16 => instantiate_min_accumulator!(data_type, u16, UInt16Type), UInt32 => instantiate_min_accumulator!(data_type, u32, UInt32Type), UInt64 => instantiate_min_accumulator!(data_type, u64, UInt64Type), + Float16 => { + instantiate_min_accumulator!(data_type, f16, Float16Type) + } Float32 => { instantiate_min_accumulator!(data_type, f32, Float32Type) } diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0cda24d6ff5e..1ffc69167e3a 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -5643,3 +5643,31 @@ query I??III?T select count(null), min(null), max(null), bit_and(NULL), bit_or(NULL), bit_xor(NULL), nth_value(NULL, 1), string_agg(NULL, ','); ---- 0 NULL NULL NULL NULL NULL NULL NULL + +# test min/max Float16 without group expression +query RRTT +WITH data AS ( + SELECT arrow_cast(1, 'Float16') AS f + UNION ALL + SELECT arrow_cast(6, 'Float16') AS f +) +SELECT MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) FROM data; +---- +1 6 Float16 Float16 + +# test min/max Float16 with group expression +query IRRTT +WITH data AS ( + SELECT 1 as k, arrow_cast(1.8125, 'Float16') AS f + UNION ALL + SELECT 1 as k, arrow_cast(6.8007813, 'Float16') AS f + UNION ALL + SELECT 2 AS k, arrow_cast(8.5, 'Float16') AS f +) +SELECT k, MIN(f), MAX(f), arrow_typeof(MIN(f)), arrow_typeof(MAX(f)) +FROM data +GROUP BY k +ORDER BY k; +---- +1 1.8125 6.8007813 Float16 Float16 +2 8.5 8.5 Float16 Float16 diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 448706744305..d1f49838f932 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -102,7 +102,7 @@ query error Error unrecognized word: unknown SELECT arrow_cast('1', 'unknown') # Round Trip tests: -query TTTTTTTTTTTTTTTTTTTTTTT +query TTTTTTTTTTTTTTTTTTTTTTTT SELECT arrow_typeof(arrow_cast(1, 'Int8')) as col_i8, arrow_typeof(arrow_cast(1, 'Int16')) as col_i16, @@ -112,8 +112,7 @@ SELECT arrow_typeof(arrow_cast(1, 'UInt16')) as col_u16, arrow_typeof(arrow_cast(1, 'UInt32')) as col_u32, arrow_typeof(arrow_cast(1, 'UInt64')) as col_u64, - -- can't seem to cast to Float16 for some reason - -- arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, + arrow_typeof(arrow_cast(1, 'Float16')) as col_f16, arrow_typeof(arrow_cast(1, 'Float32')) as col_f32, arrow_typeof(arrow_cast(1, 'Float64')) as col_f64, arrow_typeof(arrow_cast('foo', 'Utf8')) as col_utf8, @@ -130,7 +129,7 @@ SELECT arrow_typeof(arrow_cast(to_timestamp('2020-01-02 01:01:11.1234567890Z'), 'Timestamp(Nanosecond, Some("+08:00"))')) as col_tstz_ns, arrow_typeof(arrow_cast('foo', 'Dictionary(Int32, Utf8)')) as col_dict ---- -Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8) +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 Utf8 LargeUtf8 Binary LargeBinary Timestamp(Second, None) Timestamp(Millisecond, None) Timestamp(Microsecond, None) Timestamp(Nanosecond, None) Timestamp(Second, Some("+08:00")) Timestamp(Millisecond, Some("+08:00")) Timestamp(Microsecond, Some("+08:00")) Timestamp(Nanosecond, Some("+08:00")) Dictionary(Int32, Utf8) @@ -147,15 +146,14 @@ create table foo as select arrow_cast(1, 'UInt16') as col_u16, arrow_cast(1, 'UInt32') as col_u32, arrow_cast(1, 'UInt64') as col_u64, - -- can't seem to cast to Float16 for some reason - -- arrow_cast(1.0, 'Float16') as col_f16, + arrow_cast(1.0, 'Float16') as col_f16, arrow_cast(1.0, 'Float32') as col_f32, arrow_cast(1.0, 'Float64') as col_f64 ; ## Ensure each column in the table has the expected type -query TTTTTTTTTT +query TTTTTTTTTTT SELECT arrow_typeof(col_i8), arrow_typeof(col_i16), @@ -165,12 +163,12 @@ SELECT arrow_typeof(col_u16), arrow_typeof(col_u32), arrow_typeof(col_u64), - -- arrow_typeof(col_f16), + arrow_typeof(col_f16), arrow_typeof(col_f32), arrow_typeof(col_f64) FROM foo; ---- -Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float32 Float64 +Int8 Int16 Int32 Int64 UInt8 UInt16 UInt32 UInt64 Float16 Float32 Float64 statement ok