mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-05 12:11:29 +08:00
Merge pull request #3549 from benoitc/feature/optional-http-parser
Optimize ASGI performance with fast parser integration
This commit is contained in:
commit
3667a10478
269
benchmarks/http_parser_benchmark.py
Normal file
269
benchmarks/http_parser_benchmark.py
Normal file
@ -0,0 +1,269 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Benchmark comparing HTTP parser implementations.
|
||||
|
||||
Compares:
|
||||
- WSGI Python parser vs Fast parser (gunicorn_h1c)
|
||||
- ASGI Python parser vs Fast parser (gunicorn_h1c)
|
||||
|
||||
Usage:
|
||||
python benchmarks/http_parser_benchmark.py
|
||||
"""
|
||||
|
||||
import io
|
||||
import time
|
||||
import statistics
|
||||
from typing import NamedTuple
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.http.message import Request, _check_fast_parser
|
||||
from gunicorn.http.unreader import IterUnreader
|
||||
|
||||
|
||||
# Check if fast parser is available
|
||||
try:
|
||||
import gunicorn_h1c
|
||||
FAST_AVAILABLE = True
|
||||
except ImportError:
|
||||
FAST_AVAILABLE = False
|
||||
print("WARNING: gunicorn_h1c not installed. Fast parser benchmarks will be skipped.")
|
||||
print("Install with: pip install gunicorn_h1c\n")
|
||||
|
||||
|
||||
class BenchmarkResult(NamedTuple):
|
||||
name: str
|
||||
iterations: int
|
||||
total_time: float
|
||||
avg_time_us: float
|
||||
min_time_us: float
|
||||
max_time_us: float
|
||||
requests_per_sec: float
|
||||
|
||||
|
||||
# Test requests of varying complexity
|
||||
SIMPLE_REQUEST = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
|
||||
MEDIUM_REQUEST = b"""POST /api/users HTTP/1.1\r
|
||||
Host: api.example.com\r
|
||||
Content-Type: application/json\r
|
||||
Content-Length: 42\r
|
||||
Accept: application/json\r
|
||||
Authorization: Bearer token123\r
|
||||
X-Request-ID: abc-123-def-456\r
|
||||
\r
|
||||
"""
|
||||
|
||||
COMPLEX_REQUEST = b"""POST /api/v2/resources/items HTTP/1.1\r
|
||||
Host: api.example.com\r
|
||||
Content-Type: application/json; charset=utf-8\r
|
||||
Content-Length: 1024\r
|
||||
Accept: application/json, text/plain, */*\r
|
||||
Accept-Language: en-US,en;q=0.9,fr;q=0.8\r
|
||||
Accept-Encoding: gzip, deflate, br\r
|
||||
Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ\r
|
||||
X-Request-ID: 550e8400-e29b-41d4-a716-446655440000\r
|
||||
X-Correlation-ID: 7f3d8c2a-1b4e-4a6f-9c8d-2e5f6a7b8c9d\r
|
||||
X-Forwarded-For: 203.0.113.195, 70.41.3.18, 150.172.238.178\r
|
||||
X-Forwarded-Proto: https\r
|
||||
X-Real-IP: 203.0.113.195\r
|
||||
User-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36\r
|
||||
Cache-Control: no-cache, no-store, must-revalidate\r
|
||||
Pragma: no-cache\r
|
||||
Cookie: session=abc123; preferences=dark_mode\r
|
||||
If-None-Match: "etag-value-here"\r
|
||||
If-Modified-Since: Wed, 21 Oct 2024 07:28:00 GMT\r
|
||||
\r
|
||||
"""
|
||||
|
||||
|
||||
def create_wsgi_config(use_fast: bool) -> Config:
|
||||
"""Create a config for WSGI parsing."""
|
||||
cfg = Config()
|
||||
cfg.set('http_parser', 'fast' if use_fast else 'python')
|
||||
return cfg
|
||||
|
||||
|
||||
def benchmark_wsgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult:
|
||||
"""Benchmark WSGI parser."""
|
||||
times = []
|
||||
parser_type = cfg.http_parser
|
||||
|
||||
for _ in range(iterations):
|
||||
# Create fresh unreader for each iteration
|
||||
unreader = IterUnreader(iter([request_data]))
|
||||
|
||||
start = time.perf_counter()
|
||||
req = Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
|
||||
end = time.perf_counter()
|
||||
|
||||
times.append(end - start)
|
||||
|
||||
# Verify parsing worked
|
||||
assert req.method is not None
|
||||
|
||||
total_time = sum(times)
|
||||
avg_time = statistics.mean(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
|
||||
return BenchmarkResult(
|
||||
name=f"WSGI {parser_type}",
|
||||
iterations=iterations,
|
||||
total_time=total_time,
|
||||
avg_time_us=avg_time * 1_000_000,
|
||||
min_time_us=min_time * 1_000_000,
|
||||
max_time_us=max_time * 1_000_000,
|
||||
requests_per_sec=iterations / total_time,
|
||||
)
|
||||
|
||||
|
||||
def benchmark_asgi_parser(request_data: bytes, cfg: Config, iterations: int) -> BenchmarkResult:
|
||||
"""Benchmark ASGI parser."""
|
||||
from gunicorn.asgi.parser import HttpParser
|
||||
|
||||
times = []
|
||||
parser_type = cfg.http_parser
|
||||
|
||||
for _ in range(iterations):
|
||||
# Create fresh parser for each iteration
|
||||
parser = HttpParser(cfg, ('127.0.0.1', 8000), is_ssl=False)
|
||||
|
||||
start = time.perf_counter()
|
||||
result = parser.feed(bytearray(request_data))
|
||||
end = time.perf_counter()
|
||||
|
||||
times.append(end - start)
|
||||
|
||||
# Verify parsing worked
|
||||
assert result is not None
|
||||
assert result.method is not None
|
||||
|
||||
total_time = sum(times)
|
||||
avg_time = statistics.mean(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
|
||||
return BenchmarkResult(
|
||||
name=f"ASGI {parser_type}",
|
||||
iterations=iterations,
|
||||
total_time=total_time,
|
||||
avg_time_us=avg_time * 1_000_000,
|
||||
min_time_us=min_time * 1_000_000,
|
||||
max_time_us=max_time * 1_000_000,
|
||||
requests_per_sec=iterations / total_time,
|
||||
)
|
||||
|
||||
|
||||
def print_result(result: BenchmarkResult, baseline: BenchmarkResult = None):
|
||||
"""Print benchmark result."""
|
||||
speedup = ""
|
||||
if baseline and baseline.avg_time_us > 0:
|
||||
ratio = baseline.avg_time_us / result.avg_time_us
|
||||
if ratio > 1:
|
||||
speedup = f" ({ratio:.2f}x faster)"
|
||||
elif ratio < 1:
|
||||
speedup = f" ({1/ratio:.2f}x slower)"
|
||||
|
||||
print(f" {result.name:20} {result.avg_time_us:8.2f} us/req "
|
||||
f"({result.requests_per_sec:,.0f} req/s){speedup}")
|
||||
|
||||
|
||||
def run_benchmark_suite(name: str, request_data: bytes, iterations: int):
|
||||
"""Run a complete benchmark suite for a request type."""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Benchmark: {name}")
|
||||
print(f"Request size: {len(request_data)} bytes, Iterations: {iterations:,}")
|
||||
print('='*60)
|
||||
|
||||
results = []
|
||||
|
||||
# WSGI Python
|
||||
cfg_python = create_wsgi_config(use_fast=False)
|
||||
result_wsgi_python = benchmark_wsgi_parser(request_data, cfg_python, iterations)
|
||||
results.append(result_wsgi_python)
|
||||
|
||||
# WSGI Fast (if available)
|
||||
if FAST_AVAILABLE:
|
||||
cfg_fast = create_wsgi_config(use_fast=True)
|
||||
result_wsgi_fast = benchmark_wsgi_parser(request_data, cfg_fast, iterations)
|
||||
results.append(result_wsgi_fast)
|
||||
|
||||
# ASGI Python
|
||||
cfg_python = create_wsgi_config(use_fast=False)
|
||||
result_asgi_python = benchmark_asgi_parser(request_data, cfg_python, iterations)
|
||||
results.append(result_asgi_python)
|
||||
|
||||
# ASGI Fast (if available)
|
||||
if FAST_AVAILABLE:
|
||||
cfg_fast = create_wsgi_config(use_fast=True)
|
||||
result_asgi_fast = benchmark_asgi_parser(request_data, cfg_fast, iterations)
|
||||
results.append(result_asgi_fast)
|
||||
|
||||
# Print results
|
||||
print("\nResults (avg time per request):")
|
||||
print("-" * 60)
|
||||
|
||||
# Print WSGI results
|
||||
print_result(result_wsgi_python)
|
||||
if FAST_AVAILABLE:
|
||||
print_result(result_wsgi_fast, result_wsgi_python)
|
||||
|
||||
print()
|
||||
|
||||
# Print ASGI results
|
||||
print_result(result_asgi_python)
|
||||
if FAST_AVAILABLE:
|
||||
print_result(result_asgi_fast, result_asgi_python)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
print("HTTP Parser Benchmark")
|
||||
print("=" * 60)
|
||||
print(f"Fast parser (gunicorn_h1c): {'Available' if FAST_AVAILABLE else 'Not installed'}")
|
||||
|
||||
# Warmup
|
||||
print("\nWarming up...")
|
||||
cfg = create_wsgi_config(use_fast=False)
|
||||
for _ in range(100):
|
||||
unreader = IterUnreader(iter([SIMPLE_REQUEST]))
|
||||
Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
|
||||
|
||||
if FAST_AVAILABLE:
|
||||
cfg = create_wsgi_config(use_fast=True)
|
||||
for _ in range(100):
|
||||
unreader = IterUnreader(iter([SIMPLE_REQUEST]))
|
||||
Request(cfg, unreader, ('127.0.0.1', 8000), req_number=1)
|
||||
|
||||
# Run benchmarks
|
||||
iterations = 10000
|
||||
|
||||
all_results = []
|
||||
all_results.extend(run_benchmark_suite("Simple GET Request", SIMPLE_REQUEST, iterations))
|
||||
all_results.extend(run_benchmark_suite("Medium POST Request", MEDIUM_REQUEST, iterations))
|
||||
all_results.extend(run_benchmark_suite("Complex POST Request", COMPLEX_REQUEST, iterations))
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
if FAST_AVAILABLE:
|
||||
# Calculate overall speedups
|
||||
wsgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI python"])
|
||||
wsgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "WSGI fast"])
|
||||
asgi_python_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI python"])
|
||||
asgi_fast_avg = statistics.mean([r.avg_time_us for r in all_results if r.name == "ASGI fast"])
|
||||
|
||||
print(f"\nWSGI: Fast parser is {wsgi_python_avg/wsgi_fast_avg:.2f}x faster than Python parser")
|
||||
print(f"ASGI: Fast parser is {asgi_python_avg/asgi_fast_avg:.2f}x faster than Python parser")
|
||||
else:
|
||||
print("\nInstall gunicorn_h1c to see fast parser comparison:")
|
||||
print(" pip install gunicorn_h1c")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -4,12 +4,18 @@
|
||||
|
||||
# Simple WSGI app for benchmarking
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def application(environ, start_response):
|
||||
"""Basic hello world response."""
|
||||
path = environ.get('PATH_INFO', '/')
|
||||
|
||||
if path == '/large':
|
||||
body = b'X' * 65536 # 64KB
|
||||
elif path == '/slow':
|
||||
time.sleep(0.01) # 10ms simulated I/O
|
||||
body = b'Slow response'
|
||||
else:
|
||||
body = b'Hello, World!'
|
||||
|
||||
|
||||
@ -24,9 +24,11 @@ gunicorn main:app --worker-class asgi --bind 0.0.0.0:8000
|
||||
The ASGI worker provides:
|
||||
|
||||
- **HTTP/1.1** with keepalive connections
|
||||
- **HTTP/2** with multiplexing and server push (requires SSL)
|
||||
- **WebSocket** support for real-time applications
|
||||
- **Lifespan protocol** for startup/shutdown hooks
|
||||
- **Optional uvloop** for improved performance
|
||||
- **Optional fast HTTP parser** via C extension for high throughput
|
||||
- **Optional uvloop** for improved event loop performance
|
||||
- **SSL/TLS** support
|
||||
- **uWSGI protocol** for nginx `uwsgi_pass` integration
|
||||
|
||||
@ -225,6 +227,56 @@ asgi_loop = "auto" # Use uvloop if available
|
||||
asgi_lifespan = "auto" # Auto-detect lifespan support
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### Fast HTTP Parser
|
||||
|
||||
For maximum performance, install the optional `gunicorn_h1c` C extension:
|
||||
|
||||
```bash
|
||||
pip install gunicorn[fast]
|
||||
```
|
||||
|
||||
This provides a high-performance HTTP parser using picohttpparser with SIMD
|
||||
optimizations, offering significant speedups for HTTP parsing compared to the
|
||||
pure Python implementation.
|
||||
|
||||
The parser is automatically used when available (`--http-parser auto`), or you
|
||||
can explicitly require it:
|
||||
|
||||
```bash
|
||||
gunicorn myapp:app --worker-class asgi --http-parser fast
|
||||
```
|
||||
|
||||
| Parser | Description |
|
||||
|--------|-------------|
|
||||
| `auto` | Use fast parser if available, otherwise Python (default) |
|
||||
| `fast` | Require fast parser, fail if unavailable |
|
||||
| `python` | Force pure Python parser |
|
||||
|
||||
### Performance Tips
|
||||
|
||||
1. **Use uvloop** for improved event loop performance:
|
||||
```bash
|
||||
pip install uvloop
|
||||
gunicorn myapp:app --worker-class asgi --asgi-loop uvloop
|
||||
```
|
||||
|
||||
2. **Install the fast parser** for optimized HTTP parsing:
|
||||
```bash
|
||||
pip install gunicorn[fast]
|
||||
```
|
||||
|
||||
3. **Tune worker count** based on CPU cores:
|
||||
```bash
|
||||
gunicorn myapp:app --worker-class asgi --workers $(nproc)
|
||||
```
|
||||
|
||||
4. **Increase connections** for I/O-bound applications:
|
||||
```bash
|
||||
gunicorn myapp:app --worker-class asgi --worker-connections 2000
|
||||
```
|
||||
|
||||
## Comparison with Other ASGI Servers
|
||||
|
||||
| Feature | Gunicorn ASGI | Uvicorn | Hypercorn |
|
||||
|
||||
@ -3,6 +3,14 @@
|
||||
|
||||
## unreleased
|
||||
|
||||
### New Features
|
||||
|
||||
- **Fast HTTP Parser (gunicorn_h1c 0.4.1)**: Integrate new exception types and limit
|
||||
parameters from gunicorn_h1c 0.4.1 for both WSGI and ASGI workers
|
||||
- Requires gunicorn_h1c >= 0.4.1 for `http_parser='fast'`
|
||||
- Falls back to Python parser in `auto` mode if version not met
|
||||
- Proper HTTP status codes for limit errors (414, 431)
|
||||
|
||||
### Performance
|
||||
|
||||
- **ASGI HTTP Parser Optimizations**: Improve ASGI worker HTTP parsing performance
|
||||
|
||||
@ -1971,3 +1971,23 @@ need to increase this value.
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
!!! info "Added in 25.0.0"
|
||||
|
||||
### `http_parser`
|
||||
|
||||
**Command line:** `--http-parser STRING`
|
||||
|
||||
**Default:** `'auto'`
|
||||
|
||||
HTTP parser implementation for ASGI workers.
|
||||
|
||||
- auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default)
|
||||
- fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable)
|
||||
- python: Force pure Python PythonProtocol parser
|
||||
|
||||
ASGI workers use callback-based parsing in data_received() for efficient
|
||||
incremental parsing. The gunicorn_h1c C extension provides significantly
|
||||
faster HTTP parsing using picohttpparser with SIMD optimizations.
|
||||
|
||||
Install it with: pip install gunicorn[fast]
|
||||
|
||||
!!! info "Added in 25.0.0"
|
||||
|
||||
@ -6,7 +6,9 @@ WORKDIR /app
|
||||
RUN pip install --no-cache-dir \
|
||||
sentence-transformers \
|
||||
fastapi \
|
||||
pydantic
|
||||
pydantic \
|
||||
pytest \
|
||||
requests
|
||||
|
||||
# Copy gunicorn source
|
||||
COPY . /app/gunicorn-src
|
||||
|
||||
543
gunicorn/asgi/parser.py
Normal file
543
gunicorn/asgi/parser.py
Normal file
@ -0,0 +1,543 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
HTTP parser for ASGI workers.
|
||||
|
||||
Provides callback-based parsing using either the fast C parser (gunicorn_h1c)
|
||||
or the pure Python PythonProtocol fallback.
|
||||
"""
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Base error raised during HTTP parsing."""
|
||||
|
||||
|
||||
class LimitRequestLine(ParseError):
|
||||
"""Request line exceeds configured limit."""
|
||||
|
||||
|
||||
class LimitRequestHeaders(ParseError):
|
||||
"""Too many headers or header field too large."""
|
||||
|
||||
|
||||
class InvalidRequestMethod(ParseError):
|
||||
"""Invalid HTTP method."""
|
||||
|
||||
|
||||
class InvalidHTTPVersion(ParseError):
|
||||
"""Invalid HTTP version."""
|
||||
|
||||
|
||||
class InvalidHeaderName(ParseError):
|
||||
"""Invalid header name."""
|
||||
|
||||
|
||||
class InvalidHeader(ParseError):
|
||||
"""Invalid header value."""
|
||||
|
||||
|
||||
class PythonProtocol:
|
||||
"""Callback-based HTTP/1.1 parser (pure Python fallback).
|
||||
|
||||
Mirrors H1CProtocol interface for seamless switching between
|
||||
the C extension and pure Python implementations.
|
||||
|
||||
Callbacks:
|
||||
on_message_begin: () -> None - Called when request starts
|
||||
on_url: (url: bytes) -> None - Called with request URL/path
|
||||
on_header: (name: bytes, value: bytes) -> None - Called for each header
|
||||
on_headers_complete: () -> bool - Called when headers done (return True to skip body)
|
||||
on_body: (chunk: bytes) -> None - Called with body data chunks
|
||||
on_message_complete: () -> None - Called when request is complete
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_on_message_begin', '_on_url', '_on_header',
|
||||
'_on_headers_complete', '_on_body', '_on_message_complete',
|
||||
'_state', '_buffer', '_headers_list',
|
||||
'method', 'path', 'http_version', 'headers',
|
||||
'content_length', 'is_chunked', 'should_keep_alive', 'is_complete',
|
||||
'_body_remaining', '_skip_body',
|
||||
'_chunk_state', '_chunk_size', '_chunk_remaining',
|
||||
'_limit_request_line', '_limit_request_fields', '_limit_request_field_size',
|
||||
'_permit_unconventional_http_method', '_permit_unconventional_http_version',
|
||||
'_header_count',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_message_begin=None,
|
||||
on_url=None,
|
||||
on_header=None,
|
||||
on_headers_complete=None,
|
||||
on_body=None,
|
||||
on_message_complete=None,
|
||||
limit_request_line=8190,
|
||||
limit_request_fields=100,
|
||||
limit_request_field_size=8190,
|
||||
permit_unconventional_http_method=False,
|
||||
permit_unconventional_http_version=False,
|
||||
):
|
||||
self._on_message_begin = on_message_begin
|
||||
self._on_url = on_url
|
||||
self._on_header = on_header
|
||||
self._on_headers_complete = on_headers_complete
|
||||
self._on_body = on_body
|
||||
self._on_message_complete = on_message_complete
|
||||
|
||||
# Store limits
|
||||
self._limit_request_line = limit_request_line
|
||||
self._limit_request_fields = limit_request_fields
|
||||
self._limit_request_field_size = limit_request_field_size
|
||||
self._permit_unconventional_http_method = permit_unconventional_http_method
|
||||
self._permit_unconventional_http_version = permit_unconventional_http_version
|
||||
self._header_count = 0
|
||||
|
||||
# Parser state: request_line, headers, body, chunked_size, chunked_data, complete
|
||||
self._state = 'request_line'
|
||||
self._buffer = bytearray()
|
||||
self._headers_list = []
|
||||
|
||||
# Request info (populated during parsing)
|
||||
self.method = None
|
||||
self.path = None
|
||||
self.http_version = None
|
||||
self.headers = []
|
||||
self.content_length = None
|
||||
self.is_chunked = False
|
||||
self.should_keep_alive = True
|
||||
self.is_complete = False
|
||||
|
||||
# Body state
|
||||
self._body_remaining = 0
|
||||
self._skip_body = False
|
||||
|
||||
# Chunked transfer state
|
||||
self._chunk_state = 'size' # size, data, trailer
|
||||
self._chunk_size = 0
|
||||
self._chunk_remaining = 0
|
||||
|
||||
def feed(self, data):
|
||||
"""Process data, fire callbacks synchronously.
|
||||
|
||||
Args:
|
||||
data: bytes or bytearray of incoming data
|
||||
|
||||
Raises:
|
||||
ParseError: If the HTTP request is malformed
|
||||
"""
|
||||
self._buffer.extend(data)
|
||||
|
||||
while self._buffer:
|
||||
if self._state == 'request_line':
|
||||
if not self._parse_request_line():
|
||||
break
|
||||
elif self._state == 'headers':
|
||||
if not self._parse_headers():
|
||||
break
|
||||
elif self._state == 'body':
|
||||
if not self._parse_body():
|
||||
break
|
||||
elif self._state == 'chunked':
|
||||
if not self._parse_chunked_body():
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
def reset(self):
|
||||
"""Reset for next request (keepalive)."""
|
||||
self._state = 'request_line'
|
||||
self._buffer.clear()
|
||||
self._headers_list = []
|
||||
self.method = None
|
||||
self.path = None
|
||||
self.http_version = None
|
||||
self.headers = []
|
||||
self.content_length = None
|
||||
self.is_chunked = False
|
||||
self.should_keep_alive = True
|
||||
self.is_complete = False
|
||||
self._body_remaining = 0
|
||||
self._skip_body = False
|
||||
self._chunk_state = 'size'
|
||||
self._chunk_size = 0
|
||||
self._chunk_remaining = 0
|
||||
self._header_count = 0
|
||||
|
||||
def _parse_request_line(self):
|
||||
"""Parse request line, return True if complete."""
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
# Check request line length limit
|
||||
if self._limit_request_line > 0 and idx > self._limit_request_line:
|
||||
raise LimitRequestLine("Request line is too large")
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
# Parse: METHOD PATH HTTP/x.y
|
||||
parts = line.split(b' ', 2)
|
||||
if len(parts) != 3:
|
||||
raise ParseError("Invalid request line")
|
||||
|
||||
self.method = parts[0]
|
||||
self.path = parts[1]
|
||||
|
||||
# Validate method
|
||||
if not self._permit_unconventional_http_method:
|
||||
if not self._is_valid_method(self.method):
|
||||
raise InvalidRequestMethod(self.method.decode('latin-1'))
|
||||
|
||||
# Parse version
|
||||
version = parts[2]
|
||||
if version == b'HTTP/1.1':
|
||||
self.http_version = (1, 1)
|
||||
elif version == b'HTTP/1.0':
|
||||
self.http_version = (1, 0)
|
||||
else:
|
||||
if not self._permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
# Try to parse other HTTP/1.x versions if permitted
|
||||
if version.startswith(b'HTTP/1.'):
|
||||
try:
|
||||
minor = int(version[7:])
|
||||
self.http_version = (1, minor)
|
||||
except ValueError:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
else:
|
||||
raise InvalidHTTPVersion(version.decode('latin-1'))
|
||||
|
||||
if self._on_message_begin:
|
||||
self._on_message_begin()
|
||||
if self._on_url:
|
||||
self._on_url(self.path)
|
||||
|
||||
self._state = 'headers'
|
||||
return True
|
||||
|
||||
def _parse_headers(self):
|
||||
"""Parse headers, return True if headers are complete."""
|
||||
while True:
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
if not line:
|
||||
# Empty line = end of headers
|
||||
self._finalize_headers()
|
||||
return True
|
||||
|
||||
# Check header field size limit
|
||||
if self._limit_request_field_size > 0 and len(line) > self._limit_request_field_size:
|
||||
raise LimitRequestHeaders("Request header field is too large")
|
||||
|
||||
# Check header count limit
|
||||
self._header_count += 1
|
||||
if self._limit_request_fields > 0 and self._header_count > self._limit_request_fields:
|
||||
raise LimitRequestHeaders("Too many headers")
|
||||
|
||||
# Parse header
|
||||
colon = line.find(b':')
|
||||
if colon == -1:
|
||||
raise InvalidHeader("Missing colon in header")
|
||||
|
||||
name = line[:colon].strip()
|
||||
if not self._is_valid_token(name):
|
||||
raise InvalidHeaderName(name.decode('latin-1'))
|
||||
|
||||
value = line[colon + 1:].strip()
|
||||
if self._has_invalid_header_chars(value):
|
||||
raise InvalidHeader("Invalid characters in header value")
|
||||
|
||||
# Store lowercase name for internal use
|
||||
name_lower = name.lower()
|
||||
self._headers_list.append((name_lower, value))
|
||||
|
||||
if self._on_header:
|
||||
self._on_header(name_lower, value)
|
||||
|
||||
def _finalize_headers(self):
|
||||
"""Called when all headers received."""
|
||||
self.headers = self._headers_list
|
||||
|
||||
# Extract content-length and chunked
|
||||
for name, value in self.headers:
|
||||
if name == b'content-length':
|
||||
self.content_length = int(value)
|
||||
self._body_remaining = self.content_length
|
||||
elif name == b'transfer-encoding':
|
||||
self.is_chunked = b'chunked' in value.lower()
|
||||
elif name == b'connection':
|
||||
val = value.lower()
|
||||
if b'close' in val:
|
||||
self.should_keep_alive = False
|
||||
elif b'keep-alive' in val:
|
||||
self.should_keep_alive = True
|
||||
|
||||
# HTTP/1.0 defaults to close
|
||||
if self.http_version == (1, 0) and self.should_keep_alive:
|
||||
# Only keep-alive if explicitly requested
|
||||
has_keepalive = any(
|
||||
name == b'connection' and b'keep-alive' in value.lower()
|
||||
for name, value in self.headers
|
||||
)
|
||||
if not has_keepalive:
|
||||
self.should_keep_alive = False
|
||||
|
||||
if self._on_headers_complete:
|
||||
self._skip_body = self._on_headers_complete()
|
||||
|
||||
# Determine next state
|
||||
if self._skip_body:
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
elif self.is_chunked:
|
||||
self._state = 'chunked'
|
||||
self._chunk_state = 'size'
|
||||
elif self.content_length and self.content_length > 0:
|
||||
self._state = 'body'
|
||||
else:
|
||||
# No body
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
|
||||
def _parse_body(self):
|
||||
"""Parse Content-Length delimited body."""
|
||||
if not self._buffer or self._body_remaining <= 0:
|
||||
return False
|
||||
|
||||
chunk_size = min(len(self._buffer), self._body_remaining)
|
||||
chunk = bytes(self._buffer[:chunk_size])
|
||||
del self._buffer[:chunk_size]
|
||||
self._body_remaining -= chunk_size
|
||||
|
||||
if self._on_body:
|
||||
self._on_body(chunk)
|
||||
|
||||
if self._body_remaining <= 0:
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
|
||||
return True
|
||||
|
||||
def _parse_chunked_body(self):
|
||||
"""Parse chunked transfer encoding."""
|
||||
while self._buffer:
|
||||
if self._chunk_state == 'size':
|
||||
# Looking for chunk size line
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
size_line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
# Handle chunk extensions (e.g., "5;ext=value")
|
||||
semicolon = size_line.find(b';')
|
||||
if semicolon != -1:
|
||||
size_line = size_line[:semicolon].strip()
|
||||
|
||||
try:
|
||||
self._chunk_size = int(size_line, 16)
|
||||
except ValueError:
|
||||
raise ParseError("Invalid chunk size")
|
||||
|
||||
if self._chunk_size == 0:
|
||||
# Final chunk - skip trailers
|
||||
self._chunk_state = 'trailer'
|
||||
else:
|
||||
self._chunk_remaining = self._chunk_size
|
||||
self._chunk_state = 'data'
|
||||
|
||||
elif self._chunk_state == 'data':
|
||||
# Reading chunk data
|
||||
if not self._buffer:
|
||||
return False
|
||||
|
||||
to_read = min(len(self._buffer), self._chunk_remaining)
|
||||
chunk = bytes(self._buffer[:to_read])
|
||||
del self._buffer[:to_read]
|
||||
self._chunk_remaining -= to_read
|
||||
|
||||
if self._on_body:
|
||||
self._on_body(chunk)
|
||||
|
||||
if self._chunk_remaining == 0:
|
||||
# Need to consume trailing CRLF
|
||||
self._chunk_state = 'crlf'
|
||||
|
||||
elif self._chunk_state == 'crlf':
|
||||
# Skip CRLF after chunk data
|
||||
if len(self._buffer) < 2:
|
||||
return False
|
||||
del self._buffer[:2] # Skip \r\n
|
||||
self._chunk_state = 'size'
|
||||
|
||||
elif self._chunk_state == 'trailer':
|
||||
# Skip trailer headers
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
if not line:
|
||||
# Empty line = end of trailers
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_valid_method(self, method):
|
||||
"""Check if method is valid token with conventional restrictions."""
|
||||
if not method:
|
||||
return False
|
||||
# Check length (3-20 chars)
|
||||
if not 3 <= len(method) <= 20:
|
||||
return False
|
||||
# Check for lowercase or # (unconventional)
|
||||
for c in method:
|
||||
if c in b'abcdefghijklmnopqrstuvwxyz#':
|
||||
return False
|
||||
return self._is_valid_token(method)
|
||||
|
||||
def _is_valid_token(self, data):
|
||||
"""Check if data contains only RFC 9110 token characters."""
|
||||
if not data:
|
||||
return False
|
||||
for c in data:
|
||||
if c < 0x21 or c > 0x7e:
|
||||
return False
|
||||
# RFC 9110 delimiters: "(),/:;<=>?@[\]{}
|
||||
if c in b'"(),/:;<=>?@[\\]{}"':
|
||||
return False
|
||||
return True
|
||||
|
||||
def _has_invalid_header_chars(self, value):
|
||||
"""Check for NUL, CR, LF in header value."""
|
||||
return b'\x00' in value or b'\r' in value or b'\n' in value
|
||||
|
||||
|
||||
class CallbackRequest:
|
||||
"""Request object built from callback parser state.
|
||||
|
||||
Works with both H1CProtocol (C extension) and PythonProtocol.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'method', 'uri', 'path', 'query', 'fragment', 'version',
|
||||
'headers', 'headers_bytes', 'scheme', 'raw_path',
|
||||
'content_length', 'chunked', 'must_close',
|
||||
'proxy_protocol_info', '_expect_100_continue',
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.version = None
|
||||
self.headers = []
|
||||
self.headers_bytes = []
|
||||
self.scheme = "http"
|
||||
self.raw_path = b''
|
||||
self.content_length = 0
|
||||
self.chunked = False
|
||||
self.must_close = False
|
||||
self.proxy_protocol_info = None
|
||||
self._expect_100_continue = False
|
||||
|
||||
@classmethod
|
||||
def from_parser(cls, parser, is_ssl=False):
|
||||
"""Build request from callback parser state.
|
||||
|
||||
Args:
|
||||
parser: H1CProtocol or PythonProtocol instance
|
||||
is_ssl: Whether connection is SSL/TLS
|
||||
|
||||
Returns:
|
||||
CallbackRequest instance
|
||||
"""
|
||||
from urllib.parse import unquote_to_bytes
|
||||
|
||||
req = cls()
|
||||
req.method = parser.method.decode('ascii')
|
||||
|
||||
# Parse path and query from URL
|
||||
# Per ASGI spec:
|
||||
# - path: percent-decoded UTF-8 string
|
||||
# - raw_path: original bytes as received
|
||||
raw_url = parser.path
|
||||
if b'?' in raw_url:
|
||||
path_part, query_part = raw_url.split(b'?', 1)
|
||||
req.raw_path = path_part # Store original bytes
|
||||
req.path = unquote_to_bytes(path_part).decode('utf-8', errors='replace')
|
||||
req.query = query_part.decode('latin-1')
|
||||
else:
|
||||
req.raw_path = raw_url # Store original bytes
|
||||
req.path = unquote_to_bytes(raw_url).decode('utf-8', errors='replace')
|
||||
req.query = ''
|
||||
|
||||
req.uri = raw_url.decode('latin-1')
|
||||
req.fragment = ''
|
||||
req.version = parser.http_version
|
||||
|
||||
# Headers - store both bytes (for ASGI scope) and strings (for compatibility)
|
||||
req.headers_bytes = list(parser.headers)
|
||||
req.headers = [
|
||||
(n.decode('latin-1').upper(), v.decode('latin-1'))
|
||||
for n, v in parser.headers
|
||||
]
|
||||
|
||||
req.scheme = 'https' if is_ssl else 'http'
|
||||
req.content_length = parser.content_length or 0
|
||||
req.chunked = parser.is_chunked
|
||||
req.must_close = not parser.should_keep_alive
|
||||
|
||||
# Check for Expect: 100-continue
|
||||
for name, value in parser.headers:
|
||||
if name == b'expect' and value.lower() == b'100-continue':
|
||||
req._expect_100_continue = True
|
||||
break
|
||||
|
||||
return req
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for name, value in self.headers:
|
||||
if name == "CONNECTION":
|
||||
v = value.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for h, v in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class WebSocketProtocol:
|
||||
"""WebSocket connection handler for ASGI applications."""
|
||||
"""WebSocket connection handler for ASGI applications.
|
||||
|
||||
def __init__(self, transport, reader, scope, app, log):
|
||||
Uses callback-based data feeding instead of StreamReader for efficiency.
|
||||
Data is fed via feed_data() from the parent protocol's data_received().
|
||||
"""
|
||||
|
||||
def __init__(self, transport, scope, app, log):
|
||||
"""Initialize WebSocket protocol handler.
|
||||
|
||||
Args:
|
||||
transport: asyncio transport for writing
|
||||
reader: asyncio StreamReader for reading
|
||||
scope: ASGI WebSocket scope dict
|
||||
app: ASGI application callable
|
||||
log: Logger instance
|
||||
"""
|
||||
self.transport = transport
|
||||
self.reader = reader
|
||||
self.scope = scope
|
||||
self.app = app
|
||||
self.log = log
|
||||
@ -70,6 +72,26 @@ class WebSocketProtocol:
|
||||
# Receive queue for incoming messages
|
||||
self._receive_queue = asyncio.Queue()
|
||||
|
||||
# Callback-based data reception (replaces StreamReader)
|
||||
self._buffer = bytearray()
|
||||
self._data_event = asyncio.Event()
|
||||
self._eof = False
|
||||
|
||||
def feed_data(self, data):
|
||||
"""Feed incoming data from the parent protocol's data_received().
|
||||
|
||||
Args:
|
||||
data: bytes received on the connection
|
||||
"""
|
||||
if data:
|
||||
self._buffer.extend(data)
|
||||
self._data_event.set()
|
||||
|
||||
def feed_eof(self):
|
||||
"""Signal that the connection has been closed."""
|
||||
self._eof = True
|
||||
self._data_event.set()
|
||||
|
||||
async def run(self):
|
||||
"""Run the WebSocket ASGI application."""
|
||||
# Send initial connect event
|
||||
@ -295,14 +317,25 @@ class WebSocketProtocol:
|
||||
return (opcode, payload)
|
||||
|
||||
async def _read_exact(self, n):
|
||||
"""Read exactly n bytes from the reader."""
|
||||
try:
|
||||
data = await self.reader.readexactly(n)
|
||||
return data
|
||||
except asyncio.IncompleteReadError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
"""Read exactly n bytes from internal buffer.
|
||||
|
||||
Waits for data via the callback-fed buffer instead of StreamReader.
|
||||
"""
|
||||
while len(self._buffer) < n:
|
||||
if self._eof:
|
||||
return None
|
||||
self._data_event.clear()
|
||||
# Critical: check buffer AGAIN after clearing to avoid race
|
||||
# condition where data arrives between clear() and wait()
|
||||
if len(self._buffer) >= n:
|
||||
break
|
||||
await self._data_event.wait()
|
||||
if self._eof and len(self._buffer) < n:
|
||||
return None
|
||||
|
||||
data = bytes(self._buffer[:n])
|
||||
del self._buffer[:n]
|
||||
return data
|
||||
|
||||
def _unmask(self, payload, masking_key):
|
||||
"""Unmask WebSocket payload data."""
|
||||
|
||||
@ -2772,6 +2772,22 @@ def validate_asgi_lifespan(val):
|
||||
return val
|
||||
|
||||
|
||||
def validate_http_parser(val):
|
||||
"""Validate http_parser setting.
|
||||
|
||||
Accepts: auto, fast, python
|
||||
"""
|
||||
if val is None:
|
||||
return "auto"
|
||||
if not isinstance(val, str):
|
||||
raise TypeError("http_parser must be a string")
|
||||
val = val.lower().strip()
|
||||
valid_values = ("auto", "fast", "python")
|
||||
if val not in valid_values:
|
||||
raise ValueError("http_parser must be one of: %s" % ", ".join(valid_values))
|
||||
return val
|
||||
|
||||
|
||||
class ASGILoop(Setting):
|
||||
name = "asgi_loop"
|
||||
section = "Worker Processes"
|
||||
@ -2845,6 +2861,30 @@ class ASGIDisconnectGracePeriod(Setting):
|
||||
"""
|
||||
|
||||
|
||||
class HttpParser(Setting):
|
||||
name = "http_parser"
|
||||
section = "Worker Processes"
|
||||
cli = ["--http-parser"]
|
||||
meta = "STRING"
|
||||
validator = validate_http_parser
|
||||
default = "auto"
|
||||
desc = """\
|
||||
HTTP parser implementation for ASGI workers.
|
||||
|
||||
- auto: Use H1CProtocol if gunicorn_h1c is available, else PythonProtocol (default)
|
||||
- fast: Require H1CProtocol from gunicorn_h1c (fail if unavailable)
|
||||
- python: Force pure Python PythonProtocol parser
|
||||
|
||||
ASGI workers use callback-based parsing in data_received() for efficient
|
||||
incremental parsing. The gunicorn_h1c C extension provides significantly
|
||||
faster HTTP parsing using picohttpparser with SIMD optimizations.
|
||||
|
||||
Install it with: pip install gunicorn[fast]
|
||||
|
||||
.. versionadded:: 25.0.0
|
||||
"""
|
||||
|
||||
|
||||
class RootPath(Setting):
|
||||
name = "root_path"
|
||||
section = "Server Mechanics"
|
||||
|
||||
@ -341,14 +341,24 @@ class Logger:
|
||||
|
||||
return atoms
|
||||
|
||||
@property
|
||||
def access_log_enabled(self):
|
||||
"""Check if access logging is enabled.
|
||||
|
||||
Used by protocol handlers to skip building log data when logging is disabled.
|
||||
"""
|
||||
return bool(
|
||||
self.cfg.accesslog or self.cfg.logconfig or
|
||||
self.cfg.logconfig_dict or self.cfg.logconfig_json or
|
||||
(self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog)
|
||||
)
|
||||
|
||||
def access(self, resp, req, environ, request_time):
|
||||
""" See http://httpd.apache.org/docs/2.0/logs.html#combined
|
||||
for format details
|
||||
"""
|
||||
|
||||
if not (self.cfg.accesslog or self.cfg.logconfig or
|
||||
self.cfg.logconfig_dict or self.cfg.logconfig_json or
|
||||
(self.cfg.syslog and not self.cfg.disable_redirect_access_to_syslog)):
|
||||
if not self.access_log_enabled:
|
||||
return
|
||||
|
||||
# wrap atoms:
|
||||
|
||||
@ -114,11 +114,13 @@ class ChunkMissingTerminator(IOError):
|
||||
|
||||
|
||||
class LimitRequestLine(ParseException):
|
||||
def __init__(self, size, max_size):
|
||||
def __init__(self, size, max_size=None):
|
||||
self.size = size
|
||||
self.max_size = max_size
|
||||
|
||||
def __str__(self):
|
||||
if self.max_size is None:
|
||||
return str(self.size)
|
||||
return "Request Line is too large (%s > %s)" % (self.size, self.max_size)
|
||||
|
||||
|
||||
|
||||
@ -21,6 +21,80 @@ from gunicorn.http.errors import InvalidSchemeHeaders
|
||||
from gunicorn.util import bytes_to_str, split_request_uri
|
||||
|
||||
|
||||
# Fast parser availability (cached at module level)
|
||||
_fast_parser_available = None
|
||||
_fast_parser_module = None
|
||||
|
||||
# Compatibility flags not supported by the fast parser
|
||||
_FAST_PARSER_INCOMPATIBLE_FLAGS = (
|
||||
'permit_obsolete_folding',
|
||||
'strip_header_spaces',
|
||||
)
|
||||
|
||||
|
||||
def _check_fast_parser(cfg):
|
||||
"""Check if fast C parser is available and should be used.
|
||||
|
||||
Returns False if:
|
||||
- http_parser='python' is explicitly set
|
||||
- gunicorn_h1c is not installed (in 'auto' mode)
|
||||
- gunicorn_h1c < 0.4.1 (in 'auto' mode)
|
||||
- Incompatible compatibility flags are enabled (in 'auto' mode)
|
||||
|
||||
Raises RuntimeError if:
|
||||
- http_parser='fast' but gunicorn_h1c is not installed
|
||||
- http_parser='fast' but gunicorn_h1c < 0.4.1
|
||||
- http_parser='fast' but incompatible flags are enabled
|
||||
"""
|
||||
global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement
|
||||
|
||||
parser_setting = getattr(cfg, 'http_parser', 'auto')
|
||||
if parser_setting == 'python':
|
||||
return False
|
||||
|
||||
if _fast_parser_available is None:
|
||||
try:
|
||||
import gunicorn_h1c
|
||||
_fast_parser_available = True
|
||||
_fast_parser_module = gunicorn_h1c
|
||||
except ImportError:
|
||||
_fast_parser_available = False
|
||||
|
||||
if not _fast_parser_available and parser_setting == 'fast':
|
||||
raise RuntimeError("gunicorn_h1c not installed but http_parser='fast'")
|
||||
|
||||
if not _fast_parser_available:
|
||||
return False
|
||||
|
||||
# Require >= 0.4.1 for limit enforcement
|
||||
if not hasattr(_fast_parser_module, 'LimitRequestLine'):
|
||||
if parser_setting == 'fast':
|
||||
raise RuntimeError(
|
||||
"gunicorn_h1c >= 0.4.1 required for http_parser='fast'. "
|
||||
"Please upgrade: pip install --upgrade gunicorn_h1c"
|
||||
)
|
||||
# In 'auto' mode, fall back to Python parser
|
||||
return False
|
||||
|
||||
# Check for incompatible compatibility flags
|
||||
incompatible = []
|
||||
for flag in _FAST_PARSER_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
incompatible.append(flag)
|
||||
|
||||
if incompatible:
|
||||
if parser_setting == 'fast':
|
||||
raise RuntimeError(
|
||||
"http_parser='fast' is incompatible with compatibility flags: %s. "
|
||||
"Use http_parser='python' or disable these flags."
|
||||
% ', '.join(incompatible)
|
||||
)
|
||||
# In 'auto' mode, fall back to Python parser
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# PROXY protocol v2 constants
|
||||
PP_V2_SIGNATURE = b"\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A"
|
||||
|
||||
@ -99,7 +173,7 @@ class Message:
|
||||
or self.limit_request_fields > MAX_HEADERS):
|
||||
self.limit_request_fields = MAX_HEADERS
|
||||
self.limit_request_field_size = cfg.limit_request_field_size
|
||||
if self.limit_request_field_size < 0:
|
||||
if self.limit_request_field_size <= 0:
|
||||
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
|
||||
# set max header buffer size
|
||||
@ -315,12 +389,16 @@ class Request(Message):
|
||||
|
||||
# get max request line size
|
||||
self.limit_request_line = cfg.limit_request_line
|
||||
if (self.limit_request_line < 0
|
||||
if (self.limit_request_line <= 0
|
||||
or self.limit_request_line >= MAX_REQUEST_LINE):
|
||||
self.limit_request_line = MAX_REQUEST_LINE
|
||||
|
||||
self.req_number = req_number
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Check if fast parser should be used
|
||||
self._use_fast = _check_fast_parser(cfg)
|
||||
|
||||
super().__init__(cfg, unreader, peer_addr)
|
||||
|
||||
def get_data(self, unreader, buf, stop=False):
|
||||
@ -340,6 +418,102 @@ class Request(Message):
|
||||
if mode != "off" and self.req_number == 1:
|
||||
buf = self._handle_proxy_protocol(unreader, buf, mode)
|
||||
|
||||
# Use fast parser if available
|
||||
if self._use_fast:
|
||||
return self._parse_fast(unreader, buf)
|
||||
|
||||
return self._parse_python(unreader, buf)
|
||||
|
||||
def _parse_fast(self, unreader, buf):
|
||||
"""Parse request using fast C parser (gunicorn_h1c >= 0.4.1)."""
|
||||
# Read until we have complete headers
|
||||
data = bytes(buf)
|
||||
last_len = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Pass all limit parameters to C parser
|
||||
result = _fast_parser_module.parse_request(
|
||||
data,
|
||||
last_len=last_len,
|
||||
limit_request_line=self.limit_request_line,
|
||||
limit_request_fields=self.limit_request_fields,
|
||||
limit_request_field_size=self.limit_request_field_size,
|
||||
permit_unconventional_http_method=self.cfg.permit_unconventional_http_method,
|
||||
permit_unconventional_http_version=self.cfg.permit_unconventional_http_version,
|
||||
)
|
||||
break
|
||||
except _fast_parser_module.IncompleteError:
|
||||
last_len = len(data)
|
||||
self.read_into(unreader, buf)
|
||||
data = bytes(buf)
|
||||
if len(data) > self.max_buffer_headers + self.limit_request_line:
|
||||
raise LimitRequestHeaders("max buffer headers")
|
||||
except _fast_parser_module.LimitRequestLine as e:
|
||||
raise LimitRequestLine(str(e))
|
||||
except _fast_parser_module.LimitRequestHeaders as e:
|
||||
raise LimitRequestHeaders(str(e))
|
||||
except _fast_parser_module.InvalidRequestMethod as e:
|
||||
raise InvalidRequestMethod(str(e))
|
||||
except _fast_parser_module.InvalidHTTPVersion as e:
|
||||
raise InvalidHTTPVersion(str(e))
|
||||
except _fast_parser_module.InvalidHeaderName as e:
|
||||
raise InvalidHeaderName(str(e))
|
||||
except _fast_parser_module.InvalidHeader as e:
|
||||
raise InvalidHeader(str(e))
|
||||
except _fast_parser_module.ParseError as e:
|
||||
raise InvalidRequestLine(str(e))
|
||||
|
||||
# Extract parsed data
|
||||
self.method = bytes_to_str(result['method'])
|
||||
self.uri = bytes_to_str(result['path'])
|
||||
|
||||
# Casefold method if configured (validation done by C parser)
|
||||
if self.cfg.casefold_http_method:
|
||||
self.method = self.method.upper()
|
||||
|
||||
# Parse URI parts
|
||||
if len(self.uri) == 0:
|
||||
raise InvalidRequestLine(self.uri)
|
||||
try:
|
||||
parts = split_request_uri(self.uri)
|
||||
except ValueError:
|
||||
raise InvalidRequestLine(self.uri)
|
||||
self.path = parts.path or ""
|
||||
self.query = parts.query or ""
|
||||
self.fragment = parts.fragment or ""
|
||||
|
||||
# Version (validation done by C parser)
|
||||
self.version = (1, result['minor_version'])
|
||||
|
||||
# Headers - convert bytes to strings with uppercase names
|
||||
# gunicorn_h1c returns headers as (bytes, bytes) tuples
|
||||
# Header name/value validation done by C parser
|
||||
self.headers = []
|
||||
for name_bytes, value_bytes in result['headers']:
|
||||
name = bytes_to_str(name_bytes).upper()
|
||||
value = bytes_to_str(value_bytes)
|
||||
|
||||
# Handle underscore in header names (policy decision, not validation)
|
||||
if "_" in name:
|
||||
forwarder_headers = self.cfg.forwarder_headers
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
continue
|
||||
else:
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
self.headers.append((name, value))
|
||||
|
||||
# Return remaining data after headers
|
||||
consumed = result['consumed']
|
||||
return data[consumed:]
|
||||
|
||||
def _parse_python(self, unreader, buf):
|
||||
"""Parse request using pure Python parser."""
|
||||
# Get request line
|
||||
line, buf = self.read_line(unreader, buf, self.limit_request_line)
|
||||
|
||||
|
||||
@ -71,6 +71,11 @@ class HTTP2Stream:
|
||||
self.priority_depends_on = 0
|
||||
self.priority_exclusive = False
|
||||
|
||||
# Streaming body support (avoids buffering entire uploads)
|
||||
self._body_chunks = []
|
||||
self._body_event = None # Lazy-init asyncio.Event
|
||||
self._body_complete = False
|
||||
|
||||
@property
|
||||
def is_client_stream(self):
|
||||
"""Check if this is a client-initiated stream (odd stream ID)."""
|
||||
@ -122,7 +127,7 @@ class HTTP2Stream:
|
||||
self.request_complete = True
|
||||
|
||||
def receive_data(self, data, end_stream=False):
|
||||
"""Process received DATA frame.
|
||||
"""Process received DATA frame with streaming support.
|
||||
|
||||
Args:
|
||||
data: Bytes received
|
||||
@ -137,11 +142,21 @@ class HTTP2Stream:
|
||||
f"Cannot receive data in state {self.state.name}"
|
||||
)
|
||||
|
||||
# Add to chunks queue for streaming reads
|
||||
if data:
|
||||
self._body_chunks.append(data)
|
||||
if self._body_event:
|
||||
self._body_event.set()
|
||||
|
||||
# Also write to legacy BytesIO for compatibility
|
||||
self.request_body.write(data)
|
||||
|
||||
if end_stream:
|
||||
self._half_close_remote()
|
||||
self.request_complete = True
|
||||
self._body_complete = True
|
||||
if self._body_event:
|
||||
self._body_event.set()
|
||||
|
||||
def receive_trailers(self, trailers):
|
||||
"""Process received trailing headers.
|
||||
@ -283,6 +298,35 @@ class HTTP2Stream:
|
||||
"""
|
||||
return self.request_body.getvalue()
|
||||
|
||||
async def read_body_chunk(self):
|
||||
"""Read next body chunk asynchronously for streaming.
|
||||
|
||||
Returns:
|
||||
bytes: Next chunk of body data, or None if body is complete.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Initialize event lazily (avoids event loop issues at construction)
|
||||
if self._body_event is None:
|
||||
self._body_event = asyncio.Event()
|
||||
# If data already arrived before event existed, set it now
|
||||
# This prevents race where DATA frames arrive before first read
|
||||
if self._body_chunks or self._body_complete:
|
||||
self._body_event.set()
|
||||
|
||||
while True:
|
||||
# Return chunk if available
|
||||
if self._body_chunks:
|
||||
return self._body_chunks.pop(0)
|
||||
|
||||
# No more data expected
|
||||
if self._body_complete:
|
||||
return None
|
||||
|
||||
# Wait for more data
|
||||
self._body_event.clear()
|
||||
await self._body_event.wait()
|
||||
|
||||
def get_pseudo_headers(self):
|
||||
"""Extract HTTP/2 pseudo-headers from request headers.
|
||||
|
||||
|
||||
@ -53,6 +53,7 @@ tornado = ["tornado>=6.5.0"]
|
||||
gthread = []
|
||||
setproctitle = ["setproctitle"]
|
||||
http2 = ["h2>=4.1.0"]
|
||||
fast = ["gunicorn_h1c>=0.4.1"]
|
||||
testing = [
|
||||
"gevent>=24.10.1",
|
||||
"eventlet>=0.40.3",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
gevent
|
||||
eventlet
|
||||
coverage
|
||||
pytest>=7.2.0
|
||||
pytest-cov
|
||||
pytest-asyncio
|
||||
gunicorn_h1c>=0.4.1
|
||||
|
||||
@ -7,8 +7,21 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
# Add the tests directory to sys.path so test support modules can be imported
|
||||
# as 'tests.module_name' (e.g., 'tests.support_dirty_apps:CounterApp')
|
||||
tests_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if tests_dir not in sys.path:
|
||||
sys.path.insert(0, tests_dir)
|
||||
|
||||
|
||||
@pytest.fixture(params=["python", "fast"])
|
||||
def http_parser(request):
|
||||
"""Parametrize tests over http_parser implementations."""
|
||||
if request.param == "fast":
|
||||
gunicorn_h1c = pytest.importorskip("gunicorn_h1c", reason="gunicorn_h1c required")
|
||||
# Require >= 0.4.1 for limit enforcement
|
||||
if not hasattr(gunicorn_h1c, 'LimitRequestLine'):
|
||||
pytest.skip("gunicorn_h1c >= 0.4.1 required")
|
||||
return request.param
|
||||
|
||||
1
tests/requests/invalid/limit_header_default_01.http
Normal file
1
tests/requests/invalid/limit_header_default_01.http
Normal file
File diff suppressed because one or more lines are too long
11
tests/requests/invalid/limit_header_default_01.py
Normal file
11
tests/requests/invalid/limit_header_default_01.py
Normal file
@ -0,0 +1,11 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.http.errors import LimitRequestHeaders
|
||||
|
||||
cfg = Config()
|
||||
# Setting limit_request_field_size=0 should use default max (8190)
|
||||
cfg.set('limit_request_field_size', 0)
|
||||
request = LimitRequestHeaders
|
||||
1
tests/requests/invalid/limit_line_default_01.http
Normal file
1
tests/requests/invalid/limit_line_default_01.http
Normal file
File diff suppressed because one or more lines are too long
11
tests/requests/invalid/limit_line_default_01.py
Normal file
11
tests/requests/invalid/limit_line_default_01.py
Normal file
@ -0,0 +1,11 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.http.errors import LimitRequestLine
|
||||
|
||||
cfg = Config()
|
||||
# Setting limit_request_line=0 should use default max (8190)
|
||||
cfg.set('limit_request_line', 0)
|
||||
request = LimitRequestLine
|
||||
@ -5,8 +5,9 @@
|
||||
from gunicorn.config import Config
|
||||
|
||||
cfg = Config()
|
||||
cfg.set('limit_request_line', 0)
|
||||
cfg.set('limit_request_field_size', 0)
|
||||
# Request line is 8194 bytes, header line is 8209 bytes (both include CRLF)
|
||||
cfg.set('limit_request_line', 8200)
|
||||
cfg.set('limit_request_field_size', 8210)
|
||||
request = {
|
||||
"method": "PUT",
|
||||
"uri":
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
from gunicorn.config import Config
|
||||
|
||||
cfg = Config()
|
||||
cfg.set('limit_request_line', 0)
|
||||
# Header line is 8209 bytes (name + ": " + value + CRLF)
|
||||
cfg.set('limit_request_field_size', 8210)
|
||||
request = {
|
||||
"method": "GET",
|
||||
|
||||
@ -7,10 +7,8 @@ Tests for ASGI worker components.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import ipaddress
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
518
tests/test_asgi_callback_parser.py
Normal file
518
tests/test_asgi_callback_parser.py
Normal file
@ -0,0 +1,518 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for ASGI callback parsers.
|
||||
|
||||
Tests both PythonProtocol and H1CProtocol (if available) to ensure
|
||||
consistent behavior across implementations.
|
||||
"""
|
||||
|
||||
from gunicorn.asgi.parser import PythonProtocol
|
||||
|
||||
|
||||
def get_parser_class(http_parser):
|
||||
"""Get the appropriate parser class for the test parameter."""
|
||||
if http_parser == "fast":
|
||||
from gunicorn_h1c import H1CProtocol
|
||||
return H1CProtocol
|
||||
return PythonProtocol
|
||||
|
||||
|
||||
def normalize_headers(headers):
|
||||
"""Normalize headers to lowercase names for comparison.
|
||||
|
||||
H1CProtocol preserves original case, PythonProtocol lowercases.
|
||||
"""
|
||||
return {name.lower(): value for name, value in headers}
|
||||
|
||||
|
||||
class TestRequestLineParsing:
|
||||
"""Test request line parsing for both implementations."""
|
||||
|
||||
def test_simple_get(self, http_parser):
|
||||
"""Parse a simple GET request."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
events = []
|
||||
|
||||
parser = parser_class(
|
||||
on_message_begin=lambda: events.append('begin'),
|
||||
on_url=lambda url: events.append(('url', url)),
|
||||
on_headers_complete=lambda: events.append('headers_complete'),
|
||||
on_message_complete=lambda: events.append('complete'),
|
||||
)
|
||||
|
||||
parser.feed(b"GET /path HTTP/1.1\r\n\r\n")
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.path == b"/path"
|
||||
assert parser.http_version == (1, 1)
|
||||
assert parser.is_complete
|
||||
assert 'begin' in events
|
||||
assert ('url', b'/path') in events
|
||||
assert 'complete' in events
|
||||
|
||||
def test_post_with_query(self, http_parser):
|
||||
"""Parse a POST request with query string."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(b"POST /api/data?foo=bar&baz=qux HTTP/1.1\r\n\r\n")
|
||||
|
||||
assert parser.method == b"POST"
|
||||
assert parser.path == b"/api/data?foo=bar&baz=qux"
|
||||
assert parser.http_version == (1, 1)
|
||||
assert parser.is_complete
|
||||
|
||||
def test_http_10_version(self, http_parser):
|
||||
"""Parse HTTP/1.0 request."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(b"GET / HTTP/1.0\r\n\r\n")
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.http_version == (1, 0)
|
||||
assert parser.is_complete
|
||||
|
||||
def test_various_methods(self, http_parser):
|
||||
"""Test parsing various HTTP methods."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
methods = [b"GET", b"POST", b"PUT", b"DELETE", b"PATCH", b"HEAD", b"OPTIONS"]
|
||||
|
||||
for method in methods:
|
||||
parser = parser_class()
|
||||
parser.feed(method + b" / HTTP/1.1\r\n\r\n")
|
||||
assert parser.method == method
|
||||
|
||||
|
||||
class TestHeaderParsing:
|
||||
"""Test header parsing for both implementations."""
|
||||
|
||||
def test_single_header(self, http_parser):
|
||||
"""Parse a request with single header."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
headers = []
|
||||
|
||||
parser = parser_class(
|
||||
on_header=lambda n, v: headers.append((n, v)),
|
||||
)
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n")
|
||||
|
||||
assert len(parser.headers) == 1
|
||||
header_dict = normalize_headers(parser.headers)
|
||||
assert header_dict[b"host"] == b"localhost"
|
||||
callback_dict = normalize_headers(headers)
|
||||
assert callback_dict[b"host"] == b"localhost"
|
||||
|
||||
def test_multiple_headers(self, http_parser):
|
||||
"""Parse a request with multiple headers."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"User-Agent: TestClient\r\n"
|
||||
b"Accept: */*\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert len(parser.headers) == 3
|
||||
header_dict = normalize_headers(parser.headers)
|
||||
assert header_dict[b"host"] == b"localhost"
|
||||
assert header_dict[b"user-agent"] == b"TestClient"
|
||||
assert header_dict[b"accept"] == b"*/*"
|
||||
|
||||
def test_header_with_spaces(self, http_parser):
|
||||
"""Parse headers with leading/trailing spaces in values."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost \r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
header_dict = normalize_headers(parser.headers)
|
||||
assert header_dict[b"host"] == b"localhost"
|
||||
|
||||
def test_empty_header_value(self, http_parser):
|
||||
"""Parse header with empty value."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"X-Empty:\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
header_dict = normalize_headers(parser.headers)
|
||||
assert header_dict[b"x-empty"] == b""
|
||||
|
||||
def test_large_header_value(self, http_parser):
|
||||
"""Parse header with large value."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
large_value = b"x" * 4096
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"X-Large: " + large_value + b"\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
header_dict = normalize_headers(parser.headers)
|
||||
assert header_dict[b"x-large"] == large_value
|
||||
|
||||
|
||||
class TestBodyHandling:
|
||||
"""Test body parsing for both implementations."""
|
||||
|
||||
def test_content_length_body(self, http_parser):
|
||||
"""Parse request with Content-Length body."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
parser.feed(
|
||||
b"POST /data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 13\r\n"
|
||||
b"\r\n"
|
||||
b"Hello, World!"
|
||||
)
|
||||
|
||||
assert parser.content_length == 13
|
||||
assert not parser.is_chunked
|
||||
assert b"".join(body_chunks) == b"Hello, World!"
|
||||
assert parser.is_complete
|
||||
|
||||
def test_content_length_incremental(self, http_parser):
|
||||
"""Parse body arriving in multiple chunks."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
# Send headers
|
||||
parser.feed(
|
||||
b"POST /data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
assert not parser.is_complete
|
||||
|
||||
# Send body in parts
|
||||
parser.feed(b"Hello")
|
||||
assert not parser.is_complete
|
||||
parser.feed(b"World")
|
||||
assert parser.is_complete
|
||||
|
||||
assert b"".join(body_chunks) == b"HelloWorld"
|
||||
|
||||
def test_chunked_encoding(self, http_parser):
|
||||
"""Parse chunked transfer-encoded body."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
parser.feed(
|
||||
b"POST /data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\n"
|
||||
b"Hello\r\n"
|
||||
b"6\r\n"
|
||||
b"World!\r\n"
|
||||
b"0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.is_chunked
|
||||
assert b"".join(body_chunks) == b"HelloWorld!"
|
||||
assert parser.is_complete
|
||||
|
||||
def test_chunked_with_extensions(self, http_parser):
|
||||
"""Parse chunked body with chunk extensions."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
parser.feed(
|
||||
b"POST /data HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5;ext=value\r\n"
|
||||
b"Hello\r\n"
|
||||
b"0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert b"".join(body_chunks) == b"Hello"
|
||||
assert parser.is_complete
|
||||
|
||||
def test_no_body_get(self, http_parser):
|
||||
"""GET request has no body."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.content_length is None
|
||||
assert not parser.is_chunked
|
||||
assert body_chunks == []
|
||||
assert parser.is_complete
|
||||
|
||||
|
||||
class TestConnectionHandling:
|
||||
"""Test connection handling and keep-alive for both implementations."""
|
||||
|
||||
def test_http11_keepalive_default(self, http_parser):
|
||||
"""HTTP/1.1 defaults to keep-alive."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_http11_connection_close(self, http_parser):
|
||||
"""HTTP/1.1 with Connection: close."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Connection: close\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
|
||||
def test_http10_no_keepalive(self, http_parser):
|
||||
"""HTTP/1.0 defaults to no keep-alive."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.0\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
|
||||
def test_http10_with_keepalive(self, http_parser):
|
||||
"""HTTP/1.0 with Connection: keep-alive."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
|
||||
parser = parser_class()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.0\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
|
||||
class TestParserReset:
|
||||
"""Test parser reset for keep-alive connections."""
|
||||
|
||||
def test_reset_after_request(self, http_parser):
|
||||
"""Parser can be reset for a new request."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
complete_count = [0]
|
||||
|
||||
parser = parser_class(
|
||||
on_message_complete=lambda: complete_count.__setitem__(0, complete_count[0] + 1),
|
||||
)
|
||||
|
||||
# First request
|
||||
parser.feed(b"GET /first HTTP/1.1\r\n\r\n")
|
||||
assert parser.path == b"/first"
|
||||
assert parser.is_complete
|
||||
|
||||
# Reset and send second request
|
||||
parser.reset()
|
||||
assert not parser.is_complete
|
||||
# H1CProtocol resets to b'', PythonProtocol to None
|
||||
assert not parser.method
|
||||
|
||||
parser.feed(b"GET /second HTTP/1.1\r\n\r\n")
|
||||
assert parser.path == b"/second"
|
||||
assert parser.is_complete
|
||||
|
||||
assert complete_count[0] == 2
|
||||
|
||||
|
||||
class TestCallbackBehavior:
|
||||
"""Test callback behavior consistency."""
|
||||
|
||||
def test_all_callbacks_fire(self, http_parser):
|
||||
"""All callbacks fire in correct order."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
events = []
|
||||
|
||||
parser = parser_class(
|
||||
on_message_begin=lambda: events.append('begin'),
|
||||
on_url=lambda url: events.append(('url', url)),
|
||||
on_header=lambda n, v: events.append(('header', n.lower(), v)),
|
||||
on_headers_complete=lambda: events.append('headers_complete'),
|
||||
on_body=lambda chunk: events.append(('body', chunk)),
|
||||
on_message_complete=lambda: events.append('complete'),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 4\r\n"
|
||||
b"\r\n"
|
||||
b"test"
|
||||
)
|
||||
|
||||
assert events[0] == 'begin'
|
||||
assert events[1] == ('url', b'/')
|
||||
assert ('header', b'host', b'localhost') in events
|
||||
assert ('header', b'content-length', b'4') in events
|
||||
assert 'headers_complete' in events
|
||||
assert ('body', b'test') in events
|
||||
assert events[-1] == 'complete'
|
||||
|
||||
def test_skip_body_on_headers_complete(self, http_parser):
|
||||
"""Return True from on_headers_complete skips body parsing."""
|
||||
parser_class = get_parser_class(http_parser)
|
||||
body_chunks = []
|
||||
|
||||
parser = parser_class(
|
||||
on_headers_complete=lambda: True, # Skip body
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
b"0123456789"
|
||||
)
|
||||
|
||||
assert parser.is_complete
|
||||
assert body_chunks == [] # Body was skipped
|
||||
|
||||
|
||||
class TestCallbackRequest:
|
||||
"""Test CallbackRequest building from parser state."""
|
||||
|
||||
def test_non_ascii_path_decoding(self, http_parser):
|
||||
"""Test that percent-encoded UTF-8 paths are decoded correctly.
|
||||
|
||||
Per ASGI spec:
|
||||
- path: percent-decoded UTF-8 string
|
||||
- raw_path: original bytes as received
|
||||
"""
|
||||
from gunicorn.asgi.parser import CallbackRequest
|
||||
|
||||
parser_class = get_parser_class(http_parser)
|
||||
parser = parser_class()
|
||||
|
||||
# ö = %C3%B6 in UTF-8 percent-encoded
|
||||
parser.feed(b"GET /%C3%B6/ HTTP/1.1\r\nHost: test\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
# path should be percent-decoded UTF-8 string
|
||||
assert request.path == "/\u00f6/" # /ö/
|
||||
# raw_path should be original bytes
|
||||
assert request.raw_path == b"/%C3%B6/"
|
||||
|
||||
def test_non_ascii_path_with_query(self, http_parser):
|
||||
"""Test percent-encoded path with query string."""
|
||||
from gunicorn.asgi.parser import CallbackRequest
|
||||
|
||||
parser_class = get_parser_class(http_parser)
|
||||
parser = parser_class()
|
||||
|
||||
# Japanese: /日本/ = /%E6%97%A5%E6%9C%AC/
|
||||
parser.feed(b"GET /%E6%97%A5%E6%9C%AC/?q=test HTTP/1.1\r\nHost: test\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.path == "/\u65e5\u672c/" # /日本/
|
||||
assert request.raw_path == b"/%E6%97%A5%E6%9C%AC/"
|
||||
assert request.query == "q=test"
|
||||
|
||||
def test_invalid_utf8_path(self, http_parser):
|
||||
"""Test that invalid UTF-8 sequences use replacement character."""
|
||||
from gunicorn.asgi.parser import CallbackRequest
|
||||
|
||||
parser_class = get_parser_class(http_parser)
|
||||
parser = parser_class()
|
||||
|
||||
# %FF is invalid UTF-8
|
||||
parser.feed(b"GET /%FF HTTP/1.1\r\nHost: test\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
# Should use replacement character for invalid bytes
|
||||
assert "\ufffd" in request.path
|
||||
assert request.raw_path == b"/%FF"
|
||||
|
||||
def test_simple_ascii_path(self, http_parser):
|
||||
"""Test that simple ASCII paths work unchanged."""
|
||||
from gunicorn.asgi.parser import CallbackRequest
|
||||
|
||||
parser_class = get_parser_class(http_parser)
|
||||
parser = parser_class()
|
||||
|
||||
parser.feed(b"GET /api/users HTTP/1.1\r\nHost: test\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.path == "/api/users"
|
||||
assert request.raw_path == b"/api/users"
|
||||
|
||||
def test_percent_encoded_ascii(self, http_parser):
|
||||
"""Test percent-encoded ASCII characters."""
|
||||
from gunicorn.asgi.parser import CallbackRequest
|
||||
|
||||
parser_class = get_parser_class(http_parser)
|
||||
parser = parser_class()
|
||||
|
||||
# Space encoded as %20
|
||||
parser.feed(b"GET /hello%20world HTTP/1.1\r\nHost: test\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.path == "/hello world"
|
||||
assert request.raw_path == b"/hello%20world"
|
||||
@ -9,9 +9,10 @@ Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec:
|
||||
https://asgi.readthedocs.io/en/latest/specs/main.html
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
|
||||
|
||||
@ -37,7 +38,9 @@ class TestASGIVersion:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -136,7 +139,9 @@ class TestHTTPScopeKeys:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -653,6 +658,7 @@ class TestStateSharing:
|
||||
request = mock.Mock()
|
||||
request.method = "GET"
|
||||
request.path = "/"
|
||||
request.raw_path = b"/"
|
||||
request.query = ""
|
||||
request.version = (1, 1)
|
||||
request.scheme = "http"
|
||||
@ -730,31 +736,46 @@ class TestHTTPDisconnectEvent:
|
||||
return protocol
|
||||
|
||||
def test_disconnect_event_type(self):
|
||||
"""Test that disconnect event has correct type per ASGI spec."""
|
||||
"""Test that disconnect event signals body receiver per ASGI spec."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Simulate client disconnect
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Get the message from queue
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
|
||||
# Per ASGI spec: type MUST be "http.disconnect"
|
||||
assert msg["type"] == "http.disconnect"
|
||||
# Per ASGI spec: disconnect should be signaled
|
||||
assert body_receiver._closed
|
||||
|
||||
def test_disconnect_event_sent_on_connection_lost(self):
|
||||
"""Test that http.disconnect is sent when connection is lost."""
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
"""Test that disconnect is signaled when connection is lost."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
assert protocol._receive_queue.empty()
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
assert not body_receiver._closed
|
||||
|
||||
# Simulate client disconnect
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Queue should have disconnect message
|
||||
assert not protocol._receive_queue.empty()
|
||||
# Disconnect should have been signaled
|
||||
assert body_receiver._closed
|
||||
|
||||
def test_disconnect_sets_closed_flag(self):
|
||||
"""Test that connection_lost sets the closed flag."""
|
||||
@ -788,18 +809,33 @@ class TestHTTPDisconnectEvent:
|
||||
# Cancellation should be scheduled after grace period
|
||||
protocol.worker.loop.call_later.assert_called_once()
|
||||
|
||||
def test_disconnect_message_format(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_message_format(self):
|
||||
"""Test http.disconnect message format per ASGI spec.
|
||||
|
||||
The disconnect message should only contain 'type' key.
|
||||
When body is complete and disconnect is signaled, receive()
|
||||
should return {"type": "http.disconnect"}.
|
||||
"""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
|
||||
protocol.connection_lost(None)
|
||||
# Create a mock request with no body
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 0
|
||||
mock_request.chunked = False
|
||||
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Get initial body message (empty body)
|
||||
msg1 = await body_receiver.receive()
|
||||
assert msg1["type"] == "http.request"
|
||||
assert msg1["more_body"] is False
|
||||
|
||||
# Now receive should return disconnect
|
||||
msg2 = await body_receiver.receive()
|
||||
|
||||
# Per ASGI spec, disconnect message only has 'type'
|
||||
assert msg == {"type": "http.disconnect"}
|
||||
assert len(msg) == 1
|
||||
assert msg2 == {"type": "http.disconnect"}
|
||||
assert len(msg2) == 1
|
||||
|
||||
@ -50,41 +50,58 @@ class TestASGIGracefulDisconnect:
|
||||
|
||||
assert protocol._closed is True
|
||||
|
||||
def test_disconnect_sends_message_to_queue(self, mock_worker):
|
||||
"""Test that connection_lost sends http.disconnect to receive queue."""
|
||||
def test_disconnect_signals_body_receiver(self, mock_worker):
|
||||
"""Test that connection_lost signals the body receiver."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
mock_worker.nr_conns = 1
|
||||
|
||||
# Create a receive queue (simulating active request)
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
# Create a body receiver (simulating active request)
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Verify disconnect flag is not set initially
|
||||
assert not body_receiver._closed
|
||||
|
||||
# Simulate connection lost
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Check that disconnect message was sent
|
||||
assert not protocol._receive_queue.empty()
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
assert msg == {"type": "http.disconnect"}
|
||||
# Check that disconnect flag was set
|
||||
assert body_receiver._closed
|
||||
|
||||
def test_disconnect_is_idempotent(self, mock_worker):
|
||||
"""Test that connection_lost can be called multiple times safely."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
mock_worker.nr_conns = 2 # Start with 2 so we can verify only 1 is decremented
|
||||
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# First call should work
|
||||
protocol.connection_lost(None)
|
||||
assert protocol._closed is True
|
||||
assert mock_worker.nr_conns == 1
|
||||
assert protocol._receive_queue.qsize() == 1
|
||||
assert body_receiver._closed
|
||||
|
||||
# Second call should be a no-op
|
||||
protocol.connection_lost(None)
|
||||
assert mock_worker.nr_conns == 1 # Should not decrement again
|
||||
assert protocol._receive_queue.qsize() == 1 # Should not add another message
|
||||
# Closed flag is still set
|
||||
|
||||
def test_disconnect_does_not_cancel_immediately(self, mock_worker):
|
||||
"""Test that connection_lost doesn't cancel task immediately."""
|
||||
@ -156,41 +173,26 @@ class TestASGIGracefulDisconnect:
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_returns_disconnect_when_closed(self, mock_worker):
|
||||
"""Test that receive() returns http.disconnect when connection is closed."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol._closed = True
|
||||
|
||||
# Create receive queue with body complete
|
||||
receive_queue = asyncio.Queue()
|
||||
protocol._receive_queue = receive_queue
|
||||
# Create a mock request with no body
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 0
|
||||
mock_request.chunked = False
|
||||
|
||||
# Add initial body message
|
||||
await receive_queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Simulate what happens in _handle_http_request
|
||||
body_complete = False
|
||||
|
||||
async def receive():
|
||||
nonlocal body_complete
|
||||
if protocol._closed and body_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
msg = await receive_queue.get()
|
||||
|
||||
if msg.get("type") == "http.request" and not msg.get("more_body", True):
|
||||
body_complete = True
|
||||
|
||||
return msg
|
||||
|
||||
# First receive gets the body
|
||||
msg1 = await receive()
|
||||
# First receive gets the body (empty)
|
||||
msg1 = await body_receiver.receive()
|
||||
assert msg1["type"] == "http.request"
|
||||
assert msg1["more_body"] is False
|
||||
|
||||
# Second receive should get disconnect
|
||||
msg2 = await receive()
|
||||
# Second receive should get disconnect (body complete)
|
||||
msg2 = await body_receiver.receive()
|
||||
assert msg2["type"] == "http.disconnect"
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ and extension support.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
|
||||
@ -40,7 +39,9 @@ class TestHTTPScopeBuilding:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -114,7 +115,9 @@ class TestPathHandling:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -194,7 +197,9 @@ class TestQueryStringHandling:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -270,7 +275,9 @@ class TestHeaderHandling:
|
||||
"""Create a mock HTTP request."""
|
||||
request = mock.Mock()
|
||||
request.method = kwargs.get("method", "GET")
|
||||
request.path = kwargs.get("path", "/")
|
||||
path = kwargs.get("path", "/")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -362,7 +369,9 @@ class TestWebSocketScope:
|
||||
"""Create a mock WebSocket upgrade request."""
|
||||
request = mock.Mock()
|
||||
request.method = "GET"
|
||||
request.path = kwargs.get("path", "/ws")
|
||||
path = kwargs.get("path", "/ws")
|
||||
request.path = path
|
||||
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
|
||||
request.query = kwargs.get("query", "")
|
||||
request.version = kwargs.get("version", (1, 1))
|
||||
request.scheme = kwargs.get("scheme", "http")
|
||||
@ -575,6 +584,7 @@ class TestAddressHandling:
|
||||
request = mock.Mock()
|
||||
request.method = "GET"
|
||||
request.path = "/"
|
||||
request.raw_path = b"/"
|
||||
request.query = ""
|
||||
request.version = (1, 1)
|
||||
request.scheme = "http"
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
Tests for ASGI HTTP parser optimizations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ Tests for chunked transfer encoding, Server-Sent Events (SSE),
|
||||
and streaming response handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -221,43 +220,39 @@ class TestProtocolSendBody:
|
||||
|
||||
return protocol
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_body_without_chunking(self):
|
||||
def test_send_body_without_chunking(self):
|
||||
"""Test sending body without chunked encoding."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
await protocol._send_body(b"Hello, World!", chunked=False)
|
||||
protocol._send_body(b"Hello, World!", chunked=False)
|
||||
|
||||
protocol.transport.write.assert_called_once_with(b"Hello, World!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_body_with_chunking(self):
|
||||
def test_send_body_with_chunking(self):
|
||||
"""Test sending body with chunked encoding."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
await protocol._send_body(b"Hello", chunked=True)
|
||||
protocol._send_body(b"Hello", chunked=True)
|
||||
|
||||
# Should write: "5\r\nHello\r\n"
|
||||
protocol.transport.write.assert_called_once()
|
||||
call_arg = protocol.transport.write.call_args[0][0]
|
||||
assert call_arg == b"5\r\nHello\r\n"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_body_empty_without_chunking(self):
|
||||
def test_send_body_empty_without_chunking(self):
|
||||
"""Test sending empty body without chunked encoding."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
await protocol._send_body(b"", chunked=False)
|
||||
protocol._send_body(b"", chunked=False)
|
||||
|
||||
# Empty body should not write anything
|
||||
protocol.transport.write.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_body_empty_with_chunking(self):
|
||||
def test_send_body_empty_with_chunking(self):
|
||||
"""Test sending empty body with chunked encoding."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
await protocol._send_body(b"", chunked=True)
|
||||
protocol._send_body(b"", chunked=True)
|
||||
|
||||
# Empty body should not write (terminal chunk handled separately)
|
||||
protocol.transport.write.assert_not_called()
|
||||
|
||||
@ -9,7 +9,6 @@ Tests that gunicorn's WebSocket implementation conforms to RFC 6455:
|
||||
https://tools.ietf.org/html/rfc6455
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
@ -176,7 +175,7 @@ class TestWebSocketFrameMasking:
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
return WebSocketProtocol(None, {}, None, mock.Mock())
|
||||
|
||||
def test_unmask_simple(self):
|
||||
"""Test basic unmasking operation."""
|
||||
@ -299,7 +298,6 @@ class TestWebSocketProtocolInstance:
|
||||
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
reader=mock.Mock(),
|
||||
scope=scope,
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
@ -602,11 +600,9 @@ class TestWebSocketAsync:
|
||||
}
|
||||
|
||||
transport = mock.Mock()
|
||||
reader = mock.Mock()
|
||||
|
||||
return WebSocketProtocol(
|
||||
transport=transport,
|
||||
reader=reader,
|
||||
scope=scope,
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
@ -676,3 +672,367 @@ class TestWebSocketAsync:
|
||||
await protocol._send({"type": "websocket.close", "code": 1000})
|
||||
|
||||
assert protocol.closed is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Callback-based Data Feeding Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketCallbackDataFeeding:
|
||||
"""Tests for callback-based data feeding (replaces StreamReader)."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
def test_initial_buffer_empty(self):
|
||||
"""Test that initial buffer is empty."""
|
||||
protocol = self._create_protocol()
|
||||
assert len(protocol._buffer) == 0
|
||||
assert protocol._eof is False
|
||||
|
||||
def test_feed_data_adds_to_buffer(self):
|
||||
"""Test that feed_data adds bytes to buffer."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"Hello")
|
||||
assert bytes(protocol._buffer) == b"Hello"
|
||||
|
||||
protocol.feed_data(b" World")
|
||||
assert bytes(protocol._buffer) == b"Hello World"
|
||||
|
||||
def test_feed_data_ignores_empty(self):
|
||||
"""Test that feed_data ignores empty data."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"")
|
||||
assert len(protocol._buffer) == 0
|
||||
|
||||
protocol.feed_data(None)
|
||||
# Should not raise, just be ignored
|
||||
|
||||
def test_feed_data_sets_event(self):
|
||||
"""Test that feed_data sets the data event."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert not protocol._data_event.is_set()
|
||||
protocol.feed_data(b"data")
|
||||
assert protocol._data_event.is_set()
|
||||
|
||||
def test_feed_eof_sets_flag(self):
|
||||
"""Test that feed_eof sets the EOF flag."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert protocol._eof is False
|
||||
protocol.feed_eof()
|
||||
assert protocol._eof is True
|
||||
|
||||
def test_feed_eof_sets_event(self):
|
||||
"""Test that feed_eof sets the data event."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert not protocol._data_event.is_set()
|
||||
protocol.feed_eof()
|
||||
assert protocol._data_event.is_set()
|
||||
|
||||
|
||||
class TestWebSocketReadExact:
|
||||
"""Tests for _read_exact method with callback-based buffer."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_with_sufficient_data(self):
|
||||
"""Test _read_exact returns data when buffer has enough."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Pre-fill buffer
|
||||
protocol.feed_data(b"Hello World")
|
||||
|
||||
result = await protocol._read_exact(5)
|
||||
assert result == b"Hello"
|
||||
assert bytes(protocol._buffer) == b" World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_consumes_buffer(self):
|
||||
"""Test _read_exact properly consumes buffer."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"ABCDEFGH")
|
||||
|
||||
result1 = await protocol._read_exact(3)
|
||||
assert result1 == b"ABC"
|
||||
|
||||
result2 = await protocol._read_exact(3)
|
||||
assert result2 == b"DEF"
|
||||
|
||||
assert bytes(protocol._buffer) == b"GH"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_returns_none_on_eof(self):
|
||||
"""Test _read_exact returns None when EOF with insufficient data."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"Hi")
|
||||
protocol.feed_eof()
|
||||
|
||||
# Request more data than available after EOF
|
||||
result = await protocol._read_exact(10)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_waits_for_data(self):
|
||||
"""Test _read_exact waits when buffer is insufficient."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read that needs more data
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
|
||||
# Give task a chance to start waiting
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Feed enough data
|
||||
protocol.feed_data(b"1234567890")
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_handles_incremental_data(self):
|
||||
"""Test _read_exact handles data arriving in chunks."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read needing 10 bytes
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Feed data incrementally
|
||||
protocol.feed_data(b"123")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
protocol.feed_data(b"456")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
protocol.feed_data(b"7890")
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_race_condition(self):
|
||||
"""Test _read_exact handles race condition when data arrives during clear/wait gap.
|
||||
|
||||
This tests the fix for the race condition where:
|
||||
1. Task A checks buffer, needs more data
|
||||
2. Task A clears _data_event
|
||||
3. Task B (data_received) calls feed_data(), sets event
|
||||
4. Task A would wait forever on cleared event - DEADLOCK
|
||||
|
||||
The fix adds a buffer check after clear() to catch this case.
|
||||
"""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Pre-fill with partial data
|
||||
protocol.feed_data(b"12345")
|
||||
|
||||
# Start read needing 10 bytes
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Simulate race: feed remaining data rapidly
|
||||
# In the buggy version, if data arrives right after clear() but before wait(),
|
||||
# the event gets set then immediately the wait() would block on a stale clear
|
||||
protocol.feed_data(b"67890")
|
||||
|
||||
# Should complete without deadlock
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_multiple_feeds_before_wait(self):
|
||||
"""Test _read_exact when all data arrives before wait starts."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Feed all data before starting read - should not block
|
||||
protocol.feed_data(b"Complete message here")
|
||||
|
||||
result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1)
|
||||
assert result == b"Complete"
|
||||
|
||||
# Buffer should have remainder
|
||||
assert bytes(protocol._buffer) == b" message here"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_eof_during_wait(self):
|
||||
"""Test _read_exact handles EOF arriving while waiting for data."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read needing more data than we'll provide
|
||||
read_task = asyncio.create_task(protocol._read_exact(100))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Feed some data but not enough
|
||||
protocol.feed_data(b"partial")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Signal EOF - should cause read to return None
|
||||
protocol.feed_eof()
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4)
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketFragmentedMessages:
|
||||
"""Tests for WebSocket fragmented message handling."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
def _create_masked_frame(self, fin, opcode, payload, mask_key=None):
|
||||
"""Create a masked WebSocket frame.
|
||||
|
||||
Args:
|
||||
fin: FIN bit (1 for final, 0 for continuation)
|
||||
opcode: Frame opcode
|
||||
payload: Frame payload bytes
|
||||
mask_key: 4-byte masking key (generated if None)
|
||||
|
||||
Returns:
|
||||
bytes: Complete masked frame
|
||||
"""
|
||||
if mask_key is None:
|
||||
mask_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
|
||||
frame = bytearray()
|
||||
|
||||
# First byte: FIN + RSV(000) + opcode
|
||||
frame.append((fin << 7) | opcode)
|
||||
|
||||
# Second byte: MASK(1) + length
|
||||
length = len(payload)
|
||||
if length < 126:
|
||||
frame.append(0x80 | length)
|
||||
elif length < 65536:
|
||||
frame.append(0x80 | 126)
|
||||
frame.extend(struct.pack("!H", length))
|
||||
else:
|
||||
frame.append(0x80 | 127)
|
||||
frame.extend(struct.pack("!Q", length))
|
||||
|
||||
# Masking key
|
||||
frame.extend(mask_key)
|
||||
|
||||
# Masked payload
|
||||
masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||
frame.extend(masked_payload)
|
||||
|
||||
return bytes(frame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fragmented_message_reassembly(self):
|
||||
"""Test reassembly of fragmented text message with multiple continuation frames."""
|
||||
from gunicorn.asgi.websocket import (
|
||||
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT
|
||||
)
|
||||
import asyncio
|
||||
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Build fragmented message: "Hello" + " " + "World" + "!"
|
||||
# First frame: opcode=TEXT, FIN=0, payload="Hello"
|
||||
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
|
||||
# Continuation frames: opcode=CONTINUATION, FIN=0
|
||||
frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ")
|
||||
frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World")
|
||||
# Final frame: opcode=CONTINUATION, FIN=1
|
||||
frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!")
|
||||
|
||||
# Feed all frames
|
||||
protocol.feed_data(frame1 + frame2 + frame3 + frame4)
|
||||
|
||||
# Read frames - first 3 should return CONTINUATION with empty payload (waiting)
|
||||
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started
|
||||
|
||||
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued
|
||||
|
||||
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued
|
||||
|
||||
# Final frame should return complete reassembled message with original opcode
|
||||
result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result4 == (OPCODE_TEXT, b"Hello World!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_frame_during_fragmentation(self):
|
||||
"""Test that control frames (ping) can arrive during fragmented message.
|
||||
|
||||
RFC 6455 Section 5.4: Control frames MAY be injected in the middle
|
||||
of a fragmented message.
|
||||
"""
|
||||
from gunicorn.asgi.websocket import (
|
||||
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING
|
||||
)
|
||||
import asyncio
|
||||
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start fragmented message
|
||||
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
|
||||
# Ping frame in the middle (control frames are always FIN=1)
|
||||
ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping")
|
||||
# Continue and finish fragmented message
|
||||
frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World")
|
||||
|
||||
protocol.feed_data(frame1 + ping_frame + frame2)
|
||||
|
||||
# First read: fragment started (waiting for more)
|
||||
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result1 == (OPCODE_CONTINUATION, b"")
|
||||
|
||||
# Second read: ping frame (control frames handled separately)
|
||||
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result2 == (OPCODE_PING, b"ping")
|
||||
|
||||
# Third read: complete reassembled message
|
||||
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result3 == (OPCODE_TEXT, b"Hello World")
|
||||
|
||||
@ -12,11 +12,7 @@ that actually start the server and make HTTP requests.
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -120,7 +116,7 @@ class FakeListener:
|
||||
def _has_uvloop():
|
||||
"""Check if uvloop is available."""
|
||||
try:
|
||||
import uvloop
|
||||
import uvloop # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@ -337,9 +333,9 @@ class TestLifespanManager:
|
||||
async def app(scope, receive, send):
|
||||
assert "state" in scope
|
||||
scope["state"]["db"] = "connected"
|
||||
message = await receive()
|
||||
_ = await receive()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
_ = await receive()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock(), state)
|
||||
@ -393,7 +389,7 @@ class TestWebSocketProtocol:
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
# Create a minimal protocol instance
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
|
||||
|
||||
# Test unmasking (XOR operation)
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
@ -406,7 +402,7 @@ class TestWebSocketProtocol:
|
||||
"""Test WebSocket frame unmasking with empty payload."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
|
||||
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
unmasked = protocol._unmask(b"", masking_key)
|
||||
@ -555,7 +551,6 @@ class TestASGIProtocol:
|
||||
def test_scope_building(self):
|
||||
"""Test HTTP scope building."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
@ -729,10 +724,11 @@ class TestASGIHTTP2Priority:
|
||||
protocol = ASGIProtocol(worker)
|
||||
|
||||
# Create mock HTTP/1.1 request (no priority attributes)
|
||||
request = mock.Mock(spec=['method', 'path', 'query', 'version',
|
||||
request = mock.Mock(spec=['method', 'path', 'raw_path', 'query', 'version',
|
||||
'scheme', 'headers'])
|
||||
request.method = "GET"
|
||||
request.path = "/test"
|
||||
request.raw_path = b"/test"
|
||||
request.query = ""
|
||||
request.version = (1, 1)
|
||||
request.scheme = "http"
|
||||
|
||||
@ -7,20 +7,57 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.http.errors import (
|
||||
InvalidRequestLine,
|
||||
InvalidRequestMethod,
|
||||
InvalidSchemeHeaders,
|
||||
ObsoleteFolding,
|
||||
)
|
||||
import treq
|
||||
|
||||
dirname = os.path.dirname(__file__)
|
||||
reqdir = os.path.join(dirname, "requests", "invalid")
|
||||
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
|
||||
|
||||
# Flags incompatible with fast parser (require Python parser features)
|
||||
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
|
||||
|
||||
# Exceptions that only the Python parser raises (C parser has different validation)
|
||||
_PYTHON_ONLY_EXCEPTIONS = (ObsoleteFolding, InvalidSchemeHeaders)
|
||||
|
||||
# C parser may raise different but valid exceptions for these cases
|
||||
_FAST_PARSER_ALTERNATES = {
|
||||
InvalidRequestMethod: (InvalidRequestLine,), # e.g. "GET:" raises InvalidRequestLine
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fname", httpfiles)
|
||||
def test_http_parser(fname):
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
|
||||
def test_http_parser(fname, http_parser):
|
||||
"""Test invalid HTTP requests with both parser implementations."""
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
|
||||
|
||||
expect = env["request"]
|
||||
cfg = env["cfg"]
|
||||
|
||||
# Skip fast parser tests that use incompatible compatibility flags
|
||||
if http_parser == 'fast':
|
||||
for flag in _FAST_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
pytest.skip(f"fast parser incompatible with {flag}")
|
||||
|
||||
# Skip tests expecting Python-only exceptions
|
||||
if expect in _PYTHON_ONLY_EXCEPTIONS or (
|
||||
isinstance(expect, type) and issubclass(expect, _PYTHON_ONLY_EXCEPTIONS)
|
||||
):
|
||||
pytest.skip(f"fast parser does not raise {expect.__name__}")
|
||||
|
||||
# Determine acceptable exceptions (fast parser may raise alternates)
|
||||
if http_parser == 'fast' and expect in _FAST_PARSER_ALTERNATES:
|
||||
acceptable = (expect,) + _FAST_PARSER_ALTERNATES[expect]
|
||||
else:
|
||||
acceptable = expect
|
||||
|
||||
req = treq.badrequest(fname)
|
||||
|
||||
with pytest.raises(expect):
|
||||
with pytest.raises(acceptable):
|
||||
req.check(cfg)
|
||||
|
||||
518
tests/test_python_protocol.py
Normal file
518
tests/test_python_protocol.py
Normal file
@ -0,0 +1,518 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for PythonProtocol callback-based HTTP parser.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from gunicorn.asgi.parser import PythonProtocol, CallbackRequest, ParseError
|
||||
|
||||
|
||||
class TestPythonProtocolBasic:
|
||||
"""Test basic request parsing."""
|
||||
|
||||
def test_simple_get_request(self):
|
||||
"""Test parsing a simple GET request."""
|
||||
headers_complete = []
|
||||
message_complete = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_headers_complete=lambda: headers_complete.append(True),
|
||||
on_message_complete=lambda: message_complete.append(True),
|
||||
)
|
||||
|
||||
data = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.path == b"/path"
|
||||
assert parser.http_version == (1, 1)
|
||||
assert len(parser.headers) == 1
|
||||
assert parser.headers[0] == (b"host", b"example.com")
|
||||
assert parser.is_complete is True
|
||||
assert len(headers_complete) == 1
|
||||
assert len(message_complete) == 1
|
||||
|
||||
def test_get_with_query_string(self):
|
||||
"""Test parsing GET with query string."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.path == b"/search?q=test&page=1"
|
||||
assert parser.is_complete is True
|
||||
|
||||
def test_http_10_request(self):
|
||||
"""Test parsing HTTP/1.0 request."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.http_version == (1, 0)
|
||||
assert parser.should_keep_alive is False # HTTP/1.0 default
|
||||
|
||||
def test_http_10_with_keepalive(self):
|
||||
"""Test HTTP/1.0 with explicit keep-alive."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET / HTTP/1.0\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.http_version == (1, 0)
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_multiple_headers(self):
|
||||
"""Test parsing multiple headers."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Accept: text/html\r\n"
|
||||
b"Accept-Language: en-US\r\n"
|
||||
b"User-Agent: Test/1.0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert len(parser.headers) == 4
|
||||
header_names = [h[0] for h in parser.headers]
|
||||
assert b"host" in header_names
|
||||
assert b"accept" in header_names
|
||||
assert b"accept-language" in header_names
|
||||
assert b"user-agent" in header_names
|
||||
|
||||
|
||||
class TestPythonProtocolBody:
|
||||
"""Test request body parsing."""
|
||||
|
||||
def test_post_with_content_length(self):
|
||||
"""Test POST with Content-Length body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST /submit HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 13\r\n"
|
||||
b"\r\n"
|
||||
b"name=testuser"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"POST"
|
||||
assert parser.content_length == 13
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"name=testuser"
|
||||
|
||||
def test_chunked_body(self):
|
||||
"""Test chunked transfer encoding."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\nhello\r\n"
|
||||
b"6\r\n world\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.is_chunked is True
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"hello world"
|
||||
|
||||
def test_chunked_with_extension(self):
|
||||
"""Test chunked with chunk extension."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5;ext=value\r\nhello\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert b"".join(body_chunks) == b"hello"
|
||||
|
||||
|
||||
class TestPythonProtocolIncremental:
|
||||
"""Test incremental/partial data feeding."""
|
||||
|
||||
def test_partial_request_line(self):
|
||||
"""Test feeding partial request line."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
# Feed partial request line
|
||||
parser.feed(b"GET /path ")
|
||||
assert parser.method is None
|
||||
assert parser.is_complete is False
|
||||
|
||||
# Complete the request line and headers
|
||||
parser.feed(b"HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.method == b"GET"
|
||||
assert parser.is_complete is True
|
||||
|
||||
def test_partial_headers(self):
|
||||
"""Test feeding partial headers."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
parser.feed(b"GET / HTTP/1.1\r\n")
|
||||
parser.feed(b"Host: exa")
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"mple.com\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert parser.headers[0] == (b"host", b"example.com")
|
||||
|
||||
def test_partial_body(self):
|
||||
"""Test feeding partial body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"world")
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"helloworld"
|
||||
|
||||
def test_partial_chunked_body(self):
|
||||
"""Test feeding partial chunked body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\nhel"
|
||||
)
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"lo\r\n0\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"hello"
|
||||
|
||||
|
||||
class TestPythonProtocolErrors:
|
||||
"""Test error handling."""
|
||||
|
||||
def test_invalid_request_line(self):
|
||||
"""Test invalid request line."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"INVALID\r\n")
|
||||
|
||||
def test_invalid_header(self):
|
||||
"""Test invalid header (no colon)."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"GET / HTTP/1.1\r\nBadHeader\r\n\r\n")
|
||||
|
||||
def test_unsupported_http_version(self):
|
||||
"""Test unsupported HTTP version."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"GET / HTTP/2.0\r\n\r\n")
|
||||
|
||||
def test_invalid_chunk_size(self):
|
||||
"""Test invalid chunk size."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"XYZ\r\n" # Invalid hex
|
||||
)
|
||||
|
||||
|
||||
class TestPythonProtocolReset:
|
||||
"""Test parser reset for keepalive."""
|
||||
|
||||
def test_reset_clears_state(self):
|
||||
"""Test that reset clears all state."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert parser.method == b"GET"
|
||||
|
||||
parser.reset()
|
||||
|
||||
assert parser.method is None
|
||||
assert parser.path is None
|
||||
assert parser.http_version is None
|
||||
assert parser.headers == []
|
||||
assert parser.content_length is None
|
||||
assert parser.is_chunked is False
|
||||
assert parser.is_complete is False
|
||||
|
||||
def test_multiple_requests_keepalive(self):
|
||||
"""Test handling multiple requests on keepalive connection."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
# First request
|
||||
parser.feed(b"GET /first HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.path == b"/first"
|
||||
assert parser.is_complete is True
|
||||
|
||||
parser.reset()
|
||||
|
||||
# Second request
|
||||
parser.feed(b"GET /second HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.path == b"/second"
|
||||
assert parser.is_complete is True
|
||||
|
||||
|
||||
class TestPythonProtocolCallbacks:
|
||||
"""Test callback firing."""
|
||||
|
||||
def test_all_callbacks(self):
|
||||
"""Test all callbacks fire in correct order."""
|
||||
events = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_message_begin=lambda: events.append("begin"),
|
||||
on_url=lambda url: events.append(("url", url)),
|
||||
on_header=lambda n, v: events.append(("header", n, v)),
|
||||
on_headers_complete=lambda: events.append("headers_complete"),
|
||||
on_body=lambda chunk: events.append(("body", chunk)),
|
||||
on_message_complete=lambda: events.append("complete"),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
|
||||
assert events[0] == "begin"
|
||||
assert events[1] == ("url", b"/")
|
||||
assert events[2] == ("header", b"host", b"example.com")
|
||||
assert events[3] == ("header", b"content-length", b"5")
|
||||
assert events[4] == "headers_complete"
|
||||
assert events[5] == ("body", b"hello")
|
||||
assert events[6] == "complete"
|
||||
|
||||
def test_skip_body_callback(self):
|
||||
"""Test on_headers_complete returning True skips body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_headers_complete=lambda: True, # Skip body
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
|
||||
# Body should be skipped
|
||||
assert parser.is_complete is True
|
||||
assert len(body_chunks) == 0
|
||||
|
||||
|
||||
class TestCallbackRequest:
|
||||
"""Test CallbackRequest adapter."""
|
||||
|
||||
def test_from_parser_simple(self):
|
||||
"""Test creating request from parser state."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET /path?query=value HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/path"
|
||||
assert request.query == "query=value"
|
||||
assert request.uri == "/path?query=value"
|
||||
assert request.version == (1, 1)
|
||||
assert request.scheme == "http"
|
||||
|
||||
def test_from_parser_ssl(self):
|
||||
"""Test SSL scheme detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser, is_ssl=True)
|
||||
|
||||
assert request.scheme == "https"
|
||||
|
||||
def test_from_parser_headers(self):
|
||||
"""Test header conversion."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: text/plain\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
# String headers (uppercase)
|
||||
assert ("HOST", "example.com") in request.headers
|
||||
assert ("CONTENT-TYPE", "text/plain") in request.headers
|
||||
|
||||
# Bytes headers (lowercase)
|
||||
assert (b"host", b"example.com") in request.headers_bytes
|
||||
assert (b"content-type", b"text/plain") in request.headers_bytes
|
||||
|
||||
def test_from_parser_body_info(self):
|
||||
"""Test body info extraction."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.content_length == 10
|
||||
assert request.chunked is False
|
||||
|
||||
def test_from_parser_chunked(self):
|
||||
"""Test chunked transfer detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.chunked is True
|
||||
|
||||
def test_should_close(self):
|
||||
"""Test should_close method."""
|
||||
# HTTP/1.1 with Connection: close
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is True
|
||||
|
||||
# HTTP/1.1 keep-alive (default)
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is False
|
||||
|
||||
# HTTP/1.0 (default close)
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.0\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is True
|
||||
|
||||
def test_get_header(self):
|
||||
"""Test get_header method."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"X-Custom: value\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.get_header("Host") == "example.com"
|
||||
assert request.get_header("x-custom") == "value"
|
||||
assert request.get_header("X-CUSTOM") == "value"
|
||||
assert request.get_header("X-Missing") is None
|
||||
|
||||
def test_expect_100_continue(self):
|
||||
"""Test Expect: 100-continue detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Expect: 100-continue\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request._expect_100_continue is True
|
||||
|
||||
|
||||
class TestPythonProtocolConnectionClose:
|
||||
"""Test connection close handling."""
|
||||
|
||||
def test_connection_close_header(self):
|
||||
"""Test Connection: close header."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
|
||||
def test_connection_keepalive_header(self):
|
||||
"""Test Connection: keep-alive header."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_http11_default_keepalive(self):
|
||||
"""Test HTTP/1.1 defaults to keep-alive."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_http10_default_close(self):
|
||||
"""Test HTTP/1.0 defaults to close."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
@ -165,6 +165,10 @@ class TestSignalHandlingIntegration:
|
||||
proc.kill()
|
||||
pytest.fail("Gunicorn did not exit within timeout after SIGTERM")
|
||||
|
||||
@pytest.mark.skipif(
|
||||
hasattr(sys, 'pypy_version_info'),
|
||||
reason="SIGINT handling differs on PyPy, use SIGTERM test instead"
|
||||
)
|
||||
def test_graceful_shutdown_sigint(self, gunicorn_server):
|
||||
"""Verify SIGINT causes graceful shutdown."""
|
||||
proc, port = gunicorn_server
|
||||
|
||||
@ -13,13 +13,24 @@ dirname = os.path.dirname(__file__)
|
||||
reqdir = os.path.join(dirname, "requests", "valid")
|
||||
httpfiles = glob.glob(os.path.join(reqdir, "*.http"))
|
||||
|
||||
# Flags incompatible with fast parser
|
||||
_FAST_INCOMPATIBLE_FLAGS = ('permit_obsolete_folding', 'strip_header_spaces')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("fname", httpfiles)
|
||||
def test_http_parser(fname):
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py")
|
||||
def test_http_parser(fname, http_parser):
|
||||
"""Test valid HTTP requests with both parser implementations."""
|
||||
env = treq.load_py(os.path.splitext(fname)[0] + ".py", http_parser=http_parser)
|
||||
|
||||
expect = env['request']
|
||||
cfg = env['cfg']
|
||||
|
||||
# Skip fast parser tests that use incompatible compatibility flags
|
||||
if http_parser == 'fast':
|
||||
for flag in _FAST_INCOMPATIBLE_FLAGS:
|
||||
if getattr(cfg, flag, False):
|
||||
pytest.skip(f"fast parser incompatible with {flag}")
|
||||
|
||||
req = treq.request(fname, expect)
|
||||
|
||||
for case in req.gen_cases(cfg):
|
||||
|
||||
@ -33,19 +33,26 @@ def uri(data):
|
||||
return ret
|
||||
|
||||
|
||||
def load_py(fname):
|
||||
def load_py(fname, http_parser='python'):
|
||||
"""Load test configuration from Python file.
|
||||
|
||||
Args:
|
||||
fname: Path to the .py configuration file
|
||||
http_parser: Parser to use - 'python' or 'fast'
|
||||
"""
|
||||
module_name = '__config__'
|
||||
mod = types.ModuleType(module_name)
|
||||
setattr(mod, 'uri', uri)
|
||||
setattr(mod, 'cfg', Config())
|
||||
loader = importlib.machinery.SourceFileLoader(module_name, fname)
|
||||
loader.exec_module(mod)
|
||||
# Set parser after loading so test-specific configs don't override
|
||||
mod.cfg.set('http_parser', http_parser)
|
||||
return vars(mod)
|
||||
|
||||
|
||||
def decode_hex_escapes(data):
|
||||
"""Decode hex escape sequences like \\xAB in test data."""
|
||||
import re
|
||||
result = bytearray()
|
||||
i = 0
|
||||
while i < len(data):
|
||||
|
||||
@ -1,416 +0,0 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from unittest import mock
|
||||
|
||||
|
||||
def test_import():
|
||||
"""Test that the eventlet worker module can be imported."""
|
||||
try:
|
||||
import eventlet
|
||||
except AttributeError:
|
||||
if (3, 13) > sys.version_info >= (3, 12):
|
||||
pytest.skip("Ignoring eventlet failures on Python 3.12")
|
||||
raise
|
||||
__import__('gunicorn.workers.geventlet')
|
||||
|
||||
|
||||
class TestVersionRequirement:
|
||||
"""Tests for eventlet version requirement checks."""
|
||||
|
||||
def test_import_error_message(self):
|
||||
"""Test that ImportError gives correct version message."""
|
||||
with mock.patch.dict('sys.modules', {'eventlet': None}):
|
||||
# Clear cached module if present
|
||||
sys.modules.pop('gunicorn.workers.geventlet', None)
|
||||
with pytest.raises(RuntimeError, match="eventlet 0.40.3"):
|
||||
import importlib
|
||||
import gunicorn.workers.geventlet
|
||||
importlib.reload(gunicorn.workers.geventlet)
|
||||
|
||||
def test_version_check_requires_0_40_3(self):
|
||||
"""Test that version check requires eventlet 0.40.3 or higher."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from packaging.version import parse as parse_version
|
||||
min_version = parse_version('0.40.3')
|
||||
current_version = parse_version(eventlet.__version__)
|
||||
|
||||
# If we got this far, the import succeeded, meaning version is sufficient
|
||||
assert current_version >= min_version
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def eventlet_worker():
|
||||
"""Fixture to create an EventletWorker instance for testing."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import EventletWorker
|
||||
|
||||
# Create a minimal mock config
|
||||
cfg = mock.MagicMock()
|
||||
cfg.keepalive = 2
|
||||
cfg.graceful_timeout = 30
|
||||
cfg.is_ssl = False
|
||||
cfg.worker_connections = 1000
|
||||
|
||||
# Create worker with mocked dependencies
|
||||
worker = EventletWorker.__new__(EventletWorker)
|
||||
worker.cfg = cfg
|
||||
worker.alive = True
|
||||
worker.sockets = []
|
||||
worker.log = mock.MagicMock()
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
class TestEventletWorker:
|
||||
"""Tests for EventletWorker class."""
|
||||
|
||||
def test_worker_class_exists(self):
|
||||
"""Test that EventletWorker class is properly defined."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import EventletWorker
|
||||
from gunicorn.workers.base_async import AsyncWorker
|
||||
|
||||
assert issubclass(EventletWorker, AsyncWorker)
|
||||
|
||||
def test_patch_method_calls_use_hub(self, eventlet_worker):
|
||||
"""Test that patch() calls hubs.use_hub().
|
||||
|
||||
hubs.use_hub() must be called in patch() (after fork) because it creates
|
||||
OS resources like kqueue that don't survive fork.
|
||||
"""
|
||||
from eventlet import hubs
|
||||
|
||||
with mock.patch.object(hubs, 'use_hub') as mock_use_hub:
|
||||
with mock.patch('gunicorn.workers.geventlet.patch_sendfile'):
|
||||
eventlet_worker.patch()
|
||||
|
||||
mock_use_hub.assert_called_once()
|
||||
|
||||
def test_patch_method_calls_patch_sendfile(self, eventlet_worker):
|
||||
"""Test that patch() calls patch_sendfile()."""
|
||||
from eventlet import hubs
|
||||
|
||||
with mock.patch.object(hubs, 'use_hub'):
|
||||
with mock.patch('gunicorn.workers.geventlet.patch_sendfile') as mock_sf:
|
||||
eventlet_worker.patch()
|
||||
|
||||
mock_sf.assert_called_once()
|
||||
|
||||
def test_monkey_patch_called_at_import_time(self):
|
||||
"""Test that monkey_patch is called at module import time.
|
||||
|
||||
Note: hubs.use_hub() and eventlet.monkey_patch() are called at module
|
||||
import time (not in patch()) to ensure all imports are properly patched.
|
||||
This test verifies the module was patched by checking eventlet state.
|
||||
"""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
# Verify eventlet has been patched by checking that socket is patched
|
||||
import socket
|
||||
from eventlet.greenio import GreenSocket
|
||||
|
||||
# After monkey patching, socket.socket should be GreenSocket
|
||||
assert socket.socket is GreenSocket
|
||||
|
||||
def test_timeout_ctx_returns_eventlet_timeout(self, eventlet_worker):
|
||||
"""Test that timeout_ctx() returns an eventlet.Timeout."""
|
||||
import eventlet
|
||||
|
||||
timeout = eventlet_worker.timeout_ctx()
|
||||
assert isinstance(timeout, eventlet.Timeout)
|
||||
|
||||
def test_timeout_ctx_uses_keepalive_config(self, eventlet_worker):
|
||||
"""Test that timeout_ctx() uses cfg.keepalive value."""
|
||||
import eventlet
|
||||
|
||||
eventlet_worker.cfg.keepalive = 5
|
||||
with mock.patch.object(eventlet, 'Timeout') as mock_timeout:
|
||||
eventlet_worker.timeout_ctx()
|
||||
|
||||
mock_timeout.assert_called_once_with(5, False)
|
||||
|
||||
def test_timeout_ctx_with_no_keepalive(self, eventlet_worker):
|
||||
"""Test that timeout_ctx() handles no keepalive (None or 0)."""
|
||||
import eventlet
|
||||
|
||||
eventlet_worker.cfg.keepalive = 0
|
||||
with mock.patch.object(eventlet, 'Timeout') as mock_timeout:
|
||||
eventlet_worker.timeout_ctx()
|
||||
|
||||
mock_timeout.assert_called_once_with(None, False)
|
||||
|
||||
def test_handle_quit_spawns_greenthread(self, eventlet_worker):
|
||||
"""Test that handle_quit() spawns a greenthread."""
|
||||
import eventlet
|
||||
|
||||
with mock.patch.object(eventlet, 'spawn') as mock_spawn:
|
||||
eventlet_worker.handle_quit(None, None)
|
||||
|
||||
mock_spawn.assert_called_once()
|
||||
|
||||
def test_handle_usr1_spawns_greenthread(self, eventlet_worker):
|
||||
"""Test that handle_usr1() spawns a greenthread."""
|
||||
import eventlet
|
||||
|
||||
with mock.patch.object(eventlet, 'spawn') as mock_spawn:
|
||||
eventlet_worker.handle_usr1(None, None)
|
||||
|
||||
mock_spawn.assert_called_once()
|
||||
|
||||
def test_handle_wraps_ssl_when_configured(self, eventlet_worker):
|
||||
"""Test that handle() wraps socket with SSL when is_ssl is True."""
|
||||
from gunicorn.workers import geventlet
|
||||
|
||||
eventlet_worker.cfg.is_ssl = True
|
||||
mock_client = mock.MagicMock()
|
||||
mock_listener = mock.MagicMock()
|
||||
|
||||
with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl:
|
||||
mock_ssl.return_value = mock_client
|
||||
with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'):
|
||||
eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000))
|
||||
|
||||
mock_ssl.assert_called_once_with(mock_client, eventlet_worker.cfg)
|
||||
|
||||
def test_handle_no_ssl_when_not_configured(self, eventlet_worker):
|
||||
"""Test that handle() does not wrap SSL when is_ssl is False."""
|
||||
from gunicorn.workers import geventlet
|
||||
|
||||
eventlet_worker.cfg.is_ssl = False
|
||||
mock_client = mock.MagicMock()
|
||||
mock_listener = mock.MagicMock()
|
||||
|
||||
with mock.patch.object(geventlet, 'ssl_wrap_socket') as mock_ssl:
|
||||
with mock.patch('gunicorn.workers.base_async.AsyncWorker.handle'):
|
||||
eventlet_worker.handle(mock_listener, mock_client, ('127.0.0.1', 8000))
|
||||
|
||||
mock_ssl.assert_not_called()
|
||||
|
||||
|
||||
class TestAlreadyHandled:
|
||||
"""Tests for is_already_handled() method."""
|
||||
|
||||
def test_is_already_handled_new_style(self, eventlet_worker):
|
||||
"""Test is_already_handled with eventlet >= 0.30.3 (WSGI_LOCAL)."""
|
||||
from gunicorn.workers import geventlet
|
||||
|
||||
# Mock the new-style WSGI_LOCAL.already_handled
|
||||
mock_wsgi_local = mock.MagicMock()
|
||||
mock_wsgi_local.already_handled = True
|
||||
|
||||
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', mock_wsgi_local):
|
||||
with pytest.raises(StopIteration):
|
||||
eventlet_worker.is_already_handled(mock.MagicMock())
|
||||
|
||||
def test_is_already_handled_old_style(self, eventlet_worker):
|
||||
"""Test is_already_handled with eventlet < 0.30.3 (ALREADY_HANDLED)."""
|
||||
from gunicorn.workers import geventlet
|
||||
|
||||
sentinel = object()
|
||||
|
||||
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None):
|
||||
with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', sentinel):
|
||||
with pytest.raises(StopIteration):
|
||||
eventlet_worker.is_already_handled(sentinel)
|
||||
|
||||
def test_is_already_handled_returns_parent_result(self, eventlet_worker):
|
||||
"""Test is_already_handled falls through to parent when not handled."""
|
||||
from gunicorn.workers import geventlet
|
||||
|
||||
with mock.patch.object(geventlet, 'EVENTLET_WSGI_LOCAL', None):
|
||||
with mock.patch.object(geventlet, 'EVENTLET_ALREADY_HANDLED', None):
|
||||
with mock.patch('gunicorn.workers.base_async.AsyncWorker.is_already_handled') as mock_parent:
|
||||
mock_parent.return_value = False
|
||||
result = eventlet_worker.is_already_handled(mock.MagicMock())
|
||||
|
||||
assert result is False
|
||||
mock_parent.assert_called_once()
|
||||
|
||||
|
||||
class TestPatchSendfile:
|
||||
"""Tests for patch_sendfile() function."""
|
||||
|
||||
def test_patch_sendfile_adds_method_when_missing(self):
|
||||
"""Test that patch_sendfile adds sendfile to GreenSocket if missing."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import patch_sendfile, _eventlet_socket_sendfile
|
||||
from eventlet.greenio import GreenSocket
|
||||
|
||||
# Remove sendfile if it exists
|
||||
original = getattr(GreenSocket, 'sendfile', None)
|
||||
if hasattr(GreenSocket, 'sendfile'):
|
||||
delattr(GreenSocket, 'sendfile')
|
||||
|
||||
try:
|
||||
patch_sendfile()
|
||||
assert hasattr(GreenSocket, 'sendfile')
|
||||
assert GreenSocket.sendfile == _eventlet_socket_sendfile
|
||||
finally:
|
||||
# Restore original state
|
||||
if original is not None:
|
||||
GreenSocket.sendfile = original
|
||||
elif hasattr(GreenSocket, 'sendfile'):
|
||||
delattr(GreenSocket, 'sendfile')
|
||||
|
||||
def test_patch_sendfile_preserves_existing_method(self):
|
||||
"""Test that patch_sendfile does not override existing sendfile."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import patch_sendfile
|
||||
from eventlet.greenio import GreenSocket
|
||||
|
||||
# If sendfile exists, it should be preserved
|
||||
if hasattr(GreenSocket, 'sendfile'):
|
||||
original = GreenSocket.sendfile
|
||||
patch_sendfile()
|
||||
assert GreenSocket.sendfile == original
|
||||
|
||||
|
||||
class TestEventletSocketSendfile:
|
||||
"""Tests for _eventlet_socket_sendfile() function."""
|
||||
|
||||
def test_sendfile_raises_on_non_blocking(self):
|
||||
"""Test that sendfile raises ValueError for non-blocking sockets."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
|
||||
|
||||
mock_socket = mock.MagicMock()
|
||||
mock_socket.gettimeout.return_value = 0
|
||||
|
||||
with pytest.raises(ValueError, match="non-blocking"):
|
||||
_eventlet_socket_sendfile(mock_socket, mock.MagicMock())
|
||||
|
||||
def test_sendfile_seeks_to_offset(self):
|
||||
"""Test that sendfile seeks to offset if provided."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
|
||||
|
||||
mock_socket = mock.MagicMock()
|
||||
mock_socket.gettimeout.return_value = 1
|
||||
mock_file = mock.MagicMock()
|
||||
mock_file.read.return_value = b''
|
||||
|
||||
_eventlet_socket_sendfile(mock_socket, mock_file, offset=100)
|
||||
|
||||
mock_file.seek.assert_any_call(100)
|
||||
|
||||
def test_sendfile_returns_total_sent(self):
|
||||
"""Test that sendfile returns the total bytes sent."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_socket_sendfile
|
||||
|
||||
mock_socket = mock.MagicMock()
|
||||
mock_socket.gettimeout.return_value = 1
|
||||
mock_socket.send.return_value = 10
|
||||
|
||||
mock_file = mock.MagicMock()
|
||||
mock_file.read.side_effect = [b'x' * 10, b'']
|
||||
|
||||
result = _eventlet_socket_sendfile(mock_socket, mock_file)
|
||||
|
||||
assert result == 10
|
||||
|
||||
|
||||
class TestEventletServe:
|
||||
"""Tests for _eventlet_serve() function."""
|
||||
|
||||
def test_serve_creates_green_pool(self):
|
||||
"""Test that _eventlet_serve creates a GreenPool."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_serve
|
||||
|
||||
mock_sock = mock.MagicMock()
|
||||
mock_sock.accept.side_effect = eventlet.StopServe()
|
||||
|
||||
with mock.patch.object(eventlet.greenpool, 'GreenPool') as mock_pool:
|
||||
mock_pool_instance = mock.MagicMock()
|
||||
mock_pool.return_value = mock_pool_instance
|
||||
mock_pool_instance.waitall.return_value = None
|
||||
|
||||
_eventlet_serve(mock_sock, mock.MagicMock(), 100)
|
||||
|
||||
mock_pool.assert_called_once_with(100)
|
||||
|
||||
|
||||
class TestEventletStop:
|
||||
"""Tests for _eventlet_stop() function."""
|
||||
|
||||
def test_stop_waits_for_client(self):
|
||||
"""Test that _eventlet_stop waits for the client greenlet."""
|
||||
try:
|
||||
import eventlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_stop
|
||||
|
||||
mock_client = mock.MagicMock()
|
||||
mock_server = mock.MagicMock()
|
||||
mock_conn = mock.MagicMock()
|
||||
|
||||
_eventlet_stop(mock_client, mock_server, mock_conn)
|
||||
|
||||
mock_client.wait.assert_called_once()
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
def test_stop_closes_connection_on_greenlet_exit(self):
|
||||
"""Test that connection is closed even on GreenletExit."""
|
||||
try:
|
||||
import eventlet
|
||||
import greenlet
|
||||
except (ImportError, AttributeError):
|
||||
pytest.skip("eventlet not available")
|
||||
|
||||
from gunicorn.workers.geventlet import _eventlet_stop
|
||||
|
||||
mock_client = mock.MagicMock()
|
||||
mock_client.wait.side_effect = greenlet.GreenletExit()
|
||||
mock_server = mock.MagicMock()
|
||||
mock_conn = mock.MagicMock()
|
||||
|
||||
# Should not raise
|
||||
_eventlet_stop(mock_client, mock_server, mock_conn)
|
||||
|
||||
mock_conn.close.assert_called_once()
|
||||
Loading…
x
Reference in New Issue
Block a user