mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
fix: address six WSGI/ASGI parser and protocol findings
- WSGI fast parser now applies the same per-header policy as the Python parser (Expect, secure_scheme_headers, forwarded_allow_ips trust gate, forwarder_headers / header_map). Shared helpers extracted on Message. - ASGI keepalive no longer resets the parser when the previous request body was not fully framed; the connection closes instead, preventing request smuggling on pipelined connections. - BodyReceiver._wait_for_data timeout flips _closed and yields http.disconnect rather than synthesizing more_body=False. Timeout honors cfg.timeout. - ASGI chunked encoding now skips HEAD, 204, and 304 (matches Response.is_chunked in the WSGI path) via a small helper. - _setup_callback_parser passes proxy_protocol to PythonProtocol; auto falls back to the Python parser when proxy_protocol != off (the C parser does not implement PROXY framing). _effective_peername swaps the transport peer with the PROXY-supplied client address. - Parser.finish_body accepts a deadline and a 64KiB byte cap; gthread passes a deadline and abandons keepalive on incomplete drain so a stalled client cannot tie up a worker thread.
This commit is contained in:
parent
4bcda32a78
commit
e90b1c2c1e
@ -254,10 +254,15 @@ class BodyReceiver:
|
||||
if self._chunks:
|
||||
return self._pop_chunk()
|
||||
|
||||
# Complete OR timeout - mark body finished to prevent infinite loops
|
||||
# Apps should not loop forever waiting for body that won't arrive
|
||||
self._body_finished = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
if self._complete:
|
||||
self._body_finished = True
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
|
||||
# Wait returned without data and the message was not framed complete:
|
||||
# treat as a client disconnect rather than synthesizing end-of-body
|
||||
# (which would desync the next pipelined request).
|
||||
self._closed = True
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
async def _wait_for_data(self):
|
||||
"""Wait for body data to arrive via callback."""
|
||||
@ -268,11 +273,21 @@ class BodyReceiver:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._waiter = loop.create_future()
|
||||
|
||||
# Bound the wait by the configured worker timeout (default 30s).
|
||||
# The protocol-level timeout drives transport disconnect handling;
|
||||
# this only needs to escape an idle wait if data never arrives.
|
||||
cfg = getattr(self.protocol, 'cfg', None)
|
||||
timeout = getattr(cfg, 'timeout', None) if cfg is not None else None
|
||||
if not timeout or timeout <= 0:
|
||||
timeout = 30.0
|
||||
|
||||
try:
|
||||
# Wait with timeout for data or completion
|
||||
await asyncio.wait_for(self._waiter, timeout=30.0)
|
||||
await asyncio.wait_for(self._waiter, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
# No data arrived in time: mark the body receiver as disconnected
|
||||
# so receive() yields http.disconnect rather than a fake terminal
|
||||
# http.request with more_body=False.
|
||||
self._closed = True
|
||||
finally:
|
||||
self._waiter = None
|
||||
|
||||
@ -436,6 +451,10 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS:
|
||||
if getattr(self.cfg, flag, False):
|
||||
incompatible.append(flag)
|
||||
# PROXY protocol framing is implemented only in PythonProtocol; the C parser
|
||||
# has no proxy_protocol kwarg and would silently drop the framing.
|
||||
if getattr(self.cfg, 'proxy_protocol', 'off') != 'off':
|
||||
incompatible.append('proxy_protocol')
|
||||
|
||||
if parser_setting == 'python':
|
||||
parser_class = PythonProtocol
|
||||
@ -467,8 +486,9 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if limit_request_line == 0 and parser_class != PythonProtocol:
|
||||
limit_request_line = 1024 * 1024 # 1MB for C parser
|
||||
|
||||
# Create parser with callbacks and limit parameters (both parsers support them)
|
||||
self._callback_parser = parser_class(
|
||||
# Create parser with callbacks and limit parameters (both parsers support them).
|
||||
# Only the Python parser implements PROXY protocol framing; pass the option there.
|
||||
parser_kwargs = dict(
|
||||
on_headers_complete=self._on_headers_complete,
|
||||
on_body=self._on_body,
|
||||
on_message_complete=self._on_message_complete,
|
||||
@ -478,6 +498,9 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
|
||||
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
|
||||
)
|
||||
if parser_class is PythonProtocol:
|
||||
parser_kwargs['proxy_protocol'] = getattr(self.cfg, 'proxy_protocol', 'off')
|
||||
self._callback_parser = parser_class(**parser_kwargs)
|
||||
|
||||
def _on_headers_complete(self):
|
||||
"""Callback: request headers are complete."""
|
||||
@ -729,14 +752,17 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
|
||||
request = self._current_request
|
||||
|
||||
# If PROXY protocol provided a real client address, use it.
|
||||
effective_peer = self._effective_peername(peername)
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
await self._handle_websocket(request, sockname, effective_peer)
|
||||
break # WebSocket takes over the connection
|
||||
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
request, sockname, effective_peer
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
@ -755,6 +781,18 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
# Refuse keepalive if the previous request body was not fully
|
||||
# framed: residual bytes left in the transport stream would be
|
||||
# parsed as the start of the next request (smuggling).
|
||||
receiver = self._body_receiver
|
||||
message_complete = (
|
||||
receiver is None
|
||||
or receiver._complete
|
||||
or receiver._closed
|
||||
)
|
||||
if not message_complete:
|
||||
break
|
||||
|
||||
# Resume reading if paused during body consumption
|
||||
self._resume_reading()
|
||||
|
||||
@ -918,15 +956,15 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
has_transfer_encoding = True
|
||||
use_chunked = True # Framework already set chunked encoding
|
||||
|
||||
# Use chunked encoding for HTTP/1.1 streaming responses without Content-Length
|
||||
# Skip for 1xx informational responses (RFC 9110)
|
||||
# Skip if Transfer-Encoding already set by framework
|
||||
is_informational = 100 <= response_status < 200
|
||||
# Use chunked encoding for HTTP/1.1 streaming responses without Content-Length.
|
||||
# Skip when the response cannot carry a body (HEAD/1xx/204/304) or when
|
||||
# Transfer-Encoding was already set by the framework.
|
||||
is_no_body = self._response_omits_body(request.method, response_status)
|
||||
needs_chunked = (
|
||||
not has_content_length
|
||||
and not has_transfer_encoding
|
||||
and request.version >= (1, 1)
|
||||
and not is_informational
|
||||
and not is_no_body
|
||||
)
|
||||
if needs_chunked:
|
||||
use_chunked = True
|
||||
@ -1201,6 +1239,36 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
# Buffer headers for batching with first body chunk
|
||||
self._response_buffer = b"".join(parts)
|
||||
|
||||
def _effective_peername(self, peername):
|
||||
"""Return the client address advertised via PROXY protocol if any.
|
||||
|
||||
Falls back to the transport peername when PROXY protocol is disabled,
|
||||
the framing was absent, or the parser is the C variant (which currently
|
||||
does not surface PROXY metadata).
|
||||
"""
|
||||
parser = self._callback_parser
|
||||
info = getattr(parser, 'proxy_protocol_info', None) if parser else None
|
||||
if not info:
|
||||
return peername
|
||||
client_addr = info.get('client_addr')
|
||||
client_port = info.get('client_port')
|
||||
if client_addr is None or client_port is None:
|
||||
return peername
|
||||
return (client_addr, client_port)
|
||||
|
||||
@staticmethod
|
||||
def _response_omits_body(method, status):
|
||||
"""Return True when the response MUST NOT have a body (RFC 9110).
|
||||
|
||||
Applies to HEAD requests and to status codes that semantically carry no body:
|
||||
1xx informational, 204 No Content, 304 Not Modified.
|
||||
"""
|
||||
return (
|
||||
method == "HEAD"
|
||||
or status in (204, 304)
|
||||
or 100 <= status < 200
|
||||
)
|
||||
|
||||
def _send_body(self, body, chunked=False):
|
||||
"""Send response body chunk.
|
||||
|
||||
|
||||
@ -206,26 +206,94 @@ class Message:
|
||||
def parse(self, unreader):
|
||||
raise NotImplementedError()
|
||||
|
||||
def parse_headers(self, data, from_trailer=False):
|
||||
def _peer_trusted_for_forwarded(self):
|
||||
"""Return the (secure_scheme_headers, forwarder_headers) the peer is allowed to set.
|
||||
|
||||
When the peer's address is not in ``forwarded_allow_ips`` (or networks),
|
||||
configured forwarding/secure-scheme policy must be ignored to prevent
|
||||
spoofing. Returns ``({}, [])`` when the peer is untrusted.
|
||||
"""
|
||||
cfg = self.cfg
|
||||
if (not isinstance(self.peer_addr, tuple)
|
||||
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
|
||||
cfg.forwarded_allow_networks())):
|
||||
return cfg.secure_scheme_headers, cfg.forwarder_headers
|
||||
return {}, []
|
||||
|
||||
def _apply_header_policy(self, name, value, scheme_state,
|
||||
secure_scheme_headers, forwarder_headers,
|
||||
from_trailer=False):
|
||||
"""Apply per-header policy shared between Python and fast parsers.
|
||||
|
||||
Mutates ``self._expected_100_continue`` and ``self.scheme`` as needed.
|
||||
``scheme_state`` is a single-element list used as a mutable sentinel
|
||||
so the caller can detect repeated scheme headers.
|
||||
|
||||
Returns the (name, value) pair to retain, or ``None`` to drop the
|
||||
header (per ``header_map='drop'``). Raises the same exceptions the
|
||||
Python path raises so behavior is identical regardless of parser.
|
||||
"""
|
||||
if not from_trailer and name == "EXPECT":
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
|
||||
# "The Expect field value is case-insensitive."
|
||||
if value.lower() == "100-continue":
|
||||
if self.version < (1, 1):
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
|
||||
# "A server that receives a 100-continue expectation
|
||||
# in an HTTP/1.0 request MUST ignore that expectation."
|
||||
pass
|
||||
else:
|
||||
self._expected_100_continue = True
|
||||
# N.B. understood but ignored expect header does not return 417
|
||||
else:
|
||||
raise ExpectationFailed(value)
|
||||
|
||||
if name in secure_scheme_headers:
|
||||
secure = value == secure_scheme_headers[name]
|
||||
scheme = "https" if secure else "http"
|
||||
if scheme_state[0]:
|
||||
if scheme != self.scheme:
|
||||
raise InvalidSchemeHeaders()
|
||||
else:
|
||||
scheme_state[0] = True
|
||||
self.scheme = scheme
|
||||
|
||||
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
|
||||
# X-Forwarded-For: 2001:db8::ha:cc:ed
|
||||
# X_Forwarded_For: 127.0.0.1,::1
|
||||
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
|
||||
# Only modify after fixing *ALL* header transformations; network to wsgi env
|
||||
if "_" in name:
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
# This forwarder may override our environment
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
# as if we did not know we cannot safely map this
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
# almost as if it never had been there
|
||||
# but still counts against resource limits
|
||||
return None
|
||||
else:
|
||||
# fail-safe fallthrough: refuse
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
return (name, value)
|
||||
|
||||
def parse_headers(self, data, from_trailer=False):
|
||||
headers = []
|
||||
|
||||
# Split lines on \r\n
|
||||
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
|
||||
|
||||
# handle scheme headers
|
||||
scheme_header = False
|
||||
secure_scheme_headers = {}
|
||||
forwarder_headers = []
|
||||
scheme_state = [False]
|
||||
if from_trailer:
|
||||
# nonsense. either a request is https from the beginning
|
||||
# .. or we are just behind a proxy who does not remove conflicting trailers
|
||||
pass
|
||||
elif (not isinstance(self.peer_addr, tuple)
|
||||
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
|
||||
cfg.forwarded_allow_networks())):
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
forwarder_headers = cfg.forwarder_headers
|
||||
secure_scheme_headers, forwarder_headers = {}, []
|
||||
else:
|
||||
secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
|
||||
|
||||
# Parse headers into key/value pairs paying attention
|
||||
# to continuation lines.
|
||||
@ -275,52 +343,14 @@ class Message:
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
|
||||
if not from_trailer and name == "EXPECT":
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
|
||||
# "The Expect field value is case-insensitive."
|
||||
if value.lower() == "100-continue":
|
||||
if self.version < (1, 1):
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
|
||||
# "A server that receives a 100-continue expectation
|
||||
# in an HTTP/1.0 request MUST ignore that expectation."
|
||||
pass
|
||||
else:
|
||||
self._expected_100_continue = True
|
||||
# N.B. understood but ignored expect header does not return 417
|
||||
else:
|
||||
raise ExpectationFailed(value)
|
||||
|
||||
if name in secure_scheme_headers:
|
||||
secure = value == secure_scheme_headers[name]
|
||||
scheme = "https" if secure else "http"
|
||||
if scheme_header:
|
||||
if scheme != self.scheme:
|
||||
raise InvalidSchemeHeaders()
|
||||
else:
|
||||
scheme_header = True
|
||||
self.scheme = scheme
|
||||
|
||||
# ambiguous mapping allows fooling downstream, e.g. merging non-identical headers:
|
||||
# X-Forwarded-For: 2001:db8::ha:cc:ed
|
||||
# X_Forwarded_For: 127.0.0.1,::1
|
||||
# HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1
|
||||
# Only modify after fixing *ALL* header transformations; network to wsgi env
|
||||
if "_" in name:
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
# This forwarder may override our environment
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
# as if we did not know we cannot safely map this
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
# almost as if it never had been there
|
||||
# but still counts against resource limits
|
||||
continue
|
||||
else:
|
||||
# fail-safe fallthrough: refuse
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
headers.append((name, value))
|
||||
kept = self._apply_header_policy(
|
||||
name, value, scheme_state,
|
||||
secure_scheme_headers, forwarder_headers,
|
||||
from_trailer=from_trailer,
|
||||
)
|
||||
if kept is None:
|
||||
continue
|
||||
headers.append(kept)
|
||||
|
||||
return headers
|
||||
|
||||
@ -516,25 +546,23 @@ class Request(Message):
|
||||
|
||||
# Headers - convert bytes to strings with uppercase names
|
||||
# gunicorn_h1c returns headers as (bytes, bytes) tuples
|
||||
# Header name/value validation done by C parser
|
||||
# Header name/value validation done by C parser; policy (Expect,
|
||||
# secure_scheme_headers, forwarder trust gate, header_map) is enforced
|
||||
# below so the fast path mirrors parse_headers().
|
||||
self.headers = []
|
||||
scheme_state = [False]
|
||||
secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded()
|
||||
for name_bytes, value_bytes in result['headers']:
|
||||
name = bytes_to_str(name_bytes).upper()
|
||||
value = bytes_to_str(value_bytes)
|
||||
|
||||
# Handle underscore in header names (policy decision, not validation)
|
||||
if "_" in name:
|
||||
forwarder_headers = self.cfg.forwarder_headers
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
continue
|
||||
else:
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
self.headers.append((name, value))
|
||||
kept = self._apply_header_policy(
|
||||
name, value, scheme_state,
|
||||
secure_scheme_headers, forwarder_headers,
|
||||
)
|
||||
if kept is None:
|
||||
continue
|
||||
self.headers.append(kept)
|
||||
|
||||
# Return remaining data after headers
|
||||
consumed = result['consumed']
|
||||
|
||||
@ -2,12 +2,20 @@
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
|
||||
from gunicorn.http.message import Request
|
||||
from gunicorn.http.unreader import SocketUnreader, IterUnreader
|
||||
|
||||
|
||||
# Cap on bytes drained from an unconsumed request body before a keepalive
|
||||
# reset. Defends against a slow-but-steady client that stays under a per-read
|
||||
# deadline yet streams indefinitely.
|
||||
_DRAIN_MAX_BYTES = 64 * 1024
|
||||
|
||||
|
||||
class Parser:
|
||||
|
||||
mesg_class = None
|
||||
@ -27,21 +35,61 @@ class Parser:
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def finish_body(self):
|
||||
def finish_body(self, deadline=None, max_bytes=_DRAIN_MAX_BYTES):
|
||||
"""Discard any unread body of the current message.
|
||||
|
||||
This should be called before returning a keepalive connection to
|
||||
the poller to ensure the socket doesn't appear readable due to
|
||||
leftover body bytes.
|
||||
Called before returning a keepalive connection to the poller so the
|
||||
socket does not appear readable due to leftover body bytes.
|
||||
|
||||
``deadline`` is an absolute ``time.monotonic()`` value; when set the
|
||||
socket read timeout is bounded by the remaining time before each read.
|
||||
``max_bytes`` caps the total drained bytes to defend against a slow
|
||||
client that keeps trickling under the deadline.
|
||||
|
||||
Returns ``True`` when the body was fully drained, ``False`` when the
|
||||
drain was abandoned (deadline, byte cap, or socket timeout). Callers
|
||||
that observe ``False`` MUST close the connection rather than serve
|
||||
another request on it.
|
||||
"""
|
||||
if self.mesg:
|
||||
try:
|
||||
data = self.mesg.body.read(1024)
|
||||
while data:
|
||||
if not self.mesg:
|
||||
return True
|
||||
|
||||
sock = getattr(self.unreader, "sock", None)
|
||||
# gettimeout/settimeout only matter when bounding a real socket; a
|
||||
# mock or non-socket source skips the timeout plumbing.
|
||||
if sock is not None and hasattr(sock, "gettimeout") and hasattr(sock, "settimeout"):
|
||||
timeoutable_sock = sock
|
||||
prior_timeout = sock.gettimeout()
|
||||
else:
|
||||
timeoutable_sock = None
|
||||
prior_timeout = None
|
||||
|
||||
drained = 0
|
||||
try:
|
||||
while True:
|
||||
if deadline is not None and timeoutable_sock is not None:
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
timeoutable_sock.settimeout(remaining)
|
||||
try:
|
||||
data = self.mesg.body.read(1024)
|
||||
except ssl.SSLWantReadError:
|
||||
# SSL socket has no more application data available
|
||||
pass
|
||||
except (socket.timeout, TimeoutError):
|
||||
return False
|
||||
except ssl.SSLWantReadError:
|
||||
# SSL socket has no more application data available
|
||||
return True
|
||||
if not data:
|
||||
return True
|
||||
drained += len(data)
|
||||
if drained >= max_bytes:
|
||||
return False
|
||||
finally:
|
||||
if timeoutable_sock is not None:
|
||||
try:
|
||||
timeoutable_sock.settimeout(prior_timeout)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def __next__(self):
|
||||
# Stop if HTTP dictates a stop.
|
||||
|
||||
@ -478,9 +478,14 @@ class ThreadWorker(base.Worker):
|
||||
# Handle the request
|
||||
keepalive = self.handle_request(req, conn)
|
||||
if keepalive:
|
||||
# Discard any unread request body before keepalive
|
||||
# to prevent socket appearing readable due to leftover bytes
|
||||
conn.parser.finish_body()
|
||||
# Discard any unread request body before keepalive to prevent
|
||||
# the socket from appearing readable due to leftover bytes.
|
||||
# Bound the drain by the worker data timeout: a stalled client
|
||||
# must not keep this thread blocked.
|
||||
drain_deadline = time.monotonic() + DEFAULT_WORKER_DATA_TIMEOUT
|
||||
if not conn.parser.finish_body(deadline=drain_deadline):
|
||||
# Abandon keepalive when the body could not be fully drained.
|
||||
return False
|
||||
return True
|
||||
except http.errors.NoMoreData as e:
|
||||
self.log.debug("Ignored premature client disconnection. %s", e)
|
||||
|
||||
@ -204,3 +204,65 @@ class TestASGIDisconnectGracePeriod:
|
||||
from gunicorn.config import Config
|
||||
cfg = Config()
|
||||
assert cfg.asgi_disconnect_grace_period == 3
|
||||
|
||||
|
||||
class TestBodyReceiverIncompleteBody:
|
||||
"""Cover the receive() path when the request body never finishes framing."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_worker(self):
|
||||
worker = mock.Mock()
|
||||
worker.nr_conns = 0
|
||||
worker.loop = asyncio.new_event_loop()
|
||||
worker.cfg = mock.Mock()
|
||||
worker.cfg.asgi_disconnect_grace_period = 3
|
||||
worker.cfg.timeout = 0.05 # tight bound for the test
|
||||
worker.log = mock.Mock()
|
||||
return worker
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_yields_disconnect_on_timeout(self, mock_worker):
|
||||
"""When _wait_for_data times out and the body is not complete, the
|
||||
receiver MUST yield http.disconnect rather than synthesize a terminal
|
||||
http.request with more_body=False — that would desync the next
|
||||
pipelined request."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
|
||||
request = mock.Mock()
|
||||
request.content_length = 100
|
||||
request.chunked = False
|
||||
|
||||
receiver = BodyReceiver(request, protocol)
|
||||
protocol._body_receiver = receiver
|
||||
|
||||
msg = await receiver.receive()
|
||||
assert msg == {"type": "http.disconnect"}
|
||||
assert receiver._closed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_yields_terminal_request_when_complete(self, mock_worker):
|
||||
"""If the body is framed complete, the existing terminal http.request
|
||||
with more_body=False must still be returned."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
|
||||
request = mock.Mock()
|
||||
request.content_length = 5
|
||||
request.chunked = False
|
||||
|
||||
receiver = BodyReceiver(request, protocol)
|
||||
protocol._body_receiver = receiver
|
||||
|
||||
receiver.feed(b"hello")
|
||||
receiver.set_complete()
|
||||
|
||||
msg = await receiver.receive()
|
||||
assert msg["type"] == "http.request"
|
||||
assert msg["body"] == b"hello"
|
||||
# more_body may be False since the body is complete
|
||||
assert msg["more_body"] is False
|
||||
|
||||
@ -344,6 +344,43 @@ class TestHTTPVersionForChunked:
|
||||
assert uses_http1_chunked is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# No-Body Response Tests (RFC 9110)
|
||||
# ============================================================================
|
||||
|
||||
class TestResponseOmitsBody:
|
||||
"""Verify HEAD/1xx/204/304 are flagged as bodyless responses."""
|
||||
|
||||
def _omits(self, method, status):
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
return ASGIProtocol._response_omits_body(method, status)
|
||||
|
||||
def test_head_omits_body(self):
|
||||
assert self._omits("HEAD", 200) is True
|
||||
assert self._omits("HEAD", 500) is True
|
||||
|
||||
def test_204_omits_body(self):
|
||||
assert self._omits("GET", 204) is True
|
||||
assert self._omits("POST", 204) is True
|
||||
|
||||
def test_304_omits_body(self):
|
||||
assert self._omits("GET", 304) is True
|
||||
|
||||
def test_informational_omits_body(self):
|
||||
assert self._omits("GET", 100) is True
|
||||
assert self._omits("GET", 103) is True
|
||||
assert self._omits("GET", 199) is True
|
||||
|
||||
def test_get_200_has_body(self):
|
||||
assert self._omits("GET", 200) is False
|
||||
|
||||
def test_post_200_has_body(self):
|
||||
assert self._omits("POST", 200) is False
|
||||
|
||||
def test_404_has_body(self):
|
||||
assert self._omits("GET", 404) is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Streaming Response Message Sequence Tests
|
||||
# ============================================================================
|
||||
|
||||
@ -45,10 +45,6 @@ def test_asgi_parser(fname):
|
||||
if getattr(cfg, flag, False):
|
||||
pytest.skip(f"Callback parser incompatible with {flag}")
|
||||
|
||||
# Skip proxy protocol tests
|
||||
if getattr(cfg, 'proxy_protocol', 'off') != 'off':
|
||||
pytest.skip("Callback parser does not support proxy_protocol")
|
||||
|
||||
req = treq_asgi.request(fname, expect)
|
||||
|
||||
# Test with different sending strategies
|
||||
|
||||
@ -582,6 +582,59 @@ class TestASGIProtocol:
|
||||
assert scope["root_path"] == "/api"
|
||||
assert scope["http_version"] == "1.1"
|
||||
|
||||
def test_effective_peername_no_proxy(self):
|
||||
"""Without PROXY framing the transport peername is returned as-is."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
protocol = ASGIProtocol(worker)
|
||||
protocol._callback_parser = mock.Mock(proxy_protocol_info=None)
|
||||
|
||||
peer = ("10.0.0.1", 12345)
|
||||
assert protocol._effective_peername(peer) == peer
|
||||
|
||||
def test_effective_peername_with_proxy(self):
|
||||
"""PROXY-supplied client address overrides the transport peername."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
protocol = ASGIProtocol(worker)
|
||||
protocol._callback_parser = mock.Mock(proxy_protocol_info={
|
||||
'proxy_protocol': 'TCP4',
|
||||
'client_addr': '203.0.113.5',
|
||||
'client_port': 56324,
|
||||
'proxy_addr': '10.0.0.2',
|
||||
'proxy_port': 443,
|
||||
})
|
||||
|
||||
assert protocol._effective_peername(("10.0.0.1", 1)) == ("203.0.113.5", 56324)
|
||||
|
||||
def test_effective_peername_unknown_proxy(self):
|
||||
"""UNKNOWN PROXY framing has no client info; fall back to transport peername."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
protocol = ASGIProtocol(worker)
|
||||
protocol._callback_parser = mock.Mock(proxy_protocol_info={
|
||||
'proxy_protocol': 'UNKNOWN',
|
||||
'client_addr': None,
|
||||
'client_port': None,
|
||||
'proxy_addr': None,
|
||||
'proxy_port': None,
|
||||
})
|
||||
|
||||
peer = ("10.0.0.1", 12345)
|
||||
assert protocol._effective_peername(peer) == peer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Tests
|
||||
|
||||
185
tests/test_header_policy_parity.py
Normal file
185
tests/test_header_policy_parity.py
Normal file
@ -0,0 +1,185 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Parity tests for WSGI header policy across Python and fast parsers.
|
||||
|
||||
These checks ensure that Expect, secure_scheme_headers, forwarder_headers,
|
||||
and the forwarded_allow_ips trust gate are enforced identically regardless
|
||||
of the parser implementation selected by ``http_parser``.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.http.parser import RequestParser
|
||||
from gunicorn.http.errors import (
|
||||
ExpectationFailed,
|
||||
InvalidHeaderName,
|
||||
InvalidSchemeHeaders,
|
||||
)
|
||||
|
||||
|
||||
def _parse(raw, cfg, peer_addr):
|
||||
parser = RequestParser(cfg, iter([raw]), peer_addr)
|
||||
return next(iter(parser))
|
||||
|
||||
|
||||
def _cfg(http_parser, **overrides):
|
||||
cfg = Config()
|
||||
cfg.set("http_parser", http_parser)
|
||||
for k, v in overrides.items():
|
||||
cfg.set(k, v)
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture(params=["python", "fast"])
|
||||
def parser_name(request):
|
||||
if request.param == "fast":
|
||||
if hasattr(sys, "pypy_version_info"):
|
||||
pytest.skip("gunicorn_h1c not supported on PyPy")
|
||||
gunicorn_h1c = pytest.importorskip("gunicorn_h1c")
|
||||
if not hasattr(gunicorn_h1c.H1CProtocol, "asgi_headers"):
|
||||
pytest.skip("gunicorn_h1c >= 0.6.2 required")
|
||||
return request.param
|
||||
|
||||
|
||||
class TestExpectPolicy:
|
||||
def test_expect_100_continue_sets_flag(self, parser_name):
|
||||
cfg = _cfg(parser_name)
|
||||
raw = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 0\r\n"
|
||||
b"Expect: 100-continue\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
assert req._expected_100_continue is True
|
||||
|
||||
def test_expect_unknown_value_rejected(self, parser_name):
|
||||
cfg = _cfg(parser_name)
|
||||
raw = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 0\r\n"
|
||||
b"Expect: bogus-extension\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
with pytest.raises(ExpectationFailed):
|
||||
_parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
|
||||
def test_expect_ignored_in_http10(self, parser_name):
|
||||
cfg = _cfg(parser_name)
|
||||
raw = (
|
||||
b"POST / HTTP/1.0\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 0\r\n"
|
||||
b"Expect: 100-continue\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
assert req._expected_100_continue is False
|
||||
|
||||
|
||||
class TestSecureSchemeHeaders:
|
||||
def test_trusted_peer_promotes_https(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"X-Forwarded-Proto: https\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
assert req.scheme == "https"
|
||||
|
||||
def test_untrusted_peer_keeps_http(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
secure_scheme_headers={"X-FORWARDED-PROTO": "https"},
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"X-Forwarded-Proto: https\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("203.0.113.5", 1234))
|
||||
assert req.scheme == "http"
|
||||
|
||||
def test_conflicting_scheme_headers_rejected(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
secure_scheme_headers={
|
||||
"X-FORWARDED-PROTO": "https",
|
||||
"X-FORWARDED-SSL": "on",
|
||||
},
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"X-Forwarded-Proto: https\r\n"
|
||||
b"X-Forwarded-Ssl: off\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
with pytest.raises(InvalidSchemeHeaders):
|
||||
_parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
|
||||
|
||||
class TestForwarderTrustGate:
|
||||
def test_untrusted_peer_underscore_header_rejected(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
forwarder_headers="SCRIPT_NAME",
|
||||
header_map="refuse",
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Script_Name: /evil\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
with pytest.raises(InvalidHeaderName):
|
||||
_parse(raw, cfg, ("203.0.113.5", 1234))
|
||||
|
||||
def test_trusted_peer_underscore_header_accepted(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
forwarder_headers="SCRIPT_NAME",
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Script_Name: /api\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("127.0.0.1", 1234))
|
||||
names = {n for n, _ in req.headers}
|
||||
assert "SCRIPT_NAME" in names
|
||||
|
||||
def test_header_map_drop_silences_underscore(self, parser_name):
|
||||
cfg = _cfg(
|
||||
parser_name,
|
||||
forwarded_allow_ips="127.0.0.1",
|
||||
header_map="drop",
|
||||
)
|
||||
raw = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Stray_Name: x\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
req = _parse(raw, cfg, ("203.0.113.5", 1234))
|
||||
names = {n for n, _ in req.headers}
|
||||
assert "STRAY_NAME" not in names
|
||||
@ -255,6 +255,62 @@ def test_invalid_http_version_error():
|
||||
assert str(InvalidHTTPVersion((2, 1))) == 'Invalid HTTP Version: (2, 1)'
|
||||
|
||||
|
||||
def _build_request_parser(payload):
|
||||
"""Construct a RequestParser that drains the given bytes."""
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.http.parser import RequestParser
|
||||
|
||||
cfg = Config()
|
||||
parser = RequestParser(cfg, iter([payload]), None)
|
||||
next(iter(parser))
|
||||
return parser
|
||||
|
||||
|
||||
def test_finish_body_drains_remainder():
|
||||
payload = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
parser = _build_request_parser(payload)
|
||||
assert parser.finish_body() is True
|
||||
|
||||
|
||||
def test_finish_body_returns_false_when_byte_cap_exceeded():
|
||||
body = b"x" * (4096)
|
||||
payload = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: %d\r\n\r\n%s" % (len(body), body)
|
||||
)
|
||||
parser = _build_request_parser(payload)
|
||||
assert parser.finish_body(max_bytes=512) is False
|
||||
|
||||
|
||||
def test_finish_body_returns_false_on_expired_deadline():
|
||||
payload = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 100\r\n"
|
||||
b"\r\n"
|
||||
b"only-partial"
|
||||
)
|
||||
parser = _build_request_parser(payload)
|
||||
# Force an already-elapsed deadline; the drain must abandon immediately.
|
||||
import time as _time
|
||||
expired = _time.monotonic() - 1.0
|
||||
# IterUnreader has no socket; deadline path is exercised only when sock
|
||||
# is present. Stub a sock with gettimeout/settimeout to drive the branch.
|
||||
from unittest import mock
|
||||
sock = mock.Mock()
|
||||
sock.gettimeout.return_value = None
|
||||
parser.unreader.sock = sock
|
||||
assert parser.finish_body(deadline=expired) is False
|
||||
sock.settimeout.assert_called_with(None)
|
||||
|
||||
|
||||
def test_file_wrapper_iterable():
|
||||
"""FileWrapper should support the iterator protocol per PEP 3333."""
|
||||
filelike = io.BytesIO(b"hello world")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user