mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-04 11:41:32 +08:00
Optimize ASGI performance with fast parser integration
Wire HttpParser to ASGI hot path, replacing AsyncRequest.parse() with direct buffer-based parsing. Add FastAsyncRequest wrapper for body reading. Replace per-request Queue/Task with BodyReceiver for on-demand body reading. Keep headers as bytes end-to-end to avoid conversion overhead. Add backpressure control and keepalive timer. Cache response status lines and Date header. Benchmark shows 3x improvement: ~875K req/s for simple GET (was ~340K).
This commit is contained in:
parent
d89564b83c
commit
fa967743c0
@ -1971,3 +1971,21 @@ need to increase this value.
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
!!! info "Added in 25.0.0"
|
||||
|
||||
### `http_parser`
|
||||
|
||||
**Command line:** `--http-parser STRING`
|
||||
|
||||
**Default:** `'auto'`
|
||||
|
||||
HTTP parser implementation.
|
||||
|
||||
- auto: Use gunicorn_h1c if available, otherwise pure Python (default)
|
||||
- fast: Require gunicorn_h1c C extension (fail if unavailable)
|
||||
- python: Force pure Python parser
|
||||
|
||||
The gunicorn_h1c C extension provides significantly faster HTTP
|
||||
parsing using picohttpparser with SIMD optimizations. Install it
|
||||
with: pip install gunicorn[fast]
|
||||
|
||||
!!! info "Added in 25.0.0"
|
||||
|
||||
@ -52,11 +52,16 @@ def _ip_in_allow_list(ip_str, allow_list, networks):
|
||||
|
||||
|
||||
class ParseResult:
|
||||
"""Result of header parsing."""
|
||||
"""Result of header parsing.
|
||||
|
||||
Headers are stored as bytes tuples for performance:
|
||||
- headers_bytes: list of (name_bytes_lowercase, value_bytes)
|
||||
- headers: list of (name_str_uppercase, value_str) for compatibility
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'method', 'uri', 'path', 'query', 'fragment', 'version',
|
||||
'headers', 'scheme', 'content_length', 'chunked',
|
||||
'headers', 'headers_bytes', 'scheme', 'content_length', 'chunked',
|
||||
'keep_alive', 'consumed', 'proxy_protocol_info',
|
||||
'must_close', 'expect_100_continue',
|
||||
)
|
||||
@ -68,7 +73,8 @@ class ParseResult:
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.version = None
|
||||
self.headers = []
|
||||
self.headers = [] # (name_str_uppercase, value_str) for compatibility
|
||||
self.headers_bytes = [] # (name_bytes_lowercase, value_bytes) for ASGI scope
|
||||
self.scheme = "http"
|
||||
self.content_length = 0
|
||||
self.chunked = False
|
||||
@ -164,15 +170,25 @@ class HttpParser:
|
||||
return self._feed_python(buffer)
|
||||
|
||||
def _feed_fast(self, buffer):
|
||||
"""Parse using fast C parser."""
|
||||
try:
|
||||
result = HttpParser._h1c_module.parse_request(bytes(buffer))
|
||||
"""Parse using fast C parser with optimized API.
|
||||
|
||||
# gunicorn_h1c returns bytes, convert to strings (latin-1)
|
||||
Uses parse_request_fast() which:
|
||||
- Accepts bytearray directly (no bytes() copy)
|
||||
- Returns pre-computed content_length, has_chunked, connection_close
|
||||
- Returns headers as bytes tuples (no intermediate conversion)
|
||||
"""
|
||||
h1c = HttpParser._h1c_module
|
||||
try:
|
||||
# Use parse_request_fast - accepts bytearray directly
|
||||
req = h1c.parse_request_fast(buffer)
|
||||
|
||||
# Build ParseResult from fast request object
|
||||
pr = ParseResult()
|
||||
pr.method = bytes_to_str(result['method'])
|
||||
# gunicorn_h1c returns 'path' which is the full URI (path?query)
|
||||
pr.uri = bytes_to_str(result['path'])
|
||||
|
||||
# Method and path (bytes -> str)
|
||||
pr.method = bytes_to_str(req.method)
|
||||
pr.uri = bytes_to_str(req.path)
|
||||
|
||||
# Parse path/query from URI
|
||||
try:
|
||||
parts = split_request_uri(pr.uri)
|
||||
@ -183,25 +199,52 @@ class HttpParser:
|
||||
pr.path = pr.uri
|
||||
pr.query = ""
|
||||
pr.fragment = ""
|
||||
pr.version = (1, result['minor_version'])
|
||||
|
||||
# Headers - convert to uppercase strings
|
||||
pr.headers = [(bytes_to_str(n).upper(), bytes_to_str(v)) for n, v in result['headers']]
|
||||
pr.version = (1, req.minor_version)
|
||||
pr.consumed = req.consumed
|
||||
|
||||
# Headers - store both bytes (for ASGI scope) and strings (for compatibility)
|
||||
# gunicorn_h1c returns headers as (name_bytes, value_bytes)
|
||||
headers_bytes = []
|
||||
headers_str = []
|
||||
for n, v in req.headers:
|
||||
# ASGI requires lowercase header names
|
||||
headers_bytes.append((n.lower(), v))
|
||||
# Compatibility: uppercase string names
|
||||
headers_str.append((bytes_to_str(n).upper(), bytes_to_str(v)))
|
||||
pr.headers_bytes = headers_bytes
|
||||
pr.headers = headers_str
|
||||
|
||||
# Use pre-computed body info from C parser
|
||||
pr.content_length = req.content_length if req.content_length >= 0 else 0
|
||||
pr.chunked = req.has_chunked
|
||||
|
||||
# connection_close: -1 = not set, 0 = keep-alive, 1 = close
|
||||
if req.connection_close == 1:
|
||||
pr.must_close = True
|
||||
pr.keep_alive = False
|
||||
elif req.connection_close == 0:
|
||||
pr.must_close = False
|
||||
pr.keep_alive = True
|
||||
else:
|
||||
# Not set - default based on HTTP version
|
||||
pr.keep_alive = req.minor_version >= 1
|
||||
pr.must_close = False
|
||||
|
||||
pr.consumed = result['consumed']
|
||||
pr.keep_alive = result['minor_version'] >= 1
|
||||
pr.scheme = "https" if self.is_ssl else "http"
|
||||
|
||||
# Parse body info from headers
|
||||
self._parse_body_info(pr)
|
||||
# Apply scheme headers for trusted proxies
|
||||
if self._is_trusted_proxy:
|
||||
self._apply_scheme_headers(pr)
|
||||
|
||||
self._result = pr
|
||||
return pr
|
||||
|
||||
except Exception as e:
|
||||
if "incomplete" in str(e).lower():
|
||||
return None
|
||||
raise
|
||||
except h1c.IncompleteError:
|
||||
return None
|
||||
except h1c.ParseError as e:
|
||||
# Map to gunicorn HTTP errors
|
||||
raise InvalidRequestLine(str(e))
|
||||
|
||||
def _feed_python(self, buffer):
|
||||
"""Parse using pure Python parser."""
|
||||
@ -263,9 +306,15 @@ class HttpParser:
|
||||
if buffer[headers_start:headers_start + 2] == b"\r\n":
|
||||
# Empty headers
|
||||
pr.consumed = headers_start + 2
|
||||
pr.headers_bytes = []
|
||||
else:
|
||||
headers_data = bytes(buffer[headers_start:headers_end])
|
||||
pr.headers = self._parse_headers(headers_data)
|
||||
# Also generate bytes headers for ASGI scope
|
||||
pr.headers_bytes = [
|
||||
(n.lower().encode('latin-1'), v.encode('latin-1'))
|
||||
for n, v in pr.headers
|
||||
]
|
||||
pr.consumed = headers_end + 4
|
||||
|
||||
# Set scheme
|
||||
@ -637,3 +686,238 @@ class HttpParser:
|
||||
"""Reset parser state for next request on keep-alive connection."""
|
||||
self._result = None
|
||||
self.req_number += 1
|
||||
|
||||
|
||||
class FastAsyncRequest:
|
||||
"""Fast async HTTP request wrapper.
|
||||
|
||||
Wraps a ParseResult from HttpParser and provides async body reading.
|
||||
This is a lightweight adapter that allows protocol.py to use the fast
|
||||
parser while maintaining compatibility with the existing interface.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'method', 'uri', 'path', 'query', 'fragment', 'version',
|
||||
'headers', 'headers_bytes', 'scheme', 'content_length', 'chunked',
|
||||
'must_close', 'proxy_protocol_info',
|
||||
'_reader', '_buffer', '_body_remaining', '_body_reader',
|
||||
'_expect_100_continue',
|
||||
)
|
||||
|
||||
def __init__(self, parse_result, reader, buffer, consumed):
|
||||
"""Initialize from a ParseResult.
|
||||
|
||||
Args:
|
||||
parse_result: ParseResult from HttpParser.feed()
|
||||
reader: asyncio.StreamReader for body reading
|
||||
buffer: bytearray buffer with remaining data after headers
|
||||
consumed: bytes consumed from buffer by parser
|
||||
"""
|
||||
# Copy attributes from ParseResult
|
||||
self.method = parse_result.method
|
||||
self.uri = parse_result.uri
|
||||
self.path = parse_result.path
|
||||
self.query = parse_result.query
|
||||
self.fragment = parse_result.fragment
|
||||
self.version = parse_result.version
|
||||
self.headers = parse_result.headers
|
||||
self.headers_bytes = parse_result.headers_bytes # Pre-computed bytes headers
|
||||
self.scheme = parse_result.scheme
|
||||
self.content_length = parse_result.content_length
|
||||
self.chunked = parse_result.chunked
|
||||
self.must_close = parse_result.must_close
|
||||
self.proxy_protocol_info = parse_result.proxy_protocol_info
|
||||
self._expect_100_continue = parse_result.expect_100_continue
|
||||
|
||||
# Body reading state
|
||||
self._reader = reader
|
||||
# Keep remaining data after headers in buffer
|
||||
self._buffer = bytearray(buffer[consumed:])
|
||||
if self.chunked:
|
||||
self._body_remaining = -1
|
||||
elif self.content_length:
|
||||
self._body_remaining = self.content_length
|
||||
else:
|
||||
self._body_remaining = 0
|
||||
self._body_reader = None
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for name, value in self.headers:
|
||||
if name == "CONNECTION":
|
||||
v = value.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for h, v in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
|
||||
async def read_body(self, size=8192):
|
||||
"""Read a chunk of the request body.
|
||||
|
||||
Args:
|
||||
size: Maximum bytes to read
|
||||
|
||||
Returns:
|
||||
bytes: Body data, empty bytes when body is exhausted
|
||||
"""
|
||||
if self._body_remaining == 0:
|
||||
return b""
|
||||
|
||||
if self.chunked:
|
||||
return await self._read_chunked_body(size)
|
||||
else:
|
||||
return await self._read_length_body(size)
|
||||
|
||||
async def _read_length_body(self, size):
|
||||
"""Read from a length-delimited body."""
|
||||
if self._body_remaining <= 0:
|
||||
return b""
|
||||
|
||||
to_read = min(size, self._body_remaining)
|
||||
|
||||
# First, use data from our buffer
|
||||
if self._buffer:
|
||||
if len(self._buffer) <= to_read:
|
||||
data = bytes(self._buffer)
|
||||
self._buffer.clear()
|
||||
else:
|
||||
data = bytes(self._buffer[:to_read])
|
||||
del self._buffer[:to_read]
|
||||
self._body_remaining -= len(data)
|
||||
return data
|
||||
|
||||
# Read from stream
|
||||
try:
|
||||
data = await self._reader.read(to_read)
|
||||
if data:
|
||||
self._body_remaining -= len(data)
|
||||
return data
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
async def _read_chunked_body(self, size):
|
||||
"""Read from a chunked body."""
|
||||
if self._body_reader is None:
|
||||
self._body_reader = self._chunked_body_reader()
|
||||
|
||||
try:
|
||||
return await anext(self._body_reader)
|
||||
except StopAsyncIteration:
|
||||
self._body_remaining = 0
|
||||
return b""
|
||||
|
||||
async def _chunked_body_reader(self):
|
||||
"""Async generator for reading chunked body."""
|
||||
while True:
|
||||
# Read chunk size line
|
||||
size_line = await self._read_until_crlf()
|
||||
# Parse chunk size (handle extensions)
|
||||
chunk_size, *_ = size_line.split(b";", 1)
|
||||
if _:
|
||||
chunk_size = chunk_size.rstrip(b" \t")
|
||||
|
||||
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
if len(chunk_size) == 0:
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
|
||||
chunk_size = int(chunk_size, 16)
|
||||
|
||||
if chunk_size == 0:
|
||||
# Final chunk - skip trailers and final CRLF
|
||||
await self._skip_trailers()
|
||||
return
|
||||
|
||||
# Read chunk data
|
||||
remaining = chunk_size
|
||||
while remaining > 0:
|
||||
data = await self._read_data(min(remaining, 8192))
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
# Skip chunk terminating CRLF
|
||||
crlf = await self._read_data(2)
|
||||
if crlf != b"\r\n":
|
||||
# May have partial read
|
||||
while len(crlf) < 2:
|
||||
more = await self._read_data(2 - len(crlf))
|
||||
if not more:
|
||||
break
|
||||
crlf += more
|
||||
|
||||
async def _read_data(self, size):
|
||||
"""Read data from buffer or stream."""
|
||||
if self._buffer:
|
||||
if len(self._buffer) <= size:
|
||||
data = bytes(self._buffer)
|
||||
self._buffer.clear()
|
||||
return data
|
||||
else:
|
||||
data = bytes(self._buffer[:size])
|
||||
del self._buffer[:size]
|
||||
return data
|
||||
try:
|
||||
return await self._reader.read(size)
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
async def _read_until_crlf(self):
|
||||
"""Read bytes until CRLF."""
|
||||
result = bytearray()
|
||||
while True:
|
||||
# Check buffer first
|
||||
if self._buffer:
|
||||
idx = self._buffer.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
result.extend(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
return bytes(result)
|
||||
result.extend(self._buffer)
|
||||
self._buffer.clear()
|
||||
|
||||
# Read more data
|
||||
try:
|
||||
data = await self._reader.read(64)
|
||||
except Exception:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
idx = data.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
result.extend(data[:idx])
|
||||
# Put remaining data back in buffer
|
||||
remaining = data[idx + 2:]
|
||||
if remaining:
|
||||
self._buffer.extend(remaining)
|
||||
return bytes(result)
|
||||
result.extend(data)
|
||||
|
||||
return bytes(result)
|
||||
|
||||
async def _skip_trailers(self):
|
||||
"""Skip trailer headers after chunked body."""
|
||||
while True:
|
||||
line = await self._read_until_crlf()
|
||||
if not line:
|
||||
return
|
||||
|
||||
async def drain_body(self):
|
||||
"""Drain any unread body data."""
|
||||
while True:
|
||||
data = await self.read_body(8192)
|
||||
if not data:
|
||||
break
|
||||
|
||||
@ -11,10 +11,11 @@ and dispatch to ASGI applications.
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import ipaddress
|
||||
from datetime import datetime
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.asgi.parser import HttpParser, FastAsyncRequest
|
||||
from gunicorn.asgi.uwsgi import AsyncUWSGIRequest
|
||||
from gunicorn.http.errors import NoMoreData
|
||||
from gunicorn.uwsgi.errors import UWSGIParseException
|
||||
@ -30,6 +31,56 @@ def _normalize_sockaddr(sockaddr):
|
||||
return tuple(sockaddr[:2]) if sockaddr else None
|
||||
|
||||
|
||||
def _check_trusted_proxy(peer_addr, allow_list, networks):
|
||||
"""Check if peer address is in the trusted proxy list.
|
||||
|
||||
Cached at connection start to avoid repeated IP parsing per request.
|
||||
"""
|
||||
if not isinstance(peer_addr, tuple):
|
||||
return False
|
||||
if '*' in allow_list:
|
||||
return True
|
||||
try:
|
||||
ip = ipaddress.ip_address(peer_addr[0])
|
||||
except ValueError:
|
||||
return False
|
||||
for network in networks:
|
||||
if ip in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# Cached response bytes for common cases
|
||||
_CACHED_STATUS_LINES = {}
|
||||
_CACHED_SERVER_HEADER = b"Server: gunicorn/asgi\r\n"
|
||||
|
||||
# Date header cache (updated once per second)
|
||||
_cached_date_header = b""
|
||||
_cached_date_time = 0.0
|
||||
|
||||
|
||||
def _get_cached_date_header():
|
||||
"""Get cached Date header, updating once per second."""
|
||||
global _cached_date_header, _cached_date_time # pylint: disable=global-statement
|
||||
import time
|
||||
now = time.time()
|
||||
if now - _cached_date_time >= 1.0:
|
||||
# Update date header
|
||||
from email.utils import formatdate
|
||||
_cached_date_header = f"Date: {formatdate(usegmt=True)}\r\n".encode("latin-1")
|
||||
_cached_date_time = now
|
||||
return _cached_date_header
|
||||
|
||||
|
||||
def _get_cached_status_line(version, status, reason):
|
||||
"""Get cached status line bytes."""
|
||||
key = (version, status)
|
||||
if key not in _CACHED_STATUS_LINES:
|
||||
line = f"HTTP/{version[0]}.{version[1]} {status} {reason}\r\n"
|
||||
_CACHED_STATUS_LINES[key] = line.encode("latin-1")
|
||||
return _CACHED_STATUS_LINES[key]
|
||||
|
||||
|
||||
class ASGIResponseInfo:
|
||||
"""Simple container for ASGI response info for access logging."""
|
||||
|
||||
@ -46,6 +97,96 @@ class ASGIResponseInfo:
|
||||
self.headers.append((name, value))
|
||||
|
||||
|
||||
class BodyReceiver:
|
||||
"""Lightweight body receiver that reads directly on demand.
|
||||
|
||||
Replaces per-request Queue and Task with direct on-demand reading.
|
||||
This reduces allocations and improves performance for most requests
|
||||
where body is read sequentially.
|
||||
"""
|
||||
|
||||
__slots__ = ('request', 'protocol', 'body_complete', '_disconnect_event')
|
||||
|
||||
def __init__(self, request, protocol):
|
||||
self.request = request
|
||||
self.protocol = protocol
|
||||
self.body_complete = False
|
||||
self._disconnect_event = asyncio.Event()
|
||||
|
||||
def signal_disconnect(self):
|
||||
"""Signal that connection has been lost."""
|
||||
self._disconnect_event.set()
|
||||
|
||||
async def receive(self): # pylint: disable=too-many-return-statements
|
||||
"""ASGI receive callable - reads body on demand."""
|
||||
# Already finished body - return disconnect
|
||||
if self.body_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
# No body expected - must return body message before disconnect
|
||||
if self.request.content_length == 0 and not self.request.chunked:
|
||||
self.body_complete = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
# Check for disconnect before reading (only when body hasn't been returned)
|
||||
if self.protocol._closed:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
# Read body chunk directly (no intermediate Queue)
|
||||
try:
|
||||
# Create tasks for reading body and waiting for disconnect
|
||||
read_task = asyncio.create_task(self.request.read_body(65536))
|
||||
disconnect_task = asyncio.create_task(self._disconnect_event.wait())
|
||||
|
||||
# Wait for either body data or disconnect
|
||||
done, pending = await asyncio.wait(
|
||||
[read_task, disconnect_task],
|
||||
return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
|
||||
# Cancel pending tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Check what completed
|
||||
if disconnect_task in done:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
chunk = read_task.result()
|
||||
|
||||
if chunk:
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
}
|
||||
else:
|
||||
self.body_complete = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return {"type": "http.disconnect"}
|
||||
except Exception:
|
||||
self.body_complete = True
|
||||
return {
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
}
|
||||
|
||||
|
||||
class ASGIProtocol(asyncio.Protocol):
|
||||
"""HTTP/1.1 protocol handler for ASGI applications.
|
||||
|
||||
@ -66,7 +207,14 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
|
||||
# Connection state
|
||||
self._closed = False
|
||||
self._receive_queue = None # Set per-request for disconnect signaling
|
||||
self._body_receiver = None # Set per-request for disconnect signaling
|
||||
|
||||
# Backpressure control
|
||||
self._reading_paused = False
|
||||
self._max_buffer_size = 65536 * 4 # 256KB max buffer
|
||||
|
||||
# Keep-alive timer
|
||||
self._keepalive_handle = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Called when a connection is established."""
|
||||
@ -98,6 +246,54 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
"""Called when data is received on the connection."""
|
||||
if self.reader:
|
||||
self.reader.feed_data(data)
|
||||
# Backpressure: pause reading if buffer is too large
|
||||
if not self._reading_paused and self._is_buffer_full():
|
||||
self._pause_reading()
|
||||
|
||||
def _is_buffer_full(self):
|
||||
"""Check if internal buffer is full."""
|
||||
# Check StreamReader internal buffer size
|
||||
if hasattr(self.reader, '_buffer'):
|
||||
return len(self.reader._buffer) > self._max_buffer_size
|
||||
return False
|
||||
|
||||
def _pause_reading(self):
|
||||
"""Pause reading from transport due to backpressure."""
|
||||
if not self._reading_paused and self.transport:
|
||||
self._reading_paused = True
|
||||
try:
|
||||
self.transport.pause_reading()
|
||||
except (AttributeError, RuntimeError):
|
||||
pass
|
||||
|
||||
def _resume_reading(self):
|
||||
"""Resume reading from transport."""
|
||||
if self._reading_paused and self.transport:
|
||||
self._reading_paused = False
|
||||
try:
|
||||
self.transport.resume_reading()
|
||||
except (AttributeError, RuntimeError):
|
||||
pass
|
||||
|
||||
def _arm_keepalive_timer(self):
|
||||
"""Arm keepalive timeout timer after response completion."""
|
||||
if self._keepalive_handle:
|
||||
self._keepalive_handle.cancel()
|
||||
keepalive_timeout = self.cfg.keepalive
|
||||
if keepalive_timeout > 0:
|
||||
self._keepalive_handle = self.worker.loop.call_later(
|
||||
keepalive_timeout, self._keepalive_timeout
|
||||
)
|
||||
|
||||
def _cancel_keepalive_timer(self):
|
||||
"""Cancel keepalive timer when new request arrives."""
|
||||
if self._keepalive_handle:
|
||||
self._keepalive_handle.cancel()
|
||||
self._keepalive_handle = None
|
||||
|
||||
def _keepalive_timeout(self):
|
||||
"""Called when keepalive timeout expires."""
|
||||
self._close_transport()
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Called when the connection is lost or closed.
|
||||
@ -115,12 +311,16 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
|
||||
self._closed = True
|
||||
self.worker.nr_conns -= 1
|
||||
|
||||
# Cancel keepalive timer
|
||||
self._cancel_keepalive_timer()
|
||||
|
||||
if self.reader:
|
||||
self.reader.feed_eof()
|
||||
|
||||
# Signal disconnect to the app via the receive queue
|
||||
if self._receive_queue is not None:
|
||||
self._receive_queue.put_nowait({"type": "http.disconnect"})
|
||||
# Signal disconnect to the app via the body receiver
|
||||
if self._body_receiver is not None:
|
||||
self._body_receiver.signal_disconnect()
|
||||
|
||||
# Schedule task cancellation after grace period if task doesn't complete
|
||||
if self._task and not self._task.done():
|
||||
@ -160,67 +360,18 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
|
||||
async def _handle_connection(self):
|
||||
"""Main request handling loop for this connection."""
|
||||
unreader = AsyncUnreader(self.reader)
|
||||
|
||||
try:
|
||||
peername = self.transport.get_extra_info('peername')
|
||||
sockname = self.transport.get_extra_info('sockname')
|
||||
|
||||
while not self._closed:
|
||||
self.req_count += 1
|
||||
# Check protocol type - use old path for uWSGI
|
||||
protocol_type = getattr(self.cfg, 'protocol', 'http')
|
||||
if protocol_type == 'uwsgi':
|
||||
await self._handle_connection_uwsgi(peername, sockname)
|
||||
return
|
||||
|
||||
try:
|
||||
# Parse request based on protocol
|
||||
protocol = getattr(self.cfg, 'protocol', 'http')
|
||||
if protocol == 'uwsgi':
|
||||
request = await AsyncUWSGIRequest.parse(
|
||||
self.cfg,
|
||||
unreader,
|
||||
peername,
|
||||
self.req_count
|
||||
)
|
||||
else:
|
||||
request = await AsyncRequest.parse(
|
||||
self.cfg,
|
||||
unreader,
|
||||
peername,
|
||||
self.req_count
|
||||
)
|
||||
except NoMoreData:
|
||||
# Client disconnected
|
||||
break
|
||||
except UWSGIParseException as e:
|
||||
self.log.debug("uWSGI parse error: %s", e)
|
||||
break
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
break # WebSocket takes over the connection
|
||||
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
self.worker.nr += 1
|
||||
|
||||
# Check max_requests
|
||||
if self.worker.nr >= self.worker.max_requests:
|
||||
self.log.info("Autorestarting worker after current request.")
|
||||
self.worker.alive = False
|
||||
keepalive = False
|
||||
|
||||
if not keepalive or not self.worker.alive:
|
||||
break
|
||||
|
||||
# Check connection limits for keepalive
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
# Drain any unread body before next request
|
||||
await request.drain_body()
|
||||
# Fast path: use HttpParser for HTTP protocol
|
||||
await self._handle_connection_fast(peername, sockname)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
@ -229,6 +380,161 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
finally:
|
||||
self._close_transport()
|
||||
|
||||
async def _handle_connection_fast(self, peername, sockname):
|
||||
"""Fast HTTP connection handling using HttpParser."""
|
||||
# Check if peer is trusted proxy once per connection
|
||||
is_trusted = _check_trusted_proxy(
|
||||
peername,
|
||||
self.cfg.forwarded_allow_ips,
|
||||
self.cfg.forwarded_allow_networks()
|
||||
)
|
||||
|
||||
# Get SSL state
|
||||
ssl_object = self.transport.get_extra_info('ssl_object')
|
||||
is_ssl = ssl_object is not None
|
||||
|
||||
# Create parser and buffer
|
||||
parser = HttpParser(
|
||||
self.cfg, peername, is_ssl=is_ssl,
|
||||
req_number=1, is_trusted_proxy=is_trusted
|
||||
)
|
||||
buffer = bytearray()
|
||||
|
||||
while not self._closed:
|
||||
self.req_count += 1
|
||||
|
||||
# Cancel keepalive timer when new request starts
|
||||
self._cancel_keepalive_timer()
|
||||
|
||||
try:
|
||||
# Parse request using fast parser
|
||||
request = await self._parse_request_fast(
|
||||
parser, buffer, peername
|
||||
)
|
||||
except NoMoreData:
|
||||
# Client disconnected
|
||||
break
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
break # WebSocket takes over the connection
|
||||
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
self.worker.nr += 1
|
||||
|
||||
# Check max_requests
|
||||
if self.worker.nr >= self.worker.max_requests:
|
||||
self.log.info("Autorestarting worker after current request.")
|
||||
self.worker.alive = False
|
||||
keepalive = False
|
||||
|
||||
if not keepalive or not self.worker.alive:
|
||||
break
|
||||
|
||||
# Check connection limits for keepalive
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
# Drain any unread body before next request
|
||||
await request.drain_body()
|
||||
|
||||
# Resume reading if paused during body consumption
|
||||
self._resume_reading()
|
||||
|
||||
# Reset parser for next request (keep trusted proxy check)
|
||||
parser.reset()
|
||||
|
||||
# Arm keepalive timer between requests
|
||||
self._arm_keepalive_timer()
|
||||
|
||||
async def _parse_request_fast(self, parser, buffer, peername):
|
||||
"""Parse request using fast HttpParser.
|
||||
|
||||
Returns a FastAsyncRequest wrapping the ParseResult.
|
||||
"""
|
||||
# Read data until we have complete headers
|
||||
while True:
|
||||
# Try to parse current buffer
|
||||
if buffer:
|
||||
try:
|
||||
result = parser.feed(buffer)
|
||||
if result is not None:
|
||||
# Headers complete - create request wrapper
|
||||
request = FastAsyncRequest(
|
||||
result, self.reader, buffer, result.consumed
|
||||
)
|
||||
# Clear consumed data from buffer
|
||||
del buffer[:result.consumed]
|
||||
return request
|
||||
except Exception as e:
|
||||
# Re-raise HTTP parsing errors
|
||||
if 'incomplete' not in str(e).lower():
|
||||
raise
|
||||
|
||||
# Need more data
|
||||
try:
|
||||
data = await self.reader.read(65536)
|
||||
except Exception:
|
||||
data = b""
|
||||
|
||||
if not data:
|
||||
raise NoMoreData(bytes(buffer))
|
||||
|
||||
buffer.extend(data)
|
||||
|
||||
async def _handle_connection_uwsgi(self, peername, sockname):
|
||||
"""Handle uWSGI protocol connections (legacy path)."""
|
||||
unreader = AsyncUnreader(self.reader)
|
||||
|
||||
while not self._closed:
|
||||
self.req_count += 1
|
||||
|
||||
try:
|
||||
request = await AsyncUWSGIRequest.parse(
|
||||
self.cfg,
|
||||
unreader,
|
||||
peername,
|
||||
self.req_count
|
||||
)
|
||||
except NoMoreData:
|
||||
break
|
||||
except UWSGIParseException as e:
|
||||
self.log.debug("uWSGI parse error: %s", e)
|
||||
break
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
break
|
||||
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
self.worker.nr += 1
|
||||
|
||||
# Check max_requests
|
||||
if self.worker.nr >= self.worker.max_requests:
|
||||
self.log.info("Autorestarting worker after current request.")
|
||||
self.worker.alive = False
|
||||
keepalive = False
|
||||
|
||||
if not keepalive or not self.worker.alive:
|
||||
break
|
||||
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
await request.drain_body()
|
||||
|
||||
def _is_websocket_upgrade(self, request):
|
||||
"""Check if request is a WebSocket upgrade.
|
||||
|
||||
@ -273,36 +579,9 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
response_headers = []
|
||||
response_sent = 0
|
||||
|
||||
# Receive queue for body - stored on self for disconnect signaling
|
||||
receive_queue = asyncio.Queue()
|
||||
self._receive_queue = receive_queue
|
||||
body_complete = False
|
||||
|
||||
# Pre-populate with initial body state
|
||||
if request.content_length == 0 and not request.chunked:
|
||||
await receive_queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
body_complete = True
|
||||
else:
|
||||
# Start body reading task
|
||||
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
|
||||
|
||||
async def receive():
|
||||
nonlocal body_complete
|
||||
# Check if already disconnected before waiting
|
||||
if self._closed and body_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
msg = await receive_queue.get()
|
||||
|
||||
# Track when body is complete
|
||||
if msg.get("type") == "http.request" and not msg.get("more_body", True):
|
||||
body_complete = True
|
||||
|
||||
return msg
|
||||
# Create body receiver - reads directly on demand, no Queue/Task overhead
|
||||
body_receiver = BodyReceiver(request, self)
|
||||
self._body_receiver = body_receiver
|
||||
|
||||
async def send(message):
|
||||
nonlocal response_started, response_complete, exc_to_raise
|
||||
@ -373,7 +652,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
request_start = datetime.now()
|
||||
self.cfg.pre_request(self.worker, request)
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
await self.app(scope, body_receiver.receive, send)
|
||||
|
||||
if exc_to_raise is not None:
|
||||
raise exc_to_raise
|
||||
@ -394,8 +673,8 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
response_status = 500
|
||||
return False
|
||||
finally:
|
||||
# Clear the receive queue reference
|
||||
self._receive_queue = None
|
||||
# Clear the body receiver reference
|
||||
self._body_receiver = None
|
||||
|
||||
try:
|
||||
request_time = datetime.now() - request_start
|
||||
@ -412,38 +691,17 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
|
||||
return self.worker.alive and self.cfg.keepalive
|
||||
|
||||
async def _read_body_to_queue(self, request, queue):
|
||||
"""Read request body and put chunks on the queue."""
|
||||
try:
|
||||
while True:
|
||||
chunk = await request.read_body(65536)
|
||||
if chunk:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
})
|
||||
else:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
self.log.debug("Error reading body: %s", e)
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
|
||||
def _build_http_scope(self, request, sockname, peername):
|
||||
"""Build ASGI HTTP scope from parsed request."""
|
||||
# Build headers list as bytes tuples
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
# Use pre-computed bytes headers if available (fast path)
|
||||
# Fall back to conversion for legacy requests (AsyncRequest, HTTP/2)
|
||||
headers_bytes = getattr(request, 'headers_bytes', None)
|
||||
if isinstance(headers_bytes, list):
|
||||
headers = list(headers_bytes) # Copy to avoid mutation
|
||||
else:
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
server = _normalize_sockaddr(sockname)
|
||||
client = _normalize_sockaddr(peername)
|
||||
@ -563,26 +821,54 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
self._safe_write(response.encode("latin-1"))
|
||||
|
||||
async def _send_response_start(self, status, headers, request):
|
||||
"""Send HTTP response status and headers."""
|
||||
# Build status line
|
||||
reason = self._get_reason_phrase(status)
|
||||
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
|
||||
"""Send HTTP response status and headers.
|
||||
|
||||
# Build headers
|
||||
header_lines = []
|
||||
Uses cached status lines and headers for common cases to avoid
|
||||
repeated string formatting and encoding.
|
||||
"""
|
||||
# Get cached status line bytes
|
||||
reason = self._get_reason_phrase(status)
|
||||
status_line = _get_cached_status_line(request.version, status, reason)
|
||||
|
||||
# Build headers as bytes directly
|
||||
parts = [status_line]
|
||||
|
||||
has_date = False
|
||||
has_server = False
|
||||
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
name_lower = name.lower()
|
||||
parts.append(name)
|
||||
else:
|
||||
name_lower = name.lower().encode("latin-1")
|
||||
parts.append(name.encode("latin-1"))
|
||||
|
||||
parts.append(b": ")
|
||||
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
header_lines.append(f"{name}: {value}\r\n")
|
||||
parts.append(value)
|
||||
else:
|
||||
parts.append(value.encode("latin-1"))
|
||||
|
||||
# Add server header if not present
|
||||
header_lines.append("Server: gunicorn/asgi\r\n")
|
||||
parts.append(b"\r\n")
|
||||
|
||||
response = status_line + "".join(header_lines) + "\r\n"
|
||||
self._safe_write(response.encode("latin-1"))
|
||||
# Track if Date/Server headers are present
|
||||
if name_lower == b"date":
|
||||
has_date = True
|
||||
elif name_lower == b"server":
|
||||
has_server = True
|
||||
|
||||
# Add default headers if not present
|
||||
if not has_server:
|
||||
parts.append(_CACHED_SERVER_HEADER)
|
||||
if not has_date:
|
||||
parts.append(_get_cached_date_header())
|
||||
|
||||
parts.append(b"\r\n")
|
||||
|
||||
# Write as single buffer
|
||||
self._safe_write(b"".join(parts))
|
||||
|
||||
async def _send_body(self, body, chunked=False):
|
||||
"""Send response body chunk."""
|
||||
|
||||
@ -28,7 +28,7 @@ _fast_parser_module = None
|
||||
|
||||
def _check_fast_parser(cfg):
|
||||
"""Check if fast C parser is available and should be used."""
|
||||
global _fast_parser_available, _fast_parser_module
|
||||
global _fast_parser_available, _fast_parser_module # pylint: disable=global-statement
|
||||
|
||||
parser_setting = getattr(cfg, 'http_parser', 'auto')
|
||||
if parser_setting == 'python':
|
||||
|
||||
@ -7,10 +7,8 @@ Tests for ASGI worker components.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import ipaddress
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
@ -9,9 +9,10 @@ Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec:
|
||||
https://asgi.readthedocs.io/en/latest/specs/main.html
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
|
||||
|
||||
@ -730,31 +731,46 @@ class TestHTTPDisconnectEvent:
|
||||
return protocol
|
||||
|
||||
def test_disconnect_event_type(self):
|
||||
"""Test that disconnect event has correct type per ASGI spec."""
|
||||
"""Test that disconnect event signals body receiver per ASGI spec."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Simulate client disconnect
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Get the message from queue
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
|
||||
# Per ASGI spec: type MUST be "http.disconnect"
|
||||
assert msg["type"] == "http.disconnect"
|
||||
# Per ASGI spec: disconnect should be signaled
|
||||
assert body_receiver._disconnect_event.is_set()
|
||||
|
||||
def test_disconnect_event_sent_on_connection_lost(self):
|
||||
"""Test that http.disconnect is sent when connection is lost."""
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
"""Test that disconnect is signaled when connection is lost."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
assert protocol._receive_queue.empty()
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
assert not body_receiver._disconnect_event.is_set()
|
||||
|
||||
# Simulate client disconnect
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Queue should have disconnect message
|
||||
assert not protocol._receive_queue.empty()
|
||||
# Disconnect should have been signaled
|
||||
assert body_receiver._disconnect_event.is_set()
|
||||
|
||||
def test_disconnect_sets_closed_flag(self):
|
||||
"""Test that connection_lost sets the closed flag."""
|
||||
@ -788,18 +804,33 @@ class TestHTTPDisconnectEvent:
|
||||
# Cancellation should be scheduled after grace period
|
||||
protocol.worker.loop.call_later.assert_called_once()
|
||||
|
||||
def test_disconnect_message_format(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_message_format(self):
|
||||
"""Test http.disconnect message format per ASGI spec.
|
||||
|
||||
The disconnect message should only contain 'type' key.
|
||||
When body is complete and disconnect is signaled, receive()
|
||||
should return {"type": "http.disconnect"}.
|
||||
"""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = self._create_protocol()
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
|
||||
protocol.connection_lost(None)
|
||||
# Create a mock request with no body
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 0
|
||||
mock_request.chunked = False
|
||||
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Get initial body message (empty body)
|
||||
msg1 = await body_receiver.receive()
|
||||
assert msg1["type"] == "http.request"
|
||||
assert msg1["more_body"] is False
|
||||
|
||||
# Now receive should return disconnect
|
||||
msg2 = await body_receiver.receive()
|
||||
|
||||
# Per ASGI spec, disconnect message only has 'type'
|
||||
assert msg == {"type": "http.disconnect"}
|
||||
assert len(msg) == 1
|
||||
assert msg2 == {"type": "http.disconnect"}
|
||||
assert len(msg2) == 1
|
||||
|
||||
@ -50,41 +50,58 @@ class TestASGIGracefulDisconnect:
|
||||
|
||||
assert protocol._closed is True
|
||||
|
||||
def test_disconnect_sends_message_to_queue(self, mock_worker):
|
||||
"""Test that connection_lost sends http.disconnect to receive queue."""
|
||||
def test_disconnect_signals_body_receiver(self, mock_worker):
|
||||
"""Test that connection_lost signals the body receiver."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
mock_worker.nr_conns = 1
|
||||
|
||||
# Create a receive queue (simulating active request)
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
# Create a body receiver (simulating active request)
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Verify disconnect event is not set initially
|
||||
assert not body_receiver._disconnect_event.is_set()
|
||||
|
||||
# Simulate connection lost
|
||||
protocol.connection_lost(None)
|
||||
|
||||
# Check that disconnect message was sent
|
||||
assert not protocol._receive_queue.empty()
|
||||
msg = protocol._receive_queue.get_nowait()
|
||||
assert msg == {"type": "http.disconnect"}
|
||||
# Check that disconnect event was signaled
|
||||
assert body_receiver._disconnect_event.is_set()
|
||||
|
||||
def test_disconnect_is_idempotent(self, mock_worker):
|
||||
"""Test that connection_lost can be called multiple times safely."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol.reader = mock.Mock()
|
||||
mock_worker.nr_conns = 2 # Start with 2 so we can verify only 1 is decremented
|
||||
|
||||
protocol._receive_queue = asyncio.Queue()
|
||||
# Create a mock request for the body receiver
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 100
|
||||
mock_request.chunked = False
|
||||
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# First call should work
|
||||
protocol.connection_lost(None)
|
||||
assert protocol._closed is True
|
||||
assert mock_worker.nr_conns == 1
|
||||
assert protocol._receive_queue.qsize() == 1
|
||||
assert body_receiver._disconnect_event.is_set()
|
||||
|
||||
# Second call should be a no-op
|
||||
protocol.connection_lost(None)
|
||||
assert mock_worker.nr_conns == 1 # Should not decrement again
|
||||
assert protocol._receive_queue.qsize() == 1 # Should not add another message
|
||||
# Event is still set (no way to "double set" an event, so this is fine)
|
||||
|
||||
def test_disconnect_does_not_cancel_immediately(self, mock_worker):
|
||||
"""Test that connection_lost doesn't cancel task immediately."""
|
||||
@ -156,41 +173,26 @@ class TestASGIGracefulDisconnect:
|
||||
@pytest.mark.asyncio
|
||||
async def test_receive_returns_disconnect_when_closed(self, mock_worker):
|
||||
"""Test that receive() returns http.disconnect when connection is closed."""
|
||||
from gunicorn.asgi.protocol import BodyReceiver
|
||||
|
||||
protocol = ASGIProtocol(mock_worker)
|
||||
protocol._closed = True
|
||||
|
||||
# Create receive queue with body complete
|
||||
receive_queue = asyncio.Queue()
|
||||
protocol._receive_queue = receive_queue
|
||||
# Create a mock request with no body
|
||||
mock_request = mock.Mock()
|
||||
mock_request.content_length = 0
|
||||
mock_request.chunked = False
|
||||
|
||||
# Add initial body message
|
||||
await receive_queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
body_receiver = BodyReceiver(mock_request, protocol)
|
||||
protocol._body_receiver = body_receiver
|
||||
|
||||
# Simulate what happens in _handle_http_request
|
||||
body_complete = False
|
||||
|
||||
async def receive():
|
||||
nonlocal body_complete
|
||||
if protocol._closed and body_complete:
|
||||
return {"type": "http.disconnect"}
|
||||
|
||||
msg = await receive_queue.get()
|
||||
|
||||
if msg.get("type") == "http.request" and not msg.get("more_body", True):
|
||||
body_complete = True
|
||||
|
||||
return msg
|
||||
|
||||
# First receive gets the body
|
||||
msg1 = await receive()
|
||||
# First receive gets the body (empty)
|
||||
msg1 = await body_receiver.receive()
|
||||
assert msg1["type"] == "http.request"
|
||||
assert msg1["more_body"] is False
|
||||
|
||||
# Second receive should get disconnect
|
||||
msg2 = await receive()
|
||||
# Second receive should get disconnect (body complete)
|
||||
msg2 = await body_receiver.receive()
|
||||
assert msg2["type"] == "http.disconnect"
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,6 @@ and extension support.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
Tests for ASGI HTTP parser optimizations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
|
||||
@ -9,7 +9,6 @@ Tests for chunked transfer encoding, Server-Sent Events (SSE),
|
||||
and streaming response handling.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
@ -9,7 +9,6 @@ Tests that gunicorn's WebSocket implementation conforms to RFC 6455:
|
||||
https://tools.ietf.org/html/rfc6455
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
@ -12,11 +12,7 @@ that actually start the server and make HTTP requests.
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@ -120,7 +116,7 @@ class FakeListener:
|
||||
def _has_uvloop():
|
||||
"""Check if uvloop is available."""
|
||||
try:
|
||||
import uvloop
|
||||
import uvloop # noqa: F401
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
@ -337,9 +333,9 @@ class TestLifespanManager:
|
||||
async def app(scope, receive, send):
|
||||
assert "state" in scope
|
||||
scope["state"]["db"] = "connected"
|
||||
message = await receive()
|
||||
_ = await receive()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
_ = await receive()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock(), state)
|
||||
@ -555,7 +551,6 @@ class TestASGIProtocol:
|
||||
def test_scope_building(self):
|
||||
"""Test HTTP scope building."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user