fix: address six WSGI/ASGI parser and protocol findings

- WSGI fast parser now applies the same per-header policy as the Python
  parser (Expect, secure_scheme_headers, forwarded_allow_ips trust gate,
  forwarder_headers / header_map). Shared helpers extracted on Message.

- ASGI keepalive no longer resets the parser when the previous request
  body was not fully framed; the connection closes instead, preventing
  request smuggling on pipelined connections.

- BodyReceiver._wait_for_data timeout flips _closed and yields
  http.disconnect rather than synthesizing more_body=False. Timeout
  honors cfg.timeout.

- ASGI chunked encoding now skips HEAD, 204, and 304 (matches
  Response.is_chunked in the WSGI path) via a small helper.

- _setup_callback_parser passes proxy_protocol to PythonProtocol; auto
  falls back to the Python parser when proxy_protocol != off (the C
  parser does not implement PROXY framing). _effective_peername swaps
  the transport peer with the PROXY-supplied client address.

- Parser.finish_body accepts a deadline and a 64KiB byte cap; gthread
  passes a deadline and abandons keepalive on incomplete drain so a
  stalled client cannot tie up a worker thread.
This commit is contained in:
Benoit Chesneau 2026-05-03 18:19:08 +02:00
parent 4bcda32a78
commit e90b1c2c1e
10 changed files with 642 additions and 104 deletions

View File

