Skip to content

Commit

Permalink
Fix unclosed generator on trio
Browse files Browse the repository at this point in the history
  • Loading branch information
florimondmanca committed Feb 12, 2023
1 parent 7eb2022 commit b94afad
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 11 deletions.
5 changes: 2 additions & 3 deletions httpcore/_async/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,8 @@ def __init__(
self._pool = pool
self._status = status

async def __aiter__(self) -> AsyncIterator[bytes]:
async for part in self._stream:
yield part
def __aiter__(self) -> AsyncIterator[bytes]:
return self._stream.__aiter__()

async def aclose(self) -> None:
try:
Expand Down
30 changes: 22 additions & 8 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import enum
import time
from contextlib import AsyncExitStack
from types import TracebackType
from typing import (
AsyncGenerator,
AsyncIterable,
AsyncIterator,
List,
Expand Down Expand Up @@ -173,7 +175,7 @@ async def _receive_response_headers(

return http_version, event.status_code, event.reason, headers

async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]:
async def _receive_response_body(self, request: Request) -> AsyncGenerator[bytes, None]:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("read", None)

Expand Down Expand Up @@ -304,22 +306,34 @@ def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None:
self._connection = connection
self._request = request
self._closed = False

async def __aiter__(self) -> AsyncIterator[bytes]:
kwargs = {"request": self._request}
self._stream = self._connection._receive_response_body(**kwargs)
self._trace = Trace("http11.receive_response_body", request, kwargs)

def __aiter__(self) -> AsyncIterator[bytes]:
return self

async def __anext__(self) -> bytes:
if not hasattr(self, "_trace_exit_stack"):
self._trace_exit_stack = AsyncExitStack()
await self._trace_exit_stack.enter_async_context(self._trace)

try:
async with Trace("http11.receive_response_body", self._request, kwargs):
async for chunk in self._connection._receive_response_body(**kwargs):
yield chunk
except BaseException as exc:
return await self._stream.__anext__()
except BaseException:
# If we get an exception while streaming the response,
# we want to close the response (and possibly the connection)
# before raising that exception.
await self._stream.aclose()
await self.aclose()
raise exc
raise

async def aclose(self) -> None:
if hasattr(self, "_trace_exit_stack"):
await self._trace_exit_stack.aclose()

if not self._closed:
await self._stream.aclose()
self._closed = True
async with Trace("http11.response_closed", self._request):
await self._connection._response_closed()

0 comments on commit b94afad

Please sign in to comment.