Remove unused AsyncRequest class

AsyncRequest was the legacy pull-based async HTTP parser, now replaced
by the push-based CallbackRequest/PythonProtocol. Remove the unused
code and associated tests.
This commit is contained in:
Benoit Chesneau 2026-03-26 16:08:35 +01:00
parent b00f125755
commit da8bd4850a
5 changed files with 2 additions and 1368 deletions

View File

@ -10,7 +10,6 @@ HTTP parsing infrastructure adapted for async I/O.
Components:
- AsyncUnreader: Async socket reading with pushback buffer
- AsyncRequest: Async HTTP request parser
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
- LifespanManager: ASGI lifespan protocol support
@ -20,7 +19,6 @@ Usage:
"""
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
from gunicorn.asgi.lifespan import LifespanManager
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']
__all__ = ['AsyncUnreader', 'LifespanManager']

View File

@ -1,736 +0,0 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Async version of gunicorn/http/message.py for ASGI workers.
Reuses the parsing logic from the sync version, adapted for async I/O.
"""
import ipaddress
import re
import socket
import struct
from gunicorn.http.errors import (
ExpectationFailed,
InvalidHeader, InvalidHeaderName, NoMoreData,
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
LimitRequestLine, LimitRequestHeaders,
UnsupportedTransferCoding, ObsoleteFolding,
InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest,
InvalidSchemeHeaders,
)
from gunicorn.http.message import (
PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol
)
from gunicorn.util import bytes_to_str, split_request_uri
MAX_REQUEST_LINE = 8190
MAX_HEADERS = 32768
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
# Reuse regex patterns from sync version
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
METHOD_BADCHAR_RE = re.compile("[a-z#]")
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
def _ip_in_allow_list(ip_str, allow_list, networks):
"""Check if IP address is in the allow list.
Args:
ip_str: The IP address string to check
allow_list: The original allow list (strings, may contain "*")
networks: Pre-computed ipaddress.ip_network objects from config
"""
if '*' in allow_list:
return True
try:
ip = ipaddress.ip_address(ip_str)
except ValueError:
return False
for network in networks:
if ip in network:
return True
return False
class AsyncRequest:
"""Async HTTP request parser.
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
parsing logic where possible.
"""
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.remote_addr = peer_addr
self.req_number = req_number
self.version = None
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = None
self.headers = []
self.trailers = []
self.scheme = "https" if cfg.is_ssl else "http"
self.must_close = False
self._expected_100_continue = False
self.proxy_protocol_info = None
# Request line limit
self.limit_request_line = cfg.limit_request_line
if (self.limit_request_line < 0
or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE
# Headers limits
self.limit_request_fields = cfg.limit_request_fields
if (self.limit_request_fields <= 0
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = cfg.limit_request_field_size
if self.limit_request_field_size < 0:
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
# Max header buffer size
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
self.max_buffer_headers = self.limit_request_fields * \
(max_header_field_size + 2) + 4
# Body-related state
self.content_length = None
self.chunked = False
self._body_reader = None
self._body_remaining = 0
@classmethod
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
"""Parse an HTTP request from the stream.
Args:
cfg: gunicorn config object
unreader: AsyncUnreader instance
peer_addr: client address tuple
req_number: request number on this connection (for keepalive)
Returns:
AsyncRequest: Parsed request object
Raises:
NoMoreData: If no data available
Various parsing errors for malformed requests
"""
req = cls(cfg, unreader, peer_addr, req_number)
await req._parse()
return req
async def _parse(self):
"""Parse the request from the unreader."""
buf = bytearray()
await self._read_into(buf)
# Handle proxy protocol if enabled and this is the first request
mode = self.cfg.proxy_protocol
if mode != "off" and self.req_number == 1:
buf = await self._handle_proxy_protocol(buf, mode)
# Get request line
line, buf = await self._read_line(buf, self.limit_request_line)
self._parse_request_line(line)
# Headers - use bytearray.find() directly to avoid bytes() conversions
while True:
idx = buf.find(b"\r\n\r\n")
done = buf[:2] == b"\r\n"
if idx < 0 and not done:
await self._read_into(buf)
if len(buf) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers")
else:
break
if done:
self.unreader.unread(bytes(buf[2:]))
else:
self.headers = self._parse_headers(bytes(buf[:idx]), from_trailer=False)
self.unreader.unread(bytes(buf[idx + 4:]))
self._set_body_reader()
async def _read_into(self, buf):
"""Read data from unreader and append to bytearray buffer."""
data = await self.unreader.read()
if not data:
raise NoMoreData(bytes(buf))
buf.extend(data)
async def _read_line(self, buf, limit=0):
"""Read a line from buffer, returning (line, remaining_buffer).
Uses bytearray.find() directly to avoid repeated bytes() conversions.
"""
while True:
idx = buf.find(b"\r\n")
if idx >= 0:
if idx > limit > 0:
raise LimitRequestLine(idx, limit)
break
if len(buf) - 2 > limit > 0:
raise LimitRequestLine(len(buf), limit)
await self._read_into(buf)
line = bytes(buf[:idx])
remaining = bytearray(buf[idx + 2:])
return (line, remaining)
async def _handle_proxy_protocol(self, buf, mode):
"""Handle PROXY protocol detection and parsing.
Returns the buffer with proxy protocol data consumed.
"""
# Ensure we have enough data to detect v2 signature (12 bytes)
while len(buf) < 12:
await self._read_into(buf)
# Check for v2 signature first
if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE:
self._proxy_protocol_access_check()
return await self._parse_proxy_protocol_v2(buf)
# Check for v1 prefix
if mode in ("v1", "auto") and buf[:6] == b"PROXY ":
self._proxy_protocol_access_check()
return await self._parse_proxy_protocol_v1(buf)
# Not proxy protocol - return buffer unchanged
return buf
def _proxy_protocol_access_check(self):
"""Check if proxy protocol is allowed from this peer."""
if (isinstance(self.peer_addr, tuple) and
not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips,
self.cfg.proxy_allow_networks())):
raise ForbiddenProxyRequest(self.peer_addr[0])
async def _parse_proxy_protocol_v1(self, buf):
"""Parse PROXY protocol v1 (text format).
Returns buffer with v1 header consumed.
"""
# Read until we find \r\n
data = bytes(buf)
while b"\r\n" not in data:
await self._read_into(buf)
data = bytes(buf)
idx = data.find(b"\r\n")
line = bytes_to_str(data[:idx])
remaining = bytearray(data[idx + 2:])
bits = line.split(" ")
if len(bits) != 6:
raise InvalidProxyLine(line)
proto = bits[1]
s_addr = bits[2]
d_addr = bits[3]
if proto not in ["TCP4", "TCP6"]:
raise InvalidProxyLine("protocol '%s' not supported" % proto)
if proto == "TCP4":
try:
socket.inet_pton(socket.AF_INET, s_addr)
socket.inet_pton(socket.AF_INET, d_addr)
except OSError:
raise InvalidProxyLine(line)
elif proto == "TCP6":
try:
socket.inet_pton(socket.AF_INET6, s_addr)
socket.inet_pton(socket.AF_INET6, d_addr)
except OSError:
raise InvalidProxyLine(line)
try:
s_port = int(bits[4])
d_port = int(bits[5])
except ValueError:
raise InvalidProxyLine("invalid port %s" % line)
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
raise InvalidProxyLine("invalid port %s" % line)
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
return remaining
async def _parse_proxy_protocol_v2(self, buf):
"""Parse PROXY protocol v2 (binary format).
Returns buffer with v2 header consumed.
"""
# We need at least 16 bytes for the header (12 signature + 4 header)
while len(buf) < 16:
await self._read_into(buf)
# Parse header fields (after 12-byte signature)
ver_cmd = buf[12]
fam_proto = buf[13]
length = struct.unpack(">H", bytes(buf[14:16]))[0]
# Validate version (high nibble must be 0x2)
version = (ver_cmd & 0xF0) >> 4
if version != 2:
raise InvalidProxyHeader("unsupported version %d" % version)
# Extract command (low nibble)
command = ver_cmd & 0x0F
if command not in (PPCommand.LOCAL, PPCommand.PROXY):
raise InvalidProxyHeader("unsupported command %d" % command)
# Ensure we have the complete header
total_header_size = 16 + length
while len(buf) < total_header_size:
await self._read_into(buf)
# For LOCAL command, no address info is provided
if command == PPCommand.LOCAL:
self.proxy_protocol_info = {
"proxy_protocol": "LOCAL",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
# Extract address family and protocol
family = (fam_proto & 0xF0) >> 4
protocol = fam_proto & 0x0F
# We only support TCP (STREAM)
if protocol != PPProtocol.STREAM:
raise InvalidProxyHeader("only TCP protocol is supported")
addr_data = bytes(buf[16:16 + length])
if family == PPFamily.INET: # IPv4
if length < 12: # 4+4+2+2
raise InvalidProxyHeader("insufficient address data for IPv4")
s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4])
d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8])
s_port = struct.unpack(">H", addr_data[8:10])[0]
d_port = struct.unpack(">H", addr_data[10:12])[0]
proto = "TCP4"
elif family == PPFamily.INET6: # IPv6
if length < 36: # 16+16+2+2
raise InvalidProxyHeader("insufficient address data for IPv6")
s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16])
d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32])
s_port = struct.unpack(">H", addr_data[32:34])[0]
d_port = struct.unpack(">H", addr_data[34:36])[0]
proto = "TCP6"
elif family == PPFamily.UNSPEC:
# No address info provided with PROXY command
self.proxy_protocol_info = {
"proxy_protocol": "UNSPEC",
"client_addr": None,
"client_port": None,
"proxy_addr": None,
"proxy_port": None
}
return bytearray(buf[total_header_size:])
else:
raise InvalidProxyHeader("unsupported address family %d" % family)
# Set data
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
return bytearray(buf[total_header_size:])
def _parse_request_line(self, line_bytes):
"""Parse the HTTP request line."""
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
if len(bits) != 3:
raise InvalidRequestLine(bytes_to_str(line_bytes))
# Method
self.method = bits[0]
if not self.cfg.permit_unconventional_http_method:
if METHOD_BADCHAR_RE.search(self.method):
raise InvalidRequestMethod(self.method)
if not 3 <= len(bits[0]) <= 20:
raise InvalidRequestMethod(self.method)
if not TOKEN_RE.fullmatch(self.method):
raise InvalidRequestMethod(self.method)
if self.cfg.casefold_http_method:
self.method = self.method.upper()
# URI
self.uri = bits[1]
if len(self.uri) == 0:
raise InvalidRequestLine(bytes_to_str(line_bytes))
try:
parts = split_request_uri(self.uri)
except ValueError:
raise InvalidRequestLine(bytes_to_str(line_bytes))
self.path = parts.path or ""
self.query = parts.query or ""
self.fragment = parts.fragment or ""
# Version
match = VERSION_RE.fullmatch(bits[2])
if match is None:
raise InvalidHTTPVersion(bits[2])
self.version = (int(match.group(1)), int(match.group(2)))
if not (1, 0) <= self.version < (2, 0):
if not self.cfg.permit_unconventional_http_version:
raise InvalidHTTPVersion(self.version)
def _parse_headers(self, data, from_trailer=False):
"""Parse HTTP headers from raw data.
Uses index-based iteration instead of list.pop(0) for O(1) access.
"""
cfg = self.cfg
headers = []
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
num_lines = len(lines)
i = 0
# Handle scheme headers
scheme_header = False
secure_scheme_headers = {}
forwarder_headers = []
if from_trailer:
pass
elif (not isinstance(self.peer_addr, tuple)
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
cfg.forwarded_allow_networks())):
secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers
while i < num_lines:
if len(headers) >= self.limit_request_fields:
raise LimitRequestHeaders("limit request headers fields")
curr = lines[i]
i += 1
header_length = len(curr) + len("\r\n")
if curr.find(":") <= 0:
raise InvalidHeader(curr)
name, value = curr.split(":", 1)
if self.cfg.strip_header_spaces:
name = name.rstrip(" \t")
if not TOKEN_RE.fullmatch(name):
raise InvalidHeaderName(name)
name = name.upper()
value = [value.strip(" \t")]
# Consume value continuation lines using index-based iteration
while i < num_lines and lines[i].startswith((" ", "\t")):
if not self.cfg.permit_obsolete_folding:
raise ObsoleteFolding(name)
curr = lines[i]
i += 1
header_length += len(curr) + len("\r\n")
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
value.append(curr.strip("\t "))
value = " ".join(value)
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
raise InvalidHeader(name)
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
if not from_trailer and name == "EXPECT":
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
# "The Expect field value is case-insensitive."
if value.lower() == "100-continue":
if self.version < (1, 1):
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
# "A server that receives a 100-continue expectation
# in an HTTP/1.0 request MUST ignore that expectation."
pass
else:
self._expected_100_continue = True
# N.B. understood but ignored expect header does not return 417
else:
raise ExpectationFailed(value)
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_header:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_header = True
self.scheme = scheme
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
pass
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue
else:
raise InvalidHeaderName(name)
headers.append((name, value))
return headers
def _set_body_reader(self):
"""Determine how to read the request body."""
chunked = False
content_length = None
for (name, value) in self.headers:
if name == "CONTENT-LENGTH":
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
content_length = value
elif name == "TRANSFER-ENCODING":
vals = [v.strip() for v in value.split(',')]
for val in vals:
if val.lower() == "chunked":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
chunked = True
elif val.lower() == "identity":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
elif val.lower() in ('compress', 'deflate', 'gzip'):
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
self.force_close()
else:
raise UnsupportedTransferCoding(value)
if chunked:
if self.version < (1, 1):
raise InvalidHeader("TRANSFER-ENCODING", req=self)
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.chunked = True
self.content_length = None
self._body_remaining = -1
elif content_length is not None:
try:
if str(content_length).isnumeric():
content_length = int(content_length)
else:
raise InvalidHeader("CONTENT-LENGTH", req=self)
except ValueError:
raise InvalidHeader("CONTENT-LENGTH", req=self)
if content_length < 0:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.content_length = content_length
self._body_remaining = content_length
else:
# No body for requests without Content-Length or Transfer-Encoding
self.content_length = 0
self._body_remaining = 0
def force_close(self):
"""Mark connection for closing after this request."""
self.must_close = True
def should_close(self):
"""Check if connection should be closed after this request."""
if self.must_close:
return True
for (h, v) in self.headers:
if h == "CONNECTION":
v = v.lower().strip(" \t")
if v == "close":
return True
elif v == "keep-alive":
return False
break
return self.version <= (1, 0)
def get_header(self, name):
"""Get a header value by name (case-insensitive)."""
name = name.upper()
for (h, v) in self.headers:
if h == name:
return v
return None
async def read_body(self, size=8192):
"""Read a chunk of the request body.
Args:
size: Maximum bytes to read
Returns:
bytes: Body data, empty bytes when body is exhausted
"""
if self._body_remaining == 0:
return b""
if self.chunked:
return await self._read_chunked_body(size)
else:
return await self._read_length_body(size)
async def _read_length_body(self, size):
"""Read from a length-delimited body."""
if self._body_remaining <= 0:
return b""
to_read = min(size, self._body_remaining)
data = await self.unreader.read(to_read)
if data:
self._body_remaining -= len(data)
return data
async def _read_chunked_body(self, size):
"""Read from a chunked body."""
if self._body_reader is None:
self._body_reader = self._chunked_body_reader()
try:
return await anext(self._body_reader)
except StopAsyncIteration:
self._body_remaining = 0
return b""
async def _chunked_body_reader(self):
"""Async generator for reading chunked body."""
while True:
# Read chunk size line
size_line = await self._read_chunk_size_line()
# Parse chunk size (handle extensions)
chunk_size, *_ = size_line.split(b";", 1)
if _:
chunk_size = chunk_size.rstrip(b" \t")
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
raise InvalidHeader("Invalid chunk size")
if len(chunk_size) == 0:
raise InvalidHeader("Invalid chunk size")
chunk_size = int(chunk_size, 16)
if chunk_size == 0:
# Final chunk - skip trailers and final CRLF
await self._skip_trailers()
return
# Read chunk data
remaining = chunk_size
while remaining > 0:
data = await self.unreader.read(min(remaining, 8192))
if not data:
raise NoMoreData()
remaining -= len(data)
yield data
# Skip chunk terminating CRLF
crlf = await self.unreader.read(2)
if crlf != b"\r\n":
# May have partial read, try to get the rest
while len(crlf) < 2:
more = await self.unreader.read(2 - len(crlf))
if not more:
break
crlf += more
if crlf != b"\r\n":
raise InvalidHeader("Missing chunk terminator")
async def _read_chunk_size_line(self):
"""Read a chunk size line.
Performance optimization: reads 64-byte chunks instead of 1 byte at a time,
then pushes back any excess data after finding the line terminator.
"""
buf = bytearray()
while True:
data = await self.unreader.read(64)
if not data:
raise NoMoreData()
buf.extend(data)
idx = buf.find(b"\r\n")
if idx >= 0:
# Push back any data after the line
if idx + 2 < len(buf):
self.unreader.unread(bytes(buf[idx + 2:]))
return bytes(buf[:idx])
async def _skip_trailers(self):
"""Skip trailer headers after chunked body.
Performance optimization: reads 64-byte chunks instead of 1 byte at a time,
then pushes back any excess data after finding the trailer terminator.
"""
buf = bytearray()
while True:
data = await self.unreader.read(64)
if not data:
return
buf.extend(data)
# Check for empty trailer (just CRLF)
if buf[:2] == b"\r\n":
# Push back remaining data
if len(buf) > 2:
self.unreader.unread(bytes(buf[2:]))
return
# Check for full trailer terminator
idx = buf.find(b"\r\n\r\n")
if idx >= 0:
# Push back data after the trailer
if idx + 4 < len(buf):
self.unreader.unread(bytes(buf[idx + 4:]))
return
async def drain_body(self):
"""Drain any unread body data.
Should be called before reusing connection for keepalive.
"""
while True:
data = await self.read_body(8192)
if not data:
break