@ -254,11 +254,16 @@ class BodyReceiver:
if self._chunks: if self._chunks:
return self._pop_chunk() return self._pop_chunk()
# Complete OR timeout - mark body finished to prevent infinite loops if self._complete:
# Apps should not loop forever waiting for body that won't arrive
self._body_finished = True self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False} return {"type": "http.request", "body": b"", "more_body": False}
# Wait returned without data and the message was not framed complete:
# treat as a client disconnect rather than synthesizing end-of-body
# (which would desync the next pipelined request).
self._closed = True
return {"type": "http.disconnect"}
async def _wait_for_data(self): async def _wait_for_data(self):
"""Wait for body data to arrive via callback.""" """Wait for body data to arrive via callback."""
if self._chunks or self._complete or self._closed: if self._chunks or self._complete or self._closed:
@ -268,11 +273,21 @@ class BodyReceiver:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
self._waiter = loop.create_future() self._waiter = loop.create_future()
# Bound the wait by the configured worker timeout (default 30s).
# The protocol-level timeout drives transport disconnect handling;
# this only needs to escape an idle wait if data never arrives.
cfg = getattr(self.protocol, 'cfg', None)
timeout = getattr(cfg, 'timeout', None) if cfg is not None else None
if not timeout or timeout <= 0:
timeout = 30.0
try: try:
# Wait with timeout for data or completion await asyncio.wait_for(self._waiter, timeout=timeout)
await asyncio.wait_for(self._waiter, timeout=30.0)
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass # No data arrived in time: mark the body receiver as disconnected
# so receive() yields http.disconnect rather than a fake terminal
# http.request with more_body=False.
self._closed = True
finally: finally:
self._waiter = None self._waiter = None
@ -436,6 +451,10 @@ class ASGIProtocol(asyncio.Protocol):
for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS: for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS:
if getattr(self.cfg, flag, False): if getattr(self.cfg, flag, False):
incompatible.append(flag) incompatible.append(flag)
# PROXY protocol framing is implemented only in PythonProtocol; the C parser
# has no proxy_protocol kwarg and would silently drop the framing.
if getattr(self.cfg, 'proxy_protocol', 'off') != 'off':
incompatible.append('proxy_protocol')
if parser_setting == 'python': if parser_setting == 'python':
parser_class = PythonProtocol parser_class = PythonProtocol
@ -467,8 +486,9 @@ class ASGIProtocol(asyncio.Protocol):
if limit_request_line == 0 and parser_class != PythonProtocol: if limit_request_line == 0 and parser_class != PythonProtocol:
limit_request_line = 1024 * 1024 # 1MB for C parser limit_request_line = 1024 * 1024 # 1MB for C parser
# Create parser with callbacks and limit parameters (both parsers support them) # Create parser with callbacks and limit parameters (both parsers support them).
self._callback_parser = parser_class( # Only the Python parser implements PROXY protocol framing; pass the option there.
parser_kwargs = dict(
on_headers_complete=self._on_headers_complete, on_headers_complete=self._on_headers_complete,
on_body=self._on_body, on_body=self._on_body,
on_message_complete=self._on_message_complete, on_message_complete=self._on_message_complete,
@ -478,6 +498,9 @@ class ASGIProtocol(asyncio.Protocol):
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
) )
if parser_class is PythonProtocol:
parser_kwargs['proxy_protocol'] = getattr(self.cfg, 'proxy_protocol', 'off')
self._callback_parser = parser_class(**parser_kwargs)
def _on_headers_complete(self): def _on_headers_complete(self):
"""Callback: request headers are complete.""" """Callback: request headers are complete."""
@ -729,14 +752,17 @@ class ASGIProtocol(asyncio.Protocol):
request = self._current_request request = self._current_request
# If PROXY protocol provided a real client address, use it.
effective_peer = self._effective_peername(peername)
# Check for WebSocket upgrade # Check for WebSocket upgrade
if self._is_websocket_upgrade(request): if self._is_websocket_upgrade(request):
await self._handle_websocket(request, sockname, peername) await self._handle_websocket(request, sockname, effective_peer)
break # WebSocket takes over the connection break # WebSocket takes over the connection
# Handle HTTP request # Handle HTTP request
keepalive = await self._handle_http_request( keepalive = await self._handle_http_request(
request, sockname, peername request, sockname, effective_peer
) )
# Increment worker request count # Increment worker request count
@ -755,6 +781,18 @@ class ASGIProtocol(asyncio.Protocol):
if not self.cfg.keepalive: if not self.cfg.keepalive:
break break
# Refuse keepalive if the previous request body was not fully
# framed: residual bytes left in the transport stream would be
# parsed as the start of the next request (smuggling).
receiver = self._body_receiver
message_complete = (
receiver is None
or receiver._complete
or receiver._closed
)
if not message_complete:
break
# Resume reading if paused during body consumption # Resume reading if paused during body consumption
self._resume_reading() self._resume_reading()
@ -918,15 +956,15 @@ class ASGIProtocol(asyncio.Protocol):
has_transfer_encoding = True has_transfer_encoding = True
use_chunked = True # Framework already set chunked encoding use_chunked = True # Framework already set chunked encoding
# Use chunked encoding for HTTP/1.1 streaming responses without Content-Length # Use chunked encoding for HTTP/1.1 streaming responses without Content-Length.
# Skip for 1xx informational responses (RFC 9110) # Skip when the response cannot carry a body (HEAD/1xx/204/304) or when
# Skip if Transfer-Encoding already set by framework # Transfer-Encoding was already set by the framework.
is_informational = 100 <= response_status < 200 is_no_body = self._response_omits_body(request.method, response_status)
needs_chunked = ( needs_chunked = (
not has_content_length not has_content_length
and not has_transfer_encoding and not has_transfer_encoding
and request.version >= (1, 1) and request.version >= (1, 1)
and not is_informational and not is_no_body
) )
if needs_chunked: if needs_chunked:
use_chunked = True use_chunked = True
@ -1201,6 +1239,36 @@ class ASGIProtocol(asyncio.Protocol):
# Buffer headers for batching with first body chunk # Buffer headers for batching with first body chunk
self._response_buffer = b"".join(parts) self._response_buffer = b"".join(parts)
def _effective_peername(self, peername):
"""Return the client address advertised via PROXY protocol if any.
Falls back to the transport peername when PROXY protocol is disabled,
the framing was absent, or the parser is the C variant (which currently
does not surface PROXY metadata).
"""
parser = self._callback_parser
info = getattr(parser, 'proxy_protocol_info', None) if parser else None
if not info:
return peername
client_addr = info.get('client_addr')
client_port = info.get('client_port')
if client_addr is None or client_port is None:
return peername
return (client_addr, client_port)
@staticmethod
def _response_omits_body(method, status):
"""Return True when the response MUST NOT have a body (RFC 9110).
Applies to HEAD requests and to status codes that semantically carry no body:
1xx informational, 204 No Content, 304 Not Modified.
"""
return (
method == "HEAD"
or status in (204, 304)
or 100 <= status < 200
)
def _send_body(self, body, chunked=False): def _send_body(self, body, chunked=False):
"""Send response body chunk. """Send response body chunk.

