Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for wrapping rust closures as python functions #1901

Merged
merged 1 commit into from
Oct 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849)
- Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860)
- Add `abi3-py310` feature. [#1889](https://github.com/PyO3/pyo3/pull/1889)
- Add `PyCFunction::new_closure` to create a Python function from a Rust closure. [#1901](https://github.com/PyO3/pyo3/pull/1901)

### Changed

Expand Down
117 changes: 105 additions & 12 deletions src/types/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,58 @@ use crate::exceptions::PyValueError;
use crate::prelude::*;
use crate::{
class::methods::{self, PyMethodDef},
ffi, AsPyPointer,
ffi, types, AsPyPointer,
};
use std::os::raw::c_void;

/// Represents a builtin Python function object.
#[repr(transparent)]
pub struct PyCFunction(PyAny);

pyobject_native_type_core!(PyCFunction, ffi::PyCFunction_Type, #checkfunction=ffi::PyCFunction_Check);

const CLOSURE_CAPSULE_NAME: &[u8] = b"pyo3-closure\0";

unsafe extern "C" fn run_closure<F, R>(
capsule_ptr: *mut ffi::PyObject,
args: *mut ffi::PyObject,
kwargs: *mut ffi::PyObject,
) -> *mut ffi::PyObject
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
crate::callback_body!(py, {
let boxed_fn: &F =
&*(ffi::PyCapsule_GetPointer(capsule_ptr, CLOSURE_CAPSULE_NAME.as_ptr() as *const _)
as *mut F);
let args = py.from_borrowed_ptr::<types::PyTuple>(args);
let kwargs = py.from_borrowed_ptr_or_opt::<types::PyDict>(kwargs);
boxed_fn(args, kwargs)
})
}

unsafe extern "C" fn drop_closure<F, R>(capsule_ptr: *mut ffi::PyObject)
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let result = std::panic::catch_unwind(|| {
let boxed_fn: Box<F> = Box::from_raw(ffi::PyCapsule_GetPointer(
capsule_ptr,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
) as *mut F);
drop(boxed_fn)
});
if let Err(err) = result {
// This second layer of catch_unwind is useful as eprintln! can also panic.
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
eprintln!("--- PyO3 intercepted a panic when dropping a closure");
eprintln!("{:?}", err);
}));
}
}

impl PyCFunction {
/// Create a new built-in function with keywords.
pub fn new_with_keywords<'a>(
Expand Down Expand Up @@ -39,23 +82,57 @@ impl PyCFunction {
)
}

/// Create a new function from a closure.
///
/// # Examples
///
/// ```
/// # use pyo3::prelude::*;
/// # use pyo3::{py_run, types};
///
/// Python::with_gil(|py| {
/// let add_one = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
/// let i = args.extract::<(i64,)>()?.0;
/// Ok(i+1)
/// };
/// let add_one = types::PyCFunction::new_closure(add_one, py).unwrap();
/// py_run!(py, add_one, "assert add_one(42) == 43");
/// });
/// ```
pub fn new_closure<F, R>(f: F, py: Python) -> PyResult<&PyCFunction>
where
F: Fn(&types::PyTuple, Option<&types::PyDict>) -> R + Send + 'static,
R: crate::callback::IntoPyCallbackOutput<*mut ffi::PyObject>,
{
let function_ptr = Box::into_raw(Box::new(f));
let capsule = unsafe {
PyObject::from_owned_ptr_or_err(
py,
ffi::PyCapsule_New(
function_ptr as *mut c_void,
CLOSURE_CAPSULE_NAME.as_ptr() as *const _,
Some(drop_closure::<F, R>),
),
)?
};
let method_def = methods::PyMethodDef::cfunction_with_keywords(
"pyo3-closure",
methods::PyCFunctionWithKeywords(run_closure::<F, R>),
"",
);
Self::internal_new_from_pointers(method_def, py, capsule.as_ptr(), std::ptr::null_mut())
}

#[doc(hidden)]
pub fn internal_new(
fn internal_new_from_pointers(
method_def: PyMethodDef,
py_or_module: PyFunctionArguments,
py: Python,
mod_ptr: *mut ffi::PyObject,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, so in the closure case the mod_ptr is set to the capsule, which makes the capsule object appear as the first argument to run_closure. That's interesting!

module_name: *mut ffi::PyObject,
) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let def = method_def
.as_method_def()
.map_err(|err| PyValueError::new_err(err.0))?;
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

unsafe {
py.from_owned_ptr_or_err::<PyCFunction>(ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(def)),
Expand All @@ -64,6 +141,22 @@ impl PyCFunction {
))
}
}

#[doc(hidden)]
pub fn internal_new(
method_def: PyMethodDef,
py_or_module: PyFunctionArguments,
) -> PyResult<&Self> {
let (py, module) = py_or_module.into_py_and_maybe_module();
let (mod_ptr, module_name) = if let Some(m) = module {
let mod_ptr = m.as_ptr();
let name = m.name()?.into_py(py);
(mod_ptr, name.as_ptr())
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};
Self::internal_new_from_pointers(method_def, py, mod_ptr, module_name)
}
}

/// Represents a Python function object.
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_pymethods.rs");
t.compile_fail("tests/ui/invalid_pymethod_names.rs");
t.compile_fail("tests/ui/invalid_argument_attributes.rs");
t.compile_fail("tests/ui/invalid_closure.rs");
t.compile_fail("tests/ui/reject_generics.rs");

