From f3190f84cce874117136d277d4bcb8f2c62dcc9a Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Fri, 23 Jan 2026 18:40:44 +0100 Subject: [PATCH] feat: add PROXY protocol v2 support with version selection (#3451) Extend --proxy-protocol to accept version values (off, v1, v2, auto) instead of being boolean-only. This allows explicit control over which PROXY protocol versions are accepted. Changes: - Add InvalidProxyHeader exception for v2 binary header errors - Add validate_proxy_protocol() validator with backwards compatibility - Update ProxyProtocol setting with nargs="?" and const="auto" - Add PROXY v2 constants (PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol) - Add _parse_proxy_protocol_v1() and _parse_proxy_protocol_v2() methods - Update both sync (message.py) and async (asgi/message.py) parsers - Add hex escape handling in treq.py for v2 binary test data - Add test cases for v2 TCPv4 and TCPv6 Backwards compatible: --proxy-protocol alone (or True) maps to "auto". Closes #2912 --- docs/content/reference/settings.md | 26 +++- gunicorn/asgi/message.py | 217 +++++++++++++++++++++------ gunicorn/config.py | 54 ++++++- gunicorn/http/errors.py | 9 ++ gunicorn/http/message.py | 227 ++++++++++++++++++++++++----- tests/requests/valid/pp_03.http | 4 + tests/requests/valid/pp_03.py | 15 ++ tests/requests/valid/pp_04.http | 4 + tests/requests/valid/pp_04.py | 15 ++ tests/requests/valid/pp_05.http | 4 + tests/requests/valid/pp_05.py | 15 ++ tests/test_asgi.py | 7 +- tests/test_gthread.py | 2 +- tests/treq.py | 25 +++- 14 files changed, 522 insertions(+), 102 deletions(-) create mode 100644 tests/requests/valid/pp_03.http create mode 100644 tests/requests/valid/pp_03.py create mode 100644 tests/requests/valid/pp_04.http create mode 100644 tests/requests/valid/pp_04.py create mode 100644 tests/requests/valid/pp_05.http create mode 100644 tests/requests/valid/pp_05.py diff --git a/docs/content/reference/settings.md b/docs/content/reference/settings.md index 7fc136f9..79aecefa 100644 --- a/docs/content/reference/settings.md +++ b/docs/content/reference/settings.md @@ -1148,16 +1148,27 @@ command line arguments to control server configuration instead. ### `proxy_protocol` -**Command line:** `--proxy-protocol` +**Command line:** `--proxy-protocol MODE` -**Default:** `False` +**Default:** `'off'` -Enable detect PROXY protocol (PROXY mode). +Enable PROXY protocol support. -Allow using HTTP and Proxy together. It may be useful for work with -stunnel as HTTPS frontend and Gunicorn as HTTP server. +Allow using HTTP and PROXY protocol together. It may be useful for work +with stunnel as HTTPS frontend and Gunicorn as HTTP server, or with +HAProxy. -PROXY protocol: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt +Accepted values: + +* ``off`` - Disabled (default) +* ``v1`` - PROXY protocol v1 only (text format) +* ``v2`` - PROXY protocol v2 only (binary format) +* ``auto`` - Auto-detect v1 or v2 + +Using ``--proxy-protocol`` without a value is equivalent to ``auto``. + +PROXY protocol v1: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt +PROXY protocol v2: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt Example for stunnel config:: @@ -1168,6 +1179,9 @@ Example for stunnel config:: cert = /etc/ssl/certs/stunnel.pem key = /etc/ssl/certs/stunnel.key +!!! info "Changed in 24.0.0" + Extended to support version selection (v1, v2, auto). + ### `proxy_allow_ips` **Command line:** `--proxy-allow-from` diff --git a/gunicorn/asgi/message.py b/gunicorn/asgi/message.py index a2d8e825..1bb26b99 100644 --- a/gunicorn/asgi/message.py +++ b/gunicorn/asgi/message.py @@ -9,17 +9,22 @@ Reuses the parsing logic from the sync version, adapted for async I/O. """ import io +import ipaddress import re import socket +import struct from gunicorn.http.errors import ( InvalidHeader, InvalidHeaderName, NoMoreData, InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, LimitRequestLine, LimitRequestHeaders, UnsupportedTransferCoding, ObsoleteFolding, - InvalidProxyLine, ForbiddenProxyRequest, + InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest, InvalidSchemeHeaders, ) +from gunicorn.http.message import ( + PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol +) from gunicorn.util import bytes_to_str, split_request_uri MAX_REQUEST_LINE = 8190 @@ -34,6 +39,22 @@ VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)") RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]") +def _ip_in_allow_list(ip_str, allow_list): + """Check if IP address is in the allow list (which may contain networks).""" + if '*' in allow_list: + return True + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + return False + for network in allow_list: + if network == '*': + return True + if ip in network: + return True + return False + + class AsyncRequest: """Async HTTP request parser. @@ -111,33 +132,29 @@ class AsyncRequest: async def _parse(self): """Parse the request from the unreader.""" - buf = io.BytesIO() - await self._get_data(buf, stop=True) + buf = bytearray() + await self._read_into(buf, stop=True) + + # Handle proxy protocol if enabled and this is the first request + mode = self.cfg.proxy_protocol + if mode != "off" and self.req_number == 1: + buf = await self._handle_proxy_protocol(buf, mode) # Get request line - line, rbuf = await self._read_line(buf, self.limit_request_line) - - # Proxy protocol - if self._proxy_protocol(bytes_to_str(line)): - # Get next request line - buf = io.BytesIO() - buf.write(rbuf) - line, rbuf = await self._read_line(buf, self.limit_request_line) + line, buf = await self._read_line(buf, self.limit_request_line) self._parse_request_line(line) - buf = io.BytesIO() - buf.write(rbuf) # Headers - data = buf.getvalue() + data = bytes(buf) while True: idx = data.find(b"\r\n\r\n") done = data[:2] == b"\r\n" if idx < 0 and not done: - await self._get_data(buf) - data = buf.getvalue() + await self._read_into(buf) + data = bytes(buf) if len(data) > self.max_buffer_headers: raise LimitRequestHeaders("max buffer headers") else: @@ -151,18 +168,18 @@ class AsyncRequest: self._set_body_reader() - async def _get_data(self, buf, stop=False): - """Read data from unreader into buffer.""" + async def _read_into(self, buf, stop=False): + """Read data from unreader and append to bytearray buffer.""" data = await self.unreader.read() if not data: if stop: raise StopIteration() - raise NoMoreData(buf.getvalue()) - buf.write(data) + raise NoMoreData(bytes(buf)) + buf.extend(data) async def _read_line(self, buf, limit=0): - """Read a line from the buffer/stream.""" - data = buf.getvalue() + """Read a line from buffer, returning (line, remaining_buffer).""" + data = bytes(buf) while True: idx = data.find(b"\r\n") @@ -172,36 +189,54 @@ class AsyncRequest: break if len(data) - 2 > limit > 0: raise LimitRequestLine(len(data), limit) - await self._get_data(buf) - data = buf.getvalue() + await self._read_into(buf) + data = bytes(buf) - return (data[:idx], data[idx + 2:]) + return (data[:idx], bytearray(data[idx + 2:])) - def _proxy_protocol(self, line): - """Detect, check and parse proxy protocol.""" - if not self.cfg.proxy_protocol: - return False + async def _handle_proxy_protocol(self, buf, mode): + """Handle PROXY protocol detection and parsing. - if self.req_number != 1: - return False + Returns the buffer with proxy protocol data consumed. + """ + # Ensure we have enough data to detect v2 signature (12 bytes) + while len(buf) < 12: + await self._read_into(buf) - if not line.startswith("PROXY"): - return False + # Check for v2 signature first + if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE: + self._proxy_protocol_access_check() + return await self._parse_proxy_protocol_v2(buf) - self._proxy_protocol_access_check() - self._parse_proxy_protocol(line) + # Check for v1 prefix + if mode in ("v1", "auto") and buf[:6] == b"PROXY ": + self._proxy_protocol_access_check() + return await self._parse_proxy_protocol_v1(buf) - return True + # Not proxy protocol - return buffer unchanged + return buf def _proxy_protocol_access_check(self): """Check if proxy protocol is allowed from this peer.""" - if ("*" not in self.cfg.proxy_allow_ips and - isinstance(self.peer_addr, tuple) and - self.peer_addr[0] not in self.cfg.proxy_allow_ips): + if (isinstance(self.peer_addr, tuple) and + not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips)): raise ForbiddenProxyRequest(self.peer_addr[0]) - def _parse_proxy_protocol(self, line): - """Parse proxy protocol header line.""" + async def _parse_proxy_protocol_v1(self, buf): + """Parse PROXY protocol v1 (text format). + + Returns buffer with v1 header consumed. + """ + # Read until we find \r\n + data = bytes(buf) + while b"\r\n" not in data: + await self._read_into(buf) + data = bytes(buf) + + idx = data.find(b"\r\n") + line = bytes_to_str(data[:idx]) + remaining = bytearray(data[idx + 2:]) + bits = line.split(" ") if len(bits) != 6: @@ -244,6 +279,101 @@ class AsyncRequest: "proxy_port": d_port } + return remaining + + async def _parse_proxy_protocol_v2(self, buf): + """Parse PROXY protocol v2 (binary format). + + Returns buffer with v2 header consumed. + """ + # We need at least 16 bytes for the header (12 signature + 4 header) + while len(buf) < 16: + await self._read_into(buf) + + # Parse header fields (after 12-byte signature) + ver_cmd = buf[12] + fam_proto = buf[13] + length = struct.unpack(">H", bytes(buf[14:16]))[0] + + # Validate version (high nibble must be 0x2) + version = (ver_cmd & 0xF0) >> 4 + if version != 2: + raise InvalidProxyHeader("unsupported version %d" % version) + + # Extract command (low nibble) + command = ver_cmd & 0x0F + if command not in (PPCommand.LOCAL, PPCommand.PROXY): + raise InvalidProxyHeader("unsupported command %d" % command) + + # Ensure we have the complete header + total_header_size = 16 + length + while len(buf) < total_header_size: + await self._read_into(buf) + + # For LOCAL command, no address info is provided + if command == PPCommand.LOCAL: + self.proxy_protocol_info = { + "proxy_protocol": "LOCAL", + "client_addr": None, + "client_port": None, + "proxy_addr": None, + "proxy_port": None + } + return bytearray(buf[total_header_size:]) + + # Extract address family and protocol + family = (fam_proto & 0xF0) >> 4 + protocol = fam_proto & 0x0F + + # We only support TCP (STREAM) + if protocol != PPProtocol.STREAM: + raise InvalidProxyHeader("only TCP protocol is supported") + + addr_data = bytes(buf[16:16 + length]) + + if family == PPFamily.INET: # IPv4 + if length < 12: # 4+4+2+2 + raise InvalidProxyHeader("insufficient address data for IPv4") + s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4]) + d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8]) + s_port = struct.unpack(">H", addr_data[8:10])[0] + d_port = struct.unpack(">H", addr_data[10:12])[0] + proto = "TCP4" + + elif family == PPFamily.INET6: # IPv6 + if length < 36: # 16+16+2+2 + raise InvalidProxyHeader("insufficient address data for IPv6") + s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16]) + d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32]) + s_port = struct.unpack(">H", addr_data[32:34])[0] + d_port = struct.unpack(">H", addr_data[34:36])[0] + proto = "TCP6" + + elif family == PPFamily.UNSPEC: + # No address info provided with PROXY command + self.proxy_protocol_info = { + "proxy_protocol": "UNSPEC", + "client_addr": None, + "client_port": None, + "proxy_addr": None, + "proxy_port": None + } + return bytearray(buf[total_header_size:]) + + else: + raise InvalidProxyHeader("unsupported address family %d" % family) + + # Set data + self.proxy_protocol_info = { + "proxy_protocol": proto, + "client_addr": s_addr, + "client_port": s_port, + "proxy_addr": d_addr, + "proxy_port": d_port + } + + return bytearray(buf[total_header_size:]) + def _parse_request_line(self, line_bytes): """Parse the HTTP request line.""" bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] @@ -299,9 +429,8 @@ class AsyncRequest: forwarder_headers = [] if from_trailer: pass - elif ('*' in cfg.forwarded_allow_ips or - not isinstance(self.peer_addr, tuple) - or self.peer_addr[0] in cfg.forwarded_allow_ips): + elif (not isinstance(self.peer_addr, tuple) + or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips)): secure_scheme_headers = cfg.secure_scheme_headers forwarder_headers = cfg.forwarder_headers diff --git a/gunicorn/config.py b/gunicorn/config.py index 663799f2..700c9429 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2082,20 +2082,57 @@ class NewSSLContext(Setting): """ +def validate_proxy_protocol(val): + """Validate proxy_protocol setting. + + Accepts: off, false, v1, v2, auto, true + Returns normalized value: off, v1, v2, or auto + """ + if val is None: + return "off" + if isinstance(val, bool): + return "auto" if val else "off" + if not isinstance(val, str): + raise TypeError("proxy_protocol must be string or bool") + + val = val.lower().strip() + mapping = { + "false": "off", "off": "off", "0": "off", "none": "off", + "true": "auto", "auto": "auto", "1": "auto", + "v1": "v1", "v2": "v2", + } + if val not in mapping: + raise ValueError("proxy_protocol must be: off, v1, v2, or auto") + return mapping[val] + + class ProxyProtocol(Setting): name = "proxy_protocol" section = "Server Mechanics" cli = ["--proxy-protocol"] - validator = validate_bool - default = False - action = "store_true" + meta = "MODE" + validator = validate_proxy_protocol + default = "off" + nargs = "?" + const = "auto" desc = """\ - Enable detect PROXY protocol (PROXY mode). + Enable PROXY protocol support. - Allow using HTTP and Proxy together. It may be useful for work with - stunnel as HTTPS frontend and Gunicorn as HTTP server. + Allow using HTTP and PROXY protocol together. It may be useful for work + with stunnel as HTTPS frontend and Gunicorn as HTTP server, or with + HAProxy. - PROXY protocol: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt + Accepted values: + + * ``off`` - Disabled (default) + * ``v1`` - PROXY protocol v1 only (text format) + * ``v2`` - PROXY protocol v2 only (binary format) + * ``auto`` - Auto-detect v1 or v2 + + Using ``--proxy-protocol`` without a value is equivalent to ``auto``. + + PROXY protocol v1: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt + PROXY protocol v2: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt Example for stunnel config:: @@ -2105,6 +2142,9 @@ class ProxyProtocol(Setting): connect = 80 cert = /etc/ssl/certs/stunnel.pem key = /etc/ssl/certs/stunnel.key + + .. versionchanged:: 24.0.0 + Extended to support version selection (v1, v2, auto). """ diff --git a/gunicorn/http/errors.py b/gunicorn/http/errors.py index bcb97007..e9c24917 100644 --- a/gunicorn/http/errors.py +++ b/gunicorn/http/errors.py @@ -131,6 +131,15 @@ class InvalidProxyLine(ParseException): return "Invalid PROXY line: %r" % self.line +class InvalidProxyHeader(ParseException): + def __init__(self, msg): + self.msg = msg + self.code = 400 + + def __str__(self): + return "Invalid PROXY header: %s" % self.msg + + class ForbiddenProxyRequest(ParseException): def __init__(self, host): self.host = host diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index 4e8dd444..81132b34 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -2,10 +2,11 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. -import io +from enum import IntEnum import ipaddress import re import socket +import struct from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body from gunicorn.http.errors import ( @@ -14,10 +15,36 @@ from gunicorn.http.errors import ( LimitRequestLine, LimitRequestHeaders, UnsupportedTransferCoding, ObsoleteFolding, ) -from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest +from gunicorn.http.errors import InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest from gunicorn.http.errors import InvalidSchemeHeaders from gunicorn.util import bytes_to_str, split_request_uri + +# PROXY protocol v2 constants +PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" + + +class PPCommand(IntEnum): + """PROXY protocol v2 commands.""" + LOCAL = 0x0 + PROXY = 0x1 + + +class PPFamily(IntEnum): + """PROXY protocol v2 address families.""" + UNSPEC = 0x0 + INET = 0x1 # IPv4 + INET6 = 0x2 # IPv6 + UNIX = 0x3 + + +class PPProtocol(IntEnum): + """PROXY protocol v2 transport protocols.""" + UNSPEC = 0x0 + STREAM = 0x1 # TCP + DGRAM = 0x2 # UDP + + MAX_REQUEST_LINE = 8190 MAX_HEADERS = 32768 DEFAULT_MAX_HEADERFIELD_SIZE = 8190 @@ -283,26 +310,21 @@ class Request(Message): buf.write(data) def parse(self, unreader): - buf = io.BytesIO() - self.get_data(unreader, buf, stop=True) + buf = bytearray() + self.read_into(unreader, buf, stop=True) - # get request line - line, rbuf = self.read_line(unreader, buf, self.limit_request_line) + # Handle proxy protocol if enabled and this is the first request + mode = self.cfg.proxy_protocol + if mode != "off" and self.req_number == 1: + buf = self._handle_proxy_protocol(unreader, buf, mode) - # proxy protocol - if self.proxy_protocol(bytes_to_str(line)): - # get next request line - buf = io.BytesIO() - buf.write(rbuf) - line, rbuf = self.read_line(unreader, buf, self.limit_request_line) + # Get request line + line, buf = self.read_line(unreader, buf, self.limit_request_line) self.parse_request_line(line) - buf = io.BytesIO() - buf.write(rbuf) # Headers - data = buf.getvalue() - idx = data.find(b"\r\n\r\n") + data = bytes(buf) done = data[:2] == b"\r\n" while True: @@ -310,8 +332,8 @@ class Request(Message): done = data[:2] == b"\r\n" if idx < 0 and not done: - self.get_data(unreader, buf) - data = buf.getvalue() + self.read_into(unreader, buf) + data = bytes(buf) if len(data) > self.max_buffer_headers: raise LimitRequestHeaders("max buffer headers") else: @@ -324,11 +346,20 @@ class Request(Message): self.headers = self.parse_headers(data[:idx], from_trailer=False) ret = data[idx + 4:] - buf = None return ret + def read_into(self, unreader, buf, stop=False): + """Read data from unreader and append to bytearray buffer.""" + data = unreader.read() + if not data: + if stop: + raise StopIteration() + raise NoMoreData(bytes(buf)) + buf.extend(data) + def read_line(self, unreader, buf, limit=0): - data = buf.getvalue() + """Read a line from buffer, returning (line, remaining_buffer).""" + data = bytes(buf) while True: idx = data.find(b"\r\n") @@ -339,40 +370,61 @@ class Request(Message): break if len(data) - 2 > limit > 0: raise LimitRequestLine(len(data), limit) - self.get_data(unreader, buf) - data = buf.getvalue() + self.read_into(unreader, buf) + data = bytes(buf) return (data[:idx], # request line, - data[idx + 2:]) # residue in the buffer, skip \r\n + bytearray(data[idx + 2:])) # residue in the buffer, skip \r\n - def proxy_protocol(self, line): - """\ - Detect, check and parse proxy protocol. + def read_bytes(self, unreader, buf, count): + """Read exactly count bytes from buffer/unreader.""" + while len(buf) < count: + self.read_into(unreader, buf) + return bytes(buf[:count]), bytearray(buf[count:]) - :raises: ForbiddenProxyRequest, InvalidProxyLine. - :return: True for proxy protocol line else False + def _handle_proxy_protocol(self, unreader, buf, mode): + """Handle PROXY protocol detection and parsing. + + Returns the buffer with proxy protocol data consumed. """ - if not self.cfg.proxy_protocol: - return False + # Ensure we have enough data to detect v2 signature (12 bytes) + while len(buf) < 12: + self.read_into(unreader, buf) - if self.req_number != 1: - return False + # Check for v2 signature first + if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE: + self.proxy_protocol_access_check() + return self._parse_proxy_protocol_v2(unreader, buf) - if not line.startswith("PROXY"): - return False + # Check for v1 prefix + if mode in ("v1", "auto") and buf[:6] == b"PROXY ": + self.proxy_protocol_access_check() + return self._parse_proxy_protocol_v1(unreader, buf) - self.proxy_protocol_access_check() - self.parse_proxy_protocol(line) - - return True + # Not proxy protocol - return buffer unchanged + return buf def proxy_protocol_access_check(self): - # check in allow list + """Check if proxy protocol is allowed from this peer.""" if (isinstance(self.peer_addr, tuple) and not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips)): raise ForbiddenProxyRequest(self.peer_addr[0]) - def parse_proxy_protocol(self, line): + def _parse_proxy_protocol_v1(self, unreader, buf): + """Parse PROXY protocol v1 (text format). + + Returns buffer with v1 header consumed. + """ + # Read until we find \r\n + data = bytes(buf) + while b"\r\n" not in data: + self.read_into(unreader, buf) + data = bytes(buf) + + idx = data.find(b"\r\n") + line = bytes_to_str(data[:idx]) + remaining = bytearray(data[idx + 2:]) + bits = line.split(" ") if len(bits) != 6: @@ -417,6 +469,101 @@ class Request(Message): "proxy_port": d_port } + return remaining + + def _parse_proxy_protocol_v2(self, unreader, buf): + """Parse PROXY protocol v2 (binary format). + + Returns buffer with v2 header consumed. + """ + # We need at least 16 bytes for the header (12 signature + 4 header) + while len(buf) < 16: + self.read_into(unreader, buf) + + # Parse header fields (after 12-byte signature) + ver_cmd = buf[12] + fam_proto = buf[13] + length = struct.unpack(">H", bytes(buf[14:16]))[0] + + # Validate version (high nibble must be 0x2) + version = (ver_cmd & 0xF0) >> 4 + if version != 2: + raise InvalidProxyHeader("unsupported version %d" % version) + + # Extract command (low nibble) + command = ver_cmd & 0x0F + if command not in (PPCommand.LOCAL, PPCommand.PROXY): + raise InvalidProxyHeader("unsupported command %d" % command) + + # Ensure we have the complete header + total_header_size = 16 + length + while len(buf) < total_header_size: + self.read_into(unreader, buf) + + # For LOCAL command, no address info is provided + if command == PPCommand.LOCAL: + self.proxy_protocol_info = { + "proxy_protocol": "LOCAL", + "client_addr": None, + "client_port": None, + "proxy_addr": None, + "proxy_port": None + } + return bytearray(buf[total_header_size:]) + + # Extract address family and protocol + family = (fam_proto & 0xF0) >> 4 + protocol = fam_proto & 0x0F + + # We only support TCP (STREAM) + if protocol != PPProtocol.STREAM: + raise InvalidProxyHeader("only TCP protocol is supported") + + addr_data = bytes(buf[16:16 + length]) + + if family == PPFamily.INET: # IPv4 + if length < 12: # 4+4+2+2 + raise InvalidProxyHeader("insufficient address data for IPv4") + s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4]) + d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8]) + s_port = struct.unpack(">H", addr_data[8:10])[0] + d_port = struct.unpack(">H", addr_data[10:12])[0] + proto = "TCP4" + + elif family == PPFamily.INET6: # IPv6 + if length < 36: # 16+16+2+2 + raise InvalidProxyHeader("insufficient address data for IPv6") + s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16]) + d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32]) + s_port = struct.unpack(">H", addr_data[32:34])[0] + d_port = struct.unpack(">H", addr_data[34:36])[0] + proto = "TCP6" + + elif family == PPFamily.UNSPEC: + # No address info provided with PROXY command + self.proxy_protocol_info = { + "proxy_protocol": "UNSPEC", + "client_addr": None, + "client_port": None, + "proxy_addr": None, + "proxy_port": None + } + return bytearray(buf[total_header_size:]) + + else: + raise InvalidProxyHeader("unsupported address family %d" % family) + + # Set data + self.proxy_protocol_info = { + "proxy_protocol": proto, + "client_addr": s_addr, + "client_port": s_port, + "proxy_addr": d_addr, + "proxy_port": d_port + } + + return bytearray(buf[total_header_size:]) + def parse_request_line(self, line_bytes): bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] if len(bits) != 3: diff --git a/tests/requests/valid/pp_03.http b/tests/requests/valid/pp_03.http new file mode 100644 index 00000000..5a2f784f --- /dev/null +++ b/tests/requests/valid/pp_03.http @@ -0,0 +1,4 @@ +GET /no/proxy/header HTTP/1.1\r\n +Host: example.com\r\n +Content-Length: 0\r\n +\r\n diff --git a/tests/requests/valid/pp_03.py b/tests/requests/valid/pp_03.py new file mode 100644 index 00000000..70112876 --- /dev/null +++ b/tests/requests/valid/pp_03.py @@ -0,0 +1,15 @@ +from gunicorn.config import Config + +cfg = Config() +cfg.set("proxy_protocol", True) + +request = { + "method": "GET", + "uri": uri("/no/proxy/header"), + "version": (1, 1), + "headers": [ + ("HOST", "example.com"), + ("CONTENT-LENGTH", "0") + ], + "body": b"" +} diff --git a/tests/requests/valid/pp_04.http b/tests/requests/valid/pp_04.http new file mode 100644 index 00000000..f4e9ec95 --- /dev/null +++ b/tests/requests/valid/pp_04.http @@ -0,0 +1,4 @@ +\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21\x11\x00\x0C\xC0\xA8\x01\x0A\xC0\xA8\x01\x01\x30\x39\x01\xBBGET /proxy/v2/ipv4 HTTP/1.1\r\n +Host: example.com\r\n +Content-Length: 0\r\n +\r\n diff --git a/tests/requests/valid/pp_04.py b/tests/requests/valid/pp_04.py new file mode 100644 index 00000000..cbf6e7a8 --- /dev/null +++ b/tests/requests/valid/pp_04.py @@ -0,0 +1,15 @@ +from gunicorn.config import Config + +cfg = Config() +cfg.set("proxy_protocol", True) + +request = { + "method": "GET", + "uri": uri("/proxy/v2/ipv4"), + "version": (1, 1), + "headers": [ + ("HOST", "example.com"), + ("CONTENT-LENGTH", "0") + ], + "body": b"" +} diff --git a/tests/requests/valid/pp_05.http b/tests/requests/valid/pp_05.http new file mode 100644 index 00000000..616bde29 --- /dev/null +++ b/tests/requests/valid/pp_05.http @@ -0,0 +1,4 @@ +\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21\x21\x00\x24\x20\x01\x0D\xB8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x20\x01\x0D\xB8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xD4\x31\x00\x50GET /proxy/v2/ipv6 HTTP/1.1\r\n +Host: example.com\r\n +Content-Length: 0\r\n +\r\n diff --git a/tests/requests/valid/pp_05.py b/tests/requests/valid/pp_05.py new file mode 100644 index 00000000..80e2b764 --- /dev/null +++ b/tests/requests/valid/pp_05.py @@ -0,0 +1,15 @@ +from gunicorn.config import Config + +cfg = Config() +cfg.set("proxy_protocol", True) + +request = { + "method": "GET", + "uri": uri("/proxy/v2/ipv6"), + "version": (1, 1), + "headers": [ + ("HOST", "example.com"), + ("CONTENT-LENGTH", "0") + ], + "body": b"" +} diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 227f7ea2..e39ae91a 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -8,6 +8,7 @@ Tests for ASGI worker components. import asyncio import io +import ipaddress import pytest from unittest import mock @@ -48,9 +49,9 @@ class MockConfig: def __init__(self): self.is_ssl = False - self.proxy_protocol = False - self.proxy_allow_ips = ["127.0.0.1"] - self.forwarded_allow_ips = ["127.0.0.1"] + self.proxy_protocol = "off" + self.proxy_allow_ips = [ipaddress.ip_network("127.0.0.1")] + self.forwarded_allow_ips = [ipaddress.ip_network("127.0.0.1")] self.secure_scheme_headers = {} self.forwarder_headers = [] self.limit_request_line = 8190 diff --git a/tests/test_gthread.py b/tests/test_gthread.py index 0762cc99..b8839fa1 100644 --- a/tests/test_gthread.py +++ b/tests/test_gthread.py @@ -1385,7 +1385,7 @@ class TestKeepaliveBlockingMode: conn.parser = mock_parser # Mock handle_request to invoke wsgi - original_handle_request = worker.handle_request + _ = worker.handle_request # save reference before overwriting def mock_handle_request(req, conn): # Simplified version that just calls wsgi diff --git a/tests/treq.py b/tests/treq.py index fbe54700..e341780c 100644 --- a/tests/treq.py +++ b/tests/treq.py @@ -39,6 +39,27 @@ def load_py(fname): return vars(mod) +def decode_hex_escapes(data): + """Decode hex escape sequences like \\xAB in test data.""" + import re + result = bytearray() + i = 0 + while i < len(data): + # Check for \xHH hex escape + if i + 3 < len(data) and data[i:i+2] == b'\\x': + hex_chars = data[i+2:i+4] + try: + byte_val = int(hex_chars, 16) + result.append(byte_val) + i += 4 + continue + except ValueError: + pass + result.append(data[i]) + i += 1 + return bytes(result) + + class request: def __init__(self, fname, expect): self.fname = fname @@ -52,8 +73,10 @@ class request: self.data = handle.read() self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n") self.data = self.data.replace(b"\\0", b"\000").replace(b"\\n", b"\n").replace(b"\\t", b"\t") + # Handle hex escape sequences for binary data (e.g., \x0D for PROXY v2) + self.data = decode_hex_escapes(self.data) if b"\\" in self.data: - raise AssertionError("Unexpected backslash in test data - only handling HTAB, NUL and CRLF") + raise AssertionError("Unexpected backslash in test data - only handling HTAB, NUL, CRLF, and hex escapes") # Functions for sending data to the parser. # These functions mock out reading from a