View File

@ -206,26 +206,94 @@ class Message:
def parse(self, unreader): def parse(self, unreader):
raise NotImplementedError() raise NotImplementedError()
def parse_headers(self, data, from_trailer=False): def _peer_trusted_for_forwarded(self):
"""Return the (secure_scheme_headers, forwarder_headers) the peer is allowed to set.
When the peer's address is not in ``forwarded_allow_ips`` (or networks),
configured forwarding/secure-scheme policy must be ignored to prevent
spoofing. Returns ``({}, [])`` when the peer is untrusted.
"""
cfg = self.cfg cfg = self.cfg
if (not isinstance(self.peer_addr, tuple)
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
cfg.forwarded_allow_networks())):
return cfg.secure_scheme_headers, cfg.forwarder_headers
return {}, []
def _apply_header_policy(self, name, value, scheme_state,
secure_scheme_headers, forwarder_headers,
from_trailer=False):
"""Apply per-header policy shared between Python and fast parsers.
Mutates ``self._expected_100_continue`` and ``self.scheme`` as needed.
``scheme_state`` is a single-element list used as a mutable sentinel
so the caller can detect repeated scheme headers.
Returns the (name, value) pair to retain, or ``None`` to drop the
header (per ``header_map='drop'``). Raises the same exceptions the
Python path raises so behavior is identical regardless of parser.
"""
if not from_trailer and name == "EXPECT":
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
# "The Expect field value is case-insensitive."
if value.lower() == "100-continue":
if self.version < (1, 1):
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
# "A server that receives a 100-continue expectation
# in an HTTP/1.0 request MUST ignore that expectation."
pass
else:
self._expected_100_continue = True
# N.B. understood but ignored expect header does not return 417
else:
raise ExpectationFailed(value)
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_state[0]:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_state[0] = True
self.scheme = scheme
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
# X-Forwarded-For: 2001:db8::ha:cc:ed
# X_Forwarded_For: 127.0.0.1,::1
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
# Only modify after fixing *ALL* header transformations; network to wsgi env
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
# This forwarder may override our environment
pass
elif self.cfg.header_map == "dangerous":
# as if we did not know we cannot safely map this
pass
elif self.cfg.header_map == "drop":
# almost as if it never had been there
# but still counts against resource limits
return None
else:
# fail-safe fallthrough: refuse
raise InvalidHeaderName(name)
return (name, value)
def parse_headers(self, data, from_trailer=False):
headers = [] headers = []
# Split lines on \r\n # Split lines on \r\n
lines = [bytes_to_str(line) for line in data.split(b"\r\n")] lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
# handle scheme headers # handle scheme headers
scheme_header = False scheme_state = [False]
secure_scheme_headers = {}
forwarder_headers = []
if from_trailer: if from_trailer:
# nonsense. either a request is https from the beginning # nonsense. either a request is https from the beginning
# .. or we are just behind a proxy who does not remove conflicting trailers # .. or we are just behind a proxy who does not remove conflicting trailers
pass secure_scheme_headers, forwarder_headers = {}, []
elif (not isinstance(self.peer_addr, tuple) else:
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips, secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
cfg.forwarded_allow_networks())):
secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers
# Parse headers into key/value pairs paying attention # Parse headers into key/value pairs paying attention
# to continuation lines. # to continuation lines.
@ -275,52 +343,14 @@ class Message:
if header_length > self.limit_request_field_size > 0: if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size") raise LimitRequestHeaders("limit request headers fields size")
if not from_trailer and name == "EXPECT": kept = self._apply_header_policy(
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1 name, value, scheme_state,
# "The Expect field value is case-insensitive." secure_scheme_headers, forwarder_headers,
if value.lower() == "100-continue": from_trailer=from_trailer,
if self.version < (1, 1): )
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12 if kept is None:
# "A server that receives a 100-continue expectation
# in an HTTP/1.0 request MUST ignore that expectation."
pass
else:
self._expected_100_continue = True
# N.B. understood but ignored expect header does not return 417
else:
raise ExpectationFailed(value)
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_header:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_header = True
self.scheme = scheme
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
# X-Forwarded-For: 2001:db8::ha:cc:ed
# X_Forwarded_For: 127.0.0.1,::1
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
# Only modify after fixing *ALL* header transformations; network to wsgi env
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
# This forwarder may override our environment
pass
elif self.cfg.header_map == "dangerous":
# as if we did not know we cannot safely map this
pass
elif self.cfg.header_map == "drop":
# almost as if it never had been there
# but still counts against resource limits
continue continue
else: headers.append(kept)
# fail-safe fallthrough: refuse
raise InvalidHeaderName(name)
headers.append((name, value))
return headers return headers
@ -516,25 +546,23 @@ class Request(Message):
# Headers - convert bytes to strings with uppercase names # Headers - convert bytes to strings with uppercase names
# gunicorn_h1c returns headers as (bytes, bytes) tuples # gunicorn_h1c returns headers as (bytes, bytes) tuples
# Header name/value validation done by C parser # Header name/value validation done by C parser; policy (Expect,
# secure_scheme_headers, forwarder trust gate, header_map) is enforced
# below so the fast path mirrors parse_headers().
self.headers = [] self.headers = []
scheme_state = [False]
secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
for name_bytes, value_bytes in result['headers']: for name_bytes, value_bytes in result['headers']:
name = bytes_to_str(name_bytes).upper() name = bytes_to_str(name_bytes).upper()
value = bytes_to_str(value_bytes) value = bytes_to_str(value_bytes)
# Handle underscore in header names (policy decision, not validation) kept = self._apply_header_policy(
if "_" in name: name, value, scheme_state,
forwarder_headers = self.cfg.forwarder_headers secure_scheme_headers, forwarder_headers,
if name in forwarder_headers or "*" in forwarder_headers: )
pass if kept is None:
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue continue
else: self.headers.append(kept)
raise InvalidHeaderName(name)
self.headers.append((name, value))
# Return remaining data after headers # Return remaining data after headers
consumed = result['consumed'] consumed = result['consumed']

