From 03cc85ef48ae6708842083dc92bb1ccb5f7a60ec Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Sun, 22 Mar 2026 13:41:42 +0100 Subject: [PATCH] Integrate gunicorn_h1c 0.4.1 exception types and limit parameters Require gunicorn_h1c >= 0.4.1 for fast parser mode. Add new exception types and limit parameters to PythonProtocol for parity with C parser. Update tests to parametrize across both parser implementations. --- gunicorn/asgi/parser.py | 117 +++++++++++++++++++++++++++++++-- gunicorn/asgi/protocol.py | 53 +++++++++++++-- gunicorn/http/message.py | 97 +++++++++++++++++++++------ pyproject.toml | 2 +- requirements_test.txt | 1 + tests/conftest.py | 7 +- tests/test_invalid_requests.py | 15 ++++- tests/test_valid_requests.py | 15 ++++- tests/treq.py | 14 ++-- 9 files changed, 277 insertions(+), 44 deletions(-) diff --git a/gunicorn/asgi/parser.py b/gunicorn/asgi/parser.py index 088a06fb..000a41a6 100644 --- a/gunicorn/asgi/parser.py +++ b/gunicorn/asgi/parser.py @@ -11,7 +11,31 @@ or the pure Python PythonProtocol fallback. class ParseError(Exception): - """Error raised during HTTP parsing.""" + """Base error raised during HTTP parsing.""" + + +class LimitRequestLine(ParseError): + """Request line exceeds configured limit.""" + + +class LimitRequestHeaders(ParseError): + """Too many headers or header field too large.""" + + +class InvalidRequestMethod(ParseError): + """Invalid HTTP method.""" + + +class InvalidHTTPVersion(ParseError): + """Invalid HTTP version.""" + + +class InvalidHeaderName(ParseError): + """Invalid header name.""" + + +class InvalidHeader(ParseError): + """Invalid header value.""" class PythonProtocol: @@ -37,6 +61,9 @@ class PythonProtocol: 'content_length', 'is_chunked', 'should_keep_alive', 'is_complete', '_body_remaining', '_skip_body', '_chunk_state', '_chunk_size', '_chunk_remaining', + '_limit_request_line', '_limit_request_fields', '_limit_request_field_size', + '_permit_unconventional_http_method', '_permit_unconventional_http_version', + '_header_count', ) def __init__( @@ -47,6 +74,11 @@ class PythonProtocol: on_headers_complete=None, on_body=None, on_message_complete=None, + limit_request_line=8190, + limit_request_fields=100, + limit_request_field_size=8190, + permit_unconventional_http_method=False, + permit_unconventional_http_version=False, ): self._on_message_begin = on_message_begin self._on_url = on_url @@ -55,6 +87,14 @@ class PythonProtocol: self._on_body = on_body self._on_message_complete = on_message_complete + # Store limits + self._limit_request_line = limit_request_line + self._limit_request_fields = limit_request_fields + self._limit_request_field_size = limit_request_field_size + self._permit_unconventional_http_method = permit_unconventional_http_method + self._permit_unconventional_http_version = permit_unconventional_http_version + self._header_count = 0 + # Parser state: request_line, headers, body, chunked_size, chunked_data, complete self._state = 'request_line' self._buffer = bytearray() @@ -124,6 +164,7 @@ class PythonProtocol: self._chunk_state = 'size' self._chunk_size = 0 self._chunk_remaining = 0 + self._header_count = 0 def _parse_request_line(self): """Parse request line, return True if complete.""" @@ -131,6 +172,10 @@ class PythonProtocol: if idx == -1: return False + # Check request line length limit + if self._limit_request_line > 0 and idx > self._limit_request_line: + raise LimitRequestLine("Request line is too large") + line = bytes(self._buffer[:idx]) del self._buffer[:idx + 2] @@ -142,6 +187,11 @@ class PythonProtocol: self.method = parts[0] self.path = parts[1] + # Validate method + if not self._permit_unconventional_http_method: + if not self._is_valid_method(self.method): + raise InvalidRequestMethod(self.method.decode('latin-1')) + # Parse version version = parts[2] if version == b'HTTP/1.1': @@ -149,7 +199,17 @@ class PythonProtocol: elif version == b'HTTP/1.0': self.http_version = (1, 0) else: - raise ParseError("Unsupported HTTP version") + if not self._permit_unconventional_http_version: + raise InvalidHTTPVersion(version.decode('latin-1')) + # Try to parse other HTTP/1.x versions if permitted + if version.startswith(b'HTTP/1.'): + try: + minor = int(version[7:]) + self.http_version = (1, minor) + except ValueError: + raise InvalidHTTPVersion(version.decode('latin-1')) + else: + raise InvalidHTTPVersion(version.decode('latin-1')) if self._on_message_begin: self._on_message_begin() @@ -174,18 +234,34 @@ class PythonProtocol: self._finalize_headers() return True + # Check header field size limit + if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size: + raise LimitRequestHeaders("Request header field is too large") + + # Check header count limit + self._header_count += 1 + if self._limit_request_fields > 0 and self._header_count > self._limit_request_fields: + raise LimitRequestHeaders("Too many headers") + # Parse header colon = line.find(b':') if colon == -1: - raise ParseError("Invalid header") + raise InvalidHeader("Missing colon in header") + + name = line[:colon].strip() + if not self._is_valid_token(name): + raise InvalidHeaderName(name.decode('latin-1')) - name = line[:colon].strip().lower() value = line[colon + 1:].strip() + if self._has_invalid_header_chars(value): + raise InvalidHeader("Invalid characters in header value") - self._headers_list.append((name, value)) + # Store lowercase name for internal use + name_lower = name.lower() + self._headers_list.append((name_lower, value)) if self._on_header: - self._on_header(name, value) + self._on_header(name_lower, value) def _finalize_headers(self): """Called when all headers received.""" @@ -329,6 +405,35 @@ class PythonProtocol: return False + def _is_valid_method(self, method): + """Check if method is valid token with conventional restrictions.""" + if not method: + return False + # Check length (3-20 chars) + if not 3 <= len(method) <= 20: + return False + # Check for lowercase or # (unconventional) + for c in method: + if c in b'abcdefghijklmnopqrstuvwxyz#': + return False + return self._is_valid_token(method) + + def _is_valid_token(self, data): + """Check if data contains only RFC 9110 token characters.""" + if not data: + return False + for c in data: + if c < 0x21 or c > 0x7e: + return False + # RFC 9110 delimiters: "(),/:;<=>?@[\]{} + if c in b'"(),/:;<=>?@[\\]{}"': + return False + return True + + def _has_invalid_header_chars(self, value): + """Check for NUL, CR, LF in header value.""" + return b'\x00' in value or b'\r' in value or b'\n' in value + class CallbackRequest: """Request object built from callback parser state. diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index a3d15e39..0e29038e 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -16,7 +16,8 @@ import time from gunicorn.asgi.unreader import AsyncUnreader from gunicorn.asgi.parser import ( - PythonProtocol, CallbackRequest, ParseError + PythonProtocol, CallbackRequest, ParseError, + LimitRequestLine, LimitRequestHeaders ) from gunicorn.asgi.uwsgi import AsyncUWSGIRequest from gunicorn.http.errors import NoMoreData @@ -283,6 +284,7 @@ class ASGIProtocol(asyncio.Protocol): # Class-level cache for H1CProtocol availability _h1c_available = None _h1c_protocol_class = None + _h1c_has_limits = False # True if >= 0.4.1 (has limit parameters) def __init__(self, worker): self.worker = worker @@ -354,40 +356,73 @@ class ASGIProtocol(asyncio.Protocol): """Check if H1CProtocol is available (cached at class level).""" if cls._h1c_available is None: try: + import gunicorn_h1c from gunicorn_h1c import H1CProtocol cls._h1c_available = True cls._h1c_protocol_class = H1CProtocol + # Require >= 0.4.1 for limit enforcement + cls._h1c_has_limits = hasattr(gunicorn_h1c, 'LimitRequestLine') except ImportError: cls._h1c_available = False + cls._h1c_has_limits = False return cls._h1c_available + # Compatibility flags not supported by the fast parser + _FAST_PARSER_INCOMPATIBLE_FLAGS = ( + 'permit_obsolete_folding', + 'strip_header_spaces', + ) + def _setup_callback_parser(self): """Create callback parser based on http_parser setting. Parser selection: - - auto: Use H1CProtocol if available, else PythonProtocol - - fast: Require H1CProtocol (error if unavailable) + - auto: Use H1CProtocol if available (>= 0.4.1) and no incompatible flags, else PythonProtocol + - fast: Require H1CProtocol >= 0.4.1 (error if unavailable or incompatible flags) - python: Use PythonProtocol only """ parser_setting = getattr(self.cfg, 'http_parser', 'auto') + # Check for incompatible compatibility flags + incompatible = [] + for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS: + if getattr(self.cfg, flag, False): + incompatible.append(flag) + if parser_setting == 'python': parser_class = PythonProtocol elif parser_setting == 'fast': if not self._check_h1c_protocol_available(): raise RuntimeError("gunicorn_h1c required for http_parser='fast'") + if not ASGIProtocol._h1c_has_limits: + raise RuntimeError( + "gunicorn_h1c >= 0.4.1 required for http_parser='fast'. " + "Please upgrade: pip install --upgrade gunicorn_h1c" + ) + if incompatible: + raise RuntimeError( + "http_parser='fast' is incompatible with compatibility flags: %s. " + "Use http_parser='python' or disable these flags." + % ', '.join(incompatible) + ) parser_class = ASGIProtocol._h1c_protocol_class else: # auto - if self._check_h1c_protocol_available(): + if (self._check_h1c_protocol_available() and + ASGIProtocol._h1c_has_limits and not incompatible): parser_class = ASGIProtocol._h1c_protocol_class else: parser_class = PythonProtocol - # Create parser with callbacks + # Create parser with callbacks and limit parameters (both parsers support them) self._callback_parser = parser_class( on_headers_complete=self._on_headers_complete, on_body=self._on_body, on_message_complete=self._on_message_complete, + limit_request_line=self.cfg.limit_request_line, + limit_request_fields=self.cfg.limit_request_fields, + limit_request_field_size=self.cfg.limit_request_field_size, + permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, + permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, ) def _on_headers_complete(self): @@ -426,6 +461,14 @@ class ASGIProtocol(asyncio.Protocol): # HTTP/1.x path - feed directly to callback parser try: self._callback_parser.feed(data) + except LimitRequestLine as e: + self._send_error_response(414, str(e)) # URI Too Long + self._close_transport() + return + except LimitRequestHeaders as e: + self._send_error_response(431, str(e)) # Request Header Fields Too Large + self._close_transport() + return except ParseError as e: self._send_error_response(400, str(e)) self._close_transport() diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index d9050ddf..5c6df5eb 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -25,9 +25,27 @@ from gunicorn.util import bytes_to_str, split_request_uri _fast_parser_available = None _fast_parser_module = None +# Compatibility flags not supported by the fast parser +_FAST_PARSER_INCOMPATIBLE_FLAGS = ( + 'permit_obsolete_folding', + 'strip_header_spaces', +) + def _check_fast_parser(cfg): - """Check if fast C parser is available and should be used.""" + """Check if fast C parser is available and should be used. + + Returns False if: + - http_parser='python' is explicitly set + - gunicorn_h1c is not installed (in 'auto' mode) + - gunicorn_h1c < 0.4.1 (in 'auto' mode) + - Incompatible compatibility flags are enabled (in 'auto' mode) + + Raises RuntimeError if: + - http_parser='fast' but gunicorn_h1c is not installed + - http_parser='fast' but gunicorn_h1c < 0.4.1 + - http_parser='fast' but incompatible flags are enabled + """ global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement parser_setting = getattr(cfg, 'http_parser', 'auto') @@ -45,7 +63,36 @@ def _check_fast_parser(cfg): if not _fast_parser_available and parser_setting == 'fast': raise RuntimeError("gunicorn_h1c not installed but http_parser='fast'") - return _fast_parser_available + if not _fast_parser_available: + return False + + # Require >= 0.4.1 for limit enforcement + if not hasattr(_fast_parser_module, 'LimitRequestLine'): + if parser_setting == 'fast': + raise RuntimeError( + "gunicorn_h1c >= 0.4.1 required for http_parser='fast'. " + "Please upgrade: pip install --upgrade gunicorn_h1c" + ) + # In 'auto' mode, fall back to Python parser + return False + + # Check for incompatible compatibility flags + incompatible = [] + for flag in _FAST_PARSER_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + incompatible.append(flag) + + if incompatible: + if parser_setting == 'fast': + raise RuntimeError( + "http_parser='fast' is incompatible with compatibility flags: %s. " + "Use http_parser='python' or disable these flags." + % ', '.join(incompatible) + ) + # In 'auto' mode, fall back to Python parser + return False + + return True # PROXY protocol v2 constants @@ -378,14 +425,23 @@ class Request(Message): return self._parse_python(unreader, buf) def _parse_fast(self, unreader, buf): - """Parse request using fast C parser (gunicorn_h1c).""" + """Parse request using fast C parser (gunicorn_h1c >= 0.4.1).""" # Read until we have complete headers data = bytes(buf) last_len = 0 while True: try: - result = _fast_parser_module.parse_request(data, last_len=last_len) + # Pass all limit parameters (guaranteed >= 0.4.1) + result = _fast_parser_module.parse_request( + data, + last_len=last_len, + limit_request_line=self.limit_request_line, + limit_request_fields=self.limit_request_fields, + limit_request_field_size=self.limit_request_field_size, + permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, + permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, + ) break except _fast_parser_module.IncompleteError: last_len = len(data) @@ -393,6 +449,18 @@ class Request(Message): data = bytes(buf) if len(data) > self.max_buffer_headers + self.limit_request_line: raise LimitRequestHeaders("max buffer headers") + except _fast_parser_module.LimitRequestLine as e: + raise LimitRequestLine(str(e)) + except _fast_parser_module.LimitRequestHeaders as e: + raise LimitRequestHeaders(str(e)) + except _fast_parser_module.InvalidRequestMethod as e: + raise InvalidRequestMethod(str(e)) + except _fast_parser_module.InvalidHTTPVersion as e: + raise InvalidHTTPVersion(str(e)) + except _fast_parser_module.InvalidHeaderName as e: + raise InvalidHeaderName(str(e)) + except _fast_parser_module.InvalidHeader as e: + raise InvalidHeader(str(e)) except _fast_parser_module.ParseError as e: raise InvalidRequestLine(str(e)) @@ -400,14 +468,7 @@ class Request(Message): self.method = bytes_to_str(result['method']) self.uri = bytes_to_str(result['path']) - # Validate method - if not self.cfg.permit_unconventional_http_method: - if METHOD_BADCHAR_RE.search(self.method): - raise InvalidRequestMethod(self.method) - if not 3 <= len(self.method) <= 20: - raise InvalidRequestMethod(self.method) - if not TOKEN_RE.fullmatch(self.method): - raise InvalidRequestMethod(self.method) + # Casefold method if configured (validation done by C parser) if self.cfg.casefold_http_method: self.method = self.method.upper() @@ -422,24 +483,18 @@ class Request(Message): self.query = parts.query or "" self.fragment = parts.fragment or "" - # Version + # Version (validation done by C parser) self.version = (1, result['minor_version']) - if not (1, 0) <= self.version < (2, 0): - if not self.cfg.permit_unconventional_http_version: - raise InvalidHTTPVersion(self.version) # Headers - convert bytes to strings with uppercase names # gunicorn_h1c returns headers as (bytes, bytes) tuples + # Header name/value validation done by C parser self.headers = [] for name_bytes, value_bytes in result['headers']: name = bytes_to_str(name_bytes).upper() value = bytes_to_str(value_bytes) - # Validate header name - if not TOKEN_RE.fullmatch(name): - raise InvalidHeaderName(name) - - # Handle underscore in header names + # Handle underscore in header names (policy decision, not validation) if "_" in name: forwarder_headers = self.cfg.forwarder_headers if name in forwarder_headers or "*" in forwarder_headers: diff --git a/pyproject.toml b/pyproject.toml index 6b8bc002..2532c830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ tornado = ["tornado>=6.5.0"] gthread = [] setproctitle = ["setproctitle"] http2 = ["h2>=4.1.0"] -fast = ["gunicorn_h1c>=0.2.0"] +fast = ["gunicorn_h1c>=0.4.1"] testing = [ "gevent>=24.10.1", "eventlet>=0.40.3", diff --git a/requirements_test.txt b/requirements_test.txt index efa91f20..4f98848f 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -4,3 +4,4 @@ coverage pytest>=7.2.0 pytest-cov pytest-asyncio +gunicorn_h1c>=0.4.1 diff --git a/tests/conftest.py b/tests/conftest.py index 85fda072..56509c95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,7 +18,10 @@ if tests_dir not in sys.path: @pytest.fixture(params=["python", "fast"]) def http_parser(request): - """Parametrize tests over ASGI http_parser implementations.""" + """Parametrize tests over http_parser implementations.""" if request.param == "fast": - pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required") + gunicorn_h1c = pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required") + # Require >= 0.4.1 for limit enforcement + if not hasattr(gunicorn_h1c, 'LimitRequestLine'): + pytest.skip("gunicorn_h1c >= 0.4.1 required") return request.param diff --git a/tests/test_invalid_requests.py b/tests/test_invalid_requests.py index 63224d07..65119dc1 100644 --- a/tests/test_invalid_requests.py +++ b/tests/test_invalid_requests.py @@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__) reqdir = os.path.join(dirname, "requests", "invalid") httpfiles = glob.glob(os.path.join(reqdir, "*.http")) +# Flags incompatible with fast parser +_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + @pytest.mark.parametrize("fname", httpfiles) -def test_http_parser(fname): - env = treq.load_py(os.path.splitext(fname)[0] + ".py") +def test_http_parser(fname, http_parser): + """Test invalid HTTP requests with both parser implementations.""" + env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser) expect = env["request"] cfg = env["cfg"] + + # Skip fast parser tests that use incompatible compatibility flags + if http_parser == 'fast': + for flag in _FAST_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"fast parser incompatible with {flag}") + req = treq.badrequest(fname) with pytest.raises(expect): diff --git a/tests/test_valid_requests.py b/tests/test_valid_requests.py index 2c71622c..6bfa2229 100644 --- a/tests/test_valid_requests.py +++ b/tests/test_valid_requests.py @@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__) reqdir = os.path.join(dirname, "requests", "valid") httpfiles = glob.glob(os.path.join(reqdir, "*.http")) +# Flags incompatible with fast parser +_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + @pytest.mark.parametrize("fname", httpfiles) -def test_http_parser(fname): - env = treq.load_py(os.path.splitext(fname)[0] + ".py") +def test_http_parser(fname, http_parser): + """Test valid HTTP requests with both parser implementations.""" + env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser) expect = env['request'] cfg = env['cfg'] + + # Skip fast parser tests that use incompatible compatibility flags + if http_parser == 'fast': + for flag in _FAST_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"fast parser incompatible with {flag}") + req = treq.request(fname, expect) for case in req.gen_cases(cfg): diff --git a/tests/treq.py b/tests/treq.py index ce5a4901..15148809 100644 --- a/tests/treq.py +++ b/tests/treq.py @@ -33,22 +33,26 @@ def uri(data): return ret -def load_py(fname): +def load_py(fname, http_parser='python'): + """Load test configuration from Python file. + + Args: + fname: Path to the .py configuration file + http_parser: Parser to use - 'python' or 'fast' + """ module_name = '__config__' mod = types.ModuleType(module_name) setattr(mod, 'uri', uri) setattr(mod, 'cfg', Config()) loader = importlib.machinery.SourceFileLoader(module_name, fname) loader.exec_module(mod) - # Use Python parser for tests to ensure consistent validation behavior - # (set after loading so test-specific configs don't override) - mod.cfg.set('http_parser', 'python') + # Set parser after loading so test-specific configs don't override + mod.cfg.set('http_parser', http_parser) return vars(mod) def decode_hex_escapes(data): """Decode hex escape sequences like \\xAB in test data.""" - import re result = bytearray() i = 0 while i < len(data):