Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encourage creation of aiohttp objects from coroutines #3372

Merged
merged 9 commits into from
Oct 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/3331.removal
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Encourage creation of aiohttp public objects inside a coroutine
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
all: test

.install-deps: $(shell find requirements -type f)
@pip install -r requirements/cython.txt
@pip install -U -r requirements/dev.txt
pip install -r requirements/cython.txt
pip install -r requirements/dev.txt
@touch .install-deps

isort:
Expand All @@ -17,7 +17,7 @@ flake: .flake
.flake: .install-deps $(shell find aiohttp -type f) \
$(shell find tests -type f) \
$(shell find examples -type f)
@flake8 aiohttp examples tests
flake8 aiohttp examples tests
python setup.py check -rms
@if ! isort -c -rc aiohttp tests examples; then \
echo "Import sort errors, run 'make isort' to fix them!!!"; \
Expand All @@ -30,15 +30,15 @@ flake: .flake
@touch .flake

check_changes:
@./tools/check_changes.py
./tools/check_changes.py

mypy: .flake
if python -c "import sys; sys.exit(sys.implementation.name!='cpython')"; then \
mypy aiohttp; \
fi

.develop: .install-deps $(shell find aiohttp -type f) .flake check_changes mypy
@pip install -e .
# pip install -e .
@touch .develop

test: .develop
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from multidict import CIMultiDict # noqa
from yarl import URL

from .helpers import get_running_loop
from .typedefs import LooseCookies


Expand Down Expand Up @@ -133,7 +134,7 @@ class AbstractCookieJar(Sized, IterableBase):

def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
self._loop = loop or asyncio.get_event_loop()
self._loop = get_running_loop(loop)

@abstractmethod
def clear(self) -> None:
Expand Down
4 changes: 1 addition & 3 deletions aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ class BaseProtocol(asyncio.Protocol):
__slots__ = ('_loop', '_paused', '_drain_waiter',
'_connection_lost', 'transport')

def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
if loop is None:
loop = asyncio.get_event_loop()
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop # type: asyncio.AbstractEventLoop
self._paused = False
self._drain_waiter = None # type: Optional[asyncio.Future[None]]
Expand Down
20 changes: 4 additions & 16 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from .connector import BaseConnector, TCPConnector
from .cookiejar import CookieJar
from .helpers import (DEBUG, PY_36, BasicAuth, CeilTimeout, TimeoutHandle,
proxies_from_env, sentinel, strip_auth_from_url)
get_running_loop, proxies_from_env, sentinel,
strip_auth_from_url)
from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
from .http_websocket import (WSHandshakeError, WSMessage, ws_ext_gen, # noqa
ws_ext_parse)
Expand Down Expand Up @@ -121,13 +122,11 @@ def __init__(self, *, connector: Optional[BaseConnector]=None,
trust_env: bool=False,
trace_configs: Optional[List[TraceConfig]]=None) -> None:

implicit_loop = False
if loop is None:
if connector is not None:
loop = connector._loop
else:
implicit_loop = True
loop = asyncio.get_event_loop()

loop = get_running_loop(loop)