View File

@ -2,12 +2,20 @@
# This file is part of gunicorn released under the MIT license. # This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information. # See the NOTICE for more information.
import socket
import ssl import ssl
import time
from gunicorn.http.message import Request from gunicorn.http.message import Request
from gunicorn.http.unreader import SocketUnreader, IterUnreader from gunicorn.http.unreader import SocketUnreader, IterUnreader
# Cap on bytes drained from an unconsumed request body before a keepalive
# reset. Defends against a slow-but-steady client that stays under a per-read
# deadline yet streams indefinitely.
_DRAIN_MAX_BYTES = 64 * 1024
class Parser: class Parser:
mesg_class = None mesg_class = None
@ -27,20 +35,60 @@ class Parser:
def __iter__(self): def __iter__(self):
return self return self
def finish_body(self): def finish_body(self, deadline=None, max_bytes=_DRAIN_MAX_BYTES):
"""Discard any unread body of the current message. """Discard any unread body of the current message.
This should be called before returning a keepalive connection to Called before returning a keepalive connection to the poller so the
the poller to ensure the socket doesn't appear readable due to socket does not appear readable due to leftover body bytes.
leftover body bytes.
``deadline`` is an absolute ``time.monotonic()`` value; when set the
socket read timeout is bounded by the remaining time before each read.
``max_bytes`` caps the total drained bytes to defend against a slow
client that keeps trickling under the deadline.
Returns ``True`` when the body was fully drained, ``False`` when the
drain was abandoned (deadline, byte cap, or socket timeout). Callers
that observe ``False`` MUST close the connection rather than serve
another request on it.
""" """
if self.mesg: if not self.mesg:
return True
sock = getattr(self.unreader, "sock", None)
# gettimeout/settimeout only matter when bounding a real socket; a
# mock or non-socket source skips the timeout plumbing.
if sock is not None and hasattr(sock, "gettimeout") and hasattr(sock, "settimeout"):
timeoutable_sock = sock
prior_timeout = sock.gettimeout()
else:
timeoutable_sock = None
prior_timeout = None
drained = 0
try:
while True:
if deadline is not None and timeoutable_sock is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
timeoutable_sock.settimeout(remaining)
try: try:
data = self.mesg.body.read(1024) data = self.mesg.body.read(1024)
while data: except (socket.timeout, TimeoutError):
data = self.mesg.body.read(1024) return False
except ssl.SSLWantReadError: except ssl.SSLWantReadError:
# SSL socket has no more application data available # SSL socket has no more application data available
return True
if not data:
return True
drained += len(data)
if drained >= max_bytes:
return False
finally:
if timeoutable_sock is not None:
try:
timeoutable_sock.settimeout(prior_timeout)
except OSError:
pass pass
def __next__(self): def __next__(self):

