Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string aggregate grouping fuzz test, add MemTable::with_sort_exprs #9190

Merged
merged 1 commit into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ use arrow::datatypes::SchemaRef;
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion_common::{
not_impl_err, plan_err, Constraints, DataFusionError, SchemaExt,
not_impl_err, plan_err, Constraints, DFSchema, DataFusionError, SchemaExt,
};
use datafusion_execution::TaskContext;
use parking_lot::Mutex;
use tokio::sync::RwLock;
use tokio::task::JoinSet;

Expand All @@ -44,6 +45,7 @@ use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
use crate::physical_planner::create_physical_sort_expr;

/// Type alias for partition data
pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
Expand All @@ -58,6 +60,9 @@ pub struct MemTable {
pub(crate) batches: Vec<PartitionData>,
constraints: Constraints,
column_defaults: HashMap<String, Expr>,
/// Optional pre-known sort order(s). Must be `SortExpr`s.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed this feature to write a test that ran with sorted aggregates

/// inserting data into this table removes the order
pub sort_order: Arc<Mutex<Vec<Vec<Expr>>>>,
}

impl MemTable {
Expand All @@ -82,6 +87,7 @@ impl MemTable {
.collect::<Vec<_>>(),
constraints: Constraints::empty(),
column_defaults: HashMap::new(),
sort_order: Arc::new(Mutex::new(vec![])),
})
}

Expand All @@ -100,6 +106,21 @@ impl MemTable {
self
}

/// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
///
/// If the data is not sorted by this order, DataFusion may produce
/// incorrect results.
///
/// DataFusion may take advantage of this ordering to omit sorts
/// or use more efficient algorithms.
///
/// Note that multiple sort orders are supported, if some are known to be
/// equivalent,
pub fn with_sort_order(self, mut sort_order: Vec<Vec<Expr>>) -> Self {
std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
self
}

