fix: address six WSGI/ASGI parser and protocol findings

- WSGI fast parser now applies the same per-header policy as the Python
  parser (Expect, secure_scheme_headers, forwarded_allow_ips trust gate,
  forwarder_headers / header_map). Shared helpers extracted on Message.

- ASGI keepalive no longer resets the parser when the previous request
  body was not fully framed; the connection closes instead, preventing
  request smuggling on pipelined connections.

- BodyReceiver._wait_for_data timeout flips _closed and yields
  http.disconnect rather than synthesizing more_body=False. Timeout
  honors cfg.timeout.

- ASGI chunked encoding now skips HEAD, 204, and 304 (matches
  Response.is_chunked in the WSGI path) via a small helper.

- _setup_callback_parser passes proxy_protocol to PythonProtocol; auto
  falls back to the Python parser when proxy_protocol != off (the C
  parser does not implement PROXY framing). _effective_peername swaps
  the transport peer with the PROXY-supplied client address.

- Parser.finish_body accepts a deadline and a 64KiB byte cap; gthread
  passes a deadline and abandons keepalive on incomplete drain so a
  stalled client cannot tie up a worker thread.
This commit is contained in:
Benoit Chesneau 2026-05-03 18:19:08 +02:00
parent 4bcda32a78
commit e90b1c2c1e
10 changed files with 642 additions and 104 deletions

View File

