Fix ASGI callback parser header validation

Add security checks to PythonProtocol per RFC 9110/9112:
- Reject duplicate Content-Length headers
- Reject CL + TE combinations
- Reject chunked in HTTP/1.0
- Reject stacked chunked encoding
- Validate Transfer-Encoding values
- Strict chunk size validation

Add PROXY protocol v1/v2 support to callback parser.

Add treq-based test infrastructure for ASGI parser.
This commit is contained in:
Benoit Chesneau 2026-03-25 16:20:42 +01:00
parent a49a46fc19
commit ffcebce4a7
5 changed files with 1138 additions and 13 deletions

View File

@ -9,11 +9,47 @@ Provides callback-based parsing using either the fast C parser (gunicorn_h1c)
or the pure Python PythonProtocol fallback.
"""
import struct
from enum import IntEnum
class ParseError(Exception):
"""Base error raised during HTTP parsing."""
class InvalidProxyLine(ParseError):
"""Invalid PROXY protocol v1 line."""
class InvalidProxyHeader(ParseError):
"""Invalid PROXY protocol v2 header."""
# PROXY protocol v2 constants
PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
class PPCommand(IntEnum):
"""PROXY protocol v2 commands."""
LOCAL = 0x0
PROXY = 0x1
class PPFamily(IntEnum):
"""PROXY protocol v2 address families."""
UNSPEC = 0x0
INET = 0x1 # IPv4
INET6 = 0x2 # IPv6
UNIX = 0x3
class PPProtocol(IntEnum):
"""PROXY protocol v2 transport protocols."""
UNSPEC = 0x0
STREAM = 0x1 # TCP
DGRAM = 0x2 # UDP
class LimitRequestLine(ParseError):
"""Request line exceeds configured limit."""
@ -22,6 +58,10 @@ class LimitRequestHeaders(ParseError):
"""Too many headers or header field too large."""
class InvalidRequestLine(ParseError):
"""Invalid request line."""
class InvalidRequestMethod(ParseError):
"""Invalid HTTP method."""
@ -38,6 +78,14 @@ class InvalidHeader(ParseError):
"""Invalid header value."""
class UnsupportedTransferCoding(ParseError):
"""Unsupported Transfer-Encoding value."""
class InvalidChunkSize(ParseError):
"""Invalid chunk size in chunked transfer encoding."""
class PythonProtocol:
"""Callback-based HTTP/1.1 parser (pure Python fallback).
@ -64,6 +112,7 @@ class PythonProtocol:
'_limit_request_line', '_limit_request_fields', '_limit_request_field_size',
'_permit_unconventional_http_method', '_permit_unconventional_http_version',
'_header_count',
'_proxy_protocol', '_proxy_protocol_info', '_proxy_protocol_done',
)
def __init__(
@ -79,6 +128,7 @@ class PythonProtocol:
limit_request_field_size=8190,
permit_unconventional_http_method=False,
permit_unconventional_http_version=False,
proxy_protocol='off',
):
self._on_message_begin = on_message_begin
self._on_url = on_url
@ -95,8 +145,13 @@ class PythonProtocol:
self._permit_unconventional_http_version = permit_unconventional_http_version
self._header_count = 0
# Parser state: request_line, headers, body, chunked_size, chunked_data, complete
self._state = 'request_line'
# Proxy protocol
self._proxy_protocol = proxy_protocol
self._proxy_protocol_info = None
self._proxy_protocol_done = (proxy_protocol == 'off')
# Parser state: proxy_protocol, request_line, headers, body, chunked_size, chunked_data, complete
self._state = 'proxy_protocol' if proxy_protocol != 'off' else 'request_line'
self._buffer = bytearray()
self._headers_list = []
@ -131,7 +186,10 @@ class PythonProtocol:
self._buffer.extend(data)
while self._buffer:
if self._state == 'request_line':
if self._state == 'proxy_protocol':
if not self._parse_proxy_protocol():
break
elif self._state == 'request_line':
if not self._parse_request_line():
break
elif self._state == 'headers':
@ -146,6 +204,11 @@ class PythonProtocol:
else:
break
@property
def proxy_protocol_info(self):
"""Return proxy protocol info if parsed."""
return self._proxy_protocol_info
def reset(self):
"""Reset for next request (keepalive)."""
self._state = 'request_line'
@ -166,6 +229,190 @@ class PythonProtocol:
self._chunk_remaining = 0
self._header_count = 0
def _parse_proxy_protocol(self):
"""Parse PROXY protocol header if enabled.
Returns True if parsing is complete (or not applicable),
False if more data is needed.
"""
# Need at least 12 bytes to detect v2 signature or check for v1 prefix
if len(self._buffer) < 12:
return False
mode = self._proxy_protocol
# Check for v2 signature first
if mode in ('v2', 'auto') and self._buffer[:12] == PP_V2_SIGNATURE:
return self._parse_proxy_protocol_v2()
# Check for v1 prefix
if mode in ('v1', 'auto') and self._buffer[:6] == b'PROXY ':
return self._parse_proxy_protocol_v1()
# Not proxy protocol - continue with normal parsing
self._proxy_protocol_done = True
self._state = 'request_line'
return True
def _parse_proxy_protocol_v1(self):
"""Parse PROXY protocol v1 (text format).
Format: PROXY <PROTO> <SRC_ADDR> <DST_ADDR> <SRC_PORT> <DST_PORT>\r\n
"""
# Find end of line
idx = self._buffer.find(b'\r\n')
if idx == -1:
# Need more data - v1 header can be up to 107 bytes
if len(self._buffer) > 107:
raise InvalidProxyLine("PROXY v1 header too long")
return False
line = bytes(self._buffer[:idx]).decode('latin-1')
del self._buffer[:idx + 2]
# Parse the line
parts = line.split(' ')
if len(parts) < 2:
raise InvalidProxyLine("Invalid PROXY v1 line")
proto = parts[1].upper()
if proto == 'UNKNOWN':
# Unknown protocol - no address info
self._proxy_protocol_info = {
'proxy_protocol': 'UNKNOWN',
'client_addr': None,
'client_port': None,
'proxy_addr': None,
'proxy_port': None,
}
elif proto in ('TCP4', 'TCP6'):
if len(parts) != 6:
raise InvalidProxyLine("Invalid PROXY v1 line for %s" % proto)
try:
s_addr = parts[2]
d_addr = parts[3]
s_port = int(parts[4])
d_port = int(parts[5])
except ValueError as e:
raise InvalidProxyLine("Invalid PROXY v1 port: %s" % e)
if not (0 <= s_port <= 65535 and 0 <= d_port <= 65535):
raise InvalidProxyLine("Invalid PROXY v1 port range")
self._proxy_protocol_info = {
'proxy_protocol': proto,
'client_addr': s_addr,
'client_port': s_port,
'proxy_addr': d_addr,
'proxy_port': d_port,
}
else:
raise InvalidProxyLine("Unknown PROXY v1 protocol: %s" % proto)
self._proxy_protocol_done = True
self._state = 'request_line'
return True
def _parse_proxy_protocol_v2(self):
"""Parse PROXY protocol v2 (binary format)."""
# Need at least 16 bytes for header
if len(self._buffer) < 16:
return False
# Parse header
ver_cmd = self._buffer[12]
fam_prot = self._buffer[13]
length = struct.unpack('>H', bytes(self._buffer[14:16]))[0]
# Check version
version = (ver_cmd & 0xF0) >> 4
if version != 2:
raise InvalidProxyHeader("Unsupported PROXY v2 version: %d" % version)
# Check command
command = ver_cmd & 0x0F
if command not in (PPCommand.LOCAL, PPCommand.PROXY):
raise InvalidProxyHeader("Unsupported PROXY v2 command: %d" % command)
# Check if we have the complete header
total_size = 16 + length
if len(self._buffer) < total_size:
return False
# Extract address data
addr_data = bytes(self._buffer[16:total_size])
del self._buffer[:total_size]
# Handle LOCAL command
if command == PPCommand.LOCAL:
self._proxy_protocol_info = {
'proxy_protocol': 'LOCAL',
'client_addr': None,
'client_port': None,
'proxy_addr': None,
'proxy_port': None,
}
self._proxy_protocol_done = True
self._state = 'request_line'
return True
# Parse address family and protocol
family = (fam_prot & 0xF0) >> 4
protocol = fam_prot & 0x0F
if family == PPFamily.INET:
# IPv4
if len(addr_data) < 12:
raise InvalidProxyHeader("Invalid PROXY v2 IPv4 address data")
s_addr = '.'.join(str(b) for b in addr_data[:4])
d_addr = '.'.join(str(b) for b in addr_data[4:8])
s_port = struct.unpack('>H', addr_data[8:10])[0]
d_port = struct.unpack('>H', addr_data[10:12])[0]
proto = 'TCP4' if protocol == PPProtocol.STREAM else 'UDP4'
elif family == PPFamily.INET6:
# IPv6
if len(addr_data) < 36:
raise InvalidProxyHeader("Invalid PROXY v2 IPv6 address data")
# Format IPv6 addresses
s_words = struct.unpack('>8H', addr_data[:16])
d_words = struct.unpack('>8H', addr_data[16:32])
s_addr = ':'.join('%x' % w for w in s_words)
d_addr = ':'.join('%x' % w for w in d_words)
s_port = struct.unpack('>H', addr_data[32:34])[0]
d_port = struct.unpack('>H', addr_data[34:36])[0]
proto = 'TCP6' if protocol == PPProtocol.STREAM else 'UDP6'
elif family == PPFamily.UNSPEC:
# Unspecified address family
self._proxy_protocol_info = {
'proxy_protocol': 'UNSPEC',
'client_addr': None,
'client_port': None,
'proxy_addr': None,
'proxy_port': None,
}
self._proxy_protocol_done = True
self._state = 'request_line'
return True
else:
raise InvalidProxyHeader("Unsupported PROXY v2 address family: %d" % family)
self._proxy_protocol_info = {
'proxy_protocol': proto,
'client_addr': s_addr,
'client_port': s_port,
'proxy_addr': d_addr,
'proxy_port': d_port,
}
self._proxy_protocol_done = True
self._state = 'request_line'
return True
def _parse_request_line(self):
"""Parse request line, return True if complete."""
idx = self._buffer.find(b'\r\n')
@ -182,7 +429,7 @@ class PythonProtocol:
# Parse: METHOD PATH HTTP/x.y
parts = line.split(b' ', 2)
if len(parts) != 3:
raise ParseError("Invalid request line")
raise InvalidRequestLine("Invalid request line")
self.method = parts[0]
self.path = parts[1]
@ -234,8 +481,8 @@ class PythonProtocol:
self._finalize_headers()
return True
# Check header field size limit
if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size:
# Check header field size limit (include CRLF in size to match WSGI parser)
if self._limit_request_field_size > 0 and len(line) + 2 > self._limit_request_field_size:
raise LimitRequestHeaders("Request header field is too large")
# Check header count limit
@ -264,16 +511,59 @@ class PythonProtocol:
self._on_header(name_lower, value)
def _finalize_headers(self):
"""Called when all headers received."""
"""Called when all headers received.
Validates headers for request smuggling vulnerabilities:
- Rejects duplicate Content-Length headers
- Rejects requests with both Content-Length and Transfer-Encoding
- Rejects chunked Transfer-Encoding in HTTP/1.0
- Rejects stacked chunked encoding
- Validates Transfer-Encoding values
"""
self.headers = self._headers_list
# Extract content-length and chunked
# Extract and validate content-length and transfer-encoding
content_length = None
chunked = False
for name, value in self.headers:
if name == b'content-length':
self.content_length = int(value)
self._body_remaining = self.content_length
# Reject duplicate Content-Length headers (request smuggling vector)
if content_length is not None:
raise InvalidHeader("Duplicate Content-Length header")
try:
cl_value = int(value)
except ValueError:
raise InvalidHeader("Invalid Content-Length value")
if cl_value < 0:
raise InvalidHeader("Negative Content-Length")
content_length = cl_value
elif name == b'transfer-encoding':
self.is_chunked = b'chunked' in value.lower()
# Properly parse comma-separated Transfer-Encoding values
# per RFC 9112 Section 6.1
vals = [v.strip() for v in value.split(b',')]
for val in vals:
val_lower = val.lower()
if val_lower == b'chunked':
# Reject stacked chunked encoding (request smuggling vector)
if chunked:
raise InvalidHeader("Stacked chunked encoding")
chunked = True
elif val_lower == b'identity':
# identity after chunked is invalid
if chunked:
raise InvalidHeader("Invalid Transfer-Encoding after chunked")
elif val_lower in (b'compress', b'deflate', b'gzip'):
# Compression after chunked is invalid
if chunked:
raise InvalidHeader("Invalid Transfer-Encoding after chunked")
# Mark connection for close (unsupported but valid)
self.should_keep_alive = False
else:
# Reject unknown transfer codings
raise UnsupportedTransferCoding(val.decode('latin-1'))
elif name == b'connection':
val = value.lower()
if b'close' in val:
@ -281,6 +571,25 @@ class PythonProtocol:
elif b'keep-alive' in val:
self.should_keep_alive = True
# Security checks for request smuggling prevention
if chunked:
# Reject chunked in HTTP/1.0 (RFC 9112 Section 6.1)
if self.http_version < (1, 1):
raise InvalidHeader("Chunked encoding not allowed in HTTP/1.0")
# Reject Content-Length with Transfer-Encoding (request smuggling vector)
if content_length is not None:
raise InvalidHeader("Content-Length with Transfer-Encoding")
self.is_chunked = True
self.content_length = None
self._body_remaining = -1 # Chunked mode
elif content_length is not None:
self.content_length = content_length
self._body_remaining = content_length
else:
# No body
self.content_length = None
self._body_remaining = 0
# HTTP/1.0 defaults to close
if self.http_version == (1, 0) and self.should_keep_alive:
# Only keep-alive if explicitly requested
@ -348,12 +657,24 @@ class PythonProtocol:
# Handle chunk extensions (e.g., "5;ext=value")
semicolon = size_line.find(b';')
if semicolon != -1:
size_line = size_line[:semicolon].strip()
size_line = size_line[:semicolon]
# Strict validation: reject leading/trailing whitespace
# to prevent parser desync (request smuggling vector)
if size_line != size_line.strip():
raise InvalidChunkSize("Whitespace in chunk size")
if not size_line:
raise InvalidChunkSize("Empty chunk size")
# Validate hex characters only (0-9, a-f, A-F)
for c in size_line:
if c not in b'0123456789abcdefABCDEF':
raise InvalidChunkSize("Invalid character in chunk size")
try:
self._chunk_size = int(size_line, 16)
except ValueError:
raise ParseError("Invalid chunk size")
raise InvalidChunkSize("Invalid chunk size")
if self._chunk_size == 0:
# Final chunk - skip trailers

