Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex committed Feb 11, 2024
1 parent fd088ff commit aa46d41
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 58 deletions.
116 changes: 60 additions & 56 deletions tests/protocols/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import threading
import time
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, overload

import pytest

Expand All @@ -16,6 +16,7 @@
from uvicorn.lifespan.on import LifespanOn
from uvicorn.main import ServerState
from uvicorn.protocols.http.h11_impl import H11Protocol
from uvicorn.protocols.utils import ClientDisconnected

try:
from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol
Expand Down Expand Up @@ -222,12 +223,32 @@ def add_done_callback(self, callback):
pass


@overload
def get_connected_protocol(
app: ASGIApplication,
http_protocol_cls: type[HttpToolsProtocol],
lifespan: LifespanOff | LifespanOn | None = ...,
**kwargs: Any,
) -> HttpToolsProtocol:
...


@overload
def get_connected_protocol(
app: ASGIApplication,
http_protocol_cls: type[H11Protocol],
lifespan: LifespanOff | LifespanOn | None = ...,
**kwargs: Any,
) -> H11Protocol:
...


def get_connected_protocol(
app: ASGIApplication,
http_protocol_cls: HTTPProtocol,
lifespan: LifespanOff | LifespanOn | None = None,
**kwargs: Any,
):
) -> HttpToolsProtocol | H11Protocol:
loop = MockLoop()
transport = MockTransport()
config = Config(app=app, **kwargs)
Expand Down Expand Up @@ -369,9 +390,7 @@ async def test_close(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_chunked_encoding(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand All @@ -385,9 +404,7 @@ async def test_chunked_encoding(


@pytest.mark.anyio
async def test_chunked_encoding_empty_body(
http_protocol_cls: HTTPProtocol,
):
async def test_chunked_encoding_empty_body(http_protocol_cls: HTTPProtocol):
app = Response(
b"Hello, world!", status_code=200, headers={"transfer-encoding": "chunked"}
)
Expand Down Expand Up @@ -416,9 +433,7 @@ async def test_chunked_encoding_head_request(


@pytest.mark.anyio
async def test_pipelined_requests(
http_protocol_cls: HTTPProtocol,
):
async def test_pipelined_requests(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -440,9 +455,7 @@ async def test_pipelined_requests(


@pytest.mark.anyio
async def test_undersized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_undersized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx", headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -452,9 +465,7 @@ async def test_undersized_request(


@pytest.mark.anyio
async def test_oversized_request(
http_protocol_cls: HTTPProtocol,
):
async def test_oversized_request(http_protocol_cls: HTTPProtocol):
app = Response(b"xxx" * 20, headers={"content-length": "10"})

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -464,9 +475,7 @@ async def test_oversized_request(


@pytest.mark.anyio
async def test_large_post_request(
http_protocol_cls: HTTPProtocol,
):
async def test_large_post_request(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -486,9 +495,7 @@ async def test_invalid_http(http_protocol_cls: HTTPProtocol):


@pytest.mark.anyio
async def test_app_exception(
http_protocol_cls: HTTPProtocol,
):
async def test_app_exception(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
raise Exception()

Expand All @@ -500,9 +507,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_exception_during_response(
http_protocol_cls: HTTPProtocol,
):
async def test_exception_during_response(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b"1", "more_body": True})
Expand All @@ -516,9 +521,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_no_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_no_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
...

Expand All @@ -530,9 +533,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_partial_response_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_partial_response_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})

Expand All @@ -544,9 +545,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_duplicate_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_duplicate_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.start", "status": 200})
Expand All @@ -559,9 +558,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_missing_start_message(
http_protocol_cls: HTTPProtocol,
):
async def test_missing_start_message(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.body", "body": b""})

Expand All @@ -573,9 +570,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_message_after_body_complete(
http_protocol_cls: HTTPProtocol,
):
async def test_message_after_body_complete(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -589,9 +584,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_value_returned(
http_protocol_cls: HTTPProtocol,
):
async def test_value_returned(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
await send({"type": "http.response.start", "status": 200})
await send({"type": "http.response.body", "body": b""})
Expand All @@ -605,9 +598,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_disconnect(
http_protocol_cls: HTTPProtocol,
):
async def test_early_disconnect(http_protocol_cls: HTTPProtocol):
got_disconnect_event = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -629,9 +620,26 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_early_response(
http_protocol_cls: HTTPProtocol,
):
async def test_disconnect_on_send(http_protocol_cls: HTTPProtocol) -> None:
got_disconnected = False

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
try:
await send({"type": "http.response.start", "status": 200})
except ClientDisconnected:
nonlocal got_disconnected
got_disconnected = True

protocol = get_connected_protocol(app, http_protocol_cls)
protocol.data_received(SIMPLE_GET_REQUEST)
protocol.eof_received()
protocol.connection_lost(None)
await protocol.loop.run_one()
assert got_disconnected


@pytest.mark.anyio
async def test_early_response(http_protocol_cls: HTTPProtocol):
app = Response("Hello, world", media_type="text/plain")

protocol = get_connected_protocol(app, http_protocol_cls)
Expand All @@ -643,9 +651,7 @@ async def test_early_response(


@pytest.mark.anyio
async def test_read_after_response(
http_protocol_cls: HTTPProtocol,
):
async def test_read_after_response(http_protocol_cls: HTTPProtocol):
message_after_response = None

async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
Expand All @@ -663,9 +669,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable


@pytest.mark.anyio
async def test_http10_request(
http_protocol_cls: HTTPProtocol,
):
async def test_http10_request(http_protocol_cls: HTTPProtocol):
async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable):
assert scope["type"] == "http"
content = "Version: %s" % scope["http_version"]
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/http/h11_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ async def run_asgi(self, app: "ASGI3Application") -> None:
self.scope, self.receive, self.send
)
except ClientDisconnected:
...
pass
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
Expand Down
2 changes: 1 addition & 1 deletion uvicorn/protocols/http/httptools_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ async def run_asgi(self, app: ASGI3Application) -> None:
self.scope, self.receive, self.send
)
except ClientDisconnected:
...
pass
except BaseException as exc:
msg = "Exception in ASGI application\n"
self.logger.error(msg, exc_info=exc)
Expand Down

0 comments on commit aa46d41

Please sign in to comment.