Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broken support for Callable[..., Awaitable] #567

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
create UNIX datagram sockets (PR by Jean Hominal)
- Improved type annotations:

- Several functions and methods that previously only accepted coroutines as the return
type of the callable have been amended to accept any awaitables:

- ``anyio.run()``
- ``anyio.from_thread.run()``
- ``TaskGroup.start_soon()``
- ``TaskGroup.start()``
- ``BlockingPortal.call()``
- ``BlockingPortal.start_task_soon()``
- ``BlockingPortal.start_task()``

- The ``TaskStatus`` class is now a generic protocol, and should be parametrized to
indicate the type of the value passed to ``task_status.started()``
- The ``Listener`` class is now covariant in its stream type
Expand All @@ -50,6 +39,18 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
``TLSStream.wrap()`` being inadvertently set on Python 3.11.3 and 3.10.11
- Fixed ``CancelScope`` to properly handle asyncio task uncancellation on Python 3.11
(PR by Nikolay Bryskin)
- Several functions and methods that previously only accepted coroutines as the return
type of the callable can now accept any awaitables:

- ``anyio.run()``
- ``anyio.from_thread.run()``
- ``TaskGroup.start_soon()``
- ``TaskGroup.start()``
- ``BlockingPortal.call()``
- ``BlockingPortal.start_task_soon()``
- ``BlockingPortal.start_task()``

(PRs by Alex Grönholm and Ganden Schaffner)

**3.6.1**

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"exceptiongroup; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
"typing_extensions; python_version < '3.8'",
"typing_extensions >= 3.10; python_version < '3.10'",
]
dynamic = ["version"]

Expand Down
50 changes: 38 additions & 12 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from asyncio import run as native_run
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
from collections.abc import AsyncIterator, Iterable
from collections.abc import AsyncIterator, Awaitable, Iterable
from concurrent.futures import Future
from contextvars import Context, copy_context
from dataclasses import dataclass
Expand All @@ -36,7 +36,6 @@
IO,
Any,
AsyncGenerator,
Awaitable,
Callable,
Collection,
ContextManager,
Expand Down Expand Up @@ -83,6 +82,11 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

if sys.version_info >= (3, 8):

def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
Expand All @@ -96,6 +100,7 @@ def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:

T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True)
P = ParamSpec("P")

# Check whether there is native support for task names in asyncio (3.8+)
_native_task_names = hasattr(asyncio.Task, "get_name")
Expand Down Expand Up @@ -140,6 +145,31 @@ def get_callable_name(func: Callable) -> str:
return ".".join([x for x in (module, qualname) if x])


def ensure_returns_coro(
func: Callable[P, Awaitable[T_Retval]]
) -> Callable[P, Coroutine[Any, Any, T_Retval]]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]:
awaitable = func(*args, **kwargs)
# Check the common case first.
if asyncio.iscoroutine(awaitable):
return awaitable
elif not isinstance(awaitable, Awaitable):
# The user violated the type annotations.
raise TypeError(
f"Expected an async function, but {func} appears to be synchronous"
)
else:

@wraps(func)
async def inner_wrapper() -> T_Retval:
return await awaitable

return inner_wrapper()

return wrapper


#
# Event loop
#
Expand Down Expand Up @@ -611,11 +641,7 @@ def task_done(_task: asyncio.Task) -> None:
else:
parent_id = id(self.cancel_scope._host_task)

coro = func(*args, **kwargs)
if not asyncio.iscoroutine(coro):
raise TypeError(
f"Expected an async function, but {func} appears to be synchronous"
)
coro = ensure_returns_coro(func)(*args, **kwargs)

foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
if foreign_coro or sys.version_info < (3, 8):
Expand Down Expand Up @@ -1784,9 +1810,9 @@ async def _run_tests_and_fixtures(
],
) -> None:
with receive_stream:
async for coro, future in receive_stream:
async for awaitable, future in receive_stream:
try:
retval = await coro
retval = await awaitable
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
Expand All @@ -1803,9 +1829,9 @@ async def _call_in_runner_task(
self._run_tests_and_fixtures(receive_stream)
)

coro = func(*args, **kwargs)
awaitable = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self._loop.create_future()
self._send_stream.send_nowait((coro, future))
self._send_stream.send_nowait((awaitable, future))
return await future

def close(self) -> None:
Expand Down Expand Up @@ -2051,7 +2077,7 @@ def run_async_from_thread(
) -> T_Retval:
loop = cast(AbstractEventLoop, token)
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
func(*args), loop
ensure_returns_coro(func)(*args), loop
)
return f.result()

Expand Down
60 changes: 47 additions & 13 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import array
import math
import socket
from collections.abc import AsyncIterator, Iterable
import sys
from collections.abc import AsyncIterator, Awaitable, Coroutine, Iterable
from concurrent.futures import Future
from dataclasses import dataclass
from functools import partial
from functools import partial, wraps
from io import IOBase
from os import PathLike
from signal import Signals
Expand All @@ -16,11 +17,9 @@
IO,
Any,
AsyncGenerator,
Awaitable,
Callable,
Collection,
ContextManager,
Coroutine,
Generic,
Mapping,
NoReturn,
Expand Down Expand Up @@ -61,9 +60,38 @@
from ..abc._eventloop import AsyncBackend
from ..streams.memory import MemoryObjectSendStream