View File

@ -0,0 +1,68 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Test invalid HTTP requests against ASGI callback parser.
Runs the same .http test files as test_invalid_requests.py but using
the ASGI PythonProtocol callback parser.
"""
import glob
import os
import pytest
from gunicorn.http.errors import (
InvalidSchemeHeaders,
ObsoleteFolding,
)
import treq_asgi
dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "invalid")
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
# Tests that require features not supported by callback parser
SKIP_TESTS = {
# Tests requiring header_map config (underscore handling)
'chunked_07.http', '040.http',
# Tests for features not in callback parser
'008.http', # Invalid request target validation
'012.http', # Invalid request target validation
'016.http', # URI bracket validation
'020.http', # Space before colon in header name
'022.http', # Request target validation
}
# Config flags incompatible with callback parser
INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
# Exceptions only raised by Python WSGI parser
WSGI_ONLY_EXCEPTIONS = (ObsoleteFolding, InvalidSchemeHeaders)
@pytest.mark.parametrize("fname", httpfiles)
def test_asgi_parser(fname):
"""Test invalid HTTP requests with ASGI callback parser."""
basename = os.path.basename(fname)
if basename in SKIP_TESTS:
pytest.skip(f"Test {basename} not supported by callback parser")
env = treq_asgi.load_py(os.path.splitext(fname)[0] + ".py")
expect = env["request"]
cfg = env["cfg"]
# Skip tests that use incompatible config flags
for flag in INCOMPATIBLE_FLAGS:
if getattr(cfg, flag, False):
pytest.skip(f"Callback parser incompatible with {flag}")
# Skip tests expecting WSGI-only exceptions
if expect in WSGI_ONLY_EXCEPTIONS or (
isinstance(expect, type) and issubclass(expect, WSGI_ONLY_EXCEPTIONS)
):
pytest.skip(f"Callback parser does not raise {expect.__name__}")
req = treq_asgi.badrequest(fname)
req.check(cfg, expect)

View File

@ -0,0 +1,418 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for ASGI callback parser header validation.
These tests verify that PythonProtocol correctly validates HTTP headers
and body framing according to RFC 9110 and RFC 9112.
"""
import pytest
from gunicorn.asgi.parser import (
PythonProtocol,
InvalidHeader,
InvalidChunkSize,
UnsupportedTransferCoding,
ParseError,
)
class TestContentLengthTransferEncodingConflict:
"""Test rejection of requests with both CL and TE headers."""
def test_cl_te_conflict_rejected(self):
"""Request with both Content-Length and Transfer-Encoding must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
def test_te_cl_conflict_rejected(self):
"""Order doesn't matter - TE before CL also rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
def test_invalid_te_with_cl_rejected(self):
"""Invalid T-E value combined with CL must be rejected."""
parser = PythonProtocol()
# This should fail due to invalid T-E value (identity;chunked=not)
with pytest.raises((InvalidHeader, UnsupportedTransferCoding)):
parser.feed(
b"POST /headers HTTP/1.0\r\n"
b"Connection: keep-alive\r\n"
b"Transfer-Encoding: identity;chunked=not\r\n"
b"Content-Length: -999\r\n"
b"\r\n"
)
class TestDuplicateContentLength:
"""Test rejection of duplicate Content-Length headers."""
def test_duplicate_cl_rejected(self):
"""Duplicate Content-Length headers must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Duplicate Content-Length"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
def test_different_cl_values_rejected(self):
"""Different Content-Length values must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Duplicate Content-Length"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"Content-Length: 20\r\n"
b"\r\n"
)
def test_negative_cl_rejected(self):
"""Negative Content-Length must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Negative Content-Length"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: -999\r\n"
b"\r\n"
)
def test_non_numeric_cl_rejected(self):
"""Non-numeric Content-Length must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Invalid Content-Length"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: abc\r\n"
b"\r\n"
)
def test_cl_with_spaces_rejected(self):
"""Content-Length with embedded spaces must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 0 1\r\n"
b"\r\n"
)
class TestChunkedInHTTP10:
"""Test rejection of chunked encoding in HTTP/1.0."""
def test_chunked_http10_rejected(self):
"""Chunked Transfer-Encoding in HTTP/1.0 must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="HTTP/1.0"):
parser.feed(
b"POST /test HTTP/1.0\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
class TestTransferEncodingValidation:
"""Test proper validation of Transfer-Encoding header values."""
def test_stacked_chunked_rejected(self):
"""Stacked chunked encoding must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Stacked chunked"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked, chunked\r\n"
b"\r\n"
)
def test_chunked_then_identity_rejected(self):
"""Identity after chunked must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Invalid Transfer-Encoding after chunked"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked, identity\r\n"
b"\r\n"
)
def test_chunked_then_gzip_rejected(self):
"""Compression after chunked must be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Invalid Transfer-Encoding after chunked"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked, gzip\r\n"
b"\r\n"
)
def test_unknown_transfer_coding_rejected(self):
"""Unknown transfer codings must be rejected."""
parser = PythonProtocol()
with pytest.raises(UnsupportedTransferCoding):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: bogus\r\n"
b"\r\n"
)
def test_te_with_parameters_rejected(self):
"""Transfer-Encoding with parameters (like identity;chunked=not) must be rejected."""
parser = PythonProtocol()
# "identity;chunked=not" is not a valid transfer coding
with pytest.raises(UnsupportedTransferCoding):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: identity;chunked=not\r\n"
b"\r\n"
)
def test_te_with_tab_prefix_valid_chunked(self):
"""Tab before 'chunked' is stripped, value should be valid."""
parser = PythonProtocol()
# Tab is stripped during header parsing, so this is actually valid
# But if combined with CL, it should still be rejected
with pytest.raises(InvalidHeader, match="Content-Length with Transfer-Encoding"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 12\r\n"
b"Transfer-Encoding: \tchunked\r\n"
b"\r\n"
)
def test_valid_chunked_accepted(self):
"""Valid chunked request should be accepted."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_chunked
assert parser.is_complete
def test_valid_identity_then_chunked(self):
"""identity, chunked is valid per RFC."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: identity, chunked\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_chunked
assert parser.is_complete
class TestChunkSizeValidation:
"""Test strict validation of chunk sizes."""
def test_chunk_size_with_leading_space_rejected(self):
"""Leading space in chunk size must be rejected."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
with pytest.raises(InvalidChunkSize, match="Whitespace"):
parser.feed(b" 5\r\nhello\r\n0\r\n\r\n")
def test_chunk_size_with_trailing_space_rejected(self):
"""Trailing space in chunk size must be rejected."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
with pytest.raises(InvalidChunkSize, match="Whitespace"):
parser.feed(b"5 \r\nhello\r\n0\r\n\r\n")
def test_chunk_size_with_tab_rejected(self):
"""Tab in chunk size must be rejected."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
with pytest.raises(InvalidChunkSize):
parser.feed(b"\t5\r\nhello\r\n0\r\n\r\n")
def test_chunk_size_with_underscore_rejected(self):
"""Underscore in chunk size must be rejected."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
with pytest.raises(InvalidChunkSize, match="Invalid character"):
parser.feed(b"6_0\r\n" + b"x" * 96 + b"\r\n0\r\n\r\n")
def test_empty_chunk_size_rejected(self):
"""Empty chunk size must be rejected."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
with pytest.raises(InvalidChunkSize, match="Empty"):
parser.feed(b"\r\nhello\r\n0\r\n\r\n")
def test_valid_chunk_sizes(self):
"""Valid hex chunk sizes should work."""
parser = PythonProtocol()
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"a\r\n" # 10 in hex
b"0123456789\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"0123456789"
def test_chunk_extension_accepted(self):
"""Chunk extensions after semicolon should be accepted."""
parser = PythonProtocol()
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5;ext=value\r\n"
b"hello\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"hello"
class TestMultipleTransferEncodingHeaders:
"""Test handling of multiple Transfer-Encoding headers."""
def test_multiple_te_headers_with_chunked(self):
"""Multiple T-E headers that result in chunked should work."""
parser = PythonProtocol()
# This tests the iteration over headers - each T-E header is processed
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: identity\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_chunked
assert parser.is_complete
def test_multiple_te_headers_double_chunked_rejected(self):
"""Multiple T-E headers both with chunked should be rejected."""
parser = PythonProtocol()
with pytest.raises(InvalidHeader, match="Stacked chunked"):
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)

