feat: add PROXY protocol v2 support with version selection (#3451)

Extend --proxy-protocol to accept version values (off, v1, v2, auto) instead
of being boolean-only. This allows explicit control over which PROXY protocol
versions are accepted.

Changes:
- Add InvalidProxyHeader exception for v2 binary header errors
- Add validate_proxy_protocol() validator with backwards compatibility
- Update ProxyProtocol setting with nargs="?" and const="auto"
- Add PROXY v2 constants (PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol)
- Add _parse_proxy_protocol_v1() and _parse_proxy_protocol_v2() methods
- Update both sync (message.py) and async (asgi/message.py) parsers
- Add hex escape handling in treq.py for v2 binary test data
- Add test cases for v2 TCPv4 and TCPv6

Backwards compatible: --proxy-protocol alone (or True) maps to "auto".

Closes #2912
This commit is contained in:
Benoit Chesneau 2026-01-23 18:40:44 +01:00 committed by GitHub
parent f95ac41b8f
commit f3190f84cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 522 additions and 102 deletions

View File

@ -1148,16 +1148,27 @@ command line arguments to control server configuration instead.
### `proxy_protocol` ### `proxy_protocol`
**Command line:** `--proxy-protocol` **Command line:** `--proxy-protocol MODE`
**Default:** `False` **Default:** `'off'`
Enable detect PROXY protocol (PROXY mode). Enable PROXY protocol support.
Allow using HTTP and Proxy together. It may be useful for work with Allow using HTTP and PROXY protocol together. It may be useful for work
stunnel as HTTPS frontend and Gunicorn as HTTP server. with stunnel as HTTPS frontend and Gunicorn as HTTP server, or with
HAProxy.
PROXY protocol: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt Accepted values:
* ``off`` - Disabled (default)
* ``v1`` - PROXY protocol v1 only (text format)
* ``v2`` - PROXY protocol v2 only (binary format)
* ``auto`` - Auto-detect v1 or v2
Using ``--proxy-protocol`` without a value is equivalent to ``auto``.
PROXY protocol v1: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt
PROXY protocol v2: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
Example for stunnel config:: Example for stunnel config::
@ -1168,6 +1179,9 @@ Example for stunnel config::
cert = /etc/ssl/certs/stunnel.pem cert = /etc/ssl/certs/stunnel.pem
key = /etc/ssl/certs/stunnel.key key = /etc/ssl/certs/stunnel.key
!!! info "Changed in 24.0.0"
Extended to support version selection (v1, v2, auto).
### `proxy_allow_ips` ### `proxy_allow_ips`
**Command line:** `--proxy-allow-from` **Command line:** `--proxy-allow-from`

View File

@ -9,17 +9,22 @@ Reuses the parsing logic from the sync version, adapted for async I/O.
""" """
import io import io
import ipaddress
import re import re
import socket import socket
import struct
from gunicorn.http.errors import ( from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData, InvalidHeader, InvalidHeaderName, NoMoreData,
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
LimitRequestLine, LimitRequestHeaders, LimitRequestLine, LimitRequestHeaders,
UnsupportedTransferCoding, ObsoleteFolding, UnsupportedTransferCoding, ObsoleteFolding,
InvalidProxyLine, ForbiddenProxyRequest, InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest,
InvalidSchemeHeaders, InvalidSchemeHeaders,
) )
from gunicorn.http.message import (
PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol
)
from gunicorn.util import bytes_to_str, split_request_uri from gunicorn.util import bytes_to_str, split_request_uri
MAX_REQUEST_LINE = 8190 MAX_REQUEST_LINE = 8190
@ -34,6 +39,22 @@ VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]") RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
def _ip_in_allow_list(ip_str, allow_list):
"""Check if IP address is in the allow list (which may contain networks)."""
if '*' in allow_list:
return True
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
return False
for network in allow_list:
if network == '*':
return True
if ip in network:
return True
return False
class AsyncRequest: class AsyncRequest:
"""Async HTTP request parser. """Async HTTP request parser.
@ -111,33 +132,29 @@ class AsyncRequest:
async def _parse(self): async def _parse(self):
"""Parse the request from the unreader.""" """Parse the request from the unreader."""
buf = io.BytesIO() buf = bytearray()
await self._get_data(buf, stop=True) await self._read_into(buf, stop=True)
# Handle proxy protocol if enabled and this is the first request
mode = self.cfg.proxy_protocol
if mode != "off" and self.req_number == 1:
buf = await self._handle_proxy_protocol(buf, mode)
# Get request line # Get request line
line, rbuf = await self._read_line(buf, self.limit_request_line) line, buf = await self._read_line(buf, self.limit_request_line)
# Proxy protocol
if self._proxy_protocol(bytes_to_str(line)):
# Get next request line
buf = io.BytesIO()
buf.write(rbuf)
line, rbuf = await self._read_line(buf, self.limit_request_line)
self._parse_request_line(line) self._parse_request_line(line)
buf = io.BytesIO()
buf.write(rbuf)
# Headers # Headers
data = buf.getvalue() data = bytes(buf)
while True: while True:
idx = data.find(b"\r\n\r\n") idx = data.find(b"\r\n\r\n")
done = data[:2] == b"\r\n" done = data[:2] == b"\r\n"
if idx < 0 and not done: if idx < 0 and not done:
await self._get_data(buf) await self._read_into(buf)
data = buf.getvalue() data = bytes(buf)
if len(data) > self.max_buffer_headers: if len(data) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers") raise LimitRequestHeaders("max buffer headers")
else: else:
@ -151,18 +168,18 @@ class AsyncRequest:
self._set_body_reader() self._set_body_reader()
async def _get_data(self, buf, stop=False): async def _read_into(self, buf, stop=False):
"""Read data from unreader into buffer.""" """Read data from unreader and append to bytearray buffer."""
data = await self.unreader.read() data = await self.unreader.read()
if not data: if not data:
if stop: if stop:
raise StopIteration() raise StopIteration()
raise NoMoreData(buf.getvalue()) raise NoMoreData(bytes(buf))
buf.write(data) buf.extend(data)
async def _read_line(self, buf, limit=0): async def _read_line(self, buf, limit=0):
"""Read a line from the buffer/stream.""" """Read a line from buffer, returning (line, remaining_buffer)."""
data = buf.getvalue() data = bytes(buf)
while True: while True:
idx = data.find(b"\r\n") idx = data.find(b"\r\n")
@ -172,36 +189,54 @@ class AsyncRequest:
break break
if len(data) - 2 > limit > 0: if len(data) - 2 > limit > 0:
raise LimitRequestLine(len(data), limit) raise LimitRequestLine(len(data), limit)
await self._get_data(buf) await self._read_into(buf)
data = buf.getvalue() data = bytes(buf)
return (data[:idx], data[idx + 2:]) return (data[:idx], bytearray(data[idx + 2:]))
def _proxy_protocol(self, line): async def _handle_proxy_protocol(self, buf, mode):
"""Detect, check and parse proxy protocol.""" """Handle PROXY protocol detection and parsing.
if not self.cfg.proxy_protocol:
return False
if self.req_number != 1: Returns the buffer with proxy protocol data consumed.
return False """
# Ensure we have enough data to detect v2 signature (12 bytes)
while len(buf) < 12:
await self._read_into(buf)
if not line.startswith("PROXY"): # Check for v2 signature first
return False if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE:
self._proxy_protocol_access_check()
return await self._parse_proxy_protocol_v2(buf)
self._proxy_protocol_access_check() # Check for v1 prefix
self._parse_proxy_protocol(line) if mode in ("v1", "auto") and buf[:6] == b"PROXY ":
self._proxy_protocol_access_check()
return await self._parse_proxy_protocol_v1(buf)
return True # Not proxy protocol - return buffer unchanged
return buf
def _proxy_protocol_access_check(self): def _proxy_protocol_access_check(self):
"""Check if proxy protocol is allowed from this peer.""" """Check if proxy protocol is allowed from this peer."""
if ("*" not in self.cfg.proxy_allow_ips and if (isinstance(self.peer_addr, tuple) and
isinstance(self.peer_addr, tuple) and not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips)):
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(self.peer_addr[0]) raise ForbiddenProxyRequest(self.peer_addr[0])
def _parse_proxy_protocol(self, line): async def _parse_proxy_protocol_v1(self, buf):
"""Parse proxy protocol header line.""" """Parse PROXY protocol v1 (text format).
Returns buffer with v1 header consumed.
"""
# Read until we find \r\n
data = bytes(buf)
while b"\r\n" not in data:
await self._read_into(buf)
data = bytes(buf)
idx = data.find(b"\r\n")
line = bytes_to_str(data[:idx])
remaining = bytearray(data[idx + 2:])
bits = line.split(" ") bits = line.split(" ")
if len(bits) != 6: if len(bits) != 6:
@ -244,6 +279,101 @@ class AsyncRequest:
"proxy_port": d_port "proxy_port": d_port
} }
return remaining
async def _parse_proxy_protocol_v2(self, buf):
"""Parse PROXY protocol v2 (binary format).
Returns buffer with v2 header consumed.
"""
# We need at least 16 bytes for the header (12 signature + 4 header)
while len(buf) < 16:
await self._read_into(buf)
# Parse header fields (after 12-byte signature)
ver_cmd = buf[12]
fam_proto = buf[13]
length = struct.unpack(">H", bytes(buf[14:16]))[0]
# Validate version (high nibble must be 0x2)
version = (ver_cmd & 0xF0) >> 4
if version != 2:
raise InvalidProxyHeader("unsupported version %d" % version)
# Extract command (low nibble)
command = ver_cmd & 0x0F
if command not in (PPCommand.LOCAL, PPCommand.PROXY):
raise InvalidProxyHeader("unsupported command %d" % command)
# Ensure we have the complete header
total_header_size = 16 + length
while len(buf) < total_header_size:
await self._read_into(buf)
# For LOCAL command, no address info is provided
if command == PPCommand.LOCAL:
self.proxy_protocol_info = {
"proxy_protocol": "LOCAL",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
# Extract address family and protocol
family = (fam_proto & 0xF0) >> 4
protocol = fam_proto & 0x0F
# We only support TCP (STREAM)
if protocol != PPProtocol.STREAM:
raise InvalidProxyHeader("only TCP protocol is supported")
addr_data = bytes(buf[16:16 + length])
if family == PPFamily.INET: # IPv4
if length < 12: # 4+4+2+2
raise InvalidProxyHeader("insufficient address data for IPv4")
s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4])
d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8])
s_port = struct.unpack(">H", addr_data[8:10])[0]
d_port = struct.unpack(">H", addr_data[10:12])[0]
proto = "TCP4"
elif family == PPFamily.INET6: # IPv6
if length < 36: # 16+16+2+2
raise InvalidProxyHeader("insufficient address data for IPv6")
s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16])
d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32])
s_port = struct.unpack(">H", addr_data[32:34])[0]
d_port = struct.unpack(">H", addr_data[34:36])[0]
proto = "TCP6"
elif family == PPFamily.UNSPEC:
# No address info provided with PROXY command
self.proxy_protocol_info = {
"proxy_protocol": "UNSPEC",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
else:
raise InvalidProxyHeader("unsupported address family %d" % family)
# Set data
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
return bytearray(buf[total_header_size:])
def _parse_request_line(self, line_bytes): def _parse_request_line(self, line_bytes):
"""Parse the HTTP request line.""" """Parse the HTTP request line."""
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
@ -299,9 +429,8 @@ class AsyncRequest:
forwarder_headers = [] forwarder_headers = []
if from_trailer: if from_trailer:
pass pass
elif ('*' in cfg.forwarded_allow_ips or elif (not isinstance(self.peer_addr, tuple)
not isinstance(self.peer_addr, tuple) or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips)):
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers forwarder_headers = cfg.forwarder_headers

View File

@ -2082,20 +2082,57 @@ class NewSSLContext(Setting):
""" """
def validate_proxy_protocol(val):
"""Validate proxy_protocol setting.
Accepts: off, false, v1, v2, auto, true
Returns normalized value: off, v1, v2, or auto
"""
if val is None:
return "off"
if isinstance(val, bool):
return "auto" if val else "off"
if not isinstance(val, str):
raise TypeError("proxy_protocol must be string or bool")
val = val.lower().strip()
mapping = {
"false": "off", "off": "off", "0": "off", "none": "off",
"true": "auto", "auto": "auto", "1": "auto",
"v1": "v1", "v2": "v2",
}
if val not in mapping:
raise ValueError("proxy_protocol must be: off, v1, v2, or auto")
return mapping[val]
class ProxyProtocol(Setting): class ProxyProtocol(Setting):
name = "proxy_protocol" name = "proxy_protocol"
section = "Server Mechanics" section = "Server Mechanics"
cli = ["--proxy-protocol"] cli = ["--proxy-protocol"]
validator = validate_bool meta = "MODE"
default = False validator = validate_proxy_protocol
action = "store_true" default = "off"
nargs = "?"
const = "auto"
desc = """\ desc = """\
Enable detect PROXY protocol (PROXY mode). Enable PROXY protocol support.
Allow using HTTP and Proxy together. It may be useful for work with Allow using HTTP and PROXY protocol together. It may be useful for work
stunnel as HTTPS frontend and Gunicorn as HTTP server. with stunnel as HTTPS frontend and Gunicorn as HTTP server, or with
HAProxy.
PROXY protocol: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt Accepted values:
* ``off`` - Disabled (default)
* ``v1`` - PROXY protocol v1 only (text format)
* ``v2`` - PROXY protocol v2 only (binary format)
* ``auto`` - Auto-detect v1 or v2
Using ``--proxy-protocol`` without a value is equivalent to ``auto``.
PROXY protocol v1: http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt
PROXY protocol v2: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
Example for stunnel config:: Example for stunnel config::
@ -2105,6 +2142,9 @@ class ProxyProtocol(Setting):
connect = 80 connect = 80
cert = /etc/ssl/certs/stunnel.pem cert = /etc/ssl/certs/stunnel.pem
key = /etc/ssl/certs/stunnel.key key = /etc/ssl/certs/stunnel.key
.. versionchanged:: 24.0.0
Extended to support version selection (v1, v2, auto).
""" """

View File

@ -131,6 +131,15 @@ class InvalidProxyLine(ParseException):
return "Invalid PROXY line: %r" % self.line return "Invalid PROXY line: %r" % self.line
class InvalidProxyHeader(ParseException):
def __init__(self, msg):
self.msg = msg
self.code = 400
def __str__(self):
return "Invalid PROXY header: %s" % self.msg
class ForbiddenProxyRequest(ParseException): class ForbiddenProxyRequest(ParseException):
def __init__(self, host): def __init__(self, host):
self.host = host self.host = host

View File

@ -2,10 +2,11 @@
# This file is part of gunicorn released under the MIT license. # This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information. # See the NOTICE for more information.
import io from enum import IntEnum
import ipaddress import ipaddress
import re import re
import socket import socket
import struct
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
from gunicorn.http.errors import ( from gunicorn.http.errors import (
@ -14,10 +15,36 @@ from gunicorn.http.errors import (
LimitRequestLine, LimitRequestHeaders, LimitRequestLine, LimitRequestHeaders,
UnsupportedTransferCoding, ObsoleteFolding, UnsupportedTransferCoding, ObsoleteFolding,
) )
from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest from gunicorn.http.errors import InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest
from gunicorn.http.errors import InvalidSchemeHeaders from gunicorn.http.errors import InvalidSchemeHeaders
from gunicorn.util import bytes_to_str, split_request_uri from gunicorn.util import bytes_to_str, split_request_uri
# PROXY protocol v2 constants
PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
class PPCommand(IntEnum):
"""PROXY protocol v2 commands."""
LOCAL = 0x0
PROXY = 0x1
class PPFamily(IntEnum):
"""PROXY protocol v2 address families."""
UNSPEC = 0x0
INET = 0x1 # IPv4
INET6 = 0x2 # IPv6
UNIX = 0x3
class PPProtocol(IntEnum):
"""PROXY protocol v2 transport protocols."""
UNSPEC = 0x0
STREAM = 0x1 # TCP
DGRAM = 0x2 # UDP
MAX_REQUEST_LINE = 8190 MAX_REQUEST_LINE = 8190
MAX_HEADERS = 32768 MAX_HEADERS = 32768
DEFAULT_MAX_HEADERFIELD_SIZE = 8190 DEFAULT_MAX_HEADERFIELD_SIZE = 8190
@ -283,26 +310,21 @@ class Request(Message):
buf.write(data) buf.write(data)
def parse(self, unreader): def parse(self, unreader):
buf = io.BytesIO() buf = bytearray()
self.get_data(unreader, buf, stop=True) self.read_into(unreader, buf, stop=True)
# get request line # Handle proxy protocol if enabled and this is the first request
line, rbuf = self.read_line(unreader, buf, self.limit_request_line) mode = self.cfg.proxy_protocol
if mode != "off" and self.req_number == 1:
buf = self._handle_proxy_protocol(unreader, buf, mode)
# proxy protocol # Get request line
if self.proxy_protocol(bytes_to_str(line)): line, buf = self.read_line(unreader, buf, self.limit_request_line)
# get next request line
buf = io.BytesIO()
buf.write(rbuf)
line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
self.parse_request_line(line) self.parse_request_line(line)
buf = io.BytesIO()
buf.write(rbuf)
# Headers # Headers
data = buf.getvalue() data = bytes(buf)
idx = data.find(b"\r\n\r\n")
done = data[:2] == b"\r\n" done = data[:2] == b"\r\n"
while True: while True:
@ -310,8 +332,8 @@ class Request(Message):
done = data[:2] == b"\r\n" done = data[:2] == b"\r\n"
if idx < 0 and not done: if idx < 0 and not done:
self.get_data(unreader, buf) self.read_into(unreader, buf)
data = buf.getvalue() data = bytes(buf)
if len(data) > self.max_buffer_headers: if len(data) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers") raise LimitRequestHeaders("max buffer headers")
else: else:
@ -324,11 +346,20 @@ class Request(Message):
self.headers = self.parse_headers(data[:idx], from_trailer=False) self.headers = self.parse_headers(data[:idx], from_trailer=False)
ret = data[idx + 4:] ret = data[idx + 4:]
buf = None
return ret return ret
def read_into(self, unreader, buf, stop=False):
"""Read data from unreader and append to bytearray buffer."""
data = unreader.read()
if not data:
if stop:
raise StopIteration()
raise NoMoreData(bytes(buf))
buf.extend(data)
def read_line(self, unreader, buf, limit=0): def read_line(self, unreader, buf, limit=0):
data = buf.getvalue() """Read a line from buffer, returning (line, remaining_buffer)."""
data = bytes(buf)
while True: while True:
idx = data.find(b"\r\n") idx = data.find(b"\r\n")
@ -339,40 +370,61 @@ class Request(Message):
break break
if len(data) - 2 > limit > 0: if len(data) - 2 > limit > 0:
raise LimitRequestLine(len(data), limit) raise LimitRequestLine(len(data), limit)
self.get_data(unreader, buf) self.read_into(unreader, buf)
data = buf.getvalue() data = bytes(buf)
return (data[:idx], # request line, return (data[:idx], # request line,
data[idx + 2:]) # residue in the buffer, skip \r\n bytearray(data[idx + 2:])) # residue in the buffer, skip \r\n
def proxy_protocol(self, line): def read_bytes(self, unreader, buf, count):
"""\ """Read exactly count bytes from buffer/unreader."""
Detect, check and parse proxy protocol. while len(buf) < count:
self.read_into(unreader, buf)
return bytes(buf[:count]), bytearray(buf[count:])
:raises: ForbiddenProxyRequest, InvalidProxyLine. def _handle_proxy_protocol(self, unreader, buf, mode):
:return: True for proxy protocol line else False """Handle PROXY protocol detection and parsing.
Returns the buffer with proxy protocol data consumed.
""" """
if not self.cfg.proxy_protocol: # Ensure we have enough data to detect v2 signature (12 bytes)
return False while len(buf) < 12:
self.read_into(unreader, buf)
if self.req_number != 1: # Check for v2 signature first
return False if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE:
self.proxy_protocol_access_check()
return self._parse_proxy_protocol_v2(unreader, buf)
if not line.startswith("PROXY"): # Check for v1 prefix
return False if mode in ("v1", "auto") and buf[:6] == b"PROXY ":
self.proxy_protocol_access_check()
return self._parse_proxy_protocol_v1(unreader, buf)
self.proxy_protocol_access_check() # Not proxy protocol - return buffer unchanged
self.parse_proxy_protocol(line) return buf
return True
def proxy_protocol_access_check(self): def proxy_protocol_access_check(self):
# check in allow list """Check if proxy protocol is allowed from this peer."""
if (isinstance(self.peer_addr, tuple) and if (isinstance(self.peer_addr, tuple) and
not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips)): not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips)):
raise ForbiddenProxyRequest(self.peer_addr[0]) raise ForbiddenProxyRequest(self.peer_addr[0])
def parse_proxy_protocol(self, line): def _parse_proxy_protocol_v1(self, unreader, buf):
"""Parse PROXY protocol v1 (text format).
Returns buffer with v1 header consumed.
"""
# Read until we find \r\n
data = bytes(buf)
while b"\r\n" not in data:
self.read_into(unreader, buf)
data = bytes(buf)
idx = data.find(b"\r\n")
line = bytes_to_str(data[:idx])
remaining = bytearray(data[idx + 2:])
bits = line.split(" ") bits = line.split(" ")
if len(bits) != 6: if len(bits) != 6:
@ -417,6 +469,101 @@ class Request(Message):
"proxy_port": d_port "proxy_port": d_port
} }
return remaining
def _parse_proxy_protocol_v2(self, unreader, buf):
"""Parse PROXY protocol v2 (binary format).
Returns buffer with v2 header consumed.
"""
# We need at least 16 bytes for the header (12 signature + 4 header)
while len(buf) < 16:
self.read_into(unreader, buf)
# Parse header fields (after 12-byte signature)
ver_cmd = buf[12]
fam_proto = buf[13]
length = struct.unpack(">H", bytes(buf[14:16]))[0]
# Validate version (high nibble must be 0x2)
version = (ver_cmd & 0xF0) >> 4
if version != 2:
raise InvalidProxyHeader("unsupported version %d" % version)
# Extract command (low nibble)
command = ver_cmd & 0x0F
if command not in (PPCommand.LOCAL, PPCommand.PROXY):
raise InvalidProxyHeader("unsupported command %d" % command)
# Ensure we have the complete header
total_header_size = 16 + length
while len(buf) < total_header_size:
self.read_into(unreader, buf)
# For LOCAL command, no address info is provided
if command == PPCommand.LOCAL:
self.proxy_protocol_info = {
"proxy_protocol": "LOCAL",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
# Extract address family and protocol
family = (fam_proto & 0xF0) >> 4
protocol = fam_proto & 0x0F
# We only support TCP (STREAM)
if protocol != PPProtocol.STREAM:
raise InvalidProxyHeader("only TCP protocol is supported")
addr_data = bytes(buf[16:16 + length])
if family == PPFamily.INET: # IPv4
if length < 12: # 4+4+2+2
raise InvalidProxyHeader("insufficient address data for IPv4")
s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4])
d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8])
s_port = struct.unpack(">H", addr_data[8:10])[0]
d_port = struct.unpack(">H", addr_data[10:12])[0]
proto = "TCP4"
elif family == PPFamily.INET6: # IPv6
if length < 36: # 16+16+2+2
raise InvalidProxyHeader("insufficient address data for IPv6")
s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16])
d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32])
s_port = struct.unpack(">H", addr_data[32:34])[0]
d_port = struct.unpack(">H", addr_data[34:36])[0]
proto = "TCP6"
elif family == PPFamily.UNSPEC:
# No address info provided with PROXY command
self.proxy_protocol_info = {
"proxy_protocol": "UNSPEC",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
else:
raise InvalidProxyHeader("unsupported address family %d" % family)
# Set data
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
return bytearray(buf[total_header_size:])
def parse_request_line(self, line_bytes): def parse_request_line(self, line_bytes):
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
if len(bits) != 3: if len(bits) != 3:

View File

@ -0,0 +1,4 @@
GET /no/proxy/header HTTP/1.1\r\n
Host: example.com\r\n
Content-Length: 0\r\n
\r\n

View File

@ -0,0 +1,15 @@
from gunicorn.config import Config
cfg = Config()
cfg.set("proxy_protocol", True)
request = {
"method": "GET",
"uri": uri("/no/proxy/header"),
"version": (1, 1),
"headers": [
("HOST", "example.com"),
("CONTENT-LENGTH", "0")
],
"body": b""
}

View File

@ -0,0 +1,4 @@
\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21\x11\x00\x0C\xC0\xA8\x01\x0A\xC0\xA8\x01\x01\x30\x39\x01\xBBGET /proxy/v2/ipv4 HTTP/1.1\r\n
Host: example.com\r\n
Content-Length: 0\r\n
\r\n

View File

@ -0,0 +1,15 @@
from gunicorn.config import Config
cfg = Config()
cfg.set("proxy_protocol", True)
request = {
"method": "GET",
"uri": uri("/proxy/v2/ipv4"),
"version": (1, 1),
"headers": [
("HOST", "example.com"),
("CONTENT-LENGTH", "0")
],
"body": b""
}

View File

@ -0,0 +1,4 @@
\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A\x21\x21\x00\x24\x20\x01\x0D\xB8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x20\x01\x0D\xB8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xD4\x31\x00\x50GET /proxy/v2/ipv6 HTTP/1.1\r\n
Host: example.com\r\n
Content-Length: 0\r\n
\r\n

View File

@ -0,0 +1,15 @@
from gunicorn.config import Config
cfg = Config()
cfg.set("proxy_protocol", True)
request = {
"method": "GET",
"uri": uri("/proxy/v2/ipv6"),
"version": (1, 1),
"headers": [
("HOST", "example.com"),
("CONTENT-LENGTH", "0")
],
"body": b""
}

View File

@ -8,6 +8,7 @@ Tests for ASGI worker components.
import asyncio import asyncio
import io import io
import ipaddress
import pytest import pytest
from unittest import mock from unittest import mock
@ -48,9 +49,9 @@ class MockConfig:
def __init__(self): def __init__(self):
self.is_ssl = False self.is_ssl = False
self.proxy_protocol = False self.proxy_protocol = "off"
self.proxy_allow_ips = ["127.0.0.1"] self.proxy_allow_ips = [ipaddress.ip_network("127.0.0.1")]
self.forwarded_allow_ips = ["127.0.0.1"] self.forwarded_allow_ips = [ipaddress.ip_network("127.0.0.1")]
self.secure_scheme_headers = {} self.secure_scheme_headers = {}
self.forwarder_headers = [] self.forwarder_headers = []
self.limit_request_line = 8190 self.limit_request_line = 8190

View File

@ -1385,7 +1385,7 @@ class TestKeepaliveBlockingMode:
conn.parser = mock_parser conn.parser = mock_parser
# Mock handle_request to invoke wsgi # Mock handle_request to invoke wsgi
original_handle_request = worker.handle_request _ = worker.handle_request # save reference before overwriting
def mock_handle_request(req, conn): def mock_handle_request(req, conn):
# Simplified version that just calls wsgi # Simplified version that just calls wsgi

View File

@ -39,6 +39,27 @@ def load_py(fname):
return vars(mod) return vars(mod)
def decode_hex_escapes(data):
"""Decode hex escape sequences like \\xAB in test data."""
import re
result = bytearray()
i = 0
while i < len(data):
# Check for \xHH hex escape
if i + 3 < len(data) and data[i:i+2] == b'\\x':
hex_chars = data[i+2:i+4]
try:
byte_val = int(hex_chars, 16)
result.append(byte_val)
i += 4
continue
except ValueError:
pass
result.append(data[i])
i += 1
return bytes(result)
class request: class request:
def __init__(self, fname, expect): def __init__(self, fname, expect):
self.fname = fname self.fname = fname
@ -52,8 +73,10 @@ class request:
self.data = handle.read() self.data = handle.read()
self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n") self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n")
self.data = self.data.replace(b"\\0", b"\000").replace(b"\\n", b"\n").replace(b"\\t", b"\t") self.data = self.data.replace(b"\\0", b"\000").replace(b"\\n", b"\n").replace(b"\\t", b"\t")
# Handle hex escape sequences for binary data (e.g., \x0D for PROXY v2)
self.data = decode_hex_escapes(self.data)
if b"\\" in self.data: if b"\\" in self.data:
raise AssertionError("Unexpected backslash in test data - only handling HTAB, NUL and CRLF") raise AssertionError("Unexpected backslash in test data - only handling HTAB, NUL, CRLF, and hex escapes")
# Functions for sending data to the parser. # Functions for sending data to the parser.
# These functions mock out reading from a # These functions mock out reading from a