diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 191dc392..bfaa73a9 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -14,17 +14,6 @@ This library adheres to `Semantic Versioning 2.0 `_. create UNIX datagram sockets (PR by Jean Hominal) - Improved type annotations: - - Several functions and methods that previously only accepted coroutines as the return - type of the callable have been amended to accept any awaitables: - - - ``anyio.run()`` - - ``anyio.from_thread.run()`` - - ``TaskGroup.start_soon()`` - - ``TaskGroup.start()`` - - ``BlockingPortal.call()`` - - ``BlockingPortal.start_task_soon()`` - - ``BlockingPortal.start_task()`` - - The ``TaskStatus`` class is now a generic protocol, and should be parametrized to indicate the type of the value passed to ``task_status.started()`` - The ``Listener`` class is now covariant in its stream type @@ -50,6 +39,18 @@ This library adheres to `Semantic Versioning 2.0 `_. ``TLSStream.wrap()`` being inadvertently set on Python 3.11.3 and 3.10.11 - Fixed ``CancelScope`` to properly handle asyncio task uncancellation on Python 3.11 (PR by Nikolay Bryskin) +- Several functions and methods that previously only accepted coroutines as the return + type of the callable can now accept any awaitables: + + - ``anyio.run()`` + - ``anyio.from_thread.run()`` + - ``TaskGroup.start_soon()`` + - ``TaskGroup.start()`` + - ``BlockingPortal.call()`` + - ``BlockingPortal.start_task_soon()`` + - ``BlockingPortal.start_task()`` + + (PRs by Alex Grönholm and Ganden Schaffner) **3.6.1** diff --git a/pyproject.toml b/pyproject.toml index ea6bc4f8..67f43361 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "exceptiongroup; python_version < '3.11'", "idna >= 2.8", "sniffio >= 1.1", - "typing_extensions; python_version < '3.8'", + "typing_extensions >= 3.10; python_version < '3.10'", ] dynamic = ["version"] diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index df8672da..7817dba2 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -19,7 +19,7 @@ from asyncio import run as native_run from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] from collections import OrderedDict, deque -from collections.abc import AsyncIterator, Iterable +from collections.abc import AsyncIterator, Awaitable, Iterable from concurrent.futures import Future from contextvars import Context, copy_context from dataclasses import dataclass @@ -36,7 +36,6 @@ IO, Any, AsyncGenerator, - Awaitable, Callable, Collection, ContextManager, @@ -83,6 +82,11 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + if sys.version_info >= (3, 8): def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: @@ -96,6 +100,7 @@ def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) +P = ParamSpec("P") # Check whether there is native support for task names in asyncio (3.8+) _native_task_names = hasattr(asyncio.Task, "get_name") @@ -140,6 +145,31 @@ def get_callable_name(func: Callable) -> str: return ".".join([x for x in (module, qualname) if x]) +def ensure_returns_coro( + func: Callable[P, Awaitable[T_Retval]] +) -> Callable[P, Coroutine[Any, Any, T_Retval]]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]: + awaitable = func(*args, **kwargs) + # Check the common case first. + if asyncio.iscoroutine(awaitable): + return awaitable + elif not isinstance(awaitable, Awaitable): + # The user violated the type annotations. + raise TypeError( + f"Expected an async function, but {func} appears to be synchronous" + ) + else: + + @wraps(func) + async def inner_wrapper() -> T_Retval: + return await awaitable + + return inner_wrapper() + + return wrapper + + # # Event loop # @@ -611,11 +641,7 @@ def task_done(_task: asyncio.Task) -> None: else: parent_id = id(self.cancel_scope._host_task) - coro = func(*args, **kwargs) - if not asyncio.iscoroutine(coro): - raise TypeError( - f"Expected an async function, but {func} appears to be synchronous" - ) + coro = ensure_returns_coro(func)(*args, **kwargs) foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") if foreign_coro or sys.version_info < (3, 8): @@ -1784,9 +1810,9 @@ async def _run_tests_and_fixtures( ], ) -> None: with receive_stream: - async for coro, future in receive_stream: + async for awaitable, future in receive_stream: try: - retval = await coro + retval = await awaitable except BaseException as exc: if not future.cancelled(): future.set_exception(exc) @@ -1803,9 +1829,9 @@ async def _call_in_runner_task( self._run_tests_and_fixtures(receive_stream) ) - coro = func(*args, **kwargs) + awaitable = func(*args, **kwargs) future: asyncio.Future[T_Retval] = self._loop.create_future() - self._send_stream.send_nowait((coro, future)) + self._send_stream.send_nowait((awaitable, future)) return await future def close(self) -> None: @@ -2051,7 +2077,7 @@ def run_async_from_thread( ) -> T_Retval: loop = cast(AbstractEventLoop, token) f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( - func(*args), loop + ensure_returns_coro(func)(*args), loop ) return f.result() diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 48b8b40b..6dd79f47 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -3,10 +3,11 @@ import array import math import socket -from collections.abc import AsyncIterator, Iterable +import sys +from collections.abc import AsyncIterator, Awaitable, Coroutine, Iterable from concurrent.futures import Future from dataclasses import dataclass -from functools import partial +from functools import partial, wraps from io import IOBase from os import PathLike from signal import Signals @@ -16,11 +17,9 @@ IO, Any, AsyncGenerator, - Awaitable, Callable, Collection, ContextManager, - Coroutine, Generic, Mapping, NoReturn, @@ -61,9 +60,38 @@ from ..abc._eventloop import AsyncBackend from ..streams.memory import MemoryObjectSendStream -T = TypeVar("T") +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + T_Retval = TypeVar("T_Retval") T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType) +P = ParamSpec("P") + + +def ensure_returns_coro( + func: Callable[P, Awaitable[T_Retval]] +) -> Callable[P, Coroutine[Any, Any, T_Retval]]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]: + awaitable = func(*args, **kwargs) + # Check the common case first. + if isinstance(awaitable, Coroutine): + return awaitable + elif not isinstance(awaitable, Awaitable): + # The user violated the type annotations. Still, we should pass this on to + # Trio so it can raise with an appropriate message. + return awaitable + else: + + @wraps(func) + async def inner_wrapper() -> T_Retval: + return await awaitable + + return inner_wrapper() + + return wrapper # @@ -154,13 +182,15 @@ async def __aexit__( finally: self._active = False - def start_soon(self, func: Callable, *args: object, name: object = None) -> None: + def start_soon( + self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None + ) -> None: if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." ) - self._nursery.start_soon(func, *args, name=name) + self._nursery.start_soon(ensure_returns_coro(func), *args, name=name) async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None @@ -170,7 +200,7 @@ async def start( "This task group is not active; no new tasks can be started." ) - return await self._nursery.start(func, *args, name=name) + return await self._nursery.start(ensure_returns_coro(func), *args, name=name) # @@ -710,15 +740,17 @@ def __init__(self, **options: Any) -> None: from queue import Queue self._call_queue: Queue[Callable[..., object]] = Queue() - self._send_stream: MemoryObjectSendStream | None = None + self._send_stream: MemoryObjectSendStream[ + tuple[Awaitable[Any], list[Outcome]] + ] | None = None self._options = options async def _run_tests_and_fixtures(self) -> None: self._send_stream, receive_stream = create_memory_object_stream(1) with receive_stream: - async for coro, outcome_holder in receive_stream: + async for awaitable, outcome_holder in receive_stream: try: - retval = await coro + retval = await awaitable except BaseException as exc: outcome_holder.append(Error(exc)) else: @@ -793,7 +825,7 @@ def run( kwargs: dict[str, Any], options: dict[str, Any], ) -> T_Retval: - return trio.run(func, *args) + return trio.run(ensure_returns_coro(func), *args) @classmethod def current_token(cls) -> object: @@ -871,7 +903,9 @@ def run_async_from_thread( args: tuple[Any, ...], token: object, ) -> T_Retval: - return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token)) + return trio.from_thread.run( + ensure_returns_coro(func), *args, trio_token=cast(TrioToken, token) + ) @classmethod def run_sync_from_thread( diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 3c8a06bc..551d3983 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -2,7 +2,7 @@ import sys from abc import ABCMeta, abstractmethod -from collections.abc import Awaitable, Callable, Coroutine +from collections.abc import Awaitable, Callable from types import TracebackType from typing import TYPE_CHECKING, Any, TypeVar, overload @@ -48,7 +48,7 @@ class TaskGroup(metaclass=ABCMeta): @abstractmethod def start_soon( self, - func: Callable[..., Coroutine[Any, Any, Any]], + func: Callable[..., Awaitable[Any]], *args: object, name: object = None, ) -> None: diff --git a/tests/misc.py b/tests/misc.py new file mode 100644 index 00000000..5407085a --- /dev/null +++ b/tests/misc.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import sys +from typing import Any, Awaitable, Callable, Generator, TypeVar + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +T = TypeVar("T") +P = ParamSpec("P") + + +def return_non_coro_awaitable( + func: Callable[P, Awaitable[T]] +) -> Callable[P, Awaitable[T]]: + class Wrapper: + def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: + self.args = args + self.kwargs = kwargs + + def __await__(self) -> Generator[Any, None, T]: + return func(*self.args, **self.kwargs).__await__() + + return Wrapper diff --git a/tests/test_eventloop.py b/tests/test_eventloop.py index 9bc79683..e01b7e25 100644 --- a/tests/test_eventloop.py +++ b/tests/test_eventloop.py @@ -1,14 +1,16 @@ from __future__ import annotations -import asyncio import math import sys +from typing import Any import pytest from pytest_mock.plugin import MockerFixture from anyio import run, sleep_forever, sleep_until +from .misc import return_non_coro_awaitable + if sys.version_info < (3, 8): from mock import AsyncMock else: @@ -41,11 +43,14 @@ async def test_sleep_forever(fake_sleep: AsyncMock) -> None: fake_sleep.assert_called_once_with(math.inf) -def test_run_task() -> None: - """Test that anyio.run() on asyncio will work with a callable returning a Future.""" - - async def async_add(x: int, y: int) -> int: - return x + y +def test_run_non_corofunc( + anyio_backend_name: str, anyio_backend_options: dict[str, Any] +) -> None: + @return_non_coro_awaitable + async def func() -> str: + return "foo" - result = run(asyncio.create_task, async_add(1, 2), backend="asyncio") - assert result == 3 + result = run( + func, backend=anyio_backend_name, backend_options=anyio_backend_options + ) + assert result == "foo" diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py index 93886f53..6c32a701 100644 --- a/tests/test_from_thread.py +++ b/tests/test_from_thread.py @@ -27,6 +27,8 @@ from anyio.from_thread import BlockingPortal, start_blocking_portal from anyio.lowlevel import checkpoint +from .misc import return_non_coro_awaitable + if sys.version_info >= (3, 8): from typing import Literal else: @@ -39,10 +41,12 @@ async def async_add(a: int, b: int) -> int: assert threading.current_thread() is threading.main_thread() + await checkpoint() return a + b async def asyncgen_add(a: int, b: int) -> AsyncGenerator[int, Any]: + await checkpoint() yield a + b @@ -68,6 +72,12 @@ async def test_run_corofunc_from_thread(self) -> None: result = await to_thread.run_sync(thread_worker_async, async_add, 1, 2) assert result == 3 + async def test_run_non_corofunc_from_thread(self) -> None: + result = await to_thread.run_sync( + thread_worker_async, return_non_coro_awaitable(async_add), 1, 2 + ) + assert result == 3 + async def test_run_asyncgen_from_thread(self) -> None: gen = asyncgen_add(1, 2) try: @@ -115,13 +125,6 @@ async def test_run_sync_from_thread_exception(self) -> None: exc.match("unsupported operand type") - async def test_run_anyio_async_func_from_thread(self) -> None: - def worker(*args: int) -> Literal[True]: - from_thread.run(sleep, *args) - return True - - assert await to_thread.run_sync(worker, 0) - def test_run_async_from_unclaimed_thread(self) -> None: async def foo() -> None: pass @@ -182,6 +185,13 @@ async def test_call_corofunc(self) -> None: result = await to_thread.run_sync(portal.call, async_add, 1, 2) assert result == 3 + async def test_call_non_corofunc(self) -> None: + async with BlockingPortal() as portal: + result = await to_thread.run_sync( + portal.call, return_non_coro_awaitable(async_add), 1, 2 + ) + assert result == 3 + async def test_call_anext(self) -> None: gen = asyncgen_add(1, 2) try: @@ -298,6 +308,17 @@ async def event_waiter() -> Literal["test"]: portal.call(event2.wait) assert future.result() == "test" + def test_start_task_soon_non_corofunc( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + @return_non_coro_awaitable + async def taskfunc() -> Literal["test"]: + return "test" + + with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: + future = portal.start_task_soon(taskfunc) + assert future.result() == "test" + def test_start_task_soon_cancel_later( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: @@ -397,22 +418,34 @@ async def run_in_context() -> AsyncGenerator[None, None]: def test_start_no_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started() with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value is None assert future.result() is None def test_start_with_value( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: + task_status.started("foo") + + with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: + future, value = portal.start_task(taskfunc) + assert value == "foo" + assert future.result() is None + + def test_start_non_corofunc( + self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] + ) -> None: + @return_non_coro_awaitable + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started("foo") with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, value = portal.start_task(taskfunc) # type: ignore[arg-type] + future, value = portal.start_task(taskfunc) assert value == "foo" assert future.result() is None @@ -442,23 +475,21 @@ def taskfunc(*, task_status: TaskStatus) -> NoReturn: def test_start_no_started_call( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: pass with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: with pytest.raises(RuntimeError, match="Task exited"): - portal.start_task(taskfunc) # type: ignore[arg-type] + portal.start_task(taskfunc) def test_start_with_name( self, anyio_backend_name: str, anyio_backend_options: dict[str, Any] ) -> None: - def taskfunc(*, task_status: TaskStatus) -> None: + async def taskfunc(*, task_status: TaskStatus) -> None: task_status.started(get_current_task().name) with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal: - future, start_value = portal.start_task( - taskfunc, name="testname" # type: ignore[arg-type] - ) + future, start_value = portal.start_task(taskfunc, name="testname") assert start_value == "testname" def test_contextvar_propagation_sync( diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py index a5fc3c49..f7f72127 100644 --- a/tests/test_taskgroups.py +++ b/tests/test_taskgroups.py @@ -26,6 +26,8 @@ from anyio.abc import TaskGroup, TaskStatus from anyio.lowlevel import checkpoint +from .misc import return_non_coro_awaitable + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup @@ -1094,6 +1096,29 @@ async def starter_task() -> None: assert permanent_parent_id == root_task_id +async def test_start_soon_non_corofunc() -> None: + finished = False + + @return_non_coro_awaitable + async def taskfunc() -> None: + nonlocal finished + finished = True + + async with create_task_group() as tg: + tg.start_soon(taskfunc) + assert finished + + +async def test_start_non_corofunc() -> None: + @return_non_coro_awaitable + async def taskfunc(*, task_status: TaskStatus[str]) -> None: + task_status.started("foo") + + async with create_task_group() as tg: + value = await tg.start(taskfunc) + assert value == "foo" + + @pytest.mark.skipif( sys.version_info < (3, 11), reason="Task uncancelling is only supported on Python 3.11",