@ -254,10 +254,15 @@ class BodyReceiver:
if self._chunks:
return self._pop_chunk()
# Complete OR timeout - mark body finished to prevent infinite loops
# Apps should not loop forever waiting for body that won't arrive
self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False}
if self._complete:
self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False}
# Wait returned without data and the message was not framed complete:
# treat as a client disconnect rather than synthesizing end-of-body
# (which would desync the next pipelined request).
self._closed = True
return {"type": "http.disconnect"}
async def _wait_for_data(self):
"""Wait for body data to arrive via callback."""
@ -268,11 +273,21 @@ class BodyReceiver:
loop = asyncio.get_event_loop()
self._waiter = loop.create_future()
# Bound the wait by the configured worker timeout (default 30s).
# The protocol-level timeout drives transport disconnect handling;
# this only needs to escape an idle wait if data never arrives.
cfg = getattr(self.protocol, 'cfg', None)
timeout = getattr(cfg, 'timeout', None) if cfg is not None else None
if not timeout or timeout <= 0:
timeout = 30.0
try:
# Wait with timeout for data or completion
await asyncio.wait_for(self._waiter, timeout=30.0)
await asyncio.wait_for(self._waiter, timeout=timeout)
except asyncio.TimeoutError:
pass
# No data arrived in time: mark the body receiver as disconnected
# so receive() yields http.disconnect rather than a fake terminal
# http.request with more_body=False.
self._closed = True
finally:
self._waiter = None
@ -436,6 +451,10 @@ class ASGIProtocol(asyncio.Protocol):
for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS:
if getattr(self.cfg, flag, False):
incompatible.append(flag)
# PROXY protocol framing is implemented only in PythonProtocol; the C parser
# has no proxy_protocol kwarg and would silently drop the framing.
if getattr(self.cfg, 'proxy_protocol', 'off') != 'off':
incompatible.append('proxy_protocol')
if parser_setting == 'python':
parser_class = PythonProtocol
@ -467,8 +486,9 @@ class ASGIProtocol(asyncio.Protocol):
if limit_request_line == 0 and parser_class != PythonProtocol:
limit_request_line = 1024 * 1024 # 1MB for C parser
# Create parser with callbacks and limit parameters (both parsers support them)
self._callback_parser = parser_class(
# Create parser with callbacks and limit parameters (both parsers support them).
# Only the Python parser implements PROXY protocol framing; pass the option there.
parser_kwargs = dict(
on_headers_complete=self._on_headers_complete,
on_body=self._on_body,
on_message_complete=self._on_message_complete,
@ -478,6 +498,9 @@ class ASGIProtocol(asyncio.Protocol):
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
)
if parser_class is PythonProtocol:
parser_kwargs['proxy_protocol'] = getattr(self.cfg, 'proxy_protocol', 'off')
self._callback_parser = parser_class(**parser_kwargs)
def _on_headers_complete(self):
"""Callback: request headers are complete."""
@ -729,14 +752,17 @@ class ASGIProtocol(asyncio.Protocol):
request = self._current_request
# If PROXY protocol provided a real client address, use it.
effective_peer = self._effective_peername(peername)
# Check for WebSocket upgrade
if self._is_websocket_upgrade(request):
await self._handle_websocket(request, sockname, peername)
await self._handle_websocket(request, sockname, effective_peer)
break # WebSocket takes over the connection
# Handle HTTP request
keepalive = await self._handle_http_request(
request, sockname, peername
request, sockname, effective_peer
)
# Increment worker request count
@ -755,6 +781,18 @@ class ASGIProtocol(asyncio.Protocol):
if not self.cfg.keepalive:
break
# Refuse keepalive if the previous request body was not fully
# framed: residual bytes left in the transport stream would be
# parsed as the start of the next request (smuggling).
receiver = self._body_receiver
message_complete = (
receiver is None
or receiver._complete
or receiver._closed
)
if not message_complete:
break
# Resume reading if paused during body consumption
self._resume_reading()
@ -918,15 +956,15 @@ class ASGIProtocol(asyncio.Protocol):
has_transfer_encoding = True
use_chunked = True # Framework already set chunked encoding
# Use chunked encoding for HTTP/1.1 streaming responses without Content-Length
# Skip for 1xx informational responses (RFC 9110)
# Skip if Transfer-Encoding already set by framework
is_informational = 100 <= response_status < 200
# Use chunked encoding for HTTP/1.1 streaming responses without Content-Length.
# Skip when the response cannot carry a body (HEAD/1xx/204/304) or when
# Transfer-Encoding was already set by the framework.
is_no_body = self._response_omits_body(request.method, response_status)
needs_chunked = (
not has_content_length
and not has_transfer_encoding
and request.version >= (1, 1)
and not is_informational
and not is_no_body
)
if needs_chunked:
use_chunked = True
@ -1201,6 +1239,36 @@ class ASGIProtocol(asyncio.Protocol):
# Buffer headers for batching with first body chunk
self._response_buffer = b"".join(parts)
def _effective_peername(self, peername):
"""Return the client address advertised via PROXY protocol if any.
Falls back to the transport peername when PROXY protocol is disabled,
the framing was absent, or the parser is the C variant (which currently
does not surface PROXY metadata).
"""
parser = self._callback_parser
info = getattr(parser, 'proxy_protocol_info', None) if parser else None
if not info:
return peername
client_addr = info.get('client_addr')
client_port = info.get('client_port')
if client_addr is None or client_port is None:
return peername
return (client_addr, client_port)
@staticmethod
def _response_omits_body(method, status):
"""Return True when the response MUST NOT have a body (RFC 9110).
Applies to HEAD requests and to status codes that semantically carry no body:
1xx informational, 204 No Content, 304 Not Modified.
"""
return (
method == "HEAD"
or status in (204, 304)
or 100 <= status < 200
)
def _send_body(self, body, chunked=False):
"""Send response body chunk.

View File

@ -206,26 +206,94 @@ class Message:
def parse(self, unreader):
raise NotImplementedError()
def parse_headers(self, data, from_trailer=False):
def _peer_trusted_for_forwarded(self):
"""Return the (secure_scheme_headers, forwarder_headers) the peer is allowed to set.
When the peer's address is not in ``forwarded_allow_ips`` (or networks),
configured forwarding/secure-scheme policy must be ignored to prevent
spoofing. Returns ``({}, [])`` when the peer is untrusted.
"""
cfg = self.cfg
if (not isinstance(self.peer_addr, tuple)
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
cfg.forwarded_allow_networks())):
return cfg.secure_scheme_headers, cfg.forwarder_headers
return {}, []
def _apply_header_policy(self, name, value, scheme_state,
secure_scheme_headers, forwarder_headers,
from_trailer=False):
"""Apply per-header policy shared between Python and fast parsers.
Mutates ``self._expected_100_continue`` and ``self.scheme`` as needed.
``scheme_state`` is a single-element list used as a mutable sentinel
so the caller can detect repeated scheme headers.
Returns the (name, value) pair to retain, or ``None`` to drop the
header (per ``header_map='drop'``). Raises the same exceptions the
Python path raises so behavior is identical regardless of parser.
"""
if not from_trailer and name == "EXPECT":
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
# "The Expect field value is case-insensitive."
if value.lower() == "100-continue":
if self.version < (1, 1):
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
# "A server that receives a 100-continue expectation
# in an HTTP/1.0 request MUST ignore that expectation."
pass
else:
self._expected_100_continue = True
# N.B. understood but ignored expect header does not return 417
else:
raise ExpectationFailed(value)
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_state[0]:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_state[0] = True
self.scheme = scheme
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
# X-Forwarded-For: 2001:db8::ha:cc:ed
# X_Forwarded_For: 127.0.0.1,::1
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
# Only modify after fixing *ALL* header transformations; network to wsgi env
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
# This forwarder may override our environment
pass
elif self.cfg.header_map == "dangerous":
# as if we did not know we cannot safely map this
pass
elif self.cfg.header_map == "drop":
# almost as if it never had been there
# but still counts against resource limits
return None
else:
# fail-safe fallthrough: refuse
raise InvalidHeaderName(name)
return (name, value)
def parse_headers(self, data, from_trailer=False):
headers = []
# Split lines on \r\n
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
# handle scheme headers
scheme_header = False
secure_scheme_headers = {}
forwarder_headers = []
scheme_state = [False]
if from_trailer:
# nonsense. either a request is https from the beginning
# .. or we are just behind a proxy who does not remove conflicting trailers
pass
elif (not isinstance(self.peer_addr, tuple)
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
cfg.forwarded_allow_networks())):
secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers
secure_scheme_headers, forwarder_headers = {}, []
else:
secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
# Parse headers into key/value pairs paying attention
# to continuation lines.
@ -275,52 +343,14 @@ class Message:
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
if not from_trailer and name == "EXPECT":
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
# "The Expect field value is case-insensitive."
if value.lower() == "100-continue":
if self.version < (1, 1):
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
# "A server that receives a 100-continue expectation
# in an HTTP/1.0 request MUST ignore that expectation."
pass
else:
self._expected_100_continue = True
# N.B. understood but ignored expect header does not return 417
else:
raise ExpectationFailed(value)
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_header:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_header = True
self.scheme = scheme
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
# X-Forwarded-For: 2001:db8::ha:cc:ed
# X_Forwarded_For: 127.0.0.1,::1
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
# Only modify after fixing *ALL* header transformations; network to wsgi env
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
# This forwarder may override our environment
pass
elif self.cfg.header_map == "dangerous":
# as if we did not know we cannot safely map this
pass
elif self.cfg.header_map == "drop":
# almost as if it never had been there
# but still counts against resource limits
continue
else:
# fail-safe fallthrough: refuse
raise InvalidHeaderName(name)
headers.append((name, value))
kept = self._apply_header_policy(
name, value, scheme_state,
secure_scheme_headers, forwarder_headers,
from_trailer=from_trailer,
)
if kept is None:
continue
headers.append(kept)
return headers
@ -516,25 +546,23 @@ class Request(Message):
# Headers - convert bytes to strings with uppercase names
# gunicorn_h1c returns headers as (bytes, bytes) tuples
# Header name/value validation done by C parser
# Header name/value validation done by C parser; policy (Expect,
# secure_scheme_headers, forwarder trust gate, header_map) is enforced
# below so the fast path mirrors parse_headers().
self.headers = []
scheme_state = [False]
secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
for name_bytes, value_bytes in result['headers']:
name = bytes_to_str(name_bytes).upper()
value = bytes_to_str(value_bytes)
# Handle underscore in header names (policy decision, not validation)
if "_" in name:
forwarder_headers = self.cfg.forwarder_headers
if name in forwarder_headers or "*" in forwarder_headers:
pass
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue
else:
raise InvalidHeaderName(name)
self.headers.append((name, value))
kept = self._apply_header_policy(
name, value, scheme_state,
secure_scheme_headers, forwarder_headers,
)
if kept is None:
continue
self.headers.append(kept)
# Return remaining data after headers
consumed = result['consumed']

View File

@ -2,12 +2,20 @@
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import socket
import ssl
import time
from gunicorn.http.message import Request
from gunicorn.http.unreader import SocketUnreader, IterUnreader
# Cap on bytes drained from an unconsumed request body before a keepalive
# reset. Defends against a slow-but-steady client that stays under a per-read
# deadline yet streams indefinitely.
_DRAIN_MAX_BYTES = 64 * 1024
class Parser:
mesg_class = None
@ -27,21 +35,61 @@ class Parser:
def __iter__(self):
return self
def finish_body(self):
def finish_body(self, deadline=None, max_bytes=_DRAIN_MAX_BYTES):
"""Discard any unread body of the current message.
This should be called before returning a keepalive connection to
the poller to ensure the socket doesn't appear readable due to
leftover body bytes.
Called before returning a keepalive connection to the poller so the
socket does not appear readable due to leftover body bytes.
``deadline`` is an absolute ``time.monotonic()`` value; when set the
socket read timeout is bounded by the remaining time before each read.
``max_bytes`` caps the total drained bytes to defend against a slow
client that keeps trickling under the deadline.
Returns ``True`` when the body was fully drained, ``False`` when the
drain was abandoned (deadline, byte cap, or socket timeout). Callers
that observe ``False`` MUST close the connection rather than serve
another request on it.
"""
if self.mesg:
try:
data = self.mesg.body.read(1024)
while data:
if not self.mesg:
return True
sock = getattr(self.unreader, "sock", None)
# gettimeout/settimeout only matter when bounding a real socket; a
# mock or non-socket source skips the timeout plumbing.
if sock is not None and hasattr(sock, "gettimeout") and hasattr(sock, "settimeout"):
timeoutable_sock = sock
prior_timeout = sock.gettimeout()
else:
timeoutable_sock = None
prior_timeout = None
drained = 0
try:
while True:
if deadline is not None and timeoutable_sock is not None:
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
timeoutable_sock.settimeout(remaining)
try:
data = self.mesg.body.read(1024)
except ssl.SSLWantReadError:
# SSL socket has no more application data available
pass
except (socket.timeout, TimeoutError):
return False
except ssl.SSLWantReadError:
# SSL socket has no more application data available
return True
if not data:
return True
drained += len(data)
if drained >= max_bytes:
return False
finally:
if timeoutable_sock is not None:
try:
timeoutable_sock.settimeout(prior_timeout)
except OSError:
pass
def __next__(self):
# Stop if HTTP dictates a stop.

View File

@ -478,9 +478,14 @@ class ThreadWorker(base.Worker):
# Handle the request
keepalive = self.handle_request(req, conn)
if keepalive:
# Discard any unread request body before keepalive
# to prevent socket appearing readable due to leftover bytes
conn.parser.finish_body()
# Discard any unread request body before keepalive to prevent
# the socket from appearing readable due to leftover bytes.
# Bound the drain by the worker data timeout: a stalled client
# must not keep this thread blocked.
drain_deadline = time.monotonic() + DEFAULT_WORKER_DATA_TIMEOUT
if not conn.parser.finish_body(deadline=drain_deadline):
# Abandon keepalive when the body could not be fully drained.
return False
return True
except http.errors.NoMoreData as e:
self.log.debug("Ignored premature client disconnection. %s", e)

View File

@ -204,3 +204,65 @@ class TestASGIDisconnectGracePeriod:
from gunicorn.config import Config
cfg = Config()
assert cfg.asgi_disconnect_grace_period == 3
class TestBodyReceiverIncompleteBody:
"""Cover the receive() path when the request body never finishes framing."""
@pytest.fixture
def mock_worker(self):
worker = mock.Mock()
worker.nr_conns = 0
worker.loop = asyncio.new_event_loop()
worker.cfg = mock.Mock()
worker.cfg.asgi_disconnect_grace_period = 3
worker.cfg.timeout = 0.05 # tight bound for the test
worker.log = mock.Mock()
return worker
@pytest.mark.asyncio
async def test_receive_yields_disconnect_on_timeout(self, mock_worker):
"""When _wait_for_data times out and the body is not complete, the
receiver MUST yield http.disconnect rather than synthesize a terminal
http.request with more_body=False that would desync the next
pipelined request."""
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
request = mock.Mock()
request.content_length = 100
request.chunked = False
receiver = BodyReceiver(request, protocol)
protocol._body_receiver = receiver
msg = await receiver.receive()
assert msg == {"type": "http.disconnect"}
assert receiver._closed is True
@pytest.mark.asyncio
async def test_receive_yields_terminal_request_when_complete(self, mock_worker):
"""If the body is framed complete, the existing terminal http.request
with more_body=False must still be returned."""
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
request = mock.Mock()
request.content_length = 5
request.chunked = False
receiver = BodyReceiver(request, protocol)
protocol._body_receiver = receiver
receiver.feed(b"hello")
receiver.set_complete()
msg = await receiver.receive()
assert msg["type"] == "http.request"
assert msg["body"] == b"hello"
# more_body may be False since the body is complete
assert msg["more_body"] is False

View File

@ -344,6 +344,43 @@ class TestHTTPVersionForChunked:
assert uses_http1_chunked is False
# ============================================================================
# No-Body Response Tests (RFC 9110)
# ============================================================================
class TestResponseOmitsBody:
"""Verify HEAD/1xx/204/304 are flagged as bodyless responses."""
def _omits(self, method, status):
from gunicorn.asgi.protocol import ASGIProtocol
return ASGIProtocol._response_omits_body(method, status)
def test_head_omits_body(self):
assert self._omits("HEAD", 200) is True
assert self._omits("HEAD", 500) is True
def test_204_omits_body(self):
assert self._omits("GET", 204) is True
assert self._omits("POST", 204) is True
def test_304_omits_body(self):
assert self._omits("GET", 304) is True
def test_informational_omits_body(self):
assert self._omits("GET", 100) is True
assert self._omits("GET", 103) is True
assert self._omits("GET", 199) is True
def test_get_200_has_body(self):
assert self._omits("GET", 200) is False
def test_post_200_has_body(self):
assert self._omits("POST", 200) is False
def test_404_has_body(self):
assert self._omits("GET", 404) is False
# ============================================================================
# Streaming Response Message Sequence Tests
# ============================================================================

View File

@ -45,10 +45,6 @@ def test_asgi_parser(fname):
if getattr(cfg, flag, False):
pytest.skip(f"Callback parser incompatible with {flag}")
# Skip proxy protocol tests
if getattr(cfg, 'proxy_protocol', 'off') != 'off':
pytest.skip("Callback parser does not support proxy_protocol")
req = treq_asgi.request(fname, expect)
# Test with different sending strategies

View File

@ -582,6 +582,59 @@ class TestASGIProtocol:
assert scope["root_path"] == "/api"
assert scope["http_version"] == "1.1"
def test_effective_peername_no_proxy(self):
"""Without PROXY framing the transport peername is returned as-is."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info=None)
peer = ("10.0.0.1", 12345)
assert protocol._effective_peername(peer) == peer
def test_effective_peername_with_proxy(self):
"""PROXY-supplied client address overrides the transport peername."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info={
'proxy_protocol': 'TCP4',
'client_addr': '203.0.113.5',
'client_port': 56324,
'proxy_addr': '10.0.0.2',
'proxy_port': 443,
})
assert protocol._effective_peername(("10.0.0.1", 1)) == ("203.0.113.5", 56324)
def test_effective_peername_unknown_proxy(self):
"""UNKNOWN PROXY framing has no client info; fall back to transport peername."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._callback_parser = mock.Mock(proxy_protocol_info={
'proxy_protocol': 'UNKNOWN',
'client_addr': None,
'client_port': None,
'proxy_addr': None,
'proxy_port': None,
})
peer = ("10.0.0.1", 12345)
assert protocol._effective_peername(peer) == peer
# ============================================================================
# Config Tests

