diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 191dc392..bfaa73a9 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -14,17 +14,6 @@ This library adheres to `Semantic Versioning 2.0 `_.
create UNIX datagram sockets (PR by Jean Hominal)
- Improved type annotations:
- - Several functions and methods that previously only accepted coroutines as the return
- type of the callable have been amended to accept any awaitables:
-
- - ``anyio.run()``
- - ``anyio.from_thread.run()``
- - ``TaskGroup.start_soon()``
- - ``TaskGroup.start()``
- - ``BlockingPortal.call()``
- - ``BlockingPortal.start_task_soon()``
- - ``BlockingPortal.start_task()``
-
- The ``TaskStatus`` class is now a generic protocol, and should be parametrized to
indicate the type of the value passed to ``task_status.started()``
- The ``Listener`` class is now covariant in its stream type
@@ -50,6 +39,18 @@ This library adheres to `Semantic Versioning 2.0 `_.
``TLSStream.wrap()`` being inadvertently set on Python 3.11.3 and 3.10.11
- Fixed ``CancelScope`` to properly handle asyncio task uncancellation on Python 3.11
(PR by Nikolay Bryskin)
+- Several functions and methods that previously only accepted coroutines as the return
+ type of the callable can now accept any awaitables:
+
+ - ``anyio.run()``
+ - ``anyio.from_thread.run()``
+ - ``TaskGroup.start_soon()``
+ - ``TaskGroup.start()``
+ - ``BlockingPortal.call()``
+ - ``BlockingPortal.start_task_soon()``
+ - ``BlockingPortal.start_task()``
+
+ (PRs by Alex Grönholm and Ganden Schaffner)
**3.6.1**
diff --git a/pyproject.toml b/pyproject.toml
index ea6bc4f8..67f43361 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,7 +29,7 @@ dependencies = [
"exceptiongroup; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
- "typing_extensions; python_version < '3.8'",
+ "typing_extensions >= 3.10; python_version < '3.10'",
]
dynamic = ["version"]
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index df8672da..7817dba2 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -19,7 +19,7 @@
from asyncio import run as native_run
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
-from collections.abc import AsyncIterator, Iterable
+from collections.abc import AsyncIterator, Awaitable, Iterable
from concurrent.futures import Future
from contextvars import Context, copy_context
from dataclasses import dataclass
@@ -36,7 +36,6 @@
IO,
Any,
AsyncGenerator,
- Awaitable,
Callable,
Collection,
ContextManager,
@@ -83,6 +82,11 @@
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+else:
+ from typing_extensions import ParamSpec
+
if sys.version_info >= (3, 8):
def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
@@ -96,6 +100,7 @@ def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]:
T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True)
+P = ParamSpec("P")
# Check whether there is native support for task names in asyncio (3.8+)
_native_task_names = hasattr(asyncio.Task, "get_name")
@@ -140,6 +145,31 @@ def get_callable_name(func: Callable) -> str:
return ".".join([x for x in (module, qualname) if x])
+def ensure_returns_coro(
+ func: Callable[P, Awaitable[T_Retval]]
+) -> Callable[P, Coroutine[Any, Any, T_Retval]]:
+ @wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]:
+ awaitable = func(*args, **kwargs)
+ # Check the common case first.
+ if asyncio.iscoroutine(awaitable):
+ return awaitable
+ elif not isinstance(awaitable, Awaitable):
+ # The user violated the type annotations.
+ raise TypeError(
+ f"Expected an async function, but {func} appears to be synchronous"
+ )
+ else:
+
+ @wraps(func)
+ async def inner_wrapper() -> T_Retval:
+ return await awaitable
+
+ return inner_wrapper()
+
+ return wrapper
+
+
#
# Event loop
#
@@ -611,11 +641,7 @@ def task_done(_task: asyncio.Task) -> None:
else:
parent_id = id(self.cancel_scope._host_task)
- coro = func(*args, **kwargs)
- if not asyncio.iscoroutine(coro):
- raise TypeError(
- f"Expected an async function, but {func} appears to be synchronous"
- )
+ coro = ensure_returns_coro(func)(*args, **kwargs)
foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame")
if foreign_coro or sys.version_info < (3, 8):
@@ -1784,9 +1810,9 @@ async def _run_tests_and_fixtures(
],
) -> None:
with receive_stream:
- async for coro, future in receive_stream:
+ async for awaitable, future in receive_stream:
try:
- retval = await coro
+ retval = await awaitable
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
@@ -1803,9 +1829,9 @@ async def _call_in_runner_task(
self._run_tests_and_fixtures(receive_stream)
)
- coro = func(*args, **kwargs)
+ awaitable = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self._loop.create_future()
- self._send_stream.send_nowait((coro, future))
+ self._send_stream.send_nowait((awaitable, future))
return await future
def close(self) -> None:
@@ -2051,7 +2077,7 @@ def run_async_from_thread(
) -> T_Retval:
loop = cast(AbstractEventLoop, token)
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
- func(*args), loop
+ ensure_returns_coro(func)(*args), loop
)
return f.result()
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index 48b8b40b..6dd79f47 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -3,10 +3,11 @@
import array
import math
import socket
-from collections.abc import AsyncIterator, Iterable
+import sys
+from collections.abc import AsyncIterator, Awaitable, Coroutine, Iterable
from concurrent.futures import Future
from dataclasses import dataclass
-from functools import partial
+from functools import partial, wraps
from io import IOBase
from os import PathLike
from signal import Signals
@@ -16,11 +17,9 @@
IO,
Any,
AsyncGenerator,
- Awaitable,
Callable,
Collection,
ContextManager,
- Coroutine,
Generic,
Mapping,
NoReturn,
@@ -61,9 +60,38 @@
from ..abc._eventloop import AsyncBackend
from ..streams.memory import MemoryObjectSendStream
-T = TypeVar("T")
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+else:
+ from typing_extensions import ParamSpec
+
T_Retval = TypeVar("T_Retval")
T_SockAddr = TypeVar("T_SockAddr", str, IPSockAddrType)
+P = ParamSpec("P")
+
+
+def ensure_returns_coro(
+ func: Callable[P, Awaitable[T_Retval]]
+) -> Callable[P, Coroutine[Any, Any, T_Retval]]:
+ @wraps(func)
+ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, T_Retval]:
+ awaitable = func(*args, **kwargs)
+ # Check the common case first.
+ if isinstance(awaitable, Coroutine):
+ return awaitable
+ elif not isinstance(awaitable, Awaitable):
+ # The user violated the type annotations. Still, we should pass this on to
+ # Trio so it can raise with an appropriate message.
+ return awaitable
+ else:
+
+ @wraps(func)
+ async def inner_wrapper() -> T_Retval:
+ return await awaitable
+
+ return inner_wrapper()
+
+ return wrapper
#
@@ -154,13 +182,15 @@ async def __aexit__(
finally:
self._active = False
- def start_soon(self, func: Callable, *args: object, name: object = None) -> None:
+ def start_soon(
+ self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
+ ) -> None:
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
)
- self._nursery.start_soon(func, *args, name=name)
+ self._nursery.start_soon(ensure_returns_coro(func), *args, name=name)
async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
@@ -170,7 +200,7 @@ async def start(
"This task group is not active; no new tasks can be started."
)
- return await self._nursery.start(func, *args, name=name)
+ return await self._nursery.start(ensure_returns_coro(func), *args, name=name)
#
@@ -710,15 +740,17 @@ def __init__(self, **options: Any) -> None:
from queue import Queue
self._call_queue: Queue[Callable[..., object]] = Queue()
- self._send_stream: MemoryObjectSendStream | None = None
+ self._send_stream: MemoryObjectSendStream[
+ tuple[Awaitable[Any], list[Outcome]]
+ ] | None = None
self._options = options
async def _run_tests_and_fixtures(self) -> None:
self._send_stream, receive_stream = create_memory_object_stream(1)
with receive_stream:
- async for coro, outcome_holder in receive_stream:
+ async for awaitable, outcome_holder in receive_stream:
try:
- retval = await coro
+ retval = await awaitable
except BaseException as exc:
outcome_holder.append(Error(exc))
else:
@@ -793,7 +825,7 @@ def run(
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
- return trio.run(func, *args)
+ return trio.run(ensure_returns_coro(func), *args)
@classmethod
def current_token(cls) -> object:
@@ -871,7 +903,9 @@ def run_async_from_thread(
args: tuple[Any, ...],
token: object,
) -> T_Retval:
- return trio.from_thread.run(func, *args, trio_token=cast(TrioToken, token))
+ return trio.from_thread.run(
+ ensure_returns_coro(func), *args, trio_token=cast(TrioToken, token)
+ )
@classmethod
def run_sync_from_thread(
diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py
index 3c8a06bc..551d3983 100644
--- a/src/anyio/abc/_tasks.py
+++ b/src/anyio/abc/_tasks.py
@@ -2,7 +2,7 @@
import sys
from abc import ABCMeta, abstractmethod
-from collections.abc import Awaitable, Callable, Coroutine
+from collections.abc import Awaitable, Callable
from types import TracebackType
from typing import TYPE_CHECKING, Any, TypeVar, overload
@@ -48,7 +48,7 @@ class TaskGroup(metaclass=ABCMeta):
@abstractmethod
def start_soon(
self,
- func: Callable[..., Coroutine[Any, Any, Any]],
+ func: Callable[..., Awaitable[Any]],
*args: object,
name: object = None,
) -> None:
diff --git a/tests/misc.py b/tests/misc.py
new file mode 100644
index 00000000..5407085a
--- /dev/null
+++ b/tests/misc.py
@@ -0,0 +1,26 @@
+from __future__ import annotations
+
+import sys
+from typing import Any, Awaitable, Callable, Generator, TypeVar
+
+if sys.version_info >= (3, 10):
+ from typing import ParamSpec
+else:
+ from typing_extensions import ParamSpec
+
+T = TypeVar("T")
+P = ParamSpec("P")
+
+
+def return_non_coro_awaitable(
+ func: Callable[P, Awaitable[T]]
+) -> Callable[P, Awaitable[T]]:
+ class Wrapper:
+ def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
+ self.args = args
+ self.kwargs = kwargs
+
+ def __await__(self) -> Generator[Any, None, T]:
+ return func(*self.args, **self.kwargs).__await__()
+
+ return Wrapper
diff --git a/tests/test_eventloop.py b/tests/test_eventloop.py
index 9bc79683..e01b7e25 100644
--- a/tests/test_eventloop.py
+++ b/tests/test_eventloop.py
@@ -1,14 +1,16 @@
from __future__ import annotations
-import asyncio
import math
import sys
+from typing import Any
import pytest
from pytest_mock.plugin import MockerFixture
from anyio import run, sleep_forever, sleep_until
+from .misc import return_non_coro_awaitable
+
if sys.version_info < (3, 8):
from mock import AsyncMock
else:
@@ -41,11 +43,14 @@ async def test_sleep_forever(fake_sleep: AsyncMock) -> None:
fake_sleep.assert_called_once_with(math.inf)
-def test_run_task() -> None:
- """Test that anyio.run() on asyncio will work with a callable returning a Future."""
-
- async def async_add(x: int, y: int) -> int:
- return x + y
+def test_run_non_corofunc(
+ anyio_backend_name: str, anyio_backend_options: dict[str, Any]
+) -> None:
+ @return_non_coro_awaitable
+ async def func() -> str:
+ return "foo"
- result = run(asyncio.create_task, async_add(1, 2), backend="asyncio")
- assert result == 3
+ result = run(
+ func, backend=anyio_backend_name, backend_options=anyio_backend_options
+ )
+ assert result == "foo"
diff --git a/tests/test_from_thread.py b/tests/test_from_thread.py
index 93886f53..6c32a701 100644
--- a/tests/test_from_thread.py
+++ b/tests/test_from_thread.py
@@ -27,6 +27,8 @@
from anyio.from_thread import BlockingPortal, start_blocking_portal
from anyio.lowlevel import checkpoint
+from .misc import return_non_coro_awaitable
+
if sys.version_info >= (3, 8):
from typing import Literal
else:
@@ -39,10 +41,12 @@
async def async_add(a: int, b: int) -> int:
assert threading.current_thread() is threading.main_thread()
+ await checkpoint()
return a + b
async def asyncgen_add(a: int, b: int) -> AsyncGenerator[int, Any]:
+ await checkpoint()
yield a + b
@@ -68,6 +72,12 @@ async def test_run_corofunc_from_thread(self) -> None:
result = await to_thread.run_sync(thread_worker_async, async_add, 1, 2)
assert result == 3
+ async def test_run_non_corofunc_from_thread(self) -> None:
+ result = await to_thread.run_sync(
+ thread_worker_async, return_non_coro_awaitable(async_add), 1, 2
+ )
+ assert result == 3
+
async def test_run_asyncgen_from_thread(self) -> None:
gen = asyncgen_add(1, 2)
try:
@@ -115,13 +125,6 @@ async def test_run_sync_from_thread_exception(self) -> None:
exc.match("unsupported operand type")
- async def test_run_anyio_async_func_from_thread(self) -> None:
- def worker(*args: int) -> Literal[True]:
- from_thread.run(sleep, *args)
- return True
-
- assert await to_thread.run_sync(worker, 0)
-
def test_run_async_from_unclaimed_thread(self) -> None:
async def foo() -> None:
pass
@@ -182,6 +185,13 @@ async def test_call_corofunc(self) -> None:
result = await to_thread.run_sync(portal.call, async_add, 1, 2)
assert result == 3
+ async def test_call_non_corofunc(self) -> None:
+ async with BlockingPortal() as portal:
+ result = await to_thread.run_sync(
+ portal.call, return_non_coro_awaitable(async_add), 1, 2
+ )
+ assert result == 3
+
async def test_call_anext(self) -> None:
gen = asyncgen_add(1, 2)
try:
@@ -298,6 +308,17 @@ async def event_waiter() -> Literal["test"]:
portal.call(event2.wait)
assert future.result() == "test"
+ def test_start_task_soon_non_corofunc(
+ self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
+ ) -> None:
+ @return_non_coro_awaitable
+ async def taskfunc() -> Literal["test"]:
+ return "test"
+
+ with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
+ future = portal.start_task_soon(taskfunc)
+ assert future.result() == "test"
+
def test_start_task_soon_cancel_later(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
@@ -397,22 +418,34 @@ async def run_in_context() -> AsyncGenerator[None, None]:
def test_start_no_value(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
- def taskfunc(*, task_status: TaskStatus) -> None:
+ async def taskfunc(*, task_status: TaskStatus) -> None:
task_status.started()
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
- future, value = portal.start_task(taskfunc) # type: ignore[arg-type]
+ future, value = portal.start_task(taskfunc)
assert value is None
assert future.result() is None
def test_start_with_value(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
- def taskfunc(*, task_status: TaskStatus) -> None:
+ async def taskfunc(*, task_status: TaskStatus) -> None:
+ task_status.started("foo")
+
+ with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
+ future, value = portal.start_task(taskfunc)
+ assert value == "foo"
+ assert future.result() is None
+
+ def test_start_non_corofunc(
+ self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
+ ) -> None:
+ @return_non_coro_awaitable
+ async def taskfunc(*, task_status: TaskStatus) -> None:
task_status.started("foo")
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
- future, value = portal.start_task(taskfunc) # type: ignore[arg-type]
+ future, value = portal.start_task(taskfunc)
assert value == "foo"
assert future.result() is None
@@ -442,23 +475,21 @@ def taskfunc(*, task_status: TaskStatus) -> NoReturn:
def test_start_no_started_call(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
- def taskfunc(*, task_status: TaskStatus) -> None:
+ async def taskfunc(*, task_status: TaskStatus) -> None:
pass
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
with pytest.raises(RuntimeError, match="Task exited"):
- portal.start_task(taskfunc) # type: ignore[arg-type]
+ portal.start_task(taskfunc)
def test_start_with_name(
self, anyio_backend_name: str, anyio_backend_options: dict[str, Any]
) -> None:
- def taskfunc(*, task_status: TaskStatus) -> None:
+ async def taskfunc(*, task_status: TaskStatus) -> None:
task_status.started(get_current_task().name)
with start_blocking_portal(anyio_backend_name, anyio_backend_options) as portal:
- future, start_value = portal.start_task(
- taskfunc, name="testname" # type: ignore[arg-type]
- )
+ future, start_value = portal.start_task(taskfunc, name="testname")
assert start_value == "testname"
def test_contextvar_propagation_sync(
diff --git a/tests/test_taskgroups.py b/tests/test_taskgroups.py
index a5fc3c49..f7f72127 100644
--- a/tests/test_taskgroups.py
+++ b/tests/test_taskgroups.py
@@ -26,6 +26,8 @@
from anyio.abc import TaskGroup, TaskStatus
from anyio.lowlevel import checkpoint
+from .misc import return_non_coro_awaitable
+
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
@@ -1094,6 +1096,29 @@ async def starter_task() -> None:
assert permanent_parent_id == root_task_id
+async def test_start_soon_non_corofunc() -> None:
+ finished = False
+
+ @return_non_coro_awaitable
+ async def taskfunc() -> None:
+ nonlocal finished
+ finished = True
+
+ async with create_task_group() as tg:
+ tg.start_soon(taskfunc)
+ assert finished
+
+
+async def test_start_non_corofunc() -> None:
+ @return_non_coro_awaitable
+ async def taskfunc(*, task_status: TaskStatus[str]) -> None:
+ task_status.started("foo")
+
+ async with create_task_group() as tg:
+ value = await tg.start(taskfunc)
+ assert value == "foo"
+
+
@pytest.mark.skipif(
sys.version_info < (3, 11),
reason="Task uncancelling is only supported on Python 3.11",