Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched start_blocking_portal() to use daemon threads #750

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
(`#737 <https://github.com/agronholm/anyio/issues/737>`_)
- Changed the ``ResourceWarning`` from an unclosed memory object stream to include its
address for easier identification
- Bumped minimum version of Trio to v0.26.1
- Changed ``start_blocking_portal()`` to always use daemonic threads, to accommodate the
"loitering event loop" use case
- Bumped the minimum version of Trio to v0.26.1
- Fixed ``to_process.run_sync()`` failing to initialize if ``__main__.__file__`` pointed
to a file in a nonexistent directory
(`#696 <https://github.com/agronholm/anyio/issues/696>`_)
Expand Down
69 changes: 31 additions & 38 deletions src/anyio/from_thread.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations

import sys
import threading
from collections.abc import Awaitable, Callable, Generator
from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait
from concurrent.futures import Future
from contextlib import AbstractContextManager, contextmanager
from dataclasses import dataclass, field
from inspect import isawaitable
from threading import Lock, Thread, get_ident
from types import TracebackType
from typing import (
Any,
AsyncContextManager,
ContextManager,
Generic,
Iterable,
TypeVar,
cast,
overload,
Expand Down Expand Up @@ -146,7 +145,7 @@ def __new__(cls) -> BlockingPortal:
return get_async_backend().create_blocking_portal()

def __init__(self) -> None:
self._event_loop_thread_id: int | None = threading.get_ident()
self._event_loop_thread_id: int | None = get_ident()
self._stop_event = Event()
self._task_group = create_task_group()
self._cancelled_exc_class = get_cancelled_exc_class()
Expand All @@ -167,7 +166,7 @@ async def __aexit__(
def _check_running(self) -> None:
if self._event_loop_thread_id is None:
raise RuntimeError("This portal is not running")
if self._event_loop_thread_id == threading.get_ident():
if self._event_loop_thread_id == get_ident():
raise RuntimeError(
"This method cannot be called from the event loop thread"
)
Expand Down Expand Up @@ -202,7 +201,7 @@ async def _call_func(
def callback(f: Future[T_Retval]) -> None:
if f.cancelled() and self._event_loop_thread_id not in (
None,
threading.get_ident(),
get_ident(),
):
self.call(scope.cancel)

Expand Down Expand Up @@ -411,7 +410,7 @@ class BlockingPortalProvider:

backend: str = "asyncio"
backend_options: dict[str, Any] | None = None
_lock: threading.Lock = field(init=False, default_factory=threading.Lock)
_lock: Lock = field(init=False, default_factory=Lock)
_leases: int = field(init=False, default=0)
_portal: BlockingPortal = field(init=False)
_portal_cm: AbstractContextManager[BlockingPortal] | None = field(
Expand Down Expand Up @@ -469,43 +468,37 @@ def start_blocking_portal(

async def run_portal() -> None:
async with BlockingPortal() as portal_:
if future.set_running_or_notify_cancel():
future.set_result(portal_)
await portal_.sleep_until_stopped()
future.set_result(portal_)
await portal_.sleep_until_stopped()

def run_blocking_portal() -> None:
if future.set_running_or_notify_cancel():
try:
_eventloop.run(
run_portal, backend=backend, backend_options=backend_options
)
except BaseException as exc:
if not future.done():
future.set_exception(exc)

future: Future[BlockingPortal] = Future()
with ThreadPoolExecutor(1) as executor:
run_future = executor.submit(
_eventloop.run, # type: ignore[arg-type]
run_portal,
backend=backend,
backend_options=backend_options,
)
thread = Thread(target=run_blocking_portal, daemon=True)
thread.start()
try:
cancel_remaining_tasks = False
portal = future.result()
try:
wait(
cast(Iterable[Future], [run_future, future]),
return_when=FIRST_COMPLETED,
)
yield portal
except BaseException:
future.cancel()
run_future.cancel()
cancel_remaining_tasks = True
raise

if future.done():
portal = future.result()
cancel_remaining_tasks = False
finally:
try:
yield portal
except BaseException:
cancel_remaining_tasks = True
raise
finally:
try:
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass

run_future.result()
portal.call(portal.stop, cancel_remaining_tasks)
except RuntimeError:
pass
finally:
thread.join()


def check_cancelled() -> None:
Expand Down
18 changes: 6 additions & 12 deletions tests/test_from_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from anyio.lowlevel import checkpoint

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
from exceptiongroup import ExceptionGroup

pytestmark = pytest.mark.anyio

Expand Down Expand Up @@ -609,23 +609,17 @@ def test_raise_baseexception_from_task(
"""
Test that when a task raises a BaseException, it does not trigger additional
exceptions when trying to close the portal.

"""

async def raise_baseexception() -> None:
assert threading.current_thread().daemon
raise BaseException("fatal error")

with pytest.raises(BaseExceptionGroup) as outer_exc:
with start_blocking_portal(
anyio_backend_name, anyio_backend_options
) as portal:
with pytest.raises(BaseException, match="fatal error") as exc:
portal.call(raise_baseexception)

assert exc.value.__context__ is None
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
with pytest.raises(BaseException, match="fatal error") as exc:
portal.call(raise_baseexception)

assert len(outer_exc.value.exceptions) == 1
assert str(outer_exc.value.exceptions[0]) == "fatal error"
assert exc.value.__context__ is None

@pytest.mark.parametrize("portal_backend_name", get_all_backends())
async def test_from_async(
Expand Down
Loading