diff --git a/newsfragments/4788.fixed.md b/newsfragments/4788.fixed.md new file mode 100644 index 00000000000..804cd60fd3d --- /dev/null +++ b/newsfragments/4788.fixed.md @@ -0,0 +1,4 @@ +* Fixed thread-unsafe access of dict internals in BoundDictIterator on the + free-threaded build. +* Avoided creating unnecessary critical sections in BoundDictIterator + implementation on the free-threaded build. diff --git a/src/types/dict.rs b/src/types/dict.rs index b3c8e37962b..0d2e6ff335f 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -181,7 +181,8 @@ pub trait PyDictMethods<'py>: crate::sealed::Sealed { /// Iterates over the contents of this dictionary while holding a critical section on the dict. /// This is useful when the GIL is disabled and the dictionary is shared between threads. /// It is not guaranteed that the dictionary will not be modified during iteration when the - /// closure calls arbitrary Python code that releases the current critical section. + /// closure calls arbitrary Python code that releases the critical section held by the + /// iterator. Otherwise, the dictionary will not be modified during iteration. /// /// This method is a small performance optimization over `.iter().try_for_each()` when the /// nightly feature is not enabled because we cannot implement an optimised version of @@ -396,19 +397,26 @@ impl<'a, 'py> Borrowed<'a, 'py, PyDict> { /// Iterates over the contents of this dictionary without incrementing reference counts. /// /// # Safety - /// It must be known that this dictionary will not be modified during iteration. + /// It must be known that this dictionary will not be modified during iteration, + /// for example, when parsing arguments in a keyword arguments dictionary. pub(crate) unsafe fn iter_borrowed(self) -> BorrowedDictIter<'a, 'py> { BorrowedDictIter::new(self) } } fn dict_len(dict: &Bound<'_, PyDict>) -> Py_ssize_t { - #[cfg(any(not(Py_3_8), PyPy, GraalPy, Py_LIMITED_API))] + #[cfg(any(not(Py_3_8), PyPy, GraalPy, Py_LIMITED_API, Py_GIL_DISABLED))] unsafe { ffi::PyDict_Size(dict.as_ptr()) } - #[cfg(all(Py_3_8, not(PyPy), not(GraalPy), not(Py_LIMITED_API)))] + #[cfg(all( + Py_3_8, + not(PyPy), + not(GraalPy), + not(Py_LIMITED_API), + not(Py_GIL_DISABLED) + ))] unsafe { (*dict.as_ptr().cast::()).ma_used } @@ -429,8 +437,11 @@ enum DictIterImpl { } impl DictIterImpl { + #[deny(unsafe_op_in_unsafe_fn)] #[inline] - fn next<'py>( + /// Safety: the dict should be locked with a critical section on the free-threaded build + /// and otherwise not shared between threads in code that releases the GIL. + unsafe fn next_unchecked<'py>( &mut self, dict: &Bound<'py, PyDict>, ) -> Option<(Bound<'py, PyAny>, Bound<'py, PyAny>)> { @@ -440,7 +451,7 @@ impl DictIterImpl { remaining, ppos, .. - } => crate::sync::with_critical_section(dict, || { + } => { let ma_used = dict_len(dict); // These checks are similar to what CPython does. @@ -470,20 +481,20 @@ impl DictIterImpl { let mut key: *mut ffi::PyObject = std::ptr::null_mut(); let mut value: *mut ffi::PyObject = std::ptr::null_mut(); - if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) } != 0 { + if unsafe { ffi::PyDict_Next(dict.as_ptr(), ppos, &mut key, &mut value) != 0 } { *remaining -= 1; let py = dict.py(); // Safety: // - PyDict_Next returns borrowed values // - we have already checked that `PyDict_Next` succeeded, so we can assume these to be non-null Some(( - unsafe { key.assume_borrowed_unchecked(py) }.to_owned(), - unsafe { value.assume_borrowed_unchecked(py) }.to_owned(), + unsafe { key.assume_borrowed_unchecked(py).to_owned() }, + unsafe { value.assume_borrowed_unchecked(py).to_owned() }, )) } else { None } - }), + } } } @@ -504,7 +515,17 @@ impl<'py> Iterator for BoundDictIterator<'py> { #[inline] fn next(&mut self) -> Option { - self.inner.next(&self.dict) + #[cfg(Py_GIL_DISABLED)] + { + self.inner + .with_critical_section(&self.dict, |inner| unsafe { + inner.next_unchecked(&self.dict) + }) + } + #[cfg(not(Py_GIL_DISABLED))] + { + unsafe { self.inner.next_unchecked(&self.dict) } + } } #[inline] @@ -522,7 +543,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { { self.inner.with_critical_section(&self.dict, |inner| { let mut accum = init; - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { accum = f(accum, x); } accum @@ -539,7 +560,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { { self.inner.with_critical_section(&self.dict, |inner| { let mut accum = init; - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { accum = f(accum, x)? } R::from_output(accum) @@ -554,7 +575,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { F: FnMut(Self::Item) -> bool, { self.inner.with_critical_section(&self.dict, |inner| { - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { if !f(x) { return false; } @@ -571,7 +592,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { F: FnMut(Self::Item) -> bool, { self.inner.with_critical_section(&self.dict, |inner| { - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { if f(x) { return true; } @@ -588,7 +609,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { P: FnMut(&Self::Item) -> bool, { self.inner.with_critical_section(&self.dict, |inner| { - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { if predicate(&x) { return Some(x); } @@ -605,7 +626,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { F: FnMut(Self::Item) -> Option, { self.inner.with_critical_section(&self.dict, |inner| { - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { if let found @ Some(_) = f(x) { return found; } @@ -623,7 +644,7 @@ impl<'py> Iterator for BoundDictIterator<'py> { { self.inner.with_critical_section(&self.dict, |inner| { let mut acc = 0; - while let Some(x) = inner.next(&self.dict) { + while let Some(x) = unsafe { inner.next_unchecked(&self.dict) } { if predicate(x) { return Some(acc); }