Skip to content

Commit

Permalink
Switched start_blocking_portal() to use daemon threads (#750)
Browse files Browse the repository at this point in the history
ThreadPoolExecutor used daemon threads in Python 3.8 and non-daemon threads in 3.9 onwards. But daemon threads are essential for the use case where we want to have a "loitering" blocking portal which will only be shut down via an atexit hook.
  • Loading branch information
agronholm authored Aug 30, 2024
1 parent f1a46f2 commit e5a8a93
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 51 deletions.
4 changes: 3 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
(`#737 <https://github.com/agronholm/anyio/issues/737>`_)
- Changed the ``ResourceWarning`` from an unclosed memory object stream to include its
address for easier identification
- Bumped minimum version of Trio to v0.26.1
- Changed ``start_blocking_portal()`` to always use daemonic threads, to accommodate the
"loitering event loop" use case
- Bumped the minimum version of Trio to v0.26.1
- Fixed ``to_process.run_sync()`` failing to initialize if ``__main__.__file__`` pointed
to a file in a nonexistent directory
(`#696 <https://github.com/agronholm/anyio/issues/696>`_)
Expand Down
69 changes: 31 additions & 38 deletions src/anyio/from_thread.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations

import sys
import threading
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from concurrent.futures import Future
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from inspect import isawaitable
from threading import Lock, Thread, get_ident
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
ContextManager,
Generic,
Iterable,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -146,7 +145,7 @@ def __new__(cls) -> BlockingPortal:
return get_async_backend().create_blocking_portal()

def __init__(self) -> None:
self._event_loop_thread_id: int | None = threading.get_ident()
self._event_loop_thread_id: int | None = get_ident()
self._stop_event = Event()
self._task_group = create_task_group()
self._cancelled_exc_class = get_cancelled_exc_class()
Expand All @@ -167,7 +166,7 @@ async def __aexit__(
def _check_running(self) -> None:
if self._event_loop_thread_id is None:
raise RuntimeError("This portal is not running")
if self._event_loop_thread_id == threading.get_ident():
if self._event_loop_thread_id == get_ident():
raise RuntimeError(
"This method cannot be called from the event loop thread"
)
Expand Down Expand Up @@ -202,7 +201,7 @@ async def _call_func(
def callback(f: Future[T_Retval]) -> None:
if f.cancelled() and self._event_loop_thread_id not in (
None,
threading.get_ident(),
get_ident(),
):
self.call(scope.cancel)

Expand Down Expand Up @@ -411,7 +410,7 @@ class BlockingPortalProvider:

backend: str = "asyncio"
backend_options: dict[str, Any] | None = None
_lock: threading.Lock = field(init=False, default_factory=threading.Lock)
_lock: Lock = field(init=False, default_factory=Lock)
_leases: int = field(init=False, default=0)
_portal: BlockingPortal = field(init=False)
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
Expand Down Expand Up @@ -469,43 +468,37 @@ def start_blocking_portal(

async def run_portal() -> None:
async with BlockingPortal() as portal_:
if future.set_running_or_notify_cancel():
future.set_result(portal_)
await portal_.sleep_until_stopped()
future.set_result(portal_)
await portal_.sleep_until_stopped()

def run_blocking_portal() -> None:
if future.set_running_or_notify_cancel():
try:
_eventloop.run(
run_portal, backend=backend, backend_options=backend_options
)
except BaseException as exc:
if not future.done():
future.set_exception(exc)

future: Future[BlockingPortal] = Future()
with ThreadPoolExecutor(1) as executor:
run_future = executor.submit(
_eventloop.run, # type: ignore[arg-type]
run_portal,
backend=backend,
backend_options=backend_options,
)
thread = Thread(target=run_blocking_portal, daemon=True)
thread.start()
try:
cancel_remaining_tasks = False
portal = future.result()
try:
wait(
cast(Iterable[Future], [run_future, future]),
return_when=FIRST_COMPLETED,
)
yield portal
except BaseException:
future.cancel()
run_future.cancel()
cancel_remaining_tasks = True
raise

if future.done():
portal = future.result()
cancel_remaining_tasks = False
finally:
try:
yield portal
except BaseException:
cancel_remaining_tasks = True
raise
finally:
try:
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass

run_future.result()
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass
finally:
thread.join()


def check_cancelled() -> None:
Expand Down
18 changes: 6 additions & 12 deletions tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from anyio.lowlevel import checkpoint

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
from exceptiongroup import ExceptionGroup

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -609,23 +609,17 @@ def test_raise_baseexception_from_task(
"""
Test that when a task raises a BaseException, it does not trigger additional
exceptions when trying to close the portal.
"""

async def raise_baseexception() -> None:
assert threading.current_thread().daemon
raise BaseException("fatal error")

with pytest.raises(BaseExceptionGroup) as outer_exc:
with start_blocking_portal(
anyio_backend_name, anyio_backend_options
) as portal:
with pytest.raises(BaseException, match="fatal error") as exc:
portal.call(raise_baseexception)

assert exc.value.__context__ is None
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
with pytest.raises(BaseException, match="fatal error") as exc:
portal.call(raise_baseexception)

assert len(outer_exc.value.exceptions) == 1
assert str(outer_exc.value.exceptions[0]) == "fatal error"
assert exc.value.__context__ is None

@pytest.mark.parametrize("portal_backend_name", get_all_backends())
async def test_from_async(
Expand Down

0 comments on commit e5a8a93

Please sign in to comment.