diff --git a/tests/middleware/test_proxy_headers.py b/tests/middleware/test_proxy_headers.py index d20adcd88..aacf789d4 100644 --- a/tests/middleware/test_proxy_headers.py +++ b/tests/middleware/test_proxy_headers.py @@ -13,34 +13,79 @@ async def app(scope, receive, send): await response(scope, receive, send) -app = ProxyHeadersMiddleware(app, trusted_hosts="*") - - @pytest.mark.asyncio -async def test_proxy_headers(): - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: +@pytest.mark.parametrize( + ("trusted_hosts", "response_text"), + [ + # always trust + ("*", "Remote: https://1.2.3.4:0"), + # trusted proxy + ("127.0.0.1", "Remote: https://1.2.3.4:0"), + (["127.0.0.1"], "Remote: https://1.2.3.4:0"), + # trusted proxy list + (["127.0.0.1", "10.0.0.1"], "Remote: https://1.2.3.4:0"), + ("127.0.0.1, 10.0.0.1", "Remote: https://1.2.3.4:0"), + # request from untrusted proxy + ("192.168.0.1", "Remote: http://127.0.0.1:123"), + ], +) +async def test_proxy_headers_trusted_hosts(trusted_hosts, response_text): + app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts) + async with httpx.AsyncClient( + app=app_with_middleware, base_url="http://testserver" + ) as client: headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} response = await client.get("/", headers=headers) + assert response.status_code == 200 - assert response.text == "Remote: https://1.2.3.4:0" + assert response.text == response_text @pytest.mark.asyncio -async def test_proxy_headers_no_port(): - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: - headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} +@pytest.mark.parametrize( + ("trusted_hosts", "response_text"), + [ + # always trust + ("*", "Remote: https://1.2.3.4:0"), + # all proxies are trusted + ( + ["127.0.0.1", "10.0.2.1", "192.168.0.2"], + "Remote: https://1.2.3.4:0", + ), + # order doesn't matter + ( + ["10.0.2.1", "192.168.0.2", "127.0.0.1"], + "Remote: https://1.2.3.4:0", + ), + # should set first untrusted as remote address + (["192.168.0.2", "127.0.0.1"], "Remote: https://10.0.2.1:0"), + ], +) +async def test_proxy_headers_multiple_proxies(trusted_hosts, response_text): + app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts=trusted_hosts) + async with httpx.AsyncClient( + app=app_with_middleware, base_url="http://testserver" + ) as client: + headers = { + "X-Forwarded-Proto": "https", + "X-Forwarded-For": "1.2.3.4, 10.0.2.1, 192.168.0.2", + } response = await client.get("/", headers=headers) + assert response.status_code == 200 - assert response.text == "Remote: https://1.2.3.4:0" + assert response.text == response_text @pytest.mark.asyncio async def test_proxy_headers_invalid_x_forwarded_for(): - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: + app_with_middleware = ProxyHeadersMiddleware(app, trusted_hosts="*") + async with httpx.AsyncClient( + app=app_with_middleware, base_url="http://testserver" + ) as client: headers = httpx.Headers( { "X-Forwarded-Proto": "https", - "X-Forwarded-For": "\xf0\xfd\xfd\xfd, 1.2.3.4", + "X-Forwarded-For": "1.2.3.4, \xf0\xfd\xfd\xfd", }, encoding="latin-1", ) diff --git a/uvicorn/middleware/proxy_headers.py b/uvicorn/middleware/proxy_headers.py index 05f4f70e1..23901bd73 100644 --- a/uvicorn/middleware/proxy_headers.py +++ b/uvicorn/middleware/proxy_headers.py @@ -8,17 +8,28 @@ https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers#Proxies """ +from typing import List class ProxyHeadersMiddleware: def __init__(self, app, trusted_hosts="127.0.0.1"): self.app = app if isinstance(trusted_hosts, str): - self.trusted_hosts = [item.strip() for item in trusted_hosts.split(",")] + self.trusted_hosts = {item.strip() for item in trusted_hosts.split(",")} else: - self.trusted_hosts = trusted_hosts + self.trusted_hosts = set(trusted_hosts) self.always_trust = "*" in self.trusted_hosts + def get_trusted_client_host( + self, x_forwarded_for_hosts + ): # type: (List[str]) -> str + if self.always_trust: + return x_forwarded_for_hosts[0] + + for host in reversed(x_forwarded_for_hosts): + if host not in self.trusted_hosts: + return host + async def __call__(self, scope, receive, send): if scope["type"] in ("http", "websocket"): client_addr = scope.get("client") @@ -38,7 +49,10 @@ async def __call__(self, scope, receive, send): # X-Forwarded-For header. We've lost the connecting client's port # information by now, so only include the host. x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") - host = x_forwarded_for.split(",")[-1].strip() + x_forwarded_for_hosts = [ + item.strip() for item in x_forwarded_for.split(",") + ] + host = self.get_trusted_client_host(x_forwarded_for_hosts) port = 0 scope["client"] = (host, port)