View File

@ -0,0 +1,185 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Parity tests for WSGI header policy across Python and fast parsers.
These checks ensure that Expect, secure_scheme_headers, forwarder_headers,
and the forwarded_allow_ips trust gate are enforced identically regardless
of the parser implementation selected by ``http_parser``.
"""
import sys
import pytest
from gunicorn.config import Config
from gunicorn.http.parser import RequestParser
from gunicorn.http.errors import (
ExpectationFailed,
InvalidHeaderName,
InvalidSchemeHeaders,
)
def _parse(raw, cfg, peer_addr):
parser = RequestParser(cfg, iter([raw]), peer_addr)
return next(iter(parser))
def _cfg(http_parser, **overrides):
cfg = Config()
cfg.set("http_parser", http_parser)
for k, v in overrides.items():
cfg.set(k, v)
return cfg
@pytest.fixture(params=["python", "fast"])
def parser_name(request):
if request.param == "fast":
if hasattr(sys, "pypy_version_info"):
pytest.skip("gunicorn_h1c not supported on PyPy")
gunicorn_h1c = pytest.importorskip("gunicorn_h1c")
if not hasattr(gunicorn_h1c.H1CProtocol, "asgi_headers"):
pytest.skip("gunicorn_h1c >= 0.6.2 required")
return request.param
class TestExpectPolicy:
def test_expect_100_continue_sets_flag(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: 100-continue\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req._expected_100_continue is True
def test_expect_unknown_value_rejected(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: bogus-extension\r\n"
b"\r\n"
)
with pytest.raises(ExpectationFailed):
_parse(raw, cfg, ("127.0.0.1", 1234))
def test_expect_ignored_in_http10(self, parser_name):
cfg = _cfg(parser_name)
raw = (
b"POST / HTTP/1.0\r\n"
b"Host: example.com\r\n"
b"Content-Length: 0\r\n"
b"Expect: 100-continue\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req._expected_100_continue is False
class TestSecureSchemeHeaders:
def test_trusted_peer_promotes_https(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
assert req.scheme == "https"
def test_untrusted_peer_keeps_http(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("203.0.113.5", 1234))
assert req.scheme == "http"
def test_conflicting_scheme_headers_rejected(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
secure_scheme_headers={
"X-FORWARDED-PROTO": "https",
"X-FORWARDED-SSL": "on",
},
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Forwarded-Proto: https\r\n"
b"X-Forwarded-Ssl: off\r\n"
b"\r\n"
)
with pytest.raises(InvalidSchemeHeaders):
_parse(raw, cfg, ("127.0.0.1", 1234))
class TestForwarderTrustGate:
def test_untrusted_peer_underscore_header_rejected(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
forwarder_headers="SCRIPT_NAME",
header_map="refuse",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Script_Name: /evil\r\n"
b"\r\n"
)
with pytest.raises(InvalidHeaderName):
_parse(raw, cfg, ("203.0.113.5", 1234))
def test_trusted_peer_underscore_header_accepted(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
forwarder_headers="SCRIPT_NAME",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Script_Name: /api\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("127.0.0.1", 1234))
names = {n for n, _ in req.headers}
assert "SCRIPT_NAME" in names
def test_header_map_drop_silences_underscore(self, parser_name):
cfg = _cfg(
parser_name,
forwarded_allow_ips="127.0.0.1",
header_map="drop",
)
raw = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Stray_Name: x\r\n"
b"\r\n"
)
req = _parse(raw, cfg, ("203.0.113.5", 1234))
names = {n for n, _ in req.headers}
assert "STRAY_NAME" not in names

View File

@ -255,6 +255,62 @@ def test_invalid_http_version_error():
assert str(InvalidHTTPVersion((2, 1))) == 'Invalid HTTP Version: (2, 1)'
def _build_request_parser(payload):
"""Construct a RequestParser that drains the given bytes."""
from gunicorn.config import Config
from gunicorn.http.parser import RequestParser
cfg = Config()
parser = RequestParser(cfg, iter([payload]), None)
next(iter(parser))
return parser
def test_finish_body_drains_remainder():
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
parser = _build_request_parser(payload)
assert parser.finish_body() is True
def test_finish_body_returns_false_when_byte_cap_exceeded():
body = b"x" * (4096)
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: %d\r\n\r\n%s" % (len(body), body)
)
parser = _build_request_parser(payload)
assert parser.finish_body(max_bytes=512) is False
def test_finish_body_returns_false_on_expired_deadline():
payload = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 100\r\n"
b"\r\n"
b"only-partial"
)
parser = _build_request_parser(payload)
# Force an already-elapsed deadline; the drain must abandon immediately.
import time as _time
expired = _time.monotonic() - 1.0
# IterUnreader has no socket; deadline path is exercised only when sock
# is present. Stub a sock with gettimeout/settimeout to drive the branch.
from unittest import mock
sock = mock.Mock()
sock.gettimeout.return_value = None
parser.unreader.sock = sock
assert parser.finish_body(deadline=expired) is False
sock.settimeout.assert_called_with(None)
def test_file_wrapper_iterable():
"""FileWrapper should support the iterator protocol per PEP 3333."""
filelike = io.BytesIO(b"hello world")