Skip to content

Commit

Permalink
Added support for asyncio eager task factories
Browse files Browse the repository at this point in the history
Fixes #764.
  • Loading branch information
agronholm committed Nov 10, 2024
1 parent bdf09a6 commit 1807bed
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.9
rev: v0.7.1
hooks:
- id: ruff
args: [--fix, --show-fixes]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.13.0
hooks:
- id: mypy
additional_dependencies:
Expand Down
6 changes: 6 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

**UNRELEASED**

- Added support for asyncio's eager task factories:

* Updated the annotation of ``TaskInfo.coro`` to allow it to be ``None``
* Updated ``TaskGroup`` to work with asyncio eager task factories

(`#764 <https://github.com/agronholm/anyio/issues/764>`_)
- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)

Expand Down
8 changes: 6 additions & 2 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,12 @@ def get_callable_name(func: Callable) -> str:

def _task_started(task: asyncio.Task) -> bool:
"""Return ``True`` if the task has been started and has not finished."""
# The task coro should never be None here, as we never add finished tasks to the
# task list
coro = task.get_coro()
assert coro is not None
try:
return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED)
except AttributeError:
# task coro is async_genenerator_asend https://bugs.python.org/issue37771
raise Exception(f"Cannot determine if task {task} has started or not") from None
Expand Down Expand Up @@ -842,14 +846,14 @@ def task_done(_task: asyncio.Task) -> None:

name = get_callable_name(func) if name is None else str(name)
task = create_task(coro, name=name)
task.add_done_callback(task_done)

# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
task.add_done_callback(task_done)
return task

def start_soon(
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/_core/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def __init__(
id: int,
parent_id: int | None,
name: str | None,
coro: Generator[Any, Any, Any] | Awaitable[Any],
coro: Generator[Any, Any, Any] | Awaitable[Any] | None,
):
func = get_current_task
self._name = f"{func.__module__}.{func.__qualname__}"
self.id: int = id
self.parent_id: int | None = parent_id
self.name: str | None = name
self.coro: Generator[Any, Any, Any] | Awaitable[Any] = coro
self.coro: Generator[Any, Any, Any] | Awaitable[Any] | None = coro

def __eq__(self, other: object) -> bool:
if isinstance(other, TaskInfo):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _identity(v: _T) -> _T:
)


@_ignore_win32_resource_warnings # type: ignore[operator]
@_ignore_win32_resource_warnings
class TestTCPStream:
@pytest.fixture
def server_sock(self, family: AnyIPAddressFamily) -> Iterator[socket.socket]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_taskgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ async def host_agen_fn() -> AsyncGenerator[None, None]:
host_agen = host_agen_fn()
try:
loop = asyncio.get_running_loop()
await loop.create_task(host_agen.__anext__()) # type: ignore[arg-type]
await loop.create_task(host_agen.__anext__())
finally:
await host_agen.aclose()

Expand Down

0 comments on commit 1807bed

Please sign in to comment.