-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add string aggregate grouping fuzz test, add MemTable::with_sort_exprs
#9190
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions}; | |
use arrow::datatypes::DataType; | ||
use arrow::record_batch::RecordBatch; | ||
use arrow::util::pretty::pretty_format_batches; | ||
use datafusion::physical_plan::aggregates::{ | ||
AggregateExec, AggregateMode, PhysicalGroupBy, | ||
}; | ||
use arrow_array::cast::AsArray; | ||
use arrow_array::types::Int64Type; | ||
use arrow_array::Array; | ||
use hashbrown::HashMap; | ||
use rand::rngs::StdRng; | ||
use rand::{Rng, SeedableRng}; | ||
use tokio::task::JoinSet; | ||
|
||
use datafusion::common::Result; | ||
use datafusion::datasource::MemTable; | ||
use datafusion::physical_plan::aggregates::{ | ||
AggregateExec, AggregateMode, PhysicalGroupBy, | ||
}; | ||
use datafusion::physical_plan::memory::MemoryExec; | ||
use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; | ||
use datafusion::prelude::{SessionConfig, SessionContext}; | ||
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; | ||
use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; | ||
use datafusion_physical_expr::expressions::{col, Sum}; | ||
use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; | ||
use test_utils::add_empty_batches; | ||
use datafusion_physical_plan::InputOrderMode; | ||
use test_utils::{add_empty_batches, StringBatchGenerator}; | ||
|
||
#[tokio::test(flavor = "multi_thread", worker_threads = 8)] | ||
async fn aggregate_test() { | ||
/// Tests that streaming aggregate and batch (non streaming) aggregate produce | ||
/// same results | ||
#[tokio::test(flavor = "multi_thread")] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no reason to limit this to 8 threads that I know of. Using multi-thread uses more cores if available |
||
async fn streaming_aggregate_test() { | ||
let test_cases = vec![ | ||
vec!["a"], | ||
vec!["b", "a"], | ||
|
@@ -50,18 +61,18 @@ async fn aggregate_test() { | |
let n = 300; | ||
let distincts = vec![10, 20]; | ||
for distinct in distincts { | ||
let mut handles = Vec::new(); | ||
let mut join_set = JoinSet::new(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using JoinSet for consistency -- which automatically cancels outstanding tasks on panic |
||
for i in 0..n { | ||
let test_idx = i % test_cases.len(); | ||
let group_by_columns = test_cases[test_idx].clone(); | ||
let job = tokio::spawn(run_aggregate_test( | ||
join_set.spawn(run_aggregate_test( | ||
make_staggered_batches::<true>(1000, distinct, i as u64), | ||
group_by_columns, | ||
)); | ||
handles.push(job); | ||
} | ||
for job in handles { | ||
job.await.unwrap(); | ||
while let Some(join_handle) = join_set.join_next().await { | ||
// propagate errors | ||
join_handle.unwrap(); | ||
} | ||
} | ||
} | ||
|
@@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>( | |
} | ||
add_empty_batches(batches, &mut rng) | ||
} | ||
|
||
/// Test group by with string/large string columns | ||
#[tokio::test(flavor = "multi_thread")] | ||
async fn group_by_strings() { | ||
let mut join_set = JoinSet::new(); | ||
for large in [true, false] { | ||
for sorted in [true, false] { | ||
for generator in StringBatchGenerator::interesting_cases() { | ||
join_set.spawn(group_by_string_test(generator, sorted, large)); | ||
} | ||
} | ||
} | ||
while let Some(join_handle) = join_set.join_next().await { | ||
// propagate errors | ||
join_handle.unwrap(); | ||
} | ||
} | ||
|
||
/// Run GROUP BY <x> using SQL and ensure the results are correct | ||
/// | ||
/// If sorted is true, the input batches will be sorted by the group by column | ||
/// to test the streaming group by case | ||
/// | ||
/// if large is true, the input batches will be LargeStringArray | ||
async fn group_by_string_test( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here is the new test |
||
mut generator: StringBatchGenerator, | ||
sorted: bool, | ||
large: bool, | ||
) { | ||
let column_name = "a"; | ||
let input = if sorted { | ||
generator.make_sorted_input_batches(large) | ||
} else { | ||
generator.make_input_batches() | ||
}; | ||
|
||
let expected = compute_counts(&input, column_name); | ||
|
||
let schema = input[0].schema(); | ||
let session_config = SessionConfig::new().with_batch_size(50); | ||
let ctx = SessionContext::new_with_config(session_config); | ||
|
||
let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); | ||
let provider = if sorted { | ||
let sort_expr = datafusion::prelude::col("a").sort(true, true); | ||
provider.with_sort_order(vec![vec![sort_expr]]) | ||
} else { | ||
provider | ||
}; | ||
|
||
ctx.register_table("t", Arc::new(provider)).unwrap(); | ||
|
||
let df = ctx | ||
.sql("SELECT a, COUNT(*) FROM t GROUP BY a") | ||
.await | ||
.unwrap(); | ||
verify_ordered_aggregate(&df, sorted).await; | ||
let results = df.collect().await.unwrap(); | ||
|
||
// verify that the results are correct | ||
let actual = extract_result_counts(results); | ||
assert_eq!(expected, actual); | ||
} | ||
async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { | ||
struct Visitor { | ||
expected_sort: bool, | ||
} | ||
let mut visitor = Visitor { expected_sort }; | ||
|
||
impl TreeNodeVisitor for Visitor { | ||
type N = Arc<dyn ExecutionPlan>; | ||
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> { | ||
if let Some(exec) = node.as_any().downcast_ref::<AggregateExec>() { | ||
if self.expected_sort { | ||
assert!(matches!( | ||
exec.input_order_mode(), | ||
InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted | ||
)); | ||
} else { | ||
assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); | ||
} | ||
} | ||
Ok(VisitRecursion::Continue) | ||
} | ||
} | ||
|
||
let plan = frame.clone().create_physical_plan().await.unwrap(); | ||
plan.visit(&mut visitor).unwrap(); | ||
} | ||
|
||
/// Compute the count of each distinct value in the specified column | ||
/// | ||
/// ```text | ||
/// +---------------+---------------+ | ||
/// | a | b | | ||
/// +---------------+---------------+ | ||
/// | 𭏷𑩁 | 𘱦𫎛 | | ||
/// | | 𬿪 | | ||
/// ``` | ||
fn compute_counts(batches: &[RecordBatch], col: &str) -> HashMap<Option<String>, i64> { | ||
let mut output = HashMap::new(); | ||
for arr in batches | ||
.iter() | ||
.map(|batch| batch.column_by_name(col).unwrap()) | ||
{ | ||
for value in to_str_vec(arr) { | ||
output.entry(value).and_modify(|e| *e += 1).or_insert(1); | ||
} | ||
} | ||
output | ||
} | ||
|
||
fn to_str_vec(array: &ArrayRef) -> Vec<Option<String>> { | ||
match array.data_type() { | ||
DataType::Utf8 => array | ||
.as_string::<i32>() | ||
.iter() | ||
.map(|x| x.map(|x| x.to_string())) | ||
.collect(), | ||
DataType::LargeUtf8 => array | ||
.as_string::<i64>() | ||
.iter() | ||
.map(|x| x.map(|x| x.to_string())) | ||
.collect(), | ||
_ => panic!("unexpected type"), | ||
} | ||
} | ||
|
||
/// extracts the value of the first column and the count of the second column | ||
/// ```text | ||
/// +----------------+----------+ | ||
/// | a | COUNT(*) | | ||
/// +----------------+----------+ | ||
/// | | 8 | | ||
/// | | 11 | | ||
/// ``` | ||
fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i64> { | ||
let group_arrays = results.iter().map(|batch| batch.column(0)); | ||
|
||
let count_arrays = results | ||
.iter() | ||
.map(|batch| batch.column(1).as_primitive::<Int64Type>()); | ||
|
||
let mut output = HashMap::new(); | ||
for (group_arr, count_arr) in group_arrays.zip(count_arrays) { | ||
assert_eq!(group_arr.len(), count_arr.len()); | ||
let group_values = to_str_vec(group_arr); | ||
for (group, count) in group_values.into_iter().zip(count_arr.iter()) { | ||
assert!(output.get(&group).is_none()); | ||
let count = count.unwrap(); // counts can never be null | ||
output.insert(group, count); | ||
} | ||
} | ||
output | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I needed this feature to write a test that ran with sorted aggregates