View File

@ -478,9 +478,14 @@ class ThreadWorker(base.Worker):
# Handle the request # Handle the request
keepalive = self.handle_request(req, conn) keepalive = self.handle_request(req, conn)
if keepalive: if keepalive:
# Discard any unread request body before keepalive # Discard any unread request body before keepalive to prevent
# to prevent socket appearing readable due to leftover bytes # the socket from appearing readable due to leftover bytes.
conn.parser.finish_body() # Bound the drain by the worker data timeout: a stalled client
# must not keep this thread blocked.
drain_deadline = time.monotonic() + DEFAULT_WORKER_DATA_TIMEOUT
if not conn.parser.finish_body(deadline=drain_deadline):
# Abandon keepalive when the body could not be fully drained.
return False
return True return True
except http.errors.NoMoreData as e: except http.errors.NoMoreData as e:
self.log.debug("Ignored premature client disconnection. %s", e) self.log.debug("Ignored premature client disconnection. %s", e)

View File

@ -204,3 +204,65 @@ class TestASGIDisconnectGracePeriod:
from gunicorn.config import Config from gunicorn.config import Config
cfg = Config() cfg = Config()
assert cfg.asgi_disconnect_grace_period == 3 assert cfg.asgi_disconnect_grace_period == 3
class TestBodyReceiverIncompleteBody:
"""Cover the receive() path when the request body never finishes framing."""
@pytest.fixture
def mock_worker(self):
worker = mock.Mock()
worker.nr_conns = 0
worker.loop = asyncio.new_event_loop()
worker.cfg = mock.Mock()
worker.cfg.asgi_disconnect_grace_period = 3
worker.cfg.timeout = 0.05 # tight bound for the test
worker.log = mock.Mock()
return worker
@pytest.mark.asyncio
async def test_receive_yields_disconnect_on_timeout(self, mock_worker):
"""When _wait_for_data times out and the body is not complete, the
receiver MUST yield http.disconnect rather than synthesize a terminal
http.request with more_body=False that would desync the next
pipelined request."""
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
request = mock.Mock()
request.content_length = 100
request.chunked = False
receiver = BodyReceiver(request, protocol)
protocol._body_receiver = receiver
msg = await receiver.receive()
assert msg == {"type": "http.disconnect"}
assert receiver._closed is True
@pytest.mark.asyncio
async def test_receive_yields_terminal_request_when_complete(self, mock_worker):
"""If the body is framed complete, the existing terminal http.request
with more_body=False must still be returned."""
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
request = mock.Mock()
request.content_length = 5
request.chunked = False
receiver = BodyReceiver(request, protocol)
protocol._body_receiver = receiver
receiver.feed(b"hello")
receiver.set_complete()
msg = await receiver.receive()
assert msg["type"] == "http.request"
assert msg["body"] == b"hello"
# more_body may be False since the body is complete
assert msg["more_body"] is False

View File

