diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 52a87347..69e0878f 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -22,14 +22,14 @@ repos:
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
- rev: v0.6.9
+ rev: v0.8.1
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
- rev: v1.11.2
+ rev: v1.13.0
hooks:
- id: mypy
additional_dependencies:
diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 05e69a09..ce4c4db0 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -5,6 +5,8 @@ This library adheres to `Semantic Versioning 2.0 `_.
**UNRELEASED**
+- Updated ``TaskGroup`` to work with asyncio's eager task factories
+ (`#764 `_)
- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 `_; PR by @graingert)
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index fe648a6f..38b68f4d 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -28,6 +28,8 @@
Collection,
Coroutine,
Iterable,
+ Iterator,
+ MutableMapping,
Sequence,
)
from concurrent.futures import Future
@@ -351,8 +353,12 @@ def get_callable_name(func: Callable) -> str:
def _task_started(task: asyncio.Task) -> bool:
"""Return ``True`` if the task has been started and has not finished."""
+ # The task coro should never be None here, as we never add finished tasks to the
+ # task list
+ coro = task.get_coro()
+ assert coro is not None
try:
- return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
+ return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
except AttributeError:
# task coro is async_genenerator_asend https://bugs.python.org/issue37771
raise Exception(f"Cannot determine if task {task} has started or not") from None
@@ -409,8 +415,10 @@ def __enter__(self) -> CancelScope:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
+ # If using an eager task factory, the parent scope may not even contain
+ # the host task
self._parent_scope._child_scopes.add(self)
- self._parent_scope._tasks.remove(host_task)
+ self._parent_scope._tasks.discard(host_task)
self._timeout()
self._active = True
@@ -667,7 +675,45 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope
-_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
+class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
+ def __init__(self) -> None:
+ self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
+ self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}
+
+ def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
+ assert isinstance(key, asyncio.Task)
+ try:
+ return self._task_states[key]
+ except KeyError:
+ if coro := key.get_coro():
+ if state := self._preliminary_task_states.get(coro):
+ return state
+
+ raise KeyError(key)
+
+ def __setitem__(
+ self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
+ ) -> None:
+ if isinstance(key, asyncio.Task):
+ self._task_states[key] = value
+ else:
+ self._preliminary_task_states[key] = value
+
+ def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
+ if isinstance(key, asyncio.Task):
+ del self._task_states[key]
+ else:
+ del self._preliminary_task_states[key]
+
+ def __len__(self) -> int:
+ return len(self._task_states) + len(self._preliminary_task_states)
+
+ def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
+ yield from self._task_states
+ yield from self._preliminary_task_states
+
+
+_task_states = TaskStateStore()
#
@@ -787,7 +833,7 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
- task_state = _task_states[_task]
+ # task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
@@ -844,16 +890,26 @@ def task_done(_task: asyncio.Task) -> None:
f"the return value ({coro!r}) is not a coroutine object"
)
- name = get_callable_name(func) if name is None else str(name)
- task = create_task(coro, name=name)
- task.add_done_callback(task_done)
-
# Make the spawned task inherit the task group's cancel scope
- _task_states[task] = TaskState(
+ _task_states[coro] = task_state = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
+ name = get_callable_name(func) if name is None else str(name)
+ try:
+ task = create_task(coro, name=name)
+ finally:
+ del _task_states[coro]
+
+ _task_states[task] = task_state
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
+
+ if task.done():
+ # This can happen with eager task factories
+ task_done(task)
+ else:
+ task.add_done_callback(task_done)
+
return task
def start_soon(
@@ -2086,7 +2142,9 @@ def __init__(self, task: asyncio.Task):
else:
parent_id = task_state.parent_id
- super().__init__(id(task), parent_id, task.get_name(), task.get_coro())
+ coro = task.get_coro()
+ assert coro is not None, "created TaskInfo from a completed Task"
+ super().__init__(id(task), parent_id, task.get_name(), coro)
self._task = weakref.ref(task)
def has_pending_cancellation(self) -> bool:
@@ -2339,10 +2397,11 @@ def create_cancel_scope(
@classmethod
def current_effective_deadline(cls) -> float:
+ if (task := current_task()) is None:
+ return math.inf
+
try:
- cancel_scope = _task_states[
- current_task() # type: ignore[index]
- ].cancel_scope
+ cancel_scope = _task_states[task].cancel_scope
except KeyError:
return math.inf
diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py
index 88aecf38..f6e5c40c 100644
--- a/src/anyio/abc/_tasks.py
+++ b/src/anyio/abc/_tasks.py
@@ -40,6 +40,12 @@ class TaskGroup(metaclass=ABCMeta):
:ivar cancel_scope: the cancel scope inherited by all child tasks
:vartype cancel_scope: CancelScope
+
+ .. note:: On asyncio, support for eager task factories is considered to be
+ **experimental**. In particular, they don't follow the usual semantics of new
+ tasks being scheduled on the next iteration of the event loop, and may thus
+ cause unexpected behavior in code that wasn't written with such semantics in
+ mind.
"""
cancel_scope: CancelScope
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index a1189bb8..8965ea61 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -149,7 +149,7 @@ def _identity(v: _T) -> _T:
)
-@_ignore_win32_resource_warnings # type: ignore[operator]
+@_ignore_win32_resource_warnings
class TestTCPStream:
@pytest.fixture
def server_sock(self, family: AnyIPAddressFamily) -> Iterator[socket.socket]:
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index 78ef9983..84101e47 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -10,7 +10,8 @@
from typing import Any, NoReturn, cast
import pytest
-from exceptiongroup import ExceptionGroup, catch
+from exceptiongroup import catch
+from pytest import FixtureRequest
from pytest_mock import MockerFixture
import anyio
@@ -783,7 +784,7 @@ async def host_agen_fn() -> AsyncGenerator[None, None]:
host_agen = host_agen_fn()
try:
loop = asyncio.get_running_loop()
- await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
+ await loop.create_task(host_agen.__anext__())
finally:
await host_agen.aclose()
@@ -1704,3 +1705,24 @@ async def typetest_optional_status(
task_status: TaskStatus[int] = TASK_STATUS_IGNORED,
) -> None:
task_status.started(1)
+
+
+@pytest.mark.skipif(
+ sys.version_info < (3, 12),
+ reason="Eager task factories require Python 3.12",
+)
+@pytest.mark.parametrize("anyio_backend", ["asyncio"])
+async def test_eager_task_factory(request: FixtureRequest) -> None:
+ async def sync_coro() -> None:
+ # This should trigger fetching the task state
+ with CancelScope(): # noqa: ASYNC100
+ pass
+
+ loop = asyncio.get_running_loop()
+ old_task_factory = loop.get_task_factory()
+ loop.set_task_factory(asyncio.eager_task_factory)
+ request.addfinalizer(lambda: loop.set_task_factory(old_task_factory))
+
+ async with create_task_group() as tg:
+ tg.start_soon(sync_coro)
+ tg.cancel_scope.cancel()