mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 19:21:29 +08:00
Remove unused AsyncRequest class
AsyncRequest was the legacy pull-based async HTTP parser, now replaced by the push-based CallbackRequest/PythonProtocol. Remove the unused code and associated tests.
This commit is contained in:
parent
b00f125755
commit
da8bd4850a
@ -10,7 +10,6 @@ HTTP parsing infrastructure adapted for async I/O.
|
||||
|
||||
Components:
|
||||
- AsyncUnreader: Async socket reading with pushback buffer
|
||||
- AsyncRequest: Async HTTP request parser
|
||||
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
|
||||
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
|
||||
- LifespanManager: ASGI lifespan protocol support
|
||||
@ -20,7 +19,6 @@ Usage:
|
||||
"""
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']
|
||||
__all__ = ['AsyncUnreader', 'LifespanManager']
|
||||
|
||||
@ -1,736 +0,0 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Async version of gunicorn/http/message.py for ASGI workers.
|
||||
|
||||
Reuses the parsing logic from the sync version, adapted for async I/O.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import re
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from gunicorn.http.errors import (
|
||||
ExpectationFailed,
|
||||
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
|
||||
LimitRequestLine, LimitRequestHeaders,
|
||||
UnsupportedTransferCoding, ObsoleteFolding,
|
||||
InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest,
|
||||
InvalidSchemeHeaders,
|
||||
)
|
||||
from gunicorn.http.message import (
|
||||
PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol
|
||||
)
|
||||
from gunicorn.util import bytes_to_str, split_request_uri
|
||||
|
||||
MAX_REQUEST_LINE = 8190
|
||||
MAX_HEADERS = 32768
|
||||
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
|
||||
|
||||
# Reuse regex patterns from sync version
|
||||
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
|
||||
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
|
||||
METHOD_BADCHAR_RE = re.compile("[a-z#]")
|
||||
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
|
||||
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
|
||||
|
||||
|
||||
def _ip_in_allow_list(ip_str, allow_list, networks):
|
||||
"""Check if IP address is in the allow list.
|
||||
|
||||
Args:
|
||||
ip_str: The IP address string to check
|
||||
allow_list: The original allow list (strings, may contain "*")
|
||||
networks: Pre-computed ipaddress.ip_network objects from config
|
||||
"""
|
||||
if '*' in allow_list:
|
||||
return True
|
||||
try:
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return False
|
||||
for network in networks:
|
||||
if ip in network:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AsyncRequest:
|
||||
"""Async HTTP request parser.
|
||||
|
||||
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
|
||||
parsing logic where possible.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.remote_addr = peer_addr
|
||||
self.req_number = req_number
|
||||
|
||||
self.version = None
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
self.scheme = "https" if cfg.is_ssl else "http"
|
||||
self.must_close = False
|
||||
self._expected_100_continue = False
|
||||
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Request line limit
|
||||
self.limit_request_line = cfg.limit_request_line
|
||||
if (self.limit_request_line < 0
|
||||
or self.limit_request_line >= MAX_REQUEST_LINE):
|
||||
self.limit_request_line = MAX_REQUEST_LINE
|
||||
|
||||
# Headers limits
|
||||
self.limit_request_fields = cfg.limit_request_fields
|
||||
if (self.limit_request_fields <= 0
|
||||
or self.limit_request_fields > MAX_HEADERS):
|
||||
self.limit_request_fields = MAX_HEADERS
|
||||
|
||||
self.limit_request_field_size = cfg.limit_request_field_size
|
||||
if self.limit_request_field_size < 0:
|
||||
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
|
||||
# Max header buffer size
|
||||
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
self.max_buffer_headers = self.limit_request_fields * \
|
||||
(max_header_field_size + 2) + 4
|
||||
|
||||
# Body-related state
|
||||
self.content_length = None
|
||||
self.chunked = False
|
||||
self._body_reader = None
|
||||
self._body_remaining = 0
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
|
||||
"""Parse an HTTP request from the stream.
|
||||
|
||||
Args:
|
||||
cfg: gunicorn config object
|
||||
unreader: AsyncUnreader instance
|
||||
peer_addr: client address tuple
|
||||
req_number: request number on this connection (for keepalive)
|
||||
|
||||
Returns:
|
||||
AsyncRequest: Parsed request object
|
||||
|
||||
Raises:
|
||||
NoMoreData: If no data available
|
||||
Various parsing errors for malformed requests
|
||||
"""
|
||||
req = cls(cfg, unreader, peer_addr, req_number)
|
||||
await req._parse()
|
||||
return req
|
||||
|
||||
async def _parse(self):
|
||||
"""Parse the request from the unreader."""
|
||||
buf = bytearray()
|
||||
await self._read_into(buf)
|
||||
|
||||
# Handle proxy protocol if enabled and this is the first request
|
||||
mode = self.cfg.proxy_protocol
|
||||
if mode != "off" and self.req_number == 1:
|
||||
buf = await self._handle_proxy_protocol(buf, mode)
|
||||
|
||||
# Get request line
|
||||
line, buf = await self._read_line(buf, self.limit_request_line)
|
||||
|
||||
self._parse_request_line(line)
|
||||
|
||||
# Headers - use bytearray.find() directly to avoid bytes() conversions
|
||||
while True:
|
||||
idx = buf.find(b"\r\n\r\n")
|
||||
done = buf[:2] == b"\r\n"
|
||||
|
||||
if idx < 0 and not done:
|
||||
await self._read_into(buf)
|
||||
if len(buf) > self.max_buffer_headers:
|
||||
raise LimitRequestHeaders("max buffer headers")
|
||||
else:
|
||||
break
|
||||
|
||||
if done:
|
||||
self.unreader.unread(bytes(buf[2:]))
|
||||
else:
|
||||
self.headers = self._parse_headers(bytes(buf[:idx]), from_trailer=False)
|
||||
self.unreader.unread(bytes(buf[idx + 4:]))
|
||||
|
||||
self._set_body_reader()
|
||||
|
||||
async def _read_into(self, buf):
|
||||
"""Read data from unreader and append to bytearray buffer."""
|
||||
data = await self.unreader.read()
|
||||
if not data:
|
||||
raise NoMoreData(bytes(buf))
|
||||
buf.extend(data)
|
||||
|
||||
async def _read_line(self, buf, limit=0):
|
||||
"""Read a line from buffer, returning (line, remaining_buffer).
|
||||
|
||||
Uses bytearray.find() directly to avoid repeated bytes() conversions.
|
||||
"""
|
||||
while True:
|
||||
idx = buf.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
if idx > limit > 0:
|
||||
raise LimitRequestLine(idx, limit)
|
||||
break
|
||||
if len(buf) - 2 > limit > 0:
|
||||
raise LimitRequestLine(len(buf), limit)
|
||||
await self._read_into(buf)
|
||||
|
||||
line = bytes(buf[:idx])
|
||||
remaining = bytearray(buf[idx + 2:])
|
||||
return (line, remaining)
|
||||
|
||||
async def _handle_proxy_protocol(self, buf, mode):
|
||||
"""Handle PROXY protocol detection and parsing.
|
||||
|
||||
Returns the buffer with proxy protocol data consumed.
|
||||
"""
|
||||
# Ensure we have enough data to detect v2 signature (12 bytes)
|
||||
while len(buf) < 12:
|
||||
await self._read_into(buf)
|
||||
|
||||
# Check for v2 signature first
|
||||
if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE:
|
||||
self._proxy_protocol_access_check()
|
||||
return await self._parse_proxy_protocol_v2(buf)
|
||||
|
||||
# Check for v1 prefix
|
||||
if mode in ("v1", "auto") and buf[:6] == b"PROXY ":
|
||||
self._proxy_protocol_access_check()
|
||||
return await self._parse_proxy_protocol_v1(buf)
|
||||
|
||||
# Not proxy protocol - return buffer unchanged
|
||||
return buf
|
||||
|
||||
def _proxy_protocol_access_check(self):
|
||||
"""Check if proxy protocol is allowed from this peer."""
|
||||
if (isinstance(self.peer_addr, tuple) and
|
||||
not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips,
|
||||
self.cfg.proxy_allow_networks())):
|
||||
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||
|
||||
async def _parse_proxy_protocol_v1(self, buf):
|
||||
"""Parse PROXY protocol v1 (text format).
|
||||
|
||||
Returns buffer with v1 header consumed.
|
||||
"""
|
||||
# Read until we find \r\n
|
||||
data = bytes(buf)
|
||||
while b"\r\n" not in data:
|
||||
await self._read_into(buf)
|
||||
data = bytes(buf)
|
||||
|
||||
idx = data.find(b"\r\n")
|
||||
line = bytes_to_str(data[:idx])
|
||||
remaining = bytearray(data[idx + 2:])
|
||||
|
||||
bits = line.split(" ")
|
||||
|
||||
if len(bits) != 6:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
proto = bits[1]
|
||||
s_addr = bits[2]
|
||||
d_addr = bits[3]
|
||||
|
||||
if proto not in ["TCP4", "TCP6"]:
|
||||
raise InvalidProxyLine("protocol '%s' not supported" % proto)
|
||||
|
||||
if proto == "TCP4":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET, s_addr)
|
||||
socket.inet_pton(socket.AF_INET, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
elif proto == "TCP6":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, s_addr)
|
||||
socket.inet_pton(socket.AF_INET6, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
try:
|
||||
s_port = int(bits[4])
|
||||
d_port = int(bits[5])
|
||||
except ValueError:
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": proto,
|
||||
"client_addr": s_addr,
|
||||
"client_port": s_port,
|
||||
"proxy_addr": d_addr,
|
||||
"proxy_port": d_port
|
||||
}
|
||||
|
||||
return remaining
|
||||
|
||||
async def _parse_proxy_protocol_v2(self, buf):
|
||||
"""Parse PROXY protocol v2 (binary format).
|
||||
|
||||
Returns buffer with v2 header consumed.
|
||||
"""
|
||||
# We need at least 16 bytes for the header (12 signature + 4 header)
|
||||
while len(buf) < 16:
|
||||
await self._read_into(buf)
|
||||
|
||||
# Parse header fields (after 12-byte signature)
|
||||
ver_cmd = buf[12]
|
||||
fam_proto = buf[13]
|
||||
length = struct.unpack(">H", bytes(buf[14:16]))[0]
|
||||
|
||||
# Validate version (high nibble must be 0x2)
|
||||
version = (ver_cmd & 0xF0) >> 4
|
||||
if version != 2:
|
||||
raise InvalidProxyHeader("unsupported version %d" % version)
|
||||
|
||||
# Extract command (low nibble)
|
||||
command = ver_cmd & 0x0F
|
||||
if command not in (PPCommand.LOCAL, PPCommand.PROXY):
|
||||
raise InvalidProxyHeader("unsupported command %d" % command)
|
||||
|
||||
# Ensure we have the complete header
|
||||
total_header_size = 16 + length
|
||||
while len(buf) < total_header_size:
|
||||
await self._read_into(buf)
|
||||
|
||||
# For LOCAL command, no address info is provided
|
||||
if command == PPCommand.LOCAL:
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": "LOCAL",
|
||||
"client_addr": None,
|
||||
"client_port": None,
|
||||
"proxy_addr": None,
|
||||
"proxy_port": None
|
||||
}
|
||||
return bytearray(buf[total_header_size:])
|
||||
|
||||
# Extract address family and protocol
|
||||
family = (fam_proto & 0xF0) >> 4
|
||||
protocol = fam_proto & 0x0F
|
||||
|
||||
# We only support TCP (STREAM)
|
||||
if protocol != PPProtocol.STREAM:
|
||||
raise InvalidProxyHeader("only TCP protocol is supported")
|
||||
|
||||
addr_data = bytes(buf[16:16 + length])
|
||||
|
||||
if family == PPFamily.INET: # IPv4
|
||||
if length < 12: # 4+4+2+2
|
||||
raise InvalidProxyHeader("insufficient address data for IPv4")
|
||||
s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4])
|
||||
d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8])
|
||||
s_port = struct.unpack(">H", addr_data[8:10])[0]
|
||||
d_port = struct.unpack(">H", addr_data[10:12])[0]
|
||||
proto = "TCP4"
|
||||
|
||||
elif family == PPFamily.INET6: # IPv6
|
||||
if length < 36: # 16+16+2+2
|
||||
raise InvalidProxyHeader("insufficient address data for IPv6")
|
||||
s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16])
|
||||
d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32])
|
||||
s_port = struct.unpack(">H", addr_data[32:34])[0]
|
||||
d_port = struct.unpack(">H", addr_data[34:36])[0]
|
||||
proto = "TCP6"
|
||||
|
||||
elif family == PPFamily.UNSPEC:
|
||||
# No address info provided with PROXY command
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": "UNSPEC",
|
||||
"client_addr": None,
|
||||
"client_port": None,
|
||||
"proxy_addr": None,
|
||||
"proxy_port": None
|
||||
}
|
||||
return bytearray(buf[total_header_size:])
|
||||
|
||||
else:
|
||||
raise InvalidProxyHeader("unsupported address family %d" % family)
|
||||
|
||||
# Set data
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": proto,
|
||||
"client_addr": s_addr,
|
||||
"client_port": s_port,
|
||||
"proxy_addr": d_addr,
|
||||
"proxy_port": d_port
|
||||
}
|
||||
|
||||
return bytearray(buf[total_header_size:])
|
||||
|
||||
def _parse_request_line(self, line_bytes):
|
||||
"""Parse the HTTP request line."""
|
||||
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
|
||||
if len(bits) != 3:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
# Method
|
||||
self.method = bits[0]
|
||||
|
||||
if not self.cfg.permit_unconventional_http_method:
|
||||
if METHOD_BADCHAR_RE.search(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not 3 <= len(bits[0]) <= 20:
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not TOKEN_RE.fullmatch(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if self.cfg.casefold_http_method:
|
||||
self.method = self.method.upper()
|
||||
|
||||
# URI
|
||||
self.uri = bits[1]
|
||||
|
||||
if len(self.uri) == 0:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
try:
|
||||
parts = split_request_uri(self.uri)
|
||||
except ValueError:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
self.path = parts.path or ""
|
||||
self.query = parts.query or ""
|
||||
self.fragment = parts.fragment or ""
|
||||
|
||||
# Version
|
||||
match = VERSION_RE.fullmatch(bits[2])
|
||||
if match is None:
|
||||
raise InvalidHTTPVersion(bits[2])
|
||||
self.version = (int(match.group(1)), int(match.group(2)))
|
||||
if not (1, 0) <= self.version < (2, 0):
|
||||
if not self.cfg.permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(self.version)
|
||||
|
||||
def _parse_headers(self, data, from_trailer=False):
|
||||
"""Parse HTTP headers from raw data.
|
||||
|
||||
Uses index-based iteration instead of list.pop(0) for O(1) access.
|
||||
"""
|
||||
cfg = self.cfg
|
||||
headers = []
|
||||
|
||||
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
|
||||
num_lines = len(lines)
|
||||
i = 0
|
||||
|
||||
# Handle scheme headers
|
||||
scheme_header = False
|
||||
secure_scheme_headers = {}
|
||||
forwarder_headers = []
|
||||
if from_trailer:
|
||||
pass
|
||||
elif (not isinstance(self.peer_addr, tuple)
|
||||
or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips,
|
||||
cfg.forwarded_allow_networks())):
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
forwarder_headers = cfg.forwarder_headers
|
||||
|
||||
while i < num_lines:
|
||||
if len(headers) >= self.limit_request_fields:
|
||||
raise LimitRequestHeaders("limit request headers fields")
|
||||
|
||||
curr = lines[i]
|
||||
i += 1
|
||||
header_length = len(curr) + len("\r\n")
|
||||
if curr.find(":") <= 0:
|
||||
raise InvalidHeader(curr)
|
||||
name, value = curr.split(":", 1)
|
||||
if self.cfg.strip_header_spaces:
|
||||
name = name.rstrip(" \t")
|
||||
if not TOKEN_RE.fullmatch(name):
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
name = name.upper()
|
||||
value = [value.strip(" \t")]
|
||||
|
||||
# Consume value continuation lines using index-based iteration
|
||||
while i < num_lines and lines[i].startswith((" ", "\t")):
|
||||
if not self.cfg.permit_obsolete_folding:
|
||||
raise ObsoleteFolding(name)
|
||||
curr = lines[i]
|
||||
i += 1
|
||||
header_length += len(curr) + len("\r\n")
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
value.append(curr.strip("\t "))
|
||||
value = " ".join(value)
|
||||
|
||||
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
|
||||
raise InvalidHeader(name)
|
||||
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
|
||||
if not from_trailer and name == "EXPECT":
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1
|
||||
# "The Expect field value is case-insensitive."
|
||||
if value.lower() == "100-continue":
|
||||
if self.version < (1, 1):
|
||||
# https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12
|
||||
# "A server that receives a 100-continue expectation
|
||||
# in an HTTP/1.0 request MUST ignore that expectation."
|
||||
pass
|
||||
else:
|
||||
self._expected_100_continue = True
|
||||
# N.B. understood but ignored expect header does not return 417
|
||||
else:
|
||||
raise ExpectationFailed(value)
|
||||
|
||||
if name in secure_scheme_headers:
|
||||
secure = value == secure_scheme_headers[name]
|
||||
scheme = "https" if secure else "http"
|
||||
if scheme_header:
|
||||
if scheme != self.scheme:
|
||||
raise InvalidSchemeHeaders()
|
||||
else:
|
||||
scheme_header = True
|
||||
self.scheme = scheme
|
||||
|
||||
if "_" in name:
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
continue
|
||||
else:
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
headers.append((name, value))
|
||||
|
||||
return headers
|
||||
|
||||
def _set_body_reader(self):
|
||||
"""Determine how to read the request body."""
|
||||
chunked = False
|
||||
content_length = None
|
||||
|
||||
for (name, value) in self.headers:
|
||||
if name == "CONTENT-LENGTH":
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
content_length = value
|
||||
elif name == "TRANSFER-ENCODING":
|
||||
vals = [v.strip() for v in value.split(',')]
|
||||
for val in vals:
|
||||
if val.lower() == "chunked":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
chunked = True
|
||||
elif val.lower() == "identity":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
elif val.lower() in ('compress', 'deflate', 'gzip'):
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
self.force_close()
|
||||
else:
|
||||
raise UnsupportedTransferCoding(value)
|
||||
|
||||
if chunked:
|
||||
if self.version < (1, 1):
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
self.chunked = True
|
||||
self.content_length = None
|
||||
self._body_remaining = -1
|
||||
elif content_length is not None:
|
||||
try:
|
||||
if str(content_length).isnumeric():
|
||||
content_length = int(content_length)
|
||||
else:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
except ValueError:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
if content_length < 0:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
self.content_length = content_length
|
||||
self._body_remaining = content_length
|
||||
else:
|
||||
# No body for requests without Content-Length or Transfer-Encoding
|
||||
self.content_length = 0
|
||||
self._body_remaining = 0
|
||||
|
||||
def force_close(self):
|
||||
"""Mark connection for closing after this request."""
|
||||
self.must_close = True
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for (h, v) in self.headers:
|
||||
if h == "CONNECTION":
|
||||
v = v.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for (h, v) in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
|
||||
async def read_body(self, size=8192):
|
||||
"""Read a chunk of the request body.
|
||||
|
||||
Args:
|
||||
size: Maximum bytes to read
|
||||
|
||||
Returns:
|
||||
bytes: Body data, empty bytes when body is exhausted
|
||||
"""
|
||||
if self._body_remaining == 0:
|
||||
return b""
|
||||
|
||||
if self.chunked:
|
||||
return await self._read_chunked_body(size)
|
||||
else:
|
||||
return await self._read_length_body(size)
|
||||
|
||||
async def _read_length_body(self, size):
|
||||
"""Read from a length-delimited body."""
|
||||
if self._body_remaining <= 0:
|
||||
return b""
|
||||
|
||||
to_read = min(size, self._body_remaining)
|
||||
data = await self.unreader.read(to_read)
|
||||
if data:
|
||||
self._body_remaining -= len(data)
|
||||
return data
|
||||
|
||||
async def _read_chunked_body(self, size):
|
||||
"""Read from a chunked body."""
|
||||
if self._body_reader is None:
|
||||
self._body_reader = self._chunked_body_reader()
|
||||
|
||||
try:
|
||||
return await anext(self._body_reader)
|
||||
except StopAsyncIteration:
|
||||
self._body_remaining = 0
|
||||
return b""
|
||||
|
||||
async def _chunked_body_reader(self):
|
||||
"""Async generator for reading chunked body."""
|
||||
while True:
|
||||
# Read chunk size line
|
||||
size_line = await self._read_chunk_size_line()
|
||||
# Parse chunk size (handle extensions)
|
||||
chunk_size, *_ = size_line.split(b";", 1)
|
||||
if _:
|
||||
chunk_size = chunk_size.rstrip(b" \t")
|
||||
|
||||
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
if len(chunk_size) == 0:
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
|
||||
chunk_size = int(chunk_size, 16)
|
||||
|
||||
if chunk_size == 0:
|
||||
# Final chunk - skip trailers and final CRLF
|
||||
await self._skip_trailers()
|
||||
return
|
||||
|
||||
# Read chunk data
|
||||
remaining = chunk_size
|
||||
while remaining > 0:
|
||||
data = await self.unreader.read(min(remaining, 8192))
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
# Skip chunk terminating CRLF
|
||||
crlf = await self.unreader.read(2)
|
||||
if crlf != b"\r\n":
|
||||
# May have partial read, try to get the rest
|
||||
while len(crlf) < 2:
|
||||
more = await self.unreader.read(2 - len(crlf))
|
||||
if not more:
|
||||
break
|
||||
crlf += more
|
||||
if crlf != b"\r\n":
|
||||
raise InvalidHeader("Missing chunk terminator")
|
||||
|
||||
async def _read_chunk_size_line(self):
|
||||
"""Read a chunk size line.
|
||||
|
||||
Performance optimization: reads 64-byte chunks instead of 1 byte at a time,
|
||||
then pushes back any excess data after finding the line terminator.
|
||||
"""
|
||||
buf = bytearray()
|
||||
while True:
|
||||
data = await self.unreader.read(64)
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
buf.extend(data)
|
||||
idx = buf.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
# Push back any data after the line
|
||||
if idx + 2 < len(buf):
|
||||
self.unreader.unread(bytes(buf[idx + 2:]))
|
||||
return bytes(buf[:idx])
|
||||
|
||||
async def _skip_trailers(self):
|
||||
"""Skip trailer headers after chunked body.
|
||||
|
||||
Performance optimization: reads 64-byte chunks instead of 1 byte at a time,
|
||||
then pushes back any excess data after finding the trailer terminator.
|
||||
"""
|
||||
buf = bytearray()
|
||||
while True:
|
||||
data = await self.unreader.read(64)
|
||||
if not data:
|
||||
return
|
||||
buf.extend(data)
|
||||
# Check for empty trailer (just CRLF)
|
||||
if buf[:2] == b"\r\n":
|
||||
# Push back remaining data
|
||||
if len(buf) > 2:
|
||||
self.unreader.unread(bytes(buf[2:]))
|
||||
return
|
||||
# Check for full trailer terminator
|
||||
idx = buf.find(b"\r\n\r\n")
|
||||
if idx >= 0:
|
||||
# Push back data after the trailer
|
||||
if idx + 4 < len(buf):
|
||||
self.unreader.unread(bytes(buf[idx + 4:]))
|
||||
return
|
||||
|
||||
async def drain_body(self):
|
||||
"""Drain any unread body data.
|
||||
|
||||
Should be called before reusing connection for keepalive.
|
||||
"""
|
||||
while True:
|
||||
data = await self.read_body(8192)
|
||||
if not data:
|
||||
break
|
||||
@ -934,7 +934,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
def _build_http_scope(self, request, sockname, peername):
|
||||
"""Build ASGI HTTP scope from parsed request."""
|
||||
# Use pre-computed bytes headers if available (fast path)
|
||||
# Fall back to conversion for legacy requests (AsyncRequest, HTTP/2)
|
||||
# Fall back to conversion for HTTP/2 requests
|
||||
headers_bytes = getattr(request, 'headers_bytes', None)
|
||||
if isinstance(headers_bytes, list):
|
||||
headers = list(headers_bytes) # Copy to avoid mutation
|
||||
|
||||
@ -1,304 +0,0 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for ASGI worker components.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock asyncio.StreamReader for testing."""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.pos = 0
|
||||
|
||||
async def read(self, size=-1):
|
||||
if self.pos >= len(self.data):
|
||||
return b""
|
||||
if size < 0:
|
||||
result = self.data[self.pos:]
|
||||
self.pos = len(self.data)
|
||||
else:
|
||||
result = self.data[self.pos:self.pos + size]
|
||||
self.pos += size
|
||||
return result
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self.pos + n > len(self.data):
|
||||
raise asyncio.IncompleteReadError(
|
||||
self.data[self.pos:], n
|
||||
)
|
||||
result = self.data[self.pos:self.pos + n]
|
||||
self.pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock gunicorn config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_ssl = False
|
||||
self.proxy_protocol = "off"
|
||||
self.proxy_allow_ips = ["127.0.0.1"]
|
||||
self.forwarded_allow_ips = ["127.0.0.1"]
|
||||
self._proxy_allow_networks = None
|
||||
self._forwarded_allow_networks = None
|
||||
self.secure_scheme_headers = {}
|
||||
self.forwarder_headers = []
|
||||
self.limit_request_line = 8190
|
||||
self.limit_request_fields = 100
|
||||
self.limit_request_field_size = 8190
|
||||
self.permit_unconventional_http_method = False
|
||||
self.permit_unconventional_http_version = False
|
||||
self.permit_obsolete_folding = False
|
||||
self.casefold_http_method = False
|
||||
self.strip_header_spaces = False
|
||||
self.header_map = "refuse"
|
||||
|
||||
def forwarded_allow_networks(self):
|
||||
if self._forwarded_allow_networks is None:
|
||||
self._forwarded_allow_networks = [
|
||||
ipaddress.ip_network(addr)
|
||||
for addr in self.forwarded_allow_ips
|
||||
if addr != "*"
|
||||
]
|
||||
return self._forwarded_allow_networks
|
||||
|
||||
def proxy_allow_networks(self):
|
||||
if self._proxy_allow_networks is None:
|
||||
self._proxy_allow_networks = [
|
||||
ipaddress.ip_network(addr)
|
||||
for addr in self.proxy_allow_ips
|
||||
if addr != "*"
|
||||
]
|
||||
return self._proxy_allow_networks
|
||||
|
||||
|
||||
# AsyncUnreader Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_chunk():
|
||||
"""Test basic chunk reading."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_size():
|
||||
"""Test reading specific size."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(5)
|
||||
assert data == b"hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_unread():
|
||||
"""Test unread functionality."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
|
||||
# Read all data
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
# Unread some data
|
||||
unreader.unread(b"world")
|
||||
|
||||
# Read again should get unread data
|
||||
data = await unreader.read()
|
||||
assert data == b"world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_zero():
|
||||
"""Test reading zero bytes."""
|
||||
reader = MockStreamReader(b"hello")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(0)
|
||||
assert data == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_empty():
|
||||
"""Test reading from empty stream."""
|
||||
reader = MockStreamReader(b"")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b""
|
||||
|
||||
|
||||
# AsyncRequest Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_simple_get():
|
||||
"""Test parsing a simple GET request."""
|
||||
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/path"
|
||||
assert request.version == (1, 1)
|
||||
assert ("HOST", "localhost") in request.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_query():
|
||||
"""Test parsing request with query string."""
|
||||
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/search"
|
||||
assert request.query == "q=test&page=1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_post_with_body():
|
||||
"""Test parsing POST request with body."""
|
||||
request_data = (
|
||||
b"POST /submit HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 11\r\n"
|
||||
b"\r\n"
|
||||
b"hello=world"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/submit"
|
||||
assert request.content_length == 11
|
||||
|
||||
# Read body
|
||||
body = await request.read_body(100)
|
||||
assert body == b"hello=world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_multiple_headers():
|
||||
"""Test parsing request with multiple headers."""
|
||||
request_data = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Accept: text/html\r\n"
|
||||
b"Accept-Language: en-US\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert len(request.headers) == 4
|
||||
assert request.get_header("HOST") == "localhost"
|
||||
assert request.get_header("ACCEPT") == "text/html"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_http10():
|
||||
"""Test connection close detection for HTTP/1.0."""
|
||||
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.version == (1, 0)
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_connection_header():
|
||||
"""Test connection close detection with Connection header."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_keepalive():
|
||||
"""Test keepalive detection."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_no_body_for_get():
|
||||
"""Test that GET requests have no body by default."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.content_length == 0
|
||||
body = await request.read_body()
|
||||
assert body == b""
|
||||
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_method():
|
||||
"""Test invalid HTTP method detection."""
|
||||
from gunicorn.http.errors import InvalidRequestMethod
|
||||
|
||||
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidRequestMethod):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_http_version():
|
||||
"""Test invalid HTTP version detection."""
|
||||
from gunicorn.http.errors import InvalidHTTPVersion
|
||||
|
||||
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidHTTPVersion):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
@ -1,324 +0,0 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for ASGI HTTP parser optimizations.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import pytest
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock asyncio.StreamReader for testing."""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.pos = 0
|
||||
|
||||
async def read(self, size=-1):
|
||||
if self.pos >= len(self.data):
|
||||
return b""
|
||||
if size < 0:
|
||||
result = self.data[self.pos:]
|
||||
self.pos = len(self.data)
|
||||
else:
|
||||
result = self.data[self.pos:self.pos + size]
|
||||
self.pos += size
|
||||
return result
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock gunicorn config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_ssl = False
|
||||
self.proxy_protocol = "off"
|
||||
self.proxy_allow_ips = ["127.0.0.1"]
|
||||
self.forwarded_allow_ips = ["127.0.0.1"]
|
||||
self._proxy_allow_networks = None
|
||||
self._forwarded_allow_networks = None
|
||||
self.secure_scheme_headers = {}
|
||||
self.forwarder_headers = []
|
||||
self.limit_request_line = 8190
|
||||
self.limit_request_fields = 100
|
||||
self.limit_request_field_size = 8190
|
||||
self.permit_unconventional_http_method = False
|
||||
self.permit_unconventional_http_version = False
|
||||
self.permit_obsolete_folding = False
|
||||
self.casefold_http_method = False
|
||||
self.strip_header_spaces = False
|
||||
self.header_map = "refuse"
|
||||
|
||||
def forwarded_allow_networks(self):
|
||||
if self._forwarded_allow_networks is None:
|
||||
self._forwarded_allow_networks = [
|
||||
ipaddress.ip_network(addr)
|
||||
for addr in self.forwarded_allow_ips
|
||||
if addr != "*"
|
||||
]
|
||||
return self._forwarded_allow_networks
|
||||
|
||||
def proxy_allow_networks(self):
|
||||
if self._proxy_allow_networks is None:
|
||||
self._proxy_allow_networks = [
|
||||
ipaddress.ip_network(addr)
|
||||
for addr in self.proxy_allow_ips
|
||||
if addr != "*"
|
||||
]
|
||||
return self._proxy_allow_networks
|
||||
|
||||
|
||||
# Optimized Chunk Reading Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_size_line_reading():
|
||||
"""Test optimized chunk size line reading."""
|
||||
# Simulate chunked body with chunk size line
|
||||
data = b"a\r\nhello body\r\n0\r\n\r\n"
|
||||
reader = MockStreamReader(data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
|
||||
# Access the private method for testing
|
||||
line = await req._read_chunk_size_line()
|
||||
assert line == b"a"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_trailers_empty():
|
||||
"""Test skipping empty trailers."""
|
||||
data = b"\r\n"
|
||||
reader = MockStreamReader(data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
|
||||
# Should not raise
|
||||
await req._skip_trailers()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skip_trailers_with_headers():
|
||||
"""Test skipping trailers with actual headers."""
|
||||
data = b"X-Checksum: abc123\r\n\r\n"
|
||||
reader = MockStreamReader(data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000))
|
||||
# Should not raise
|
||||
await req._skip_trailers()
|
||||
|
||||
|
||||
# Buffer Reuse Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unreader_buffer_reuse():
|
||||
"""Test that AsyncUnreader reuses buffers efficiently."""
|
||||
data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
reader = MockStreamReader(data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
|
||||
# Read in chunks
|
||||
chunk1 = await unreader.read(10)
|
||||
assert chunk1 == b"GET / HTTP"
|
||||
|
||||
# Read more
|
||||
chunk2 = await unreader.read(10)
|
||||
assert chunk2 == b"/1.1\r\nHost"
|
||||
|
||||
# Unread some data
|
||||
unreader.unread(b"/1.1\r\nHost")
|
||||
|
||||
# Read again - should get unreaded data
|
||||
chunk3 = await unreader.read(10)
|
||||
assert chunk3 == b"/1.1\r\nHost"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unreader_unread_prepends():
|
||||
"""Test that unread prepends data."""
|
||||
data = b"original"
|
||||
reader = MockStreamReader(data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
|
||||
# Read some data first
|
||||
await unreader.read(4) # "orig"
|
||||
|
||||
# Unread something different
|
||||
unreader.unread(b"NEW")
|
||||
|
||||
# Should read the new data first
|
||||
result = await unreader.read(3)
|
||||
assert result == b"NEW"
|
||||
|
||||
|
||||
# Header Parsing Optimization Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_header_parsing_index_iteration():
|
||||
"""Test that header parsing uses index-based iteration."""
|
||||
raw_request = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: text/plain\r\n"
|
||||
b"X-Custom: value\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert req.method == "GET"
|
||||
assert req.path == "/"
|
||||
assert len(req.headers) == 3
|
||||
assert ("HOST", "example.com") in req.headers
|
||||
assert ("CONTENT-TYPE", "text/plain") in req.headers
|
||||
assert ("X-CUSTOM", "value") in req.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_many_headers_performance():
|
||||
"""Test parsing request with many headers."""
|
||||
headers = []
|
||||
for i in range(50):
|
||||
headers.append(f"X-Header-{i}: value-{i}\r\n")
|
||||
|
||||
raw_request = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
+ "".join(headers).encode()
|
||||
+ b"\r\n"
|
||||
)
|
||||
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert len(req.headers) == 50
|
||||
|
||||
|
||||
# Bytearray Find Optimization Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bytearray_find_optimization():
|
||||
"""Test that bytearray.find() is used instead of bytes().find()."""
|
||||
raw_request = (
|
||||
b"GET /path?query=value HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert req.method == "GET"
|
||||
assert req.path == "/path"
|
||||
assert req.query == "query=value"
|
||||
assert req.content_length == 5
|
||||
|
||||
|
||||
# Chunked Body Tests with Optimized Reading
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunked_body_optimized_reading():
|
||||
"""Test reading chunked body with optimized chunk reading."""
|
||||
raw_request = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\nhello\r\n"
|
||||
b"6\r\n world\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert req.chunked is True
|
||||
assert req.content_length is None
|
||||
|
||||
# Read body
|
||||
body_parts = []
|
||||
while True:
|
||||
chunk = await req.read_body(1024)
|
||||
if not chunk:
|
||||
break
|
||||
body_parts.append(chunk)
|
||||
|
||||
body = b"".join(body_parts)
|
||||
assert body == b"hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunked_body_with_extension():
|
||||
"""Test reading chunked body with chunk extensions."""
|
||||
raw_request = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5;ext=value\r\nhello\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
chunk = await req.read_body(1024)
|
||||
assert chunk == b"hello"
|
||||
|
||||
|
||||
# Edge Cases
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_headers():
|
||||
"""Test request with no headers."""
|
||||
raw_request = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert req.method == "GET"
|
||||
assert len(req.headers) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_header_value():
|
||||
"""Test request with large header value."""
|
||||
large_value = "x" * 4000 # Within default limit
|
||||
raw_request = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
+ f"X-Large-Header: {large_value}\r\n".encode()
|
||||
+ b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(raw_request)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert req.get_header("X-Large-Header") == large_value
|
||||
Loading…
x
Reference in New Issue
Block a user