Merge pull request #3549 from benoitc/feature/optional-http-parser

Optimize ASGI performance with fast parser integration
This commit is contained in:
Benoit Chesneau 2026-03-23 14:21:20 +01:00 committed by GitHub
commit 3667a10478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 3656 additions and 769 deletions

View File

@ -0,0 +1,269 @@
#!/usr/bin/env python
"""
Benchmark comparing HTTP parser implementations.
Compares:
- WSGI Python parser vs Fast parser (gunicorn_h1c)
- ASGI Python parser vs Fast parser (gunicorn_h1c)
Usage:
python benchmarks/http_parser_benchmark.py
"""
import io
import time
import statistics
from typing import NamedTuple
from gunicorn.config import Config
from gunicorn.http.message import Request, _check_fast_parser
from gunicorn.http.unreader import IterUnreader
# Check if fast parser is available
try:
import gunicorn_h1c
FAST_AVAILABLE = True
except ImportError:
FAST_AVAILABLE = False
print("WARNING: gunicorn_h1c not installed. Fast parser benchmarks will be skipped.")
print("Install with: pip install gunicorn_h1c\n")
class BenchmarkResult(NamedTuple):
name: str
iterations: int
total_time: float
avg_time_us: float
min_time_us: float
max_time_us: float
requests_per_sec: float
# Test requests of varying complexity
SIMPLE_REQUEST = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
MEDIUM_REQUEST = b"""POST /api/users HTTP/1.1\r
Host: api.example.com\r
Content-Type: application/json\r
Content-Length: 42\r
Accept: application/json\r
Authorization: Bearer token123\r
X-Request-ID: abc-123-def-456\r
\r
"""
COMPLEX_REQUEST = b"""POST /api/v2/resources/items HTTP/1.1\r
Host: api.example.com\r
Content-Type: application/json; charset=utf-8\r
Content-Length: 1024\r
Accept: application/json, text/plain, */*\r
Accept-Language: en-US,en;q=0.9,fr;q=0.8\r
Accept-Encoding: gzip, deflate, br\r
Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ\r
X-Request-ID: 550e8400-e29b-41d4-a716-446655440000\r
X-Correlation-ID: 7f3d8c2a-1b4e-4a6f-9c8d-2e5f6a7b8c9d\r
X-Forwarded-For: 203.0.113.195, 70.41.3.18, 150.172.238.178\r
X-Forwarded-Proto: https\r
X-Real-IP: 203.0.113.195\r
User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36\r
Cache-Control: no-cache, no-store, must-revalidate\r
Pragma: no-cache\r
Cookie: session=abc123; preferences=dark_mode\r
If-None-Match: "etag-value-here"\r
If-Modified-Since: Wed, 21 Oct 2024 07:28:00 GMT\r
\r
"""
def create_wsgi_config(use_fast: bool) -> Config:
"""Create a config for WSGI parsing."""
cfg = Config()
cfg.set('http_parser', 'fast' if use_fast else 'python')
return cfg
def benchmark_wsgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult:
"""Benchmark WSGI parser."""
times = []
parser_type = cfg.http_parser
for _ in range(iterations):
# Create fresh unreader for each iteration
unreader = IterUnreader(iter([request_data]))
start = time.perf_counter()
req = Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
end = time.perf_counter()
times.append(end - start)
# Verify parsing worked
assert req.method is not None
total_time = sum(times)
avg_time = statistics.mean(times)
min_time = min(times)
max_time = max(times)
return BenchmarkResult(
name=f"WSGI {parser_type}",
iterations=iterations,
total_time=total_time,
avg_time_us=avg_time * 1_000_000,
min_time_us=min_time * 1_000_000,
max_time_us=max_time * 1_000_000,
requests_per_sec=iterations / total_time,
)
def benchmark_asgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult:
"""Benchmark ASGI parser."""
from gunicorn.asgi.parser import HttpParser
times = []
parser_type = cfg.http_parser
for _ in range(iterations):
# Create fresh parser for each iteration
parser = HttpParser(cfg, ('127.0.0.1', 8000), is_ssl=False)
start = time.perf_counter()
result = parser.feed(bytearray(request_data))
end = time.perf_counter()
times.append(end - start)
# Verify parsing worked
assert result is not None
assert result.method is not None
total_time = sum(times)
avg_time = statistics.mean(times)
min_time = min(times)
max_time = max(times)
return BenchmarkResult(
name=f"ASGI {parser_type}",
iterations=iterations,
total_time=total_time,
avg_time_us=avg_time * 1_000_000,
min_time_us=min_time * 1_000_000,
max_time_us=max_time * 1_000_000,
requests_per_sec=iterations / total_time,
)
def print_result(result: BenchmarkResult, baseline: BenchmarkResult = None):
"""Print benchmark result."""
speedup = ""
if baseline and baseline.avg_time_us > 0:
ratio = baseline.avg_time_us / result.avg_time_us
if ratio > 1:
speedup = f" ({ratio:.2f}x faster)"
elif ratio < 1:
speedup = f" ({1/ratio:.2f}x slower)"
print(f" {result.name:20} {result.avg_time_us:8.2f} us/req "
f"({result.requests_per_sec:,.0f} req/s){speedup}")
def run_benchmark_suite(name: str, request_data: bytes, iterations: int):
"""Run a complete benchmark suite for a request type."""
print(f"\n{'='*60}")
print(f"Benchmark: {name}")
print(f"Request size: {len(request_data)} bytes, Iterations: {iterations:,}")
print('='*60)
results = []
# WSGI Python
cfg_python = create_wsgi_config(use_fast=False)
result_wsgi_python = benchmark_wsgi_parser(request_data, cfg_python, iterations)
results.append(result_wsgi_python)
# WSGI Fast (if available)
if FAST_AVAILABLE:
cfg_fast = create_wsgi_config(use_fast=True)
result_wsgi_fast = benchmark_wsgi_parser(request_data, cfg_fast, iterations)
results.append(result_wsgi_fast)
# ASGI Python
cfg_python = create_wsgi_config(use_fast=False)
result_asgi_python = benchmark_asgi_parser(request_data, cfg_python, iterations)
results.append(result_asgi_python)
# ASGI Fast (if available)
if FAST_AVAILABLE:
cfg_fast = create_wsgi_config(use_fast=True)
result_asgi_fast = benchmark_asgi_parser(request_data, cfg_fast, iterations)
results.append(result_asgi_fast)
# Print results
print("\nResults (avg time per request):")
print("-" * 60)
# Print WSGI results
print_result(result_wsgi_python)
if FAST_AVAILABLE:
print_result(result_wsgi_fast, result_wsgi_python)
print()
# Print ASGI results
print_result(result_asgi_python)
if FAST_AVAILABLE:
print_result(result_asgi_fast, result_asgi_python)
return results
def main():
print("HTTP Parser Benchmark")
print("=" * 60)
print(f"Fast parser (gunicorn_h1c): {'Available' if FAST_AVAILABLE else 'Not installed'}")
# Warmup
print("\nWarming up...")
cfg = create_wsgi_config(use_fast=False)
for _ in range(100):
unreader = IterUnreader(iter([SIMPLE_REQUEST]))
Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
if FAST_AVAILABLE:
cfg = create_wsgi_config(use_fast=True)
for _ in range(100):
unreader = IterUnreader(iter([SIMPLE_REQUEST]))
Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
# Run benchmarks
iterations = 10000
all_results = []
all_results.extend(run_benchmark_suite("Simple GET Request", SIMPLE_REQUEST, iterations))
all_results.extend(run_benchmark_suite("Medium POST Request", MEDIUM_REQUEST, iterations))
all_results.extend(run_benchmark_suite("Complex POST Request", COMPLEX_REQUEST, iterations))
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
if FAST_AVAILABLE:
# Calculate overall speedups
wsgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI python"])
wsgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI fast"])
asgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI python"])
asgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI fast"])
print(f"\nWSGI: Fast parser is {wsgi_python_avg/wsgi_fast_avg:.2f}x faster than Python parser")
print(f"ASGI: Fast parser is {asgi_python_avg/asgi_fast_avg:.2f}x faster than Python parser")
else:
print("\nInstall gunicorn_h1c to see fast parser comparison:")
print(" pip install gunicorn_h1c")
print()
if __name__ == "__main__":
main()

View File

@ -4,12 +4,18 @@
# Simple WSGI app for benchmarking
import time
def application(environ, start_response):
"""Basic hello world response."""
path = environ.get('PATH_INFO', '/')
if path == '/large':
body = b'X' * 65536 # 64KB
elif path == '/slow':
time.sleep(0.01) # 10ms simulated I/O
body = b'Slow response'
else:
body = b'Hello, World!'

View File