if connector is None:
connector = TCPConnector(loop=loop)
Expand All @@ -141,17 +140,6 @@ def __init__(self, *, connector: Optional[BaseConnector]=None,
if loop.get_debug():
self._source_traceback = traceback.extract_stack(sys._getframe(1))

if implicit_loop and not loop.is_running():
warnings.warn("Creating a client session outside of coroutine is "
"a very dangerous idea",
stacklevel=2)
context = {'client_session': self,
'message': 'Creating a client session outside '
'of coroutine'}
if self._source_traceback is not None:
context['source_traceback'] = self._source_traceback
loop.call_exception_handler(context)

if cookie_jar is None:
cookie_jar = CookieJar(loop=loop)
self._cookie_jar = cookie_jar
Expand Down
8 changes: 3 additions & 5 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@ class ResponseHandler(BaseProtocol,
DataQueue[Tuple[RawResponseMessage, StreamReader]]):
"""Helper class to adapt between Protocol and StreamReader."""

def __init__(self, *,
loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
if loop is None:
loop = asyncio.get_event_loop()
def __init__(self,
loop: asyncio.AbstractEventLoop) -> None:
BaseProtocol.__init__(self, loop=loop)
DataQueue.__init__(self, loop=loop)
DataQueue.__init__(self, loop)

self._should_close = False

Expand Down
6 changes: 3 additions & 3 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
ssl_errors)
from .client_proto import ResponseHandler
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
from .helpers import PY_36, CeilTimeout, is_ip_address, noop, sentinel
from .helpers import (PY_36, CeilTimeout, get_running_loop, is_ip_address,
noop, sentinel)
from .http import RESPONSES
from .locks import EventResultOrError
from .resolver import DefaultResolver
Expand Down Expand Up @@ -180,8 +181,7 @@ def __init__(self, *,
if keepalive_timeout is sentinel:
keepalive_timeout = 15.0

if loop is None:
loop = asyncio.get_event_loop()
loop = get_running_loop(loop)

self._closed = False
if loop.get_debug():
Expand Down
18 changes: 17 additions & 1 deletion aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
import sys
import time
import warnings
import weakref
from collections import namedtuple
from contextlib import suppress
Expand All @@ -30,7 +31,7 @@
from yarl import URL

from . import hdrs
from .log import client_logger
from .log import client_logger, internal_logger
from .typedefs import PathLike # noqa


Expand Down Expand Up @@ -231,6 +232,21 @@ def current_task(loop: Optional[asyncio.AbstractEventLoop]=None) -> asyncio.Task
return asyncio.Task.current_task(loop=loop) # type: ignore


def get_running_loop(
loop: Optional[asyncio.AbstractEventLoop]=None
) -> asyncio.AbstractEventLoop:
if loop is None:
loop = asyncio.get_event_loop()
if not loop.is_running():
warnings.warn("The object should be created from async function",
DeprecationWarning, stacklevel=3)
if loop.get_debug():
internal_logger.warning(
"The object should be created from async function",
stack_info=True)
return loop


def isasyncgenfunction(obj: Any) -> bool:
func = getattr(inspect, 'isasyncgenfunction', None)
if func is not None:
Expand Down
10 changes: 3 additions & 7 deletions aiohttp/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, List, Optional

from .abc import AbstractResolver
from .helpers import get_running_loop


__all__ = ('ThreadedResolver', 'AsyncResolver', 'DefaultResolver')
Expand All @@ -22,9 +23,7 @@ class ThreadedResolver(AbstractResolver):
"""

def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None) -> None:
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._loop = get_running_loop(loop)

async def resolve(self, host: str, port: int=0,
family: int=socket.AF_INET) -> List[Dict[str, Any]]:
Expand All @@ -50,13 +49,10 @@ class AsyncResolver(AbstractResolver):

def __init__(self, loop: Optional[asyncio.AbstractEventLoop]=None,
*args: Any, **kwargs: Any) -> None:
if loop is None:
loop = asyncio.get_event_loop()

if aiodns is None:
raise RuntimeError("Resolver requires aiodns library")

self._loop = loop
self._loop = get_running_loop(loop)
self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs)

if not hasattr(self._resolver, 'gethostbyname'):
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def read_nowait(self) -> bytes:
class DataQueue(Generic[_T]):
"""DataQueue is a general-purpose blocking queue with one reader."""

def __init__(self, *, loop: asyncio.AbstractEventLoop) -> None:
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self._eof = False
self._waiter = None # type: Optional[asyncio.Future[bool]]
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/web_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class RequestHandler(BaseProtocol):
'access_logger', '_close', '_force_close')

def __init__(self, manager: 'Server', *,
loop: Optional[asyncio.AbstractEventLoop]=None,
loop: asyncio.AbstractEventLoop,
keepalive_timeout: float=75., # NGINX default is 75 secs
tcp_keepalive: bool=True,
logger: Logger=server_logger,
Expand All @@ -118,7 +118,7 @@ def __init__(self, manager: 'Server', *,
max_field_size: int=8190,
lingering_time: float=10.0):

super().__init__(loop=loop)
super().__init__(loop)

self._request_count = 0
self._keepalive = False
Expand Down
5 changes: 2 additions & 3 deletions aiohttp/web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa

from .abc import AbstractStreamWriter
from .helpers import get_running_loop
from .http_parser import RawRequestMessage
from .streams import StreamReader
from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler
Expand All @@ -20,9 +21,7 @@ def __init__(self,
request_factory: Optional[_RequestFactory]=None,
loop: Optional[asyncio.AbstractEventLoop]=None,
**kwargs: Any) -> None:
if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
self._loop = get_running_loop(loop)
self._connections = {} # type: Dict[RequestHandler, asyncio.Transport]
self._kwargs = kwargs
self.requests_count = 0
Expand Down
11 changes: 2 additions & 9 deletions tests/test_base_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,13 @@
async def test_loop() -> None:
loop = asyncio.get_event_loop()
asyncio.set_event_loop(None)
pr = BaseProtocol(loop=loop)
assert pr._loop is loop


async def test_default_loop() -> None:
loop = asyncio.get_event_loop()
asyncio.set_event_loop(loop)
pr = BaseProtocol()
pr = BaseProtocol(loop)
assert pr._loop is loop


async def test_pause_writing() -> None:
loop = asyncio.get_event_loop()
pr = BaseProtocol(loop=loop)
pr = BaseProtocol(loop)
assert not pr._paused
pr.pause_writing()
assert pr._paused
Expand Down
2 changes: 1 addition & 1 deletion tests/test_client_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_http_processing_error(session) -> None:
loop.get_debug.return_value = True

connection = mock.Mock()
connection.protocol = aiohttp.DataQueue(loop=loop)
connection.protocol = aiohttp.DataQueue(loop)
connection.protocol.set_response_params = mock.Mock()
connection.protocol.set_exception(http.HttpProcessingError())

Expand Down
Loading