-
-
Notifications
You must be signed in to change notification settings - Fork 763
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -11,27 +11,35 @@ | |||||||||||||||
|
||||||||||||||||
|
||||||||||||||||
class ProxyHeadersMiddleware: | ||||||||||||||||
def __init__(self, app, num_proxies=1): | ||||||||||||||||
def __init__(self, app, trusted_hosts="127.0.0.1"): | ||||||||||||||||
self.app = app | ||||||||||||||||
self.num_proxies = num_proxies | ||||||||||||||||
if isinstance(trusted_hosts, str): | ||||||||||||||||
self.trusted_hosts = [item.strip() for item in trusted_hosts.split(",")] | ||||||||||||||||
else: | ||||||||||||||||
self.trusted_hosts = trusted_hosts | ||||||||||||||||
self.always_trust = "*" in self.trusted_hosts | ||||||||||||||||
|
||||||||||||||||
async def __call__(self, scope, receive, send): | ||||||||||||||||
if scope["type"] in ("http", "websocket"): | ||||||||||||||||
headers = dict(scope["headers"]) | ||||||||||||||||
|
||||||||||||||||
if b"x-forwarded-proto" in headers: | ||||||||||||||||
# Determine if the incoming request was http or https based on | ||||||||||||||||
# the X-Forwarded-Proto header. | ||||||||||||||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("ascii") | ||||||||||||||||
scope["scheme"] = x_forwarded_proto.strip() | ||||||||||||||||
|
||||||||||||||||
if b"x-forwarded-for" in headers: | ||||||||||||||||
# Determine the client address from the last trusted IP in the | ||||||||||||||||
# X-Forwarded-For header. We've lost the connecting client's port | ||||||||||||||||
# information by now, so only include the host. | ||||||||||||||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("ascii") | ||||||||||||||||
host = x_forwarded_for.split(",")[-self.num_proxies].strip() | ||||||||||||||||
port = 0 | ||||||||||||||||
scope["client"] = (host, port) | ||||||||||||||||
|
||||||||||||||||
await self.app(scope, receive, send) | ||||||||||||||||
client_addr = scope.get("client") | ||||||||||||||||
client_host = client_addr[0] if client_addr else None | ||||||||||||||||
|
||||||||||||||||
if self.always_trust or client_host in self.trusted_hosts: | ||||||||||||||||
headers = dict(scope["headers"]) | ||||||||||||||||
|
||||||||||||||||
if b"x-forwarded-proto" in headers: | ||||||||||||||||
# Determine if the incoming request was http or https based on | ||||||||||||||||
# the X-Forwarded-Proto header. | ||||||||||||||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("ascii") | ||||||||||||||||
scope["scheme"] = x_forwarded_proto.strip() | ||||||||||||||||
|
||||||||||||||||
if b"x-forwarded-for" in headers: | ||||||||||||||||
# Determine the client address from the last trusted IP in the | ||||||||||||||||
# X-Forwarded-For header. We've lost the connecting client's port | ||||||||||||||||
# information by now, so only include the host. | ||||||||||||||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("ascii") | ||||||||||||||||
host = x_forwarded_for.split(",")[-1].strip() | ||||||||||||||||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong.
tomchristie
Author
Member
|
def get_trusted_client_host(self, x_forwarded_for_hosts): # type: (List[str]) -> str | |
if self.always_trust: | |
return x_forwarded_for_hosts[0] | |
for host in reversed(x_forwarded_for_hosts): | |
if host not in self.trusted_hosts: | |
return host |
Previous code was working for arbitrary number of proxies in the chain, but new code only designed to work for a single proxy?
[-1] will extract the IP of the host connecting to the last proxy. Which in case of a single proxy will be client; in case of multiple proxies it will be a proxy before the last one. https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For