View File

@ -934,7 +934,7 @@ class ASGIProtocol(asyncio.Protocol):
def _build_http_scope(self, request, sockname, peername):
"""Build ASGI HTTP scope from parsed request."""
# Use pre-computed bytes headers if available (fast path)
# Fall back to conversion for legacy requests (AsyncRequest, HTTP/2)
# Fall back to conversion for HTTP/2 requests
headers_bytes = getattr(request, 'headers_bytes', None)
if isinstance(headers_bytes, list):
headers = list(headers_bytes) # Copy to avoid mutation

View File

@ -1,304 +0,0 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for ASGI worker components.
"""
import asyncio
import ipaddress
import pytest
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
class MockStreamReader:
"""Mock asyncio.StreamReader for testing."""
def __init__(self, data):
self.data = data
self.pos = 0
async def read(self, size=-1):
if self.pos >= len(self.data):
return b""
if size < 0:
result = self.data[self.pos:]
self.pos = len(self.data)
else:
result = self.data[self.pos:self.pos + size]
self.pos += size
return result
async def readexactly(self, n):
if self.pos + n > len(self.data):
raise asyncio.IncompleteReadError(
self.data[self.pos:], n
)
result = self.data[self.pos:self.pos + n]
self.pos += n
return result
class MockConfig:
"""Mock gunicorn config for testing."""
def __init__(self):
self.is_ssl = False
self.proxy_protocol = "off"
self.proxy_allow_ips = ["127.0.0.1"]
self.forwarded_allow_ips = ["127.0.0.1"]
self._proxy_allow_networks = None
self._forwarded_allow_networks = None
self.secure_scheme_headers = {}
self.forwarder_headers = []
self.limit_request_line = 8190
self.limit_request_fields = 100
self.limit_request_field_size = 8190
self.permit_unconventional_http_method = False
self.permit_unconventional_http_version = False
self.permit_obsolete_folding = False
self.casefold_http_method = False
self.strip_header_spaces = False
self.header_map = "refuse"
def forwarded_allow_networks(self):
if self._forwarded_allow_networks is None:
self._forwarded_allow_networks = [
ipaddress.ip_network(addr)
for addr in self.forwarded_allow_ips
if addr != "*"
]
return self._forwarded_allow_networks
def proxy_allow_networks(self):
if self._proxy_allow_networks is None:
self._proxy_allow_networks = [
ipaddress.ip_network(addr)
for addr in self.proxy_allow_ips
if addr != "*"
]
return self._proxy_allow_networks
# AsyncUnreader Tests
@pytest.mark.asyncio
async def test_async_unreader_read_chunk():
"""Test basic chunk reading."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b"hello world"
@pytest.mark.asyncio
async def test_async_unreader_read_size():
"""Test reading specific size."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read(5)
assert data == b"hello"
@pytest.mark.asyncio
async def test_async_unreader_unread():
"""Test unread functionality."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
# Read all data
data = await unreader.read()
assert data == b"hello world"
# Unread some data
unreader.unread(b"world")
# Read again should get unread data
data = await unreader.read()
assert data == b"world"
@pytest.mark.asyncio
async def test_async_unreader_read_zero():
"""Test reading zero bytes."""
reader = MockStreamReader(b"hello")
unreader = AsyncUnreader(reader)
data = await unreader.read(0)
assert data == b""
@pytest.mark.asyncio
async def test_async_unreader_read_empty():
"""Test reading from empty stream."""
reader = MockStreamReader(b"")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b""
# AsyncRequest Tests
@pytest.mark.asyncio
async def test_async_request_simple_get():
"""Test parsing a simple GET request."""
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/path"
assert request.version == (1, 1)
assert ("HOST", "localhost") in request.headers
@pytest.mark.asyncio
async def test_async_request_with_query():
"""Test parsing request with query string."""
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/search"
assert request.query == "q=test&page=1"
@pytest.mark.asyncio
async def test_async_request_post_with_body():
"""Test parsing POST request with body."""
request_data = (
b"POST /submit HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 11\r\n"
b"\r\n"
b"hello=world"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "POST"
assert request.path == "/submit"
assert request.content_length == 11
# Read body
body = await request.read_body(100)
assert body == b"hello=world"
@pytest.mark.asyncio
async def test_async_request_multiple_headers():
"""Test parsing request with multiple headers."""
request_data = (
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Accept: text/html\r\n"
b"Accept-Language: en-US\r\n"
b"Connection: keep-alive\r\n"
b"\r\n"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert len(request.headers) == 4
assert request.get_header("HOST") == "localhost"
assert request.get_header("ACCEPT") == "text/html"
@pytest.mark.asyncio
async def test_async_request_should_close_http10():
"""Test connection close detection for HTTP/1.0."""
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.version == (1, 0)
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_should_close_connection_header():
"""Test connection close detection with Connection header."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_keepalive():
"""Test keepalive detection."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is False
@pytest.mark.asyncio
async def test_async_request_no_body_for_get():
"""Test that GET requests have no body by default."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.content_length == 0
body = await request.read_body()
assert body == b""
# Error handling tests
@pytest.mark.asyncio
async def test_async_request_invalid_method():
"""Test invalid HTTP method detection."""
from gunicorn.http.errors import InvalidRequestMethod
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidRequestMethod):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
@pytest.mark.asyncio
async def test_async_request_invalid_http_version():
"""Test invalid HTTP version detection."""
from gunicorn.http.errors import InvalidHTTPVersion
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidHTTPVersion):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))

View File

@ -1,324 +0,0 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for ASGI HTTP parser optimizations.
"""
import ipaddress
import pytest
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
class MockStreamReader:
"""Mock asyncio.StreamReader for testing."""
def __init__(self, data):
self.data = data
self.pos = 0
async def read(self, size=-1):
if self.pos >= len(self.data):
return b""
if size < 0:
result = self.data[self.pos:]
self.pos = len(self.data)
else:
result = self.data[self.pos:self.pos + size]
self.pos += size
return result
class MockConfig:
"""Mock gunicorn config for testing."""
def __init__(self):
self.is_ssl = False
self.proxy_protocol = "off"
self.proxy_allow_ips = ["127.0.0.1"]
self.forwarded_allow_ips = ["127.0.0.1"]
self._proxy_allow_networks = None
self._forwarded_allow_networks = None
self.secure_scheme_headers = {}
self.forwarder_headers = []
self.limit_request_line = 8190
self.limit_request_fields = 100
self.limit_request_field_size = 8190
self.permit_unconventional_http_method = False
self.permit_unconventional_http_version = False
self.permit_obsolete_folding = False
self.casefold_http_method = False
self.strip_header_spaces = False
self.header_map = "refuse"
def forwarded_allow_networks(self):
if self._forwarded_allow_networks is None:
self._forwarded_allow_networks = [
ipaddress.ip_network(addr)
for addr in self.forwarded_allow_ips
if addr != "*"
]
return self._forwarded_allow_networks
def proxy_allow_networks(self):
if self._proxy_allow_networks is None:
self._proxy_allow_networks = [
ipaddress.ip_network(addr)
for addr in self.proxy_allow_ips
if addr != "*"
]
return self._proxy_allow_networks
# Optimized Chunk Reading Tests
@pytest.mark.asyncio
async def test_chunk_size_line_reading():
"""Test optimized chunk size line reading."""
# Simulate chunked body with chunk size line
data = b"a\r\nhello body\r\n0\r\n\r\n"
reader = MockStreamReader(data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
# Access the private method for testing
line = await req._read_chunk_size_line()
assert line == b"a"
@pytest.mark.asyncio
async def test_skip_trailers_empty():
"""Test skipping empty trailers."""
data = b"\r\n"
reader = MockStreamReader(data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
# Should not raise
await req._skip_trailers()
@pytest.mark.asyncio
async def test_skip_trailers_with_headers():
"""Test skipping trailers with actual headers."""
data = b"X-Checksum: abc123\r\n\r\n"
reader = MockStreamReader(data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
# Should not raise
await req._skip_trailers()
# Buffer Reuse Tests
@pytest.mark.asyncio
async def test_unreader_buffer_reuse():
"""Test that AsyncUnreader reuses buffers efficiently."""
data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
reader = MockStreamReader(data)
unreader = AsyncUnreader(reader)
# Read in chunks
chunk1 = await unreader.read(10)
assert chunk1 == b"GET / HTTP"
# Read more
chunk2 = await unreader.read(10)
assert chunk2 == b"/1.1\r\nHost"
# Unread some data
unreader.unread(b"/1.1\r\nHost")
# Read again - should get unreaded data
chunk3 = await unreader.read(10)
assert chunk3 == b"/1.1\r\nHost"
@pytest.mark.asyncio
async def test_unreader_unread_prepends():
"""Test that unread prepends data."""
data = b"original"
reader = MockStreamReader(data)
unreader = AsyncUnreader(reader)
# Read some data first
await unreader.read(4) # "orig"
# Unread something different
unreader.unread(b"NEW")
# Should read the new data first
result = await unreader.read(3)
assert result == b"NEW"
# Header Parsing Optimization Tests
@pytest.mark.asyncio
async def test_header_parsing_index_iteration():
"""Test that header parsing uses index-based iteration."""
raw_request = (
b"GET / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Type: text/plain\r\n"
b"X-Custom: value\r\n"
b"\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert req.method == "GET"
assert req.path == "/"
assert len(req.headers) == 3
assert ("HOST", "example.com") in req.headers
assert ("CONTENT-TYPE", "text/plain") in req.headers
assert ("X-CUSTOM", "value") in req.headers
@pytest.mark.asyncio
async def test_many_headers_performance():
"""Test parsing request with many headers."""
headers = []
for i in range(50):
headers.append(f"X-Header-{i}: value-{i}\r\n")
raw_request = (
b"GET / HTTP/1.1\r\n"
+ "".join(headers).encode()
+ b"\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert len(req.headers) == 50
# Bytearray Find Optimization Tests
@pytest.mark.asyncio
async def test_bytearray_find_optimization():
"""Test that bytearray.find() is used instead of bytes().find()."""
raw_request = (
b"GET /path?query=value HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert req.method == "GET"
assert req.path == "/path"
assert req.query == "query=value"
assert req.content_length == 5
# Chunked Body Tests with Optimized Reading
@pytest.mark.asyncio
async def test_chunked_body_optimized_reading():
"""Test reading chunked body with optimized chunk reading."""
raw_request = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\nhello\r\n"
b"6\r\n world\r\n"
b"0\r\n\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert req.chunked is True
assert req.content_length is None
# Read body
body_parts = []
while True:
chunk = await req.read_body(1024)
if not chunk:
break
body_parts.append(chunk)
body = b"".join(body_parts)
assert body == b"hello world"
@pytest.mark.asyncio
async def test_chunked_body_with_extension():
"""Test reading chunked body with chunk extensions."""
raw_request = (
b"POST / HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5;ext=value\r\nhello\r\n"
b"0\r\n\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
chunk = await req.read_body(1024)
assert chunk == b"hello"
# Edge Cases
@pytest.mark.asyncio
async def test_empty_headers():
"""Test request with no headers."""
raw_request = (
b"GET / HTTP/1.1\r\n"
b"\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert req.method == "GET"
assert len(req.headers) == 0
@pytest.mark.asyncio
async def test_large_header_value():
"""Test request with large header value."""
large_value = "x" * 4000 # Within default limit
raw_request = (
b"GET / HTTP/1.1\r\n"
+ f"X-Large-Header: {large_value}\r\n".encode()
+ b"\r\n"
)
reader = MockStreamReader(raw_request)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert req.get_header("X-Large-Header") == large_value