mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
Integrate gunicorn_h1c 0.4.1 exception types and limit parameters
Require gunicorn_h1c >= 0.4.1 for fast parser mode. Add new exception types and limit parameters to PythonProtocol for parity with C parser. Update tests to parametrize across both parser implementations.
This commit is contained in:
parent
f308e7abfa
commit
03cc85ef48
@ -11,7 +11,31 @@ or the pure Python PythonProtocol fallback.
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Error raised during HTTP parsing."""
|
||||
"""Base error raised during HTTP parsing."""
|
||||
|
||||
|
||||
class LimitRequestLine(ParseError):
|
||||
"""Request line exceeds configured limit."""
|
||||
|
||||
|
||||
class LimitRequestHeaders(ParseError):
|
||||
"""Too many headers or header field too large."""
|
||||
|
||||
|
||||
class InvalidRequestMethod(ParseError):
|
||||
"""Invalid HTTP method."""
|
||||
|
||||
|
||||
class InvalidHTTPVersion(ParseError):
|
||||
"""Invalid HTTP version."""
|
||||
|
||||
|
||||
class InvalidHeaderName(ParseError):
|
||||
"""Invalid header name."""
|
||||
|
||||
|
||||
class InvalidHeader(ParseError):
|
||||
"""Invalid header value."""
|
||||
|
||||
|
||||
class PythonProtocol:
|
||||
@ -37,6 +61,9 @@ class PythonProtocol:
|
||||
'content_length', 'is_chunked', 'should_keep_alive', 'is_complete',
|
||||
'_body_remaining', '_skip_body',
|
||||
'_chunk_state', '_chunk_size', '_chunk_remaining',
|
||||
'_limit_request_line', '_limit_request_fields', '_limit_request_field_size',
|
||||
'_permit_unconventional_http_method', '_permit_unconventional_http_version',
|
||||
'_header_count',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -47,6 +74,11 @@ class PythonProtocol:
|
||||
on_headers_complete=None,
|
||||
on_body=None,
|
||||
on_message_complete=None,
|
||||
limit_request_line=8190,
|
||||
limit_request_fields=100,
|
||||
limit_request_field_size=8190,
|
||||
permit_unconventional_http_method=False,
|
||||
permit_unconventional_http_version=False,
|
||||
):
|
||||
self._on_message_begin = on_message_begin
|
||||
self._on_url = on_url
|
||||
@ -55,6 +87,14 @@ class PythonProtocol:
|
||||
self._on_body = on_body
|
||||
self._on_message_complete = on_message_complete
|
||||
|
||||
# Store limits
|
||||
self._limit_request_line = limit_request_line
|
||||
self._limit_request_fields = limit_request_fields
|
||||
self._limit_request_field_size = limit_request_field_size
|
||||
self._permit_unconventional_http_method = permit_unconventional_http_method
|
||||
self._permit_unconventional_http_version = permit_unconventional_http_version
|
||||
self._header_count = 0
|
||||
|
||||
# Parser state: request_line, headers, body, chunked_size, chunked_data, complete
|
||||
self._state = 'request_line'
|
||||
self._buffer = bytearray()
|
||||
@ -124,6 +164,7 @@ class PythonProtocol:
|
||||
self._chunk_state = 'size'
|
||||
self._chunk_size = 0
|
||||
self._chunk_remaining = 0
|
||||
self._header_count = 0
|
||||
|
||||
def _parse_request_line(self):
|
||||
"""Parse request line, return True if complete."""
|
||||
@ -131,6 +172,10 @@ class PythonProtocol:
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
# Check request line length limit
|
||||
if self._limit_request_line > 0 and idx > self._limit_request_line:
|
||||
raise LimitRequestLine("Request line is too large")
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
@ -142,6 +187,11 @@ class PythonProtocol:
|
||||
self.method = parts[0]
|
||||
self.path = parts[1]
|
||||
|
||||
# Validate method
|
||||
if not self._permit_unconventional_http_method:
|
||||
if not self._is_valid_method(self.method):
|
||||
raise InvalidRequestMethod(self.method.decode('latin-1'))
|
||||
|
||||
# Parse version
|
||||
version = parts[2]
|
||||
if version == b'HTTP/1.1':
|
||||
@ -149,7 +199,17 @@ class PythonProtocol:
|
||||
elif version == b'HTTP/1.0':
|
||||
self.http_version = (1, 0)
|
||||
else:
|
||||
raise ParseError("Unsupported HTTP version")
|
||||
if not self._permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
# Try to parse other HTTP/1.x versions if permitted
|
||||
if version.startswith(b'HTTP/1.'):
|
||||
try:
|
||||
minor = int(version[7:])
|
||||
self.http_version = (1, minor)
|
||||
except ValueError:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
else:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
|
||||
if self._on_message_begin:
|
||||
self._on_message_begin()
|
||||
@ -174,18 +234,34 @@ class PythonProtocol:
|
||||
self._finalize_headers()
|
||||
return True
|
||||
|
||||
# Check header field size limit
|
||||
if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size:
|
||||
raise LimitRequestHeaders("Request header field is too large")
|
||||
|
||||
# Check header count limit
|
||||
self._header_count += 1
|
||||
if self._limit_request_fields > 0 and self._header_count > self._limit_request_fields:
|
||||
raise LimitRequestHeaders("Too many headers")
|
||||
|
||||
# Parse header
|
||||
colon = line.find(b':')
|
||||
if colon == -1:
|
||||
raise ParseError("Invalid header")
|
||||
raise InvalidHeader("Missing colon in header")
|
||||
|
||||
name = line[:colon].strip()
|
||||
if not self._is_valid_token(name):
|
||||
raise InvalidHeaderName(name.decode('latin-1'))
|
||||
|
||||
name = line[:colon].strip().lower()
|
||||
value = line[colon + 1:].strip()
|
||||
if self._has_invalid_header_chars(value):
|
||||
raise InvalidHeader("Invalid characters in header value")
|
||||
|
||||
self._headers_list.append((name, value))
|
||||
# Store lowercase name for internal use
|
||||
name_lower = name.lower()
|
||||
self._headers_list.append((name_lower, value))
|
||||
|
||||
if self._on_header:
|
||||
self._on_header(name, value)
|
||||
self._on_header(name_lower, value)
|
||||
|
||||
def _finalize_headers(self):
|
||||
"""Called when all headers received."""
|
||||
@ -329,6 +405,35 @@ class PythonProtocol:
|
||||
|
||||
return False
|
||||
|
||||
def _is_valid_method(self, method):
|
||||
"""Check if method is valid token with conventional restrictions."""
|
||||
if not method:
|
||||
return False
|
||||
# Check length (3-20 chars)
|
||||
if not 3 <= len(method) <= 20:
|
||||
return False
|
||||
# Check for lowercase or # (unconventional)
|
||||
for c in method:
|
||||
if c in b'abcdefghijklmnopqrstuvwxyz#':
|
||||
return False
|
||||
return self._is_valid_token(method)
|
||||
|
||||
def _is_valid_token(self, data):
|
||||
"""Check if data contains only RFC 9110 token characters."""
|
||||
if not data:
|
||||
return False
|
||||
for c in data:
|
||||
if c < 0x21 or c > 0x7e:
|
||||
return False
|
||||
# RFC 9110 delimiters: "(),/:;<=>?@[\]{}
|
||||
if c in b'"(),/:;<=>?@[\\]{}"':
|
||||
return False
|
||||
return True
|
||||
|
||||
def _has_invalid_header_chars(self, value):
|
||||
"""Check for NUL, CR, LF in header value."""
|
||||
return b'\x00' in value or b'\r' in value or b'\n' in value
|
||||
|
||||
|
||||
class CallbackRequest:
|
||||
"""Request object built from callback parser state.
|
||||
|
||||
@ -16,7 +16,8 @@ import time
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.parser import (
|
||||
PythonProtocol, CallbackRequest, ParseError
|
||||
PythonProtocol, CallbackRequest, ParseError,
|
||||
LimitRequestLine, LimitRequestHeaders
|
||||
)
|
||||
from gunicorn.asgi.uwsgi import AsyncUWSGIRequest
|
||||
from gunicorn.http.errors import NoMoreData
|
||||
@ -283,6 +284,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
# Class-level cache for H1CProtocol availability
|
||||
_h1c_available = None
|
||||
_h1c_protocol_class = None
|
||||
_h1c_has_limits = False # True if >= 0.4.1 (has limit parameters)
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
@ -354,40 +356,73 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
"""Check if H1CProtocol is available (cached at class level)."""
|
||||
if cls._h1c_available is None:
|
||||
try:
|
||||
import gunicorn_h1c
|
||||
from gunicorn_h1c import H1CProtocol
|
||||
cls._h1c_available = True
|
||||
cls._h1c_protocol_class = H1CProtocol
|
||||
# Require >= 0.4.1 for limit enforcement
|
||||
cls._h1c_has_limits = hasattr(gunicorn_h1c, 'LimitRequestLine')
|
||||
except ImportError:
|
||||
cls._h1c_available = False
|
||||
cls._h1c_has_limits = False
|
||||
return cls._h1c_available
|
||||
|
||||
# Compatibility flags not supported by the fast parser
|
||||
_FAST_PARSER_INCOMPATIBLE_FLAGS = (
|
||||
'permit_obsolete_folding',
|
||||
'strip_header_spaces',
|
||||
)
|
||||
|
||||
def _setup_callback_parser(self):
|
||||
"""Create callback parser based on http_parser setting.
|
||||
|
||||
Parser selection:
|
||||
- auto: Use H1CProtocol if available, else PythonProtocol
|
||||
- fast: Require H1CProtocol (error if unavailable)
|
||||
- auto: Use H1CProtocol if available (>= 0.4.1) and no incompatible flags, else PythonProtocol
|
||||
- fast: Require H1CProtocol >= 0.4.1 (error if unavailable or incompatible flags)
|
||||
- python: Use PythonProtocol only
|
||||
"""
|
||||
parser_setting = getattr(self.cfg, 'http_parser', 'auto')
|
||||
|
||||
# Check for incompatible compatibility flags
|
||||
incompatible = []
|
||||
for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS:
|
||||
if getattr(self.cfg, flag, False):
|
||||
incompatible.append(flag)
|
||||
|
||||
if parser_setting == 'python':
|
||||
parser_class = PythonProtocol
|
||||
elif parser_setting == 'fast':
|
||||
if not self._check_h1c_protocol_available():
|
||||
raise RuntimeError("gunicorn_h1c required for http_parser='fast'")
|
||||
if not ASGIProtocol._h1c_has_limits:
|
||||
raise RuntimeError(
|
||||
"gunicorn_h1c >= 0.4.1 required for http_parser='fast'. "
|
||||
"Please upgrade: pip install --upgrade gunicorn_h1c"
|
||||
)
|
||||
if incompatible:
|
||||
raise RuntimeError(
|
||||
"http_parser='fast' is incompatible with compatibility flags: %s. "
|
||||
"Use http_parser='python' or disable these flags."
|
||||
% ', '.join(incompatible)
|
||||
)
|
||||
parser_class = ASGIProtocol._h1c_protocol_class
|
||||
else: # auto
|
||||
if self._check_h1c_protocol_available():
|
||||
if (self._check_h1c_protocol_available() and
|
||||
ASGIProtocol._h1c_has_limits and not incompatible):
|
||||
parser_class = ASGIProtocol._h1c_protocol_class
|
||||
else:
|
||||
parser_class = PythonProtocol
|
||||
|
||||
# Create parser with callbacks
|
||||
# Create parser with callbacks and limit parameters (both parsers support them)
|
||||
self._callback_parser = parser_class(
|
||||
on_headers_complete=self._on_headers_complete,
|
||||
on_body=self._on_body,
|
||||
on_message_complete=self._on_message_complete,
|
||||
limit_request_line=self.cfg.limit_request_line,
|
||||
limit_request_fields=self.cfg.limit_request_fields,
|
||||
limit_request_field_size=self.cfg.limit_request_field_size,
|
||||
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
|
||||
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
|
||||
)
|
||||
|
||||
def _on_headers_complete(self):
|
||||
@ -426,6 +461,14 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
# HTTP/1.x path - feed directly to callback parser
|
||||
try:
|
||||
self._callback_parser.feed(data)
|
||||
except LimitRequestLine as e:
|
||||
self._send_error_response(414, str(e)) # URI Too Long
|
||||
self._close_transport()
|
||||
return
|
||||
except LimitRequestHeaders as e:
|
||||
self._send_error_response(431, str(e)) # Request Header Fields Too Large
|
||||
self._close_transport()
|
||||
return
|
||||
except ParseError as e:
|
||||
self._send_error_response(400, str(e))
|
||||
self._close_transport()
|
||||
|
||||
@ -25,9 +25,27 @@ from gunicorn.util import bytes_to_str, split_request_uri
|
||||
_fast_parser_available = None
|
||||
_fast_parser_module = None
|
||||
|
||||
# Compatibility flags not supported by the fast parser
|
||||
_FAST_PARSER_INCOMPATIBLE_FLAGS = (
|
||||
'permit_obsolete_folding',
|
||||
'strip_header_spaces',
|
||||
)
|
||||
|
||||
|
||||
def _check_fast_parser(cfg):
|
||||
"""Check if fast C parser is available and should be used."""
|
||||
"""Check if fast C parser is available and should be used.
|
||||
|
||||
Returns False if:
|
||||
- http_parser='python' is explicitly set
|
||||
- gunicorn_h1c is not installed (in 'auto' mode)
|
||||
- gunicorn_h1c < 0.4.1 (in 'auto' mode)
|
||||
- Incompatible compatibility flags are enabled (in 'auto' mode)
|
||||
|
||||
Raises RuntimeError if:
|
||||
- http_parser='fast' but gunicorn_h1c is not installed
|
||||
- http_parser='fast' but gunicorn_h1c < 0.4.1
|
||||
- http_parser='fast' but incompatible flags are enabled
|
||||
"""
|
||||
global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement
|
||||
|
||||
parser_setting = getattr(cfg, 'http_parser', 'auto')
|
||||
@ -45,7 +63,36 @@ def _check_fast_parser(cfg):
|
||||
if not _fast_parser_available and parser_setting == 'fast':
|
||||
raise RuntimeError("gunicorn_h1c not installed but http_parser='fast'")
|
||||
|
||||
return _fast_parser_available
|
||||
if not _fast_parser_available:
|
||||
return False
|
||||
|
||||
# Require >= 0.4.1 for limit enforcement
|
||||
if not hasattr(_fast_parser_module, 'LimitRequestLine'):
|
||||
if parser_setting == 'fast':
|
||||
raise RuntimeError(
|
||||
"gunicorn_h1c >= 0.4.1 required for http_parser='fast'. "
|
||||
"Please upgrade: pip install --upgrade gunicorn_h1c"
|
||||
)
|
||||
# In 'auto' mode, fall back to Python parser
|
||||
return False
|
||||
|
||||
# Check for incompatible compatibility flags
|
||||
incompatible = []
|
||||
for flag in _FAST_PARSER_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
incompatible.append(flag)
|
||||
|
||||
if incompatible:
|
||||
if parser_setting == 'fast':
|
||||
raise RuntimeError(
|
||||
"http_parser='fast' is incompatible with compatibility flags: %s. "
|
||||
"Use http_parser='python' or disable these flags."
|
||||
% ', '.join(incompatible)
|
||||
)
|
||||
# In 'auto' mode, fall back to Python parser
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# PROXY protocol v2 constants
|
||||
@ -378,14 +425,23 @@ class Request(Message):
|
||||
return self._parse_python(unreader, buf)
|
||||
|
||||
def _parse_fast(self, unreader, buf):
|
||||
"""Parse request using fast C parser (gunicorn_h1c)."""
|
||||
"""Parse request using fast C parser (gunicorn_h1c >= 0.4.1)."""
|
||||
# Read until we have complete headers
|
||||
data = bytes(buf)
|
||||
last_len = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = _fast_parser_module.parse_request(data, last_len=last_len)
|
||||
# Pass all limit parameters (guaranteed >= 0.4.1)
|
||||
result = _fast_parser_module.parse_request(
|
||||
data,
|
||||
last_len=last_len,
|
||||
limit_request_line=self.limit_request_line,
|
||||
limit_request_fields=self.limit_request_fields,
|
||||
limit_request_field_size=self.limit_request_field_size,
|
||||
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
|
||||
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
|
||||
)
|
||||
break
|
||||
except _fast_parser_module.IncompleteError:
|
||||
last_len = len(data)
|
||||
@ -393,6 +449,18 @@ class Request(Message):
|
||||
data = bytes(buf)
|
||||
if len(data) > self.max_buffer_headers + self.limit_request_line:
|
||||
raise LimitRequestHeaders("max buffer headers")
|
||||
except _fast_parser_module.LimitRequestLine as e:
|
||||
raise LimitRequestLine(str(e))
|
||||
except _fast_parser_module.LimitRequestHeaders as e:
|
||||
raise LimitRequestHeaders(str(e))
|
||||
except _fast_parser_module.InvalidRequestMethod as e:
|
||||
raise InvalidRequestMethod(str(e))
|
||||
except _fast_parser_module.InvalidHTTPVersion as e:
|
||||
raise InvalidHTTPVersion(str(e))
|
||||
except _fast_parser_module.InvalidHeaderName as e:
|
||||
raise InvalidHeaderName(str(e))
|
||||
except _fast_parser_module.InvalidHeader as e:
|
||||
raise InvalidHeader(str(e))
|
||||
except _fast_parser_module.ParseError as e:
|
||||
raise InvalidRequestLine(str(e))
|
||||
|
||||
@ -400,14 +468,7 @@ class Request(Message):
|
||||
self.method = bytes_to_str(result['method'])
|
||||
self.uri = bytes_to_str(result['path'])
|
||||
|
||||
# Validate method
|
||||
if not self.cfg.permit_unconventional_http_method:
|
||||
if METHOD_BADCHAR_RE.search(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not 3 <= len(self.method) <= 20:
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not TOKEN_RE.fullmatch(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
# Casefold method if configured (validation done by C parser)
|
||||
if self.cfg.casefold_http_method:
|
||||
self.method = self.method.upper()
|
||||
|
||||
@ -422,24 +483,18 @@ class Request(Message):
|
||||
self.query = parts.query or ""
|
||||
self.fragment = parts.fragment or ""
|
||||
|
||||
# Version
|
||||
# Version (validation done by C parser)
|
||||
self.version = (1, result['minor_version'])
|
||||
if not (1, 0) <= self.version < (2, 0):
|
||||
if not self.cfg.permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(self.version)
|
||||
|
||||
# Headers - convert bytes to strings with uppercase names
|
||||
# gunicorn_h1c returns headers as (bytes, bytes) tuples
|
||||
# Header name/value validation done by C parser
|
||||
self.headers = []
|
||||
for name_bytes, value_bytes in result['headers']:
|
||||
name = bytes_to_str(name_bytes).upper()
|
||||
value = bytes_to_str(value_bytes)
|
||||
|
||||
# Validate header name
|
||||
if not TOKEN_RE.fullmatch(name):
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
# Handle underscore in header names
|
||||
# Handle underscore in header names (policy decision, not validation)
|
||||
if "_" in name:
|
||||
forwarder_headers = self.cfg.forwarder_headers
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
|
||||
@ -53,7 +53,7 @@ tornado = ["tornado>=6.5.0"]
|
||||
gthread = []
|
||||
setproctitle = ["setproctitle"]
|
||||
http2 = ["h2>=4.1.0"]
|
||||
fast = ["gunicorn_h1c>=0.2.0"]
|
||||
fast = ["gunicorn_h1c>=0.4.1"]
|
||||
testing = [
|
||||
"gevent>=24.10.1",
|
||||
"eventlet>=0.40.3",
|
||||
|
||||
@ -4,3 +4,4 @@ coverage
|
||||
pytest>=7.2.0
|
||||
pytest-cov
|
||||
pytest-asyncio
|
||||
gunicorn_h1c>=0.4.1
|
||||
|
||||
@ -18,7 +18,10 @@ if tests_dir not in sys.path:
|
||||
|
||||
@pytest.fixture(params=["python", "fast"])
|
||||
def http_parser(request):
|
||||
"""Parametrize tests over ASGI http_parser implementations."""
|
||||
"""Parametrize tests over http_parser implementations."""
|
||||
if request.param == "fast":
|
||||
pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required")
|
||||
gunicorn_h1c = pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required")
|
||||
# Require >= 0.4.1 for limit enforcement
|
||||
if not hasattr(gunicorn_h1c, 'LimitRequestLine'):
|
||||
pytest.skip("gunicorn_h1c >= 0.4.1 required")
|
||||
return request.param
|
||||
|
||||
@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__)
|
||||
reqdir = os.path.join(dirname, "requests", "invalid")
|
||||
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
|
||||
|
||||
# Flags incompatible with fast parser
|
||||
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fname", httpfiles)
|
||||
def test_http_parser(fname):
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
|
||||
def test_http_parser(fname, http_parser):
|
||||
"""Test invalid HTTP requests with both parser implementations."""
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
|
||||
|
||||
expect = env["request"]
|
||||
cfg = env["cfg"]
|
||||
|
||||
# Skip fast parser tests that use incompatible compatibility flags
|
||||
if http_parser == 'fast':
|
||||
for flag in _FAST_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
pytest.skip(f"fast parser incompatible with {flag}")
|
||||
|
||||
req = treq.badrequest(fname)
|
||||
|
||||
with pytest.raises(expect):
|
||||
|
||||
@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__)
|
||||
reqdir = os.path.join(dirname, "requests", "valid")
|
||||
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
|
||||
|
||||
# Flags incompatible with fast parser
|
||||
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fname", httpfiles)
|
||||
def test_http_parser(fname):
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
|
||||
def test_http_parser(fname, http_parser):
|
||||
"""Test valid HTTP requests with both parser implementations."""
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
|
||||
|
||||
expect = env['request']
|
||||
cfg = env['cfg']
|
||||
|
||||
# Skip fast parser tests that use incompatible compatibility flags
|
||||
if http_parser == 'fast':
|
||||
for flag in _FAST_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
pytest.skip(f"fast parser incompatible with {flag}")
|
||||
|
||||
req = treq.request(fname, expect)
|
||||
|
||||
for case in req.gen_cases(cfg):
|
||||
|
||||
@ -33,22 +33,26 @@ def uri(data):
|
||||
return ret
|
||||
|
||||
|
||||
def load_py(fname):
|
||||
def load_py(fname, http_parser='python'):
|
||||
"""Load test configuration from Python file.
|
||||
|
||||
Args:
|
||||
fname: Path to the .py configuration file
|
||||
http_parser: Parser to use - 'python' or 'fast'
|
||||
"""
|
||||
module_name = '__config__'
|
||||
mod = types.ModuleType(module_name)
|
||||
setattr(mod, 'uri', uri)
|
||||
setattr(mod, 'cfg', Config())
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, fname)
|
||||
loader.exec_module(mod)
|
||||
# Use Python parser for tests to ensure consistent validation behavior
|
||||
# (set after loading so test-specific configs don't override)
|
||||
mod.cfg.set('http_parser', 'python')
|
||||
# Set parser after loading so test-specific configs don't override
|
||||
mod.cfg.set('http_parser', http_parser)
|
||||
return vars(mod)
|
||||
|
||||
|
||||
def decode_hex_escapes(data):
|
||||
"""Decode hex escape sequences like \\xAB in test data."""
|
||||
import re
|
||||
result = bytearray()
|
||||
i = 0
|
||||
while i < len(data):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user