diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 52a87347..69e0878f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,14 +22,14 @@ repos: - id: trailing-whitespace - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.8.1 hooks: - id: ruff args: [--fix, --show-fixes] - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.13.0 hooks: - id: mypy additional_dependencies: diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 05e69a09..ce4c4db0 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -5,6 +5,8 @@ This library adheres to `Semantic Versioning 2.0 `_. **UNRELEASED** +- Updated ``TaskGroup`` to work with asyncio's eager task factories + (`#764 `_) - Fixed a misleading ``ValueError`` in the context of DNS failures (`#815 `_; PR by @graingert) - Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index fe648a6f..38b68f4d 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -28,6 +28,8 @@ Collection, Coroutine, Iterable, + Iterator, + MutableMapping, Sequence, ) from concurrent.futures import Future @@ -351,8 +353,12 @@ def get_callable_name(func: Callable) -> str: def _task_started(task: asyncio.Task) -> bool: """Return ``True`` if the task has been started and has not finished.""" + # The task coro should never be None here, as we never add finished tasks to the + # task list + coro = task.get_coro() + assert coro is not None try: - return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED) + return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) except AttributeError: # task coro is async_genenerator_asend https://bugs.python.org/issue37771 raise Exception(f"Cannot determine if task {task} has started or not") from None @@ -409,8 +415,10 @@ def __enter__(self) -> CancelScope: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self if self._parent_scope is not None: + # If using an eager task factory, the parent scope may not even contain + # the host task self._parent_scope._child_scopes.add(self) - self._parent_scope._tasks.remove(host_task) + self._parent_scope._tasks.discard(host_task) self._timeout() self._active = True @@ -667,7 +675,45 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): self.cancel_scope = cancel_scope -_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() +class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]): + def __init__(self) -> None: + self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]() + self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {} + + def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState: + assert isinstance(key, asyncio.Task) + try: + return self._task_states[key] + except KeyError: + if coro := key.get_coro(): + if state := self._preliminary_task_states.get(coro): + return state + + raise KeyError(key) + + def __setitem__( + self, key: asyncio.Task | Awaitable[Any], value: TaskState, / + ) -> None: + if isinstance(key, asyncio.Task): + self._task_states[key] = value + else: + self._preliminary_task_states[key] = value + + def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None: + if isinstance(key, asyncio.Task): + del self._task_states[key] + else: + del self._preliminary_task_states[key] + + def __len__(self) -> int: + return len(self._task_states) + len(self._preliminary_task_states) + + def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]: + yield from self._task_states + yield from self._preliminary_task_states + + +_task_states = TaskStateStore() # @@ -787,7 +833,7 @@ def _spawn( task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: - task_state = _task_states[_task] + # task_state = _task_states[_task] assert task_state.cancel_scope is not None assert _task in task_state.cancel_scope._tasks task_state.cancel_scope._tasks.remove(_task) @@ -844,16 +890,26 @@ def task_done(_task: asyncio.Task) -> None: f"the return value ({coro!r}) is not a coroutine object" ) - name = get_callable_name(func) if name is None else str(name) - task = create_task(coro, name=name) - task.add_done_callback(task_done) - # Make the spawned task inherit the task group's cancel scope - _task_states[task] = TaskState( + _task_states[coro] = task_state = TaskState( parent_id=parent_id, cancel_scope=self.cancel_scope ) + name = get_callable_name(func) if name is None else str(name) + try: + task = create_task(coro, name=name) + finally: + del _task_states[coro] + + _task_states[task] = task_state self.cancel_scope._tasks.add(task) self._tasks.add(task) + + if task.done(): + # This can happen with eager task factories + task_done(task) + else: + task.add_done_callback(task_done) + return task def start_soon( @@ -2086,7 +2142,9 @@ def __init__(self, task: asyncio.Task): else: parent_id = task_state.parent_id - super().__init__(id(task), parent_id, task.get_name(), task.get_coro()) + coro = task.get_coro() + assert coro is not None, "created TaskInfo from a completed Task" + super().__init__(id(task), parent_id, task.get_name(), coro) self._task = weakref.ref(task) def has_pending_cancellation(self) -> bool: @@ -2339,10 +2397,11 @@ def create_cancel_scope( @classmethod def current_effective_deadline(cls) -> float: + if (task := current_task()) is None: + return math.inf + try: - cancel_scope = _task_states[ - current_task() # type: ignore[index] - ].cancel_scope + cancel_scope = _task_states[task].cancel_scope except KeyError: return math.inf diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 88aecf38..f6e5c40c 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -40,6 +40,12 @@ class TaskGroup(metaclass=ABCMeta): :ivar cancel_scope: the cancel scope inherited by all child tasks :vartype cancel_scope: CancelScope + + .. note:: On asyncio, support for eager task factories is considered to be + **experimental**. In particular, they don't follow the usual semantics of new + tasks being scheduled on the next iteration of the event loop, and may thus + cause unexpected behavior in code that wasn't written with such semantics in + mind. """ cancel_scope: CancelScope diff --git a/tests/test_sockets.py b/tests/test_sockets.py index a1189bb8..8965ea61 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -149,7 +149,7 @@ def _identity(v: _T) -> _T: ) -@_ignore_win32_resource_warnings # type: ignore[operator] +@_ignore_win32_resource_warnings class TestTCPStream: @pytest.fixture def server_sock(self, family: AnyIPAddressFamily) -> Iterator[socket.socket]: diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index 78ef9983..84101e47 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -10,7 +10,8 @@ from typing import Any, NoReturn, cast import pytest -from exceptiongroup import ExceptionGroup, catch +from exceptiongroup import catch +from pytest import FixtureRequest from pytest_mock import MockerFixture import anyio @@ -783,7 +784,7 @@ async def host_agen_fn() -> AsyncGenerator[None, None]: host_agen = host_agen_fn() try: loop = asyncio.get_running_loop() - await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type] + await loop.create_task(host_agen.__anext__()) finally: await host_agen.aclose() @@ -1704,3 +1705,24 @@ async def typetest_optional_status( task_status: TaskStatus[int] = TASK_STATUS_IGNORED, ) -> None: task_status.started(1) + + +@pytest.mark.skipif( + sys.version_info < (3, 12), + reason="Eager task factories require Python 3.12", +) +@pytest.mark.parametrize("anyio_backend", ["asyncio"]) +async def test_eager_task_factory(request: FixtureRequest) -> None: + async def sync_coro() -> None: + # This should trigger fetching the task state + with CancelScope(): # noqa: ASYNC100 + pass + + loop = asyncio.get_running_loop() + old_task_factory = loop.get_task_factory() + loop.set_task_factory(asyncio.eager_task_factory) + request.addfinalizer(lambda: loop.set_task_factory(old_task_factory)) + + async with create_task_group() as tg: + tg.start_soon(sync_coro) + tg.cancel_scope.cancel()