@ -344,6 +344,43 @@ class TestHTTPVersionForChunked:
assert uses_http1_chunked is False assert uses_http1_chunked is False
# ============================================================================
# No-Body Response Tests (RFC 9110)
# ============================================================================
class TestResponseOmitsBody:
"""Verify HEAD/1xx/204/304 are flagged as bodyless responses."""
def _omits(self, method, status):
from gunicorn.asgi.protocol import ASGIProtocol
return ASGIProtocol._response_omits_body(method, status)
def test_head_omits_body(self):
assert self._omits("HEAD", 200) is True
assert self._omits("HEAD", 500) is True
def test_204_omits_body(self):
assert self._omits("GET", 204) is True
assert self._omits("POST", 204) is True
def test_304_omits_body(self):
assert self._omits("GET", 304) is True
def test_informational_omits_body(self):
assert self._omits("GET", 100) is True
assert self._omits("GET", 103) is True
assert self._omits("GET", 199) is True
def test_get_200_has_body(self):
assert self._omits("GET", 200) is False
def test_post_200_has_body(self):
assert self._omits("POST", 200) is False
def test_404_has_body(self):
assert self._omits("GET", 404) is False
# ============================================================================ # ============================================================================
# Streaming Response Message Sequence Tests # Streaming Response Message Sequence Tests
# ============================================================================ # ============================================================================

View File

@ -45,10 +45,6 @@ def test_asgi_parser(fname):
if getattr(cfg, flag, False): if getattr(cfg, flag, False):
pytest.skip(f"Callback parser incompatible with {flag}") pytest.skip(f"Callback parser incompatible with {flag}")
# Skip proxy protocol tests
if getattr(cfg, 'proxy_protocol', 'off') != 'off':
pytest.skip("Callback parser does not support proxy_protocol")
req = treq_asgi.request(fname, expect) req = treq_asgi.request(fname, expect)
# Test with different sending strategies # Test with different sending strategies

View File

