Skip to content

Commit

Permalink
Support async cancellations. (#726)
Browse files Browse the repository at this point in the history
* Add 'AsyncShieldCancellation' context manager

* Update _synchronization.py

* Linting

* Fix docstring wording

* Add interim 'nocover' to show tests passing.

* Add failing test case for HTTP/1.1 cancellations

* Neat cleanup for HTTP/1.1 write cancellations

* Drop 'nocover' for ShieldCancellation

* Add failing test case for HTTP/1.1 cancellations during response reading

* Resolve failing test case

* Add failing test cases for cancellations on connection pools

* Resolve failing test cases

* Add failing test cases for cancellations on HTTP/2 connections

* Resolve failing test cases

* Add failing test cases for cancellations on HTTP/2 connections when reading response

* Resolve failing test cases

* Update CHANGELOG

* Fix yield behaviour
  • Loading branch information
tomchristie authored Jul 4, 2023
1 parent 630e1e9 commit 31a4a56
Show file tree
Hide file tree
Showing 9 changed files with 364 additions and 32 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## unreleased

- The networking backend interface has [been added to the public API](https://www.encode.io/httpcore/network-backends). Some classes which were previously private implementation detail are now part of the top-level public API. (#699)
- Support async cancellations, ensuring that the connection pool is left in a clean state when cancellations occur. (#726)
- Graceful handling of HTTP/2 GoAway frames, with requests being transparently retried on a new connection. (#730)
- Add exceptions when a synchronous `trace callback` is passed to an asynchronous request or an asynchronous `trace callback` is passed to a synchronous request. (#717)

Expand Down
8 changes: 5 additions & 3 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._backends.base import SOCKET_OPTION, AsyncNetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import AsyncEvent, AsyncLock
from .._synchronization import AsyncEvent, AsyncLock, AsyncShieldCancellation
from .connection import AsyncHTTPConnection
from .interfaces import AsyncConnectionInterface, AsyncRequestInterface

Expand Down Expand Up @@ -257,7 +257,8 @@ async def handle_async_request(self, request: Request) -> Response:
status.unset_connection()
await self._attempt_to_acquire_connection(status)
except BaseException as exc:
await self.response_closed(status)
with AsyncShieldCancellation():
await self.response_closed(status)
raise exc
else:
break
Expand Down Expand Up @@ -351,4 +352,5 @@ async def aclose(self) -> None:
if hasattr(self._stream, "aclose"):
await self._stream.aclose()
finally:
await self._pool.response_closed(self._status)
with AsyncShieldCancellation():
await self._pool.response_closed(self._status)
10 changes: 6 additions & 4 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock
from .._synchronization import AsyncLock, AsyncShieldCancellation
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -115,8 +115,9 @@ async def handle_async_request(self, request: Request) -> Response:
},
)
except BaseException as exc:
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
with AsyncShieldCancellation():
async with Trace("response_closed", logger, request) as trace:
await self._response_closed()
raise exc

# Sending the request...
Expand Down Expand Up @@ -319,7 +320,8 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
await self.aclose()
with AsyncShieldCancellation():
await self.aclose()
raise exc

async def aclose(self) -> None:
Expand Down
26 changes: 17 additions & 9 deletions httpcore/_async/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncSemaphore
from .._synchronization import AsyncLock, AsyncSemaphore, AsyncShieldCancellation
from .._trace import Trace
from .interfaces import AsyncConnectionInterface

Expand Down Expand Up @@ -103,9 +103,15 @@ async def handle_async_request(self, request: Request) -> Response:

async with self._init_lock:
if not self._sent_connection_init:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
try:
kwargs = {"request": request}
async with Trace("send_connection_init", logger, request, kwargs):
await self._send_connection_init(**kwargs)
except BaseException as exc:
with AsyncShieldCancellation():
await self.aclose()
raise exc

self._sent_connection_init = True

# Initially start with just 1 until the remote server provides
Expand Down Expand Up @@ -154,10 +160,11 @@ async def handle_async_request(self, request: Request) -> Response:
"stream_id": stream_id,
},
)
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)
except BaseException as exc: # noqa: PIE786
with AsyncShieldCancellation():
kwargs = {"stream_id": stream_id}
async with Trace("response_closed", logger, request, kwargs):
await self._response_closed(stream_id=stream_id)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
Expand Down Expand Up @@ -570,7 +577,8 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
await self.aclose()
with AsyncShieldCancellation():
await self.aclose()
raise exc

async def aclose(self) -> None:
Expand Down
8 changes: 5 additions & 3 deletions httpcore/_sync/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._backends.base import SOCKET_OPTION, NetworkBackend
from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
from .._models import Origin, Request, Response
from .._synchronization import Event, Lock
from .._synchronization import Event, Lock, ShieldCancellation
from .connection import HTTPConnection
from .interfaces import ConnectionInterface, RequestInterface

