Skip to content

Commit

Permalink
Pyo3 refactorings (#740)
Browse files Browse the repository at this point in the history
* let pyo3 convert the StorageContexts argument in PySessionContext::register_object_store

* clean PySessionContext methods from_pylist and from_pydict

* clean PySessionContext metehods from_polars, from_pandas, from_arrow_table

* prefer bound Python token over  Python::with_gil

When available, using an already bound python token is zero-cost.

Python::with_gil carries a runtime check.

Ref: PyO3/pyo3#4274
  • Loading branch information
Michael-J-Ward authored Jun 26, 2024
1 parent faa26b2 commit 32d6975
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 116 deletions.
146 changes: 65 additions & 81 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ use datafusion::prelude::{
AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions,
};
use datafusion_common::ScalarValue;
use pyo3::types::PyTuple;
use pyo3::types::{PyDict, PyList, PyTuple};
use tokio::task::JoinHandle;

/// Configuration options for a SessionContext
Expand Down Expand Up @@ -291,24 +291,17 @@ impl PySessionContext {
pub fn register_object_store(
&mut self,
scheme: &str,
store: &Bound<'_, PyAny>,
store: StorageContexts,
host: Option<&str>,
) -> PyResult<()> {
let res: Result<(Arc<dyn ObjectStore>, String), PyErr> =
match StorageContexts::extract_bound(store) {
Ok(store) => match store {
StorageContexts::AmazonS3(s3) => Ok((s3.inner, s3.bucket_name)),
StorageContexts::GoogleCloudStorage(gcs) => Ok((gcs.inner, gcs.bucket_name)),
StorageContexts::MicrosoftAzure(azure) => {
Ok((azure.inner, azure.container_name))
}
StorageContexts::LocalFileSystem(local) => Ok((local.inner, "".to_string())),
},
Err(_e) => Err(PyValueError::new_err("Invalid object store")),
};

// for most stores the "host" is the bucket name and can be inferred from the store
let (store, upstream_host) = res?;
let (store, upstream_host): (Arc<dyn ObjectStore>, String) = match store {
StorageContexts::AmazonS3(s3) => (s3.inner, s3.bucket_name),
StorageContexts::GoogleCloudStorage(gcs) => (gcs.inner, gcs.bucket_name),
StorageContexts::MicrosoftAzure(azure) => (azure.inner, azure.container_name),
StorageContexts::LocalFileSystem(local) => (local.inner, "".to_string()),
};

// let users override the host to match the api signature from upstream
let derived_host = if let Some(host) = host {
host
Expand Down Expand Up @@ -434,105 +427,96 @@ impl PySessionContext {
}

/// Construct datafusion dataframe from Python list
#[allow(clippy::wrong_self_convention)]
pub fn from_pylist(
&mut self,
data: PyObject,
data: Bound<'_, PyList>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Acquire GIL Token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pylist", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from Python dictionary
#[allow(clippy::wrong_self_convention)]
pub fn from_pydict(
&mut self,
data: PyObject,
data: Bound<'_, PyDict>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Acquire GIL Token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pydict", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from Arrow Table
#[allow(clippy::wrong_self_convention)]
pub fn from_arrow_table(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0(py, "to_batches")?;

let schema = data.getattr(py, "schema")?;
let schema = schema.extract::<PyArrowType<Schema>>(py)?;

// Cast PyObject to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>(py)?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, Some(schema), py)
})
// Instantiate pyarrow Table object & convert to batches
let table = data.call_method0("to_batches")?;

let schema = data.getattr("schema")?;
let schema = schema.extract::<PyArrowType<Schema>>()?;

// Cast PyAny to RecordBatch type
// Because create_dataframe() expects a vector of vectors of record batches
// here we need to wrap the vector of record batches in an additional vector
let batches = table.extract::<PyArrowType<Vec<RecordBatch>>>()?;
let list_of_batches = PyArrowType::from(vec![batches.0]);
self.create_dataframe(list_of_batches, name, Some(schema), py)
}

/// Construct datafusion dataframe from pandas
#[allow(clippy::wrong_self_convention)]
pub fn from_pandas(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?.into();

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Obtain GIL token
let py = data.py();

// Instantiate pyarrow Table object & convert to Arrow Table
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[data]);
let table = table_class.call_method1("from_pandas", args)?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
}

/// Construct datafusion dataframe from polars
#[allow(clippy::wrong_self_convention)]
pub fn from_polars(
&mut self,
data: PyObject,
data: Bound<'_, PyAny>,
name: Option<&str>,
_py: Python,
) -> PyResult<PyDataFrame> {
Python::with_gil(|py| {
// Convert Polars dataframe to Arrow Table
let table = data.call_method0(py, "to_arrow")?;
// Convert Polars dataframe to Arrow Table
let table = data.call_method0("to_arrow")?;

// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, py)?;
Ok(df)
})
// Convert Arrow Table to datafusion DataFrame
let df = self.from_arrow_table(table, name, data.py())?;
Ok(df)
}

pub fn register_table(&mut self, name: &str, table: &PyTable) -> PyResult<()> {
Expand Down
55 changes: 22 additions & 33 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,17 +423,15 @@ impl PyDataFrame {

/// Convert to Arrow Table
/// Collect the batches and pass to Arrow Table
fn to_arrow_table(&self, py: Python) -> PyResult<PyObject> {
fn to_arrow_table(&self, py: Python<'_>) -> PyResult<PyObject> {
let batches = self.collect(py)?.to_object(py);
let schema: PyObject = self.schema().into_py(py);

Python::with_gil(|py| {
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
})
// Instantiate pyarrow Table object and use its from_batches method
let table_class = py.import_bound("pyarrow")?.getattr("Table")?;
let args = PyTuple::new_bound(py, &[batches, schema]);
let table: PyObject = table_class.call_method1("from_batches", args)?.into();
Ok(table)
}

fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
Expand Down Expand Up @@ -464,51 +462,42 @@ impl PyDataFrame {

/// Convert to pandas dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to Pandas DataFrame
fn to_pandas(&self, py: Python) -> PyResult<PyObject> {
fn to_pandas(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
let result = table.call_method0(py, "to_pandas")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pandas
let result = table.call_method0(py, "to_pandas")?;
Ok(result)
}

/// Convert to Python list using pyarrow
/// Each list item represents one row encoded as dictionary
fn to_pylist(&self, py: Python) -> PyResult<PyObject> {
fn to_pylist(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
let result = table.call_method0(py, "to_pylist")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pylist
let result = table.call_method0(py, "to_pylist")?;
Ok(result)
}

/// Convert to Python dictionary using pyarrow
/// Each dictionary key is a column and the dictionary value represents the column values
fn to_pydict(&self, py: Python) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
let result = table.call_method0(py, "to_pydict")?;
Ok(result)
})
// See also: https://arrow.apache.org/docs/python/generated/pyarrow.Table.html#pyarrow.Table.to_pydict
let result = table.call_method0(py, "to_pydict")?;
Ok(result)
}

/// Convert to polars dataframe with pyarrow
/// Collect the batches, pass to Arrow Table & then convert to polars DataFrame
fn to_polars(&self, py: Python) -> PyResult<PyObject> {
fn to_polars(&self, py: Python<'_>) -> PyResult<PyObject> {
let table = self.to_arrow_table(py)?;

Python::with_gil(|py| {
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
})
let dataframe = py.import_bound("polars")?.getattr("DataFrame")?;
let args = PyTuple::new_bound(py, &[table]);
let result: PyObject = dataframe.call1(args)?.into();
Ok(result)
}

// Executes this DataFrame to get the total number of rows.
Expand Down
4 changes: 2 additions & 2 deletions src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl PyLogicalPlan {
impl PyLogicalPlan {
/// Return the specific logical operator
pub fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Python::with_gil(|_| match self.plan.as_ref() {
match self.plan.as_ref() {
LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py),
LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
LogicalPlan::CrossJoin(plan) => PyCrossJoin::from(plan.clone()).to_variant(py),
Expand All @@ -85,7 +85,7 @@ impl PyLogicalPlan {
"Cannot convert this plan to a LogicalNode: {:?}",
other
))),
})
}
}

/// Get the inputs to this plan
Expand Down

0 comments on commit 32d6975

Please sign in to comment.