@ -582,6 +582,59 @@ class TestASGIProtocol:
assert scope["root_path"] == "/api" assert scope["root_path"] == "/api"
assert scope["http_version"] == "1.1" assert scope["http_version"] == "1.1"
def test_effective_peername_no_proxy(self):
"""Without PROXY framing the transport peername is returned as-is."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info=None)
peer = ("10.0.0.1", 12345)
assert protocol._effective_peername(peer) == peer
def test_effective_peername_with_proxy(self):
"""PROXY-supplied client address overrides the transport peername."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info={
'proxy_protocol': 'TCP4',
'client_addr': '203.0.113.5',
'client_port': 56324,
'proxy_addr': '10.0.0.2',
'proxy_port': 443,
})
assert protocol._effective_peername(("10.0.0.1", 1)) == ("203.0.113.5", 56324)
def test_effective_peername_unknown_proxy(self):
"""UNKNOWN PROXY framing has no client info; fall back to transport peername."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info={
'proxy_protocol': 'UNKNOWN',
'client_addr': None,
'client_port': None,
'proxy_addr': None,
'proxy_port': None,
})
peer = ("10.0.0.1", 12345)
assert protocol._effective_peername(peer) == peer
# ============================================================================ # ============================================================================
# Config Tests # Config Tests

View File

@ -0,0 +1,185 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Parity tests for WSGI header policy across Python and fast parsers.
These checks ensure that Expect, secure_scheme_headers, forwarder_headers,
and the forwarded_allow_ips trust gate are enforced identically regardless
of the parser implementation selected by ``http_parser``.
"""
import sys
import pytest
from gunicorn.config import Config
from gunicorn.http.parser import RequestParser
from gunicorn.http.errors import (
ExpectationFailed,
InvalidHeaderName,
InvalidSchemeHeaders,
)
def _parse(raw, cfg, peer_addr):
parser = RequestParser(cfg, iter([raw]), peer_addr)
return next(iter(parser))
def _cfg(http_parser, **overrides):
cfg = Config()
cfg.set("http_parser", http_parser)
for k, v in overrides.items():
cfg.set(k, v)
return cfg
@pytest.fixture(params=["python", "fast"])
def parser_name(request):
if request.param == "fast":
if hasattr(sys, "pypy_version_info"):
pytest.skip("gunicorn_h1c not supported on PyPy")
gunicorn_h1c = pytest.importorskip("gunicorn_h1c")
if not hasattr(gunicorn_h1c.H1CProtocol, "asgi_headers"):
pytest.skip("gunicorn_h1c >= 0.6.2 required")
return request.param
class TestExpectPolicy:
def test_expect_100_continue_sets_flag(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: 100-continue\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req._expected_100_continue is True
def test_expect_unknown_value_rejected(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: bogus-extension\r\n"
b"\r\n"
)
with pytest.raises(ExpectationFailed):
_parse(raw, cfg, ("127.0.0.1", 1234))
def test_expect_ignored_in_http10(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.0\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: 100-continue\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req._expected_100_continue is False
class TestSecureSchemeHeaders:
def test_trusted_peer_promotes_https(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req.scheme == "https"
def test_untrusted_peer_keeps_http(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("203.0.113.5", 1234))
assert req.scheme == "http"
def test_conflicting_scheme_headers_rejected(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={
"X-FORWARDED-PROTO": "https",
"X-FORWARDED-SSL": "on",
},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"X-Forwarded-Ssl: off\r\n"
b"\r\n"
)
with pytest.raises(InvalidSchemeHeaders):
_parse(raw, cfg, ("127.0.0.1", 1234))
class TestForwarderTrustGate:
def test_untrusted_peer_underscore_header_rejected(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
forwarder_headers="SCRIPT_NAME",
header_map="refuse",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Script_Name: /evil\r\n"
b"\r\n"
)
with pytest.raises(InvalidHeaderName):
_parse(raw, cfg, ("203.0.113.5", 1234))
def test_trusted_peer_underscore_header_accepted(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
forwarder_headers="SCRIPT_NAME",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Script_Name: /api\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
names = {n for n, _ in req.headers}
assert "SCRIPT_NAME" in names
def test_header_map_drop_silences_underscore(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
header_map="drop",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Stray_Name: x\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("203.0.113.5", 1234))
names = {n for n, _ in req.headers}
assert "STRAY_NAME" not in names

View File

@ -255,6 +255,62 @@ def test_invalid_http_version_error():
assert str(InvalidHTTPVersion((2, 1))) == 'Invalid HTTP Version: (2, 1)' assert str(InvalidHTTPVersion((2, 1))) == 'Invalid HTTP Version: (2, 1)'
def _build_request_parser(payload):
"""Construct a RequestParser that drains the given bytes."""
from gunicorn.config import Config
from gunicorn.http.parser import RequestParser
cfg = Config()
parser = RequestParser(cfg, iter([payload]), None)
next(iter(parser))
return parser
def test_finish_body_drains_remainder():
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
parser = _build_request_parser(payload)
assert parser.finish_body() is True
def test_finish_body_returns_false_when_byte_cap_exceeded():
body = b"x" * (4096)
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: %d\r\n\r\n%s" % (len(body), body)
)
parser = _build_request_parser(payload)
assert parser.finish_body(max_bytes=512) is False
def test_finish_body_returns_false_on_expired_deadline():
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 100\r\n"
b"\r\n"
b"only-partial"
)
parser = _build_request_parser(payload)
# Force an already-elapsed deadline; the drain must abandon immediately.
import time as _time
expired = _time.monotonic() - 1.0
# IterUnreader has no socket; deadline path is exercised only when sock
# is present. Stub a sock with gettimeout/settimeout to drive the branch.
from unittest import mock
sock = mock.Mock()
sock.gettimeout.return_value = None
parser.unreader.sock = sock
assert parser.finish_body(deadline=expired) is False
sock.settimeout.assert_called_with(None)
def test_file_wrapper_iterable(): def test_file_wrapper_iterable():
"""FileWrapper should support the iterator protocol per PEP 3333.""" """FileWrapper should support the iterator protocol per PEP 3333."""
filelike = io.BytesIO(b"hello world") filelike = io.BytesIO(b"hello world")