View File

@ -0,0 +1,53 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Test valid HTTP requests against ASGI callback parser.
Runs the same .http test files as test_valid_requests.py but using
the ASGI PythonProtocol callback parser.
"""
import glob
import os
import pytest
import treq_asgi
dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "valid")
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
# Tests that require features not supported by callback parser
SKIP_TESTS = set()
# Tests that use config options incompatible with callback parser
INCOMPATIBLE_BOOL_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
@pytest.mark.parametrize("fname", httpfiles)
def test_asgi_parser(fname):
"""Test valid HTTP requests with ASGI callback parser."""
basename = os.path.basename(fname)
if basename in SKIP_TESTS:
pytest.skip(f"Test {basename} not supported by callback parser")
env = treq_asgi.load_py(os.path.splitext(fname)[0] + ".py")
expect = env['request']
cfg = env['cfg']
# Skip tests that use incompatible config flags
for flag in INCOMPATIBLE_BOOL_FLAGS:
if getattr(cfg, flag, False):
pytest.skip(f"Callback parser incompatible with {flag}")
# Skip proxy protocol tests
if getattr(cfg, 'proxy_protocol', 'off') != 'off':
pytest.skip("Callback parser does not support proxy_protocol")
req = treq_asgi.request(fname, expect)
# Test with different sending strategies
for sender in [req.send_all, req.send_lines, req.send_random]:
req.check(cfg, sender)

265
tests/treq_asgi.py Normal file
View File

@ -0,0 +1,265 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Test request utilities for ASGI callback parser.
Provides the same test infrastructure as treq.py but for testing
the ASGI PythonProtocol callback parser.
"""
import importlib.machinery
import os
import random
import types
from gunicorn.config import Config
from gunicorn.asgi.parser import (
PythonProtocol,
ParseError,
InvalidHeader,
InvalidHeaderName,
InvalidRequestLine,
InvalidRequestMethod,
InvalidHTTPVersion,
LimitRequestLine,
LimitRequestHeaders,
UnsupportedTransferCoding,
InvalidChunkSize,
InvalidProxyLine,
InvalidProxyHeader,
)
from gunicorn.util import split_request_uri
dirname = os.path.dirname(__file__)
random.seed()
def uri(data):
ret = {"raw": data}
parts = split_request_uri(data)
ret["scheme"] = parts.scheme or ''
ret["host"] = parts.netloc.rsplit(":", 1)[0] or None
ret["port"] = parts.port or 80
ret["path"] = parts.path or ''
ret["query"] = parts.query or ''
ret["fragment"] = parts.fragment or ''
return ret
def load_py(fname):
"""Load test configuration from Python file."""
module_name = '__config__'
mod = types.ModuleType(module_name)
setattr(mod, 'uri', uri)
setattr(mod, 'cfg', Config())
loader = importlib.machinery.SourceFileLoader(module_name, fname)
loader.exec_module(mod)
return vars(mod)
def decode_hex_escapes(data):
"""Decode hex escape sequences like \\xAB in test data."""
result = bytearray()
i = 0
while i < len(data):
if i + 3 < len(data) and data[i:i+2] == b'\\x':
hex_chars = data[i+2:i+4]
try:
byte_val = int(hex_chars, 16)
result.append(byte_val)
i += 4
continue
except ValueError:
pass
result.append(data[i])
i += 1
return bytes(result)
# Map WSGI parser exceptions to ASGI parser exceptions
EXCEPTION_MAP = {
'InvalidRequestLine': (InvalidRequestLine, ParseError),
'InvalidRequestMethod': (InvalidRequestMethod, ParseError),
'InvalidHTTPVersion': (InvalidHTTPVersion, ParseError),
'InvalidHeader': (InvalidHeader, ParseError),
'InvalidHeaderName': (InvalidHeaderName, ParseError),
'LimitRequestLine': (LimitRequestLine, ParseError),
'LimitRequestHeaders': (LimitRequestHeaders, ParseError),
'UnsupportedTransferCoding': (UnsupportedTransferCoding, ParseError),
'InvalidChunkSize': (InvalidChunkSize, ParseError),
'InvalidProxyLine': (InvalidProxyLine, ParseError),
'InvalidProxyHeader': (InvalidProxyHeader, ParseError),
}
def map_exception(wsgi_exc):
"""Map a WSGI exception class to equivalent ASGI parser exceptions."""
exc_name = wsgi_exc.__name__
if exc_name in EXCEPTION_MAP:
return EXCEPTION_MAP[exc_name]
# For other exceptions, accept any ParseError
return (ParseError,)
class request:
"""Test valid HTTP requests against ASGI callback parser."""
def __init__(self, fname, expect):
self.fname = fname
self.name = os.path.basename(fname)
self.expect = expect
if not isinstance(self.expect, list):
self.expect = [self.expect]
with open(self.fname, 'rb') as handle:
self.data = handle.read()
self.data = self.data.replace(b"\n", b"").replace(b"\\r\\n", b"\r\n")
self.data = self.data.replace(b"\\0", b"\000").replace(b"\\n", b"\n").replace(b"\\t", b"\t")
self.data = decode_hex_escapes(self.data)
if b"\\" in self.data:
raise AssertionError("Unexpected backslash in test data")
def send_all(self):
yield self.data
def send_lines(self):
lines = self.data
pos = lines.find(b"\r\n")
while pos > 0:
yield lines[:pos+2]
lines = lines[pos+2:]
pos = lines.find(b"\r\n")
if lines:
yield lines
def send_bytes(self):
for d in self.data:
yield bytes([d])
def send_random(self):
maxs = max(1, round(len(self.data) / 10))
read = 0
while read < len(self.data):
chunk = random.randint(1, maxs)
yield self.data[read:read+chunk]
read += chunk
def check(self, cfg, sender):
"""Parse request and verify it matches expected values."""
body_chunks = []
# Handle limit_request_field_size=0 meaning "use default"
field_size = cfg.limit_request_field_size
if field_size <= 0:
field_size = 8190 # Default max
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
limit_request_line=cfg.limit_request_line,
limit_request_fields=cfg.limit_request_fields,
limit_request_field_size=field_size,
permit_unconventional_http_method=cfg.permit_unconventional_http_method,
permit_unconventional_http_version=cfg.permit_unconventional_http_version,
proxy_protocol=getattr(cfg, 'proxy_protocol', 'off'),
)
for chunk in sender():
parser.feed(chunk)
# Verify parsed request matches expected
exp = self.expect[0] # For now, handle single request
assert parser.method == exp["method"].encode('latin-1'), \
f"Method mismatch: {parser.method} != {exp['method']}"
# Path comparison - parser stores raw bytes
expected_path = exp["uri"]["raw"].encode('latin-1')
assert parser.path == expected_path, \
f"Path mismatch: {parser.path} != {expected_path}"
assert parser.http_version == exp["version"], \
f"Version mismatch: {parser.http_version} != {exp['version']}"
# Headers - convert to comparable format
parsed_headers = [
(n.decode('latin-1').upper(), v.decode('latin-1'))
for n, v in parser.headers
]
assert parsed_headers == exp["headers"], \
f"Headers mismatch: {parsed_headers} != {exp['headers']}"
# Body
body = b"".join(body_chunks)
expected_body = exp["body"]
assert body == expected_body, \
f"Body mismatch: {body!r} != {expected_body!r}"
assert parser.is_complete, "Parser did not complete"
class badrequest:
"""Test invalid HTTP requests against ASGI callback parser."""
def __init__(self, fname):
self.fname = fname
self.name = os.path.basename(fname)
with open(self.fname) as handle:
self.data = handle.read()
self.data = self.data.replace("\n", "").replace("\\r\\n", "\r\n")
self.data = self.data.replace("\\0", "\000").replace("\\n", "\n").replace("\\t", "\t")
if "\\" in self.data:
raise AssertionError("Unexpected backslash in test data")
self.data = self.data.encode('latin1')
def send_all(self):
yield self.data
def send_random(self):
maxs = max(1, round(len(self.data) / 10))
read = 0
while read < len(self.data):
chunk = random.randint(1, maxs)
yield self.data[read:read+chunk]
read += chunk
def check(self, cfg, expected_exc):
"""Verify parser raises expected exception."""
# Handle limit_request_field_size=0 meaning "use default"
field_size = cfg.limit_request_field_size
if field_size <= 0:
field_size = 8190 # Default max
parser = PythonProtocol(
limit_request_line=cfg.limit_request_line,
limit_request_fields=cfg.limit_request_fields,
limit_request_field_size=field_size,
permit_unconventional_http_method=cfg.permit_unconventional_http_method,
permit_unconventional_http_version=cfg.permit_unconventional_http_version,
proxy_protocol=getattr(cfg, 'proxy_protocol', 'off'),
)
# Get acceptable exception types
acceptable = map_exception(expected_exc)
raised = False
try:
for chunk in self.send_random():
parser.feed(chunk)
# If we get here without exception, try to check if parser completed
# Some invalid requests might parse headers but fail on body
if not parser.is_complete:
# Parser stalled - this counts as detecting invalid input
raised = True
except acceptable:
raised = True
except ParseError:
# Accept any ParseError as valid rejection
raised = True
if not raised:
raise AssertionError(
f"Expected {expected_exc.__name__} but parser accepted the request"
)