Expand Down Expand Up @@ -257,7 +257,8 @@ def handle_request(self, request: Request) -> Response:
status.unset_connection()
self._attempt_to_acquire_connection(status)
except BaseException as exc:
self.response_closed(status)
with ShieldCancellation():
self.response_closed(status)
raise exc
else:
break
Expand Down Expand Up @@ -351,4 +352,5 @@ def close(self) -> None:
if hasattr(self._stream, "close"):
self._stream.close()
finally:
self._pool.response_closed(self._status)
with ShieldCancellation():
self._pool.response_closed(self._status)
10 changes: 6 additions & 4 deletions httpcore/_sync/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock
from .._synchronization import Lock, ShieldCancellation
from .._trace import Trace
from .interfaces import ConnectionInterface

Expand Down Expand Up @@ -115,8 +115,9 @@ def handle_request(self, request: Request) -> Response:
},
)
except BaseException as exc:
with Trace("response_closed", logger, request) as trace:
self._response_closed()
with ShieldCancellation():
with Trace("response_closed", logger, request) as trace:
self._response_closed()
raise exc

# Sending the request...
Expand Down Expand Up @@ -319,7 +320,8 @@ def __iter__(self) -> Iterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
self.close()
with ShieldCancellation():
self.close()
raise exc

def close(self) -> None:
Expand Down
26 changes: 17 additions & 9 deletions httpcore/_sync/http2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RemoteProtocolError,
)
from .._models import Origin, Request, Response
from .._synchronization import Lock, Semaphore
from .._synchronization import Lock, Semaphore, ShieldCancellation
from .._trace import Trace
from .interfaces import ConnectionInterface

Expand Down Expand Up @@ -103,9 +103,15 @@ def handle_request(self, request: Request) -> Response:

with self._init_lock:
if not self._sent_connection_init:
kwargs = {"request": request}
with Trace("send_connection_init", logger, request, kwargs):
self._send_connection_init(**kwargs)
try:
kwargs = {"request": request}
with Trace("send_connection_init", logger, request, kwargs):
self._send_connection_init(**kwargs)
except BaseException as exc:
with ShieldCancellation():
self.close()
raise exc

self._sent_connection_init = True

# Initially start with just 1 until the remote server provides
Expand Down Expand Up @@ -154,10 +160,11 @@ def handle_request(self, request: Request) -> Response:
"stream_id": stream_id,
},
)
except Exception as exc: # noqa: PIE786
kwargs = {"stream_id": stream_id}
with Trace("response_closed", logger, request, kwargs):
self._response_closed(stream_id=stream_id)
except BaseException as exc: # noqa: PIE786
with ShieldCancellation():
kwargs = {"stream_id": stream_id}
with Trace("response_closed", logger, request, kwargs):
self._response_closed(stream_id=stream_id)

if isinstance(exc, h2.exceptions.ProtocolError):
# One case where h2 can raise a protocol error is when a
Expand Down Expand Up @@ -570,7 +577,8 @@ def __iter__(self) -> typing.Iterator[bytes]:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
self.close()
with ShieldCancellation():
self.close()
raise exc

def close(self) -> None:
Expand Down
65 changes: 65 additions & 0 deletions httpcore/_synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,55 @@ async def release(self) -> None:
self._anyio_semaphore.release()


class AsyncShieldCancellation:
# For certain portions of our codebase where we're dealing with
# closing connections during exception handling we want to shield
# the operation from being cancelled.
#
# with AsyncShieldCancellation():
# ... # clean-up operations, shielded from cancellation.

def __init__(self) -> None:
"""
Detect if we're running under 'asyncio' or 'trio' and create
a shielded scope with the correct implementation.
"""
self._backend = sniffio.current_async_library()

if self._backend == "trio":
if trio is None: # pragma: nocover
raise RuntimeError(
"Running under trio requires the 'trio' package to be installed."
)

self._trio_shield = trio.CancelScope(shield=True)
else:
if anyio is None: # pragma: nocover
raise RuntimeError(
"Running under asyncio requires the 'anyio' package to be installed."
)

self._anyio_shield = anyio.CancelScope(shield=True)

def __enter__(self) -> "AsyncShieldCancellation":
if self._backend == "trio":
self._trio_shield.__enter__()
else:
self._anyio_shield.__enter__()
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self._backend == "trio":
self._trio_shield.__exit__(exc_type, exc_value, traceback)
else:
self._anyio_shield.__exit__(exc_type, exc_value, traceback)


# Our thread-based synchronization primitives...


Expand Down Expand Up @@ -212,3 +261,19 @@ def acquire(self) -> None:

def release(self) -> None:
self._semaphore.release()


class ShieldCancellation:
# Thread-synchronous codebases don't support cancellation semantics.
# We have this class because we need to mirror the async and sync
# cases within our package, but it's just a no-op.
def __enter__(self) -> "ShieldCancellation":
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
pass
Loading

0 comments on commit 31a4a56

Please sign in to comment.