From 624efcacce6c8c2035281b077e4c98f18866e224 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 10 Feb 2024 10:39:59 -0500 Subject: [PATCH] Add string aggregagte grouping fuzz test --- datafusion/core/src/datasource/memory.rs | 60 +++++- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 190 ++++++++++++++++-- .../fuzz_cases/distinct_count_string_fuzz.rs | 104 +--------- .../physical-plan/src/aggregates/mod.rs | 3 + test-utils/src/data_gen.rs | 1 + test-utils/src/lib.rs | 2 + test-utils/src/string_gen.rs | 139 +++++++++++++ 7 files changed, 381 insertions(+), 118 deletions(-) create mode 100644 test-utils/src/string_gen.rs diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 7c61cc536860..901e74dfc218 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -29,9 +29,10 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion_common::{ - not_impl_err, plan_err, Constraints, DataFusionError, SchemaExt, + not_impl_err, plan_err, Constraints, DFSchema, DataFusionError, SchemaExt, }; use datafusion_execution::TaskContext; +use parking_lot::Mutex; use tokio::sync::RwLock; use tokio::task::JoinSet; @@ -44,6 +45,7 @@ use crate::physical_plan::memory::MemoryExec; use crate::physical_plan::{common, SendableRecordBatchStream}; use crate::physical_plan::{repartition::RepartitionExec, Partitioning}; use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan}; +use crate::physical_planner::create_physical_sort_expr; /// Type alias for partition data pub type PartitionData = Arc>>; @@ -58,6 +60,9 @@ pub struct MemTable { pub(crate) batches: Vec, constraints: Constraints, column_defaults: HashMap, + /// Optional pre-known sort order(s). Must be `SortExpr`s. + /// inserting data into this table removes the order + pub sort_order: Arc>>>, } impl MemTable { @@ -82,6 +87,7 @@ impl MemTable { .collect::>(), constraints: Constraints::empty(), column_defaults: HashMap::new(), + sort_order: Arc::new(Mutex::new(vec![])), }) } @@ -100,6 +106,21 @@ impl MemTable { self } + /// Specify an optional pre-known sort order(s). Must be `SortExpr`s. + /// + /// If the data is not sorted by this order, DataFusion may produce + /// incorrect results. + /// + /// DataFusion may take advantage of this ordering to omit sorts + /// or use more efficient algorithms. + /// + /// Note that multiple sort orders are supported, if some are known to be + /// equivalent, + pub fn with_sort_order(self, mut sort_order: Vec>) -> Self { + std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order); + self + } + /// Create a mem table by reading from another data source pub async fn load( t: Arc, @@ -184,7 +205,7 @@ impl TableProvider for MemTable { async fn scan( &self, - _state: &SessionState, + state: &SessionState, projection: Option<&Vec>, _filters: &[Expr], _limit: Option, @@ -194,11 +215,33 @@ impl TableProvider for MemTable { let inner_vec = arc_inner_vec.read().await; partitions.push(inner_vec.clone()) } - Ok(Arc::new(MemoryExec::try_new( - &partitions, - self.schema(), - projection.cloned(), - )?)) + let mut exec = + MemoryExec::try_new(&partitions, self.schema(), projection.cloned())?; + + // add sort information if present + let sort_order = self.sort_order.lock(); + if !sort_order.is_empty() { + let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?; + + let file_sort_order = sort_order + .iter() + .map(|sort_exprs| { + sort_exprs + .iter() + .map(|expr| { + create_physical_sort_expr( + expr, + &df_schema, + state.execution_props(), + ) + }) + .collect::>>() + }) + .collect::>>()?; + exec = exec.with_sort_information(file_sort_order); + } + + Ok(Arc::new(exec)) } /// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`]. @@ -219,6 +262,9 @@ impl TableProvider for MemTable { input: Arc, overwrite: bool, ) -> Result> { + // If we are inserting into the table, any sort order may be messed up so reset it here + *self.sort_order.lock() = vec![]; + // Create a physical plan from the logical plan. // Check that the schema of the plan matches the schema of this table. if !self diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 9069dbbd5850..6b371b782cb5 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; use arrow::record_batch::RecordBatch; use arrow::util::pretty::pretty_format_batches; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; +use arrow_array::cast::AsArray; +use arrow_array::types::Int64Type; +use arrow_array::Array; +use hashbrown::HashMap; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; +use tokio::task::JoinSet; +use datafusion::common::Result; +use datafusion::datasource::MemTable; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, displayable, ExecutionPlan}; -use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion::prelude::{DataFrame, SessionConfig, SessionContext}; +use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion}; use datafusion_physical_expr::expressions::{col, Sum}; use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr}; -use test_utils::add_empty_batches; +use datafusion_physical_plan::InputOrderMode; +use test_utils::{add_empty_batches, StringBatchGenerator}; -#[tokio::test(flavor = "multi_thread", worker_threads = 8)] -async fn aggregate_test() { +/// Tests that streaming aggregate and batch (non streaming) aggregate produce +/// same results +#[tokio::test(flavor = "multi_thread")] +async fn streaming_aggregate_test() { let test_cases = vec![ vec!["a"], vec!["b", "a"], @@ -50,18 +61,18 @@ async fn aggregate_test() { let n = 300; let distincts = vec![10, 20]; for distinct in distincts { - let mut handles = Vec::new(); + let mut join_set = JoinSet::new(); for i in 0..n { let test_idx = i % test_cases.len(); let group_by_columns = test_cases[test_idx].clone(); - let job = tokio::spawn(run_aggregate_test( + join_set.spawn(run_aggregate_test( make_staggered_batches::(1000, distinct, i as u64), group_by_columns, )); - handles.push(job); } - for job in handles { - job.await.unwrap(); + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); } } } @@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches( } add_empty_batches(batches, &mut rng) } + +/// Test group by with string/large string columns +#[tokio::test(flavor = "multi_thread")] +async fn group_by_strings() { + let mut join_set = JoinSet::new(); + for large in [true, false] { + for sorted in [true, false] { + for generator in StringBatchGenerator::interesting_cases() { + join_set.spawn(group_by_string_test(generator, sorted, large)); + } + } + } + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); + } +} + +/// Run GROUP BY using SQL and ensure the results are correct +/// +/// If sorted is true, the input batches will be sorted by the group by column +/// to test the streaming group by case +/// +/// if large is true, the input batches will be LargeStringArray +async fn group_by_string_test( + mut generator: StringBatchGenerator, + sorted: bool, + large: bool, +) { + let column_name = "a"; + let input = if sorted { + generator.make_sorted_input_batches(large) + } else { + generator.make_input_batches() + }; + + let expected = compute_counts(&input, column_name); + + let schema = input[0].schema(); + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::new_with_config(session_config); + + let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap(); + let provider = if sorted { + let sort_expr = datafusion::prelude::col("a").sort(true, true); + provider.with_sort_order(vec![vec![sort_expr]]) + } else { + provider + }; + + ctx.register_table("t", Arc::new(provider)).unwrap(); + + let df = ctx + .sql("SELECT a, COUNT(*) FROM t GROUP BY a") + .await + .unwrap(); + verify_ordered_aggregate(&df, sorted).await; + let results = df.collect().await.unwrap(); + + // verify that the results are correct + let actual = extract_result_counts(results); + assert_eq!(expected, actual); +} +async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { + struct Visitor { + expected_sort: bool, + } + let mut visitor = Visitor { expected_sort }; + + impl TreeNodeVisitor for Visitor { + type N = Arc; + fn pre_visit(&mut self, node: &Self::N) -> Result { + if let Some(exec) = node.as_any().downcast_ref::() { + if self.expected_sort { + assert!(matches!( + exec.input_order_mode(), + InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted + )); + } else { + assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear)); + } + } + Ok(VisitRecursion::Continue) + } + } + + let plan = frame.clone().create_physical_plan().await.unwrap(); + plan.visit(&mut visitor).unwrap(); +} + +/// Compute the count of each distinct value in the specified column +/// +/// ```text +/// +---------------+---------------+ +/// | a | b | +/// +---------------+---------------+ +/// | 𭏷񬝜󓴻𼇪󄶛𑩁򽵐󦊟 | 󺚤𘱦𫎛񐕿 | +/// | 󂌿󶴬񰶨񺹭𿑵󖺉 | 񥼧􋽐󮋋󑤐𬿪𜋃 | +/// ``` +fn compute_counts(batches: &[RecordBatch], col: &str) -> HashMap, i64> { + let mut output = HashMap::new(); + for arr in batches + .iter() + .map(|batch| batch.column_by_name(col).unwrap()) + { + for value in to_str_vec(arr) { + output.entry(value).and_modify(|e| *e += 1).or_insert(1); + } + } + output +} + +fn to_str_vec(array: &ArrayRef) -> Vec> { + match array.data_type() { + DataType::Utf8 => array + .as_string::() + .iter() + .map(|x| x.map(|x| x.to_string())) + .collect(), + DataType::LargeUtf8 => array + .as_string::() + .iter() + .map(|x| x.map(|x| x.to_string())) + .collect(), + _ => panic!("unexpected type"), + } +} + +/// extracts the value of the first column and the count of the second column +/// ```text +/// +----------------+----------+ +/// | a | COUNT(*) | +/// +----------------+----------+ +/// | 񩢰񴠍 | 8 | +/// | 󇿺򷜄򩨝񜖫𑟑񣶏󣥽𹕉 | 11 | +/// ``` +fn extract_result_counts(results: Vec) -> HashMap, i64> { + let group_arrays = results.iter().map(|batch| batch.column(0)); + + let count_arrays = results + .iter() + .map(|batch| batch.column(1).as_primitive::()); + + let mut output = HashMap::new(); + for (group_arr, count_arr) in group_arrays.zip(count_arrays) { + assert_eq!(group_arr.len(), count_arr.len()); + let group_values = to_str_vec(group_arr); + for (group, count) in group_values.into_iter().zip(count_arr.iter()) { + assert!(output.get(&group).is_none()); + let count = count.unwrap(); // counts can never be null + output.insert(group, count); + } + } + output +} diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 343a1756476f..64b858cebc84 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -19,43 +19,22 @@ use std::sync::Arc; -use arrow::array::ArrayRef; use arrow::record_batch::RecordBatch; -use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use arrow_array::{Array, OffsetSizeTrait}; use arrow_array::cast::AsArray; use datafusion::datasource::MemTable; -use rand::rngs::StdRng; -use rand::{thread_rng, Rng, SeedableRng}; use std::collections::HashSet; use tokio::task::JoinSet; use datafusion::prelude::{SessionConfig, SessionContext}; -use test_utils::stagger_batch; +use test_utils::StringBatchGenerator; #[tokio::test(flavor = "multi_thread")] async fn distinct_count_string_test() { - // max length of generated strings let mut join_set = JoinSet::new(); - let mut rng = thread_rng(); - for null_pct in [0.0, 0.01, 0.1, 0.5] { - for _ in 0..100 { - let max_len = rng.gen_range(1..50); - let num_strings = rng.gen_range(1..100); - let num_distinct_strings = if num_strings > 1 { - rng.gen_range(1..num_strings) - } else { - num_strings - }; - let generator = BatchGenerator { - max_len, - num_strings, - num_distinct_strings, - null_pct, - rng: StdRng::from_seed(rng.gen()), - }; - join_set.spawn(async move { run_distinct_count_test(generator).await }); - } + for generator in StringBatchGenerator::interesting_cases() { + join_set.spawn(async move { run_distinct_count_test(generator).await }); } while let Some(join_handle) = join_set.join_next().await { // propagate errors @@ -65,7 +44,7 @@ async fn distinct_count_string_test() { /// Run COUNT DISTINCT using SQL and compare the result to computing the /// distinct count using HashSet -async fn run_distinct_count_test(mut generator: BatchGenerator) { +async fn run_distinct_count_test(mut generator: StringBatchGenerator) { let input = generator.make_input_batches(); let schema = input[0].schema(); @@ -136,76 +115,3 @@ fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize { assert!(!array.is_null(0)); array.value(0).try_into().unwrap() } - -struct BatchGenerator { - //// The maximum length of the strings - max_len: usize, - /// the total number of strings in the output - num_strings: usize, - /// The number of distinct strings in the columns - num_distinct_strings: usize, - /// The percentage of nulls in the columns - null_pct: f64, - /// Random number generator - rng: StdRng, -} - -impl BatchGenerator { - /// Make batches of random strings with a random length columns "a" and "b": - /// - /// * "a" is a StringArray - /// * "b" is a LargeStringArray - fn make_input_batches(&mut self) -> Vec { - // use a random number generator to pick a random sized output - - let batch = RecordBatch::try_from_iter(vec![ - ("a", self.gen_data::()), - ("b", self.gen_data::()), - ]) - .unwrap(); - - stagger_batch(batch) - } - - /// Creates a StringArray or LargeStringArray with random strings according - /// to the parameters of the BatchGenerator - fn gen_data(&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) - .map(|_| Some(random_string(&mut self.rng, self.max_len))) - .collect(); - - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_strings) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_strings > 1 { - let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); - - let options = None; - arrow::compute::take(&distinct_strings, &indicies, options).unwrap() - } -} - -/// Return a string of random characters of length 1..=max_len -fn random_string(rng: &mut StdRng, max_len: usize) -> String { - // pick characters at random (not just ascii) - match max_len { - 0 => "".to_string(), - 1 => String::from(rng.gen::()), - _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) - .take(len) - .map(char::from) - .collect::() - } - } -} diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 2d7a8cccc481..156362430558 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -507,6 +507,9 @@ impl AggregateExec { } true } + pub fn input_order_mode(&self) -> &InputOrderMode { + &self.input_order_mode + } } impl DisplayAs for AggregateExec { diff --git a/test-utils/src/data_gen.rs b/test-utils/src/data_gen.rs index f5ed8510a79e..45ad51bb44d6 100644 --- a/test-utils/src/data_gen.rs +++ b/test-utils/src/data_gen.rs @@ -46,6 +46,7 @@ impl Default for GeneratorOptions { } } +/// Creates access log like entries #[derive(Default)] struct BatchBuilder { service: StringDictionaryBuilder, diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 0c3668d2f8c0..777a24470232 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -22,8 +22,10 @@ use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; mod data_gen; +mod string_gen; pub use data_gen::AccessLogGenerator; +pub use string_gen::StringBatchGenerator; pub use env_logger; diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs new file mode 100644 index 000000000000..530fc1535387 --- /dev/null +++ b/test-utils/src/string_gen.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, RecordBatch, UInt32Array}; +use crate::stagger_batch; +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use arrow::record_batch::RecordBatch; +use rand::rngs::StdRng; +use rand::{thread_rng, Rng, SeedableRng}; + +/// Randomly generate strings +pub struct StringBatchGenerator { + //// The maximum length of the strings + pub max_len: usize, + /// the total number of strings in the output + pub num_strings: usize, + /// The number of distinct strings in the columns + pub num_distinct_strings: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl StringBatchGenerator { + /// Make batches of random strings with a random length columns "a" and "b". + /// + /// * "a" is a StringArray + /// * "b" is a LargeStringArray + pub fn make_input_batches(&mut self) -> Vec { + // use a random number generator to pick a random sized output + let batch = RecordBatch::try_from_iter(vec![ + ("a", self.gen_data::()), + ("b", self.gen_data::()), + ]) + .unwrap(); + stagger_batch(batch) + } + + /// Return a column sorted array of random strings, sorted by a + /// + /// if large is false, the array is a StringArray + /// if large is true, the array is a LargeStringArray + pub fn make_sorted_input_batches(&mut self, large: bool) -> Vec { + let array = if large { + self.gen_data::() + } else { + self.gen_data::() + }; + + let array = arrow::compute::sort(&array, None).unwrap(); + + let batch = RecordBatch::try_from_iter(vec![("a", array)]).unwrap(); + stagger_batch(batch) + } + + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } + + /// Return an set of `BatchGenerator`s that cover a range of interesting + /// cases + pub fn interesting_cases() -> Vec { + let mut cases = vec![]; + let mut rng = thread_rng(); + for null_pct in [0.0, 0.01, 0.1, 0.5] { + for _ in 0..100 { + // max length of generated strings + let max_len = rng.gen_range(1..50); + let num_strings = rng.gen_range(1..100); + let num_distinct_strings = if num_strings > 1 { + rng.gen_range(1..num_strings) + } else { + num_strings + }; + cases.push(StringBatchGenerator { + max_len, + num_strings, + num_distinct_strings, + null_pct, + rng: StdRng::from_seed(rng.gen()), + }) + } + } + cases + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +}