diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index 365ed59a..5c66ece8 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -254,10 +254,15 @@ class BodyReceiver: if self._chunks: return self._pop_chunk() - # Complete OR timeout - mark body finished to prevent infinite loops - # Apps should not loop forever waiting for body that won't arrive - self._body_finished = True - return {"type": "http.request", "body": b"", "more_body": False} + if self._complete: + self._body_finished = True + return {"type": "http.request", "body": b"", "more_body": False} + + # Wait returned without data and the message was not framed complete: + # treat as a client disconnect rather than synthesizing end-of-body + # (which would desync the next pipelined request). + self._closed = True + return {"type": "http.disconnect"} async def _wait_for_data(self): """Wait for body data to arrive via callback.""" @@ -268,11 +273,21 @@ class BodyReceiver: loop = asyncio.get_event_loop() self._waiter = loop.create_future() + # Bound the wait by the configured worker timeout (default 30s). + # The protocol-level timeout drives transport disconnect handling; + # this only needs to escape an idle wait if data never arrives. + cfg = getattr(self.protocol, 'cfg', None) + timeout = getattr(cfg, 'timeout', None) if cfg is not None else None + if not timeout or timeout <= 0: + timeout = 30.0 + try: - # Wait with timeout for data or completion - await asyncio.wait_for(self._waiter, timeout=30.0) + await asyncio.wait_for(self._waiter, timeout=timeout) except asyncio.TimeoutError: - pass + # No data arrived in time: mark the body receiver as disconnected + # so receive() yields http.disconnect rather than a fake terminal + # http.request with more_body=False. + self._closed = True finally: self._waiter = None @@ -436,6 +451,10 @@ class ASGIProtocol(asyncio.Protocol): for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS: if getattr(self.cfg, flag, False): incompatible.append(flag) + # PROXY protocol framing is implemented only in PythonProtocol; the C parser + # has no proxy_protocol kwarg and would silently drop the framing. + if getattr(self.cfg, 'proxy_protocol', 'off') != 'off': + incompatible.append('proxy_protocol') if parser_setting == 'python': parser_class = PythonProtocol @@ -467,8 +486,9 @@ class ASGIProtocol(asyncio.Protocol): if limit_request_line == 0 and parser_class != PythonProtocol: limit_request_line = 1024 * 1024 # 1MB for C parser - # Create parser with callbacks and limit parameters (both parsers support them) - self._callback_parser = parser_class( + # Create parser with callbacks and limit parameters (both parsers support them). + # Only the Python parser implements PROXY protocol framing; pass the option there. + parser_kwargs = dict( on_headers_complete=self._on_headers_complete, on_body=self._on_body, on_message_complete=self._on_message_complete, @@ -478,6 +498,9 @@ class ASGIProtocol(asyncio.Protocol): permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, ) + if parser_class is PythonProtocol: + parser_kwargs['proxy_protocol'] = getattr(self.cfg, 'proxy_protocol', 'off') + self._callback_parser = parser_class(**parser_kwargs) def _on_headers_complete(self): """Callback: request headers are complete.""" @@ -729,14 +752,17 @@ class ASGIProtocol(asyncio.Protocol): request = self._current_request + # If PROXY protocol provided a real client address, use it. + effective_peer = self._effective_peername(peername) + # Check for WebSocket upgrade if self._is_websocket_upgrade(request): - await self._handle_websocket(request, sockname, peername) + await self._handle_websocket(request, sockname, effective_peer) break # WebSocket takes over the connection # Handle HTTP request keepalive = await self._handle_http_request( - request, sockname, peername + request, sockname, effective_peer ) # Increment worker request count @@ -755,6 +781,18 @@ class ASGIProtocol(asyncio.Protocol): if not self.cfg.keepalive: break + # Refuse keepalive if the previous request body was not fully + # framed: residual bytes left in the transport stream would be + # parsed as the start of the next request (smuggling). + receiver = self._body_receiver + message_complete = ( + receiver is None + or receiver._complete + or receiver._closed + ) + if not message_complete: + break + # Resume reading if paused during body consumption self._resume_reading() @@ -918,15 +956,15 @@ class ASGIProtocol(asyncio.Protocol): has_transfer_encoding = True use_chunked = True # Framework already set chunked encoding - # Use chunked encoding for HTTP/1.1 streaming responses without Content-Length - # Skip for 1xx informational responses (RFC 9110) - # Skip if Transfer-Encoding already set by framework - is_informational = 100 <= response_status < 200 + # Use chunked encoding for HTTP/1.1 streaming responses without Content-Length. + # Skip when the response cannot carry a body (HEAD/1xx/204/304) or when + # Transfer-Encoding was already set by the framework. + is_no_body = self._response_omits_body(request.method, response_status) needs_chunked = ( not has_content_length and not has_transfer_encoding and request.version >= (1, 1) - and not is_informational + and not is_no_body ) if needs_chunked: use_chunked = True @@ -1201,6 +1239,36 @@ class ASGIProtocol(asyncio.Protocol): # Buffer headers for batching with first body chunk self._response_buffer = b"".join(parts) + def _effective_peername(self, peername): + """Return the client address advertised via PROXY protocol if any. + + Falls back to the transport peername when PROXY protocol is disabled, + the framing was absent, or the parser is the C variant (which currently + does not surface PROXY metadata). + """ + parser = self._callback_parser + info = getattr(parser, 'proxy_protocol_info', None) if parser else None + if not info: + return peername + client_addr = info.get('client_addr') + client_port = info.get('client_port') + if client_addr is None or client_port is None: + return peername + return (client_addr, client_port) + + @staticmethod + def _response_omits_body(method, status): + """Return True when the response MUST NOT have a body (RFC 9110). + + Applies to HEAD requests and to status codes that semantically carry no body: + 1xx informational, 204 No Content, 304 Not Modified. + """ + return ( + method == "HEAD" + or status in (204, 304) + or 100 <= status < 200 + ) + def _send_body(self, body, chunked=False): """Send response body chunk. diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index 1cc94fb7..6409c43e 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -206,26 +206,94 @@ class Message: def parse(self, unreader): raise NotImplementedError() - def parse_headers(self, data, from_trailer=False): + def _peer_trusted_for_forwarded(self): + """Return the (secure_scheme_headers, forwarder_headers) the peer is allowed to set. + + When the peer's address is not in ``forwarded_allow_ips`` (or networks), + configured forwarding/secure-scheme policy must be ignored to prevent + spoofing. Returns ``({}, [])`` when the peer is untrusted. + """ cfg = self.cfg + if (not isinstance(self.peer_addr, tuple) + or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips, + cfg.forwarded_allow_networks())): + return cfg.secure_scheme_headers, cfg.forwarder_headers + return {}, [] + + def _apply_header_policy(self, name, value, scheme_state, + secure_scheme_headers, forwarder_headers, + from_trailer=False): + """Apply per-header policy shared between Python and fast parsers. + + Mutates ``self._expected_100_continue`` and ``self.scheme`` as needed. + ``scheme_state`` is a single-element list used as a mutable sentinel + so the caller can detect repeated scheme headers. + + Returns the (name, value) pair to retain, or ``None`` to drop the + header (per ``header_map='drop'``). Raises the same exceptions the + Python path raises so behavior is identical regardless of parser. + """ + if not from_trailer and name == "EXPECT": + # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1 + # "The Expect field value is case-insensitive." + if value.lower() == "100-continue": + if self.version < (1, 1): + # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12 + # "A server that receives a 100-continue expectation + # in an HTTP/1.0 request MUST ignore that expectation." + pass + else: + self._expected_100_continue = True + # N.B. understood but ignored expect header does not return 417 + else: + raise ExpectationFailed(value) + + if name in secure_scheme_headers: + secure = value == secure_scheme_headers[name] + scheme = "https" if secure else "http" + if scheme_state[0]: + if scheme != self.scheme: + raise InvalidSchemeHeaders() + else: + scheme_state[0] = True + self.scheme = scheme + + # ambiguous mapping allows fooling downstream, e.g. merging non-identical headers: + # X-Forwarded-For: 2001:db8::ha:cc:ed + # X_Forwarded_For: 127.0.0.1,::1 + # HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1 + # Only modify after fixing *ALL* header transformations; network to wsgi env + if "_" in name: + if name in forwarder_headers or "*" in forwarder_headers: + # This forwarder may override our environment + pass + elif self.cfg.header_map == "dangerous": + # as if we did not know we cannot safely map this + pass + elif self.cfg.header_map == "drop": + # almost as if it never had been there + # but still counts against resource limits + return None + else: + # fail-safe fallthrough: refuse + raise InvalidHeaderName(name) + + return (name, value) + + def parse_headers(self, data, from_trailer=False): headers = [] # Split lines on \r\n lines = [bytes_to_str(line) for line in data.split(b"\r\n")] # handle scheme headers - scheme_header = False - secure_scheme_headers = {} - forwarder_headers = [] + scheme_state = [False] if from_trailer: # nonsense. either a request is https from the beginning # .. or we are just behind a proxy who does not remove conflicting trailers - pass - elif (not isinstance(self.peer_addr, tuple) - or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips, - cfg.forwarded_allow_networks())): - secure_scheme_headers = cfg.secure_scheme_headers - forwarder_headers = cfg.forwarder_headers + secure_scheme_headers, forwarder_headers = {}, [] + else: + secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded() # Parse headers into key/value pairs paying attention # to continuation lines. @@ -275,52 +343,14 @@ class Message: if header_length > self.limit_request_field_size > 0: raise LimitRequestHeaders("limit request headers fields size") - if not from_trailer and name == "EXPECT": - # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1 - # "The Expect field value is case-insensitive." - if value.lower() == "100-continue": - if self.version < (1, 1): - # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12 - # "A server that receives a 100-continue expectation - # in an HTTP/1.0 request MUST ignore that expectation." - pass - else: - self._expected_100_continue = True - # N.B. understood but ignored expect header does not return 417 - else: - raise ExpectationFailed(value) - - if name in secure_scheme_headers: - secure = value == secure_scheme_headers[name] - scheme = "https" if secure else "http" - if scheme_header: - if scheme != self.scheme: - raise InvalidSchemeHeaders() - else: - scheme_header = True - self.scheme = scheme - - # ambiguous mapping allows fooling downstream, e.g. merging non-identical headers: - # X-Forwarded-For: 2001:db8::ha:cc:ed - # X_Forwarded_For: 127.0.0.1,::1 - # HTTP_X_FORWARDED_FOR = 2001:db8::ha:cc:ed,127.0.0.1,::1 - # Only modify after fixing *ALL* header transformations; network to wsgi env - if "_" in name: - if name in forwarder_headers or "*" in forwarder_headers: - # This forwarder may override our environment - pass - elif self.cfg.header_map == "dangerous": - # as if we did not know we cannot safely map this - pass - elif self.cfg.header_map == "drop": - # almost as if it never had been there - # but still counts against resource limits - continue - else: - # fail-safe fallthrough: refuse - raise InvalidHeaderName(name) - - headers.append((name, value)) + kept = self._apply_header_policy( + name, value, scheme_state, + secure_scheme_headers, forwarder_headers, + from_trailer=from_trailer, + ) + if kept is None: + continue + headers.append(kept) return headers @@ -516,25 +546,23 @@ class Request(Message): # Headers - convert bytes to strings with uppercase names # gunicorn_h1c returns headers as (bytes, bytes) tuples - # Header name/value validation done by C parser + # Header name/value validation done by C parser; policy (Expect, + # secure_scheme_headers, forwarder trust gate, header_map) is enforced + # below so the fast path mirrors parse_headers(). self.headers = [] + scheme_state = [False] + secure_scheme_headers, forwarder_headers = self._peer_trusted_for_forwarded() for name_bytes, value_bytes in result['headers']: name = bytes_to_str(name_bytes).upper() value = bytes_to_str(value_bytes) - # Handle underscore in header names (policy decision, not validation) - if "_" in name: - forwarder_headers = self.cfg.forwarder_headers - if name in forwarder_headers or "*" in forwarder_headers: - pass - elif self.cfg.header_map == "dangerous": - pass - elif self.cfg.header_map == "drop": - continue - else: - raise InvalidHeaderName(name) - - self.headers.append((name, value)) + kept = self._apply_header_policy( + name, value, scheme_state, + secure_scheme_headers, forwarder_headers, + ) + if kept is None: + continue + self.headers.append(kept) # Return remaining data after headers consumed = result['consumed'] diff --git a/gunicorn/http/parser.py b/gunicorn/http/parser.py index 260beafa..0d774e93 100644 --- a/gunicorn/http/parser.py +++ b/gunicorn/http/parser.py @@ -2,12 +2,20 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. +import socket import ssl +import time from gunicorn.http.message import Request from gunicorn.http.unreader import SocketUnreader, IterUnreader +# Cap on bytes drained from an unconsumed request body before a keepalive +# reset. Defends against a slow-but-steady client that stays under a per-read +# deadline yet streams indefinitely. +_DRAIN_MAX_BYTES = 64 * 1024 + + class Parser: mesg_class = None @@ -27,21 +35,61 @@ class Parser: def __iter__(self): return self - def finish_body(self): + def finish_body(self, deadline=None, max_bytes=_DRAIN_MAX_BYTES): """Discard any unread body of the current message. - This should be called before returning a keepalive connection to - the poller to ensure the socket doesn't appear readable due to - leftover body bytes. + Called before returning a keepalive connection to the poller so the + socket does not appear readable due to leftover body bytes. + + ``deadline`` is an absolute ``time.monotonic()`` value; when set the + socket read timeout is bounded by the remaining time before each read. + ``max_bytes`` caps the total drained bytes to defend against a slow + client that keeps trickling under the deadline. + + Returns ``True`` when the body was fully drained, ``False`` when the + drain was abandoned (deadline, byte cap, or socket timeout). Callers + that observe ``False`` MUST close the connection rather than serve + another request on it. """ - if self.mesg: - try: - data = self.mesg.body.read(1024) - while data: + if not self.mesg: + return True + + sock = getattr(self.unreader, "sock", None) + # gettimeout/settimeout only matter when bounding a real socket; a + # mock or non-socket source skips the timeout plumbing. + if sock is not None and hasattr(sock, "gettimeout") and hasattr(sock, "settimeout"): + timeoutable_sock = sock + prior_timeout = sock.gettimeout() + else: + timeoutable_sock = None + prior_timeout = None + + drained = 0 + try: + while True: + if deadline is not None and timeoutable_sock is not None: + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + timeoutable_sock.settimeout(remaining) + try: data = self.mesg.body.read(1024) - except ssl.SSLWantReadError: - # SSL socket has no more application data available - pass + except (socket.timeout, TimeoutError): + return False + except ssl.SSLWantReadError: + # SSL socket has no more application data available + return True + if not data: + return True + drained += len(data) + if drained >= max_bytes: + return False + finally: + if timeoutable_sock is not None: + try: + timeoutable_sock.settimeout(prior_timeout) + except OSError: + pass def __next__(self): # Stop if HTTP dictates a stop. diff --git a/gunicorn/workers/gthread.py b/gunicorn/workers/gthread.py index cd5f86f5..bf7711fc 100644 --- a/gunicorn/workers/gthread.py +++ b/gunicorn/workers/gthread.py @@ -478,9 +478,14 @@ class ThreadWorker(base.Worker): # Handle the request keepalive = self.handle_request(req, conn) if keepalive: - # Discard any unread request body before keepalive - # to prevent socket appearing readable due to leftover bytes - conn.parser.finish_body() + # Discard any unread request body before keepalive to prevent + # the socket from appearing readable due to leftover bytes. + # Bound the drain by the worker data timeout: a stalled client + # must not keep this thread blocked. + drain_deadline = time.monotonic() + DEFAULT_WORKER_DATA_TIMEOUT + if not conn.parser.finish_body(deadline=drain_deadline): + # Abandon keepalive when the body could not be fully drained. + return False return True except http.errors.NoMoreData as e: self.log.debug("Ignored premature client disconnection. %s", e) diff --git a/tests/test_asgi_disconnect.py b/tests/test_asgi_disconnect.py index 7de45bb3..a0270051 100644 --- a/tests/test_asgi_disconnect.py +++ b/tests/test_asgi_disconnect.py @@ -204,3 +204,65 @@ class TestASGIDisconnectGracePeriod: from gunicorn.config import Config cfg = Config() assert cfg.asgi_disconnect_grace_period == 3 + + +class TestBodyReceiverIncompleteBody: + """Cover the receive() path when the request body never finishes framing.""" + + @pytest.fixture + def mock_worker(self): + worker = mock.Mock() + worker.nr_conns = 0 + worker.loop = asyncio.new_event_loop() + worker.cfg = mock.Mock() + worker.cfg.asgi_disconnect_grace_period = 3 + worker.cfg.timeout = 0.05 # tight bound for the test + worker.log = mock.Mock() + return worker + + @pytest.mark.asyncio + async def test_receive_yields_disconnect_on_timeout(self, mock_worker): + """When _wait_for_data times out and the body is not complete, the + receiver MUST yield http.disconnect rather than synthesize a terminal + http.request with more_body=False — that would desync the next + pipelined request.""" + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + + protocol = ASGIProtocol(mock_worker) + protocol.reader = mock.Mock() + + request = mock.Mock() + request.content_length = 100 + request.chunked = False + + receiver = BodyReceiver(request, protocol) + protocol._body_receiver = receiver + + msg = await receiver.receive() + assert msg == {"type": "http.disconnect"} + assert receiver._closed is True + + @pytest.mark.asyncio + async def test_receive_yields_terminal_request_when_complete(self, mock_worker): + """If the body is framed complete, the existing terminal http.request + with more_body=False must still be returned.""" + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + + protocol = ASGIProtocol(mock_worker) + protocol.reader = mock.Mock() + + request = mock.Mock() + request.content_length = 5 + request.chunked = False + + receiver = BodyReceiver(request, protocol) + protocol._body_receiver = receiver + + receiver.feed(b"hello") + receiver.set_complete() + + msg = await receiver.receive() + assert msg["type"] == "http.request" + assert msg["body"] == b"hello" + # more_body may be False since the body is complete + assert msg["more_body"] is False diff --git a/tests/test_asgi_streaming.py b/tests/test_asgi_streaming.py index d4bd7ed4..338eba10 100644 --- a/tests/test_asgi_streaming.py +++ b/tests/test_asgi_streaming.py @@ -344,6 +344,43 @@ class TestHTTPVersionForChunked: assert uses_http1_chunked is False +# ============================================================================ +# No-Body Response Tests (RFC 9110) +# ============================================================================ + +class TestResponseOmitsBody: + """Verify HEAD/1xx/204/304 are flagged as bodyless responses.""" + + def _omits(self, method, status): + from gunicorn.asgi.protocol import ASGIProtocol + return ASGIProtocol._response_omits_body(method, status) + + def test_head_omits_body(self): + assert self._omits("HEAD", 200) is True + assert self._omits("HEAD", 500) is True + + def test_204_omits_body(self): + assert self._omits("GET", 204) is True + assert self._omits("POST", 204) is True + + def test_304_omits_body(self): + assert self._omits("GET", 304) is True + + def test_informational_omits_body(self): + assert self._omits("GET", 100) is True + assert self._omits("GET", 103) is True + assert self._omits("GET", 199) is True + + def test_get_200_has_body(self): + assert self._omits("GET", 200) is False + + def test_post_200_has_body(self): + assert self._omits("POST", 200) is False + + def test_404_has_body(self): + assert self._omits("GET", 404) is False + + # ============================================================================ # Streaming Response Message Sequence Tests # ============================================================================ diff --git a/tests/test_asgi_valid_requests.py b/tests/test_asgi_valid_requests.py index 3ab33f97..a9147c73 100644 --- a/tests/test_asgi_valid_requests.py +++ b/tests/test_asgi_valid_requests.py @@ -45,10 +45,6 @@ def test_asgi_parser(fname): if getattr(cfg, flag, False): pytest.skip(f"Callback parser incompatible with {flag}") - # Skip proxy protocol tests - if getattr(cfg, 'proxy_protocol', 'off') != 'off': - pytest.skip("Callback parser does not support proxy_protocol") - req = treq_asgi.request(fname, expect) # Test with different sending strategies diff --git a/tests/test_asgi_worker.py b/tests/test_asgi_worker.py index bf534601..fbe9a6e6 100644 --- a/tests/test_asgi_worker.py +++ b/tests/test_asgi_worker.py @@ -582,6 +582,59 @@ class TestASGIProtocol: assert scope["root_path"] == "/api" assert scope["http_version"] == "1.1" + def test_effective_peername_no_proxy(self): + """Without PROXY framing the transport peername is returned as-is.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + protocol = ASGIProtocol(worker) + protocol._callback_parser = mock.Mock(proxy_protocol_info=None) + + peer = ("10.0.0.1", 12345) + assert protocol._effective_peername(peer) == peer + + def test_effective_peername_with_proxy(self): + """PROXY-supplied client address overrides the transport peername.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + protocol = ASGIProtocol(worker) + protocol._callback_parser = mock.Mock(proxy_protocol_info={ + 'proxy_protocol': 'TCP4', + 'client_addr': '203.0.113.5', + 'client_port': 56324, + 'proxy_addr': '10.0.0.2', + 'proxy_port': 443, + }) + + assert protocol._effective_peername(("10.0.0.1", 1)) == ("203.0.113.5", 56324) + + def test_effective_peername_unknown_proxy(self): + """UNKNOWN PROXY framing has no client info; fall back to transport peername.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + protocol = ASGIProtocol(worker) + protocol._callback_parser = mock.Mock(proxy_protocol_info={ + 'proxy_protocol': 'UNKNOWN', + 'client_addr': None, + 'client_port': None, + 'proxy_addr': None, + 'proxy_port': None, + }) + + peer = ("10.0.0.1", 12345) + assert protocol._effective_peername(peer) == peer + # ============================================================================ # Config Tests diff --git a/tests/test_header_policy_parity.py b/tests/test_header_policy_parity.py new file mode 100644 index 00000000..90905e0c --- /dev/null +++ b/tests/test_header_policy_parity.py @@ -0,0 +1,185 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Parity tests for WSGI header policy across Python and fast parsers. + +These checks ensure that Expect, secure_scheme_headers, forwarder_headers, +and the forwarded_allow_ips trust gate are enforced identically regardless +of the parser implementation selected by ``http_parser``. +""" + +import sys + +import pytest + +from gunicorn.config import Config +from gunicorn.http.parser import RequestParser +from gunicorn.http.errors import ( + ExpectationFailed, + InvalidHeaderName, + InvalidSchemeHeaders, +) + + +def _parse(raw, cfg, peer_addr): + parser = RequestParser(cfg, iter([raw]), peer_addr) + return next(iter(parser)) + + +def _cfg(http_parser, **overrides): + cfg = Config() + cfg.set("http_parser", http_parser) + for k, v in overrides.items(): + cfg.set(k, v) + return cfg + + +@pytest.fixture(params=["python", "fast"]) +def parser_name(request): + if request.param == "fast": + if hasattr(sys, "pypy_version_info"): + pytest.skip("gunicorn_h1c not supported on PyPy") + gunicorn_h1c = pytest.importorskip("gunicorn_h1c") + if not hasattr(gunicorn_h1c.H1CProtocol, "asgi_headers"): + pytest.skip("gunicorn_h1c >= 0.6.2 required") + return request.param + + +class TestExpectPolicy: + def test_expect_100_continue_sets_flag(self, parser_name): + cfg = _cfg(parser_name) + raw = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 0\r\n" + b"Expect: 100-continue\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("127.0.0.1", 1234)) + assert req._expected_100_continue is True + + def test_expect_unknown_value_rejected(self, parser_name): + cfg = _cfg(parser_name) + raw = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 0\r\n" + b"Expect: bogus-extension\r\n" + b"\r\n" + ) + with pytest.raises(ExpectationFailed): + _parse(raw, cfg, ("127.0.0.1", 1234)) + + def test_expect_ignored_in_http10(self, parser_name): + cfg = _cfg(parser_name) + raw = ( + b"POST / HTTP/1.0\r\n" + b"Host: example.com\r\n" + b"Content-Length: 0\r\n" + b"Expect: 100-continue\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("127.0.0.1", 1234)) + assert req._expected_100_continue is False + + +class TestSecureSchemeHeaders: + def test_trusted_peer_promotes_https(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + secure_scheme_headers={"X-FORWARDED-PROTO": "https"}, + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"X-Forwarded-Proto: https\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("127.0.0.1", 1234)) + assert req.scheme == "https" + + def test_untrusted_peer_keeps_http(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + secure_scheme_headers={"X-FORWARDED-PROTO": "https"}, + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"X-Forwarded-Proto: https\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("203.0.113.5", 1234)) + assert req.scheme == "http" + + def test_conflicting_scheme_headers_rejected(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + secure_scheme_headers={ + "X-FORWARDED-PROTO": "https", + "X-FORWARDED-SSL": "on", + }, + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"X-Forwarded-Proto: https\r\n" + b"X-Forwarded-Ssl: off\r\n" + b"\r\n" + ) + with pytest.raises(InvalidSchemeHeaders): + _parse(raw, cfg, ("127.0.0.1", 1234)) + + +class TestForwarderTrustGate: + def test_untrusted_peer_underscore_header_rejected(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + forwarder_headers="SCRIPT_NAME", + header_map="refuse", + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Script_Name: /evil\r\n" + b"\r\n" + ) + with pytest.raises(InvalidHeaderName): + _parse(raw, cfg, ("203.0.113.5", 1234)) + + def test_trusted_peer_underscore_header_accepted(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + forwarder_headers="SCRIPT_NAME", + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Script_Name: /api\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("127.0.0.1", 1234)) + names = {n for n, _ in req.headers} + assert "SCRIPT_NAME" in names + + def test_header_map_drop_silences_underscore(self, parser_name): + cfg = _cfg( + parser_name, + forwarded_allow_ips="127.0.0.1", + header_map="drop", + ) + raw = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Stray_Name: x\r\n" + b"\r\n" + ) + req = _parse(raw, cfg, ("203.0.113.5", 1234)) + names = {n for n, _ in req.headers} + assert "STRAY_NAME" not in names diff --git a/tests/test_http.py b/tests/test_http.py index ef9b5ea5..0531916d 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -255,6 +255,62 @@ def test_invalid_http_version_error(): assert str(InvalidHTTPVersion((2, 1))) == 'Invalid HTTP Version: (2, 1)' +def _build_request_parser(payload): + """Construct a RequestParser that drains the given bytes.""" + from gunicorn.config import Config + from gunicorn.http.parser import RequestParser + + cfg = Config() + parser = RequestParser(cfg, iter([payload]), None) + next(iter(parser)) + return parser + + +def test_finish_body_drains_remainder(): + payload = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + parser = _build_request_parser(payload) + assert parser.finish_body() is True + + +def test_finish_body_returns_false_when_byte_cap_exceeded(): + body = b"x" * (4096) + payload = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: %d\r\n\r\n%s" % (len(body), body) + ) + parser = _build_request_parser(payload) + assert parser.finish_body(max_bytes=512) is False + + +def test_finish_body_returns_false_on_expired_deadline(): + payload = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 100\r\n" + b"\r\n" + b"only-partial" + ) + parser = _build_request_parser(payload) + # Force an already-elapsed deadline; the drain must abandon immediately. + import time as _time + expired = _time.monotonic() - 1.0 + # IterUnreader has no socket; deadline path is exercised only when sock + # is present. Stub a sock with gettimeout/settimeout to drive the branch. + from unittest import mock + sock = mock.Mock() + sock.gettimeout.return_value = None + parser.unreader.sock = sock + assert parser.finish_body(deadline=expired) is False + sock.settimeout.assert_called_with(None) + + def test_file_wrapper_iterable(): """FileWrapper should support the iterator protocol per PEP 3333.""" filelike = io.BytesIO(b"hello world")