Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TableProvider.scan() and PhysicalPlanner::create_physical_plan() async #1013

Merged
merged 5 commits into from
Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ballista/rust/core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use crate::memory_stream::MemoryStream;
use crate::serde::scheduler::PartitionStats;

use crate::config::BallistaConfig;
use async_trait::async_trait;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::error::Result as ArrowResult;
use datafusion::arrow::{
Expand Down Expand Up @@ -269,8 +270,9 @@ impl BallistaQueryPlanner {
}
}

#[async_trait]
impl QueryPlanner for BallistaQueryPlanner {
fn create_physical_plan(
async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
_ctx_state: &ExecutionContextState,
Expand Down
2 changes: 2 additions & 0 deletions ballista/rust/scheduler/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ impl SchedulerGrpc for SchedulerServer {

let plan = fail_job!(datafusion_ctx
.create_physical_plan(&optimized_plan)
.await
.map_err(|e| {
let msg = format!("Could not create physical plan: {}", e);
error!("{}", msg);
Expand Down Expand Up @@ -447,6 +448,7 @@ impl SchedulerGrpc for SchedulerServer {
let mut planner = DistributedPlanner::new();
let stages = fail_job!(planner
.plan_query_stages(&job_id_spawn, plan)
.await
.map_err(|e| {
let msg = format!("Could not plan query stages: {}", e);
error!("{}", msg);
Expand Down
59 changes: 36 additions & 23 deletions ballista/rust/scheduler/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::repartition::RepartitionExec;
use datafusion::physical_plan::windows::WindowAggExec;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use futures::future::BoxFuture;
use futures::FutureExt;
use log::info;

type PartialQueryStageResult = (Arc<dyn ExecutionPlan>, Vec<Arc<ShuffleWriterExec>>);
Expand All @@ -55,14 +57,15 @@ impl DistributedPlanner {
/// Returns a vector of ExecutionPlans, where the root node is a [ShuffleWriterExec].
/// Plans that depend on the input of other plans will have leaf nodes of type [UnresolvedShuffleExec].
/// A [ShuffleWriterExec] is created whenever the partitioning changes.
pub fn plan_query_stages(
&mut self,
job_id: &str,
pub async fn plan_query_stages<'a>(
&'a mut self,
job_id: &'a str,
execution_plan: Arc<dyn ExecutionPlan>,
) -> Result<Vec<Arc<ShuffleWriterExec>>> {
info!("planning query stages");
let (new_plan, mut stages) =
self.plan_query_stages_internal(job_id, execution_plan)?;
let (new_plan, mut stages) = self
.plan_query_stages_internal(job_id, execution_plan)
.await?;
stages.push(create_shuffle_writer(
job_id,
self.next_stage_id(),
Expand All @@ -75,11 +78,12 @@ impl DistributedPlanner {
/// Returns a potentially modified version of the input execution_plan along with the resulting query stages.
/// This function is needed because the input execution_plan might need to be modified, but it might not hold a
/// complete query stage (its parent might also belong to the same stage)
fn plan_query_stages_internal(
&mut self,
job_id: &str,
fn plan_query_stages_internal<'a>(
&'a mut self,
job_id: &'a str,
execution_plan: Arc<dyn ExecutionPlan>,
) -> Result<PartialQueryStageResult> {
) -> BoxFuture<'a, Result<PartialQueryStageResult>> {
async move {
// recurse down and replace children
if execution_plan.children().is_empty() {
return Ok((execution_plan, vec![]));
Expand All @@ -88,8 +92,9 @@ impl DistributedPlanner {
let mut stages = vec![];
let mut children = vec![];
for child in execution_plan.children() {
let (new_child, mut child_stages) =
self.plan_query_stages_internal(job_id, child.clone())?;
let (new_child, mut child_stages) = self
.plan_query_stages_internal(job_id, child.clone())
.await?;
children.push(new_child);
stages.append(&mut child_stages);
}
Expand Down Expand Up @@ -161,6 +166,8 @@ impl DistributedPlanner {
Ok((execution_plan.with_new_children(children)?, stages))
}
}
.boxed()
}

/// Generate a new stage ID
fn next_stage_id(&mut self) -> usize {
Expand Down Expand Up @@ -262,8 +269,8 @@ mod test {
};
}

#[test]
fn distributed_hash_aggregate_plan() -> Result<(), BallistaError> {
#[tokio::test]
async fn distributed_hash_aggregate_plan() -> Result<(), BallistaError> {
let mut ctx = datafusion_test_context("testdata")?;

// simplified form of TPC-H query 1
Expand All @@ -276,11 +283,13 @@ mod test {

let plan = df.to_logical_plan();
let plan = ctx.optimize(&plan)?;
let plan = ctx.create_physical_plan(&plan)?;
let plan = ctx.create_physical_plan(&plan).await?;

let mut planner = DistributedPlanner::new();
let job_uuid = Uuid::new_v4();
let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
let stages = planner
.plan_query_stages(&job_uuid.to_string(), plan)
.await?;
for stage in &stages {
println!("{}", displayable(stage.as_ref()).indent().to_string());
}
Expand Down Expand Up @@ -345,8 +354,8 @@ mod test {
Ok(())
}

#[test]
fn distributed_join_plan() -> Result<(), BallistaError> {
#[tokio::test]
async fn distributed_join_plan() -> Result<(), BallistaError> {
let mut ctx = datafusion_test_context("testdata")?;

// simplified form of TPC-H query 12
Expand Down Expand Up @@ -386,11 +395,13 @@ order by

let plan = df.to_logical_plan();
let plan = ctx.optimize(&plan)?;
let plan = ctx.create_physical_plan(&plan)?;
let plan = ctx.create_physical_plan(&plan).await?;

let mut planner = DistributedPlanner::new();
let job_uuid = Uuid::new_v4();
let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
let stages = planner
.plan_query_stages(&job_uuid.to_string(), plan)
.await?;
for stage in &stages {
println!("{}", displayable(stage.as_ref()).indent().to_string());
}
Expand Down Expand Up @@ -516,8 +527,8 @@ order by
Ok(())
}

#[test]
fn roundtrip_serde_hash_aggregate() -> Result<(), BallistaError> {
#[tokio::test]
async fn roundtrip_serde_hash_aggregate() -> Result<(), BallistaError> {
let mut ctx = datafusion_test_context("testdata")?;

// simplified form of TPC-H query 1
Expand All @@ -530,11 +541,13 @@ order by

let plan = df.to_logical_plan();
let plan = ctx.optimize(&plan)?;
let plan = ctx.create_physical_plan(&plan)?;
let plan = ctx.create_physical_plan(&plan).await?;

let mut planner = DistributedPlanner::new();
let job_uuid = Uuid::new_v4();
let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
let stages = planner
.plan_query_stages(&job_uuid.to_string(), plan)
.await?;

let partial_hash = stages[0].children()[0].clone();
let partial_hash_serde = roundtrip_operator(partial_hash.clone())?;
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/bin/nyctaxi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Resu
if debug {
println!("Optimized logical plan:\n{:?}", plan);
}
let physical_plan = ctx.create_physical_plan(&plan)?;
let physical_plan = ctx.create_physical_plan(&plan).await?;
let result = collect(physical_plan).await?;
if debug {
pretty::print_batches(&result)?;
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ async fn execute_query(
if debug {
println!("=== Optimized logical plan ===\n{:?}\n", plan);
}
let physical_plan = ctx.create_physical_plan(&plan)?;
let physical_plan = ctx.create_physical_plan(&plan).await?;
if debug {
println!(
"=== Physical plan ===\n{}\n",
Expand Down Expand Up @@ -394,7 +394,7 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> {
// create the physical plan
let csv = csv.to_logical_plan();
let csv = ctx.optimize(&csv)?;
let csv = ctx.create_physical_plan(&csv)?;
let csv = ctx.create_physical_plan(&csv).await?;

let output_path = output_root_path.join(table);
let output_path = output_path.to_str().unwrap().to_owned();
Expand Down Expand Up @@ -1063,7 +1063,7 @@ mod tests {
use datafusion::physical_plan::ExecutionPlan;
use std::convert::TryInto;

fn round_trip_query(n: usize) -> Result<()> {
async fn round_trip_query(n: usize) -> Result<()> {
let config = ExecutionConfig::new()
.with_target_partitions(1)
.with_batch_size(10);
Expand Down Expand Up @@ -1110,7 +1110,7 @@ mod tests {

// test physical plan roundtrip
if env::var("TPCH_DATA").is_ok() {
let physical_plan = ctx.create_physical_plan(&plan)?;
let physical_plan = ctx.create_physical_plan(&plan).await?;
let proto: protobuf::PhysicalPlanNode =
(physical_plan.clone()).try_into().unwrap();
let round_trip: Arc<dyn ExecutionPlan> = (&proto).try_into().unwrap();
Expand All @@ -1126,9 +1126,9 @@ mod tests {

macro_rules! test_round_trip {
($tn:ident, $query:expr) => {
#[test]
fn $tn() -> Result<()> {
round_trip_query($query)
#[tokio::test]
async fn $tn() -> Result<()> {
round_trip_query($query).await
}
};
}
Expand Down
8 changes: 5 additions & 3 deletions datafusion/src/datasource/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use std::{
};

use arrow::datatypes::SchemaRef;
use async_trait::async_trait;

use crate::physical_plan::avro::{AvroExec, AvroReadOptions};
use crate::{
Expand Down Expand Up @@ -120,6 +121,7 @@ impl AvroFile {
}
}

#[async_trait]
impl TableProvider for AvroFile {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -129,7 +131,7 @@ impl TableProvider for AvroFile {
self.schema.clone()
}

fn scan(
async fn scan(
&self,
projection: &Option<Vec<usize>>,
batch_size: usize,
Expand Down Expand Up @@ -185,7 +187,7 @@ mod tests {
async fn read_small_batches() -> Result<()> {
let table = load_table("alltypes_plain.avro")?;
let projection = None;
let exec = table.scan(&projection, 2, &[], None)?;
let exec = table.scan(&projection, 2, &[], None).await?;
let stream = exec.execute(0).await?;

let _ = stream
Expand Down Expand Up @@ -414,7 +416,7 @@ mod tests {
table: Arc<dyn TableProvider>,
projection: &Option<Vec<usize>>,
) -> Result<RecordBatch> {
let exec = table.scan(projection, 1024, &[], None)?;
let exec = table.scan(projection, 1024, &[], None).await?;
let mut it = exec.execute(0).await?;
it.next()
.await
Expand Down
4 changes: 3 additions & 1 deletion datafusion/src/datasource/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
//! ```

use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use std::any::Any;
use std::io::{Read, Seek};
use std::string::String;
Expand Down Expand Up @@ -157,6 +158,7 @@ impl CsvFile {
}
}

#[async_trait]
impl TableProvider for CsvFile {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -166,7 +168,7 @@ impl TableProvider for CsvFile {
self.schema.clone()
}

fn scan(
async fn scan(
&self,
projection: &Option<Vec<usize>>,
batch_size: usize,
Expand Down
5 changes: 4 additions & 1 deletion datafusion/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
use std::any::Any;
use std::sync::Arc;

use async_trait::async_trait;

use crate::arrow::datatypes::SchemaRef;
use crate::error::Result;
use crate::logical_plan::Expr;
Expand Down Expand Up @@ -54,6 +56,7 @@ pub enum TableType {
}

/// Source table
#[async_trait]
pub trait TableProvider: Sync + Send {
/// Returns the table provider as [`Any`](std::any::Any) so that it can be
/// downcast to a specific implementation.
Expand All @@ -68,7 +71,7 @@ pub trait TableProvider: Sync + Send {
}

/// Create an ExecutionPlan that will scan the table.
fn scan(
async fn scan(
&self,
projection: &Option<Vec<usize>>,
batch_size: usize,
Expand Down
4 changes: 3 additions & 1 deletion datafusion/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use std::any::Any;
use std::sync::Arc;

use arrow::datatypes::*;
use async_trait::async_trait;

use crate::datasource::TableProvider;
use crate::error::Result;
Expand All @@ -39,6 +40,7 @@ impl EmptyTable {
}
}

#[async_trait]
impl TableProvider for EmptyTable {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -48,7 +50,7 @@ impl TableProvider for EmptyTable {
self.schema.clone()
}

fn scan(
async fn scan(
&self,
projection: &Option<Vec<usize>>,
_batch_size: usize,
Expand Down
5 changes: 4 additions & 1 deletion datafusion/src/datasource/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use crate::{
},
};
use arrow::{datatypes::SchemaRef, json::reader::infer_json_schema_from_seekable};
use async_trait::async_trait;

trait SeekRead: Read + Seek {}

Expand Down Expand Up @@ -101,6 +102,8 @@ impl NdJsonFile {
})
}
}

#[async_trait]
impl TableProvider for NdJsonFile {
fn as_any(&self) -> &dyn Any {
self
Expand All @@ -110,7 +113,7 @@ impl TableProvider for NdJsonFile {
self.schema.clone()
}

fn scan(
async fn scan(
&self,
projection: &Option<Vec<usize>>,
batch_size: usize,
Expand Down
Loading