diff --git a/arrow/src/array/array_union.rs b/arrow/src/array/array_union.rs index 2a4a42d95b78..63cf5c2a09f8 100644 --- a/arrow/src/array/array_union.rs +++ b/arrow/src/array/array_union.rs @@ -61,7 +61,6 @@ use std::any::Any; /// type_id_buffer, /// Some(value_offsets_buffer), /// children, -/// None, /// ).unwrap(); /// /// let value = array.value(0).as_any().downcast_ref::().unwrap().value(0); @@ -94,7 +93,6 @@ use std::any::Any; /// type_id_buffer, /// None, /// children, -/// None, /// ).unwrap(); /// /// let value = array.value(0).as_any().downcast_ref::().unwrap().value(0); @@ -140,7 +138,6 @@ impl UnionArray { type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, - bitmap_data: Option, ) -> Self { let (field_types, field_values): (Vec<_>, Vec<_>) = child_arrays.into_iter().unzip(); @@ -152,13 +149,11 @@ impl UnionArray { UnionMode::Sparse }; - let mut builder = ArrayData::builder(DataType::Union(field_types, mode)) + let builder = ArrayData::builder(DataType::Union(field_types, mode)) .add_buffer(type_ids) .child_data(field_values.into_iter().map(|a| a.data().clone()).collect()) .len(len); - if let Some(bitmap) = bitmap_data { - builder = builder.null_bit_buffer(bitmap) - } + let data = match value_offsets { Some(b) => builder.add_buffer(b).build_unchecked(), None => builder.build_unchecked(), @@ -171,7 +166,6 @@ impl UnionArray { type_ids: Buffer, value_offsets: Option, child_arrays: Vec<(Field, ArrayRef)>, - bitmap: Option, ) -> Result { if let Some(b) = &value_offsets { if ((type_ids.len()) * 4) != b.len() { @@ -216,7 +210,7 @@ impl UnionArray { // Unsafe Justification: arguments were validated above (and // re-revalidated as part of data().validate() below) let new_self = - unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays, bitmap) }; + unsafe { Self::new_unchecked(type_ids, value_offsets, child_arrays) }; new_self.data().validate()?; Ok(new_self) @@ -512,7 +506,7 @@ mod tests { builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); let union = builder.build().unwrap(); @@ -522,29 +516,29 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(1_i32, value); } 1 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(3_i64, value); } 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(10_i32, value); } - 3 => assert!(union.is_null(i)), + 3 => assert!(slot.is_null(0)), 4 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(6_i32, value); @@ -560,7 +554,7 @@ mod tests { builder.append::("a", 1).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 10).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); let union = builder.build().unwrap(); @@ -573,15 +567,15 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(10_i32, value); } - 1 => assert!(new_union.is_null(i)), + 1 => assert!(slot.is_null(0)), 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(6_i32, value); @@ -614,13 +608,9 @@ mod tests { Arc::new(float_array), ), ]; - let array = UnionArray::try_new( - type_id_buffer, - Some(value_offsets_buffer), - children, - None, - ) - .unwrap(); + let array = + UnionArray::try_new(type_id_buffer, Some(value_offsets_buffer), children) + .unwrap(); // Check type ids assert_eq!(Buffer::from_slice_ref(&type_ids), array.data().buffers()[0]); @@ -800,7 +790,7 @@ mod tests { fn test_sparse_mixed_with_nulls() { let mut builder = UnionBuilder::new_sparse(5); builder.append::("a", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); builder.append::("a", 4).unwrap(); let union = builder.build().unwrap(); @@ -824,22 +814,22 @@ mod tests { match i { 0 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(1_i32, value); } - 1 => assert!(union.is_null(i)), + 1 => assert!(slot.is_null(0)), 2 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(value, 3_f64); } 3 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(4_i32, value); @@ -853,9 +843,9 @@ mod tests { fn test_sparse_mixed_with_nulls_and_offset() { let mut builder = UnionBuilder::new_sparse(5); builder.append::("a", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("c").unwrap(); builder.append::("a", 4).unwrap(); let union = builder.build().unwrap(); @@ -866,18 +856,18 @@ mod tests { for i in 0..new_union.len() { let slot = new_union.value(i); match i { - 0 => assert!(new_union.is_null(i)), + 0 => assert!(slot.is_null(0)), 1 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!new_union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(value, 3_f64); } - 2 => assert!(new_union.is_null(i)), + 2 => assert!(slot.is_null(0)), 3 => { let slot = slot.as_any().downcast_ref::().unwrap(); - assert!(!new_union.is_null(i)); + assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(4_i32, value); @@ -886,4 +876,12 @@ mod tests { } } } + + #[test] + fn test_type_check() { + let mut builder = UnionBuilder::new_sparse(2); + builder.append::("a", 1.0).unwrap(); + let err = builder.append::("a", 1).unwrap_err().to_string(); + assert!(err.contains("Attempt to write col \"a\" with type Int32 doesn't match existing type Float32"), "{}", err); + } } diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index e98627baef2b..1c64b5062b94 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -1894,23 +1894,19 @@ struct FieldData { values_buffer: Option, /// The number of array slots represented by the buffer slots: usize, - /// A builder for the bitmap if required (for Sparse Unions) - bitmap_builder: Option, + /// A builder for the null bitmap + bitmap_builder: BooleanBufferBuilder, } impl FieldData { /// Creates a new `FieldData`. - fn new( - type_id: i8, - data_type: DataType, - bitmap_builder: Option, - ) -> Self { + fn new(type_id: i8, data_type: DataType) -> Self { Self { type_id, data_type, values_buffer: Some(MutableBuffer::new(1)), slots: 0, - bitmap_builder, + bitmap_builder: BooleanBufferBuilder::new(1), } } @@ -1931,28 +1927,26 @@ impl FieldData { self.values_buffer = Some(mutable_buffer); self.slots += 1; - if let Some(b) = &mut self.bitmap_builder { - b.append(true) - }; + self.bitmap_builder.append(true); Ok(()) } /// Appends a null to this `FieldData`. #[allow(clippy::unnecessary_wraps)] fn append_null(&mut self) -> Result<()> { - if let Some(b) = &mut self.bitmap_builder { - let values_buffer = self - .values_buffer - .take() - .expect("Values buffer was never created"); - let mut builder: BufferBuilder = - mutable_buffer_to_builder(values_buffer, self.slots); - builder.advance(1); - let mutable_buffer = builder_to_mutable_buffer(builder); - self.values_buffer = Some(mutable_buffer); - self.slots += 1; - b.append(false); - }; + let values_buffer = self + .values_buffer + .take() + .expect("Values buffer was never created"); + + let mut builder: BufferBuilder = + mutable_buffer_to_builder(values_buffer, self.slots); + + builder.advance(1); + let mutable_buffer = builder_to_mutable_buffer(builder); + self.values_buffer = Some(mutable_buffer); + self.slots += 1; + self.bitmap_builder.append(false); Ok(()) } @@ -2047,8 +2041,6 @@ pub struct UnionBuilder { type_id_builder: Int8BufferBuilder, /// Builder to keep track of offsets (`None` for sparse unions) value_offset_builder: Option, - /// Optional builder for null slots - bitmap_builder: Option, } impl UnionBuilder { @@ -2059,7 +2051,6 @@ impl UnionBuilder { fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: Some(Int32BufferBuilder::new(capacity)), - bitmap_builder: None, } } @@ -2070,39 +2061,13 @@ impl UnionBuilder { fields: HashMap::default(), type_id_builder: Int8BufferBuilder::new(capacity), value_offset_builder: None, - bitmap_builder: None, } } /// Appends a null to this builder. #[inline] - pub fn append_null(&mut self) -> Result<()> { - if self.bitmap_builder.is_none() { - let mut builder = BooleanBufferBuilder::new(self.len + 1); - for _ in 0..self.len { - builder.append(true); - } - self.bitmap_builder = Some(builder) - } - self.bitmap_builder - .as_mut() - .expect("Cannot be None") - .append(false); - - self.type_id_builder.append(i8::default()); - - match &mut self.value_offset_builder { - // Handle dense union - Some(value_offset_builder) => value_offset_builder.append(i32::default()), - // Handle sparse union - None => { - for (_, fd) in self.fields.iter_mut() { - fd.append_null_dynamic()?; - } - } - }; - self.len += 1; - Ok(()) + pub fn append_null(&mut self, type_name: &str) -> Result<()> { + self.append_option::(type_name, None) } /// Appends a value to this builder. @@ -2111,22 +2076,28 @@ impl UnionBuilder { &mut self, type_name: &str, v: T::Native, + ) -> Result<()> { + self.append_option::(type_name, Some(v)) + } + + fn append_option( + &mut self, + type_name: &str, + v: Option, ) -> Result<()> { let type_name = type_name.to_string(); let mut field_data = match self.fields.remove(&type_name) { - Some(data) => data, - None => match self.value_offset_builder { - Some(_) => { - // For Dense Union, we don't build bitmap in individual field - FieldData::new(self.fields.len() as i8, T::DATA_TYPE, None) + Some(data) => { + if data.data_type != T::DATA_TYPE { + return Err(ArrowError::InvalidArgumentError(format!("Attempt to write col \"{}\" with type {} doesn't match existing type {}", type_name, T::DATA_TYPE, data.data_type))); } + data + } + None => match self.value_offset_builder { + Some(_) => FieldData::new(self.fields.len() as i8, T::DATA_TYPE), None => { - let mut fd = FieldData::new( - self.fields.len() as i8, - T::DATA_TYPE, - Some(BooleanBufferBuilder::new(1)), - ); + let mut fd = FieldData::new(self.fields.len() as i8, T::DATA_TYPE); for _ in 0..self.len { fd.append_null::()?; } @@ -2143,20 +2114,19 @@ impl UnionBuilder { } // Sparse Union None => { - for (name, fd) in self.fields.iter_mut() { - if name != &type_name { - fd.append_null_dynamic()?; - } + for (_, fd) in self.fields.iter_mut() { + // Append to all bar the FieldData currently being appended to + fd.append_null_dynamic()?; } } } - field_data.append_to_values_buffer::(v)?; - self.fields.insert(type_name, field_data); - // Update the bitmap builder if it exists - if let Some(b) = &mut self.bitmap_builder { - b.append(true); + match v { + Some(v) => field_data.append_to_values_buffer::(v)?, + None => field_data.append_null::()?, } + + self.fields.insert(type_name, field_data); self.len += 1; Ok(()) } @@ -2173,7 +2143,7 @@ impl UnionBuilder { data_type, values_buffer, slots, - bitmap_builder, + mut bitmap_builder, }, ) in self.fields.into_iter() { @@ -2182,16 +2152,10 @@ impl UnionBuilder { .into(); let arr_data_builder = ArrayDataBuilder::new(data_type.clone()) .add_buffer(buffer) - .len(slots); - // .build(); - let arr_data_ref = unsafe { - match bitmap_builder { - Some(mut bb) => arr_data_builder - .null_bit_buffer(bb.finish()) - .build_unchecked(), - None => arr_data_builder.build_unchecked(), - } - }; + .len(slots) + .null_bit_buffer(bitmap_builder.finish()); + + let arr_data_ref = unsafe { arr_data_builder.build_unchecked() }; let array_ref = make_array(arr_data_ref); children.push((type_id, (Field::new(&name, data_type, false), array_ref))) } @@ -2201,9 +2165,8 @@ impl UnionBuilder { .expect("This will never be None as type ids are always i8 values.") }); let children: Vec<_> = children.into_iter().map(|(_, b)| b).collect(); - let bitmap = self.bitmap_builder.map(|mut b| b.finish()); - UnionArray::try_new(type_id_buffer, value_offsets_buffer, children, bitmap) + UnionArray::try_new(type_id_buffer, value_offsets_buffer, children) } } diff --git a/arrow/src/array/data.rs b/arrow/src/array/data.rs index 2afc00b58958..c0ecef75d1c0 100644 --- a/arrow/src/array/data.rs +++ b/arrow/src/array/data.rs @@ -621,6 +621,13 @@ impl ArrayData { // Check that the data layout conforms to the spec let layout = layout(&self.data_type); + if !layout.can_contain_null_mask && self.null_bitmap.is_some() { + return Err(ArrowError::InvalidArgumentError(format!( + "Arrays of type {:?} cannot contain a null bitmask", + self.data_type, + ))); + } + if self.buffers.len() != layout.buffers.len() { return Err(ArrowError::InvalidArgumentError(format!( "Expected {} buffers in array of type {:?}, got {}", @@ -1224,9 +1231,13 @@ fn layout(data_type: &DataType) -> DataTypeLayout { // https://github.com/apache/arrow/blob/661c7d749150905a63dd3b52e0a04dac39030d95/cpp/src/arrow/type.h (and .cc) use std::mem::size_of; match data_type { - DataType::Null => DataTypeLayout::new_empty(), + DataType::Null => DataTypeLayout { + buffers: vec![], + can_contain_null_mask: false, + }, DataType::Boolean => DataTypeLayout { buffers: vec![BufferSpec::BitMap], + can_contain_null_mask: true, }, DataType::Int8 => DataTypeLayout::new_fixed_width(size_of::()), DataType::Int16 => DataTypeLayout::new_fixed_width(size_of::()), @@ -1287,6 +1298,7 @@ fn layout(data_type: &DataType) -> DataTypeLayout { ] } }, + can_contain_null_mask: false, } } DataType::Dictionary(key_type, _value_type) => layout(key_type), @@ -1308,6 +1320,9 @@ fn layout(data_type: &DataType) -> DataTypeLayout { struct DataTypeLayout { /// A vector of buffer layout specifications, one for each expected buffer pub buffers: Vec, + + /// Can contain a null bitmask + pub can_contain_null_mask: bool, } impl DataTypeLayout { @@ -1315,6 +1330,7 @@ impl DataTypeLayout { pub fn new_fixed_width(byte_width: usize) -> Self { Self { buffers: vec![BufferSpec::FixedWidth { byte_width }], + can_contain_null_mask: true, } } @@ -1322,7 +1338,10 @@ impl DataTypeLayout { /// (e.g. FixedSizeList). Note such arrays may still have a Null /// Bitmap pub fn new_empty() -> Self { - Self { buffers: vec![] } + Self { + buffers: vec![], + can_contain_null_mask: true, + } } /// Describes a basic numeric array where each element has a fixed @@ -1338,6 +1357,7 @@ impl DataTypeLayout { // values BufferSpec::VariableWidth, ], + can_contain_null_mask: true, } } } diff --git a/arrow/src/array/equal/boolean.rs b/arrow/src/array/equal/boolean.rs index 35c9786e49f9..de34d7fab189 100644 --- a/arrow/src/array/equal/boolean.rs +++ b/arrow/src/array/equal/boolean.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; @@ -24,8 +23,6 @@ use super::utils::{equal_bits, equal_len}; pub(super) fn boolean_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, mut lhs_start: usize, mut rhs_start: usize, mut len: usize, @@ -33,8 +30,8 @@ pub(super) fn boolean_equal( let lhs_values = lhs.buffers()[0].as_slice(); let rhs_values = rhs.buffers()[0].as_slice(); - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { // Optimize performance for starting offset at u8 boundary. @@ -73,8 +70,8 @@ pub(super) fn boolean_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); let lhs_start = lhs.offset() + lhs_start; let rhs_start = rhs.offset() + rhs_start; diff --git a/arrow/src/array/equal/decimal.rs b/arrow/src/array/equal/decimal.rs index 1ee6ec9b5436..e9879f3f281e 100644 --- a/arrow/src/array/equal/decimal.rs +++ b/arrow/src/array/equal/decimal.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::utils::equal_len; pub(super) fn decimal_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,8 +36,8 @@ pub(super) fn decimal_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_len( @@ -52,8 +49,8 @@ pub(super) fn decimal_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/dictionary.rs b/arrow/src/array/equal/dictionary.rs index 22add2494d2b..4c9bcf798760 100644 --- a/arrow/src/array/equal/dictionary.rs +++ b/arrow/src/array/equal/dictionary.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::ArrowNativeType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::equal_range; pub(super) fn dictionary_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -37,8 +34,8 @@ pub(super) fn dictionary_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { (0..len).all(|i| { @@ -48,8 +45,6 @@ pub(super) fn dictionary_equal( equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, @@ -57,8 +52,8 @@ pub(super) fn dictionary_equal( }) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -71,8 +66,6 @@ pub(super) fn dictionary_equal( && equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), lhs_keys[lhs_pos].to_usize().unwrap(), rhs_keys[rhs_pos].to_usize().unwrap(), 1, diff --git a/arrow/src/array/equal/fixed_binary.rs b/arrow/src/array/equal/fixed_binary.rs index 5f8f93232d53..aea0e08a9ebf 100644 --- a/arrow/src/array/equal/fixed_binary.rs +++ b/arrow/src/array/equal/fixed_binary.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::utils::equal_len; pub(super) fn fixed_binary_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,8 +36,8 @@ pub(super) fn fixed_binary_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * size..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * size..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_len( @@ -52,8 +49,8 @@ pub(super) fn fixed_binary_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/fixed_list.rs b/arrow/src/array/equal/fixed_list.rs index e708a06efcdb..82a347c86574 100644 --- a/arrow/src/array/equal/fixed_list.rs +++ b/arrow/src/array/equal/fixed_list.rs @@ -16,7 +16,6 @@ // under the License. use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::datatypes::DataType; use crate::util::bit_util::get_bit; @@ -25,8 +24,6 @@ use super::equal_range; pub(super) fn fixed_list_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -39,23 +36,21 @@ pub(super) fn fixed_list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), - size * lhs_start, - size * rhs_start, + (lhs_start + lhs.offset()) * size, + (rhs_start + rhs.offset()) * size, size * len, ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; @@ -69,10 +64,8 @@ pub(super) fn fixed_list_equal( && equal_range( lhs_values, rhs_values, - lhs_values.null_buffer(), - rhs_values.null_buffer(), - lhs_pos * size, - rhs_pos * size, + (lhs_pos + lhs.offset()) * size, + (rhs_pos + rhs.offset()) * size, size, // 1 * size since we are comparing a single entry ) }) diff --git a/arrow/src/array/equal/list.rs b/arrow/src/array/equal/list.rs index 000b31a1f785..09ad896f46eb 100644 --- a/arrow/src/array/equal/list.rs +++ b/arrow/src/array/equal/list.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::datatypes::DataType; use crate::{ array::ArrayData, array::{data::count_nulls, OffsetSizeTrait}, - buffer::Buffer, util::bit_util::get_bit, }; -use super::{ - equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls, -}; +use super::equal_range; fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { // invariant from `base_equal` @@ -49,66 +45,9 @@ fn lengths_equal(lhs: &[T], rhs: &[T]) -> bool { }) } -#[allow(clippy::too_many_arguments)] -#[inline] -fn offset_value_equal( - lhs_values: &ArrayData, - rhs_values: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, - lhs_offsets: &[T], - rhs_offsets: &[T], - lhs_pos: usize, - rhs_pos: usize, - len: usize, - data_type: &DataType, -) -> bool { - let lhs_start = lhs_offsets[lhs_pos].to_usize().unwrap(); - let rhs_start = rhs_offsets[rhs_pos].to_usize().unwrap(); - let lhs_len = lhs_offsets[lhs_pos + len] - lhs_offsets[lhs_pos]; - let rhs_len = rhs_offsets[rhs_pos + len] - rhs_offsets[rhs_pos]; - - lhs_len == rhs_len && { - match data_type { - DataType::Map(_, _) => { - // Don't use `equal_range` which calls `utils::base_equal` that checks - // struct fields, but we don't enforce struct field names. - equal_nulls( - lhs_values, - rhs_values, - lhs_nulls, - rhs_nulls, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ) && equal_values( - lhs_values, - rhs_values, - lhs_nulls, - rhs_nulls, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ) - } - _ => equal_range( - lhs_values, - rhs_values, - lhs_nulls, - rhs_nulls, - lhs_start, - rhs_start, - lhs_len.to_usize().unwrap(), - ), - } - } -} - pub(super) fn list_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -123,7 +62,7 @@ pub(super) fn list_equal( // no child values. This causes panics when trying to count set bits. // // We caught this by chance from an accidental test-case, but due to the nature of this - // crash only occuring on list equality checks, we are adding a check here, instead of + // crash only occurring on list equality checks, we are adding a check here, instead of // on the buffer/bitmap utilities, as a length check would incur a penalty for almost all // other use-cases. // @@ -134,10 +73,11 @@ pub(super) fn list_equal( // however, one is more likely to slice into a list array and get a region that has 0 // child values. // The test that triggered this behaviour had [4, 4] as a slice of 1 value slot. - let lhs_child_length = lhs_offsets.get(len).unwrap().to_usize().unwrap() - - lhs_offsets.first().unwrap().to_usize().unwrap(); - let rhs_child_length = rhs_offsets.get(len).unwrap().to_usize().unwrap() - - rhs_offsets.first().unwrap().to_usize().unwrap(); + let lhs_child_length = lhs_offsets[lhs_start + len].to_usize().unwrap() + - lhs_offsets[lhs_start].to_usize().unwrap(); + + let rhs_child_length = rhs_offsets[rhs_start + len].to_usize().unwrap() + - rhs_offsets[rhs_start].to_usize().unwrap(); if lhs_child_length == 0 && lhs_child_length == rhs_child_length { return true; @@ -146,64 +86,33 @@ pub(super) fn list_equal( let lhs_values = &lhs.child_data()[0]; let rhs_values = &rhs.child_data()[0]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); - // compute the child logical bitmap - let child_lhs_nulls = - child_logical_null_buffer(lhs, lhs_nulls, lhs.child_data().get(0).unwrap()); - let child_rhs_nulls = - child_logical_null_buffer(rhs, rhs_nulls, rhs.child_data().get(0).unwrap()); + if lhs_null_count != rhs_null_count { + return false; + } if lhs_null_count == 0 && rhs_null_count == 0 { - lengths_equal( - &lhs_offsets[lhs_start..lhs_start + len], - &rhs_offsets[rhs_start..rhs_start + len], - ) && { - match lhs.data_type() { - DataType::Map(_, _) => { - // Don't use `equal_range` which calls `utils::base_equal` that checks - // struct fields, but we don't enforce struct field names. - equal_nulls( - lhs_values, - rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets[lhs_start].to_usize().unwrap(), - rhs_offsets[rhs_start].to_usize().unwrap(), - (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) - .to_usize() - .unwrap(), - ) && equal_values( - lhs_values, - rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets[lhs_start].to_usize().unwrap(), - rhs_offsets[rhs_start].to_usize().unwrap(), - (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) - .to_usize() - .unwrap(), - ) - } - _ => equal_range( - lhs_values, - rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets[lhs_start].to_usize().unwrap(), - rhs_offsets[rhs_start].to_usize().unwrap(), - (lhs_offsets[lhs_start + len] - lhs_offsets[lhs_start]) - .to_usize() - .unwrap(), - ), - } - } + lhs_child_length == rhs_child_length + && lengths_equal( + &lhs_offsets[lhs_start..lhs_start + len], + &rhs_offsets[rhs_start..rhs_start + len], + ) + && equal_range( + lhs_values, + rhs_values, + lhs_offsets[lhs_start].to_usize().unwrap(), + rhs_offsets[rhs_start].to_usize().unwrap(), + lhs_child_length, + ) } else { // get a ref of the parent null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().unwrap().as_slice(); + // with nulls, we need to compare item by item whenever it is not null + // TODO: Could potentially compare runs of not NULL values (0..len).all(|i| { let lhs_pos = lhs_start + i; let rhs_pos = rhs_start + i; @@ -211,20 +120,56 @@ pub(super) fn list_equal( let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); + if lhs_is_null != rhs_is_null { + return false; + } + + let lhs_offset_start = lhs_offsets[lhs_pos].to_usize().unwrap(); + let lhs_offset_end = lhs_offsets[lhs_pos + 1].to_usize().unwrap(); + let rhs_offset_start = rhs_offsets[rhs_pos].to_usize().unwrap(); + let rhs_offset_end = rhs_offsets[rhs_pos + 1].to_usize().unwrap(); + + let lhs_len = lhs_offset_end - lhs_offset_start; + let rhs_len = rhs_offset_end - rhs_offset_start; + lhs_is_null - || (lhs_is_null == rhs_is_null) - && offset_value_equal::( + || (lhs_len == rhs_len + && equal_range( lhs_values, rhs_values, - child_lhs_nulls.as_ref(), - child_rhs_nulls.as_ref(), - lhs_offsets, - rhs_offsets, - lhs_pos, - rhs_pos, - 1, - lhs.data_type(), - ) + lhs_offset_start, + rhs_offset_start, + lhs_len, + )) }) } } + +#[cfg(test)] +mod tests { + use crate::array::{Int64Builder, ListBuilder}; + + #[test] + fn list_array_non_zero_nulls() { + // Tests handling of list arrays with non-empty null ranges + let mut builder = ListBuilder::new(Int64Builder::new(10)); + builder.values().append_value(1).unwrap(); + builder.values().append_value(2).unwrap(); + builder.values().append_value(3).unwrap(); + builder.append(true).unwrap(); + builder.append(false).unwrap(); + let array1 = builder.finish(); + + let mut builder = ListBuilder::new(Int64Builder::new(10)); + builder.values().append_value(1).unwrap(); + builder.values().append_value(2).unwrap(); + builder.values().append_value(3).unwrap(); + builder.append(true).unwrap(); + builder.values().append_null().unwrap(); + builder.values().append_null().unwrap(); + builder.append(false).unwrap(); + let array2 = builder.finish(); + + assert_eq!(array1, array2); + } +} diff --git a/arrow/src/array/equal/mod.rs b/arrow/src/array/equal/mod.rs index 07c173b13326..f5f0d60c713e 100644 --- a/arrow/src/array/equal/mod.rs +++ b/arrow/src/array/equal/mod.rs @@ -25,10 +25,7 @@ use super::{ GenericStringArray, MapArray, NullArray, OffsetSizeTrait, PrimitiveArray, StringOffsetSizeTrait, StructArray, }; -use crate::{ - buffer::Buffer, - datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}, -}; +use crate::datatypes::{ArrowPrimitiveType, DataType, IntervalUnit}; use half::f16; mod boolean; @@ -144,147 +141,99 @@ impl PartialEq for StructArray { } /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively -/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. -/// -/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. -/// This then affects the null count of the array, thus the merged nulls are passed separately -/// as `lhs_nulls` and `rhs_nulls` variables to functions. -/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary. +/// for `len` slots. #[inline] fn equal_values( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { match lhs.data_type() { DataType::Null => null_equal(lhs, rhs, lhs_start, rhs_start, len), - DataType::Boolean => { - boolean_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - } - DataType::UInt8 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt16 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int8 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int16 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Float32 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Float64 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + DataType::Boolean => boolean_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::UInt64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int8 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Date32 | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + | DataType::Interval(IntervalUnit::YearMonth) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) + } DataType::Date64 | DataType::Interval(IntervalUnit::DayTime) | DataType::Time64(_) | DataType::Timestamp(_, _) - | DataType::Duration(_) => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Interval(IntervalUnit::MonthDayNano) => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Utf8 | DataType::Binary => variable_sized_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::LargeUtf8 | DataType::LargeBinary => variable_sized_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::FixedSizeBinary(_) => { - fixed_binary_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + | DataType::Duration(_) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Decimal(_, _) => { - decimal_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Interval(IntervalUnit::MonthDayNano) => { + primitive_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::List(_) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Utf8 | DataType::Binary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::LargeList(_) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::LargeUtf8 | DataType::LargeBinary => { + variable_sized_equal::(lhs, rhs, lhs_start, rhs_start, len) } - DataType::FixedSizeList(_, _) => { - fixed_list_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - } - DataType::Struct(_) => { - struct_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::FixedSizeBinary(_) => { + fixed_binary_equal(lhs, rhs, lhs_start, rhs_start, len) } - DataType::Union(_, _) => { - union_equal(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + DataType::Decimal(_, _) => decimal_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::FixedSizeList(_, _) => { + fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len) } + DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Dictionary(data_type, _) => match data_type.as_ref() { - DataType::Int8 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int16 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int32 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Int64 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt8 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt16 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt32 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::UInt64 => dictionary_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), + DataType::Int8 => dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Int16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::Int64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt8 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt16 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt32 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } + DataType::UInt64 => { + dictionary_equal::(lhs, rhs, lhs_start, rhs_start, len) + } _ => unreachable!(), }, - DataType::Float16 => primitive_equal::( - lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len, - ), - DataType::Map(_, _) => { - list_equal::(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - } + DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), } } fn equal_range( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - utils::base_equal(lhs, rhs) - && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + utils::equal_nulls(lhs, rhs, lhs_start, rhs_start, len) + && equal_values(lhs, rhs, lhs_start, rhs_start, len) } /// Logically compares two [ArrayData]. @@ -300,12 +249,10 @@ fn equal_range( /// This function may panic whenever any of the [ArrayData] does not follow the Arrow specification. /// (e.g. wrong number of buffers, buffer `len` does not correspond to the declared `len`) pub fn equal(lhs: &ArrayData, rhs: &ArrayData) -> bool { - let lhs_nulls = lhs.null_buffer(); - let rhs_nulls = rhs.null_buffer(); utils::base_equal(lhs, rhs) && lhs.null_count() == rhs.null_count() - && utils::equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, 0, 0, lhs.len()) + && utils::equal_nulls(lhs, rhs, 0, 0, lhs.len()) + && equal_values(lhs, rhs, 0, 0, lhs.len()) } #[cfg(test)] @@ -494,6 +441,13 @@ mod tests { (1, 2), true, ), + ( + vec![Some(1), Some(2), None, Some(0)], + (2, 2), + vec![Some(4), Some(5), Some(0), None], + (2, 2), + false, + ), ]; for (lhs, slice_lhs, rhs, slice_rhs, expected) in cases { @@ -990,6 +944,11 @@ mod tests { None, ]); test_equal(&a, &b, false); + + let b = create_fixed_size_list_array(&[None, Some(&[4, 5, 6]), None, None]); + + test_equal(&a.slice(2, 4), &b, true); + test_equal(&a.slice(3, 3), &b.slice(1, 3), true); } #[test] @@ -1359,7 +1318,7 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); builder.append::("b", 7).unwrap(); let union1 = builder.build().unwrap(); @@ -1369,7 +1328,7 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); builder.append::("b", 7).unwrap(); let union2 = builder.build().unwrap(); @@ -1389,8 +1348,8 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); + builder.append_null::("c").unwrap(); + builder.append_null::("b").unwrap(); builder.append::("b", 7).unwrap(); let union4 = builder.build().unwrap(); @@ -1406,7 +1365,7 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); builder.append::("b", 7).unwrap(); let union1 = builder.build().unwrap(); @@ -1416,7 +1375,7 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("a", 6).unwrap(); builder.append::("b", 7).unwrap(); let union2 = builder.build().unwrap(); @@ -1436,8 +1395,8 @@ mod tests { builder.append::("b", 2).unwrap(); builder.append::("c", 3).unwrap(); builder.append::("a", 4).unwrap(); - builder.append_null().unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); + builder.append_null::("a").unwrap(); builder.append::("b", 7).unwrap(); let union4 = builder.build().unwrap(); diff --git a/arrow/src/array/equal/primitive.rs b/arrow/src/array/equal/primitive.rs index db7587915c8a..09882cd78509 100644 --- a/arrow/src/array/equal/primitive.rs +++ b/arrow/src/array/equal/primitive.rs @@ -18,7 +18,6 @@ use std::mem::size_of; use crate::array::{data::count_nulls, ArrayData}; -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use super::utils::equal_len; @@ -26,8 +25,6 @@ use super::utils::equal_len; pub(super) fn primitive_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -36,8 +33,8 @@ pub(super) fn primitive_equal( let lhs_values = &lhs.buffers()[0].as_slice()[lhs.offset() * byte_width..]; let rhs_values = &rhs.buffers()[0].as_slice()[rhs.offset() * byte_width..]; - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 { // without nulls, we just need to compare slices @@ -50,8 +47,8 @@ pub(super) fn primitive_equal( ) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; diff --git a/arrow/src/array/equal/structure.rs b/arrow/src/array/equal/structure.rs index b3cc4029e9ec..0f943e40cac6 100644 --- a/arrow/src/array/equal/structure.rs +++ b/arrow/src/array/equal/structure.rs @@ -15,24 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::data::count_nulls, array::ArrayData, buffer::Buffer, util::bit_util::get_bit, -}; +use crate::{array::data::count_nulls, array::ArrayData, util::bit_util::get_bit}; -use super::{equal_range, utils::child_logical_null_buffer}; +use super::equal_range; /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively -/// for `len` slots. The null buffers `lhs_nulls` and `rhs_nulls` inherit parent nullability. -/// -/// If an array is a child of a struct or list, the array's nulls have to be merged with the parent. -/// This then affects the null count of the array, thus the merged nulls are passed separately -/// as `lhs_nulls` and `rhs_nulls` variables to functions. -/// The nulls are merged with a bitwise AND, and null counts are recomputed where necessary. -fn equal_values( +/// for `len` slots. +fn equal_child_values( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -41,39 +32,27 @@ fn equal_values( .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - // merge the null data - let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values); - let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values); - equal_range( - lhs_values, - rhs_values, - lhs_merged_nulls.as_ref(), - rhs_merged_nulls.as_ref(), - lhs_start, - rhs_start, - len, - ) + equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) }) } pub(super) fn struct_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { // we have to recalculate null counts from the null buffers - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + if lhs_null_count == 0 && rhs_null_count == 0 { - equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + equal_child_values(lhs, rhs, lhs_start, rhs_start, len) } else { // get a ref of the null buffer bytes, to use in testing for nullness - let lhs_null_bytes = lhs_nulls.as_ref().unwrap().as_slice(); - let rhs_null_bytes = rhs_nulls.as_ref().unwrap().as_slice(); + let lhs_null_bytes = lhs.null_buffer().as_ref().unwrap().as_slice(); + let rhs_null_bytes = rhs.null_buffer().as_ref().unwrap().as_slice(); // with nulls, we need to compare item by item whenever it is not null (0..len).all(|i| { let lhs_pos = lhs_start + i; @@ -82,9 +61,11 @@ pub(super) fn struct_equal( let lhs_is_null = !get_bit(lhs_null_bytes, lhs_pos + lhs.offset()); let rhs_is_null = !get_bit(rhs_null_bytes, rhs_pos + rhs.offset()); - lhs_is_null - || (lhs_is_null == rhs_is_null) - && equal_values(lhs, rhs, lhs_nulls, rhs_nulls, lhs_pos, rhs_pos, 1) + if lhs_is_null != rhs_is_null { + return false; + } + + lhs_is_null || equal_child_values(lhs, rhs, lhs_pos, rhs_pos, 1) }) } } diff --git a/arrow/src/array/equal/union.rs b/arrow/src/array/equal/union.rs index 36cd19725b5d..021b0a3b7fe7 100644 --- a/arrow/src/array/equal/union.rs +++ b/arrow/src/array/equal/union.rs @@ -15,13 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::{ - array::ArrayData, buffer::Buffer, datatypes::DataType, datatypes::UnionMode, -}; +use crate::{array::ArrayData, datatypes::DataType, datatypes::UnionMode}; -use super::{ - equal_range, equal_values, utils::child_logical_null_buffer, utils::equal_nulls, -}; +use super::equal_range; fn equal_dense( lhs: &ArrayData, @@ -41,11 +37,9 @@ fn equal_dense( let lhs_values = &lhs.child_data()[*l_type_id as usize]; let rhs_values = &rhs.child_data()[*r_type_id as usize]; - equal_values( + equal_range( lhs_values, rhs_values, - None, - None, *l_offset as usize, *r_offset as usize, 1, @@ -56,8 +50,6 @@ fn equal_dense( fn equal_sparse( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -66,26 +58,13 @@ fn equal_sparse( .iter() .zip(rhs.child_data()) .all(|(lhs_values, rhs_values)| { - // merge the null data - let lhs_merged_nulls = child_logical_null_buffer(lhs, lhs_nulls, lhs_values); - let rhs_merged_nulls = child_logical_null_buffer(rhs, rhs_nulls, rhs_values); - equal_range( - lhs_values, - rhs_values, - lhs_merged_nulls.as_ref(), - rhs_merged_nulls.as_ref(), - lhs_start, - rhs_start, - len, - ) + equal_range(lhs_values, rhs_values, lhs_start, rhs_start, len) }) } pub(super) fn union_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -104,9 +83,7 @@ pub(super) fn union_equal( let lhs_offsets_range = &lhs_offsets[lhs_start..lhs_start + len]; let rhs_offsets_range = &rhs_offsets[rhs_start..rhs_start + len]; - // nullness is kept in the parent UnionArray, so we compare its nulls here lhs_type_id_range == rhs_type_id_range - && equal_nulls(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) && equal_dense( lhs, rhs, @@ -121,7 +98,7 @@ pub(super) fn union_equal( DataType::Union(_, UnionMode::Sparse), ) => { lhs_type_id_range == rhs_type_id_range - && equal_sparse(lhs, rhs, lhs_nulls, rhs_nulls, lhs_start, rhs_start, len) + && equal_sparse(lhs, rhs, lhs_start, rhs_start, len) } _ => unimplemented!( "Logical equality not yet implemented between dense and sparse union arrays" diff --git a/arrow/src/array/equal/utils.rs b/arrow/src/array/equal/utils.rs index b6690f936ff6..8875239caf52 100644 --- a/arrow/src/array/equal/utils.rs +++ b/arrow/src/array/equal/utils.rs @@ -15,10 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::array::{data::count_nulls, ArrayData, OffsetSizeTrait}; -use crate::bitmap::Bitmap; -use crate::buffer::{Buffer, MutableBuffer}; -use crate::datatypes::{DataType, UnionMode}; +use crate::array::{data::count_nulls, ArrayData}; +use crate::datatypes::DataType; use crate::util::bit_util; // whether bits along the positions are equal @@ -41,17 +39,20 @@ pub(super) fn equal_bits( pub(super) fn equal_nulls( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, ) -> bool { - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); + + if lhs_null_count != rhs_null_count { + return false; + } + if lhs_null_count > 0 || rhs_null_count > 0 { - let lhs_values = lhs_nulls.unwrap().as_slice(); - let rhs_values = rhs_nulls.unwrap().as_slice(); + let lhs_values = lhs.null_buffer().unwrap().as_slice(); + let rhs_values = rhs.null_buffer().unwrap().as_slice(); equal_bits( lhs_values, rhs_values, @@ -111,227 +112,3 @@ pub(super) fn equal_len( ) -> bool { lhs_values[lhs_start..(lhs_start + len)] == rhs_values[rhs_start..(rhs_start + len)] } - -/// Computes the logical validity bitmap of the array data using the -/// parent's array data. The parent should be a list or struct, else -/// the logical bitmap of the array is returned unaltered. -/// -/// Parent data is passed along with the parent's logical bitmap, as -/// nested arrays could have a logical bitmap different to the physical -/// one on the `ArrayData`. -pub(super) fn child_logical_null_buffer( - parent_data: &ArrayData, - logical_null_buffer: Option<&Buffer>, - child_data: &ArrayData, -) -> Option { - let parent_len = parent_data.len(); - let parent_bitmap = logical_null_buffer - .cloned() - .map(Bitmap::from) - .unwrap_or_else(|| { - let ceil = bit_util::ceil(parent_len, 8); - Bitmap::from(Buffer::from(vec![0b11111111; ceil])) - }); - let self_null_bitmap = child_data.null_bitmap().cloned().unwrap_or_else(|| { - let ceil = bit_util::ceil(child_data.len(), 8); - Bitmap::from(Buffer::from(vec![0b11111111; ceil])) - }); - match parent_data.data_type() { - DataType::List(_) | DataType::Map(_, _) => Some(logical_list_bitmap::( - parent_data, - parent_bitmap, - self_null_bitmap, - )), - DataType::LargeList(_) => Some(logical_list_bitmap::( - parent_data, - parent_bitmap, - self_null_bitmap, - )), - DataType::FixedSizeList(_, len) => { - let len = *len as usize; - let array_offset = parent_data.offset(); - let bitmap_len = bit_util::ceil(parent_len * len, 8); - let mut buffer = MutableBuffer::from_len_zeroed(bitmap_len); - let null_slice = buffer.as_slice_mut(); - (array_offset..parent_len + array_offset).for_each(|index| { - let start = index * len; - let end = start + len; - let mask = parent_bitmap.is_set(index); - (start..end).for_each(|child_index| { - if mask && self_null_bitmap.is_set(child_index) { - bit_util::set_bit(null_slice, child_index); - } - }); - }); - Some(buffer.into()) - } - DataType::Struct(_) => { - // Arrow implementations are free to pad data, which can result in null buffers not - // having the same length. - // Rust bitwise comparisons will return an error if left AND right is performed on - // buffers of different length. - // This might be a valid case during integration testing, where we read Arrow arrays - // from IPC data, which has padding. - // - // We first perform a bitwise comparison, and if there is an error, we revert to a - // slower method that indexes into the buffers one-by-one. - let result = &parent_bitmap & &self_null_bitmap; - if let Ok(bitmap) = result { - return Some(bitmap.bits); - } - // slow path - let array_offset = parent_data.offset(); - let mut buffer = MutableBuffer::new_null(parent_len); - let null_slice = buffer.as_slice_mut(); - (0..parent_len).for_each(|index| { - if parent_bitmap.is_set(index + array_offset) - && self_null_bitmap.is_set(index + array_offset) - { - bit_util::set_bit(null_slice, index); - } - }); - Some(buffer.into()) - } - DataType::Union(_, mode) => union_child_logical_null_buffer( - parent_data, - parent_len, - &parent_bitmap, - &self_null_bitmap, - mode, - ), - DataType::Dictionary(_, _) => { - unimplemented!("Logical equality not yet implemented for nested dictionaries") - } - data_type => panic!("Data type {:?} is not a supported nested type", data_type), - } -} - -pub(super) fn union_child_logical_null_buffer( - parent_data: &ArrayData, - parent_len: usize, - parent_bitmap: &Bitmap, - self_null_bitmap: &Bitmap, - mode: &UnionMode, -) -> Option { - match mode { - UnionMode::Sparse => { - // See the logic of `DataType::Struct` in `child_logical_null_buffer`. - let result = parent_bitmap & self_null_bitmap; - if let Ok(bitmap) = result { - return Some(bitmap.bits); - } - - // slow path - let array_offset = parent_data.offset(); - let mut buffer = MutableBuffer::new_null(parent_len); - let null_slice = buffer.as_slice_mut(); - (0..parent_len).for_each(|index| { - if parent_bitmap.is_set(index + array_offset) - && self_null_bitmap.is_set(index + array_offset) - { - bit_util::set_bit(null_slice, index); - } - }); - Some(buffer.into()) - } - UnionMode::Dense => { - // We don't keep bitmap in child data of Dense UnionArray - unimplemented!("Logical equality not yet implemented for dense union arrays") - } - } -} - -// Calculate a list child's logical bitmap/buffer -#[inline] -fn logical_list_bitmap( - parent_data: &ArrayData, - parent_bitmap: Bitmap, - child_bitmap: Bitmap, -) -> Buffer { - let offsets = parent_data.buffer::(0); - let offset_start = offsets.first().unwrap().to_usize().unwrap(); - let offset_len = offsets.get(parent_data.len()).unwrap().to_usize().unwrap(); - let mut buffer = MutableBuffer::new_null(offset_len - offset_start); - let null_slice = buffer.as_slice_mut(); - - offsets - .windows(2) - .enumerate() - .take(parent_data.len()) - .for_each(|(index, window)| { - let start = window[0].to_usize().unwrap(); - let end = window[1].to_usize().unwrap(); - let mask = parent_bitmap.is_set(index); - (start..end).for_each(|child_index| { - if mask && child_bitmap.is_set(child_index) { - bit_util::set_bit(null_slice, child_index - offset_start); - } - }); - }); - buffer.into() -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::datatypes::{Field, ToByteSlice}; - - #[test] - fn test_logical_null_buffer() { - let child_data = ArrayData::builder(DataType::Int32) - .len(11) - .add_buffer(Buffer::from( - vec![1i32, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11].to_byte_slice(), - )) - .build() - .unwrap(); - - let data = ArrayData::builder(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - false, - )))) - .len(7) - .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice())) - .null_bit_buffer(Buffer::from(vec![0b01011010])) - .add_child_data(child_data.clone()) - .build() - .unwrap(); - - // Get the child logical null buffer. The child is non-nullable, but because the list has nulls, - // we expect the child to logically have some nulls, inherited from the parent: - // [1, 2, 3, null, null, 6, 7, 8, 9, null, 11] - let nulls = child_logical_null_buffer( - &data, - data.null_buffer(), - data.child_data().get(0).unwrap(), - ); - let expected = Some(Buffer::from(vec![0b11100111, 0b00000101])); - assert_eq!(nulls, expected); - - // test with offset - let data = ArrayData::builder(DataType::List(Box::new(Field::new( - "item", - DataType::Int32, - false, - )))) - .len(4) - .offset(3) - .add_buffer(Buffer::from(vec![0, 0, 3, 5, 6, 9, 10, 11].to_byte_slice())) - // the null_bit_buffer doesn't have an offset, i.e. cleared the 3 offset bits 0b[---]01011[010] - .null_bit_buffer(Buffer::from(vec![0b00001011])) - .add_child_data(child_data) - .build() - .unwrap(); - - let nulls = child_logical_null_buffer( - &data, - data.null_buffer(), - data.child_data().get(0).unwrap(), - ); - - let expected = Some(Buffer::from(vec![0b00101111])); - assert_eq!(nulls, expected); - } -} diff --git a/arrow/src/array/equal/variable_size.rs b/arrow/src/array/equal/variable_size.rs index 946f107f3971..f40f79e404ac 100644 --- a/arrow/src/array/equal/variable_size.rs +++ b/arrow/src/array/equal/variable_size.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::Buffer; use crate::util::bit_util::get_bit; use crate::{ array::data::count_nulls, @@ -51,8 +50,6 @@ fn offset_value_equal( pub(super) fn variable_sized_equal( lhs: &ArrayData, rhs: &ArrayData, - lhs_nulls: Option<&Buffer>, - rhs_nulls: Option<&Buffer>, lhs_start: usize, rhs_start: usize, len: usize, @@ -64,8 +61,8 @@ pub(super) fn variable_sized_equal( let lhs_values = lhs.buffers()[1].as_slice(); let rhs_values = rhs.buffers()[1].as_slice(); - let lhs_null_count = count_nulls(lhs_nulls, lhs_start, len); - let rhs_null_count = count_nulls(rhs_nulls, rhs_start, len); + let lhs_null_count = count_nulls(lhs.null_buffer(), lhs_start + lhs.offset(), len); + let rhs_null_count = count_nulls(rhs.null_buffer(), rhs_start + rhs.offset(), len); if lhs_null_count == 0 && rhs_null_count == 0 @@ -87,10 +84,13 @@ pub(super) fn variable_sized_equal( let rhs_pos = rhs_start + i; // the null bits can still be `None`, indicating that the value is valid. - let lhs_is_null = !lhs_nulls + let lhs_is_null = !lhs + .null_buffer() .map(|v| get_bit(v.as_slice(), lhs.offset() + lhs_pos)) .unwrap_or(true); - let rhs_is_null = !rhs_nulls + + let rhs_is_null = !rhs + .null_buffer() .map(|v| get_bit(v.as_slice(), rhs.offset() + rhs_pos)) .unwrap_or(true); diff --git a/arrow/src/array/transform/union.rs b/arrow/src/array/transform/union.rs index ec672daf4d99..bbea508219d0 100644 --- a/arrow/src/array/transform/union.rs +++ b/arrow/src/array/transform/union.rs @@ -22,140 +22,50 @@ use super::{Extend, _MutableArrayData}; pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); - if array.null_count() == 0 { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - // extends type_ids - mutable - .buffer1 - .extend_from_slice(&type_ids[start..start + len]); - - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend(index, start, start + len)) - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - // extends type_ids - mutable - .buffer1 - .extend_from_slice(&type_ids[start..start + len]); + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); - (start..start + len).for_each(|i| { - if array.is_valid(i) { - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend(index, i, i + 1)) - } else { - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend_nulls(1)) - } - }) - }, - ) - } + mutable + .child_data + .iter_mut() + .for_each(|child| child.extend(index, start, start + len)) + }, + ) } pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { let type_ids = array.buffer::(0); let offsets = array.buffer::(1); - if array.null_count() == 0 { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - // extends type_ids - mutable - .buffer1 - .extend_from_slice(&type_ids[start..start + len]); - // extends offsets - mutable - .buffer2 - .extend_from_slice(&offsets[start..start + len]); - - (start..start + len).for_each(|i| { - let type_id = type_ids[i] as usize; - let offset_start = offsets[start] as usize; - - mutable.child_data[type_id].extend( - index, - offset_start, - offset_start + 1, - ) - }) - }, - ) - } else { - Box::new( - move |mutable: &mut _MutableArrayData, - index: usize, - start: usize, - len: usize| { - // extends type_ids - mutable - .buffer1 - .extend_from_slice(&type_ids[start..start + len]); - // extends offsets - mutable - .buffer2 - .extend_from_slice(&offsets[start..start + len]); + Box::new( + move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { + // extends type_ids + mutable + .buffer1 + .extend_from_slice(&type_ids[start..start + len]); - (start..start + len).for_each(|i| { - let type_id = type_ids[i] as usize; - let offset_start = offsets[start] as usize; + (start..start + len).for_each(|i| { + let type_id = type_ids[i] as usize; + let src_offset = offsets[i] as usize; + let child_data = &mut mutable.child_data[type_id]; + let dst_offset = child_data.len(); - if array.is_valid(i) { - mutable.child_data[type_id].extend( - index, - offset_start, - offset_start + 1, - ) - } else { - mutable.child_data[type_id].extend_nulls(1) - } - }) - }, - ) - } + // Extend offsets + mutable.buffer2.push(dst_offset as i32); + mutable.child_data[type_id].extend(index, src_offset, src_offset + 1) + }) + }, + ) } -pub(super) fn extend_nulls_dense(mutable: &mut _MutableArrayData, len: usize) { - let mut count: usize = 0; - let num = len / mutable.child_data.len(); - mutable - .child_data - .iter_mut() - .enumerate() - .for_each(|(idx, child)| { - let n = if count + num > len { len - count } else { num }; - count += n; - mutable - .buffer1 - .extend_from_slice(vec![idx as i8; n].as_slice()); - mutable - .buffer2 - .extend_from_slice(vec![child.len() as i32; n].as_slice()); - child.extend_nulls(n) - }) +pub(super) fn extend_nulls_dense(_mutable: &mut _MutableArrayData, _len: usize) { + panic!("cannot call extend_nulls on UnionArray as cannot infer type"); } -pub(super) fn extend_nulls_sparse(mutable: &mut _MutableArrayData, len: usize) { - mutable - .child_data - .iter_mut() - .for_each(|child| child.extend_nulls(len)) +pub(super) fn extend_nulls_sparse(_mutable: &mut _MutableArrayData, _len: usize) { + panic!("cannot call extend_nulls on UnionArray as cannot infer type"); } diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index df59ba63c79d..b4abcd5a441e 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -1670,22 +1670,53 @@ mod tests { test_filter_union_array(array); } + #[test] + fn test_filter_run_union_array_dense() { + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("A", 3).unwrap(); + builder.append::("A", 34).unwrap(); + let array = builder.build().unwrap(); + + let filter_array = BooleanArray::from(vec![true, true, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(3); + builder.append::("A", 1).unwrap(); + builder.append::("A", 3).unwrap(); + let expected = builder.build().unwrap(); + + assert_eq!(filtered.data(), expected.data()); + } + #[test] fn test_filter_union_array_dense_with_nulls() { let mut builder = UnionBuilder::new_dense(4); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("B").unwrap(); builder.append::("A", 34).unwrap(); let array = builder.build().unwrap(); + let filter_array = BooleanArray::from(vec![true, true, false, false]); + let c = filter(&array, &filter_array).unwrap(); + let filtered = c.as_any().downcast_ref::().unwrap(); + + let mut builder = UnionBuilder::new_dense(2); + builder.append::("A", 1).unwrap(); + builder.append::("B", 3.2).unwrap(); + let expected_array = builder.build().unwrap(); + + compare_union_arrays(filtered, &expected_array); + let filter_array = BooleanArray::from(vec![true, false, true, false]); let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(1); + let mut builder = UnionBuilder::new_dense(2); builder.append::("A", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("B").unwrap(); let expected_array = builder.build().unwrap(); compare_union_arrays(filtered, &expected_array); @@ -1707,7 +1738,7 @@ mod tests { let mut builder = UnionBuilder::new_sparse(4); builder.append::("A", 1).unwrap(); builder.append::("B", 3.2).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("B").unwrap(); builder.append::("A", 34).unwrap(); let array = builder.build().unwrap(); @@ -1715,9 +1746,9 @@ mod tests { let c = filter(&array, &filter_array).unwrap(); let filtered = c.as_any().downcast_ref::().unwrap(); - let mut builder = UnionBuilder::new_dense(1); + let mut builder = UnionBuilder::new_sparse(2); builder.append::("A", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("B").unwrap(); let expected_array = builder.build().unwrap(); compare_union_arrays(filtered, &expected_array); @@ -1732,9 +1763,9 @@ mod tests { let slot1 = union1.value(i); let slot2 = union2.value(i); - assert_eq!(union1.is_null(i), union2.is_null(i)); + assert_eq!(slot1.is_null(0), slot2.is_null(0)); - if !union1.is_null(i) && !union2.is_null(i) { + if !slot1.is_null(0) && !slot2.is_null(0) { match type_id { 0 => { let slot1 = slot1.as_any().downcast_ref::().unwrap(); diff --git a/arrow/src/ipc/reader.rs b/arrow/src/ipc/reader.rs index 143fa929da7c..098d9bb353e3 100644 --- a/arrow/src/ipc/reader.rs +++ b/arrow/src/ipc/reader.rs @@ -190,11 +190,10 @@ fn create_array( let len = union_node.length() as usize; - let null_buffer: Buffer = read_buffer(&buffers[buffer_index], data); let type_ids: Buffer = - read_buffer(&buffers[buffer_index + 1], data)[..len].into(); + read_buffer(&buffers[buffer_index], data)[..len].into(); - buffer_index += 2; + buffer_index += 1; let value_offsets = match mode { UnionMode::Dense => { @@ -224,13 +223,7 @@ fn create_array( children.push((field.clone(), triple.0)); } - let array = UnionArray::try_new( - type_ids, - value_offsets, - children, - Some(null_buffer), - )?; - + let array = UnionArray::try_new(type_ids, value_offsets, children)?; Arc::new(array) } Null => { @@ -1359,7 +1352,7 @@ mod tests { fn check_union_with_builder(mut builder: UnionBuilder) { builder.append::("a", 1).unwrap(); - builder.append_null().unwrap(); + builder.append_null::("a").unwrap(); builder.append::("c", 3.0).unwrap(); builder.append::("a", 4).unwrap(); builder.append::("d", 11).unwrap(); diff --git a/arrow/src/ipc/writer.rs b/arrow/src/ipc/writer.rs index a5b35f364e72..3bb471f058bb 100644 --- a/arrow/src/ipc/writer.rs +++ b/arrow/src/ipc/writer.rs @@ -811,7 +811,11 @@ fn write_array_data( let mut offset = offset; nodes.push(ipc::FieldNode::new(num_rows as i64, null_count as i64)); // NullArray does not have any buffers, thus the null buffer is not generated - if array_data.data_type() != &DataType::Null { + // UnionArray does not have a validity buffer + if !matches!( + array_data.data_type(), + DataType::Null | DataType::Union(_, _) + ) { // write null buffer if exists let null_buffer = match array_data.null_buffer() { None => { @@ -1273,8 +1277,7 @@ mod tests { let offsets = Buffer::from_slice_ref(&[0_i32, 1, 2]); let union = - UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)], None) - .unwrap(); + UnionArray::try_new(types, Some(offsets), vec![(dctfield, array)]).unwrap(); let schema = Arc::new(Schema::new(vec![Field::new( "union",