From ffcebce4a772b65939a7ce0a2b11f37ddb19226f Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 25 Mar 2026 16:20:42 +0100 Subject: [PATCH] Fix ASGI callback parser header validation Add security checks to PythonProtocol per RFC 9110/9112: - Reject duplicate Content-Length headers - Reject CL + TE combinations - Reject chunked in HTTP/1.0 - Reject stacked chunked encoding - Validate Transfer-Encoding values - Strict chunk size validation Add PROXY protocol v1/v2 support to callback parser. Add treq-based test infrastructure for ASGI parser. --- gunicorn/asgi/parser.py | 347 +++++++++++++++++++++- tests/test_asgi_invalid_requests.py | 68 +++++ tests/test_asgi_parser_validation.py | 418 +++++++++++++++++++++++++++ tests/test_asgi_valid_requests.py | 53 ++++ tests/treq_asgi.py | 265 +++++++++++++++++ 5 files changed, 1138 insertions(+), 13 deletions(-) create mode 100644 tests/test_asgi_invalid_requests.py create mode 100644 tests/test_asgi_parser_validation.py create mode 100644 tests/test_asgi_valid_requests.py create mode 100644 tests/treq_asgi.py diff --git a/gunicorn/asgi/parser.py b/gunicorn/asgi/parser.py index e7432a77..843cc489 100644 --- a/gunicorn/asgi/parser.py +++ b/gunicorn/asgi/parser.py @@ -9,11 +9,47 @@ Provides callback-based parsing using either the fast C parser (gunicorn_h1c) or the pure Python PythonProtocol fallback. """ +import struct +from enum import IntEnum + class ParseError(Exception): """Base error raised during HTTP parsing.""" +class InvalidProxyLine(ParseError): + """Invalid PROXY protocol v1 line.""" + + +class InvalidProxyHeader(ParseError): + """Invalid PROXY protocol v2 header.""" + + +# PROXY protocol v2 constants +PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" + + +class PPCommand(IntEnum): + """PROXY protocol v2 commands.""" + LOCAL = 0x0 + PROXY = 0x1 + + +class PPFamily(IntEnum): + """PROXY protocol v2 address families.""" + UNSPEC = 0x0 + INET = 0x1 # IPv4 + INET6 = 0x2 # IPv6 + UNIX = 0x3 + + +class PPProtocol(IntEnum): + """PROXY protocol v2 transport protocols.""" + UNSPEC = 0x0 + STREAM = 0x1 # TCP + DGRAM = 0x2 # UDP + + class LimitRequestLine(ParseError): """Request line exceeds configured limit.""" @@ -22,6 +58,10 @@ class LimitRequestHeaders(ParseError): """Too many headers or header field too large.""" +class InvalidRequestLine(ParseError): + """Invalid request line.""" + + class InvalidRequestMethod(ParseError): """Invalid HTTP method.""" @@ -38,6 +78,14 @@ class InvalidHeader(ParseError): """Invalid header value.""" +class UnsupportedTransferCoding(ParseError): + """Unsupported Transfer-Encoding value.""" + + +class InvalidChunkSize(ParseError): + """Invalid chunk size in chunked transfer encoding.""" + + class PythonProtocol: """Callback-based HTTP/1.1 parser (pure Python fallback). @@ -64,6 +112,7 @@ class PythonProtocol: '_limit_request_line', '_limit_request_fields', '_limit_request_field_size', '_permit_unconventional_http_method', '_permit_unconventional_http_version', '_header_count', + '_proxy_protocol', '_proxy_protocol_info', '_proxy_protocol_done', ) def __init__( @@ -79,6 +128,7 @@ class PythonProtocol: limit_request_field_size=8190, permit_unconventional_http_method=False, permit_unconventional_http_version=False, + proxy_protocol='off', ): self._on_message_begin = on_message_begin self._on_url = on_url @@ -95,8 +145,13 @@ class PythonProtocol: self._permit_unconventional_http_version = permit_unconventional_http_version self._header_count = 0 - # Parser state: request_line, headers, body, chunked_size, chunked_data, complete - self._state = 'request_line' + # Proxy protocol + self._proxy_protocol = proxy_protocol + self._proxy_protocol_info = None + self._proxy_protocol_done = (proxy_protocol == 'off') + + # Parser state: proxy_protocol, request_line, headers, body, chunked_size, chunked_data, complete + self._state = 'proxy_protocol' if proxy_protocol != 'off' else 'request_line' self._buffer = bytearray() self._headers_list = [] @@ -131,7 +186,10 @@ class PythonProtocol: self._buffer.extend(data) while self._buffer: - if self._state == 'request_line': + if self._state == 'proxy_protocol': + if not self._parse_proxy_protocol(): + break + elif self._state == 'request_line': if not self._parse_request_line(): break elif self._state == 'headers': @@ -146,6 +204,11 @@ class PythonProtocol: else: break + @property + def proxy_protocol_info(self): + """Return proxy protocol info if parsed.""" + return self._proxy_protocol_info + def reset(self): """Reset for next request (keepalive).""" self._state = 'request_line' @@ -166,6 +229,190 @@ class PythonProtocol: self._chunk_remaining = 0 self._header_count = 0 + def _parse_proxy_protocol(self): + """Parse PROXY protocol header if enabled. + + Returns True if parsing is complete (or not applicable), + False if more data is needed. + """ + # Need at least 12 bytes to detect v2 signature or check for v1 prefix + if len(self._buffer) < 12: + return False + + mode = self._proxy_protocol + + # Check for v2 signature first + if mode in ('v2', 'auto') and self._buffer[:12] == PP_V2_SIGNATURE: + return self._parse_proxy_protocol_v2() + + # Check for v1 prefix + if mode in ('v1', 'auto') and self._buffer[:6] == b'PROXY ': + return self._parse_proxy_protocol_v1() + + # Not proxy protocol - continue with normal parsing + self._proxy_protocol_done = True + self._state = 'request_line' + return True + + def _parse_proxy_protocol_v1(self): + """Parse PROXY protocol v1 (text format). + + Format: PROXY \r\n + """ + # Find end of line + idx = self._buffer.find(b'\r\n') + if idx == -1: + # Need more data - v1 header can be up to 107 bytes + if len(self._buffer) > 107: + raise InvalidProxyLine("PROXY v1 header too long") + return False + + line = bytes(self._buffer[:idx]).decode('latin-1') + del self._buffer[:idx + 2] + + # Parse the line + parts = line.split(' ') + if len(parts) < 2: + raise InvalidProxyLine("Invalid PROXY v1 line") + + proto = parts[1].upper() + + if proto == 'UNKNOWN': + # Unknown protocol - no address info + self._proxy_protocol_info = { + 'proxy_protocol': 'UNKNOWN', + 'client_addr': None, + 'client_port': None, + 'proxy_addr': None, + 'proxy_port': None, + } + elif proto in ('TCP4', 'TCP6'): + if len(parts) != 6: + raise InvalidProxyLine("Invalid PROXY v1 line for %s" % proto) + + try: + s_addr = parts[2] + d_addr = parts[3] + s_port = int(parts[4]) + d_port = int(parts[5]) + except ValueError as e: + raise InvalidProxyLine("Invalid PROXY v1 port: %s" % e) + + if not (0 <= s_port <= 65535 and 0 <= d_port <= 65535): + raise InvalidProxyLine("Invalid PROXY v1 port range") + + self._proxy_protocol_info = { + 'proxy_protocol': proto, + 'client_addr': s_addr, + 'client_port': s_port, + 'proxy_addr': d_addr, + 'proxy_port': d_port, + } + else: + raise InvalidProxyLine("Unknown PROXY v1 protocol: %s" % proto) + + self._proxy_protocol_done = True + self._state = 'request_line' + return True + + def _parse_proxy_protocol_v2(self): + """Parse PROXY protocol v2 (binary format).""" + # Need at least 16 bytes for header + if len(self._buffer) < 16: + return False + + # Parse header + ver_cmd = self._buffer[12] + fam_prot = self._buffer[13] + length = struct.unpack('>H', bytes(self._buffer[14:16]))[0] + + # Check version + version = (ver_cmd & 0xF0) >> 4 + if version != 2: + raise InvalidProxyHeader("Unsupported PROXY v2 version: %d" % version) + + # Check command + command = ver_cmd & 0x0F + if command not in (PPCommand.LOCAL, PPCommand.PROXY): + raise InvalidProxyHeader("Unsupported PROXY v2 command: %d" % command) + + # Check if we have the complete header + total_size = 16 + length + if len(self._buffer) < total_size: + return False + + # Extract address data + addr_data = bytes(self._buffer[16:total_size]) + del self._buffer[:total_size] + + # Handle LOCAL command + if command == PPCommand.LOCAL: + self._proxy_protocol_info = { + 'proxy_protocol': 'LOCAL', + 'client_addr': None, + 'client_port': None, + 'proxy_addr': None, + 'proxy_port': None, + } + self._proxy_protocol_done = True + self._state = 'request_line' + return True + + # Parse address family and protocol + family = (fam_prot & 0xF0) >> 4 + protocol = fam_prot & 0x0F + + if family == PPFamily.INET: + # IPv4 + if len(addr_data) < 12: + raise InvalidProxyHeader("Invalid PROXY v2 IPv4 address data") + s_addr = '.'.join(str(b) for b in addr_data[:4]) + d_addr = '.'.join(str(b) for b in addr_data[4:8]) + s_port = struct.unpack('>H', addr_data[8:10])[0] + d_port = struct.unpack('>H', addr_data[10:12])[0] + proto = 'TCP4' if protocol == PPProtocol.STREAM else 'UDP4' + + elif family == PPFamily.INET6: + # IPv6 + if len(addr_data) < 36: + raise InvalidProxyHeader("Invalid PROXY v2 IPv6 address data") + # Format IPv6 addresses + s_words = struct.unpack('>8H', addr_data[:16]) + d_words = struct.unpack('>8H', addr_data[16:32]) + s_addr = ':'.join('%x' % w for w in s_words) + d_addr = ':'.join('%x' % w for w in d_words) + s_port = struct.unpack('>H', addr_data[32:34])[0] + d_port = struct.unpack('>H', addr_data[34:36])[0] + proto = 'TCP6' if protocol == PPProtocol.STREAM else 'UDP6' + + elif family == PPFamily.UNSPEC: + # Unspecified address family + self._proxy_protocol_info = { + 'proxy_protocol': 'UNSPEC', + 'client_addr': None, + 'client_port': None, + 'proxy_addr': None, + 'proxy_port': None, + } + self._proxy_protocol_done = True + self._state = 'request_line' + return True + + else: + raise InvalidProxyHeader("Unsupported PROXY v2 address family: %d" % family) + + self._proxy_protocol_info = { + 'proxy_protocol': proto, + 'client_addr': s_addr, + 'client_port': s_port, + 'proxy_addr': d_addr, + 'proxy_port': d_port, + } + + self._proxy_protocol_done = True + self._state = 'request_line' + return True + def _parse_request_line(self): """Parse request line, return True if complete.""" idx = self._buffer.find(b'\r\n') @@ -182,7 +429,7 @@ class PythonProtocol: # Parse: METHOD PATH HTTP/x.y parts = line.split(b' ', 2) if len(parts) != 3: - raise ParseError("Invalid request line") + raise InvalidRequestLine("Invalid request line") self.method = parts[0] self.path = parts[1] @@ -234,8 +481,8 @@ class PythonProtocol: self._finalize_headers() return True - # Check header field size limit - if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size: + # Check header field size limit (include CRLF in size to match WSGI parser) + if self._limit_request_field_size > 0 and len(line) + 2 > self._limit_request_field_size: raise LimitRequestHeaders("Request header field is too large") # Check header count limit @@ -264,16 +511,59 @@ class PythonProtocol: self._on_header(name_lower, value) def _finalize_headers(self): - """Called when all headers received.""" + """Called when all headers received. + + Validates headers for request smuggling vulnerabilities: + - Rejects duplicate Content-Length headers + - Rejects requests with both Content-Length and Transfer-Encoding + - Rejects chunked Transfer-Encoding in HTTP/1.0 + - Rejects stacked chunked encoding + - Validates Transfer-Encoding values + """ self.headers = self._headers_list - # Extract content-length and chunked + # Extract and validate content-length and transfer-encoding + content_length = None + chunked = False + for name, value in self.headers: if name == b'content-length': - self.content_length = int(value) - self._body_remaining = self.content_length + # Reject duplicate Content-Length headers (request smuggling vector) + if content_length is not None: + raise InvalidHeader("Duplicate Content-Length header") + try: + cl_value = int(value) + except ValueError: + raise InvalidHeader("Invalid Content-Length value") + if cl_value < 0: + raise InvalidHeader("Negative Content-Length") + content_length = cl_value + elif name == b'transfer-encoding': - self.is_chunked = b'chunked' in value.lower() + # Properly parse comma-separated Transfer-Encoding values + # per RFC 9112 Section 6.1 + vals = [v.strip() for v in value.split(b',')] + for val in vals: + val_lower = val.lower() + if val_lower == b'chunked': + # Reject stacked chunked encoding (request smuggling vector) + if chunked: + raise InvalidHeader("Stacked chunked encoding") + chunked = True + elif val_lower == b'identity': + # identity after chunked is invalid + if chunked: + raise InvalidHeader("Invalid Transfer-Encoding after chunked") + elif val_lower in (b'compress', b'deflate', b'gzip'): + # Compression after chunked is invalid + if chunked: + raise InvalidHeader("Invalid Transfer-Encoding after chunked") + # Mark connection for close (unsupported but valid) + self.should_keep_alive = False + else: + # Reject unknown transfer codings + raise UnsupportedTransferCoding(val.decode('latin-1')) + elif name == b'connection': val = value.lower() if b'close' in val: @@ -281,6 +571,25 @@ class PythonProtocol: elif b'keep-alive' in val: self.should_keep_alive = True + # Security checks for request smuggling prevention + if chunked: + # Reject chunked in HTTP/1.0 (RFC 9112 Section 6.1) + if self.http_version < (1, 1): + raise InvalidHeader("Chunked encoding not allowed in HTTP/1.0") + # Reject Content-Length with Transfer-Encoding (request smuggling vector) + if content_length is not None: + raise InvalidHeader("Content-Length with Transfer-Encoding") + self.is_chunked = True + self.content_length = None + self._body_remaining = -1 # Chunked mode + elif content_length is not None: + self.content_length = content_length + self._body_remaining = content_length + else: + # No body + self.content_length = None + self._body_remaining = 0 + # HTTP/1.0 defaults to close if self.http_version == (1, 0) and self.should_keep_alive: # Only keep-alive if explicitly requested @@ -348,12 +657,24 @@ class PythonProtocol: # Handle chunk extensions (e.g., "5;ext=value") semicolon = size_line.find(b';') if semicolon != -1: - size_line = size_line[:semicolon].strip() + size_line = size_line[:semicolon] + + # Strict validation: reject leading/trailing whitespace + # to prevent parser desync (request smuggling vector) + if size_line != size_line.strip(): + raise InvalidChunkSize("Whitespace in chunk size") + if not size_line: + raise InvalidChunkSize("Empty chunk size") + + # Validate hex characters only (0-9, a-f, A-F) + for c in size_line: + if c not in b'0123456789abcdefABCDEF': + raise InvalidChunkSize("Invalid character in chunk size") try: self._chunk_size = int(size_line, 16) except ValueError: - raise ParseError("Invalid chunk size") + raise InvalidChunkSize("Invalid chunk size") if self._chunk_size == 0: # Final chunk - skip trailers diff --git a/tests/test_asgi_invalid_requests.py b/tests/test_asgi_invalid_requests.py new file mode 100644 index 00000000..0b7521b8 --- /dev/null +++ b/tests/test_asgi_invalid_requests.py @@ -0,0 +1,68 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Test invalid HTTP requests against ASGI callback parser. + +Runs the same .http test files as test_invalid_requests.py but using +the ASGI PythonProtocol callback parser. +""" + +import glob +import os + +import pytest + +from gunicorn.http.errors import ( + InvalidSchemeHeaders, + ObsoleteFolding, +) +import treq_asgi + +dirname = os.path.dirname(__file__) +reqdir = os.path.join(dirname, "requests", "invalid") +httpfiles = glob.glob(os.path.join(reqdir, "*.http")) + +# Tests that require features not supported by callback parser +SKIP_TESTS = { + # Tests requiring header_map config (underscore handling) + 'chunked_07.http', '040.http', + # Tests for features not in callback parser + '008.http', # Invalid request target validation + '012.http', # Invalid request target validation + '016.http', # URI bracket validation + '020.http', # Space before colon in header name + '022.http', # Request target validation +} + +# Config flags incompatible with callback parser +INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + +# Exceptions only raised by Python WSGI parser +WSGI_ONLY_EXCEPTIONS = (ObsoleteFolding, InvalidSchemeHeaders) + + +@pytest.mark.parametrize("fname", httpfiles) +def test_asgi_parser(fname): + """Test invalid HTTP requests with ASGI callback parser.""" + basename = os.path.basename(fname) + if basename in SKIP_TESTS: + pytest.skip(f"Test {basename} not supported by callback parser") + + env = treq_asgi.load_py(os.path.splitext(fname)[0] + ".py") + expect = env["request"] + cfg = env["cfg"] + + # Skip tests that use incompatible config flags + for flag in INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"Callback parser incompatible with {flag}") + + # Skip tests expecting WSGI-only exceptions + if expect in WSGI_ONLY_EXCEPTIONS or ( + isinstance(expect, type) and issubclass(expect, WSGI_ONLY_EXCEPTIONS) + ): + pytest.skip(f"Callback parser does not raise {expect.__name__}") + + req = treq_asgi.badrequest(fname) + req.check(cfg, expect) diff --git a/tests/test_asgi_parser_validation.py b/tests/test_asgi_parser_validation.py new file mode 100644 index 00000000..48d5d0b2 --- /dev/null +++ b/tests/test_asgi_parser_validation.py @@ -0,0 +1,418 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for ASGI callback parser header validation. + +These tests verify that PythonProtocol correctly validates HTTP headers +and body framing according to RFC 9110 and RFC 9112. +""" + +import pytest + +from gunicorn.asgi.parser import ( + PythonProtocol, + InvalidHeader, + InvalidChunkSize, + UnsupportedTransferCoding, + ParseError, +) + + +class TestContentLengthTransferEncodingConflict: + """Test rejection of requests with both CL and TE headers.""" + + def test_cl_te_conflict_rejected(self): + """Request with both Content-Length and Transfer-Encoding must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + def test_te_cl_conflict_rejected(self): + """Order doesn't matter - TE before CL also rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + def test_invalid_te_with_cl_rejected(self): + """Invalid T-E value combined with CL must be rejected.""" + parser = PythonProtocol() + + # This should fail due to invalid T-E value (identity;chunked=not) + with pytest.raises((InvalidHeader, UnsupportedTransferCoding)): + parser.feed( + b"POST /headers HTTP/1.0\r\n" + b"Connection: keep-alive\r\n" + b"Transfer-Encoding: identity;chunked=not\r\n" + b"Content-Length: -999\r\n" + b"\r\n" + ) + + +class TestDuplicateContentLength: + """Test rejection of duplicate Content-Length headers.""" + + def test_duplicate_cl_rejected(self): + """Duplicate Content-Length headers must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Duplicate Content-Length"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + def test_different_cl_values_rejected(self): + """Different Content-Length values must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Duplicate Content-Length"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"Content-Length: 20\r\n" + b"\r\n" + ) + + def test_negative_cl_rejected(self): + """Negative Content-Length must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Negative Content-Length"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: -999\r\n" + b"\r\n" + ) + + def test_non_numeric_cl_rejected(self): + """Non-numeric Content-Length must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Invalid Content-Length"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: abc\r\n" + b"\r\n" + ) + + def test_cl_with_spaces_rejected(self): + """Content-Length with embedded spaces must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 0 1\r\n" + b"\r\n" + ) + + +class TestChunkedInHTTP10: + """Test rejection of chunked encoding in HTTP/1.0.""" + + def test_chunked_http10_rejected(self): + """Chunked Transfer-Encoding in HTTP/1.0 must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="HTTP/1.0"): + parser.feed( + b"POST /test HTTP/1.0\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + +class TestTransferEncodingValidation: + """Test proper validation of Transfer-Encoding header values.""" + + def test_stacked_chunked_rejected(self): + """Stacked chunked encoding must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Stacked chunked"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked, chunked\r\n" + b"\r\n" + ) + + def test_chunked_then_identity_rejected(self): + """Identity after chunked must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Invalid Transfer-Encoding after chunked"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked, identity\r\n" + b"\r\n" + ) + + def test_chunked_then_gzip_rejected(self): + """Compression after chunked must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Invalid Transfer-Encoding after chunked"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked, gzip\r\n" + b"\r\n" + ) + + def test_unknown_transfer_coding_rejected(self): + """Unknown transfer codings must be rejected.""" + parser = PythonProtocol() + + with pytest.raises(UnsupportedTransferCoding): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: bogus\r\n" + b"\r\n" + ) + + def test_te_with_parameters_rejected(self): + """Transfer-Encoding with parameters (like identity;chunked=not) must be rejected.""" + parser = PythonProtocol() + + # "identity;chunked=not" is not a valid transfer coding + with pytest.raises(UnsupportedTransferCoding): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: identity;chunked=not\r\n" + b"\r\n" + ) + + def test_te_with_tab_prefix_valid_chunked(self): + """Tab before 'chunked' is stripped, value should be valid.""" + parser = PythonProtocol() + + # Tab is stripped during header parsing, so this is actually valid + # But if combined with CL, it should still be rejected + with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 12\r\n" + b"Transfer-Encoding: \tchunked\r\n" + b"\r\n" + ) + + def test_valid_chunked_accepted(self): + """Valid chunked request should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_chunked + assert parser.is_complete + + def test_valid_identity_then_chunked(self): + """identity, chunked is valid per RFC.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: identity, chunked\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_chunked + assert parser.is_complete + + +class TestChunkSizeValidation: + """Test strict validation of chunk sizes.""" + + def test_chunk_size_with_leading_space_rejected(self): + """Leading space in chunk size must be rejected.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + with pytest.raises(InvalidChunkSize, match="Whitespace"): + parser.feed(b" 5\r\nhello\r\n0\r\n\r\n") + + def test_chunk_size_with_trailing_space_rejected(self): + """Trailing space in chunk size must be rejected.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + with pytest.raises(InvalidChunkSize, match="Whitespace"): + parser.feed(b"5 \r\nhello\r\n0\r\n\r\n") + + def test_chunk_size_with_tab_rejected(self): + """Tab in chunk size must be rejected.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + with pytest.raises(InvalidChunkSize): + parser.feed(b"\t5\r\nhello\r\n0\r\n\r\n") + + def test_chunk_size_with_underscore_rejected(self): + """Underscore in chunk size must be rejected.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + with pytest.raises(InvalidChunkSize, match="Invalid character"): + parser.feed(b"6_0\r\n" + b"x" * 96 + b"\r\n0\r\n\r\n") + + def test_empty_chunk_size_rejected(self): + """Empty chunk size must be rejected.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + with pytest.raises(InvalidChunkSize, match="Empty"): + parser.feed(b"\r\nhello\r\n0\r\n\r\n") + + def test_valid_chunk_sizes(self): + """Valid hex chunk sizes should work.""" + parser = PythonProtocol() + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"a\r\n" # 10 in hex + b"0123456789\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"0123456789" + + def test_chunk_extension_accepted(self): + """Chunk extensions after semicolon should be accepted.""" + parser = PythonProtocol() + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5;ext=value\r\n" + b"hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"hello" + + +class TestMultipleTransferEncodingHeaders: + """Test handling of multiple Transfer-Encoding headers.""" + + def test_multiple_te_headers_with_chunked(self): + """Multiple T-E headers that result in chunked should work.""" + parser = PythonProtocol() + + # This tests the iteration over headers - each T-E header is processed + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: identity\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_chunked + assert parser.is_complete + + def test_multiple_te_headers_double_chunked_rejected(self): + """Multiple T-E headers both with chunked should be rejected.""" + parser = PythonProtocol() + + with pytest.raises(InvalidHeader, match="Stacked chunked"): + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) diff --git a/tests/test_asgi_valid_requests.py b/tests/test_asgi_valid_requests.py new file mode 100644 index 00000000..8aacef7f --- /dev/null +++ b/tests/test_asgi_valid_requests.py @@ -0,0 +1,53 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Test valid HTTP requests against ASGI callback parser. + +Runs the same .http test files as test_valid_requests.py but using +the ASGI PythonProtocol callback parser. +""" + +import glob +import os + +import pytest + +import treq_asgi + +dirname = os.path.dirname(__file__) +reqdir = os.path.join(dirname, "requests", "valid") +httpfiles = glob.glob(os.path.join(reqdir, "*.http")) + +# Tests that require features not supported by callback parser +SKIP_TESTS = set() + +# Tests that use config options incompatible with callback parser +INCOMPATIBLE_BOOL_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + + +@pytest.mark.parametrize("fname", httpfiles) +def test_asgi_parser(fname): + """Test valid HTTP requests with ASGI callback parser.""" + basename = os.path.basename(fname) + if basename in SKIP_TESTS: + pytest.skip(f"Test {basename} not supported by callback parser") + + env = treq_asgi.load_py(os.path.splitext(fname)[0] + ".py") + expect = env['request'] + cfg = env['cfg'] + + # Skip tests that use incompatible config flags + for flag in INCOMPATIBLE_BOOL_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"Callback parser incompatible with {flag}") + + # Skip proxy protocol tests + if getattr(cfg, 'proxy_protocol', 'off') != 'off': + pytest.skip("Callback parser does not support proxy_protocol") + + req = treq_asgi.request(fname, expect) + + # Test with different sending strategies + for sender in [req.send_all, req.send_lines, req.send_random]: + req.check(cfg, sender) diff --git a/tests/treq_asgi.py b/tests/treq_asgi.py new file mode 100644 index 00000000..dbb6d2d2 --- /dev/null +++ b/tests/treq_asgi.py @@ -0,0 +1,265 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Test request utilities for ASGI callback parser. + +Provides the same test infrastructure as treq.py but for testing +the ASGI PythonProtocol callback parser. +""" + +import importlib.machinery +import os +import random +import types + +from gunicorn.config import Config +from gunicorn.asgi.parser import ( + PythonProtocol, + ParseError, + InvalidHeader, + InvalidHeaderName, + InvalidRequestLine, + InvalidRequestMethod, + InvalidHTTPVersion, + LimitRequestLine, + LimitRequestHeaders, + UnsupportedTransferCoding, + InvalidChunkSize, + InvalidProxyLine, + InvalidProxyHeader, +) +from gunicorn.util import split_request_uri + +dirname = os.path.dirname(__file__) +random.seed() + + +def uri(data): + ret = {"raw": data} + parts = split_request_uri(data) + ret["scheme"] = parts.scheme or '' + ret["host"] = parts.netloc.rsplit(":", 1)[0] or None + ret["port"] = parts.port or 80 + ret["path"] = parts.path or '' + ret["query"] = parts.query or '' + ret["fragment"] = parts.fragment or '' + return ret + + +def load_py(fname): + """Load test configuration from Python file.""" + module_name = '__config__' + mod = types.ModuleType(module_name) + setattr(mod, 'uri', uri) + setattr(mod, 'cfg', Config()) + loader = importlib.machinery.SourceFileLoader(module_name, fname) + loader.exec_module(mod) + return vars(mod) + + +def decode_hex_escapes(data): + """Decode hex escape sequences like \\xAB in test data.""" + result = bytearray() + i = 0 + while i < len(data): + if i + 3 < len(data) and data[i:i+2] == b'\\x': + hex_chars = data[i+2:i+4] + try: + byte_val = int(hex_chars, 16) + result.append(byte_val) + i += 4 + continue + except ValueError: + pass + result.append(data[i]) + i += 1 + return bytes(result) + + +# Map WSGI parser exceptions to ASGI parser exceptions +EXCEPTION_MAP = { + 'InvalidRequestLine': (InvalidRequestLine, ParseError), + 'InvalidRequestMethod': (InvalidRequestMethod, ParseError), + 'InvalidHTTPVersion': (InvalidHTTPVersion, ParseError), + 'InvalidHeader': (InvalidHeader, ParseError), + 'InvalidHeaderName': (InvalidHeaderName, ParseError), + 'LimitRequestLine': (LimitRequestLine, ParseError), + 'LimitRequestHeaders': (LimitRequestHeaders, ParseError), + 'UnsupportedTransferCoding': (UnsupportedTransferCoding, ParseError), + 'InvalidChunkSize': (InvalidChunkSize, ParseError), + 'InvalidProxyLine': (InvalidProxyLine, ParseError), + 'InvalidProxyHeader': (InvalidProxyHeader, ParseError), +} + + +def map_exception(wsgi_exc): + """Map a WSGI exception class to equivalent ASGI parser exceptions.""" + exc_name = wsgi_exc.__name__ + if exc_name in EXCEPTION_MAP: + return EXCEPTION_MAP[exc_name] + # For other exceptions, accept any ParseError + return (ParseError,) + + +class request: + """Test valid HTTP requests against ASGI callback parser.""" + + def __init__(self, fname, expect): + self.fname = fname + self.name = os.path.basename(fname) + + self.expect = expect + if not isinstance(self.expect, list): + self.expect = [self.expect] + + with open(self.fname, 'rb') as handle: + self.data = handle.read() + self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n") + self.data = self.data.replace(b"\\0", b"\000").replace(b"\\n", b"\n").replace(b"\\t", b"\t") + self.data = decode_hex_escapes(self.data) + if b"\\" in self.data: + raise AssertionError("Unexpected backslash in test data") + + def send_all(self): + yield self.data + + def send_lines(self): + lines = self.data + pos = lines.find(b"\r\n") + while pos > 0: + yield lines[:pos+2] + lines = lines[pos+2:] + pos = lines.find(b"\r\n") + if lines: + yield lines + + def send_bytes(self): + for d in self.data: + yield bytes([d]) + + def send_random(self): + maxs = max(1, round(len(self.data) / 10)) + read = 0 + while read < len(self.data): + chunk = random.randint(1, maxs) + yield self.data[read:read+chunk] + read += chunk + + def check(self, cfg, sender): + """Parse request and verify it matches expected values.""" + body_chunks = [] + + # Handle limit_request_field_size=0 meaning "use default" + field_size = cfg.limit_request_field_size + if field_size <= 0: + field_size = 8190 # Default max + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + limit_request_line=cfg.limit_request_line, + limit_request_fields=cfg.limit_request_fields, + limit_request_field_size=field_size, + permit_unconventional_http_method=cfg.permit_unconventional_http_method, + permit_unconventional_http_version=cfg.permit_unconventional_http_version, + proxy_protocol=getattr(cfg, 'proxy_protocol', 'off'), + ) + + for chunk in sender(): + parser.feed(chunk) + + # Verify parsed request matches expected + exp = self.expect[0] # For now, handle single request + + assert parser.method == exp["method"].encode('latin-1'), \ + f"Method mismatch: {parser.method} != {exp['method']}" + + # Path comparison - parser stores raw bytes + expected_path = exp["uri"]["raw"].encode('latin-1') + assert parser.path == expected_path, \ + f"Path mismatch: {parser.path} != {expected_path}" + + assert parser.http_version == exp["version"], \ + f"Version mismatch: {parser.http_version} != {exp['version']}" + + # Headers - convert to comparable format + parsed_headers = [ + (n.decode('latin-1').upper(), v.decode('latin-1')) + for n, v in parser.headers + ] + assert parsed_headers == exp["headers"], \ + f"Headers mismatch: {parsed_headers} != {exp['headers']}" + + # Body + body = b"".join(body_chunks) + expected_body = exp["body"] + assert body == expected_body, \ + f"Body mismatch: {body!r} != {expected_body!r}" + + assert parser.is_complete, "Parser did not complete" + + +class badrequest: + """Test invalid HTTP requests against ASGI callback parser.""" + + def __init__(self, fname): + self.fname = fname + self.name = os.path.basename(fname) + + with open(self.fname) as handle: + self.data = handle.read() + self.data = self.data.replace("\n", "").replace("\\r\\n", "\r\n") + self.data = self.data.replace("\\0", "\000").replace("\\n", "\n").replace("\\t", "\t") + if "\\" in self.data: + raise AssertionError("Unexpected backslash in test data") + self.data = self.data.encode('latin1') + + def send_all(self): + yield self.data + + def send_random(self): + maxs = max(1, round(len(self.data) / 10)) + read = 0 + while read < len(self.data): + chunk = random.randint(1, maxs) + yield self.data[read:read+chunk] + read += chunk + + def check(self, cfg, expected_exc): + """Verify parser raises expected exception.""" + # Handle limit_request_field_size=0 meaning "use default" + field_size = cfg.limit_request_field_size + if field_size <= 0: + field_size = 8190 # Default max + + parser = PythonProtocol( + limit_request_line=cfg.limit_request_line, + limit_request_fields=cfg.limit_request_fields, + limit_request_field_size=field_size, + permit_unconventional_http_method=cfg.permit_unconventional_http_method, + permit_unconventional_http_version=cfg.permit_unconventional_http_version, + proxy_protocol=getattr(cfg, 'proxy_protocol', 'off'), + ) + + # Get acceptable exception types + acceptable = map_exception(expected_exc) + + raised = False + try: + for chunk in self.send_random(): + parser.feed(chunk) + # If we get here without exception, try to check if parser completed + # Some invalid requests might parse headers but fail on body + if not parser.is_complete: + # Parser stalled - this counts as detecting invalid input + raised = True + except acceptable: + raised = True + except ParseError: + # Accept any ParseError as valid rejection + raised = True + + if not raised: + raise AssertionError( + f"Expected {expected_exc.__name__} but parser accepted the request" + )