T = TypeVar("T")
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

T_Retval = TypeVar("T_Retval")
T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
P = ParamSpec("P")


def ensure_returns_coro(
func: Callable[P, Awaitable[T_Retval]]
) -> Callable[P, Coroutine[Any, Any, T_Retval]]:
@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]:
awaitable = func(*args, **kwargs)
# Check the common case first.
if isinstance(awaitable, Coroutine):
return awaitable
elif not isinstance(awaitable, Awaitable):
# The user violated the type annotations. Still, we should pass this on to
# Trio so it can raise with an appropriate message.
return awaitable
else:

@wraps(func)
async def inner_wrapper() -> T_Retval:
return await awaitable

return inner_wrapper()

return wrapper


#
Expand Down Expand Up @@ -154,13 +182,15 @@ async def __aexit__(
finally:
self._active = False

def start_soon(self, func: Callable, *args: object, name: object = None) -> None:
def start_soon(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
) -> None:
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
)

self._nursery.start_soon(func, *args, name=name)
self._nursery.start_soon(ensure_returns_coro(func), *args, name=name)

async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
Expand All @@ -170,7 +200,7 @@ async def start(
"This task group is not active; no new tasks can be started."
)

return await self._nursery.start(func, *args, name=name)
return await self._nursery.start(ensure_returns_coro(func), *args, name=name)


#
Expand Down Expand Up @@ -710,15 +740,17 @@ def __init__(self, **options: Any) -> None:
from queue import Queue

self._call_queue: Queue[Callable[..., object]] = Queue()
self._send_stream: MemoryObjectSendStream | None = None
self._send_stream: MemoryObjectSendStream[
tuple[Awaitable[Any], list[Outcome]]
] | None = None
self._options = options

async def _run_tests_and_fixtures(self) -> None:
self._send_stream, receive_stream = create_memory_object_stream(1)
with receive_stream:
async for coro, outcome_holder in receive_stream:
async for awaitable, outcome_holder in receive_stream:
try:
retval = await coro
retval = await awaitable
except BaseException as exc:
outcome_holder.append(Error(exc))
else:
Expand Down Expand Up @@ -793,7 +825,7 @@ def run(
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
return trio.run(func, *args)
return trio.run(ensure_returns_coro(func), *args)

@classmethod
def current_token(cls) -> object:
Expand Down Expand Up @@ -871,7 +903,9 @@ def run_async_from_thread(
args: tuple[Any, ...],
token: object,
) -> T_Retval:
return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
return trio.from_thread.run(
ensure_returns_coro(func), *args, trio_token=cast(TrioToken, token)
)

@classmethod
def run_sync_from_thread(
Expand Down
4 changes: 2 additions & 2 deletions src/anyio/abc/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Awaitable, Callable, Coroutine
from collections.abc import Awaitable, Callable
from types import TracebackType
from typing import TYPE_CHECKING, Any, TypeVar, overload

Expand Down Expand Up @@ -48,7 +48,7 @@ class TaskGroup(metaclass=ABCMeta):
@abstractmethod
def start_soon(
self,
func: Callable[..., Coroutine[Any, Any, Any]],
func: Callable[..., Awaitable[Any]],
*args: object,
name: object = None,
) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import sys
from typing import Any, Awaitable, Callable, Generator, TypeVar

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


def return_non_coro_awaitable(
func: Callable[P, Awaitable[T]]
) -> Callable[P, Awaitable[T]]:
class Wrapper:
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
self.args = args
self.kwargs = kwargs

def __await__(self) -> Generator[Any, None, T]:
return func(*self.args, **self.kwargs).__await__()

return Wrapper
21 changes: 13 additions & 8 deletions tests/test_eventloop.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import asyncio
import math
import sys
from typing import Any

import pytest
from pytest_mock.plugin import MockerFixture

from anyio import run, sleep_forever, sleep_until

from .misc import return_non_coro_awaitable

if sys.version_info < (3, 8):
from mock import AsyncMock
else:
Expand Down Expand Up @@ -41,11 +43,14 @@ async def test_sleep_forever(fake_sleep: AsyncMock) -> None:
fake_sleep.assert_called_once_with(math.inf)


def test_run_task() -> None:
"""Test that anyio.run() on asyncio will work with a callable returning a Future."""

async def async_add(x: int, y: int) -> int:
return x + y
def test_run_non_corofunc(
anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
@return_non_coro_awaitable
async def func() -> str:
return "foo"

result = run(asyncio.create_task, async_add(1, 2), backend="asyncio")
assert result == 3
result = run(
func, backend=anyio_backend_name, backend_options=anyio_backend_options
)
assert result == "foo"
Loading