/// Create a mem table by reading from another data source
pub async fn load(
t: Arc<dyn TableProvider>,
Expand Down Expand Up @@ -184,7 +205,7 @@ impl TableProvider for MemTable {

async fn scan(
&self,
_state: &SessionState,
state: &SessionState,
projection: Option<&Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand All @@ -194,11 +215,33 @@ impl TableProvider for MemTable {
let inner_vec = arc_inner_vec.read().await;
partitions.push(inner_vec.clone())
}
Ok(Arc::new(MemoryExec::try_new(
&partitions,
self.schema(),
projection.cloned(),
)?))
let mut exec =
MemoryExec::try_new(&partitions, self.schema(), projection.cloned())?;

// add sort information if present
let sort_order = self.sort_order.lock();
if !sort_order.is_empty() {
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;

let file_sort_order = sort_order
.iter()
.map(|sort_exprs| {
sort_exprs
.iter()
.map(|expr| {
create_physical_sort_expr(
expr,
&df_schema,
state.execution_props(),
)
})
.collect::<Result<Vec<_>>>()
})
.collect::<Result<Vec<_>>>()?;
exec = exec.with_sort_information(file_sort_order);
}

Ok(Arc::new(exec))
}

/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
Expand All @@ -219,6 +262,9 @@ impl TableProvider for MemTable {
input: Arc<dyn ExecutionPlan>,
overwrite: bool,
) -> Result<Arc<dyn ExecutionPlan>> {
// If we are inserting into the table, any sort order may be messed up so reset it here
*self.sort_order.lock() = vec![];

// Create a physical plan from the logical plan.
// Check that the schema of the plan matches the schema of this table.
if !self
Expand Down
190 changes: 178 additions & 12 deletions datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow::record_batch::RecordBatch;
use arrow::util::pretty::pretty_format_batches;
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use arrow_array::cast::AsArray;
use arrow_array::types::Int64Type;
use arrow_array::Array;
use hashbrown::HashMap;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
use tokio::task::JoinSet;

use datafusion::common::Result;
use datafusion::datasource::MemTable;
use datafusion::physical_plan::aggregates::{
AggregateExec, AggregateMode, PhysicalGroupBy,
};
use datafusion::physical_plan::memory::MemoryExec;
use datafusion::physical_plan::{collect, displayable, ExecutionPlan};
use datafusion::prelude::{SessionConfig, SessionContext};
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion};
use datafusion_physical_expr::expressions::{col, Sum};
use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr};
use test_utils::add_empty_batches;
use datafusion_physical_plan::InputOrderMode;
use test_utils::{add_empty_batches, StringBatchGenerator};

#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn aggregate_test() {
/// Tests that streaming aggregate and batch (non streaming) aggregate produce
/// same results
#[tokio::test(flavor = "multi_thread")]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason to limit this to 8 threads that I know of. Using multi-thread uses more cores if available

async fn streaming_aggregate_test() {
let test_cases = vec![
vec!["a"],
vec!["b", "a"],
Expand All @@ -50,18 +61,18 @@ async fn aggregate_test() {
let n = 300;
let distincts = vec![10, 20];
for distinct in distincts {
let mut handles = Vec::new();
let mut join_set = JoinSet::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using JoinSet for consistency -- which automatically cancels outstanding tasks on panic

for i in 0..n {
let test_idx = i % test_cases.len();
let group_by_columns = test_cases[test_idx].clone();
let job = tokio::spawn(run_aggregate_test(
join_set.spawn(run_aggregate_test(
make_staggered_batches::<true>(1000, distinct, i as u64),
group_by_columns,
));
handles.push(job);
}
for job in handles {
job.await.unwrap();
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}
}
Expand Down Expand Up @@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>(
}
add_empty_batches(batches, &mut rng)
}

/// Test group by with string/large string columns
#[tokio::test(flavor = "multi_thread")]
async fn group_by_strings() {
let mut join_set = JoinSet::new();
for large in [true, false] {
for sorted in [true, false] {
for generator in StringBatchGenerator::interesting_cases() {
join_set.spawn(group_by_string_test(generator, sorted, large));
}
}
}
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}

/// Run GROUP BY <x> using SQL and ensure the results are correct
///
/// If sorted is true, the input batches will be sorted by the group by column
/// to test the streaming group by case
///
/// if large is true, the input batches will be LargeStringArray
async fn group_by_string_test(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is the new test

mut generator: StringBatchGenerator,
sorted: bool,
large: bool,
) {
let column_name = "a";
let input = if sorted {
generator.make_sorted_input_batches(large)
} else {
generator.make_input_batches()
};

let expected = compute_counts(&input, column_name);

let schema = input[0].schema();
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::new_with_config(session_config);

let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap();
let provider = if sorted {
let sort_expr = datafusion::prelude::col("a").sort(true, true);
provider.with_sort_order(vec![vec![sort_expr]])
} else {
provider
};

ctx.register_table("t", Arc::new(provider)).unwrap();

let df = ctx
.sql("SELECT a, COUNT(*) FROM t GROUP BY a")
.await
.unwrap();
verify_ordered_aggregate(&df, sorted).await;
let results = df.collect().await.unwrap();

// verify that the results are correct
let actual = extract_result_counts(results);
assert_eq!(expected, actual);
}
async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) {
struct Visitor {
expected_sort: bool,
}
let mut visitor = Visitor { expected_sort };

impl TreeNodeVisitor for Visitor {
type N = Arc<dyn ExecutionPlan>;
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
if let Some(exec) = node.as_any().downcast_ref::<AggregateExec>() {
if self.expected_sort {
assert!(matches!(
exec.input_order_mode(),
InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted
));
} else {
assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear));
}
}
Ok(VisitRecursion::Continue)
}
}

let plan = frame.clone().create_physical_plan().await.unwrap();
plan.visit(&mut visitor).unwrap();
}

/// Compute the count of each distinct value in the specified column
///
/// ```text
/// +---------------+---------------+
/// | a | b |
/// +---------------+---------------+
/// | 𭏷񬝜󓴻𼇪󄶛𑩁򽵐󦊟 | 󺚤𘱦𫎛񐕿 |
/// | 󂌿󶴬񰶨񺹭𿑵󖺉 | 񥼧􋽐󮋋󑤐𬿪𜋃 |
/// ```
fn compute_counts(batches: &[RecordBatch], col: &str) -> HashMap<Option<String>, i64> {
let mut output = HashMap::new();
for arr in batches
.iter()
.map(|batch| batch.column_by_name(col).unwrap())
{
for value in to_str_vec(arr) {
output.entry(value).and_modify(|e| *e += 1).or_insert(1);
}
}
output
}

fn to_str_vec(array: &ArrayRef) -> Vec<Option<String>> {
match array.data_type() {
DataType::Utf8 => array
.as_string::<i32>()
.iter()
.map(|x| x.map(|x| x.to_string()))
.collect(),
DataType::LargeUtf8 => array
.as_string::<i64>()
.iter()
.map(|x| x.map(|x| x.to_string()))
.collect(),
_ => panic!("unexpected type"),
}
}

/// extracts the value of the first column and the count of the second column
/// ```text
/// +----------------+----------+
/// | a | COUNT(*) |
/// +----------------+----------+
/// | 񩢰񴠍 | 8 |
/// | 󇿺򷜄򩨝񜖫𑟑񣶏󣥽𹕉 | 11 |
/// ```
fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i64> {
let group_arrays = results.iter().map(|batch| batch.column(0));

let count_arrays = results
.iter()
.map(|batch| batch.column(1).as_primitive::<Int64Type>());

let mut output = HashMap::new();
for (group_arr, count_arr) in group_arrays.zip(count_arrays) {
assert_eq!(group_arr.len(), count_arr.len());
let group_values = to_str_vec(group_arr);
for (group, count) in group_values.into_iter().zip(count_arr.iter()) {
assert!(output.get(&group).is_none());
let count = count.unwrap(); // counts can never be null
output.insert(group, count);
}
}
output
}
Loading
Loading