Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecordBatch normalization (flattening) #6758

Merged
merged 21 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bbd7c8b
Added set up for the example of flattening from pyarrow.
Nov 18, 2024
8abcd25
Logic for recursive normalizer with a base normalize function, based …
Nov 20, 2024
6bba7d3
Added recursive normalize function for `Schema`, and started building…
Nov 23, 2024
55eb953
Built out a bit more of the iterative normalize.
Nov 23, 2024
30d6294
Fixed normalize function for `RecordBatch`. Adjusted test case to mat…
Nov 24, 2024
0ed979d
Added tests for `Schema` normalization. Partial tests for `RecordBatch`.
Nov 25, 2024
d9d08cd
Removed stray comments.
Nov 25, 2024
d1b3260
Commenting out exclamation field.
Nov 25, 2024
a12082c
Merge remote-tracking branch 'upstream/main' into feature/record-batc…
Dec 5, 2024
7adda58
Fixed test for `RecordBatch`.
Dec 5, 2024
9c9c699
Formatting.
Dec 5, 2024
4422add
Additional documentation for `normalize` functions. Switched `Schema`…
Dec 31, 2024
d0dc5a7
Forgot to push to the columns in the else case.
Dec 31, 2024
1e40c98
Adjusted the documentation to include the parameters.
Dec 31, 2024
3c424d1
Formatting.
Dec 31, 2024
6d6b026
Edited examples to not be ran as tests.
Dec 31, 2024
71380b6
Adjusted based on some of the suggestions. Simplified the matching an…
Jan 5, 2025
af7946b
Additional test cases for List and FixedSizeList in Schema.
Jan 11, 2025
e97cc9c
Additional test cases for deeply nested normalization.
Jan 20, 2025
b90e8f5
Suggestions from Jefffrey on the descriptions and stack initialization.
Jan 20, 2025
6a2e3ca
Forgot parenthesis.
Jan 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 280 additions & 3 deletions arrow-array/src/record_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
//! A two-dimensional batch of column-oriented data with a defined
//! [schema](arrow_schema::Schema).

use crate::cast::AsArray;
use crate::{new_empty_array, Array, ArrayRef, StructArray};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaBuilder, SchemaRef};
use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef};
use std::ops::Index;
use std::sync::Arc;

Expand Down Expand Up @@ -394,6 +395,108 @@ impl RecordBatch {
)
}

/// Normalize a semi-structured [`RecordBatch`] into a flat table.
///
/// Nested [`Field`]s will generate names separated by `separator`, up to a depth of `max_level`
/// (unlimited if `None`).
///
/// e.g. given a [`RecordBatch`] with schema:
///
/// ```text
/// "foo": StructArray<"bar": Utf8>
/// ```
///
/// A separator of `"."` would generate a batch with the schema:
///
/// ```text
/// "foo.bar": Utf8
/// ```
///
/// Note that giving a depth of `Some(0)` to `max_level` is the same as passing in `None`;
/// it will be treated as unlimited.
///
/// # Example
///
/// ```
/// # use std::sync::Arc;
/// # use arrow_array::{ArrayRef, Int64Array, StringArray, StructArray, RecordBatch};
/// # use arrow_schema::{DataType, Field, Fields, Schema};
/// #
/// let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
/// let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
///
/// let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
/// let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
///
/// let a = Arc::new(StructArray::from(vec![
/// (animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
/// (n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
/// ]));
///
/// let schema = Schema::new(vec![
/// Field::new(
/// "a",
/// DataType::Struct(Fields::from(vec![animals_field, n_legs_field])),
/// false,
/// )
/// ]);
///
/// let normalized = RecordBatch::try_new(Arc::new(schema), vec![a])
/// .expect("valid conversion")
/// .normalize(".", None)
/// .expect("valid normalization");
///
/// let expected = RecordBatch::try_from_iter_with_nullable(vec![
/// ("a.animals", animals.clone(), true),
/// ("a.n_legs", n_legs.clone(), true),
/// ])
/// .expect("valid conversion");
///
/// assert_eq!(expected, normalized);
/// ```
pub fn normalize(&self, separator: &str, max_level: Option<usize>) -> Result<Self, ArrowError> {
let max_level = match max_level.unwrap_or(usize::MAX) {
0 => usize::MAX,
val => val,
};
let mut stack: Vec<(usize, &ArrayRef, Vec<&str>, &FieldRef)> = self
.columns
.iter()
.zip(self.schema.fields())
.rev()
.map(|(c, f)| {
let name_vec: Vec<&str> = vec![f.name()];
(0, c, name_vec, f)
})
.collect();
let mut columns: Vec<ArrayRef> = Vec::new();
let mut fields: Vec<FieldRef> = Vec::new();

while let Some((depth, c, name, field_ref)) = stack.pop() {
match field_ref.data_type() {
DataType::Struct(ff) if depth < max_level => {
// Need to zip these in reverse to maintain original order
for (cff, fff) in c.as_struct().columns().iter().zip(ff.into_iter()).rev() {
let mut name = name.clone();
name.push(separator);
name.push(fff.name());
stack.push((depth + 1, cff, name, fff))
}
}
_ => {
let updated_field = Field::new(
name.concat(),
field_ref.data_type().clone(),
field_ref.is_nullable(),
);
columns.push(c.clone());
fields.push(Arc::new(updated_field));
}
}
}
RecordBatch::try_new(Arc::new(Schema::new(fields)), columns)
}

/// Returns the number of columns in the record batch.
///
/// # Example
Expand Down Expand Up @@ -768,15 +871,14 @@ where

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use super::*;
use crate::{
BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray,
};
use arrow_buffer::{Buffer, ToByteSlice};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::Fields;
use std::collections::HashMap;

#[test]
fn create_record_batch() {
Expand Down Expand Up @@ -1197,6 +1299,181 @@ mod tests {
assert_ne!(batch1, batch2);
}

#[test]
fn normalize_simple() {
let animals: ArrayRef = Arc::new(StringArray::from(vec!["Parrot", ""]));
let n_legs: ArrayRef = Arc::new(Int64Array::from(vec![Some(2), Some(4)]));
let year: ArrayRef = Arc::new(Int64Array::from(vec![None, Some(2022)]));

let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let a = Arc::new(StructArray::from(vec![
(animals_field.clone(), Arc::new(animals.clone()) as ArrayRef),
(n_legs_field.clone(), Arc::new(n_legs.clone()) as ArrayRef),
(year_field.clone(), Arc::new(year.clone()) as ArrayRef),
]));

let month = Arc::new(Int64Array::from(vec![Some(4), Some(6)]));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized =
RecordBatch::try_new(Arc::new(schema.clone()), vec![a.clone(), month.clone()])
.expect("valid conversion")
.normalize(".", Some(0))
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("a.animals", animals.clone(), true),
("a.n_legs", n_legs.clone(), true),
("a.year", year.clone(), true),
("month", month.clone(), true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);

// check 0 and None have the same effect
let normalized = RecordBatch::try_new(Arc::new(schema), vec![a, month.clone()])
.expect("valid conversion")
.normalize(".", None)
.expect("valid normalization");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_nested() {
// Initialize schema
let a = Arc::new(Field::new("a", DataType::Int64, true));
let b = Arc::new(Field::new("b", DataType::Int64, false));
let c = Arc::new(Field::new("c", DataType::Int64, true));

let one = Arc::new(Field::new(
"1",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
false,
));
let two = Arc::new(Field::new(
"2",
DataType::Struct(Fields::from(vec![a.clone(), b.clone(), c.clone()])),
true,
));

let exclamation = Arc::new(Field::new(
"!",
DataType::Struct(Fields::from(vec![one.clone(), two.clone()])),
false,
));

let schema = Schema::new(vec![exclamation.clone()]);

// Initialize fields
let a_field = Int64Array::from(vec![Some(0), Some(1)]);
let b_field = Int64Array::from(vec![Some(2), Some(3)]);
let c_field = Int64Array::from(vec![None, Some(4)]);

let one_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);
let two_field = StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
]);

let exclamation_field = Arc::new(StructArray::from(vec![
(one.clone(), Arc::new(one_field) as ArrayRef),
(two.clone(), Arc::new(two_field) as ArrayRef),
]));

// Normalize top level
let normalized =
RecordBatch::try_new(Arc::new(schema.clone()), vec![exclamation_field.clone()])
.expect("valid conversion")
.normalize(".", Some(1))
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
(
"!.1",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
false,
),
(
"!.2",
Arc::new(StructArray::from(vec![
(a.clone(), Arc::new(a_field.clone()) as ArrayRef),
(b.clone(), Arc::new(b_field.clone()) as ArrayRef),
(c.clone(), Arc::new(c_field.clone()) as ArrayRef),
])) as ArrayRef,
true,
),
])
.expect("valid conversion");

assert_eq!(expected, normalized);

// Normalize all levels
let normalized = RecordBatch::try_new(Arc::new(schema), vec![exclamation_field])
.expect("valid conversion")
.normalize(".", None)
.expect("valid normalization");

let expected = RecordBatch::try_from_iter_with_nullable(vec![
("!.1.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.1.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.1.c", Arc::new(c_field.clone()) as ArrayRef, true),
("!.2.a", Arc::new(a_field.clone()) as ArrayRef, true),
("!.2.b", Arc::new(b_field.clone()) as ArrayRef, false),
("!.2.c", Arc::new(c_field.clone()) as ArrayRef, true),
])
.expect("valid conversion");

assert_eq!(expected, normalized);
}

#[test]
fn normalize_empty() {
let animals_field = Arc::new(Field::new("animals", DataType::Utf8, true));
let n_legs_field = Arc::new(Field::new("n_legs", DataType::Int64, true));
let year_field = Arc::new(Field::new("year", DataType::Int64, true));

let schema = Schema::new(vec![
Field::new(
"a",
DataType::Struct(Fields::from(vec![animals_field, n_legs_field, year_field])),
false,
),
Field::new("month", DataType::Int64, true),
]);

let normalized = RecordBatch::new_empty(Arc::new(schema.clone()))
.normalize(".", Some(0))
.expect("valid normalization");

let expected = RecordBatch::new_empty(Arc::new(
schema.normalize(".", Some(0)).expect("valid normalization"),
));

assert_eq!(expected, normalized);
}

#[test]
fn project() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)]));
Expand Down
Loading
Loading