tests_rust_1_48(&t);
Expand Down
56 changes: 55 additions & 1 deletion tests/test_pyfunction.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(not(Py_LIMITED_API))]
use pyo3::buffer::PyBuffer;
use pyo3::prelude::*;
use pyo3::types::PyCFunction;
use pyo3::types::{self, PyCFunction};
#[cfg(not(Py_LIMITED_API))]
use pyo3::types::{PyDateTime, PyFunction};

Expand Down Expand Up @@ -213,3 +213,57 @@ fn test_conversion_error() {
"argument 'option_arg': 'str' object cannot be interpreted as an integer"
);
}

#[test]
fn test_closure() {
let gil = Python::acquire_gil();
let py = gil.python();

let f = |args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<_> {
let gil = Python::acquire_gil();
let py = gil.python();
let res: Vec<_> = args
.iter()
.map(|elem| {
if let Ok(i) = elem.extract::<i64>() {
(i + 1).into_py(py)
} else if let Ok(f) = elem.extract::<f64>() {
(2. * f).into_py(py)
} else if let Ok(mut s) = elem.extract::<String>() {
s.push_str("-py");
s.into_py(py)
} else {
panic!("unexpected argument type for {:?}", elem)
}
})
.collect();
Ok(res)
};
let closure_py = PyCFunction::new_closure(f, py).unwrap();

py_assert!(py, closure_py, "closure_py(42) == [43]");
py_assert!(
py,
closure_py,
"closure_py(42, 3.14, 'foo') == [43, 6.28, 'foo-py']"
);
}

#[test]
fn test_closure_counter() {
let gil = Python::acquire_gil();
let py = gil.python();

let counter = std::cell::RefCell::new(0);
let counter_fn =
move |_args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<i32> {
let mut counter = counter.borrow_mut();
*counter += 1;
Ok(*counter)
};
let counter_py = PyCFunction::new_closure(counter_fn, py).unwrap();

py_assert!(py, counter_py, "counter_py() == 1");
py_assert!(py, counter_py, "counter_py() == 2");
py_assert!(py, counter_py, "counter_py() == 3");
}
Comment on lines +252 to +269
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. I guess because of the GIL we can agree that the RefCell usage is safe? Do you think it would even be safe for us to support FnMut?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right and holding the GIL should prevent the boxed function to be called in parallel so FnMut should be fine (though I'm not super confident on all this). I've switched to allowing FnMut and also tweaked the example so that it doesn't require a RefCell anymore and instead just uses a Box.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I'm afraid I think I spoke too quickly, sorry. I've tweaked the Box example a bit to one which holds an &mut reference to the Box while calling itself recursively in another thread, which looks to me like it breaks the rule that only one &mut reference to a value can exist at a time. (i.e. it's probably possible to create UB with FnMut.)

let mut counter = Box::new(0);
let counter_fn_obj: Arc<RwLock<Option<Py<PyCFunction>>>> = Arc::new(RwLock::new(None));
let counter_fn_obj_clone = counter_fn_obj.clone();
let counter_fn =
    move |_args: &types::PyTuple, _kwargs: Option<&types::PyDict>| -> PyResult<i32> {
        let counter_value = &mut *counter;
        if *counter_value < 5 {
            _args.py().allow_threads(|| {
                let counter_fn_obj_clone = counter_fn_obj_clone.clone();
                let other_thread = std::thread::spawn(move || Python::with_gil(|py| {
                    counter_fn_obj_clone.read().as_ref().expect("function has been made").call0(py)
                }));
                *counter_value += 1;
                other_thread.join().unwrap()
            })?;
        }
        Ok(*counter_value)
    };
let counter_py = PyCFunction::new_closure(counter_fn, py).unwrap();
*counter_fn_obj.write() = Some(counter_py.into());

So I think to be safe let's go back to just Fn. Sorry about that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think once we switch back to Fn this looks good to merge in my opinion!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, thanks for checking this thoroughly, just reverted the FnMut changes.

19 changes: 19 additions & 0 deletions tests/ui/invalid_closure.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use pyo3::prelude::*;
use pyo3::types::{PyCFunction, PyDict, PyTuple};

fn main() {
let fun: Py<PyCFunction> = Python::with_gil(|py| {
let local_data = vec![0, 1, 2, 3, 4];
let ref_: &[u8] = &local_data;

let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
println!("This is five: {:?}", ref_.len());
Ok(())
};
PyCFunction::new_closure(closure_fn, py).unwrap().into()
});

Python::with_gil(|py| {
fun.call0(py).unwrap();
});
}
28 changes: 28 additions & 0 deletions tests/ui/invalid_closure.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
error[E0597]: `local_data` does not live long enough
--> tests/ui/invalid_closure.rs:7:27
|
7 | let ref_: &[u8] = &local_data;
| ^^^^^^^^^^^ borrowed value does not live long enough
...
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ---------------------------------------- argument requires that `local_data` is borrowed for `'static`
14 | });
| - `local_data` dropped here while still borrowed

error[E0373]: closure may outlive the current function, but it borrows `ref_`, which is owned by the current function
--> tests/ui/invalid_closure.rs:9:26
|
9 | let closure_fn = |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ may outlive borrowed value `ref_`
10 | println!("This is five: {:?}", ref_.len());
| ---- `ref_` is borrowed here
|
note: function requires argument type to outlive `'static`
--> tests/ui/invalid_closure.rs:13:9
|
13 | PyCFunction::new_closure(closure_fn, py).unwrap().into()
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
help: to force the closure to take ownership of `ref_` (and any other referenced variables), use the `move` keyword
|
9 | let closure_fn = move |_args: &PyTuple, _kwargs: Option<&PyDict>| -> PyResult<()> {
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^