diff --git a/benchmarks/http_parser_benchmark.py b/benchmarks/http_parser_benchmark.py new file mode 100644 index 00000000..dfcb4013 --- /dev/null +++ b/benchmarks/http_parser_benchmark.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python +""" +Benchmark comparing HTTP parser implementations. + +Compares: +- WSGI Python parser vs Fast parser (gunicorn_h1c) +- ASGI Python parser vs Fast parser (gunicorn_h1c) + +Usage: + python benchmarks/http_parser_benchmark.py +""" + +import io +import time +import statistics +from typing import NamedTuple + +from gunicorn.config import Config +from gunicorn.http.message import Request, _check_fast_parser +from gunicorn.http.unreader import IterUnreader + + +# Check if fast parser is available +try: + import gunicorn_h1c + FAST_AVAILABLE = True +except ImportError: + FAST_AVAILABLE = False + print("WARNING: gunicorn_h1c not installed. Fast parser benchmarks will be skipped.") + print("Install with: pip install gunicorn_h1c\n") + + +class BenchmarkResult(NamedTuple): + name: str + iterations: int + total_time: float + avg_time_us: float + min_time_us: float + max_time_us: float + requests_per_sec: float + + +# Test requests of varying complexity +SIMPLE_REQUEST = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" + +MEDIUM_REQUEST = b"""POST /api/users HTTP/1.1\r +Host: api.example.com\r +Content-Type: application/json\r +Content-Length: 42\r +Accept: application/json\r +Authorization: Bearer token123\r +X-Request-ID: abc-123-def-456\r +\r +""" + +COMPLEX_REQUEST = b"""POST /api/v2/resources/items HTTP/1.1\r +Host: api.example.com\r +Content-Type: application/json; charset=utf-8\r +Content-Length: 1024\r +Accept: application/json, text/plain, */*\r +Accept-Language: en-US,en;q=0.9,fr;q=0.8\r +Accept-Encoding: gzip, deflate, br\r +Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ\r +X-Request-ID: 550e8400-e29b-41d4-a716-446655440000\r +X-Correlation-ID: 7f3d8c2a-1b4e-4a6f-9c8d-2e5f6a7b8c9d\r +X-Forwarded-For: 203.0.113.195, 70.41.3.18, 150.172.238.178\r +X-Forwarded-Proto: https\r +X-Real-IP: 203.0.113.195\r +User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36\r +Cache-Control: no-cache, no-store, must-revalidate\r +Pragma: no-cache\r +Cookie: session=abc123; preferences=dark_mode\r +If-None-Match: "etag-value-here"\r +If-Modified-Since: Wed, 21 Oct 2024 07:28:00 GMT\r +\r +""" + + +def create_wsgi_config(use_fast: bool) -> Config: + """Create a config for WSGI parsing.""" + cfg = Config() + cfg.set('http_parser', 'fast' if use_fast else 'python') + return cfg + + +def benchmark_wsgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult: + """Benchmark WSGI parser.""" + times = [] + parser_type = cfg.http_parser + + for _ in range(iterations): + # Create fresh unreader for each iteration + unreader = IterUnreader(iter([request_data])) + + start = time.perf_counter() + req = Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1) + end = time.perf_counter() + + times.append(end - start) + + # Verify parsing worked + assert req.method is not None + + total_time = sum(times) + avg_time = statistics.mean(times) + min_time = min(times) + max_time = max(times) + + return BenchmarkResult( + name=f"WSGI {parser_type}", + iterations=iterations, + total_time=total_time, + avg_time_us=avg_time * 1_000_000, + min_time_us=min_time * 1_000_000, + max_time_us=max_time * 1_000_000, + requests_per_sec=iterations / total_time, + ) + + +def benchmark_asgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult: + """Benchmark ASGI parser.""" + from gunicorn.asgi.parser import HttpParser + + times = [] + parser_type = cfg.http_parser + + for _ in range(iterations): + # Create fresh parser for each iteration + parser = HttpParser(cfg, ('127.0.0.1', 8000), is_ssl=False) + + start = time.perf_counter() + result = parser.feed(bytearray(request_data)) + end = time.perf_counter() + + times.append(end - start) + + # Verify parsing worked + assert result is not None + assert result.method is not None + + total_time = sum(times) + avg_time = statistics.mean(times) + min_time = min(times) + max_time = max(times) + + return BenchmarkResult( + name=f"ASGI {parser_type}", + iterations=iterations, + total_time=total_time, + avg_time_us=avg_time * 1_000_000, + min_time_us=min_time * 1_000_000, + max_time_us=max_time * 1_000_000, + requests_per_sec=iterations / total_time, + ) + + +def print_result(result: BenchmarkResult, baseline: BenchmarkResult = None): + """Print benchmark result.""" + speedup = "" + if baseline and baseline.avg_time_us > 0: + ratio = baseline.avg_time_us / result.avg_time_us + if ratio > 1: + speedup = f" ({ratio:.2f}x faster)" + elif ratio < 1: + speedup = f" ({1/ratio:.2f}x slower)" + + print(f" {result.name:20} {result.avg_time_us:8.2f} us/req " + f"({result.requests_per_sec:,.0f} req/s){speedup}") + + +def run_benchmark_suite(name: str, request_data: bytes, iterations: int): + """Run a complete benchmark suite for a request type.""" + print(f"\n{'='*60}") + print(f"Benchmark: {name}") + print(f"Request size: {len(request_data)} bytes, Iterations: {iterations:,}") + print('='*60) + + results = [] + + # WSGI Python + cfg_python = create_wsgi_config(use_fast=False) + result_wsgi_python = benchmark_wsgi_parser(request_data, cfg_python, iterations) + results.append(result_wsgi_python) + + # WSGI Fast (if available) + if FAST_AVAILABLE: + cfg_fast = create_wsgi_config(use_fast=True) + result_wsgi_fast = benchmark_wsgi_parser(request_data, cfg_fast, iterations) + results.append(result_wsgi_fast) + + # ASGI Python + cfg_python = create_wsgi_config(use_fast=False) + result_asgi_python = benchmark_asgi_parser(request_data, cfg_python, iterations) + results.append(result_asgi_python) + + # ASGI Fast (if available) + if FAST_AVAILABLE: + cfg_fast = create_wsgi_config(use_fast=True) + result_asgi_fast = benchmark_asgi_parser(request_data, cfg_fast, iterations) + results.append(result_asgi_fast) + + # Print results + print("\nResults (avg time per request):") + print("-" * 60) + + # Print WSGI results + print_result(result_wsgi_python) + if FAST_AVAILABLE: + print_result(result_wsgi_fast, result_wsgi_python) + + print() + + # Print ASGI results + print_result(result_asgi_python) + if FAST_AVAILABLE: + print_result(result_asgi_fast, result_asgi_python) + + return results + + +def main(): + print("HTTP Parser Benchmark") + print("=" * 60) + print(f"Fast parser (gunicorn_h1c): {'Available' if FAST_AVAILABLE else 'Not installed'}") + + # Warmup + print("\nWarming up...") + cfg = create_wsgi_config(use_fast=False) + for _ in range(100): + unreader = IterUnreader(iter([SIMPLE_REQUEST])) + Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1) + + if FAST_AVAILABLE: + cfg = create_wsgi_config(use_fast=True) + for _ in range(100): + unreader = IterUnreader(iter([SIMPLE_REQUEST])) + Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1) + + # Run benchmarks + iterations = 10000 + + all_results = [] + all_results.extend(run_benchmark_suite("Simple GET Request", SIMPLE_REQUEST, iterations)) + all_results.extend(run_benchmark_suite("Medium POST Request", MEDIUM_REQUEST, iterations)) + all_results.extend(run_benchmark_suite("Complex POST Request", COMPLEX_REQUEST, iterations)) + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + + if FAST_AVAILABLE: + # Calculate overall speedups + wsgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI python"]) + wsgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI fast"]) + asgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI python"]) + asgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI fast"]) + + print(f"\nWSGI: Fast parser is {wsgi_python_avg/wsgi_fast_avg:.2f}x faster than Python parser") + print(f"ASGI: Fast parser is {asgi_python_avg/asgi_fast_avg:.2f}x faster than Python parser") + else: + print("\nInstall gunicorn_h1c to see fast parser comparison:") + print(" pip install gunicorn_h1c") + + print() + + +if __name__ == "__main__": + main() diff --git a/benchmarks/simple_app.py b/benchmarks/simple_app.py index 1eaa50da..129b8e98 100644 --- a/benchmarks/simple_app.py +++ b/benchmarks/simple_app.py @@ -4,12 +4,18 @@ # Simple WSGI app for benchmarking +import time + + def application(environ, start_response): """Basic hello world response.""" path = environ.get('PATH_INFO', '/') if path == '/large': body = b'X' * 65536 # 64KB + elif path == '/slow': + time.sleep(0.01) # 10ms simulated I/O + body = b'Slow response' else: body = b'Hello, World!' diff --git a/docs/content/asgi.md b/docs/content/asgi.md index c91e5afb..802b9c77 100644 --- a/docs/content/asgi.md +++ b/docs/content/asgi.md @@ -24,9 +24,11 @@ gunicorn main:app --worker-class asgi --bind 0.0.0.0:8000 The ASGI worker provides: - **HTTP/1.1** with keepalive connections +- **HTTP/2** with multiplexing and server push (requires SSL) - **WebSocket** support for real-time applications - **Lifespan protocol** for startup/shutdown hooks -- **Optional uvloop** for improved performance +- **Optional fast HTTP parser** via C extension for high throughput +- **Optional uvloop** for improved event loop performance - **SSL/TLS** support - **uWSGI protocol** for nginx `uwsgi_pass` integration @@ -225,6 +227,56 @@ asgi_loop = "auto" # Use uvloop if available asgi_lifespan = "auto" # Auto-detect lifespan support ``` +## Performance + +### Fast HTTP Parser + +For maximum performance, install the optional `gunicorn_h1c` C extension: + +```bash +pip install gunicorn[fast] +``` + +This provides a high-performance HTTP parser using picohttpparser with SIMD +optimizations, offering significant speedups for HTTP parsing compared to the +pure Python implementation. + +The parser is automatically used when available (`--http-parser auto`), or you +can explicitly require it: + +```bash +gunicorn myapp:app --worker-class asgi --http-parser fast +``` + +| Parser | Description | +|--------|-------------| +| `auto` | Use fast parser if available, otherwise Python (default) | +| `fast` | Require fast parser, fail if unavailable | +| `python` | Force pure Python parser | + +### Performance Tips + +1. **Use uvloop** for improved event loop performance: + ```bash + pip install uvloop + gunicorn myapp:app --worker-class asgi --asgi-loop uvloop + ``` + +2. **Install the fast parser** for optimized HTTP parsing: + ```bash + pip install gunicorn[fast] + ``` + +3. **Tune worker count** based on CPU cores: + ```bash + gunicorn myapp:app --worker-class asgi --workers $(nproc) + ``` + +4. **Increase connections** for I/O-bound applications: + ```bash + gunicorn myapp:app --worker-class asgi --worker-connections 2000 + ``` + ## Comparison with Other ASGI Servers | Feature | Gunicorn ASGI | Uvicorn | Hypercorn | diff --git a/docs/content/news.md b/docs/content/news.md index 6315a88c..51f37e0c 100644 --- a/docs/content/news.md +++ b/docs/content/news.md @@ -3,6 +3,14 @@ ## unreleased +### New Features + +- **Fast HTTP Parser (gunicorn_h1c 0.4.1)**: Integrate new exception types and limit + parameters from gunicorn_h1c 0.4.1 for both WSGI and ASGI workers + - Requires gunicorn_h1c >= 0.4.1 for `http_parser='fast'` + - Falls back to Python parser in `auto` mode if version not met + - Proper HTTP status codes for limit errors (414, 431) + ### Performance - **ASGI HTTP Parser Optimizations**: Improve ASGI worker HTTP parsing performance diff --git a/docs/content/reference/settings.md b/docs/content/reference/settings.md index ac45ac9f..84ddd16c 100644 --- a/docs/content/reference/settings.md +++ b/docs/content/reference/settings.md @@ -1971,3 +1971,23 @@ need to increase this value. This setting only affects the ``asgi`` worker type. !!! info "Added in 25.0.0" + +### `http_parser` + +**Command line:** `--http-parser STRING` + +**Default:** `'auto'` + +HTTP parser implementation for ASGI workers. + +- auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default) +- fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable) +- python: Force pure Python PythonProtocol parser + +ASGI workers use callback-based parsing in data_received() for efficient +incremental parsing. The gunicorn_h1c C extension provides significantly +faster HTTP parsing using picohttpparser with SIMD optimizations. + +Install it with: pip install gunicorn[fast] + +!!! info "Added in 25.0.0" diff --git a/examples/embedding_service/Dockerfile b/examples/embedding_service/Dockerfile index b931d6c3..2afb6394 100644 --- a/examples/embedding_service/Dockerfile +++ b/examples/embedding_service/Dockerfile @@ -6,7 +6,9 @@ WORKDIR /app RUN pip install --no-cache-dir \ sentence-transformers \ fastapi \ - pydantic + pydantic \ + pytest \ + requests # Copy gunicorn source COPY . /app/gunicorn-src diff --git a/gunicorn/asgi/parser.py b/gunicorn/asgi/parser.py new file mode 100644 index 00000000..e7432a77 --- /dev/null +++ b/gunicorn/asgi/parser.py @@ -0,0 +1,543 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +HTTP parser for ASGI workers. + +Provides callback-based parsing using either the fast C parser (gunicorn_h1c) +or the pure Python PythonProtocol fallback. +""" + + +class ParseError(Exception): + """Base error raised during HTTP parsing.""" + + +class LimitRequestLine(ParseError): + """Request line exceeds configured limit.""" + + +class LimitRequestHeaders(ParseError): + """Too many headers or header field too large.""" + + +class InvalidRequestMethod(ParseError): + """Invalid HTTP method.""" + + +class InvalidHTTPVersion(ParseError): + """Invalid HTTP version.""" + + +class InvalidHeaderName(ParseError): + """Invalid header name.""" + + +class InvalidHeader(ParseError): + """Invalid header value.""" + + +class PythonProtocol: + """Callback-based HTTP/1.1 parser (pure Python fallback). + + Mirrors H1CProtocol interface for seamless switching between + the C extension and pure Python implementations. + + Callbacks: + on_message_begin: () -> None - Called when request starts + on_url: (url: bytes) -> None - Called with request URL/path + on_header: (name: bytes, value: bytes) -> None - Called for each header + on_headers_complete: () -> bool - Called when headers done (return True to skip body) + on_body: (chunk: bytes) -> None - Called with body data chunks + on_message_complete: () -> None - Called when request is complete + """ + + __slots__ = ( + '_on_message_begin', '_on_url', '_on_header', + '_on_headers_complete', '_on_body', '_on_message_complete', + '_state', '_buffer', '_headers_list', + 'method', 'path', 'http_version', 'headers', + 'content_length', 'is_chunked', 'should_keep_alive', 'is_complete', + '_body_remaining', '_skip_body', + '_chunk_state', '_chunk_size', '_chunk_remaining', + '_limit_request_line', '_limit_request_fields', '_limit_request_field_size', + '_permit_unconventional_http_method', '_permit_unconventional_http_version', + '_header_count', + ) + + def __init__( + self, + on_message_begin=None, + on_url=None, + on_header=None, + on_headers_complete=None, + on_body=None, + on_message_complete=None, + limit_request_line=8190, + limit_request_fields=100, + limit_request_field_size=8190, + permit_unconventional_http_method=False, + permit_unconventional_http_version=False, + ): + self._on_message_begin = on_message_begin + self._on_url = on_url + self._on_header = on_header + self._on_headers_complete = on_headers_complete + self._on_body = on_body + self._on_message_complete = on_message_complete + + # Store limits + self._limit_request_line = limit_request_line + self._limit_request_fields = limit_request_fields + self._limit_request_field_size = limit_request_field_size + self._permit_unconventional_http_method = permit_unconventional_http_method + self._permit_unconventional_http_version = permit_unconventional_http_version + self._header_count = 0 + + # Parser state: request_line, headers, body, chunked_size, chunked_data, complete + self._state = 'request_line' + self._buffer = bytearray() + self._headers_list = [] + + # Request info (populated during parsing) + self.method = None + self.path = None + self.http_version = None + self.headers = [] + self.content_length = None + self.is_chunked = False + self.should_keep_alive = True + self.is_complete = False + + # Body state + self._body_remaining = 0 + self._skip_body = False + + # Chunked transfer state + self._chunk_state = 'size' # size, data, trailer + self._chunk_size = 0 + self._chunk_remaining = 0 + + def feed(self, data): + """Process data, fire callbacks synchronously. + + Args: + data: bytes or bytearray of incoming data + + Raises: + ParseError: If the HTTP request is malformed + """ + self._buffer.extend(data) + + while self._buffer: + if self._state == 'request_line': + if not self._parse_request_line(): + break + elif self._state == 'headers': + if not self._parse_headers(): + break + elif self._state == 'body': + if not self._parse_body(): + break + elif self._state == 'chunked': + if not self._parse_chunked_body(): + break + else: + break + + def reset(self): + """Reset for next request (keepalive).""" + self._state = 'request_line' + self._buffer.clear() + self._headers_list = [] + self.method = None + self.path = None + self.http_version = None + self.headers = [] + self.content_length = None + self.is_chunked = False + self.should_keep_alive = True + self.is_complete = False + self._body_remaining = 0 + self._skip_body = False + self._chunk_state = 'size' + self._chunk_size = 0 + self._chunk_remaining = 0 + self._header_count = 0 + + def _parse_request_line(self): + """Parse request line, return True if complete.""" + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + # Check request line length limit + if self._limit_request_line > 0 and idx > self._limit_request_line: + raise LimitRequestLine("Request line is too large") + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + # Parse: METHOD PATH HTTP/x.y + parts = line.split(b' ', 2) + if len(parts) != 3: + raise ParseError("Invalid request line") + + self.method = parts[0] + self.path = parts[1] + + # Validate method + if not self._permit_unconventional_http_method: + if not self._is_valid_method(self.method): + raise InvalidRequestMethod(self.method.decode('latin-1')) + + # Parse version + version = parts[2] + if version == b'HTTP/1.1': + self.http_version = (1, 1) + elif version == b'HTTP/1.0': + self.http_version = (1, 0) + else: + if not self._permit_unconventional_http_version: + raise InvalidHTTPVersion(version.decode('latin-1')) + # Try to parse other HTTP/1.x versions if permitted + if version.startswith(b'HTTP/1.'): + try: + minor = int(version[7:]) + self.http_version = (1, minor) + except ValueError: + raise InvalidHTTPVersion(version.decode('latin-1')) + else: + raise InvalidHTTPVersion(version.decode('latin-1')) + + if self._on_message_begin: + self._on_message_begin() + if self._on_url: + self._on_url(self.path) + + self._state = 'headers' + return True + + def _parse_headers(self): + """Parse headers, return True if headers are complete.""" + while True: + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + if not line: + # Empty line = end of headers + self._finalize_headers() + return True + + # Check header field size limit + if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size: + raise LimitRequestHeaders("Request header field is too large") + + # Check header count limit + self._header_count += 1 + if self._limit_request_fields > 0 and self._header_count > self._limit_request_fields: + raise LimitRequestHeaders("Too many headers") + + # Parse header + colon = line.find(b':') + if colon == -1: + raise InvalidHeader("Missing colon in header") + + name = line[:colon].strip() + if not self._is_valid_token(name): + raise InvalidHeaderName(name.decode('latin-1')) + + value = line[colon + 1:].strip() + if self._has_invalid_header_chars(value): + raise InvalidHeader("Invalid characters in header value") + + # Store lowercase name for internal use + name_lower = name.lower() + self._headers_list.append((name_lower, value)) + + if self._on_header: + self._on_header(name_lower, value) + + def _finalize_headers(self): + """Called when all headers received.""" + self.headers = self._headers_list + + # Extract content-length and chunked + for name, value in self.headers: + if name == b'content-length': + self.content_length = int(value) + self._body_remaining = self.content_length + elif name == b'transfer-encoding': + self.is_chunked = b'chunked' in value.lower() + elif name == b'connection': + val = value.lower() + if b'close' in val: + self.should_keep_alive = False + elif b'keep-alive' in val: + self.should_keep_alive = True + + # HTTP/1.0 defaults to close + if self.http_version == (1, 0) and self.should_keep_alive: + # Only keep-alive if explicitly requested + has_keepalive = any( + name == b'connection' and b'keep-alive' in value.lower() + for name, value in self.headers + ) + if not has_keepalive: + self.should_keep_alive = False + + if self._on_headers_complete: + self._skip_body = self._on_headers_complete() + + # Determine next state + if self._skip_body: + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + elif self.is_chunked: + self._state = 'chunked' + self._chunk_state = 'size' + elif self.content_length and self.content_length > 0: + self._state = 'body' + else: + # No body + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + + def _parse_body(self): + """Parse Content-Length delimited body.""" + if not self._buffer or self._body_remaining <= 0: + return False + + chunk_size = min(len(self._buffer), self._body_remaining) + chunk = bytes(self._buffer[:chunk_size]) + del self._buffer[:chunk_size] + self._body_remaining -= chunk_size + + if self._on_body: + self._on_body(chunk) + + if self._body_remaining <= 0: + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + + return True + + def _parse_chunked_body(self): + """Parse chunked transfer encoding.""" + while self._buffer: + if self._chunk_state == 'size': + # Looking for chunk size line + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + size_line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + # Handle chunk extensions (e.g., "5;ext=value") + semicolon = size_line.find(b';') + if semicolon != -1: + size_line = size_line[:semicolon].strip() + + try: + self._chunk_size = int(size_line, 16) + except ValueError: + raise ParseError("Invalid chunk size") + + if self._chunk_size == 0: + # Final chunk - skip trailers + self._chunk_state = 'trailer' + else: + self._chunk_remaining = self._chunk_size + self._chunk_state = 'data' + + elif self._chunk_state == 'data': + # Reading chunk data + if not self._buffer: + return False + + to_read = min(len(self._buffer), self._chunk_remaining) + chunk = bytes(self._buffer[:to_read]) + del self._buffer[:to_read] + self._chunk_remaining -= to_read + + if self._on_body: + self._on_body(chunk) + + if self._chunk_remaining == 0: + # Need to consume trailing CRLF + self._chunk_state = 'crlf' + + elif self._chunk_state == 'crlf': + # Skip CRLF after chunk data + if len(self._buffer) < 2: + return False + del self._buffer[:2] # Skip \r\n + self._chunk_state = 'size' + + elif self._chunk_state == 'trailer': + # Skip trailer headers + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + if not line: + # Empty line = end of trailers + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + return True + + return False + + def _is_valid_method(self, method): + """Check if method is valid token with conventional restrictions.""" + if not method: + return False + # Check length (3-20 chars) + if not 3 <= len(method) <= 20: + return False + # Check for lowercase or # (unconventional) + for c in method: + if c in b'abcdefghijklmnopqrstuvwxyz#': + return False + return self._is_valid_token(method) + + def _is_valid_token(self, data): + """Check if data contains only RFC 9110 token characters.""" + if not data: + return False + for c in data: + if c < 0x21 or c > 0x7e: + return False + # RFC 9110 delimiters: "(),/:;<=>?@[\]{} + if c in b'"(),/:;<=>?@[\\]{}"': + return False + return True + + def _has_invalid_header_chars(self, value): + """Check for NUL, CR, LF in header value.""" + return b'\x00' in value or b'\r' in value or b'\n' in value + + +class CallbackRequest: + """Request object built from callback parser state. + + Works with both H1CProtocol (C extension) and PythonProtocol. + """ + + __slots__ = ( + 'method', 'uri', 'path', 'query', 'fragment', 'version', + 'headers', 'headers_bytes', 'scheme', 'raw_path', + 'content_length', 'chunked', 'must_close', + 'proxy_protocol_info', '_expect_100_continue', + ) + + def __init__(self): + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = None + self.version = None + self.headers = [] + self.headers_bytes = [] + self.scheme = "http" + self.raw_path = b'' + self.content_length = 0 + self.chunked = False + self.must_close = False + self.proxy_protocol_info = None + self._expect_100_continue = False + + @classmethod + def from_parser(cls, parser, is_ssl=False): + """Build request from callback parser state. + + Args: + parser: H1CProtocol or PythonProtocol instance + is_ssl: Whether connection is SSL/TLS + + Returns: + CallbackRequest instance + """ + from urllib.parse import unquote_to_bytes + + req = cls() + req.method = parser.method.decode('ascii') + + # Parse path and query from URL + # Per ASGI spec: + # - path: percent-decoded UTF-8 string + # - raw_path: original bytes as received + raw_url = parser.path + if b'?' in raw_url: + path_part, query_part = raw_url.split(b'?', 1) + req.raw_path = path_part # Store original bytes + req.path = unquote_to_bytes(path_part).decode('utf-8', errors='replace') + req.query = query_part.decode('latin-1') + else: + req.raw_path = raw_url # Store original bytes + req.path = unquote_to_bytes(raw_url).decode('utf-8', errors='replace') + req.query = '' + + req.uri = raw_url.decode('latin-1') + req.fragment = '' + req.version = parser.http_version + + # Headers - store both bytes (for ASGI scope) and strings (for compatibility) + req.headers_bytes = list(parser.headers) + req.headers = [ + (n.decode('latin-1').upper(), v.decode('latin-1')) + for n, v in parser.headers + ] + + req.scheme = 'https' if is_ssl else 'http' + req.content_length = parser.content_length or 0 + req.chunked = parser.is_chunked + req.must_close = not parser.should_keep_alive + + # Check for Expect: 100-continue + for name, value in parser.headers: + if name == b'expect' and value.lower() == b'100-continue': + req._expect_100_continue = True + break + + return req + + def should_close(self): + """Check if connection should be closed after this request.""" + if self.must_close: + return True + for name, value in self.headers: + if name == "CONNECTION": + v = value.lower().strip(" \t") + if v == "close": + return True + elif v == "keep-alive": + return False + break + return self.version <= (1, 0) + + def get_header(self, name): + """Get a header value by name (case-insensitive).""" + name = name.upper() + for h, v in self.headers: + if h == name: + return v + return None diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index 356bd507..7a593f15 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -11,15 +11,33 @@ and dispatch to ASGI applications. import asyncio import errno -from datetime import datetime +import ipaddress +import time from gunicorn.asgi.unreader import AsyncUnreader -from gunicorn.asgi.message import AsyncRequest +from gunicorn.asgi.parser import ( + PythonProtocol, CallbackRequest, ParseError, + LimitRequestLine, LimitRequestHeaders +) from gunicorn.asgi.uwsgi import AsyncUWSGIRequest from gunicorn.http.errors import NoMoreData from gunicorn.uwsgi.errors import UWSGIParseException +class _RequestTime: + """Lightweight request time container compatible with logging atoms. + + Uses time.monotonic() elapsed seconds instead of datetime.now() syscalls. + Provides .seconds and .microseconds attributes for glogging.py compatibility. + """ + + __slots__ = ('seconds', 'microseconds') + + def __init__(self, elapsed): + self.seconds = int(elapsed) + self.microseconds = int((elapsed - self.seconds) * 1_000_000) + + def _normalize_sockaddr(sockaddr): """Normalize socket address to ASGI-compatible (host, port) tuple. @@ -30,6 +48,101 @@ def _normalize_sockaddr(sockaddr): return tuple(sockaddr[:2]) if sockaddr else None +def _check_trusted_proxy(peer_addr, allow_list, networks): + """Check if peer address is in the trusted proxy list. + + Cached at connection start to avoid repeated IP parsing per request. + """ + if not isinstance(peer_addr, tuple): + return False + if '*' in allow_list: + return True + try: + ip = ipaddress.ip_address(peer_addr[0]) + except ValueError: + return False + for network in networks: + if ip in network: + return True + return False + + +# Cached response bytes for common cases +_CACHED_STATUS_LINES = {} +_CACHED_SERVER_HEADER = b"Server: gunicorn/asgi\r\n" + +# Date header cache (updated once per second) +_cached_date_header = b"" +_cached_date_time = 0.0 + +# Pre-compute common chunk size prefixes to avoid repeated formatting +_CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)} + +# High water mark for write buffer backpressure (64KB) +HIGH_WATER_LIMIT = 65536 + + +class FlowControl: + """Manage transport-level write flow control. + + Blocks send() when transport buffer exceeds high water mark, + preventing memory issues with large streaming responses. + """ + __slots__ = ('_transport', 'read_paused', 'write_paused', '_is_writable_event') + + def __init__(self, transport): + self._transport = transport + self.read_paused = False + self.write_paused = False + self._is_writable_event = asyncio.Event() + self._is_writable_event.set() + + async def drain(self): + """Wait until transport is writable.""" + await self._is_writable_event.wait() + + def pause_reading(self): + if not self.read_paused: + self.read_paused = True + self._transport.pause_reading() + + def resume_reading(self): + if self.read_paused: + self.read_paused = False + self._transport.resume_reading() + + def pause_writing(self): + if not self.write_paused: + self.write_paused = True + self._is_writable_event.clear() + + def resume_writing(self): + if self.write_paused: + self.write_paused = False + self._is_writable_event.set() + + +def _get_cached_date_header(): + """Get cached Date header, updating once per second.""" + global _cached_date_header, _cached_date_time # pylint: disable=global-statement + now = time.time() + if now - _cached_date_time >= 1.0: + # Update date header + from email.utils import formatdate + _cached_date_header = f"Date: {formatdate(usegmt=True)}\r\n".encode("latin-1") + _cached_date_time = now + return _cached_date_header + + +def _get_cached_status_line(version, status, reason): + """Get cached status line bytes.""" + key = (version, status) + if key not in _CACHED_STATUS_LINES: + line = f"HTTP/{version[0]}.{version[1]} {status} {reason}\r\n" + _CACHED_STATUS_LINES[key] = line.encode("latin-1") + return _CACHED_STATUS_LINES[key] + + class ASGIResponseInfo: """Simple container for ASGI response info for access logging.""" @@ -46,12 +159,131 @@ class ASGIResponseInfo: self.headers.append((name, value)) +class BodyReceiver: + """Body receiver for callback-based parsers. + + Body chunks are fed directly via the feed() method from parser callbacks. + Uses Future-based waiting for efficient async receive(). + """ + + __slots__ = ('_chunks', '_complete', '_body_finished', '_closed', '_waiter', + 'request', 'protocol') + + def __init__(self, request, protocol): + self.request = request + self.protocol = protocol + self._chunks = [] + self._complete = False + self._body_finished = False # True after returning more_body=False + self._closed = False + self._waiter = None + + def feed(self, chunk): + """Feed a body chunk directly (called by parser callback).""" + if chunk: + self._chunks.append(chunk) + self._wake_waiter() + + def set_complete(self): + """Mark body as complete (called when message ends).""" + self._complete = True + self._wake_waiter() + + def signal_disconnect(self): + """Signal that connection has been lost.""" + self._closed = True + self._wake_waiter() + + def _wake_waiter(self): + """Wake up any pending receive() call.""" + if self._waiter is not None and not self._waiter.done(): + self._waiter.set_result(None) + + async def receive(self): # pylint: disable=too-many-return-statements + """ASGI receive callable - returns body chunks or disconnect.""" + # Already disconnected or body finished + if self._closed or self._body_finished: + return {"type": "http.disconnect"} + + # Fast path: chunk already available + if self._chunks: + return self._pop_chunk() + + # Body complete with no more chunks + if self._complete: + self._body_finished = True + return {"type": "http.request", "body": b"", "more_body": False} + + # No body expected + if self.request.content_length == 0 and not self.request.chunked: + self._complete = True + self._body_finished = True + return {"type": "http.request", "body": b"", "more_body": False} + + # Check protocol closed state + if self.protocol._closed: + self._closed = True + return {"type": "http.disconnect"} + + # Wait for body chunk to arrive via callback + try: + await self._wait_for_data() + return self._build_receive_result() + except asyncio.CancelledError: + return {"type": "http.disconnect"} + + def _pop_chunk(self): + """Pop a chunk and return the appropriate message.""" + chunk = self._chunks.pop(0) + more = bool(self._chunks) or not self._complete + if not more: + self._body_finished = True + return {"type": "http.request", "body": chunk, "more_body": more} + + def _build_receive_result(self): + """Build receive result after waiting for data.""" + if self._closed: + return {"type": "http.disconnect"} + + if self._chunks: + return self._pop_chunk() + + # Complete OR timeout - mark body finished to prevent infinite loops + # Apps should not loop forever waiting for body that won't arrive + self._body_finished = True + return {"type": "http.request", "body": b"", "more_body": False} + + async def _wait_for_data(self): + """Wait for body data to arrive via callback.""" + if self._chunks or self._complete or self._closed: + return + + # Create a new waiter + loop = asyncio.get_event_loop() + self._waiter = loop.create_future() + + try: + # Wait with timeout for data or completion + await asyncio.wait_for(self._waiter, timeout=30.0) + except asyncio.TimeoutError: + pass + finally: + self._waiter = None + + class ASGIProtocol(asyncio.Protocol): """HTTP/1.1 protocol handler for ASGI applications. Handles connection lifecycle, request parsing, and ASGI app invocation. + Uses callback-based parsing (H1CProtocol/PythonProtocol) for efficient + incremental parsing in data_received(). """ + # Class-level cache for H1CProtocol availability + _h1c_available = None + _h1c_protocol_class = None + _h1c_has_limits = False # True if >= 0.4.1 (has limit parameters) + def __init__(self, worker): self.worker = worker self.cfg = worker.cfg @@ -59,14 +291,36 @@ class ASGIProtocol(asyncio.Protocol): self.app = worker.asgi self.transport = None - self.reader = None + self.reader = None # Only used for HTTP/2 self.writer = None self._task = None self.req_count = 0 # Connection state self._closed = False - self._receive_queue = None # Set per-request for disconnect signaling + self._body_receiver = None # Set per-request for disconnect signaling + + # Response buffering for write batching + self._response_buffer = None + + # Backpressure control + self._reading_paused = False + self._max_buffer_size = 65536 * 4 # 256KB max buffer (HTTP/2 only) + + # Keep-alive timer + self._keepalive_handle = None + + # Callback parser state + self._callback_parser = None + self._request_ready = None # Event signaling headers complete + self._current_request = None # Request built from parser state + self._is_ssl = False + + # Write flow control + self._flow_control = None + + # WebSocket protocol (set during upgrade, receives data via callbacks) + self._websocket = None def connection_made(self, transport): """Called when a connection is established.""" @@ -78,26 +332,200 @@ class ASGIProtocol(asyncio.Protocol): if ssl_object and hasattr(ssl_object, 'selected_alpn_protocol'): alpn = ssl_object.selected_alpn_protocol() if alpn == 'h2': - # HTTP/2 connection - create reader immediately to avoid race condition - # data_received may be called before _handle_http2_connection starts + # HTTP/2 connection - uses StreamReader (complex framing) self.reader = asyncio.StreamReader() self._task = self.worker.loop.create_task( self._handle_http2_connection(transport, ssl_object) ) return - # HTTP/1.x connection - # Create stream reader/writer - self.reader = asyncio.StreamReader() + # HTTP/1.x connection - always use callback parser + self._is_ssl = ssl_object is not None self.writer = transport - # Start handling requests + # Setup flow control for HTTP/1.x + self._flow_control = FlowControl(transport) + transport.set_write_buffer_limits(high=HIGH_WATER_LIMIT) + + # Setup callback parser with request ready event + self._request_ready = asyncio.Event() + self._setup_callback_parser() self._task = self.worker.loop.create_task(self._handle_connection()) + @classmethod + def _check_h1c_protocol_available(cls): + """Check if H1CProtocol is available (cached at class level).""" + if cls._h1c_available is None: + try: + import gunicorn_h1c + from gunicorn_h1c import H1CProtocol + cls._h1c_available = True + cls._h1c_protocol_class = H1CProtocol + # Require >= 0.4.1 for limit enforcement + cls._h1c_has_limits = hasattr(gunicorn_h1c, 'LimitRequestLine') + except ImportError: + cls._h1c_available = False + cls._h1c_has_limits = False + return cls._h1c_available + + # Compatibility flags not supported by the fast parser + _FAST_PARSER_INCOMPATIBLE_FLAGS = ( + 'permit_obsolete_folding', + 'strip_header_spaces', + ) + + def _setup_callback_parser(self): + """Create callback parser based on http_parser setting. + + Parser selection: + - auto: Use H1CProtocol if available (>= 0.4.1) and no incompatible flags, else PythonProtocol + - fast: Require H1CProtocol >= 0.4.1 (error if unavailable or incompatible flags) + - python: Use PythonProtocol only + """ + parser_setting = getattr(self.cfg, 'http_parser', 'auto') + + # Check for incompatible compatibility flags + incompatible = [] + for flag in self._FAST_PARSER_INCOMPATIBLE_FLAGS: + if getattr(self.cfg, flag, False): + incompatible.append(flag) + + if parser_setting == 'python': + parser_class = PythonProtocol + elif parser_setting == 'fast': + if not self._check_h1c_protocol_available(): + raise RuntimeError("gunicorn_h1c required for http_parser='fast'") + if not ASGIProtocol._h1c_has_limits: + raise RuntimeError( + "gunicorn_h1c >= 0.4.1 required for http_parser='fast'. " + "Please upgrade: pip install --upgrade gunicorn_h1c" + ) + if incompatible: + raise RuntimeError( + "http_parser='fast' is incompatible with compatibility flags: %s. " + "Use http_parser='python' or disable these flags." + % ', '.join(incompatible) + ) + parser_class = ASGIProtocol._h1c_protocol_class + else: # auto + if (self._check_h1c_protocol_available() and + ASGIProtocol._h1c_has_limits and not incompatible): + parser_class = ASGIProtocol._h1c_protocol_class + else: + parser_class = PythonProtocol + + # Create parser with callbacks and limit parameters (both parsers support them) + self._callback_parser = parser_class( + on_headers_complete=self._on_headers_complete, + on_body=self._on_body, + on_message_complete=self._on_message_complete, + limit_request_line=self.cfg.limit_request_line, + limit_request_fields=self.cfg.limit_request_fields, + limit_request_field_size=self.cfg.limit_request_field_size, + permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, + permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, + ) + + def _on_headers_complete(self): + """Callback: request headers are complete.""" + # Build request from parser state + self._current_request = CallbackRequest.from_parser( + self._callback_parser, is_ssl=self._is_ssl + ) + + # Create body receiver for this request + self._body_receiver = BodyReceiver(self._current_request, self) + + # Signal that request is ready for processing + if self._request_ready: + self._request_ready.set() + + # Return True for HEAD to skip body parsing + return self._callback_parser.method == b'HEAD' + + def _on_body(self, chunk): + """Callback: received body data chunk.""" + if self._body_receiver: + self._body_receiver.feed(chunk) + + def _on_message_complete(self): + """Callback: request is fully received.""" + if self._body_receiver: + self._body_receiver.set_complete() + def data_received(self, data): """Called when data is received on the connection.""" + if self._websocket: + # WebSocket path - forward to WebSocket protocol + self._websocket.feed_data(data) + return if self.reader: + # HTTP/2 path - use StreamReader self.reader.feed_data(data) + elif self._callback_parser: + # HTTP/1.x path - feed directly to callback parser + try: + self._callback_parser.feed(data) + except LimitRequestLine as e: + self._send_error_response(414, str(e)) # URI Too Long + self._close_transport() + return + except LimitRequestHeaders as e: + self._send_error_response(431, str(e)) # Request Header Fields Too Large + self._close_transport() + return + except ParseError as e: + self._send_error_response(400, str(e)) + self._close_transport() + return + + # Backpressure: pause reading if buffer is too large + if not self._reading_paused and self._is_buffer_full(): + self._pause_reading() + + def _is_buffer_full(self): + """Check if internal buffer is full (HTTP/2 only).""" + if self.reader and hasattr(self.reader, '_buffer'): + return len(self.reader._buffer) > self._max_buffer_size + return False + + def _pause_reading(self): + """Pause reading from transport due to backpressure.""" + if not self._reading_paused and self.transport: + self._reading_paused = True + try: + self.transport.pause_reading() + except (AttributeError, RuntimeError): + pass + + def _resume_reading(self): + """Resume reading from transport.""" + if self._reading_paused and self.transport: + self._reading_paused = False + try: + self.transport.resume_reading() + except (AttributeError, RuntimeError): + pass + + def _arm_keepalive_timer(self): + """Arm keepalive timeout timer after response completion.""" + if self._keepalive_handle: + self._keepalive_handle.cancel() + keepalive_timeout = self.cfg.keepalive + if keepalive_timeout > 0: + self._keepalive_handle = self.worker.loop.call_later( + keepalive_timeout, self._keepalive_timeout + ) + + def _cancel_keepalive_timer(self): + """Cancel keepalive timer when new request arrives.""" + if self._keepalive_handle: + self._keepalive_handle.cancel() + self._keepalive_handle = None + + def _keepalive_timeout(self): + """Called when keepalive timeout expires.""" + self._close_transport() def connection_lost(self, exc): """Called when the connection is lost or closed. @@ -115,12 +543,20 @@ class ASGIProtocol(asyncio.Protocol): self._closed = True self.worker.nr_conns -= 1 + + # Cancel keepalive timer + self._cancel_keepalive_timer() + if self.reader: self.reader.feed_eof() - # Signal disconnect to the app via the receive queue - if self._receive_queue is not None: - self._receive_queue.put_nowait({"type": "http.disconnect"}) + # Signal EOF to WebSocket if active + if self._websocket: + self._websocket.feed_eof() + + # Signal disconnect to the app via the body receiver + if self._body_receiver is not None: + self._body_receiver.signal_disconnect() # Schedule task cancellation after grace period if task doesn't complete if self._task and not self._task.done(): @@ -139,6 +575,16 @@ class ASGIProtocol(asyncio.Protocol): if self._task and not self._task.done(): self._task.cancel() + def pause_writing(self): + """Called by transport when write buffer exceeds high water mark.""" + if self._flow_control: + self._flow_control.pause_writing() + + def resume_writing(self): + """Called by transport when write buffer drains below low water mark.""" + if self._flow_control: + self._flow_control.resume_writing() + def _safe_write(self, data): """Write data to transport, handling connection errors gracefully. @@ -159,40 +605,40 @@ class ASGIProtocol(asyncio.Protocol): pass async def _handle_connection(self): - """Main request handling loop for this connection.""" - unreader = AsyncUnreader(self.reader) + """Main request handling loop using callback-based parser. + Uses synchronous parsing in data_received(), avoiding the async + overhead of pull-based parsing. The parser fires callbacks when + headers and body data are available, and this loop waits on + events rather than actively parsing. + """ try: peername = self.transport.get_extra_info('peername') sockname = self.transport.get_extra_info('sockname') + # Check protocol type - use separate path for uWSGI + protocol_type = getattr(self.cfg, 'protocol', 'http') + if protocol_type == 'uwsgi': + await self._handle_connection_uwsgi(peername, sockname) + return + while not self._closed: self.req_count += 1 + self._cancel_keepalive_timer() - try: - # Parse request based on protocol - protocol = getattr(self.cfg, 'protocol', 'http') - if protocol == 'uwsgi': - request = await AsyncUWSGIRequest.parse( - self.cfg, - unreader, - peername, - self.req_count - ) - else: - request = await AsyncRequest.parse( - self.cfg, - unreader, - peername, - self.req_count - ) - except NoMoreData: - # Client disconnected - break - except UWSGIParseException as e: - self.log.debug("uWSGI parse error: %s", e) + # Wait for headers to be parsed (callback sets the event and _current_request) + # Don't clear if request already arrived (data_received ran before us) + if not self._request_ready.is_set(): + try: + await self._request_ready.wait() + except asyncio.CancelledError: + break + + if self._closed or self._current_request is None: break + request = self._current_request + # Check for WebSocket upgrade if self._is_websocket_upgrade(request): await self._handle_websocket(request, sockname, peername) @@ -219,8 +665,20 @@ class ASGIProtocol(asyncio.Protocol): if not self.cfg.keepalive: break - # Drain any unread body before next request - await request.drain_body() + # Resume reading if paused during body consumption + self._resume_reading() + + # Reset parser for next request + if self._callback_parser: + self._callback_parser.reset() + + # Clear request state for next iteration + self._current_request = None + self._body_receiver = None + self._request_ready.clear() + + # Arm keepalive timer between requests + self._arm_keepalive_timer() except asyncio.CancelledError: pass @@ -229,6 +687,53 @@ class ASGIProtocol(asyncio.Protocol): finally: self._close_transport() + async def _handle_connection_uwsgi(self, peername, sockname): + """Handle uWSGI protocol connections (legacy path).""" + unreader = AsyncUnreader(self.reader) + + while not self._closed: + self.req_count += 1 + + try: + request = await AsyncUWSGIRequest.parse( + self.cfg, + unreader, + peername, + self.req_count + ) + except NoMoreData: + break + except UWSGIParseException as e: + self.log.debug("uWSGI parse error: %s", e) + break + + # Check for WebSocket upgrade + if self._is_websocket_upgrade(request): + await self._handle_websocket(request, sockname, peername) + break + + # Handle HTTP request + keepalive = await self._handle_http_request( + request, sockname, peername + ) + + # Increment worker request count + self.worker.nr += 1 + + # Check max_requests + if self.worker.nr >= self.worker.max_requests: + self.log.info("Autorestarting worker after current request.") + self.worker.alive = False + keepalive = False + + if not keepalive or not self.worker.alive: + break + + if not self.cfg.keepalive: + break + + await request.drain_body() + def _is_websocket_upgrade(self, request): """Check if request is a WebSocket upgrade. @@ -254,10 +759,17 @@ class ASGIProtocol(asyncio.Protocol): """Handle WebSocket upgrade request.""" from gunicorn.asgi.websocket import WebSocketProtocol + # Stop callback parser - WebSocket uses its own data handling + self._callback_parser = None + scope = self._build_websocket_scope(request, sockname, peername) ws_protocol = WebSocketProtocol( - self.transport, self.reader, scope, self.app, self.log + self.transport, scope, self.app, self.log ) + + # Store reference so data_received() forwards to WebSocket + self._websocket = ws_protocol + await ws_protocol.run() async def _handle_http_request(self, request, sockname, peername): @@ -268,41 +780,16 @@ class ASGIProtocol(asyncio.Protocol): exc_to_raise = None use_chunked = False + # Reset response buffer for write batching + self._response_buffer = None + # Response tracking for access logging response_status = 500 response_headers = [] response_sent = 0 - # Receive queue for body - stored on self for disconnect signaling - receive_queue = asyncio.Queue() - self._receive_queue = receive_queue - body_complete = False - - # Pre-populate with initial body state - if request.content_length == 0 and not request.chunked: - await receive_queue.put({ - "type": "http.request", - "body": b"", - "more_body": False, - }) - body_complete = True - else: - # Start body reading task - asyncio.create_task(self._read_body_to_queue(request, receive_queue)) - - async def receive(): - nonlocal body_complete - # Check if already disconnected before waiting - if self._closed and body_complete: - return {"type": "http.disconnect"} - - msg = await receive_queue.get() - - # Track when body is complete - if msg.get("type") == "http.request" and not msg.get("more_body", True): - body_complete = True - - return msg + # Use body receiver created in _on_headers_complete (receives data via callbacks) + body_receiver = self._body_receiver async def send(message): nonlocal response_started, response_complete, exc_to_raise @@ -319,7 +806,7 @@ class ASGIProtocol(asyncio.Protocol): # Handle informational responses (1xx) like 103 Early Hints info_status = message.get("status") info_headers = message.get("headers", []) - await self._send_informational(info_status, info_headers, request) + self._send_informational(info_status, info_headers, request) return if msg_type == "http.response.start": @@ -342,7 +829,7 @@ class ASGIProtocol(asyncio.Protocol): use_chunked = True response_headers = list(response_headers) + [(b"transfer-encoding", b"chunked")] - await self._send_response_start(response_status, response_headers, request) + self._send_response_start(response_status, response_headers, request) elif msg_type == "http.response.body": if not response_started: @@ -356,31 +843,41 @@ class ASGIProtocol(asyncio.Protocol): more_body = message.get("more_body", False) if body: - await self._send_body(body, chunked=use_chunked) + self._send_body(body, chunked=use_chunked) response_sent += len(body) + # Apply write backpressure for streaming responses + if self._flow_control: + await self._flow_control.drain() if not more_body: if use_chunked: - # Send terminal chunk - self._safe_write(b"0\r\n\r\n") + # Send terminal chunk, combined with any buffered headers + if self._response_buffer: + self._safe_write(self._response_buffer + b"0\r\n\r\n") + self._response_buffer = None + else: + self._safe_write(b"0\r\n\r\n") + elif self._response_buffer: + # Non-chunked empty response - flush headers + self._safe_write(self._response_buffer) + self._response_buffer = None response_complete = True - # Build environ for logging - environ = self._build_environ(request, sockname, peername) - resp = None + # Only build environ for logging if access logging is enabled + access_log_enabled = self.log.access_log_enabled try: - request_start = datetime.now() + request_start = time.monotonic() self.cfg.pre_request(self.worker, request) - await self.app(scope, receive, send) + await self.app(scope, body_receiver.receive, send) if exc_to_raise is not None: raise exc_to_raise # Ensure response was sent if not response_started: - await self._send_error_response(500, "Internal Server Error") + self._send_error_response(500, "Internal Server Error") response_status = 500 except asyncio.CancelledError: @@ -390,18 +887,23 @@ class ASGIProtocol(asyncio.Protocol): except Exception: self.log.exception("Error in ASGI application") if not response_started: - await self._send_error_response(500, "Internal Server Error") + self._send_error_response(500, "Internal Server Error") response_status = 500 return False finally: - # Clear the receive queue reference - self._receive_queue = None + # Clear the body receiver reference + self._body_receiver = None try: - request_time = datetime.now() - request_start - # Create response info for logging - resp = ASGIResponseInfo(response_status, response_headers, response_sent) - self.log.access(resp, request, environ, request_time) + request_time = _RequestTime(time.monotonic() - request_start) + # Only build log data if access logging is enabled + if access_log_enabled: + environ = self._build_environ(request, sockname, peername) + resp = ASGIResponseInfo(response_status, response_headers, response_sent) + self.log.access(resp, request, environ, request_time) + else: + environ = None + resp = None self.cfg.post_request(self.worker, request, environ, resp) except Exception: self.log.exception("Exception in post_request hook") @@ -412,38 +914,17 @@ class ASGIProtocol(asyncio.Protocol): return self.worker.alive and self.cfg.keepalive - async def _read_body_to_queue(self, request, queue): - """Read request body and put chunks on the queue.""" - try: - while True: - chunk = await request.read_body(65536) - if chunk: - await queue.put({ - "type": "http.request", - "body": chunk, - "more_body": True, - }) - else: - await queue.put({ - "type": "http.request", - "body": b"", - "more_body": False, - }) - break - except Exception as e: - self.log.debug("Error reading body: %s", e) - await queue.put({ - "type": "http.request", - "body": b"", - "more_body": False, - }) - def _build_http_scope(self, request, sockname, peername): """Build ASGI HTTP scope from parsed request.""" - # Build headers list as bytes tuples - headers = [] - for name, value in request.headers: - headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + # Use pre-computed bytes headers if available (fast path) + # Fall back to conversion for legacy requests (AsyncRequest, HTTP/2) + headers_bytes = getattr(request, 'headers_bytes', None) + if isinstance(headers_bytes, list): + headers = list(headers_bytes) # Copy to avoid mutation + else: + headers = [] + for name, value in request.headers: + headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) server = _normalize_sockaddr(sockname) client = _normalize_sockaddr(peername) @@ -455,7 +936,7 @@ class ASGIProtocol(asyncio.Protocol): "method": request.method, "scheme": request.scheme, "path": request.path, - "raw_path": request.path.encode("latin-1") if request.path else b"", + "raw_path": request.raw_path if request.raw_path else b"", "query_string": request.query.encode("latin-1") if request.query else b"", "root_path": self.cfg.root_path or "", "headers": headers, @@ -519,7 +1000,7 @@ class ASGIProtocol(asyncio.Protocol): "http_version": f"{request.version[0]}.{request.version[1]}", "scheme": "wss" if request.scheme == "https" else "ws", "path": request.path, - "raw_path": request.path.encode("latin-1") if request.path else b"", + "raw_path": request.raw_path if request.raw_path else b"", "query_string": request.query.encode("latin-1") if request.query else b"", "root_path": self.cfg.root_path or "", "headers": headers, @@ -534,7 +1015,7 @@ class ASGIProtocol(asyncio.Protocol): return scope - async def _send_informational(self, status, headers, request): + def _send_informational(self, status, headers, request): """Send an informational response (1xx) such as 103 Early Hints. Args: @@ -562,39 +1043,86 @@ class ASGIProtocol(asyncio.Protocol): response += "\r\n" self._safe_write(response.encode("latin-1")) - async def _send_response_start(self, status, headers, request): - """Send HTTP response status and headers.""" - # Build status line - reason = self._get_reason_phrase(status) - status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n" + def _send_response_start(self, status, headers, request): + """Send HTTP response status and headers. - # Build headers - header_lines = [] + Uses cached status lines and headers for common cases to avoid + repeated string formatting and encoding. + """ + # Get cached status line bytes + reason = self._get_reason_phrase(status) + status_line = _get_cached_status_line(request.version, status, reason) + + # Build headers as bytes directly + parts = [status_line] + + has_date = False + has_server = False for name, value in headers: if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - header_lines.append(f"{name}: {value}\r\n") - - # Add server header if not present - header_lines.append("Server: gunicorn/asgi\r\n") - - response = status_line + "".join(header_lines) + "\r\n" - self._safe_write(response.encode("latin-1")) - - async def _send_body(self, body, chunked=False): - """Send response body chunk.""" - if body: - if chunked: - # Chunked encoding: size in hex + CRLF + data + CRLF - chunk = f"{len(body):x}\r\n".encode("latin-1") + body + b"\r\n" - self._safe_write(chunk) + name_lower = name.lower() + parts.append(name) else: + name_lower = name.lower().encode("latin-1") + parts.append(name.encode("latin-1")) + + parts.append(b": ") + + if isinstance(value, bytes): + parts.append(value) + else: + parts.append(value.encode("latin-1")) + + parts.append(b"\r\n") + + # Track if Date/Server headers are present + if name_lower == b"date": + has_date = True + elif name_lower == b"server": + has_server = True + + # Add default headers if not present + if not has_server: + parts.append(_CACHED_SERVER_HEADER) + if not has_date: + parts.append(_get_cached_date_header()) + + parts.append(b"\r\n") + + # Buffer headers for batching with first body chunk + self._response_buffer = b"".join(parts) + + def _send_body(self, body, chunked=False): + """Send response body chunk. + + Combines buffered headers with first body chunk for efficient write batching. + """ + if chunked: + if body: + # Chunked encoding: size in hex + CRLF + data + CRLF + # Use pre-cached prefix for common sizes, else format + size = len(body) + prefix = _CHUNK_PREFIXES.get(size) or f"{size:x}\r\n".encode("latin-1") + chunk_data = prefix + body + b"\r\n" + else: + chunk_data = b"" + + # Combine with buffered headers if present + if self._response_buffer: + self._safe_write(self._response_buffer + chunk_data) + self._response_buffer = None + elif chunk_data: + self._safe_write(chunk_data) + else: + # Non-chunked: combine headers + body or just body + if self._response_buffer: + self._safe_write(self._response_buffer + body) + self._response_buffer = None + elif body: self._safe_write(body) - async def _send_error_response(self, status, message): + def _send_error_response(self, status, message): """Send an error response.""" body = message.encode("utf-8") response = ( @@ -733,32 +1261,82 @@ class ASGIProtocol(asyncio.Protocol): pass self._close_transport() + def _convert_h2_headers(self, headers): + """Convert ASGI headers to HTTP/2 format (lowercase string names).""" + result = [] + for name, value in headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + result.append((name.lower(), value)) + return result + async def _handle_http2_request(self, request, h2_conn, sockname, peername): - """Handle a single HTTP/2 request.""" + """Handle a single HTTP/2 request with streaming support. + + Streams both request and response body chunks immediately, + avoiding buffering entire uploads and enabling SSE, streaming + downloads, and other real-time use cases. + """ stream_id = request.stream.stream_id + stream = h2_conn.streams.get(stream_id) scope = self._build_http2_scope(request, sockname, peername) response_started = False response_complete = False + headers_sent = False exc_to_raise = None - response_status = 500 response_headers = [] - response_body = b'' - response_trailers = [] + response_sent = 0 + + # Track if we've finished receiving body + body_received = False async def receive(): - # For HTTP/2, the body is already buffered in the stream - body = request.body.read() + nonlocal body_received + + # Check if stream is closed or missing + if stream is None or stream.state.name == "CLOSED": + return {"type": "http.disconnect"} + + # First call: if body already complete (small requests), return it + if not body_received and stream.request_complete and not stream._body_chunks: + body_received = True + body = stream.get_request_body() + return { + "type": "http.request", + "body": body, + "more_body": False, + } + + # Streaming: read next chunk + try: + chunk = await asyncio.wait_for( + stream.read_body_chunk(), + timeout=30.0 + ) + except asyncio.TimeoutError: + return {"type": "http.disconnect"} + + if chunk is None: + body_received = True + return { + "type": "http.request", + "body": b"", + "more_body": False, + } + return { "type": "http.request", - "body": body, - "more_body": False, + "body": chunk, + "more_body": not stream._body_complete, } async def send(message): - nonlocal response_started, response_complete, exc_to_raise - nonlocal response_status, response_headers, response_body + nonlocal response_started, response_complete, headers_sent + nonlocal response_status, response_headers, response_sent, exc_to_raise msg_type = message["type"] @@ -766,14 +1344,7 @@ class ASGIProtocol(asyncio.Protocol): # Handle informational responses (1xx) like 103 Early Hints over HTTP/2 info_status = message.get("status") info_headers = message.get("headers", []) - # Convert headers to list of string tuples - headers = [] - for name, value in info_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - headers.append((name, value)) + headers = self._convert_h2_headers(info_headers) await h2_conn.send_informational(stream_id, info_status, headers) return @@ -784,6 +1355,7 @@ class ASGIProtocol(asyncio.Protocol): response_started = True response_status = message["status"] response_headers = message.get("headers", []) + # Don't send headers yet - wait for first body chunk elif msg_type == "http.response.body": if not response_started: @@ -796,10 +1368,31 @@ class ASGIProtocol(asyncio.Protocol): body = message.get("body", b"") more_body = message.get("more_body", False) + # Send headers with first body chunk + if not headers_sent: + headers = self._convert_h2_headers(response_headers) + response_hdrs = [(':status', str(response_status))] + response_hdrs.extend(headers) + + # Send headers without end_stream since we have body + stream = h2_conn.streams.get(stream_id) + if stream is None: + exc_to_raise = RuntimeError("Stream closed") + return + h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False) + stream.send_headers(response_hdrs, end_stream=False) + await h2_conn._send_pending_data() + headers_sent = True + + # Stream body immediately if body: - response_body += body + await h2_conn.send_data(stream_id, body, end_stream=not more_body) + response_sent += len(body) if not more_body: + if not body: + # Empty final chunk - send end_stream + await h2_conn.send_data(stream_id, b"", end_stream=True) response_complete = True elif msg_type == "http.response.trailers": @@ -807,19 +1400,12 @@ class ASGIProtocol(asyncio.Protocol): exc_to_raise = RuntimeError("Cannot send trailers before body complete") return trailer_headers = message.get("headers", []) - # Convert to list of tuples with string values - trailers = [] - for name, value in trailer_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - trailers.append((name, value)) - response_trailers.extend(trailers) + trailers = self._convert_h2_headers(trailer_headers) + await h2_conn.send_trailers(stream_id, trailers) - # Build environ for logging - environ = self._build_http2_environ(request, sockname, peername) - request_start = datetime.now() + # Only build environ for logging if access logging is enabled + access_log_enabled = self.log.access_log_enabled + request_start = time.monotonic() try: self.cfg.pre_request(self.worker, request) @@ -828,57 +1414,41 @@ class ASGIProtocol(asyncio.Protocol): if exc_to_raise is not None: raise exc_to_raise - # Send response via HTTP/2 - if response_started: - # Convert headers to list of tuples - headers = [] - for name, value in response_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - headers.append((name, value)) - - if response_trailers: - # Send headers, body, then trailers separately - response_hdrs = [(':status', str(response_status))] - for name, value in headers: - response_hdrs.append((name.lower(), str(value))) - - # Send headers without ending stream - h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False) - stream = h2_conn.streams[stream_id] - stream.send_headers(response_hdrs, end_stream=False) - await h2_conn._send_pending_data() - - # Send body without ending stream - if response_body: - h2_conn.h2_conn.send_data(stream_id, response_body, end_stream=False) - stream.send_data(response_body, end_stream=False) - await h2_conn._send_pending_data() - - # Send trailers (ends stream) - await h2_conn.send_trailers(stream_id, response_trailers) - else: - await h2_conn.send_response( - stream_id, response_status, headers, response_body - ) - else: + # Handle case where app didn't send any response + if not response_started: await h2_conn.send_error(stream_id, 500, "Internal Server Error") response_status = 500 + # Handle case where headers were started but no body was sent + elif not headers_sent: + # Send headers now (empty body response) + headers = self._convert_h2_headers(response_headers) + response_hdrs = [(':status', str(response_status))] + response_hdrs.extend(headers) + stream = h2_conn.streams.get(stream_id) + if stream: + h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=True) + stream.send_headers(response_hdrs, end_stream=True) + await h2_conn._send_pending_data() + except Exception: self.log.exception("Error in ASGI application") - if not response_started: + if not headers_sent: await h2_conn.send_error(stream_id, 500, "Internal Server Error") response_status = 500 finally: try: - request_time = datetime.now() - request_start - resp = ASGIResponseInfo( - response_status, response_headers, len(response_body) - ) - self.log.access(resp, request, environ, request_time) + request_time = _RequestTime(time.monotonic() - request_start) + # Only build log data if access logging is enabled + if access_log_enabled: + environ = self._build_http2_environ(request, sockname, peername) + resp = ASGIResponseInfo( + response_status, response_headers, response_sent + ) + self.log.access(resp, request, environ, request_time) + else: + environ = None + resp = None self.cfg.post_request(self.worker, request, environ, resp) except Exception: self.log.exception("Exception in post_request hook") @@ -902,7 +1472,7 @@ class ASGIProtocol(asyncio.Protocol): "method": request.method, "scheme": request.scheme, "path": request.path, - "raw_path": request.path.encode("latin-1") if request.path else b"", + "raw_path": getattr(request, 'raw_path', None) or (request.path.encode("latin-1") if request.path else b""), "query_string": request.query.encode("latin-1") if request.query else b"", "root_path": self.cfg.root_path or "", "headers": headers, diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py index 737268b6..d1b2251b 100644 --- a/gunicorn/asgi/websocket.py +++ b/gunicorn/asgi/websocket.py @@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" class WebSocketProtocol: - """WebSocket connection handler for ASGI applications.""" + """WebSocket connection handler for ASGI applications. - def __init__(self, transport, reader, scope, app, log): + Uses callback-based data feeding instead of StreamReader for efficiency. + Data is fed via feed_data() from the parent protocol's data_received(). + """ + + def __init__(self, transport, scope, app, log): """Initialize WebSocket protocol handler. Args: transport: asyncio transport for writing - reader: asyncio StreamReader for reading scope: ASGI WebSocket scope dict app: ASGI application callable log: Logger instance """ self.transport = transport - self.reader = reader self.scope = scope self.app = app self.log = log @@ -70,6 +72,26 @@ class WebSocketProtocol: # Receive queue for incoming messages self._receive_queue = asyncio.Queue() + # Callback-based data reception (replaces StreamReader) + self._buffer = bytearray() + self._data_event = asyncio.Event() + self._eof = False + + def feed_data(self, data): + """Feed incoming data from the parent protocol's data_received(). + + Args: + data: bytes received on the connection + """ + if data: + self._buffer.extend(data) + self._data_event.set() + + def feed_eof(self): + """Signal that the connection has been closed.""" + self._eof = True + self._data_event.set() + async def run(self): """Run the WebSocket ASGI application.""" # Send initial connect event @@ -295,14 +317,25 @@ class WebSocketProtocol: return (opcode, payload) async def _read_exact(self, n): - """Read exactly n bytes from the reader.""" - try: - data = await self.reader.readexactly(n) - return data - except asyncio.IncompleteReadError: - return None - except Exception: - return None + """Read exactly n bytes from internal buffer. + + Waits for data via the callback-fed buffer instead of StreamReader. + """ + while len(self._buffer) < n: + if self._eof: + return None + self._data_event.clear() + # Critical: check buffer AGAIN after clearing to avoid race + # condition where data arrives between clear() and wait() + if len(self._buffer) >= n: + break + await self._data_event.wait() + if self._eof and len(self._buffer) < n: + return None + + data = bytes(self._buffer[:n]) + del self._buffer[:n] + return data def _unmask(self, payload, masking_key): """Unmask WebSocket payload data.""" diff --git a/gunicorn/config.py b/gunicorn/config.py index 22ebaf4d..997a9830 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2772,6 +2772,22 @@ def validate_asgi_lifespan(val): return val +def validate_http_parser(val): + """Validate http_parser setting. + + Accepts: auto, fast, python + """ + if val is None: + return "auto" + if not isinstance(val, str): + raise TypeError("http_parser must be a string") + val = val.lower().strip() + valid_values = ("auto", "fast", "python") + if val not in valid_values: + raise ValueError("http_parser must be one of: %s" % ", ".join(valid_values)) + return val + + class ASGILoop(Setting): name = "asgi_loop" section = "Worker Processes" @@ -2845,6 +2861,30 @@ class ASGIDisconnectGracePeriod(Setting): """ +class HttpParser(Setting): + name = "http_parser" + section = "Worker Processes" + cli = ["--http-parser"] + meta = "STRING" + validator = validate_http_parser + default = "auto" + desc = """\ + HTTP parser implementation for ASGI workers. + + - auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default) + - fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable) + - python: Force pure Python PythonProtocol parser + + ASGI workers use callback-based parsing in data_received() for efficient + incremental parsing. The gunicorn_h1c C extension provides significantly + faster HTTP parsing using picohttpparser with SIMD optimizations. + + Install it with: pip install gunicorn[fast] + + .. versionadded:: 25.0.0 + """ + + class RootPath(Setting): name = "root_path" section = "Server Mechanics" diff --git a/gunicorn/glogging.py b/gunicorn/glogging.py index ade25eee..075016e2 100644 --- a/gunicorn/glogging.py +++ b/gunicorn/glogging.py @@ -341,14 +341,24 @@ class Logger: return atoms + @property + def access_log_enabled(self): + """Check if access logging is enabled. + + Used by protocol handlers to skip building log data when logging is disabled. + """ + return bool( + self.cfg.accesslog or self.cfg.logconfig or + self.cfg.logconfig_dict or self.cfg.logconfig_json or + (self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog) + ) + def access(self, resp, req, environ, request_time): """ See http://httpd.apache.org/docs/2.0/logs.html#combined for format details """ - if not (self.cfg.accesslog or self.cfg.logconfig or - self.cfg.logconfig_dict or self.cfg.logconfig_json or - (self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog)): + if not self.access_log_enabled: return # wrap atoms: diff --git a/gunicorn/http/errors.py b/gunicorn/http/errors.py index 92e5431c..f3ba534e 100644 --- a/gunicorn/http/errors.py +++ b/gunicorn/http/errors.py @@ -114,11 +114,13 @@ class ChunkMissingTerminator(IOError): class LimitRequestLine(ParseException): - def __init__(self, size, max_size): + def __init__(self, size, max_size=None): self.size = size self.max_size = max_size def __str__(self): + if self.max_size is None: + return str(self.size) return "Request Line is too large (%s > %s)" % (self.size, self.max_size) diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index d12c136f..f3506d5a 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -21,6 +21,80 @@ from gunicorn.http.errors import InvalidSchemeHeaders from gunicorn.util import bytes_to_str, split_request_uri +# Fast parser availability (cached at module level) +_fast_parser_available = None +_fast_parser_module = None + +# Compatibility flags not supported by the fast parser +_FAST_PARSER_INCOMPATIBLE_FLAGS = ( + 'permit_obsolete_folding', + 'strip_header_spaces', +) + + +def _check_fast_parser(cfg): + """Check if fast C parser is available and should be used. + + Returns False if: + - http_parser='python' is explicitly set + - gunicorn_h1c is not installed (in 'auto' mode) + - gunicorn_h1c < 0.4.1 (in 'auto' mode) + - Incompatible compatibility flags are enabled (in 'auto' mode) + + Raises RuntimeError if: + - http_parser='fast' but gunicorn_h1c is not installed + - http_parser='fast' but gunicorn_h1c < 0.4.1 + - http_parser='fast' but incompatible flags are enabled + """ + global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement + + parser_setting = getattr(cfg, 'http_parser', 'auto') + if parser_setting == 'python': + return False + + if _fast_parser_available is None: + try: + import gunicorn_h1c + _fast_parser_available = True + _fast_parser_module = gunicorn_h1c + except ImportError: + _fast_parser_available = False + + if not _fast_parser_available and parser_setting == 'fast': + raise RuntimeError("gunicorn_h1c not installed but http_parser='fast'") + + if not _fast_parser_available: + return False + + # Require >= 0.4.1 for limit enforcement + if not hasattr(_fast_parser_module, 'LimitRequestLine'): + if parser_setting == 'fast': + raise RuntimeError( + "gunicorn_h1c >= 0.4.1 required for http_parser='fast'. " + "Please upgrade: pip install --upgrade gunicorn_h1c" + ) + # In 'auto' mode, fall back to Python parser + return False + + # Check for incompatible compatibility flags + incompatible = [] + for flag in _FAST_PARSER_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + incompatible.append(flag) + + if incompatible: + if parser_setting == 'fast': + raise RuntimeError( + "http_parser='fast' is incompatible with compatibility flags: %s. " + "Use http_parser='python' or disable these flags." + % ', '.join(incompatible) + ) + # In 'auto' mode, fall back to Python parser + return False + + return True + + # PROXY protocol v2 constants PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A" @@ -99,7 +173,7 @@ class Message: or self.limit_request_fields > MAX_HEADERS): self.limit_request_fields = MAX_HEADERS self.limit_request_field_size = cfg.limit_request_field_size - if self.limit_request_field_size < 0: + if self.limit_request_field_size <= 0: self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE # set max header buffer size @@ -315,12 +389,16 @@ class Request(Message): # get max request line size self.limit_request_line = cfg.limit_request_line - if (self.limit_request_line < 0 + if (self.limit_request_line <= 0 or self.limit_request_line >= MAX_REQUEST_LINE): self.limit_request_line = MAX_REQUEST_LINE self.req_number = req_number self.proxy_protocol_info = None + + # Check if fast parser should be used + self._use_fast = _check_fast_parser(cfg) + super().__init__(cfg, unreader, peer_addr) def get_data(self, unreader, buf, stop=False): @@ -340,6 +418,102 @@ class Request(Message): if mode != "off" and self.req_number == 1: buf = self._handle_proxy_protocol(unreader, buf, mode) + # Use fast parser if available + if self._use_fast: + return self._parse_fast(unreader, buf) + + return self._parse_python(unreader, buf) + + def _parse_fast(self, unreader, buf): + """Parse request using fast C parser (gunicorn_h1c >= 0.4.1).""" + # Read until we have complete headers + data = bytes(buf) + last_len = 0 + + while True: + try: + # Pass all limit parameters to C parser + result = _fast_parser_module.parse_request( + data, + last_len=last_len, + limit_request_line=self.limit_request_line, + limit_request_fields=self.limit_request_fields, + limit_request_field_size=self.limit_request_field_size, + permit_unconventional_http_method=self.cfg.permit_unconventional_http_method, + permit_unconventional_http_version=self.cfg.permit_unconventional_http_version, + ) + break + except _fast_parser_module.IncompleteError: + last_len = len(data) + self.read_into(unreader, buf) + data = bytes(buf) + if len(data) > self.max_buffer_headers + self.limit_request_line: + raise LimitRequestHeaders("max buffer headers") + except _fast_parser_module.LimitRequestLine as e: + raise LimitRequestLine(str(e)) + except _fast_parser_module.LimitRequestHeaders as e: + raise LimitRequestHeaders(str(e)) + except _fast_parser_module.InvalidRequestMethod as e: + raise InvalidRequestMethod(str(e)) + except _fast_parser_module.InvalidHTTPVersion as e: + raise InvalidHTTPVersion(str(e)) + except _fast_parser_module.InvalidHeaderName as e: + raise InvalidHeaderName(str(e)) + except _fast_parser_module.InvalidHeader as e: + raise InvalidHeader(str(e)) + except _fast_parser_module.ParseError as e: + raise InvalidRequestLine(str(e)) + + # Extract parsed data + self.method = bytes_to_str(result['method']) + self.uri = bytes_to_str(result['path']) + + # Casefold method if configured (validation done by C parser) + if self.cfg.casefold_http_method: + self.method = self.method.upper() + + # Parse URI parts + if len(self.uri) == 0: + raise InvalidRequestLine(self.uri) + try: + parts = split_request_uri(self.uri) + except ValueError: + raise InvalidRequestLine(self.uri) + self.path = parts.path or "" + self.query = parts.query or "" + self.fragment = parts.fragment or "" + + # Version (validation done by C parser) + self.version = (1, result['minor_version']) + + # Headers - convert bytes to strings with uppercase names + # gunicorn_h1c returns headers as (bytes, bytes) tuples + # Header name/value validation done by C parser + self.headers = [] + for name_bytes, value_bytes in result['headers']: + name = bytes_to_str(name_bytes).upper() + value = bytes_to_str(value_bytes) + + # Handle underscore in header names (policy decision, not validation) + if "_" in name: + forwarder_headers = self.cfg.forwarder_headers + if name in forwarder_headers or "*" in forwarder_headers: + pass + elif self.cfg.header_map == "dangerous": + pass + elif self.cfg.header_map == "drop": + continue + else: + raise InvalidHeaderName(name) + + self.headers.append((name, value)) + + # Return remaining data after headers + consumed = result['consumed'] + return data[consumed:] + + def _parse_python(self, unreader, buf): + """Parse request using pure Python parser.""" # Get request line line, buf = self.read_line(unreader, buf, self.limit_request_line) diff --git a/gunicorn/http2/stream.py b/gunicorn/http2/stream.py index 18ba40f4..34b7be18 100644 --- a/gunicorn/http2/stream.py +++ b/gunicorn/http2/stream.py @@ -71,6 +71,11 @@ class HTTP2Stream: self.priority_depends_on = 0 self.priority_exclusive = False + # Streaming body support (avoids buffering entire uploads) + self._body_chunks = [] + self._body_event = None # Lazy-init asyncio.Event + self._body_complete = False + @property def is_client_stream(self): """Check if this is a client-initiated stream (odd stream ID).""" @@ -122,7 +127,7 @@ class HTTP2Stream: self.request_complete = True def receive_data(self, data, end_stream=False): - """Process received DATA frame. + """Process received DATA frame with streaming support. Args: data: Bytes received @@ -137,11 +142,21 @@ class HTTP2Stream: f"Cannot receive data in state {self.state.name}" ) + # Add to chunks queue for streaming reads + if data: + self._body_chunks.append(data) + if self._body_event: + self._body_event.set() + + # Also write to legacy BytesIO for compatibility self.request_body.write(data) if end_stream: self._half_close_remote() self.request_complete = True + self._body_complete = True + if self._body_event: + self._body_event.set() def receive_trailers(self, trailers): """Process received trailing headers. @@ -283,6 +298,35 @@ class HTTP2Stream: """ return self.request_body.getvalue() + async def read_body_chunk(self): + """Read next body chunk asynchronously for streaming. + + Returns: + bytes: Next chunk of body data, or None if body is complete. + """ + import asyncio + + # Initialize event lazily (avoids event loop issues at construction) + if self._body_event is None: + self._body_event = asyncio.Event() + # If data already arrived before event existed, set it now + # This prevents race where DATA frames arrive before first read + if self._body_chunks or self._body_complete: + self._body_event.set() + + while True: + # Return chunk if available + if self._body_chunks: + return self._body_chunks.pop(0) + + # No more data expected + if self._body_complete: + return None + + # Wait for more data + self._body_event.clear() + await self._body_event.wait() + def get_pseudo_headers(self): """Extract HTTP/2 pseudo-headers from request headers. diff --git a/pyproject.toml b/pyproject.toml index c23d6622..2532c830 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ tornado = ["tornado>=6.5.0"] gthread = [] setproctitle = ["setproctitle"] http2 = ["h2>=4.1.0"] +fast = ["gunicorn_h1c>=0.4.1"] testing = [ "gevent>=24.10.1", "eventlet>=0.40.3", diff --git a/requirements_test.txt b/requirements_test.txt index efa91f20..98e266d8 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -1,6 +1,6 @@ gevent -eventlet coverage pytest>=7.2.0 pytest-cov pytest-asyncio +gunicorn_h1c>=0.4.1 diff --git a/tests/conftest.py b/tests/conftest.py index 46046918..56509c95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,8 +7,21 @@ import os import sys +import pytest + # Add the tests directory to sys.path so test support modules can be imported # as 'tests.module_name' (e.g., 'tests.support_dirty_apps:CounterApp') tests_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if tests_dir not in sys.path: sys.path.insert(0, tests_dir) + + +@pytest.fixture(params=["python", "fast"]) +def http_parser(request): + """Parametrize tests over http_parser implementations.""" + if request.param == "fast": + gunicorn_h1c = pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required") + # Require >= 0.4.1 for limit enforcement + if not hasattr(gunicorn_h1c, 'LimitRequestLine'): + pytest.skip("gunicorn_h1c >= 0.4.1 required") + return request.param diff --git a/tests/requests/invalid/limit_header_default_01.http b/tests/requests/invalid/limit_header_default_01.http new file mode 100644 index 00000000..9d3d7336 --- /dev/null +++ b/tests/requests/invalid/limit_header_default_01.http @@ -0,0 +1 @@ +GET / HTTP/1.0\r\nX: yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy\r\n\r\n \ No newline at end of file diff --git a/tests/requests/invalid/limit_header_default_01.py b/tests/requests/invalid/limit_header_default_01.py new file mode 100644 index 00000000..0a8c1fcf --- /dev/null +++ b/tests/requests/invalid/limit_header_default_01.py @@ -0,0 +1,11 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.config import Config +from gunicorn.http.errors import LimitRequestHeaders + +cfg = Config() +# Setting limit_request_field_size=0 should use default max (8190) +cfg.set('limit_request_field_size', 0) +request = LimitRequestHeaders diff --git a/tests/requests/invalid/limit_line_default_01.http b/tests/requests/invalid/limit_line_default_01.http new file mode 100644 index 00000000..f1224938 --- /dev/null +++ b/tests/requests/invalid/limit_line_default_01.http @@ -0,0 +1 @@ +GET /xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx HTTP/1.0\r\n \ No newline at end of file diff --git a/tests/requests/invalid/limit_line_default_01.py b/tests/requests/invalid/limit_line_default_01.py new file mode 100644 index 00000000..03d666b2 --- /dev/null +++ b/tests/requests/invalid/limit_line_default_01.py @@ -0,0 +1,11 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.config import Config +from gunicorn.http.errors import LimitRequestLine + +cfg = Config() +# Setting limit_request_line=0 should use default max (8190) +cfg.set('limit_request_line', 0) +request = LimitRequestLine diff --git a/tests/requests/valid/024.py b/tests/requests/valid/024.py index 3dc8a597..bf9a2f22 100644 --- a/tests/requests/valid/024.py +++ b/tests/requests/valid/024.py @@ -5,8 +5,9 @@ from gunicorn.config import Config cfg = Config() -cfg.set('limit_request_line', 0) -cfg.set('limit_request_field_size', 0) +# Request line is 8194 bytes, header line is 8209 bytes (both include CRLF) +cfg.set('limit_request_line', 8200) +cfg.set('limit_request_field_size', 8210) request = { "method": "PUT", "uri": diff --git a/tests/requests/valid/026.py b/tests/requests/valid/026.py index a56e7ae5..30c9d7b8 100644 --- a/tests/requests/valid/026.py +++ b/tests/requests/valid/026.py @@ -5,7 +5,7 @@ from gunicorn.config import Config cfg = Config() -cfg.set('limit_request_line', 0) +# Header line is 8209 bytes (name + ": " + value + CRLF) cfg.set('limit_request_field_size', 8210) request = { "method": "GET", diff --git a/tests/test_asgi.py b/tests/test_asgi.py index 4e75f7ac..74e11ba2 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -7,10 +7,8 @@ Tests for ASGI worker components. """ import asyncio -import io import ipaddress import pytest -from unittest import mock from gunicorn.asgi.unreader import AsyncUnreader from gunicorn.asgi.message import AsyncRequest diff --git a/tests/test_asgi_callback_parser.py b/tests/test_asgi_callback_parser.py new file mode 100644 index 00000000..7bba3ea2 --- /dev/null +++ b/tests/test_asgi_callback_parser.py @@ -0,0 +1,518 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for ASGI callback parsers. + +Tests both PythonProtocol and H1CProtocol (if available) to ensure +consistent behavior across implementations. +""" + +from gunicorn.asgi.parser import PythonProtocol + + +def get_parser_class(http_parser): + """Get the appropriate parser class for the test parameter.""" + if http_parser == "fast": + from gunicorn_h1c import H1CProtocol + return H1CProtocol + return PythonProtocol + + +def normalize_headers(headers): + """Normalize headers to lowercase names for comparison. + + H1CProtocol preserves original case, PythonProtocol lowercases. + """ + return {name.lower(): value for name, value in headers} + + +class TestRequestLineParsing: + """Test request line parsing for both implementations.""" + + def test_simple_get(self, http_parser): + """Parse a simple GET request.""" + parser_class = get_parser_class(http_parser) + events = [] + + parser = parser_class( + on_message_begin=lambda: events.append('begin'), + on_url=lambda url: events.append(('url', url)), + on_headers_complete=lambda: events.append('headers_complete'), + on_message_complete=lambda: events.append('complete'), + ) + + parser.feed(b"GET /path HTTP/1.1\r\n\r\n") + + assert parser.method == b"GET" + assert parser.path == b"/path" + assert parser.http_version == (1, 1) + assert parser.is_complete + assert 'begin' in events + assert ('url', b'/path') in events + assert 'complete' in events + + def test_post_with_query(self, http_parser): + """Parse a POST request with query string.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed(b"POST /api/data?foo=bar&baz=qux HTTP/1.1\r\n\r\n") + + assert parser.method == b"POST" + assert parser.path == b"/api/data?foo=bar&baz=qux" + assert parser.http_version == (1, 1) + assert parser.is_complete + + def test_http_10_version(self, http_parser): + """Parse HTTP/1.0 request.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed(b"GET / HTTP/1.0\r\n\r\n") + + assert parser.method == b"GET" + assert parser.http_version == (1, 0) + assert parser.is_complete + + def test_various_methods(self, http_parser): + """Test parsing various HTTP methods.""" + parser_class = get_parser_class(http_parser) + methods = [b"GET", b"POST", b"PUT", b"DELETE", b"PATCH", b"HEAD", b"OPTIONS"] + + for method in methods: + parser = parser_class() + parser.feed(method + b" / HTTP/1.1\r\n\r\n") + assert parser.method == method + + +class TestHeaderParsing: + """Test header parsing for both implementations.""" + + def test_single_header(self, http_parser): + """Parse a request with single header.""" + parser_class = get_parser_class(http_parser) + headers = [] + + parser = parser_class( + on_header=lambda n, v: headers.append((n, v)), + ) + parser.feed(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n") + + assert len(parser.headers) == 1 + header_dict = normalize_headers(parser.headers) + assert header_dict[b"host"] == b"localhost" + callback_dict = normalize_headers(headers) + assert callback_dict[b"host"] == b"localhost" + + def test_multiple_headers(self, http_parser): + """Parse a request with multiple headers.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"User-Agent: TestClient\r\n" + b"Accept: */*\r\n" + b"\r\n" + ) + + assert len(parser.headers) == 3 + header_dict = normalize_headers(parser.headers) + assert header_dict[b"host"] == b"localhost" + assert header_dict[b"user-agent"] == b"TestClient" + assert header_dict[b"accept"] == b"*/*" + + def test_header_with_spaces(self, http_parser): + """Parse headers with leading/trailing spaces in values.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost \r\n" + b"\r\n" + ) + + header_dict = normalize_headers(parser.headers) + assert header_dict[b"host"] == b"localhost" + + def test_empty_header_value(self, http_parser): + """Parse header with empty value.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Empty:\r\n" + b"\r\n" + ) + + header_dict = normalize_headers(parser.headers) + assert header_dict[b"x-empty"] == b"" + + def test_large_header_value(self, http_parser): + """Parse header with large value.""" + parser_class = get_parser_class(http_parser) + + large_value = b"x" * 4096 + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Large: " + large_value + b"\r\n" + b"\r\n" + ) + + header_dict = normalize_headers(parser.headers) + assert header_dict[b"x-large"] == large_value + + +class TestBodyHandling: + """Test body parsing for both implementations.""" + + def test_content_length_body(self, http_parser): + """Parse request with Content-Length body.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_body=lambda chunk: body_chunks.append(chunk), + ) + parser.feed( + b"POST /data HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 13\r\n" + b"\r\n" + b"Hello, World!" + ) + + assert parser.content_length == 13 + assert not parser.is_chunked + assert b"".join(body_chunks) == b"Hello, World!" + assert parser.is_complete + + def test_content_length_incremental(self, http_parser): + """Parse body arriving in multiple chunks.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + # Send headers + parser.feed( + b"POST /data HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + assert not parser.is_complete + + # Send body in parts + parser.feed(b"Hello") + assert not parser.is_complete + parser.feed(b"World") + assert parser.is_complete + + assert b"".join(body_chunks) == b"HelloWorld" + + def test_chunked_encoding(self, http_parser): + """Parse chunked transfer-encoded body.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_body=lambda chunk: body_chunks.append(chunk), + ) + parser.feed( + b"POST /data HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\n" + b"Hello\r\n" + b"6\r\n" + b"World!\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_chunked + assert b"".join(body_chunks) == b"HelloWorld!" + assert parser.is_complete + + def test_chunked_with_extensions(self, http_parser): + """Parse chunked body with chunk extensions.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_body=lambda chunk: body_chunks.append(chunk), + ) + parser.feed( + b"POST /data HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5;ext=value\r\n" + b"Hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert b"".join(body_chunks) == b"Hello" + assert parser.is_complete + + def test_no_body_get(self, http_parser): + """GET request has no body.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_body=lambda chunk: body_chunks.append(chunk), + ) + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.content_length is None + assert not parser.is_chunked + assert body_chunks == [] + assert parser.is_complete + + +class TestConnectionHandling: + """Test connection handling and keep-alive for both implementations.""" + + def test_http11_keepalive_default(self, http_parser): + """HTTP/1.1 defaults to keep-alive.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.should_keep_alive is True + + def test_http11_connection_close(self, http_parser): + """HTTP/1.1 with Connection: close.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"\r\n" + ) + + assert parser.should_keep_alive is False + + def test_http10_no_keepalive(self, http_parser): + """HTTP/1.0 defaults to no keep-alive.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.0\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.should_keep_alive is False + + def test_http10_with_keepalive(self, http_parser): + """HTTP/1.0 with Connection: keep-alive.""" + parser_class = get_parser_class(http_parser) + + parser = parser_class() + parser.feed( + b"GET / HTTP/1.0\r\n" + b"Host: localhost\r\n" + b"Connection: keep-alive\r\n" + b"\r\n" + ) + + assert parser.should_keep_alive is True + + +class TestParserReset: + """Test parser reset for keep-alive connections.""" + + def test_reset_after_request(self, http_parser): + """Parser can be reset for a new request.""" + parser_class = get_parser_class(http_parser) + complete_count = [0] + + parser = parser_class( + on_message_complete=lambda: complete_count.__setitem__(0, complete_count[0] + 1), + ) + + # First request + parser.feed(b"GET /first HTTP/1.1\r\n\r\n") + assert parser.path == b"/first" + assert parser.is_complete + + # Reset and send second request + parser.reset() + assert not parser.is_complete + # H1CProtocol resets to b'', PythonProtocol to None + assert not parser.method + + parser.feed(b"GET /second HTTP/1.1\r\n\r\n") + assert parser.path == b"/second" + assert parser.is_complete + + assert complete_count[0] == 2 + + +class TestCallbackBehavior: + """Test callback behavior consistency.""" + + def test_all_callbacks_fire(self, http_parser): + """All callbacks fire in correct order.""" + parser_class = get_parser_class(http_parser) + events = [] + + parser = parser_class( + on_message_begin=lambda: events.append('begin'), + on_url=lambda url: events.append(('url', url)), + on_header=lambda n, v: events.append(('header', n.lower(), v)), + on_headers_complete=lambda: events.append('headers_complete'), + on_body=lambda chunk: events.append(('body', chunk)), + on_message_complete=lambda: events.append('complete'), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"test" + ) + + assert events[0] == 'begin' + assert events[1] == ('url', b'/') + assert ('header', b'host', b'localhost') in events + assert ('header', b'content-length', b'4') in events + assert 'headers_complete' in events + assert ('body', b'test') in events + assert events[-1] == 'complete' + + def test_skip_body_on_headers_complete(self, http_parser): + """Return True from on_headers_complete skips body parsing.""" + parser_class = get_parser_class(http_parser) + body_chunks = [] + + parser = parser_class( + on_headers_complete=lambda: True, # Skip body + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + b"0123456789" + ) + + assert parser.is_complete + assert body_chunks == [] # Body was skipped + + +class TestCallbackRequest: + """Test CallbackRequest building from parser state.""" + + def test_non_ascii_path_decoding(self, http_parser): + """Test that percent-encoded UTF-8 paths are decoded correctly. + + Per ASGI spec: + - path: percent-decoded UTF-8 string + - raw_path: original bytes as received + """ + from gunicorn.asgi.parser import CallbackRequest + + parser_class = get_parser_class(http_parser) + parser = parser_class() + + # ö = %C3%B6 in UTF-8 percent-encoded + parser.feed(b"GET /%C3%B6/ HTTP/1.1\r\nHost: test\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + # path should be percent-decoded UTF-8 string + assert request.path == "/\u00f6/" # /ö/ + # raw_path should be original bytes + assert request.raw_path == b"/%C3%B6/" + + def test_non_ascii_path_with_query(self, http_parser): + """Test percent-encoded path with query string.""" + from gunicorn.asgi.parser import CallbackRequest + + parser_class = get_parser_class(http_parser) + parser = parser_class() + + # Japanese: /日本/ = /%E6%97%A5%E6%9C%AC/ + parser.feed(b"GET /%E6%97%A5%E6%9C%AC/?q=test HTTP/1.1\r\nHost: test\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + assert request.path == "/\u65e5\u672c/" # /日本/ + assert request.raw_path == b"/%E6%97%A5%E6%9C%AC/" + assert request.query == "q=test" + + def test_invalid_utf8_path(self, http_parser): + """Test that invalid UTF-8 sequences use replacement character.""" + from gunicorn.asgi.parser import CallbackRequest + + parser_class = get_parser_class(http_parser) + parser = parser_class() + + # %FF is invalid UTF-8 + parser.feed(b"GET /%FF HTTP/1.1\r\nHost: test\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + # Should use replacement character for invalid bytes + assert "\ufffd" in request.path + assert request.raw_path == b"/%FF" + + def test_simple_ascii_path(self, http_parser): + """Test that simple ASCII paths work unchanged.""" + from gunicorn.asgi.parser import CallbackRequest + + parser_class = get_parser_class(http_parser) + parser = parser_class() + + parser.feed(b"GET /api/users HTTP/1.1\r\nHost: test\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + assert request.path == "/api/users" + assert request.raw_path == b"/api/users" + + def test_percent_encoded_ascii(self, http_parser): + """Test percent-encoded ASCII characters.""" + from gunicorn.asgi.parser import CallbackRequest + + parser_class = get_parser_class(http_parser) + parser = parser_class() + + # Space encoded as %20 + parser.feed(b"GET /hello%20world HTTP/1.1\r\nHost: test\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + assert request.path == "/hello world" + assert request.raw_path == b"/hello%20world" diff --git a/tests/test_asgi_compliance.py b/tests/test_asgi_compliance.py index b0ac919d..5e97ad0e 100644 --- a/tests/test_asgi_compliance.py +++ b/tests/test_asgi_compliance.py @@ -9,9 +9,10 @@ Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec: https://asgi.readthedocs.io/en/latest/specs/main.html """ -import asyncio from unittest import mock +import pytest + from gunicorn.config import Config @@ -37,7 +38,9 @@ class TestASGIVersion: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -136,7 +139,9 @@ class TestHTTPScopeKeys: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -653,6 +658,7 @@ class TestStateSharing: request = mock.Mock() request.method = "GET" request.path = "/" + request.raw_path = b"/" request.query = "" request.version = (1, 1) request.scheme = "http" @@ -730,31 +736,46 @@ class TestHTTPDisconnectEvent: return protocol def test_disconnect_event_type(self): - """Test that disconnect event has correct type per ASGI spec.""" + """Test that disconnect event signals body receiver per ASGI spec.""" + from gunicorn.asgi.protocol import BodyReceiver + protocol = self._create_protocol() - protocol._receive_queue = asyncio.Queue() + + # Create a mock request for the body receiver + mock_request = mock.Mock() + mock_request.content_length = 100 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver # Simulate client disconnect protocol.connection_lost(None) - # Get the message from queue - msg = protocol._receive_queue.get_nowait() - - # Per ASGI spec: type MUST be "http.disconnect" - assert msg["type"] == "http.disconnect" + # Per ASGI spec: disconnect should be signaled + assert body_receiver._closed def test_disconnect_event_sent_on_connection_lost(self): - """Test that http.disconnect is sent when connection is lost.""" - protocol = self._create_protocol() - protocol._receive_queue = asyncio.Queue() + """Test that disconnect is signaled when connection is lost.""" + from gunicorn.asgi.protocol import BodyReceiver - assert protocol._receive_queue.empty() + protocol = self._create_protocol() + + # Create a mock request for the body receiver + mock_request = mock.Mock() + mock_request.content_length = 100 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver + + assert not body_receiver._closed # Simulate client disconnect protocol.connection_lost(None) - # Queue should have disconnect message - assert not protocol._receive_queue.empty() + # Disconnect should have been signaled + assert body_receiver._closed def test_disconnect_sets_closed_flag(self): """Test that connection_lost sets the closed flag.""" @@ -788,18 +809,33 @@ class TestHTTPDisconnectEvent: # Cancellation should be scheduled after grace period protocol.worker.loop.call_later.assert_called_once() - def test_disconnect_message_format(self): + @pytest.mark.asyncio + async def test_disconnect_message_format(self): """Test http.disconnect message format per ASGI spec. - The disconnect message should only contain 'type' key. + When body is complete and disconnect is signaled, receive() + should return {"type": "http.disconnect"}. """ + from gunicorn.asgi.protocol import BodyReceiver + protocol = self._create_protocol() - protocol._receive_queue = asyncio.Queue() - protocol.connection_lost(None) + # Create a mock request with no body + mock_request = mock.Mock() + mock_request.content_length = 0 + mock_request.chunked = False - msg = protocol._receive_queue.get_nowait() + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver + + # Get initial body message (empty body) + msg1 = await body_receiver.receive() + assert msg1["type"] == "http.request" + assert msg1["more_body"] is False + + # Now receive should return disconnect + msg2 = await body_receiver.receive() # Per ASGI spec, disconnect message only has 'type' - assert msg == {"type": "http.disconnect"} - assert len(msg) == 1 + assert msg2 == {"type": "http.disconnect"} + assert len(msg2) == 1 diff --git a/tests/test_asgi_disconnect.py b/tests/test_asgi_disconnect.py index 8423ce7d..7de45bb3 100644 --- a/tests/test_asgi_disconnect.py +++ b/tests/test_asgi_disconnect.py @@ -50,41 +50,58 @@ class TestASGIGracefulDisconnect: assert protocol._closed is True - def test_disconnect_sends_message_to_queue(self, mock_worker): - """Test that connection_lost sends http.disconnect to receive queue.""" + def test_disconnect_signals_body_receiver(self, mock_worker): + """Test that connection_lost signals the body receiver.""" + from gunicorn.asgi.protocol import BodyReceiver + protocol = ASGIProtocol(mock_worker) protocol.reader = mock.Mock() mock_worker.nr_conns = 1 - # Create a receive queue (simulating active request) - protocol._receive_queue = asyncio.Queue() + # Create a mock request for the body receiver + mock_request = mock.Mock() + mock_request.content_length = 100 + mock_request.chunked = False + + # Create a body receiver (simulating active request) + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver + + # Verify disconnect flag is not set initially + assert not body_receiver._closed # Simulate connection lost protocol.connection_lost(None) - # Check that disconnect message was sent - assert not protocol._receive_queue.empty() - msg = protocol._receive_queue.get_nowait() - assert msg == {"type": "http.disconnect"} + # Check that disconnect flag was set + assert body_receiver._closed def test_disconnect_is_idempotent(self, mock_worker): """Test that connection_lost can be called multiple times safely.""" + from gunicorn.asgi.protocol import BodyReceiver + protocol = ASGIProtocol(mock_worker) protocol.reader = mock.Mock() mock_worker.nr_conns = 2 # Start with 2 so we can verify only 1 is decremented - protocol._receive_queue = asyncio.Queue() + # Create a mock request for the body receiver + mock_request = mock.Mock() + mock_request.content_length = 100 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver # First call should work protocol.connection_lost(None) assert protocol._closed is True assert mock_worker.nr_conns == 1 - assert protocol._receive_queue.qsize() == 1 + assert body_receiver._closed # Second call should be a no-op protocol.connection_lost(None) assert mock_worker.nr_conns == 1 # Should not decrement again - assert protocol._receive_queue.qsize() == 1 # Should not add another message + # Closed flag is still set def test_disconnect_does_not_cancel_immediately(self, mock_worker): """Test that connection_lost doesn't cancel task immediately.""" @@ -156,41 +173,26 @@ class TestASGIGracefulDisconnect: @pytest.mark.asyncio async def test_receive_returns_disconnect_when_closed(self, mock_worker): """Test that receive() returns http.disconnect when connection is closed.""" + from gunicorn.asgi.protocol import BodyReceiver + protocol = ASGIProtocol(mock_worker) protocol._closed = True - # Create receive queue with body complete - receive_queue = asyncio.Queue() - protocol._receive_queue = receive_queue + # Create a mock request with no body + mock_request = mock.Mock() + mock_request.content_length = 0 + mock_request.chunked = False - # Add initial body message - await receive_queue.put({ - "type": "http.request", - "body": b"", - "more_body": False, - }) + body_receiver = BodyReceiver(mock_request, protocol) + protocol._body_receiver = body_receiver - # Simulate what happens in _handle_http_request - body_complete = False - - async def receive(): - nonlocal body_complete - if protocol._closed and body_complete: - return {"type": "http.disconnect"} - - msg = await receive_queue.get() - - if msg.get("type") == "http.request" and not msg.get("more_body", True): - body_complete = True - - return msg - - # First receive gets the body - msg1 = await receive() + # First receive gets the body (empty) + msg1 = await body_receiver.receive() assert msg1["type"] == "http.request" + assert msg1["more_body"] is False - # Second receive should get disconnect - msg2 = await receive() + # Second receive should get disconnect (body complete) + msg2 = await body_receiver.receive() assert msg2["type"] == "http.disconnect" diff --git a/tests/test_asgi_http_scope.py b/tests/test_asgi_http_scope.py index c4ddfe38..84827239 100644 --- a/tests/test_asgi_http_scope.py +++ b/tests/test_asgi_http_scope.py @@ -11,7 +11,6 @@ and extension support. from unittest import mock -import pytest from gunicorn.config import Config @@ -40,7 +39,9 @@ class TestHTTPScopeBuilding: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -114,7 +115,9 @@ class TestPathHandling: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -194,7 +197,9 @@ class TestQueryStringHandling: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -270,7 +275,9 @@ class TestHeaderHandling: """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") - request.path = kwargs.get("path", "/") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -362,7 +369,9 @@ class TestWebSocketScope: """Create a mock WebSocket upgrade request.""" request = mock.Mock() request.method = "GET" - request.path = kwargs.get("path", "/ws") + path = kwargs.get("path", "/ws") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") @@ -575,6 +584,7 @@ class TestAddressHandling: request = mock.Mock() request.method = "GET" request.path = "/" + request.raw_path = b"/" request.query = "" request.version = (1, 1) request.scheme = "http" diff --git a/tests/test_asgi_parser.py b/tests/test_asgi_parser.py index 15211798..e20aad17 100644 --- a/tests/test_asgi_parser.py +++ b/tests/test_asgi_parser.py @@ -6,7 +6,6 @@ Tests for ASGI HTTP parser optimizations. """ -import asyncio import ipaddress import pytest diff --git a/tests/test_asgi_streaming.py b/tests/test_asgi_streaming.py index 393a07e4..d4bd7ed4 100644 --- a/tests/test_asgi_streaming.py +++ b/tests/test_asgi_streaming.py @@ -9,7 +9,6 @@ Tests for chunked transfer encoding, Server-Sent Events (SSE), and streaming response handling. """ -import asyncio from unittest import mock import pytest @@ -221,43 +220,39 @@ class TestProtocolSendBody: return protocol - @pytest.mark.asyncio - async def test_send_body_without_chunking(self): + def test_send_body_without_chunking(self): """Test sending body without chunked encoding.""" protocol = self._create_protocol() - await protocol._send_body(b"Hello, World!", chunked=False) + protocol._send_body(b"Hello, World!", chunked=False) protocol.transport.write.assert_called_once_with(b"Hello, World!") - @pytest.mark.asyncio - async def test_send_body_with_chunking(self): + def test_send_body_with_chunking(self): """Test sending body with chunked encoding.""" protocol = self._create_protocol() - await protocol._send_body(b"Hello", chunked=True) + protocol._send_body(b"Hello", chunked=True) # Should write: "5\r\nHello\r\n" protocol.transport.write.assert_called_once() call_arg = protocol.transport.write.call_args[0][0] assert call_arg == b"5\r\nHello\r\n" - @pytest.mark.asyncio - async def test_send_body_empty_without_chunking(self): + def test_send_body_empty_without_chunking(self): """Test sending empty body without chunked encoding.""" protocol = self._create_protocol() - await protocol._send_body(b"", chunked=False) + protocol._send_body(b"", chunked=False) # Empty body should not write anything protocol.transport.write.assert_not_called() - @pytest.mark.asyncio - async def test_send_body_empty_with_chunking(self): + def test_send_body_empty_with_chunking(self): """Test sending empty body with chunked encoding.""" protocol = self._create_protocol() - await protocol._send_body(b"", chunked=True) + protocol._send_body(b"", chunked=True) # Empty body should not write (terminal chunk handled separately) protocol.transport.write.assert_not_called() diff --git a/tests/test_asgi_websocket_protocol.py b/tests/test_asgi_websocket_protocol.py index 5173042a..08db5866 100644 --- a/tests/test_asgi_websocket_protocol.py +++ b/tests/test_asgi_websocket_protocol.py @@ -9,7 +9,6 @@ Tests that gunicorn's WebSocket implementation conforms to RFC 6455: https://tools.ietf.org/html/rfc6455 """ -import asyncio import base64 import hashlib import struct @@ -176,7 +175,7 @@ class TestWebSocketFrameMasking: def _create_protocol(self): """Create a WebSocketProtocol instance for testing.""" from gunicorn.asgi.websocket import WebSocketProtocol - return WebSocketProtocol(None, None, {}, None, mock.Mock()) + return WebSocketProtocol(None, {}, None, mock.Mock()) def test_unmask_simple(self): """Test basic unmasking operation.""" @@ -299,7 +298,6 @@ class TestWebSocketProtocolInstance: return WebSocketProtocol( transport=mock.Mock(), - reader=mock.Mock(), scope=scope, app=mock.AsyncMock(), log=mock.Mock(), @@ -602,11 +600,9 @@ class TestWebSocketAsync: } transport = mock.Mock() - reader = mock.Mock() return WebSocketProtocol( transport=transport, - reader=reader, scope=scope, app=mock.AsyncMock(), log=mock.Mock(), @@ -676,3 +672,367 @@ class TestWebSocketAsync: await protocol._send({"type": "websocket.close", "code": 1000}) assert protocol.closed is True + + +# ============================================================================ +# Callback-based Data Feeding Tests +# ============================================================================ + +class TestWebSocketCallbackDataFeeding: + """Tests for callback-based data feeding (replaces StreamReader).""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + def test_initial_buffer_empty(self): + """Test that initial buffer is empty.""" + protocol = self._create_protocol() + assert len(protocol._buffer) == 0 + assert protocol._eof is False + + def test_feed_data_adds_to_buffer(self): + """Test that feed_data adds bytes to buffer.""" + protocol = self._create_protocol() + + protocol.feed_data(b"Hello") + assert bytes(protocol._buffer) == b"Hello" + + protocol.feed_data(b" World") + assert bytes(protocol._buffer) == b"Hello World" + + def test_feed_data_ignores_empty(self): + """Test that feed_data ignores empty data.""" + protocol = self._create_protocol() + + protocol.feed_data(b"") + assert len(protocol._buffer) == 0 + + protocol.feed_data(None) + # Should not raise, just be ignored + + def test_feed_data_sets_event(self): + """Test that feed_data sets the data event.""" + protocol = self._create_protocol() + + assert not protocol._data_event.is_set() + protocol.feed_data(b"data") + assert protocol._data_event.is_set() + + def test_feed_eof_sets_flag(self): + """Test that feed_eof sets the EOF flag.""" + protocol = self._create_protocol() + + assert protocol._eof is False + protocol.feed_eof() + assert protocol._eof is True + + def test_feed_eof_sets_event(self): + """Test that feed_eof sets the data event.""" + protocol = self._create_protocol() + + assert not protocol._data_event.is_set() + protocol.feed_eof() + assert protocol._data_event.is_set() + + +class TestWebSocketReadExact: + """Tests for _read_exact method with callback-based buffer.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_read_exact_with_sufficient_data(self): + """Test _read_exact returns data when buffer has enough.""" + protocol = self._create_protocol() + + # Pre-fill buffer + protocol.feed_data(b"Hello World") + + result = await protocol._read_exact(5) + assert result == b"Hello" + assert bytes(protocol._buffer) == b" World" + + @pytest.mark.asyncio + async def test_read_exact_consumes_buffer(self): + """Test _read_exact properly consumes buffer.""" + protocol = self._create_protocol() + + protocol.feed_data(b"ABCDEFGH") + + result1 = await protocol._read_exact(3) + assert result1 == b"ABC" + + result2 = await protocol._read_exact(3) + assert result2 == b"DEF" + + assert bytes(protocol._buffer) == b"GH" + + @pytest.mark.asyncio + async def test_read_exact_returns_none_on_eof(self): + """Test _read_exact returns None when EOF with insufficient data.""" + protocol = self._create_protocol() + + protocol.feed_data(b"Hi") + protocol.feed_eof() + + # Request more data than available after EOF + result = await protocol._read_exact(10) + assert result is None + + @pytest.mark.asyncio + async def test_read_exact_waits_for_data(self): + """Test _read_exact waits when buffer is insufficient.""" + import asyncio + protocol = self._create_protocol() + + # Start read that needs more data + read_task = asyncio.create_task(protocol._read_exact(10)) + + # Give task a chance to start waiting + await asyncio.sleep(0.01) + assert not read_task.done() + + # Feed enough data + protocol.feed_data(b"1234567890") + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_handles_incremental_data(self): + """Test _read_exact handles data arriving in chunks.""" + import asyncio + protocol = self._create_protocol() + + # Start read needing 10 bytes + read_task = asyncio.create_task(protocol._read_exact(10)) + + await asyncio.sleep(0.01) + + # Feed data incrementally + protocol.feed_data(b"123") + await asyncio.sleep(0.01) + assert not read_task.done() + + protocol.feed_data(b"456") + await asyncio.sleep(0.01) + assert not read_task.done() + + protocol.feed_data(b"7890") + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_race_condition(self): + """Test _read_exact handles race condition when data arrives during clear/wait gap. + + This tests the fix for the race condition where: + 1. Task A checks buffer, needs more data + 2. Task A clears _data_event + 3. Task B (data_received) calls feed_data(), sets event + 4. Task A would wait forever on cleared event - DEADLOCK + + The fix adds a buffer check after clear() to catch this case. + """ + import asyncio + protocol = self._create_protocol() + + # Pre-fill with partial data + protocol.feed_data(b"12345") + + # Start read needing 10 bytes + read_task = asyncio.create_task(protocol._read_exact(10)) + await asyncio.sleep(0.01) + assert not read_task.done() + + # Simulate race: feed remaining data rapidly + # In the buggy version, if data arrives right after clear() but before wait(), + # the event gets set then immediately the wait() would block on a stale clear + protocol.feed_data(b"67890") + + # Should complete without deadlock + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_multiple_feeds_before_wait(self): + """Test _read_exact when all data arrives before wait starts.""" + import asyncio + protocol = self._create_protocol() + + # Feed all data before starting read - should not block + protocol.feed_data(b"Complete message here") + + result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1) + assert result == b"Complete" + + # Buffer should have remainder + assert bytes(protocol._buffer) == b" message here" + + @pytest.mark.asyncio + async def test_read_exact_eof_during_wait(self): + """Test _read_exact handles EOF arriving while waiting for data.""" + import asyncio + protocol = self._create_protocol() + + # Start read needing more data than we'll provide + read_task = asyncio.create_task(protocol._read_exact(100)) + + await asyncio.sleep(0.01) + assert not read_task.done() + + # Feed some data but not enough + protocol.feed_data(b"partial") + await asyncio.sleep(0.01) + assert not read_task.done() + + # Signal EOF - should cause read to return None + protocol.feed_eof() + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result is None + + +# ============================================================================ +# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4) +# ============================================================================ + +class TestWebSocketFragmentedMessages: + """Tests for WebSocket fragmented message handling.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + def _create_masked_frame(self, fin, opcode, payload, mask_key=None): + """Create a masked WebSocket frame. + + Args: + fin: FIN bit (1 for final, 0 for continuation) + opcode: Frame opcode + payload: Frame payload bytes + mask_key: 4-byte masking key (generated if None) + + Returns: + bytes: Complete masked frame + """ + if mask_key is None: + mask_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + + frame = bytearray() + + # First byte: FIN + RSV(000) + opcode + frame.append((fin << 7) | opcode) + + # Second byte: MASK(1) + length + length = len(payload) + if length < 126: + frame.append(0x80 | length) + elif length < 65536: + frame.append(0x80 | 126) + frame.extend(struct.pack("!H", length)) + else: + frame.append(0x80 | 127) + frame.extend(struct.pack("!Q", length)) + + # Masking key + frame.extend(mask_key) + + # Masked payload + masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) + frame.extend(masked_payload) + + return bytes(frame) + + @pytest.mark.asyncio + async def test_fragmented_message_reassembly(self): + """Test reassembly of fragmented text message with multiple continuation frames.""" + from gunicorn.asgi.websocket import ( + OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT + ) + import asyncio + + protocol = self._create_protocol() + + # Build fragmented message: "Hello" + " " + "World" + "!" + # First frame: opcode=TEXT, FIN=0, payload="Hello" + frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello") + # Continuation frames: opcode=CONTINUATION, FIN=0 + frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ") + frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World") + # Final frame: opcode=CONTINUATION, FIN=1 + frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!") + + # Feed all frames + protocol.feed_data(frame1 + frame2 + frame3 + frame4) + + # Read frames - first 3 should return CONTINUATION with empty payload (waiting) + result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started + + result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued + + result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued + + # Final frame should return complete reassembled message with original opcode + result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result4 == (OPCODE_TEXT, b"Hello World!") + + @pytest.mark.asyncio + async def test_control_frame_during_fragmentation(self): + """Test that control frames (ping) can arrive during fragmented message. + + RFC 6455 Section 5.4: Control frames MAY be injected in the middle + of a fragmented message. + """ + from gunicorn.asgi.websocket import ( + OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING + ) + import asyncio + + protocol = self._create_protocol() + + # Start fragmented message + frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello") + # Ping frame in the middle (control frames are always FIN=1) + ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping") + # Continue and finish fragmented message + frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World") + + protocol.feed_data(frame1 + ping_frame + frame2) + + # First read: fragment started (waiting for more) + result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result1 == (OPCODE_CONTINUATION, b"") + + # Second read: ping frame (control frames handled separately) + result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result2 == (OPCODE_PING, b"ping") + + # Third read: complete reassembled message + result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result3 == (OPCODE_TEXT, b"Hello World") diff --git a/tests/test_asgi_worker.py b/tests/test_asgi_worker.py index fdccf8b1..bf534601 100644 --- a/tests/test_asgi_worker.py +++ b/tests/test_asgi_worker.py @@ -12,11 +12,7 @@ that actually start the server and make HTTP requests. import asyncio import errno import os -import signal import socket -import sys -import time -import threading from unittest import mock import pytest @@ -120,7 +116,7 @@ class FakeListener: def _has_uvloop(): """Check if uvloop is available.""" try: - import uvloop + import uvloop # noqa: F401 return True except ImportError: return False @@ -337,9 +333,9 @@ class TestLifespanManager: async def app(scope, receive, send): assert "state" in scope scope["state"]["db"] = "connected" - message = await receive() + _ = await receive() await send({"type": "lifespan.startup.complete"}) - message = await receive() + _ = await receive() await send({"type": "lifespan.shutdown.complete"}) manager = LifespanManager(app, mock.Mock(), state) @@ -393,7 +389,7 @@ class TestWebSocketProtocol: from gunicorn.asgi.websocket import WebSocketProtocol # Create a minimal protocol instance - protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + protocol = WebSocketProtocol(None, {}, None, mock.Mock()) # Test unmasking (XOR operation) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) @@ -406,7 +402,7 @@ class TestWebSocketProtocol: """Test WebSocket frame unmasking with empty payload.""" from gunicorn.asgi.websocket import WebSocketProtocol - protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + protocol = WebSocketProtocol(None, {}, None, mock.Mock()) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) unmasked = protocol._unmask(b"", masking_key) @@ -555,7 +551,6 @@ class TestASGIProtocol: def test_scope_building(self): """Test HTTP scope building.""" from gunicorn.asgi.protocol import ASGIProtocol - from gunicorn.asgi.message import AsyncRequest worker = mock.Mock() worker.cfg = Config() @@ -729,10 +724,11 @@ class TestASGIHTTP2Priority: protocol = ASGIProtocol(worker) # Create mock HTTP/1.1 request (no priority attributes) - request = mock.Mock(spec=['method', 'path', 'query', 'version', + request = mock.Mock(spec=['method', 'path', 'raw_path', 'query', 'version', 'scheme', 'headers']) request.method = "GET" request.path = "/test" + request.raw_path = b"/test" request.query = "" request.version = (1, 1) request.scheme = "http" diff --git a/tests/test_invalid_requests.py b/tests/test_invalid_requests.py index 63224d07..9ec121d7 100644 --- a/tests/test_invalid_requests.py +++ b/tests/test_invalid_requests.py @@ -7,20 +7,57 @@ import os import pytest +from gunicorn.http.errors import ( + InvalidRequestLine, + InvalidRequestMethod, + InvalidSchemeHeaders, + ObsoleteFolding, +) import treq dirname = os.path.dirname(__file__) reqdir = os.path.join(dirname, "requests", "invalid") httpfiles = glob.glob(os.path.join(reqdir, "*.http")) +# Flags incompatible with fast parser (require Python parser features) +_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + +# Exceptions that only the Python parser raises (C parser has different validation) +_PYTHON_ONLY_EXCEPTIONS = (ObsoleteFolding, InvalidSchemeHeaders) + +# C parser may raise different but valid exceptions for these cases +_FAST_PARSER_ALTERNATES = { + InvalidRequestMethod: (InvalidRequestLine,), # e.g. "GET:" raises InvalidRequestLine +} + @pytest.mark.parametrize("fname", httpfiles) -def test_http_parser(fname): - env = treq.load_py(os.path.splitext(fname)[0] + ".py") +def test_http_parser(fname, http_parser): + """Test invalid HTTP requests with both parser implementations.""" + env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser) expect = env["request"] cfg = env["cfg"] + + # Skip fast parser tests that use incompatible compatibility flags + if http_parser == 'fast': + for flag in _FAST_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"fast parser incompatible with {flag}") + + # Skip tests expecting Python-only exceptions + if expect in _PYTHON_ONLY_EXCEPTIONS or ( + isinstance(expect, type) and issubclass(expect, _PYTHON_ONLY_EXCEPTIONS) + ): + pytest.skip(f"fast parser does not raise {expect.__name__}") + + # Determine acceptable exceptions (fast parser may raise alternates) + if http_parser == 'fast' and expect in _FAST_PARSER_ALTERNATES: + acceptable = (expect,) + _FAST_PARSER_ALTERNATES[expect] + else: + acceptable = expect + req = treq.badrequest(fname) - with pytest.raises(expect): + with pytest.raises(acceptable): req.check(cfg) diff --git a/tests/test_python_protocol.py b/tests/test_python_protocol.py new file mode 100644 index 00000000..a8c6b8a8 --- /dev/null +++ b/tests/test_python_protocol.py @@ -0,0 +1,518 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for PythonProtocol callback-based HTTP parser. +""" + +import pytest +from gunicorn.asgi.parser import PythonProtocol, CallbackRequest, ParseError + + +class TestPythonProtocolBasic: + """Test basic request parsing.""" + + def test_simple_get_request(self): + """Test parsing a simple GET request.""" + headers_complete = [] + message_complete = [] + + parser = PythonProtocol( + on_headers_complete=lambda: headers_complete.append(True), + on_message_complete=lambda: message_complete.append(True), + ) + + data = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.method == b"GET" + assert parser.path == b"/path" + assert parser.http_version == (1, 1) + assert len(parser.headers) == 1 + assert parser.headers[0] == (b"host", b"example.com") + assert parser.is_complete is True + assert len(headers_complete) == 1 + assert len(message_complete) == 1 + + def test_get_with_query_string(self): + """Test parsing GET with query string.""" + parser = PythonProtocol() + + data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.method == b"GET" + assert parser.path == b"/search?q=test&page=1" + assert parser.is_complete is True + + def test_http_10_request(self): + """Test parsing HTTP/1.0 request.""" + parser = PythonProtocol() + + data = b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.http_version == (1, 0) + assert parser.should_keep_alive is False # HTTP/1.0 default + + def test_http_10_with_keepalive(self): + """Test HTTP/1.0 with explicit keep-alive.""" + parser = PythonProtocol() + + data = b"GET / HTTP/1.0\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n" + parser.feed(data) + + assert parser.http_version == (1, 0) + assert parser.should_keep_alive is True + + def test_multiple_headers(self): + """Test parsing multiple headers.""" + parser = PythonProtocol() + + data = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Accept: text/html\r\n" + b"Accept-Language: en-US\r\n" + b"User-Agent: Test/1.0\r\n" + b"\r\n" + ) + parser.feed(data) + + assert len(parser.headers) == 4 + header_names = [h[0] for h in parser.headers] + assert b"host" in header_names + assert b"accept" in header_names + assert b"accept-language" in header_names + assert b"user-agent" in header_names + + +class TestPythonProtocolBody: + """Test request body parsing.""" + + def test_post_with_content_length(self): + """Test POST with Content-Length body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST /submit HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 13\r\n" + b"\r\n" + b"name=testuser" + ) + parser.feed(data) + + assert parser.method == b"POST" + assert parser.content_length == 13 + assert parser.is_complete is True + assert b"".join(body_chunks) == b"name=testuser" + + def test_chunked_body(self): + """Test chunked transfer encoding.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\nhello\r\n" + b"6\r\n world\r\n" + b"0\r\n\r\n" + ) + parser.feed(data) + + assert parser.is_chunked is True + assert parser.is_complete is True + assert b"".join(body_chunks) == b"hello world" + + def test_chunked_with_extension(self): + """Test chunked with chunk extension.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5;ext=value\r\nhello\r\n" + b"0\r\n\r\n" + ) + parser.feed(data) + + assert b"".join(body_chunks) == b"hello" + + +class TestPythonProtocolIncremental: + """Test incremental/partial data feeding.""" + + def test_partial_request_line(self): + """Test feeding partial request line.""" + parser = PythonProtocol() + + # Feed partial request line + parser.feed(b"GET /path ") + assert parser.method is None + assert parser.is_complete is False + + # Complete the request line and headers + parser.feed(b"HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.method == b"GET" + assert parser.is_complete is True + + def test_partial_headers(self): + """Test feeding partial headers.""" + parser = PythonProtocol() + + parser.feed(b"GET / HTTP/1.1\r\n") + parser.feed(b"Host: exa") + assert parser.is_complete is False + + parser.feed(b"mple.com\r\n\r\n") + assert parser.is_complete is True + assert parser.headers[0] == (b"host", b"example.com") + + def test_partial_body(self): + """Test feeding partial body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + b"hello" + ) + assert parser.is_complete is False + + parser.feed(b"world") + assert parser.is_complete is True + assert b"".join(body_chunks) == b"helloworld" + + def test_partial_chunked_body(self): + """Test feeding partial chunked body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\nhel" + ) + assert parser.is_complete is False + + parser.feed(b"lo\r\n0\r\n\r\n") + assert parser.is_complete is True + assert b"".join(body_chunks) == b"hello" + + +class TestPythonProtocolErrors: + """Test error handling.""" + + def test_invalid_request_line(self): + """Test invalid request line.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"INVALID\r\n") + + def test_invalid_header(self): + """Test invalid header (no colon).""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"GET / HTTP/1.1\r\nBadHeader\r\n\r\n") + + def test_unsupported_http_version(self): + """Test unsupported HTTP version.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"GET / HTTP/2.0\r\n\r\n") + + def test_invalid_chunk_size(self): + """Test invalid chunk size.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"XYZ\r\n" # Invalid hex + ) + + +class TestPythonProtocolReset: + """Test parser reset for keepalive.""" + + def test_reset_clears_state(self): + """Test that reset clears all state.""" + parser = PythonProtocol() + + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.is_complete is True + assert parser.method == b"GET" + + parser.reset() + + assert parser.method is None + assert parser.path is None + assert parser.http_version is None + assert parser.headers == [] + assert parser.content_length is None + assert parser.is_chunked is False + assert parser.is_complete is False + + def test_multiple_requests_keepalive(self): + """Test handling multiple requests on keepalive connection.""" + parser = PythonProtocol() + + # First request + parser.feed(b"GET /first HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.path == b"/first" + assert parser.is_complete is True + + parser.reset() + + # Second request + parser.feed(b"GET /second HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.path == b"/second" + assert parser.is_complete is True + + +class TestPythonProtocolCallbacks: + """Test callback firing.""" + + def test_all_callbacks(self): + """Test all callbacks fire in correct order.""" + events = [] + + parser = PythonProtocol( + on_message_begin=lambda: events.append("begin"), + on_url=lambda url: events.append(("url", url)), + on_header=lambda n, v: events.append(("header", n, v)), + on_headers_complete=lambda: events.append("headers_complete"), + on_body=lambda chunk: events.append(("body", chunk)), + on_message_complete=lambda: events.append("complete"), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + + assert events[0] == "begin" + assert events[1] == ("url", b"/") + assert events[2] == ("header", b"host", b"example.com") + assert events[3] == ("header", b"content-length", b"5") + assert events[4] == "headers_complete" + assert events[5] == ("body", b"hello") + assert events[6] == "complete" + + def test_skip_body_callback(self): + """Test on_headers_complete returning True skips body.""" + body_chunks = [] + + parser = PythonProtocol( + on_headers_complete=lambda: True, # Skip body + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + + # Body should be skipped + assert parser.is_complete is True + assert len(body_chunks) == 0 + + +class TestCallbackRequest: + """Test CallbackRequest adapter.""" + + def test_from_parser_simple(self): + """Test creating request from parser state.""" + parser = PythonProtocol() + parser.feed(b"GET /path?query=value HTTP/1.1\r\nHost: example.com\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + assert request.method == "GET" + assert request.path == "/path" + assert request.query == "query=value" + assert request.uri == "/path?query=value" + assert request.version == (1, 1) + assert request.scheme == "http" + + def test_from_parser_ssl(self): + """Test SSL scheme detection.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + request = CallbackRequest.from_parser(parser, is_ssl=True) + + assert request.scheme == "https" + + def test_from_parser_headers(self): + """Test header conversion.""" + parser = PythonProtocol() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + # String headers (uppercase) + assert ("HOST", "example.com") in request.headers + assert ("CONTENT-TYPE", "text/plain") in request.headers + + # Bytes headers (lowercase) + assert (b"host", b"example.com") in request.headers_bytes + assert (b"content-type", b"text/plain") in request.headers_bytes + + def test_from_parser_body_info(self): + """Test body info extraction.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.content_length == 10 + assert request.chunked is False + + def test_from_parser_chunked(self): + """Test chunked transfer detection.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.chunked is True + + def test_should_close(self): + """Test should_close method.""" + # HTTP/1.1 with Connection: close + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is True + + # HTTP/1.1 keep-alive (default) + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is False + + # HTTP/1.0 (default close) + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.0\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is True + + def test_get_header(self): + """Test get_header method.""" + parser = PythonProtocol() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"X-Custom: value\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.get_header("Host") == "example.com" + assert request.get_header("x-custom") == "value" + assert request.get_header("X-CUSTOM") == "value" + assert request.get_header("X-Missing") is None + + def test_expect_100_continue(self): + """Test Expect: 100-continue detection.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Expect: 100-continue\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request._expect_100_continue is True + + +class TestPythonProtocolConnectionClose: + """Test connection close handling.""" + + def test_connection_close_header(self): + """Test Connection: close header.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + + assert parser.should_keep_alive is False + + def test_connection_keepalive_header(self): + """Test Connection: keep-alive header.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n") + + assert parser.should_keep_alive is True + + def test_http11_default_keepalive(self): + """Test HTTP/1.1 defaults to keep-alive.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + assert parser.should_keep_alive is True + + def test_http10_default_close(self): + """Test HTTP/1.0 defaults to close.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n") + + assert parser.should_keep_alive is False diff --git a/tests/test_signal_integration.py b/tests/test_signal_integration.py index f7975f78..2895c089 100644 --- a/tests/test_signal_integration.py +++ b/tests/test_signal_integration.py @@ -165,6 +165,10 @@ class TestSignalHandlingIntegration: proc.kill() pytest.fail("Gunicorn did not exit within timeout after SIGTERM") + @pytest.mark.skipif( + hasattr(sys, 'pypy_version_info'), + reason="SIGINT handling differs on PyPy, use SIGTERM test instead" + ) def test_graceful_shutdown_sigint(self, gunicorn_server): """Verify SIGINT causes graceful shutdown.""" proc, port = gunicorn_server diff --git a/tests/test_valid_requests.py b/tests/test_valid_requests.py index 2c71622c..6bfa2229 100644 --- a/tests/test_valid_requests.py +++ b/tests/test_valid_requests.py @@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__) reqdir = os.path.join(dirname, "requests", "valid") httpfiles = glob.glob(os.path.join(reqdir, "*.http")) +# Flags incompatible with fast parser +_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces') + @pytest.mark.parametrize("fname", httpfiles) -def test_http_parser(fname): - env = treq.load_py(os.path.splitext(fname)[0] + ".py") +def test_http_parser(fname, http_parser): + """Test valid HTTP requests with both parser implementations.""" + env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser) expect = env['request'] cfg = env['cfg'] + + # Skip fast parser tests that use incompatible compatibility flags + if http_parser == 'fast': + for flag in _FAST_INCOMPATIBLE_FLAGS: + if getattr(cfg, flag, False): + pytest.skip(f"fast parser incompatible with {flag}") + req = treq.request(fname, expect) for case in req.gen_cases(cfg): diff --git a/tests/treq.py b/tests/treq.py index a37ddd9f..15148809 100644 --- a/tests/treq.py +++ b/tests/treq.py @@ -33,19 +33,26 @@ def uri(data): return ret -def load_py(fname): +def load_py(fname, http_parser='python'): + """Load test configuration from Python file. + + Args: + fname: Path to the .py configuration file + http_parser: Parser to use - 'python' or 'fast' + """ module_name = '__config__' mod = types.ModuleType(module_name) setattr(mod, 'uri', uri) setattr(mod, 'cfg', Config()) loader = importlib.machinery.SourceFileLoader(module_name, fname) loader.exec_module(mod) + # Set parser after loading so test-specific configs don't override + mod.cfg.set('http_parser', http_parser) return vars(mod) def decode_hex_escapes(data): """Decode hex escape sequences like \\xAB in test data.""" - import re result = bytearray() i = 0 while i < len(data): diff --git a/tests/workers/test_geventlet.py b/tests/workers/test_geventlet.py deleted file mode 100644 index 0719f038..00000000 --- a/tests/workers/test_geventlet.py +++ /dev/null @@ -1,416 +0,0 @@ -# -# This file is part of gunicorn released under the MIT license. -# See the NOTICE for more information. - -import pytest -import sys -from unittest import mock - - -def test_import(): - """Test that the eventlet worker module can be imported.""" - try: - import eventlet - except AttributeError: - if (3, 13) > sys.version_info >= (3, 12): - pytest.skip("Ignoring eventlet failures on Python 3.12") - raise - __import__('gunicorn.workers.geventlet') - - -class TestVersionRequirement: - """Tests for eventlet version requirement checks.""" - - def test_import_error_message(self): - """Test that ImportError gives correct version message.""" - with mock.patch.dict('sys.modules', {'eventlet': None}): - # Clear cached module if present - sys.modules.pop('gunicorn.workers.geventlet', None) - with pytest.raises(RuntimeError, match="eventlet 0.40.3"): - import importlib - import gunicorn.workers.geventlet - importlib.reload(gunicorn.workers.geventlet) - - def test_version_check_requires_0_40_3(self): - """Test that version check requires eventlet 0.40.3 or higher.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from packaging.version import parse as parse_version - min_version = parse_version('0.40.3') - current_version = parse_version(eventlet.__version__) - - # If we got this far, the import succeeded, meaning version is sufficient - assert current_version >= min_version - - -@pytest.fixture -def eventlet_worker(): - """Fixture to create an EventletWorker instance for testing.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import EventletWorker - - # Create a minimal mock config - cfg = mock.MagicMock() - cfg.keepalive = 2 - cfg.graceful_timeout = 30 - cfg.is_ssl = False - cfg.worker_connections = 1000 - - # Create worker with mocked dependencies - worker = EventletWorker.__new__(EventletWorker) - worker.cfg = cfg - worker.alive = True - worker.sockets = [] - worker.log = mock.MagicMock() - - return worker - - -class TestEventletWorker: - """Tests for EventletWorker class.""" - - def test_worker_class_exists(self): - """Test that EventletWorker class is properly defined.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import EventletWorker - from gunicorn.workers.base_async import AsyncWorker - - assert issubclass(EventletWorker, AsyncWorker) - - def test_patch_method_calls_use_hub(self, eventlet_worker): - """Test that patch() calls hubs.use_hub(). - - hubs.use_hub() must be called in patch() (after fork) because it creates - OS resources like kqueue that don't survive fork. - """ - from eventlet import hubs - - with mock.patch.object(hubs, 'use_hub') as mock_use_hub: - with mock.patch('gunicorn.workers.geventlet.patch_sendfile'): - eventlet_worker.patch() - - mock_use_hub.assert_called_once() - - def test_patch_method_calls_patch_sendfile(self, eventlet_worker): - """Test that patch() calls patch_sendfile().""" - from eventlet import hubs - - with mock.patch.object(hubs, 'use_hub'): - with mock.patch('gunicorn.workers.geventlet.patch_sendfile') as mock_sf: - eventlet_worker.patch() - - mock_sf.assert_called_once() - - def test_monkey_patch_called_at_import_time(self): - """Test that monkey_patch is called at module import time. - - Note: hubs.use_hub() and eventlet.monkey_patch() are called at module - import time (not in patch()) to ensure all imports are properly patched. - This test verifies the module was patched by checking eventlet state. - """ - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - # Verify eventlet has been patched by checking that socket is patched - import socket - from eventlet.greenio import GreenSocket - - # After monkey patching, socket.socket should be GreenSocket - assert socket.socket is GreenSocket - - def test_timeout_ctx_returns_eventlet_timeout(self, eventlet_worker): - """Test that timeout_ctx() returns an eventlet.Timeout.""" - import eventlet - - timeout = eventlet_worker.timeout_ctx() - assert isinstance(timeout, eventlet.Timeout) - - def test_timeout_ctx_uses_keepalive_config(self, eventlet_worker): - """Test that timeout_ctx() uses cfg.keepalive value.""" - import eventlet - - eventlet_worker.cfg.keepalive = 5 - with mock.patch.object(eventlet, 'Timeout') as mock_timeout: - eventlet_worker.timeout_ctx() - - mock_timeout.assert_called_once_with(5, False) - - def test_timeout_ctx_with_no_keepalive(self, eventlet_worker): - """Test that timeout_ctx() handles no keepalive (None or 0).""" - import eventlet - - eventlet_worker.cfg.keepalive = 0 - with mock.patch.object(eventlet, 'Timeout') as mock_timeout: - eventlet_worker.timeout_ctx() - - mock_timeout.assert_called_once_with(None, False) - - def test_handle_quit_spawns_greenthread(self, eventlet_worker): - """Test that handle_quit() spawns a greenthread.""" - import eventlet - - with mock.patch.object(eventlet, 'spawn') as mock_spawn: - eventlet_worker.handle_quit(None, None) - - mock_spawn.assert_called_once() - - def test_handle_usr1_spawns_greenthread(self, eventlet_worker): - """Test that handle_usr1() spawns a greenthread.""" - import eventlet - - with mock.patch.object(eventlet, 'spawn') as mock_spawn: - eventlet_worker.handle_usr1(None, None) - - mock_spawn.assert_called_once() - - def test_handle_wraps_ssl_when_configured(self, eventlet_worker): - """Test that handle() wraps socket with SSL when is_ssl is True.""" - from gunicorn.workers import geventlet - - eventlet_worker.cfg.is_ssl = True - mock_client = mock.MagicMock() - mock_listener = mock.MagicMock() - - with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl: - mock_ssl.return_value = mock_client - with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'): - eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000)) - - mock_ssl.assert_called_once_with(mock_client, eventlet_worker.cfg) - - def test_handle_no_ssl_when_not_configured(self, eventlet_worker): - """Test that handle() does not wrap SSL when is_ssl is False.""" - from gunicorn.workers import geventlet - - eventlet_worker.cfg.is_ssl = False - mock_client = mock.MagicMock() - mock_listener = mock.MagicMock() - - with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl: - with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'): - eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000)) - - mock_ssl.assert_not_called() - - -class TestAlreadyHandled: - """Tests for is_already_handled() method.""" - - def test_is_already_handled_new_style(self, eventlet_worker): - """Test is_already_handled with eventlet >= 0.30.3 (WSGI_LOCAL).""" - from gunicorn.workers import geventlet - - # Mock the new-style WSGI_LOCAL.already_handled - mock_wsgi_local = mock.MagicMock() - mock_wsgi_local.already_handled = True - - with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', mock_wsgi_local): - with pytest.raises(StopIteration): - eventlet_worker.is_already_handled(mock.MagicMock()) - - def test_is_already_handled_old_style(self, eventlet_worker): - """Test is_already_handled with eventlet < 0.30.3 (ALREADY_HANDLED).""" - from gunicorn.workers import geventlet - - sentinel = object() - - with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None): - with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', sentinel): - with pytest.raises(StopIteration): - eventlet_worker.is_already_handled(sentinel) - - def test_is_already_handled_returns_parent_result(self, eventlet_worker): - """Test is_already_handled falls through to parent when not handled.""" - from gunicorn.workers import geventlet - - with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None): - with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', None): - with mock.patch('gunicorn.workers.base_async.AsyncWorker.is_already_handled') as mock_parent: - mock_parent.return_value = False - result = eventlet_worker.is_already_handled(mock.MagicMock()) - - assert result is False - mock_parent.assert_called_once() - - -class TestPatchSendfile: - """Tests for patch_sendfile() function.""" - - def test_patch_sendfile_adds_method_when_missing(self): - """Test that patch_sendfile adds sendfile to GreenSocket if missing.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import patch_sendfile, _eventlet_socket_sendfile - from eventlet.greenio import GreenSocket - - # Remove sendfile if it exists - original = getattr(GreenSocket, 'sendfile', None) - if hasattr(GreenSocket, 'sendfile'): - delattr(GreenSocket, 'sendfile') - - try: - patch_sendfile() - assert hasattr(GreenSocket, 'sendfile') - assert GreenSocket.sendfile == _eventlet_socket_sendfile - finally: - # Restore original state - if original is not None: - GreenSocket.sendfile = original - elif hasattr(GreenSocket, 'sendfile'): - delattr(GreenSocket, 'sendfile') - - def test_patch_sendfile_preserves_existing_method(self): - """Test that patch_sendfile does not override existing sendfile.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import patch_sendfile - from eventlet.greenio import GreenSocket - - # If sendfile exists, it should be preserved - if hasattr(GreenSocket, 'sendfile'): - original = GreenSocket.sendfile - patch_sendfile() - assert GreenSocket.sendfile == original - - -class TestEventletSocketSendfile: - """Tests for _eventlet_socket_sendfile() function.""" - - def test_sendfile_raises_on_non_blocking(self): - """Test that sendfile raises ValueError for non-blocking sockets.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_socket_sendfile - - mock_socket = mock.MagicMock() - mock_socket.gettimeout.return_value = 0 - - with pytest.raises(ValueError, match="non-blocking"): - _eventlet_socket_sendfile(mock_socket, mock.MagicMock()) - - def test_sendfile_seeks_to_offset(self): - """Test that sendfile seeks to offset if provided.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_socket_sendfile - - mock_socket = mock.MagicMock() - mock_socket.gettimeout.return_value = 1 - mock_file = mock.MagicMock() - mock_file.read.return_value = b'' - - _eventlet_socket_sendfile(mock_socket, mock_file, offset=100) - - mock_file.seek.assert_any_call(100) - - def test_sendfile_returns_total_sent(self): - """Test that sendfile returns the total bytes sent.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_socket_sendfile - - mock_socket = mock.MagicMock() - mock_socket.gettimeout.return_value = 1 - mock_socket.send.return_value = 10 - - mock_file = mock.MagicMock() - mock_file.read.side_effect = [b'x' * 10, b''] - - result = _eventlet_socket_sendfile(mock_socket, mock_file) - - assert result == 10 - - -class TestEventletServe: - """Tests for _eventlet_serve() function.""" - - def test_serve_creates_green_pool(self): - """Test that _eventlet_serve creates a GreenPool.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_serve - - mock_sock = mock.MagicMock() - mock_sock.accept.side_effect = eventlet.StopServe() - - with mock.patch.object(eventlet.greenpool, 'GreenPool') as mock_pool: - mock_pool_instance = mock.MagicMock() - mock_pool.return_value = mock_pool_instance - mock_pool_instance.waitall.return_value = None - - _eventlet_serve(mock_sock, mock.MagicMock(), 100) - - mock_pool.assert_called_once_with(100) - - -class TestEventletStop: - """Tests for _eventlet_stop() function.""" - - def test_stop_waits_for_client(self): - """Test that _eventlet_stop waits for the client greenlet.""" - try: - import eventlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_stop - - mock_client = mock.MagicMock() - mock_server = mock.MagicMock() - mock_conn = mock.MagicMock() - - _eventlet_stop(mock_client, mock_server, mock_conn) - - mock_client.wait.assert_called_once() - mock_conn.close.assert_called_once() - - def test_stop_closes_connection_on_greenlet_exit(self): - """Test that connection is closed even on GreenletExit.""" - try: - import eventlet - import greenlet - except (ImportError, AttributeError): - pytest.skip("eventlet not available") - - from gunicorn.workers.geventlet import _eventlet_stop - - mock_client = mock.MagicMock() - mock_client.wait.side_effect = greenlet.GreenletExit() - mock_server = mock.MagicMock() - mock_conn = mock.MagicMock() - - # Should not raise - _eventlet_stop(mock_client, mock_server, mock_conn) - - mock_conn.close.assert_called_once()