diff --git a/datafusion/common/src/utils.rs b/datafusion/common/src/utils.rs index d609c695d68e..d8fbb4000273 100644 --- a/datafusion/common/src/utils.rs +++ b/datafusion/common/src/utils.rs @@ -464,62 +464,53 @@ pub fn coerced_type_with_base_type_only( base_type: &DataType, ) -> DataType { match data_type { - DataType::List(field) - | DataType::FixedSizeList(field, _) - | DataType::LargeList(field) => { - let field_type = match field.data_type() { - // nested type could be different list type - DataType::List(_) - | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) => { - coerced_type_with_base_type_only(field.data_type(), base_type) - } - _ => base_type.to_owned(), - }; - if matches!(data_type, DataType::LargeList(_)) { - DataType::LargeList(Arc::new(Field::new( - field.name(), - field_type, - field.is_nullable(), - ))) - } else { - DataType::List(Arc::new(Field::new( - field.name(), - field_type, - field.is_nullable(), - ))) - } + DataType::List(field) | DataType::FixedSizeList(field, _) => { + let field_type = + coerced_type_with_base_type_only(field.data_type(), base_type); + + DataType::List(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } + DataType::LargeList(field) => { + let field_type = + coerced_type_with_base_type_only(field.data_type(), base_type); + + DataType::LargeList(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) } + _ => base_type.clone(), } } pub fn coerced_fixed_size_list_to_list(data_type: &DataType) -> DataType { match data_type { - DataType::FixedSizeList(field, _) => { - let field_type = match field.data_type() { - DataType::List(_) - | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) => { - coerced_fixed_size_list_to_list(field.data_type()) - } - _ => field.data_type().to_owned(), - }; - if matches!(data_type, DataType::LargeList(_)) { - DataType::LargeList(Arc::new(Field::new( - field.name(), - field_type, - field.is_nullable(), - ))) - } else { - DataType::List(Arc::new(Field::new( - field.name(), - field_type, - field.is_nullable(), - ))) - } + DataType::List(field) | DataType::FixedSizeList(field, _) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::List(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) } - _ => data_type.to_owned(), + DataType::LargeList(field) => { + let field_type = coerced_fixed_size_list_to_list(field.data_type()); + + DataType::LargeList(Arc::new(Field::new( + field.name(), + field_type, + field.is_nullable(), + ))) + } + + _ => data_type.clone(), } } diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index 92b309817dee..806fdaaa5246 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -21,7 +21,7 @@ use arrow::{ compute::can_cast_types, datatypes::{DataType, TimeUnit}, }; -use datafusion_common::utils::list_ndims; +use datafusion_common::utils::{coerced_fixed_size_list_to_list, list_ndims}; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, DataFusionError, Result, }; @@ -141,7 +141,8 @@ fn get_valid_types( DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) => { - Ok(vec![vec![array_type.clone(), DataType::Int64]]) + let array_type = coerced_fixed_size_list_to_list(array_type); + Ok(vec![vec![array_type, DataType::Int64]]) } _ => Ok(vec![vec![]]), } diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4a4b3d849f5f..454ae80758c3 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -23,7 +23,6 @@ use arrow::datatypes::{DataType, IntervalUnit}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{RewriteRecursion, TreeNodeRewriter}; -use datafusion_common::utils::coerced_fixed_size_list_to_list; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, plan_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -590,17 +589,21 @@ fn coerce_arguments_for_fun( if expressions.is_empty() { return Ok(vec![]); } - let mut expressions: Vec = expressions.to_vec(); - // coerce the fixed size list to list for all array fucntions - if fun.name().contains("array") { + // Cast Fixedsizelist to List for array functions + if *fun == BuiltinScalarFunction::MakeArray { expressions = expressions .into_iter() .map(|expr| { let data_type = expr.get_type(schema).unwrap(); - let to_type = coerced_fixed_size_list_to_list(&data_type); - expr.cast_to(&to_type, schema) + if let DataType::FixedSizeList(field, _) = data_type { + let field = field.as_ref().clone(); + let to_type = DataType::List(Arc::new(field)); + expr.cast_to(&to_type, schema) + } else { + Ok(expr) + } }) .collect::>>()?; }