Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made asyncio TaskGroup work with eager task factories #822

Merged
merged 15 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there's 0.8 out now?

hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies:
Expand Down
6 changes: 6 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Improved compatibility with asyncio's eager task factories:

* Updated the annotation of ``TaskInfo.coro`` to allow it to be ``None``
* Updated ``TaskGroup`` to work with asyncio eager task factories

(`#764 <https://github.com/agronholm/anyio/issues/764>`_)
- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
Expand Down
81 changes: 69 additions & 12 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
Collection,
Coroutine,
Iterable,
Iterator,
MutableMapping,
Sequence,
)
from concurrent.futures import Future
Expand Down Expand Up @@ -351,8 +353,12 @@ def get_callable_name(func: Callable) -> str:

def _task_started(task: asyncio.Task) -> bool:
"""Return ``True`` if the task has been started and has not finished."""
# The task coro should never be None here, as we never add finished tasks to the
# task list
coro = task.get_coro()
assert coro is not None
try:
return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
except AttributeError:
# task coro is async_genenerator_asend https://bugs.python.org/issue37771
raise Exception(f"Cannot determine if task {task} has started or not") from None
Expand Down Expand Up @@ -409,8 +415,10 @@ def __enter__(self) -> CancelScope:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
# If using an eager task factory, the parent scope may not even contain
# the host task
self._parent_scope._child_scopes.add(self)
self._parent_scope._tasks.remove(host_task)
self._parent_scope._tasks.discard(host_task)

self._timeout()
self._active = True
Expand Down Expand Up @@ -667,7 +675,45 @@ def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.cancel_scope = cancel_scope


_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary()
class TaskStateStore(MutableMapping["Awaitable[Any] | asyncio.Task", TaskState]):
def __init__(self) -> None:
self._task_states = WeakKeyDictionary[asyncio.Task, TaskState]()
self._preliminary_task_states: dict[Awaitable[Any], TaskState] = {}

def __getitem__(self, key: Awaitable[Any] | asyncio.Task, /) -> TaskState:
assert isinstance(key, asyncio.Task)
try:
return self._task_states[key]
except KeyError:
if coro := key.get_coro():
if state := self._preliminary_task_states.get(coro):
return state

raise KeyError(key)

def __setitem__(
self, key: asyncio.Task | Awaitable[Any], value: TaskState, /
) -> None:
if isinstance(key, asyncio.Task):
self._task_states[key] = value
else:
self._preliminary_task_states[key] = value

def __delitem__(self, key: asyncio.Task | Awaitable[Any], /) -> None:
if isinstance(key, asyncio.Task):
del self._task_states[key]
else:
del self._preliminary_task_states[key]

def __len__(self) -> int:
return len(self._task_states) + len(self._preliminary_task_states)

def __iter__(self) -> Iterator[Awaitable[Any] | asyncio.Task]:
yield from self._task_states
yield from self._preliminary_task_states


_task_states = TaskStateStore()


#
Expand Down Expand Up @@ -787,7 +833,7 @@ def _spawn(
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
task_state = _task_states[_task]
# task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
Expand Down Expand Up @@ -844,16 +890,26 @@ def task_done(_task: asyncio.Task) -> None:
f"the return value ({coro!r}) is not a coroutine object"
)

name = get_callable_name(func) if name is None else str(name)
task = create_task(coro, name=name)
task.add_done_callback(task_done)

# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
_task_states[coro] = task_state = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
name = get_callable_name(func) if name is None else str(name)
try:
task = create_task(coro, name=name)
finally:
del _task_states[coro]

_task_states[task] = task_state
self.cancel_scope._tasks.add(task)
self._tasks.add(task)

if task.done():
# This can happen with eager task factories
task_done(task)
else:
task.add_done_callback(task_done)

return task

def start_soon(
Expand Down Expand Up @@ -2339,10 +2395,11 @@ def create_cancel_scope(

@classmethod
def current_effective_deadline(cls) -> float:
if (task := current_task()) is None:
return math.inf

try:
cancel_scope = _task_states[
current_task() # type: ignore[index]
].cancel_scope
cancel_scope = _task_states[task].cancel_scope
except KeyError:
return math.inf

Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_core/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def __init__(
id: int,
parent_id: int | None,
name: str | None,
coro: Generator[Any, Any, Any] | Awaitable[Any],
coro: Generator[Any, Any, Any] | Awaitable[Any] | None,
):
func = get_current_task
self._name = f"{func.__module__}.{func.__qualname__}"
self.id: int = id
self.parent_id: int | None = parent_id
self.name: str | None = name
self.coro: Generator[Any, Any, Any] | Awaitable[Any] = coro
self.coro: Generator[Any, Any, Any] | Awaitable[Any] | None = coro

def __eq__(self, other: object) -> bool:
if isinstance(other, TaskInfo):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _identity(v: _T) -> _T:
)


@_ignore_win32_resource_warnings # type: ignore[operator]
@_ignore_win32_resource_warnings
class TestTCPStream:
@pytest.fixture
def server_sock(self, family: AnyIPAddressFamily) -> Iterator[socket.socket]:
Expand Down
26 changes: 24 additions & 2 deletions tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from typing import Any, NoReturn, cast

import pytest
from exceptiongroup import ExceptionGroup, catch
from exceptiongroup import catch
from pytest import FixtureRequest
from pytest_mock import MockerFixture

import anyio
Expand Down Expand Up @@ -783,7 +784,7 @@ async def host_agen_fn() -> AsyncGenerator[None, None]:
host_agen = host_agen_fn()
try:
loop = asyncio.get_running_loop()
await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
await loop.create_task(host_agen.__anext__())
finally:
await host_agen.aclose()

Expand Down Expand Up @@ -1704,3 +1705,24 @@ async def typetest_optional_status(
task_status: TaskStatus[int] = TASK_STATUS_IGNORED,
) -> None:
task_status.started(1)


@pytest.mark.skipif(
sys.version_info < (3, 12),
reason="Eager task factories require Python 3.12",
)
@pytest.mark.parametrize("anyio_backend", ["asyncio"])
async def test_eager_task_factory(request: FixtureRequest) -> None:
async def sync_coro() -> None:
# This should trigger fetching the task state
with CancelScope(): # noqa: ASYNC100
pass

loop = asyncio.get_running_loop()
old_task_factory = loop.get_task_factory()
loop.set_task_factory(asyncio.eager_task_factory)
request.addfinalizer(lambda: loop.set_task_factory(old_task_factory))

async with create_task_group() as tg:
tg.start_soon(sync_coro)
tg.cancel_scope.cancel()