@ -24,9 +24,11 @@ gunicorn main:app --worker-class asgi --bind 0.0.0.0:8000
The ASGI worker provides:
- **HTTP/1.1** with keepalive connections
- **HTTP/2** with multiplexing and server push (requires SSL)
- **WebSocket** support for real-time applications
- **Lifespan protocol** for startup/shutdown hooks
- **Optional uvloop** for improved performance
- **Optional fast HTTP parser** via C extension for high throughput
- **Optional uvloop** for improved event loop performance
- **SSL/TLS** support
- **uWSGI protocol** for nginx `uwsgi_pass` integration
@ -225,6 +227,56 @@ asgi_loop = "auto" # Use uvloop if available
asgi_lifespan = "auto" # Auto-detect lifespan support
```
## Performance
### Fast HTTP Parser
For maximum performance, install the optional `gunicorn_h1c` C extension:
```bash
pip install gunicorn[fast]
```
This provides a high-performance HTTP parser using picohttpparser with SIMD
optimizations, offering significant speedups for HTTP parsing compared to the
pure Python implementation.
The parser is automatically used when available (`--http-parser auto`), or you
can explicitly require it:
```bash
gunicorn myapp:app --worker-class asgi --http-parser fast
```
| Parser | Description |
|--------|-------------|
| `auto` | Use fast parser if available, otherwise Python (default) |
| `fast` | Require fast parser, fail if unavailable |
| `python` | Force pure Python parser |
### Performance Tips
1. **Use uvloop** for improved event loop performance:
```bash
pip install uvloop
gunicorn myapp:app --worker-class asgi --asgi-loop uvloop
```
2. **Install the fast parser** for optimized HTTP parsing:
```bash
pip install gunicorn[fast]
```
3. **Tune worker count** based on CPU cores:
```bash
gunicorn myapp:app --worker-class asgi --workers $(nproc)
```
4. **Increase connections** for I/O-bound applications:
```bash
gunicorn myapp:app --worker-class asgi --worker-connections 2000
```
## Comparison with Other ASGI Servers
| Feature | Gunicorn ASGI | Uvicorn | Hypercorn |

View File

@ -3,6 +3,14 @@
## unreleased
### New Features
- **Fast HTTP Parser (gunicorn_h1c 0.4.1)**: Integrate new exception types and limit
parameters from gunicorn_h1c 0.4.1 for both WSGI and ASGI workers
- Requires gunicorn_h1c >= 0.4.1 for `http_parser='fast'`
- Falls back to Python parser in `auto` mode if version not met
- Proper HTTP status codes for limit errors (414, 431)
### Performance
- **ASGI HTTP Parser Optimizations**: Improve ASGI worker HTTP parsing performance

View File

@ -1971,3 +1971,23 @@ need to increase this value.
This setting only affects the ``asgi`` worker type.
!!! info "Added in 25.0.0"
### `http_parser`
**Command line:** `--http-parser STRING`
**Default:** `'auto'`
HTTP parser implementation for ASGI workers.
- auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default)
- fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable)
- python: Force pure Python PythonProtocol parser
ASGI workers use callback-based parsing in data_received() for efficient
incremental parsing. The gunicorn_h1c C extension provides significantly
faster HTTP parsing using picohttpparser with SIMD optimizations.
Install it with: pip install gunicorn[fast]
!!! info "Added in 25.0.0"

View File

@ -6,7 +6,9 @@ WORKDIR /app
RUN pip install --no-cache-dir \
sentence-transformers \
fastapi \
pydantic
pydantic \
pytest \
requests
# Copy gunicorn source
COPY . /app/gunicorn-src

543
gunicorn/asgi/parser.py Normal file
View File

@ -0,0 +1,543 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
HTTP parser for ASGI workers.
Provides callback-based parsing using either the fast C parser (gunicorn_h1c)
or the pure Python PythonProtocol fallback.
"""
class ParseError(Exception):
"""Base error raised during HTTP parsing."""
class LimitRequestLine(ParseError):
"""Request line exceeds configured limit."""
class LimitRequestHeaders(ParseError):
"""Too many headers or header field too large."""
class InvalidRequestMethod(ParseError):
"""Invalid HTTP method."""
class InvalidHTTPVersion(ParseError):
"""Invalid HTTP version."""
class InvalidHeaderName(ParseError):
"""Invalid header name."""
class InvalidHeader(ParseError):
"""Invalid header value."""
class PythonProtocol:
"""Callback-based HTTP/1.1 parser (pure Python fallback).
Mirrors H1CProtocol interface for seamless switching between
the C extension and pure Python implementations.
Callbacks:
on_message_begin: () -> None - Called when request starts
on_url: (url: bytes) -> None - Called with request URL/path
on_header: (name: bytes, value: bytes) -> None - Called for each header
on_headers_complete: () -> bool - Called when headers done (return True to skip body)
on_body: (chunk: bytes) -> None - Called with body data chunks
on_message_complete: () -> None - Called when request is complete
"""
__slots__ = (
'_on_message_begin', '_on_url', '_on_header',
'_on_headers_complete', '_on_body', '_on_message_complete',
'_state', '_buffer', '_headers_list',
'method', 'path', 'http_version', 'headers',
'content_length', 'is_chunked', 'should_keep_alive', 'is_complete',
'_body_remaining', '_skip_body',
'_chunk_state', '_chunk_size', '_chunk_remaining',
'_limit_request_line', '_limit_request_fields', '_limit_request_field_size',
'_permit_unconventional_http_method', '_permit_unconventional_http_version',
'_header_count',
)
def __init__(
self,
on_message_begin=None,
on_url=None,
on_header=None,
on_headers_complete=None,
on_body=None,
on_message_complete=None,
limit_request_line=8190,
limit_request_fields=100,
limit_request_field_size=8190,
permit_unconventional_http_method=False,
permit_unconventional_http_version=False,
):
self._on_message_begin = on_message_begin
self._on_url = on_url
self._on_header = on_header
self._on_headers_complete = on_headers_complete
self._on_body = on_body
self._on_message_complete = on_message_complete
# Store limits
self._limit_request_line = limit_request_line
self._limit_request_fields = limit_request_fields
self._limit_request_field_size = limit_request_field_size
self._permit_unconventional_http_method = permit_unconventional_http_method
self._permit_unconventional_http_version = permit_unconventional_http_version
self._header_count = 0
# Parser state: request_line, headers, body, chunked_size, chunked_data, complete
self._state = 'request_line'
self._buffer = bytearray()
self._headers_list = []
# Request info (populated during parsing)
self.method = None
self.path = None
self.http_version = None
self.headers = []
self.content_length = None
self.is_chunked = False
self.should_keep_alive = True
self.is_complete = False
# Body state
self._body_remaining = 0
self._skip_body = False
# Chunked transfer state
self._chunk_state = 'size' # size, data, trailer
self._chunk_size = 0
self._chunk_remaining = 0
def feed(self, data):
"""Process data, fire callbacks synchronously.
Args:
data: bytes or bytearray of incoming data
Raises:
ParseError: If the HTTP request is malformed
"""
self._buffer.extend(data)
while self._buffer:
if self._state == 'request_line':
if not self._parse_request_line():
break
elif self._state == 'headers':
if not self._parse_headers():
break
elif self._state == 'body':
if not self._parse_body():
break
elif self._state == 'chunked':
if not self._parse_chunked_body():
break
else:
break
def reset(self):
"""Reset for next request (keepalive)."""
self._state = 'request_line'
self._buffer.clear()
self._headers_list = []
self.method = None
self.path = None
self.http_version = None
self.headers = []
self.content_length = None
self.is_chunked = False
self.should_keep_alive = True
self.is_complete = False
self._body_remaining = 0
self._skip_body = False
self._chunk_state = 'size'
self._chunk_size = 0
self._chunk_remaining = 0
self._header_count = 0
def _parse_request_line(self):
"""Parse request line, return True if complete."""
idx = self._buffer.find(b'\r\n')
if idx == -1:
return False
# Check request line length limit
if self._limit_request_line > 0 and idx > self._limit_request_line:
raise LimitRequestLine("Request line is too large")
line = bytes(self._buffer[:idx])
del self._buffer[:idx + 2]
# Parse: METHOD PATH HTTP/x.y
parts = line.split(b' ', 2)
if len(parts) != 3:
raise ParseError("Invalid request line")
self.method = parts[0]
self.path = parts[1]
# Validate method
if not self._permit_unconventional_http_method:
if not self._is_valid_method(self.method):
raise InvalidRequestMethod(self.method.decode('latin-1'))
# Parse version
version = parts[2]
if version == b'HTTP/1.1':
self.http_version = (1, 1)
elif version == b'HTTP/1.0':
self.http_version = (1, 0)
else:
if not self._permit_unconventional_http_version:
raise InvalidHTTPVersion(version.decode('latin-1'))
# Try to parse other HTTP/1.x versions if permitted
if version.startswith(b'HTTP/1.'):
try:
minor = int(version[7:])
self.http_version = (1, minor)
except ValueError:
raise InvalidHTTPVersion(version.decode('latin-1'))
else:
raise InvalidHTTPVersion(version.decode('latin-1'))
if self._on_message_begin:
self._on_message_begin()
if self._on_url:
self._on_url(self.path)
self._state = 'headers'
return True
def _parse_headers(self):
"""Parse headers, return True if headers are complete."""
while True:
idx = self._buffer.find(b'\r\n')
if idx == -1:
return False
line = bytes(self._buffer[:idx])
del self._buffer[:idx + 2]
if not line:
# Empty line = end of headers
self._finalize_headers()
return True
# Check header field size limit
if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size:
raise LimitRequestHeaders("Request header field is too large")
# Check header count limit
self._header_count += 1
if self._limit_request_fields > 0 and self._header_count > self._limit_request_fields:
raise LimitRequestHeaders("Too many headers")
# Parse header
colon = line.find(b':')
if colon == -1:
raise InvalidHeader("Missing colon in header")
name = line[:colon].strip()
if not self._is_valid_token(name):
raise InvalidHeaderName(name.decode('latin-1'))
value = line[colon + 1:].strip()
if self._has_invalid_header_chars(value):
raise InvalidHeader("Invalid characters in header value")
# Store lowercase name for internal use
name_lower = name.lower()
self._headers_list.append((name_lower, value))
if self._on_header:
self._on_header(name_lower, value)
def _finalize_headers(self):
"""Called when all headers received."""
self.headers = self._headers_list
# Extract content-length and chunked
for name, value in self.headers:
if name == b'content-length':
self.content_length = int(value)
self._body_remaining = self.content_length
elif name == b'transfer-encoding':
self.is_chunked = b'chunked' in value.lower()
elif name == b'connection':
val = value.lower()
if b'close' in val:
self.should_keep_alive = False
elif b'keep-alive' in val:
self.should_keep_alive = True
# HTTP/1.0 defaults to close
if self.http_version == (1, 0) and self.should_keep_alive:
# Only keep-alive if explicitly requested
has_keepalive = any(
name == b'connection' and b'keep-alive' in value.lower()
for name, value in self.headers
)
if not has_keepalive:
self.should_keep_alive = False
if self._on_headers_complete:
self._skip_body = self._on_headers_complete()
# Determine next state
if self._skip_body:
self._state = 'complete'
self.is_complete = True
if self._on_message_complete:
self._on_message_complete()
elif self.is_chunked:
self._state = 'chunked'
self._chunk_state = 'size'
elif self.content_length and self.content_length > 0:
self._state = 'body'
else:
# No body
self._state = 'complete'
self.is_complete = True
if self._on_message_complete:
self._on_message_complete()
def _parse_body(self):
"""Parse Content-Length delimited body."""
if not self._buffer or self._body_remaining <= 0:
return False
chunk_size = min(len(self._buffer), self._body_remaining)
chunk = bytes(self._buffer[:chunk_size])
del self._buffer[:chunk_size]
self._body_remaining -= chunk_size
if self._on_body:
self._on_body(chunk)
if self._body_remaining <= 0:
self._state = 'complete'
self.is_complete = True
if self._on_message_complete:
self._on_message_complete()
return True
def _parse_chunked_body(self):
"""Parse chunked transfer encoding."""
while self._buffer:
if self._chunk_state == 'size':
# Looking for chunk size line
idx = self._buffer.find(b'\r\n')
if idx == -1:
return False
size_line = bytes(self._buffer[:idx])
del self._buffer[:idx + 2]
# Handle chunk extensions (e.g., "5;ext=value")
semicolon = size_line.find(b';')
if semicolon != -1:
size_line = size_line[:semicolon].strip()
try:
self._chunk_size = int(size_line, 16)
except ValueError:
raise ParseError("Invalid chunk size")
if self._chunk_size == 0:
# Final chunk - skip trailers
self._chunk_state = 'trailer'
else:
self._chunk_remaining = self._chunk_size
self._chunk_state = 'data'
elif self._chunk_state == 'data':
# Reading chunk data
if not self._buffer:
return False
to_read = min(len(self._buffer), self._chunk_remaining)
chunk = bytes(self._buffer[:to_read])
del self._buffer[:to_read]
self._chunk_remaining -= to_read
if self._on_body:
self._on_body(chunk)
if self._chunk_remaining == 0:
# Need to consume trailing CRLF
self._chunk_state = 'crlf'
elif self._chunk_state == 'crlf':
# Skip CRLF after chunk data
if len(self._buffer) < 2:
return False
del self._buffer[:2] # Skip \r\n
self._chunk_state = 'size'
elif self._chunk_state == 'trailer':
# Skip trailer headers
idx = self._buffer.find(b'\r\n')
if idx == -1:
return False
line = bytes(self._buffer[:idx])
del self._buffer[:idx + 2]
if not line:
# Empty line = end of trailers
self._state = 'complete'
self.is_complete = True
if self._on_message_complete:
self._on_message_complete()
return True
return False
def _is_valid_method(self, method):
"""Check if method is valid token with conventional restrictions."""
if not method:
return False
# Check length (3-20 chars)
if not 3 <= len(method) <= 20:
return False
# Check for lowercase or # (unconventional)
for c in method:
if c in b'abcdefghijklmnopqrstuvwxyz#':
return False
return self._is_valid_token(method)
def _is_valid_token(self, data):
"""Check if data contains only RFC 9110 token characters."""
if not data:
return False
for c in data:
if c < 0x21 or c > 0x7e:
return False
# RFC 9110 delimiters: "(),/:;<=>?@[\]{}
if c in b'"(),/:;<=>?@[\\]{}"':
return False
return True
def _has_invalid_header_chars(self, value):
"""Check for NUL, CR, LF in header value."""
return b'\x00' in value or b'\r' in value or b'\n' in value
class CallbackRequest:
"""Request object built from callback parser state.
Works with both H1CProtocol (C extension) and PythonProtocol.
"""
__slots__ = (
'method', 'uri', 'path', 'query', 'fragment', 'version',
'headers', 'headers_bytes', 'scheme', 'raw_path',
'content_length', 'chunked', 'must_close',
'proxy_protocol_info', '_expect_100_continue',
)
def __init__(self):
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = None
self.version = None
self.headers = []
self.headers_bytes = []
self.scheme = "http"
self.raw_path = b''
self.content_length = 0
self.chunked = False
self.must_close = False
self.proxy_protocol_info = None
self._expect_100_continue = False
@classmethod
def from_parser(cls, parser, is_ssl=False):
"""Build request from callback parser state.
Args:
parser: H1CProtocol or PythonProtocol instance
is_ssl: Whether connection is SSL/TLS
Returns:
CallbackRequest instance
"""
from urllib.parse import unquote_to_bytes
req = cls()
req.method = parser.method.decode('ascii')
# Parse path and query from URL
# Per ASGI spec:
# - path: percent-decoded UTF-8 string
# - raw_path: original bytes as received
raw_url = parser.path
if b'?' in raw_url:
path_part, query_part = raw_url.split(b'?', 1)
req.raw_path = path_part # Store original bytes
req.path = unquote_to_bytes(path_part).decode('utf-8', errors='replace')
req.query = query_part.decode('latin-1')
else:
req.raw_path = raw_url # Store original bytes
req.path = unquote_to_bytes(raw_url).decode('utf-8', errors='replace')
req.query = ''
req.uri = raw_url.decode('latin-1')
req.fragment = ''
req.version = parser.http_version
# Headers - store both bytes (for ASGI scope) and strings (for compatibility)
req.headers_bytes = list(parser.headers)
req.headers = [
(n.decode('latin-1').upper(), v.decode('latin-1'))
for n, v in parser.headers
]
req.scheme = 'https' if is_ssl else 'http'
req.content_length = parser.content_length or 0
req.chunked = parser.is_chunked
req.must_close = not parser.should_keep_alive
# Check for Expect: 100-continue
for name, value in parser.headers:
if name == b'expect' and value.lower() == b'100-continue':
req._expect_100_continue = True
break
return req
def should_close(self):
"""Check if connection should be closed after this request."""
if self.must_close:
return True
for name, value in self.headers:
if name == "CONNECTION":
v = value.lower().strip(" \t")
if v == "close":
return True
elif v == "keep-alive":
return False
break
return self.version <= (1, 0)
def get_header(self, name):
"""Get a header value by name (case-insensitive)."""
name = name.upper()
for h, v in self.headers:
if h == name:
return v
return None

File diff suppressed because it is too large Load Diff

View File

@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
class WebSocketProtocol:
"""WebSocket connection handler for ASGI applications."""
"""WebSocket connection handler for ASGI applications.
def __init__(self, transport, reader, scope, app, log):
Uses callback-based data feeding instead of StreamReader for efficiency.
Data is fed via feed_data() from the parent protocol's data_received().
"""
def __init__(self, transport, scope, app, log):
"""Initialize WebSocket protocol handler.
Args:
transport: asyncio transport for writing
reader: asyncio StreamReader for reading
scope: ASGI WebSocket scope dict
app: ASGI application callable
log: Logger instance
"""
self.transport = transport
self.reader = reader
self.scope = scope
self.app = app
self.log = log
@ -70,6 +72,26 @@ class WebSocketProtocol:
# Receive queue for incoming messages
self._receive_queue = asyncio.Queue()
# Callback-based data reception (replaces StreamReader)
self._buffer = bytearray()
self._data_event = asyncio.Event()
self._eof = False
def feed_data(self, data):
"""Feed incoming data from the parent protocol's data_received().
Args:
data: bytes received on the connection
"""
if data:
self._buffer.extend(data)
self._data_event.set()
def feed_eof(self):
"""Signal that the connection has been closed."""
self._eof = True
self._data_event.set()
async def run(self):
"""Run the WebSocket ASGI application."""
# Send initial connect event
@ -295,14 +317,25 @@ class WebSocketProtocol:
return (opcode, payload)
async def _read_exact(self, n):
"""Read exactly n bytes from the reader."""
try:
data = await self.reader.readexactly(n)
return data
except asyncio.IncompleteReadError:
return None
except Exception:
return None
"""Read exactly n bytes from internal buffer.
Waits for data via the callback-fed buffer instead of StreamReader.
"""
while len(self._buffer) < n:
if self._eof:
return None
self._data_event.clear()
# Critical: check buffer AGAIN after clearing to avoid race
# condition where data arrives between clear() and wait()
if len(self._buffer) >= n:
break
await self._data_event.wait()
if self._eof and len(self._buffer) < n:
return None
data = bytes(self._buffer[:n])
del self._buffer[:n]
return data
def _unmask(self, payload, masking_key):
"""Unmask WebSocket payload data."""

View File

@ -2772,6 +2772,22 @@ def validate_asgi_lifespan(val):
return val
def validate_http_parser(val):
"""Validate http_parser setting.
Accepts: auto, fast, python
"""
if val is None:
return "auto"
if not isinstance(val, str):
raise TypeError("http_parser must be a string")
val = val.lower().strip()
valid_values = ("auto", "fast", "python")
if val not in valid_values:
raise ValueError("http_parser must be one of: %s" % ", ".join(valid_values))
return val
class ASGILoop(Setting):
name = "asgi_loop"
section = "Worker Processes"
@ -2845,6 +2861,30 @@ class ASGIDisconnectGracePeriod(Setting):
"""
class HttpParser(Setting):
name = "http_parser"
section = "Worker Processes"
cli = ["--http-parser"]
meta = "STRING"
validator = validate_http_parser
default = "auto"
desc = """\
HTTP parser implementation for ASGI workers.
- auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default)
- fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable)
- python: Force pure Python PythonProtocol parser
ASGI workers use callback-based parsing in data_received() for efficient
incremental parsing. The gunicorn_h1c C extension provides significantly
faster HTTP parsing using picohttpparser with SIMD optimizations.
Install it with: pip install gunicorn[fast]
.. versionadded:: 25.0.0
"""
class RootPath(Setting):
name = "root_path"
section = "Server Mechanics"

View File

@ -341,14 +341,24 @@ class Logger:
return atoms
@property
def access_log_enabled(self):
"""Check if access logging is enabled.
Used by protocol handlers to skip building log data when logging is disabled.
"""
return bool(
self.cfg.accesslog or self.cfg.logconfig or
self.cfg.logconfig_dict or self.cfg.logconfig_json or
(self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog)
)
def access(self, resp, req, environ, request_time):
""" See http://httpd.apache.org/docs/2.0/logs.html#combined
for format details
"""
if not (self.cfg.accesslog or self.cfg.logconfig or
self.cfg.logconfig_dict or self.cfg.logconfig_json or
(self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog)):
if not self.access_log_enabled:
return
# wrap atoms:

View File

@ -114,11 +114,13 @@ class ChunkMissingTerminator(IOError):
class LimitRequestLine(ParseException):
def __init__(self, size, max_size):
def __init__(self, size, max_size=None):
self.size = size
self.max_size = max_size
def __str__(self):
if self.max_size is None:
return str(self.size)
return "Request Line is too large (%s > %s)" % (self.size, self.max_size)

View File

@ -21,6 +21,80 @@ from gunicorn.http.errors import InvalidSchemeHeaders
from gunicorn.util import bytes_to_str, split_request_uri
# Fast parser availability (cached at module level)
_fast_parser_available = None
_fast_parser_module = None
# Compatibility flags not supported by the fast parser
_FAST_PARSER_INCOMPATIBLE_FLAGS = (
'permit_obsolete_folding',
'strip_header_spaces',
)
def _check_fast_parser(cfg):
"""Check if fast C parser is available and should be used.
Returns False if:
- http_parser='python' is explicitly set
- gunicorn_h1c is not installed (in 'auto' mode)
- gunicorn_h1c < 0.4.1 (in 'auto' mode)
- Incompatible compatibility flags are enabled (in 'auto' mode)
Raises RuntimeError if:
- http_parser='fast' but gunicorn_h1c is not installed
- http_parser='fast' but gunicorn_h1c < 0.4.1
- http_parser='fast' but incompatible flags are enabled
"""
global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement
parser_setting = getattr(cfg, 'http_parser', 'auto')
if parser_setting == 'python':
return False
if _fast_parser_available is None:
try:
import gunicorn_h1c
_fast_parser_available = True
_fast_parser_module = gunicorn_h1c
except ImportError:
_fast_parser_available = False
if not _fast_parser_available and parser_setting == 'fast':
raise RuntimeError("gunicorn_h1c not installed but http_parser='fast'")
if not _fast_parser_available:
return False
# Require >= 0.4.1 for limit enforcement
if not hasattr(_fast_parser_module, 'LimitRequestLine'):
if parser_setting == 'fast':
raise RuntimeError(
"gunicorn_h1c >= 0.4.1 required for http_parser='fast'. "
"Please upgrade: pip install --upgrade gunicorn_h1c"
)
# In 'auto' mode, fall back to Python parser
return False
# Check for incompatible compatibility flags
incompatible = []
for flag in _FAST_PARSER_INCOMPATIBLE_FLAGS:
if getattr(cfg, flag, False):
incompatible.append(flag)
if incompatible:
if parser_setting == 'fast':
raise RuntimeError(
"http_parser='fast' is incompatible with compatibility flags: %s. "
"Use http_parser='python' or disable these flags."
% ', '.join(incompatible)
)
# In 'auto' mode, fall back to Python parser
return False
return True
# PROXY protocol v2 constants
PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
@ -99,7 +173,7 @@ class Message:
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = cfg.limit_request_field_size
if self.limit_request_field_size < 0:
if self.limit_request_field_size <= 0:
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
# set max header buffer size
@ -315,12 +389,16 @@ class Request(Message):
# get max request line size
self.limit_request_line = cfg.limit_request_line
if (self.limit_request_line < 0
if (self.limit_request_line <= 0
or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE
self.req_number = req_number
self.proxy_protocol_info = None
# Check if fast parser should be used
self._use_fast = _check_fast_parser(cfg)
super().__init__(cfg, unreader, peer_addr)
def get_data(self, unreader, buf, stop=False):
@ -340,6 +418,102 @@ class Request(Message):
if mode != "off" and self.req_number == 1:
buf = self._handle_proxy_protocol(unreader, buf, mode)
# Use fast parser if available
if self._use_fast:
return self._parse_fast(unreader, buf)
return self._parse_python(unreader, buf)
def _parse_fast(self, unreader, buf):
"""Parse request using fast C parser (gunicorn_h1c >= 0.4.1)."""
# Read until we have complete headers
data = bytes(buf)
last_len = 0
while True:
try:
# Pass all limit parameters to C parser
result = _fast_parser_module.parse_request(
data,
last_len=last_len,
limit_request_line=self.limit_request_line,
limit_request_fields=self.limit_request_fields,
limit_request_field_size=self.limit_request_field_size,
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
)
break
except _fast_parser_module.IncompleteError:
last_len = len(data)
self.read_into(unreader, buf)
data = bytes(buf)
if len(data) > self.max_buffer_headers + self.limit_request_line:
raise LimitRequestHeaders("max buffer headers")
except _fast_parser_module.LimitRequestLine as e:
raise LimitRequestLine(str(e))
except _fast_parser_module.LimitRequestHeaders as e:
raise LimitRequestHeaders(str(e))
except _fast_parser_module.InvalidRequestMethod as e:
raise InvalidRequestMethod(str(e))
except _fast_parser_module.InvalidHTTPVersion as e:
raise InvalidHTTPVersion(str(e))
except _fast_parser_module.InvalidHeaderName as e:
raise InvalidHeaderName(str(e))
except _fast_parser_module.InvalidHeader as e:
raise InvalidHeader(str(e))
except _fast_parser_module.ParseError as e:
raise InvalidRequestLine(str(e))
# Extract parsed data
self.method = bytes_to_str(result['method'])
self.uri = bytes_to_str(result['path'])
# Casefold method if configured (validation done by C parser)
if self.cfg.casefold_http_method:
self.method = self.method.upper()
# Parse URI parts
if len(self.uri) == 0:
raise InvalidRequestLine(self.uri)
try:
parts = split_request_uri(self.uri)
except ValueError:
raise InvalidRequestLine(self.uri)
self.path = parts.path or ""
self.query = parts.query or ""
self.fragment = parts.fragment or ""
# Version (validation done by C parser)
self.version = (1, result['minor_version'])
# Headers - convert bytes to strings with uppercase names
# gunicorn_h1c returns headers as (bytes, bytes) tuples
# Header name/value validation done by C parser
self.headers = []
for name_bytes, value_bytes in result['headers']:
name = bytes_to_str(name_bytes).upper()
value = bytes_to_str(value_bytes)
# Handle underscore in header names (policy decision, not validation)
if "_" in name:
forwarder_headers = self.cfg.forwarder_headers
if name in forwarder_headers or "*" in forwarder_headers:
pass
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue
else:
raise InvalidHeaderName(name)
self.headers.append((name, value))
# Return remaining data after headers
consumed = result['consumed']
return data[consumed:]
def _parse_python(self, unreader, buf):
"""Parse request using pure Python parser."""
# Get request line
line, buf = self.read_line(unreader, buf, self.limit_request_line)

View File

@ -71,6 +71,11 @@ class HTTP2Stream:
self.priority_depends_on = 0
self.priority_exclusive = False
# Streaming body support (avoids buffering entire uploads)
self._body_chunks = []
self._body_event = None # Lazy-init asyncio.Event
self._body_complete = False
@property
def is_client_stream(self):
"""Check if this is a client-initiated stream (odd stream ID)."""
@ -122,7 +127,7 @@ class HTTP2Stream:
self.request_complete = True
def receive_data(self, data, end_stream=False):
"""Process received DATA frame.
"""Process received DATA frame with streaming support.
Args:
data: Bytes received
@ -137,11 +142,21 @@ class HTTP2Stream:
f"Cannot receive data in state {self.state.name}"
)
# Add to chunks queue for streaming reads
if data:
self._body_chunks.append(data)
if self._body_event:
self._body_event.set()
# Also write to legacy BytesIO for compatibility
self.request_body.write(data)
if end_stream:
self._half_close_remote()
self.request_complete = True
self._body_complete = True
if self._body_event:
self._body_event.set()
def receive_trailers(self, trailers):
"""Process received trailing headers.
@ -283,6 +298,35 @@ class HTTP2Stream:
"""
return self.request_body.getvalue()
async def read_body_chunk(self):
"""Read next body chunk asynchronously for streaming.
Returns:
bytes: Next chunk of body data, or None if body is complete.
"""
import asyncio
# Initialize event lazily (avoids event loop issues at construction)
if self._body_event is None:
self._body_event = asyncio.Event()
# If data already arrived before event existed, set it now
# This prevents race where DATA frames arrive before first read
if self._body_chunks or self._body_complete:
self._body_event.set()
while True:
# Return chunk if available
if self._body_chunks:
return self._body_chunks.pop(0)
# No more data expected
if self._body_complete:
return None
# Wait for more data
self._body_event.clear()
await self._body_event.wait()
def get_pseudo_headers(self):
"""Extract HTTP/2 pseudo-headers from request headers.

View File

@ -53,6 +53,7 @@ tornado = ["tornado>=6.5.0"]
gthread = []
setproctitle = ["setproctitle"]
http2 = ["h2>=4.1.0"]
fast = ["gunicorn_h1c>=0.4.1"]
testing = [
"gevent>=24.10.1",
"eventlet>=0.40.3",

View File

@ -1,6 +1,6 @@
gevent
eventlet
coverage
pytest>=7.2.0
pytest-cov
pytest-asyncio
gunicorn_h1c>=0.4.1

View File

@ -7,8 +7,21 @@
import os
import sys
import pytest
# Add the tests directory to sys.path so test support modules can be imported
# as 'tests.module_name' (e.g., 'tests.support_dirty_apps:CounterApp')
tests_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if tests_dir not in sys.path:
sys.path.insert(0, tests_dir)
@pytest.fixture(params=["python", "fast"])
def http_parser(request):
"""Parametrize tests over http_parser implementations."""
if request.param == "fast":
gunicorn_h1c = pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required")
# Require >= 0.4.1 for limit enforcement
if not hasattr(gunicorn_h1c, 'LimitRequestLine'):
pytest.skip("gunicorn_h1c >= 0.4.1 required")
return request.param

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,11 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestHeaders
cfg = Config()
# Setting limit_request_field_size=0 should use default max (8190)
cfg.set('limit_request_field_size', 0)
request = LimitRequestHeaders

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,11 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestLine
cfg = Config()
# Setting limit_request_line=0 should use default max (8190)
cfg.set('limit_request_line', 0)
request = LimitRequestLine

View File

@ -5,8 +5,9 @@
from gunicorn.config import Config
cfg = Config()
cfg.set('limit_request_line', 0)
cfg.set('limit_request_field_size', 0)
# Request line is 8194 bytes, header line is 8209 bytes (both include CRLF)
cfg.set('limit_request_line', 8200)
cfg.set('limit_request_field_size', 8210)
request = {
"method": "PUT",
"uri":

View File

@ -5,7 +5,7 @@
from gunicorn.config import Config
cfg = Config()
cfg.set('limit_request_line', 0)
# Header line is 8209 bytes (name + ": " + value + CRLF)
cfg.set('limit_request_field_size', 8210)
request = {
"method": "GET",

View File

@ -7,10 +7,8 @@ Tests for ASGI worker components.
"""
import asyncio
import io
import ipaddress
import pytest
from unittest import mock
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest

View File

@ -0,0 +1,518 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for ASGI callback parsers.
Tests both PythonProtocol and H1CProtocol (if available) to ensure
consistent behavior across implementations.
"""
from gunicorn.asgi.parser import PythonProtocol
def get_parser_class(http_parser):
"""Get the appropriate parser class for the test parameter."""
if http_parser == "fast":
from gunicorn_h1c import H1CProtocol
return H1CProtocol
return PythonProtocol
def normalize_headers(headers):
"""Normalize headers to lowercase names for comparison.
H1CProtocol preserves original case, PythonProtocol lowercases.
"""
return {name.lower(): value for name, value in headers}
class TestRequestLineParsing:
"""Test request line parsing for both implementations."""
def test_simple_get(self, http_parser):
"""Parse a simple GET request."""
parser_class = get_parser_class(http_parser)
events = []
parser = parser_class(
on_message_begin=lambda: events.append('begin'),
on_url=lambda url: events.append(('url', url)),
on_headers_complete=lambda: events.append('headers_complete'),
on_message_complete=lambda: events.append('complete'),
)
parser.feed(b"GET /path HTTP/1.1\r\n\r\n")
assert parser.method == b"GET"
assert parser.path == b"/path"
assert parser.http_version == (1, 1)
assert parser.is_complete
assert 'begin' in events
assert ('url', b'/path') in events
assert 'complete' in events
def test_post_with_query(self, http_parser):
"""Parse a POST request with query string."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(b"POST /api/data?foo=bar&baz=qux HTTP/1.1\r\n\r\n")
assert parser.method == b"POST"
assert parser.path == b"/api/data?foo=bar&baz=qux"
assert parser.http_version == (1, 1)
assert parser.is_complete
def test_http_10_version(self, http_parser):
"""Parse HTTP/1.0 request."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(b"GET / HTTP/1.0\r\n\r\n")
assert parser.method == b"GET"
assert parser.http_version == (1, 0)
assert parser.is_complete
def test_various_methods(self, http_parser):
"""Test parsing various HTTP methods."""
parser_class = get_parser_class(http_parser)
methods = [b"GET", b"POST", b"PUT", b"DELETE", b"PATCH", b"HEAD", b"OPTIONS"]
for method in methods:
parser = parser_class()
parser.feed(method + b" / HTTP/1.1\r\n\r\n")
assert parser.method == method
class TestHeaderParsing:
"""Test header parsing for both implementations."""
def test_single_header(self, http_parser):
"""Parse a request with single header."""
parser_class = get_parser_class(http_parser)
headers = []
parser = parser_class(
on_header=lambda n, v: headers.append((n, v)),
)
parser.feed(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
assert len(parser.headers) == 1
header_dict = normalize_headers(parser.headers)
assert header_dict[b"host"] == b"localhost"
callback_dict = normalize_headers(headers)
assert callback_dict[b"host"] == b"localhost"
def test_multiple_headers(self, http_parser):
"""Parse a request with multiple headers."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"User-Agent: TestClient\r\n"
b"Accept: */*\r\n"
b"\r\n"
)
assert len(parser.headers) == 3
header_dict = normalize_headers(parser.headers)
assert header_dict[b"host"] == b"localhost"
assert header_dict[b"user-agent"] == b"TestClient"
assert header_dict[b"accept"] == b"*/*"
def test_header_with_spaces(self, http_parser):
"""Parse headers with leading/trailing spaces in values."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost \r\n"
b"\r\n"
)
header_dict = normalize_headers(parser.headers)
assert header_dict[b"host"] == b"localhost"
def test_empty_header_value(self, http_parser):
"""Parse header with empty value."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Empty:\r\n"
b"\r\n"
)
header_dict = normalize_headers(parser.headers)
assert header_dict[b"x-empty"] == b""
def test_large_header_value(self, http_parser):
"""Parse header with large value."""
parser_class = get_parser_class(http_parser)
large_value = b"x" * 4096
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Large: " + large_value + b"\r\n"
b"\r\n"
)
header_dict = normalize_headers(parser.headers)
assert header_dict[b"x-large"] == large_value
class TestBodyHandling:
"""Test body parsing for both implementations."""
def test_content_length_body(self, http_parser):
"""Parse request with Content-Length body."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 13\r\n"
b"\r\n"
b"Hello, World!"
)
assert parser.content_length == 13
assert not parser.is_chunked
assert b"".join(body_chunks) == b"Hello, World!"
assert parser.is_complete
def test_content_length_incremental(self, http_parser):
"""Parse body arriving in multiple chunks."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_body=lambda chunk: body_chunks.append(chunk),
)
# Send headers
parser.feed(
b"POST /data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
assert not parser.is_complete
# Send body in parts
parser.feed(b"Hello")
assert not parser.is_complete
parser.feed(b"World")
assert parser.is_complete
assert b"".join(body_chunks) == b"HelloWorld"
def test_chunked_encoding(self, http_parser):
"""Parse chunked transfer-encoded body."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\n"
b"Hello\r\n"
b"6\r\n"
b"World!\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_chunked
assert b"".join(body_chunks) == b"HelloWorld!"
assert parser.is_complete
def test_chunked_with_extensions(self, http_parser):
"""Parse chunked body with chunk extensions."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /data HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5;ext=value\r\n"
b"Hello\r\n"
b"0\r\n"
b"\r\n"
)
assert b"".join(body_chunks) == b"Hello"
assert parser.is_complete
def test_no_body_get(self, http_parser):
"""GET request has no body."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.content_length is None
assert not parser.is_chunked
assert body_chunks == []
assert parser.is_complete
class TestConnectionHandling:
"""Test connection handling and keep-alive for both implementations."""
def test_http11_keepalive_default(self, http_parser):
"""HTTP/1.1 defaults to keep-alive."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.should_keep_alive is True
def test_http11_connection_close(self, http_parser):
"""HTTP/1.1 with Connection: close."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Connection: close\r\n"
b"\r\n"
)
assert parser.should_keep_alive is False
def test_http10_no_keepalive(self, http_parser):
"""HTTP/1.0 defaults to no keep-alive."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.0\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.should_keep_alive is False
def test_http10_with_keepalive(self, http_parser):
"""HTTP/1.0 with Connection: keep-alive."""
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(
b"GET / HTTP/1.0\r\n"
b"Host: localhost\r\n"
b"Connection: keep-alive\r\n"
b"\r\n"
)
assert parser.should_keep_alive is True
class TestParserReset:
"""Test parser reset for keep-alive connections."""
def test_reset_after_request(self, http_parser):
"""Parser can be reset for a new request."""
parser_class = get_parser_class(http_parser)
complete_count = [0]
parser = parser_class(
on_message_complete=lambda: complete_count.__setitem__(0, complete_count[0] + 1),
)
# First request
parser.feed(b"GET /first HTTP/1.1\r\n\r\n")
assert parser.path == b"/first"
assert parser.is_complete
# Reset and send second request
parser.reset()
assert not parser.is_complete
# H1CProtocol resets to b'', PythonProtocol to None
assert not parser.method
parser.feed(b"GET /second HTTP/1.1\r\n\r\n")
assert parser.path == b"/second"
assert parser.is_complete
assert complete_count[0] == 2
class TestCallbackBehavior:
"""Test callback behavior consistency."""
def test_all_callbacks_fire(self, http_parser):
"""All callbacks fire in correct order."""
parser_class = get_parser_class(http_parser)
events = []
parser = parser_class(
on_message_begin=lambda: events.append('begin'),
on_url=lambda url: events.append(('url', url)),
on_header=lambda n, v: events.append(('header', n.lower(), v)),
on_headers_complete=lambda: events.append('headers_complete'),
on_body=lambda chunk: events.append(('body', chunk)),
on_message_complete=lambda: events.append('complete'),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 4\r\n"
b"\r\n"
b"test"
)
assert events[0] == 'begin'
assert events[1] == ('url', b'/')
assert ('header', b'host', b'localhost') in events
assert ('header', b'content-length', b'4') in events
assert 'headers_complete' in events
assert ('body', b'test') in events
assert events[-1] == 'complete'
def test_skip_body_on_headers_complete(self, http_parser):
"""Return True from on_headers_complete skips body parsing."""
parser_class = get_parser_class(http_parser)
body_chunks = []
parser = parser_class(
on_headers_complete=lambda: True, # Skip body
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
b"0123456789"
)
assert parser.is_complete
assert body_chunks == [] # Body was skipped
class TestCallbackRequest:
"""Test CallbackRequest building from parser state."""
def test_non_ascii_path_decoding(self, http_parser):
"""Test that percent-encoded UTF-8 paths are decoded correctly.
Per ASGI spec:
- path: percent-decoded UTF-8 string
- raw_path: original bytes as received
"""
from gunicorn.asgi.parser import CallbackRequest
parser_class = get_parser_class(http_parser)
parser = parser_class()
# ö = %C3%B6 in UTF-8 percent-encoded
parser.feed(b"GET /%C3%B6/ HTTP/1.1\r\nHost: test\r\n\r\n")
request = CallbackRequest.from_parser(parser)
# path should be percent-decoded UTF-8 string
assert request.path == "/\u00f6/" # /ö/
# raw_path should be original bytes
assert request.raw_path == b"/%C3%B6/"
def test_non_ascii_path_with_query(self, http_parser):
"""Test percent-encoded path with query string."""
from gunicorn.asgi.parser import CallbackRequest
parser_class = get_parser_class(http_parser)
parser = parser_class()
# Japanese: /日本/ = /%E6%97%A5%E6%9C%AC/
parser.feed(b"GET /%E6%97%A5%E6%9C%AC/?q=test HTTP/1.1\r\nHost: test\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.path == "/\u65e5\u672c/" # /日本/
assert request.raw_path == b"/%E6%97%A5%E6%9C%AC/"
assert request.query == "q=test"
def test_invalid_utf8_path(self, http_parser):
"""Test that invalid UTF-8 sequences use replacement character."""
from gunicorn.asgi.parser import CallbackRequest
parser_class = get_parser_class(http_parser)
parser = parser_class()
# %FF is invalid UTF-8
parser.feed(b"GET /%FF HTTP/1.1\r\nHost: test\r\n\r\n")
request = CallbackRequest.from_parser(parser)
# Should use replacement character for invalid bytes
assert "\ufffd" in request.path
assert request.raw_path == b"/%FF"
def test_simple_ascii_path(self, http_parser):
"""Test that simple ASCII paths work unchanged."""
from gunicorn.asgi.parser import CallbackRequest
parser_class = get_parser_class(http_parser)
parser = parser_class()
parser.feed(b"GET /api/users HTTP/1.1\r\nHost: test\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.path == "/api/users"
assert request.raw_path == b"/api/users"
def test_percent_encoded_ascii(self, http_parser):
"""Test percent-encoded ASCII characters."""
from gunicorn.asgi.parser import CallbackRequest
parser_class = get_parser_class(http_parser)
parser = parser_class()
# Space encoded as %20
parser.feed(b"GET /hello%20world HTTP/1.1\r\nHost: test\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.path == "/hello world"
assert request.raw_path == b"/hello%20world"

View File

@ -9,9 +9,10 @@ Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec:
https://asgi.readthedocs.io/en/latest/specs/main.html
"""
import asyncio
from unittest import mock
import pytest
from gunicorn.config import Config
@ -37,7 +38,9 @@ class TestASGIVersion:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -136,7 +139,9 @@ class TestHTTPScopeKeys:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -653,6 +658,7 @@ class TestStateSharing:
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
@ -730,31 +736,46 @@ class TestHTTPDisconnectEvent:
return protocol
def test_disconnect_event_type(self):
"""Test that disconnect event has correct type per ASGI spec."""
"""Test that disconnect event signals body receiver per ASGI spec."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
protocol._receive_queue = asyncio.Queue()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Simulate client disconnect
protocol.connection_lost(None)
# Get the message from queue
msg = protocol._receive_queue.get_nowait()
# Per ASGI spec: type MUST be "http.disconnect"
assert msg["type"] == "http.disconnect"
# Per ASGI spec: disconnect should be signaled
assert body_receiver._closed
def test_disconnect_event_sent_on_connection_lost(self):
"""Test that http.disconnect is sent when connection is lost."""
protocol = self._create_protocol()
protocol._receive_queue = asyncio.Queue()
"""Test that disconnect is signaled when connection is lost."""
from gunicorn.asgi.protocol import BodyReceiver
assert protocol._receive_queue.empty()
protocol = self._create_protocol()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
assert not body_receiver._closed
# Simulate client disconnect
protocol.connection_lost(None)
# Queue should have disconnect message
assert not protocol._receive_queue.empty()
# Disconnect should have been signaled
assert body_receiver._closed
def test_disconnect_sets_closed_flag(self):
"""Test that connection_lost sets the closed flag."""
@ -788,18 +809,33 @@ class TestHTTPDisconnectEvent:
# Cancellation should be scheduled after grace period
protocol.worker.loop.call_later.assert_called_once()
def test_disconnect_message_format(self):
@pytest.mark.asyncio
async def test_disconnect_message_format(self):
"""Test http.disconnect message format per ASGI spec.
The disconnect message should only contain 'type' key.
When body is complete and disconnect is signaled, receive()
should return {"type": "http.disconnect"}.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
protocol._receive_queue = asyncio.Queue()
protocol.connection_lost(None)
# Create a mock request with no body
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
msg = protocol._receive_queue.get_nowait()
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Get initial body message (empty body)
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["more_body"] is False
# Now receive should return disconnect
msg2 = await body_receiver.receive()
# Per ASGI spec, disconnect message only has 'type'
assert msg == {"type": "http.disconnect"}
assert len(msg) == 1
assert msg2 == {"type": "http.disconnect"}
assert len(msg2) == 1

View File

@ -50,41 +50,58 @@ class TestASGIGracefulDisconnect:
assert protocol._closed is True
def test_disconnect_sends_message_to_queue(self, mock_worker):
"""Test that connection_lost sends http.disconnect to receive queue."""
def test_disconnect_signals_body_receiver(self, mock_worker):
"""Test that connection_lost signals the body receiver."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
mock_worker.nr_conns = 1
# Create a receive queue (simulating active request)
protocol._receive_queue = asyncio.Queue()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
# Create a body receiver (simulating active request)
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Verify disconnect flag is not set initially
assert not body_receiver._closed
# Simulate connection lost
protocol.connection_lost(None)
# Check that disconnect message was sent
assert not protocol._receive_queue.empty()
msg = protocol._receive_queue.get_nowait()
assert msg == {"type": "http.disconnect"}
# Check that disconnect flag was set
assert body_receiver._closed
def test_disconnect_is_idempotent(self, mock_worker):
"""Test that connection_lost can be called multiple times safely."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol.reader = mock.Mock()
mock_worker.nr_conns = 2 # Start with 2 so we can verify only 1 is decremented
protocol._receive_queue = asyncio.Queue()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# First call should work
protocol.connection_lost(None)
assert protocol._closed is True
assert mock_worker.nr_conns == 1
assert protocol._receive_queue.qsize() == 1
assert body_receiver._closed
# Second call should be a no-op
protocol.connection_lost(None)
assert mock_worker.nr_conns == 1 # Should not decrement again
assert protocol._receive_queue.qsize() == 1 # Should not add another message
# Closed flag is still set
def test_disconnect_does_not_cancel_immediately(self, mock_worker):
"""Test that connection_lost doesn't cancel task immediately."""
@ -156,41 +173,26 @@ class TestASGIGracefulDisconnect:
@pytest.mark.asyncio
async def test_receive_returns_disconnect_when_closed(self, mock_worker):
"""Test that receive() returns http.disconnect when connection is closed."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = ASGIProtocol(mock_worker)
protocol._closed = True
# Create receive queue with body complete
receive_queue = asyncio.Queue()
protocol._receive_queue = receive_queue
# Create a mock request with no body
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
# Add initial body message
await receive_queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Simulate what happens in _handle_http_request
body_complete = False
async def receive():
nonlocal body_complete
if protocol._closed and body_complete:
return {"type": "http.disconnect"}
msg = await receive_queue.get()
if msg.get("type") == "http.request" and not msg.get("more_body", True):
body_complete = True
return msg
# First receive gets the body
msg1 = await receive()
# First receive gets the body (empty)
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["more_body"] is False
# Second receive should get disconnect
msg2 = await receive()
# Second receive should get disconnect (body complete)
msg2 = await body_receiver.receive()
assert msg2["type"] == "http.disconnect"

View File

@ -11,7 +11,6 @@ and extension support.
from unittest import mock
import pytest
from gunicorn.config import Config
@ -40,7 +39,9 @@ class TestHTTPScopeBuilding:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -114,7 +115,9 @@ class TestPathHandling:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -194,7 +197,9 @@ class TestQueryStringHandling:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -270,7 +275,9 @@ class TestHeaderHandling:
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -362,7 +369,9 @@ class TestWebSocketScope:
"""Create a mock WebSocket upgrade request."""
request = mock.Mock()
request.method = "GET"
request.path = kwargs.get("path", "/ws")
path = kwargs.get("path", "/ws")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
@ -575,6 +584,7 @@ class TestAddressHandling:
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"

View File

@ -6,7 +6,6 @@
Tests for ASGI HTTP parser optimizations.
"""
import asyncio
import ipaddress
import pytest

View File

@ -9,7 +9,6 @@ Tests for chunked transfer encoding, Server-Sent Events (SSE),
and streaming response handling.
"""
import asyncio
from unittest import mock
import pytest
@ -221,43 +220,39 @@ class TestProtocolSendBody:
return protocol
@pytest.mark.asyncio
async def test_send_body_without_chunking(self):
def test_send_body_without_chunking(self):
"""Test sending body without chunked encoding."""
protocol = self._create_protocol()
await protocol._send_body(b"Hello, World!", chunked=False)
protocol._send_body(b"Hello, World!", chunked=False)
protocol.transport.write.assert_called_once_with(b"Hello, World!")
@pytest.mark.asyncio
async def test_send_body_with_chunking(self):
def test_send_body_with_chunking(self):
"""Test sending body with chunked encoding."""
protocol = self._create_protocol()
await protocol._send_body(b"Hello", chunked=True)
protocol._send_body(b"Hello", chunked=True)
# Should write: "5\r\nHello\r\n"
protocol.transport.write.assert_called_once()
call_arg = protocol.transport.write.call_args[0][0]
assert call_arg == b"5\r\nHello\r\n"
@pytest.mark.asyncio
async def test_send_body_empty_without_chunking(self):
def test_send_body_empty_without_chunking(self):
"""Test sending empty body without chunked encoding."""
protocol = self._create_protocol()
await protocol._send_body(b"", chunked=False)
protocol._send_body(b"", chunked=False)
# Empty body should not write anything
protocol.transport.write.assert_not_called()
@pytest.mark.asyncio
async def test_send_body_empty_with_chunking(self):
def test_send_body_empty_with_chunking(self):
"""Test sending empty body with chunked encoding."""
protocol = self._create_protocol()
await protocol._send_body(b"", chunked=True)
protocol._send_body(b"", chunked=True)
# Empty body should not write (terminal chunk handled separately)
protocol.transport.write.assert_not_called()

View File

@ -9,7 +9,6 @@ Tests that gunicorn's WebSocket implementation conforms to RFC 6455:
https://tools.ietf.org/html/rfc6455
"""
import asyncio
import base64
import hashlib
import struct
@ -176,7 +175,7 @@ class TestWebSocketFrameMasking:
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(None, None, {}, None, mock.Mock())
return WebSocketProtocol(None, {}, None, mock.Mock())
def test_unmask_simple(self):
"""Test basic unmasking operation."""
@ -299,7 +298,6 @@ class TestWebSocketProtocolInstance:
return WebSocketProtocol(
transport=mock.Mock(),
reader=mock.Mock(),
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
@ -602,11 +600,9 @@ class TestWebSocketAsync:
}
transport = mock.Mock()
reader = mock.Mock()
return WebSocketProtocol(
transport=transport,
reader=reader,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
@ -676,3 +672,367 @@ class TestWebSocketAsync:
await protocol._send({"type": "websocket.close", "code": 1000})
assert protocol.closed is True
# ============================================================================
# Callback-based Data Feeding Tests
# ============================================================================
class TestWebSocketCallbackDataFeeding:
"""Tests for callback-based data feeding (replaces StreamReader)."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
def test_initial_buffer_empty(self):
"""Test that initial buffer is empty."""
protocol = self._create_protocol()
assert len(protocol._buffer) == 0
assert protocol._eof is False
def test_feed_data_adds_to_buffer(self):
"""Test that feed_data adds bytes to buffer."""
protocol = self._create_protocol()
protocol.feed_data(b"Hello")
assert bytes(protocol._buffer) == b"Hello"
protocol.feed_data(b" World")
assert bytes(protocol._buffer) == b"Hello World"
def test_feed_data_ignores_empty(self):
"""Test that feed_data ignores empty data."""
protocol = self._create_protocol()
protocol.feed_data(b"")
assert len(protocol._buffer) == 0
protocol.feed_data(None)
# Should not raise, just be ignored
def test_feed_data_sets_event(self):
"""Test that feed_data sets the data event."""
protocol = self._create_protocol()
assert not protocol._data_event.is_set()
protocol.feed_data(b"data")
assert protocol._data_event.is_set()
def test_feed_eof_sets_flag(self):
"""Test that feed_eof sets the EOF flag."""
protocol = self._create_protocol()
assert protocol._eof is False
protocol.feed_eof()
assert protocol._eof is True
def test_feed_eof_sets_event(self):
"""Test that feed_eof sets the data event."""
protocol = self._create_protocol()
assert not protocol._data_event.is_set()
protocol.feed_eof()
assert protocol._data_event.is_set()
class TestWebSocketReadExact:
"""Tests for _read_exact method with callback-based buffer."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_read_exact_with_sufficient_data(self):
"""Test _read_exact returns data when buffer has enough."""
protocol = self._create_protocol()
# Pre-fill buffer
protocol.feed_data(b"Hello World")
result = await protocol._read_exact(5)
assert result == b"Hello"
assert bytes(protocol._buffer) == b" World"
@pytest.mark.asyncio
async def test_read_exact_consumes_buffer(self):
"""Test _read_exact properly consumes buffer."""
protocol = self._create_protocol()
protocol.feed_data(b"ABCDEFGH")
result1 = await protocol._read_exact(3)
assert result1 == b"ABC"
result2 = await protocol._read_exact(3)
assert result2 == b"DEF"
assert bytes(protocol._buffer) == b"GH"
@pytest.mark.asyncio
async def test_read_exact_returns_none_on_eof(self):
"""Test _read_exact returns None when EOF with insufficient data."""
protocol = self._create_protocol()
protocol.feed_data(b"Hi")
protocol.feed_eof()
# Request more data than available after EOF
result = await protocol._read_exact(10)
assert result is None
@pytest.mark.asyncio
async def test_read_exact_waits_for_data(self):
"""Test _read_exact waits when buffer is insufficient."""
import asyncio
protocol = self._create_protocol()
# Start read that needs more data
read_task = asyncio.create_task(protocol._read_exact(10))
# Give task a chance to start waiting
await asyncio.sleep(0.01)
assert not read_task.done()
# Feed enough data
protocol.feed_data(b"1234567890")
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_handles_incremental_data(self):
"""Test _read_exact handles data arriving in chunks."""
import asyncio
protocol = self._create_protocol()
# Start read needing 10 bytes
read_task = asyncio.create_task(protocol._read_exact(10))
await asyncio.sleep(0.01)
# Feed data incrementally
protocol.feed_data(b"123")
await asyncio.sleep(0.01)
assert not read_task.done()
protocol.feed_data(b"456")
await asyncio.sleep(0.01)
assert not read_task.done()
protocol.feed_data(b"7890")
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_race_condition(self):
"""Test _read_exact handles race condition when data arrives during clear/wait gap.
This tests the fix for the race condition where:
1. Task A checks buffer, needs more data
2. Task A clears _data_event
3. Task B (data_received) calls feed_data(), sets event
4. Task A would wait forever on cleared event - DEADLOCK
The fix adds a buffer check after clear() to catch this case.
"""
import asyncio
protocol = self._create_protocol()
# Pre-fill with partial data
protocol.feed_data(b"12345")
# Start read needing 10 bytes
read_task = asyncio.create_task(protocol._read_exact(10))
await asyncio.sleep(0.01)
assert not read_task.done()
# Simulate race: feed remaining data rapidly
# In the buggy version, if data arrives right after clear() but before wait(),
# the event gets set then immediately the wait() would block on a stale clear
protocol.feed_data(b"67890")
# Should complete without deadlock
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_multiple_feeds_before_wait(self):
"""Test _read_exact when all data arrives before wait starts."""
import asyncio
protocol = self._create_protocol()
# Feed all data before starting read - should not block
protocol.feed_data(b"Complete message here")
result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1)
assert result == b"Complete"
# Buffer should have remainder
assert bytes(protocol._buffer) == b" message here"
@pytest.mark.asyncio
async def test_read_exact_eof_during_wait(self):
"""Test _read_exact handles EOF arriving while waiting for data."""
import asyncio
protocol = self._create_protocol()
# Start read needing more data than we'll provide
read_task = asyncio.create_task(protocol._read_exact(100))
await asyncio.sleep(0.01)
assert not read_task.done()
# Feed some data but not enough
protocol.feed_data(b"partial")
await asyncio.sleep(0.01)
assert not read_task.done()
# Signal EOF - should cause read to return None
protocol.feed_eof()
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result is None
# ============================================================================
# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4)
# ============================================================================
class TestWebSocketFragmentedMessages:
"""Tests for WebSocket fragmented message handling."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
def _create_masked_frame(self, fin, opcode, payload, mask_key=None):
"""Create a masked WebSocket frame.
Args:
fin: FIN bit (1 for final, 0 for continuation)
opcode: Frame opcode
payload: Frame payload bytes
mask_key: 4-byte masking key (generated if None)
Returns:
bytes: Complete masked frame
"""
if mask_key is None:
mask_key = bytes([0x37, 0xfa, 0x21, 0x3d])
frame = bytearray()
# First byte: FIN + RSV(000) + opcode
frame.append((fin << 7) | opcode)
# Second byte: MASK(1) + length
length = len(payload)
if length < 126:
frame.append(0x80 | length)
elif length < 65536:
frame.append(0x80 | 126)
frame.extend(struct.pack("!H", length))
else:
frame.append(0x80 | 127)
frame.extend(struct.pack("!Q", length))
# Masking key
frame.extend(mask_key)
# Masked payload
masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
frame.extend(masked_payload)
return bytes(frame)
@pytest.mark.asyncio
async def test_fragmented_message_reassembly(self):
"""Test reassembly of fragmented text message with multiple continuation frames."""
from gunicorn.asgi.websocket import (
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT
)
import asyncio
protocol = self._create_protocol()
# Build fragmented message: "Hello" + " " + "World" + "!"
# First frame: opcode=TEXT, FIN=0, payload="Hello"
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
# Continuation frames: opcode=CONTINUATION, FIN=0
frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ")
frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World")
# Final frame: opcode=CONTINUATION, FIN=1
frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!")
# Feed all frames
protocol.feed_data(frame1 + frame2 + frame3 + frame4)
# Read frames - first 3 should return CONTINUATION with empty payload (waiting)
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued
# Final frame should return complete reassembled message with original opcode
result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result4 == (OPCODE_TEXT, b"Hello World!")
@pytest.mark.asyncio
async def test_control_frame_during_fragmentation(self):
"""Test that control frames (ping) can arrive during fragmented message.
RFC 6455 Section 5.4: Control frames MAY be injected in the middle
of a fragmented message.
"""
from gunicorn.asgi.websocket import (
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING
)
import asyncio
protocol = self._create_protocol()
# Start fragmented message
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
# Ping frame in the middle (control frames are always FIN=1)
ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping")
# Continue and finish fragmented message
frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World")
protocol.feed_data(frame1 + ping_frame + frame2)
# First read: fragment started (waiting for more)
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result1 == (OPCODE_CONTINUATION, b"")
# Second read: ping frame (control frames handled separately)
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result2 == (OPCODE_PING, b"ping")
# Third read: complete reassembled message
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result3 == (OPCODE_TEXT, b"Hello World")

View File

@ -12,11 +12,7 @@ that actually start the server and make HTTP requests.
import asyncio
import errno
import os
import signal
import socket
import sys
import time
import threading
from unittest import mock
import pytest
@ -120,7 +116,7 @@ class FakeListener:
def _has_uvloop():
"""Check if uvloop is available."""
try:
import uvloop
import uvloop # noqa: F401
return True
except ImportError:
return False
@ -337,9 +333,9 @@ class TestLifespanManager:
async def app(scope, receive, send):
assert "state" in scope
scope["state"]["db"] = "connected"
message = await receive()
_ = await receive()
await send({"type": "lifespan.startup.complete"})
message = await receive()
_ = await receive()
await send({"type": "lifespan.shutdown.complete"})
manager = LifespanManager(app, mock.Mock(), state)
@ -393,7 +389,7 @@ class TestWebSocketProtocol:
from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
# Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
@ -406,7 +402,7 @@ class TestWebSocketProtocol:
"""Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key)
@ -555,7 +551,6 @@ class TestASGIProtocol:
def test_scope_building(self):
"""Test HTTP scope building."""
from gunicorn.asgi.protocol import ASGIProtocol
from gunicorn.asgi.message import AsyncRequest
worker = mock.Mock()
worker.cfg = Config()
@ -729,10 +724,11 @@ class TestASGIHTTP2Priority:
protocol = ASGIProtocol(worker)
# Create mock HTTP/1.1 request (no priority attributes)
request = mock.Mock(spec=['method', 'path', 'query', 'version',
request = mock.Mock(spec=['method', 'path', 'raw_path', 'query', 'version',
'scheme', 'headers'])
request.method = "GET"
request.path = "/test"
request.raw_path = b"/test"
request.query = ""
request.version = (1, 1)
request.scheme = "http"

View File

@ -7,20 +7,57 @@ import os
import pytest
from gunicorn.http.errors import (
InvalidRequestLine,
InvalidRequestMethod,
InvalidSchemeHeaders,
ObsoleteFolding,
)
import treq
dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "invalid")
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
# Flags incompatible with fast parser (require Python parser features)
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
# Exceptions that only the Python parser raises (C parser has different validation)
_PYTHON_ONLY_EXCEPTIONS = (ObsoleteFolding, InvalidSchemeHeaders)
# C parser may raise different but valid exceptions for these cases
_FAST_PARSER_ALTERNATES = {
InvalidRequestMethod: (InvalidRequestLine,), # e.g. "GET:" raises InvalidRequestLine
}
@pytest.mark.parametrize("fname", httpfiles)
def test_http_parser(fname):
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
def test_http_parser(fname, http_parser):
"""Test invalid HTTP requests with both parser implementations."""
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
expect = env["request"]
cfg = env["cfg"]
# Skip fast parser tests that use incompatible compatibility flags
if http_parser == 'fast':
for flag in _FAST_INCOMPATIBLE_FLAGS:
if getattr(cfg, flag, False):
pytest.skip(f"fast parser incompatible with {flag}")
# Skip tests expecting Python-only exceptions
if expect in _PYTHON_ONLY_EXCEPTIONS or (
isinstance(expect, type) and issubclass(expect, _PYTHON_ONLY_EXCEPTIONS)
):
pytest.skip(f"fast parser does not raise {expect.__name__}")
# Determine acceptable exceptions (fast parser may raise alternates)
if http_parser == 'fast' and expect in _FAST_PARSER_ALTERNATES:
acceptable = (expect,) + _FAST_PARSER_ALTERNATES[expect]
else:
acceptable = expect
req = treq.badrequest(fname)
with pytest.raises(expect):
with pytest.raises(acceptable):
req.check(cfg)

View File

@ -0,0 +1,518 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for PythonProtocol callback-based HTTP parser.
"""
import pytest
from gunicorn.asgi.parser import PythonProtocol, CallbackRequest, ParseError
class TestPythonProtocolBasic:
"""Test basic request parsing."""
def test_simple_get_request(self):
"""Test parsing a simple GET request."""
headers_complete = []
message_complete = []
parser = PythonProtocol(
on_headers_complete=lambda: headers_complete.append(True),
on_message_complete=lambda: message_complete.append(True),
)
data = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"
parser.feed(data)
assert parser.method == b"GET"
assert parser.path == b"/path"
assert parser.http_version == (1, 1)
assert len(parser.headers) == 1
assert parser.headers[0] == (b"host", b"example.com")
assert parser.is_complete is True
assert len(headers_complete) == 1
assert len(message_complete) == 1
def test_get_with_query_string(self):
"""Test parsing GET with query string."""
parser = PythonProtocol()
data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: example.com\r\n\r\n"
parser.feed(data)
assert parser.method == b"GET"
assert parser.path == b"/search?q=test&page=1"
assert parser.is_complete is True
def test_http_10_request(self):
"""Test parsing HTTP/1.0 request."""
parser = PythonProtocol()
data = b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n"
parser.feed(data)
assert parser.http_version == (1, 0)
assert parser.should_keep_alive is False # HTTP/1.0 default
def test_http_10_with_keepalive(self):
"""Test HTTP/1.0 with explicit keep-alive."""
parser = PythonProtocol()
data = b"GET / HTTP/1.0\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n"
parser.feed(data)
assert parser.http_version == (1, 0)
assert parser.should_keep_alive is True
def test_multiple_headers(self):
"""Test parsing multiple headers."""
parser = PythonProtocol()
data = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Accept: text/html\r\n"
b"Accept-Language: en-US\r\n"
b"User-Agent: Test/1.0\r\n"
b"\r\n"
)
parser.feed(data)
assert len(parser.headers) == 4
header_names = [h[0] for h in parser.headers]
assert b"host" in header_names
assert b"accept" in header_names
assert b"accept-language" in header_names
assert b"user-agent" in header_names
class TestPythonProtocolBody:
"""Test request body parsing."""
def test_post_with_content_length(self):
"""Test POST with Content-Length body."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
data = (
b"POST /submit HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 13\r\n"
b"\r\n"
b"name=testuser"
)
parser.feed(data)
assert parser.method == b"POST"
assert parser.content_length == 13
assert parser.is_complete is True
assert b"".join(body_chunks) == b"name=testuser"
def test_chunked_body(self):
"""Test chunked transfer encoding."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
data = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\nhello\r\n"
b"6\r\n world\r\n"
b"0\r\n\r\n"
)
parser.feed(data)
assert parser.is_chunked is True
assert parser.is_complete is True
assert b"".join(body_chunks) == b"hello world"
def test_chunked_with_extension(self):
"""Test chunked with chunk extension."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
data = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5;ext=value\r\nhello\r\n"
b"0\r\n\r\n"
)
parser.feed(data)
assert b"".join(body_chunks) == b"hello"
class TestPythonProtocolIncremental:
"""Test incremental/partial data feeding."""
def test_partial_request_line(self):
"""Test feeding partial request line."""
parser = PythonProtocol()
# Feed partial request line
parser.feed(b"GET /path ")
assert parser.method is None
assert parser.is_complete is False
# Complete the request line and headers
parser.feed(b"HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert parser.method == b"GET"
assert parser.is_complete is True
def test_partial_headers(self):
"""Test feeding partial headers."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\n")
parser.feed(b"Host: exa")
assert parser.is_complete is False
parser.feed(b"mple.com\r\n\r\n")
assert parser.is_complete is True
assert parser.headers[0] == (b"host", b"example.com")
def test_partial_body(self):
"""Test feeding partial body."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
b"hello"
)
assert parser.is_complete is False
parser.feed(b"world")
assert parser.is_complete is True
assert b"".join(body_chunks) == b"helloworld"
def test_partial_chunked_body(self):
"""Test feeding partial chunked body."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\nhel"
)
assert parser.is_complete is False
parser.feed(b"lo\r\n0\r\n\r\n")
assert parser.is_complete is True
assert b"".join(body_chunks) == b"hello"
class TestPythonProtocolErrors:
"""Test error handling."""
def test_invalid_request_line(self):
"""Test invalid request line."""
parser = PythonProtocol()
with pytest.raises(ParseError):
parser.feed(b"INVALID\r\n")
def test_invalid_header(self):
"""Test invalid header (no colon)."""
parser = PythonProtocol()
with pytest.raises(ParseError):
parser.feed(b"GET / HTTP/1.1\r\nBadHeader\r\n\r\n")
def test_unsupported_http_version(self):
"""Test unsupported HTTP version."""
parser = PythonProtocol()
with pytest.raises(ParseError):
parser.feed(b"GET / HTTP/2.0\r\n\r\n")
def test_invalid_chunk_size(self):
"""Test invalid chunk size."""
parser = PythonProtocol()
with pytest.raises(ParseError):
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"XYZ\r\n" # Invalid hex
)
class TestPythonProtocolReset:
"""Test parser reset for keepalive."""
def test_reset_clears_state(self):
"""Test that reset clears all state."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert parser.is_complete is True
assert parser.method == b"GET"
parser.reset()
assert parser.method is None
assert parser.path is None
assert parser.http_version is None
assert parser.headers == []
assert parser.content_length is None
assert parser.is_chunked is False
assert parser.is_complete is False
def test_multiple_requests_keepalive(self):
"""Test handling multiple requests on keepalive connection."""
parser = PythonProtocol()
# First request
parser.feed(b"GET /first HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert parser.path == b"/first"
assert parser.is_complete is True
parser.reset()
# Second request
parser.feed(b"GET /second HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert parser.path == b"/second"
assert parser.is_complete is True
class TestPythonProtocolCallbacks:
"""Test callback firing."""
def test_all_callbacks(self):
"""Test all callbacks fire in correct order."""
events = []
parser = PythonProtocol(
on_message_begin=lambda: events.append("begin"),
on_url=lambda url: events.append(("url", url)),
on_header=lambda n, v: events.append(("header", n, v)),
on_headers_complete=lambda: events.append("headers_complete"),
on_body=lambda chunk: events.append(("body", chunk)),
on_message_complete=lambda: events.append("complete"),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
assert events[0] == "begin"
assert events[1] == ("url", b"/")
assert events[2] == ("header", b"host", b"example.com")
assert events[3] == ("header", b"content-length", b"5")
assert events[4] == "headers_complete"
assert events[5] == ("body", b"hello")
assert events[6] == "complete"
def test_skip_body_callback(self):
"""Test on_headers_complete returning True skips body."""
body_chunks = []
parser = PythonProtocol(
on_headers_complete=lambda: True, # Skip body
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
# Body should be skipped
assert parser.is_complete is True
assert len(body_chunks) == 0
class TestCallbackRequest:
"""Test CallbackRequest adapter."""
def test_from_parser_simple(self):
"""Test creating request from parser state."""
parser = PythonProtocol()
parser.feed(b"GET /path?query=value HTTP/1.1\r\nHost: example.com\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.method == "GET"
assert request.path == "/path"
assert request.query == "query=value"
assert request.uri == "/path?query=value"
assert request.version == (1, 1)
assert request.scheme == "http"
def test_from_parser_ssl(self):
"""Test SSL scheme detection."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
request = CallbackRequest.from_parser(parser, is_ssl=True)
assert request.scheme == "https"
def test_from_parser_headers(self):
"""Test header conversion."""
parser = PythonProtocol()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Type: text/plain\r\n"
b"\r\n"
)
request = CallbackRequest.from_parser(parser)
# String headers (uppercase)
assert ("HOST", "example.com") in request.headers
assert ("CONTENT-TYPE", "text/plain") in request.headers
# Bytes headers (lowercase)
assert (b"host", b"example.com") in request.headers_bytes
assert (b"content-type", b"text/plain") in request.headers_bytes
def test_from_parser_body_info(self):
"""Test body info extraction."""
parser = PythonProtocol()
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
request = CallbackRequest.from_parser(parser)
assert request.content_length == 10
assert request.chunked is False
def test_from_parser_chunked(self):
"""Test chunked transfer detection."""
parser = PythonProtocol()
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
)
request = CallbackRequest.from_parser(parser)
assert request.chunked is True
def test_should_close(self):
"""Test should_close method."""
# HTTP/1.1 with Connection: close
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.should_close() is True
# HTTP/1.1 keep-alive (default)
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.should_close() is False
# HTTP/1.0 (default close)
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.0\r\n\r\n")
request = CallbackRequest.from_parser(parser)
assert request.should_close() is True
def test_get_header(self):
"""Test get_header method."""
parser = PythonProtocol()
parser.feed(
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"X-Custom: value\r\n"
b"\r\n"
)
request = CallbackRequest.from_parser(parser)
assert request.get_header("Host") == "example.com"
assert request.get_header("x-custom") == "value"
assert request.get_header("X-CUSTOM") == "value"
assert request.get_header("X-Missing") is None
def test_expect_100_continue(self):
"""Test Expect: 100-continue detection."""
parser = PythonProtocol()
parser.feed(
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Expect: 100-continue\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
request = CallbackRequest.from_parser(parser)
assert request._expect_100_continue is True
class TestPythonProtocolConnectionClose:
"""Test connection close handling."""
def test_connection_close_header(self):
"""Test Connection: close header."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
assert parser.should_keep_alive is False
def test_connection_keepalive_header(self):
"""Test Connection: keep-alive header."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n")
assert parser.should_keep_alive is True
def test_http11_default_keepalive(self):
"""Test HTTP/1.1 defaults to keep-alive."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
assert parser.should_keep_alive is True
def test_http10_default_close(self):
"""Test HTTP/1.0 defaults to close."""
parser = PythonProtocol()
parser.feed(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n")
assert parser.should_keep_alive is False

View File

@ -165,6 +165,10 @@ class TestSignalHandlingIntegration:
proc.kill()
pytest.fail("Gunicorn did not exit within timeout after SIGTERM")
@pytest.mark.skipif(
hasattr(sys, 'pypy_version_info'),
reason="SIGINT handling differs on PyPy, use SIGTERM test instead"
)
def test_graceful_shutdown_sigint(self, gunicorn_server):
"""Verify SIGINT causes graceful shutdown."""
proc, port = gunicorn_server

View File

@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "valid")
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
# Flags incompatible with fast parser
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
@pytest.mark.parametrize("fname", httpfiles)
def test_http_parser(fname):
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
def test_http_parser(fname, http_parser):
"""Test valid HTTP requests with both parser implementations."""
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
expect = env['request']
cfg = env['cfg']
# Skip fast parser tests that use incompatible compatibility flags
if http_parser == 'fast':
for flag in _FAST_INCOMPATIBLE_FLAGS:
if getattr(cfg, flag, False):
pytest.skip(f"fast parser incompatible with {flag}")
req = treq.request(fname, expect)
for case in req.gen_cases(cfg):

View File

@ -33,19 +33,26 @@ def uri(data):
return ret
def load_py(fname):
def load_py(fname, http_parser='python'):
"""Load test configuration from Python file.
Args:
fname: Path to the .py configuration file
http_parser: Parser to use - 'python' or 'fast'
"""
module_name = '__config__'
mod = types.ModuleType(module_name)
setattr(mod, 'uri', uri)
setattr(mod, 'cfg', Config())
loader = importlib.machinery.SourceFileLoader(module_name, fname)
loader.exec_module(mod)
# Set parser after loading so test-specific configs don't override
mod.cfg.set('http_parser', http_parser)
return vars(mod)
def decode_hex_escapes(data):
"""Decode hex escape sequences like \\xAB in test data."""
import re
result = bytearray()
i = 0
while i < len(data):

View File

@ -1,416 +0,0 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import pytest
import sys
from unittest import mock
def test_import():
"""Test that the eventlet worker module can be imported."""
try:
import eventlet
except AttributeError:
if (3, 13) > sys.version_info >= (3, 12):
pytest.skip("Ignoring eventlet failures on Python 3.12")
raise
__import__('gunicorn.workers.geventlet')
class TestVersionRequirement:
"""Tests for eventlet version requirement checks."""
def test_import_error_message(self):
"""Test that ImportError gives correct version message."""
with mock.patch.dict('sys.modules', {'eventlet': None}):
# Clear cached module if present
sys.modules.pop('gunicorn.workers.geventlet', None)
with pytest.raises(RuntimeError, match="eventlet 0.40.3"):
import importlib
import gunicorn.workers.geventlet
importlib.reload(gunicorn.workers.geventlet)
def test_version_check_requires_0_40_3(self):
"""Test that version check requires eventlet 0.40.3 or higher."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from packaging.version import parse as parse_version
min_version = parse_version('0.40.3')
current_version = parse_version(eventlet.__version__)
# If we got this far, the import succeeded, meaning version is sufficient
assert current_version >= min_version
@pytest.fixture
def eventlet_worker():
"""Fixture to create an EventletWorker instance for testing."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import EventletWorker
# Create a minimal mock config
cfg = mock.MagicMock()
cfg.keepalive = 2
cfg.graceful_timeout = 30
cfg.is_ssl = False
cfg.worker_connections = 1000
# Create worker with mocked dependencies
worker = EventletWorker.__new__(EventletWorker)
worker.cfg = cfg
worker.alive = True
worker.sockets = []
worker.log = mock.MagicMock()
return worker
class TestEventletWorker:
"""Tests for EventletWorker class."""
def test_worker_class_exists(self):
"""Test that EventletWorker class is properly defined."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import EventletWorker
from gunicorn.workers.base_async import AsyncWorker
assert issubclass(EventletWorker, AsyncWorker)
def test_patch_method_calls_use_hub(self, eventlet_worker):
"""Test that patch() calls hubs.use_hub().
hubs.use_hub() must be called in patch() (after fork) because it creates
OS resources like kqueue that don't survive fork.
"""
from eventlet import hubs
with mock.patch.object(hubs, 'use_hub') as mock_use_hub:
with mock.patch('gunicorn.workers.geventlet.patch_sendfile'):
eventlet_worker.patch()
mock_use_hub.assert_called_once()
def test_patch_method_calls_patch_sendfile(self, eventlet_worker):
"""Test that patch() calls patch_sendfile()."""
from eventlet import hubs
with mock.patch.object(hubs, 'use_hub'):
with mock.patch('gunicorn.workers.geventlet.patch_sendfile') as mock_sf:
eventlet_worker.patch()
mock_sf.assert_called_once()
def test_monkey_patch_called_at_import_time(self):
"""Test that monkey_patch is called at module import time.
Note: hubs.use_hub() and eventlet.monkey_patch() are called at module
import time (not in patch()) to ensure all imports are properly patched.
This test verifies the module was patched by checking eventlet state.
"""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
# Verify eventlet has been patched by checking that socket is patched
import socket
from eventlet.greenio import GreenSocket
# After monkey patching, socket.socket should be GreenSocket
assert socket.socket is GreenSocket
def test_timeout_ctx_returns_eventlet_timeout(self, eventlet_worker):
"""Test that timeout_ctx() returns an eventlet.Timeout."""
import eventlet
timeout = eventlet_worker.timeout_ctx()
assert isinstance(timeout, eventlet.Timeout)
def test_timeout_ctx_uses_keepalive_config(self, eventlet_worker):
"""Test that timeout_ctx() uses cfg.keepalive value."""
import eventlet
eventlet_worker.cfg.keepalive = 5
with mock.patch.object(eventlet, 'Timeout') as mock_timeout:
eventlet_worker.timeout_ctx()
mock_timeout.assert_called_once_with(5, False)
def test_timeout_ctx_with_no_keepalive(self, eventlet_worker):
"""Test that timeout_ctx() handles no keepalive (None or 0)."""
import eventlet
eventlet_worker.cfg.keepalive = 0
with mock.patch.object(eventlet, 'Timeout') as mock_timeout:
eventlet_worker.timeout_ctx()
mock_timeout.assert_called_once_with(None, False)
def test_handle_quit_spawns_greenthread(self, eventlet_worker):
"""Test that handle_quit() spawns a greenthread."""
import eventlet
with mock.patch.object(eventlet, 'spawn') as mock_spawn:
eventlet_worker.handle_quit(None, None)
mock_spawn.assert_called_once()
def test_handle_usr1_spawns_greenthread(self, eventlet_worker):
"""Test that handle_usr1() spawns a greenthread."""
import eventlet
with mock.patch.object(eventlet, 'spawn') as mock_spawn:
eventlet_worker.handle_usr1(None, None)
mock_spawn.assert_called_once()
def test_handle_wraps_ssl_when_configured(self, eventlet_worker):
"""Test that handle() wraps socket with SSL when is_ssl is True."""
from gunicorn.workers import geventlet
eventlet_worker.cfg.is_ssl = True
mock_client = mock.MagicMock()
mock_listener = mock.MagicMock()
with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl:
mock_ssl.return_value = mock_client
with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'):
eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000))
mock_ssl.assert_called_once_with(mock_client, eventlet_worker.cfg)
def test_handle_no_ssl_when_not_configured(self, eventlet_worker):
"""Test that handle() does not wrap SSL when is_ssl is False."""
from gunicorn.workers import geventlet
eventlet_worker.cfg.is_ssl = False
mock_client = mock.MagicMock()
mock_listener = mock.MagicMock()
with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl:
with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'):
eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000))
mock_ssl.assert_not_called()
class TestAlreadyHandled:
"""Tests for is_already_handled() method."""
def test_is_already_handled_new_style(self, eventlet_worker):
"""Test is_already_handled with eventlet >= 0.30.3 (WSGI_LOCAL)."""
from gunicorn.workers import geventlet
# Mock the new-style WSGI_LOCAL.already_handled
mock_wsgi_local = mock.MagicMock()
mock_wsgi_local.already_handled = True
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', mock_wsgi_local):
with pytest.raises(StopIteration):
eventlet_worker.is_already_handled(mock.MagicMock())
def test_is_already_handled_old_style(self, eventlet_worker):
"""Test is_already_handled with eventlet < 0.30.3 (ALREADY_HANDLED)."""
from gunicorn.workers import geventlet
sentinel = object()
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None):
with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', sentinel):
with pytest.raises(StopIteration):
eventlet_worker.is_already_handled(sentinel)
def test_is_already_handled_returns_parent_result(self, eventlet_worker):
"""Test is_already_handled falls through to parent when not handled."""
from gunicorn.workers import geventlet
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None):
with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', None):
with mock.patch('gunicorn.workers.base_async.AsyncWorker.is_already_handled') as mock_parent:
mock_parent.return_value = False
result = eventlet_worker.is_already_handled(mock.MagicMock())
assert result is False
mock_parent.assert_called_once()
class TestPatchSendfile:
"""Tests for patch_sendfile() function."""
def test_patch_sendfile_adds_method_when_missing(self):
"""Test that patch_sendfile adds sendfile to GreenSocket if missing."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import patch_sendfile, _eventlet_socket_sendfile
from eventlet.greenio import GreenSocket
# Remove sendfile if it exists
original = getattr(GreenSocket, 'sendfile', None)
if hasattr(GreenSocket, 'sendfile'):
delattr(GreenSocket, 'sendfile')
try:
patch_sendfile()
assert hasattr(GreenSocket, 'sendfile')
assert GreenSocket.sendfile == _eventlet_socket_sendfile
finally:
# Restore original state
if original is not None:
GreenSocket.sendfile = original
elif hasattr(GreenSocket, 'sendfile'):
delattr(GreenSocket, 'sendfile')
def test_patch_sendfile_preserves_existing_method(self):
"""Test that patch_sendfile does not override existing sendfile."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import patch_sendfile
from eventlet.greenio import GreenSocket
# If sendfile exists, it should be preserved
if hasattr(GreenSocket, 'sendfile'):
original = GreenSocket.sendfile
patch_sendfile()
assert GreenSocket.sendfile == original
class TestEventletSocketSendfile:
"""Tests for _eventlet_socket_sendfile() function."""
def test_sendfile_raises_on_non_blocking(self):
"""Test that sendfile raises ValueError for non-blocking sockets."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
mock_socket = mock.MagicMock()
mock_socket.gettimeout.return_value = 0
with pytest.raises(ValueError, match="non-blocking"):
_eventlet_socket_sendfile(mock_socket, mock.MagicMock())
def test_sendfile_seeks_to_offset(self):
"""Test that sendfile seeks to offset if provided."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
mock_socket = mock.MagicMock()
mock_socket.gettimeout.return_value = 1
mock_file = mock.MagicMock()
mock_file.read.return_value = b''
_eventlet_socket_sendfile(mock_socket, mock_file, offset=100)
mock_file.seek.assert_any_call(100)
def test_sendfile_returns_total_sent(self):
"""Test that sendfile returns the total bytes sent."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
mock_socket = mock.MagicMock()
mock_socket.gettimeout.return_value = 1
mock_socket.send.return_value = 10
mock_file = mock.MagicMock()
mock_file.read.side_effect = [b'x' * 10, b'']
result = _eventlet_socket_sendfile(mock_socket, mock_file)
assert result == 10
class TestEventletServe:
"""Tests for _eventlet_serve() function."""
def test_serve_creates_green_pool(self):
"""Test that _eventlet_serve creates a GreenPool."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_serve
mock_sock = mock.MagicMock()
mock_sock.accept.side_effect = eventlet.StopServe()
with mock.patch.object(eventlet.greenpool, 'GreenPool') as mock_pool:
mock_pool_instance = mock.MagicMock()
mock_pool.return_value = mock_pool_instance
mock_pool_instance.waitall.return_value = None
_eventlet_serve(mock_sock, mock.MagicMock(), 100)
mock_pool.assert_called_once_with(100)
class TestEventletStop:
"""Tests for _eventlet_stop() function."""
def test_stop_waits_for_client(self):
"""Test that _eventlet_stop waits for the client greenlet."""
try:
import eventlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_stop
mock_client = mock.MagicMock()
mock_server = mock.MagicMock()
mock_conn = mock.MagicMock()
_eventlet_stop(mock_client, mock_server, mock_conn)
mock_client.wait.assert_called_once()
mock_conn.close.assert_called_once()
def test_stop_closes_connection_on_greenlet_exit(self):
"""Test that connection is closed even on GreenletExit."""
try:
import eventlet
import greenlet
except (ImportError, AttributeError):
pytest.skip("eventlet not available")
from gunicorn.workers.geventlet import _eventlet_stop
mock_client = mock.MagicMock()
mock_client.wait.side_effect = greenlet.GreenletExit()
mock_server = mock.MagicMock()
mock_conn = mock.MagicMock()
# Should not raise
_eventlet_stop(mock_client, mock_server, mock_conn)
mock_conn.close.assert_called_once()