From da8bd4850ac0f2d0df215390dad88392eb538d74 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 26 Mar 2026 16:08:35 +0100 Subject: [PATCH] Remove unused AsyncRequest class AsyncRequest was the legacy pull-based async HTTP parser, now replaced by the push-based CallbackRequest/PythonProtocol. Remove the unused code and associated tests. --- gunicorn/asgi/__init__.py | 4 +- gunicorn/asgi/message.py | 736 -------------------------------------- gunicorn/asgi/protocol.py | 2 +- tests/test_asgi.py | 304 ---------------- tests/test_asgi_parser.py | 324 ----------------- 5 files changed, 2 insertions(+), 1368 deletions(-) delete mode 100644 gunicorn/asgi/message.py delete mode 100644 tests/test_asgi.py delete mode 100644 tests/test_asgi_parser.py diff --git a/gunicorn/asgi/__init__.py b/gunicorn/asgi/__init__.py index c2f13b2a..7c5a0610 100644 --- a/gunicorn/asgi/__init__.py +++ b/gunicorn/asgi/__init__.py @@ -10,7 +10,6 @@ HTTP parsing infrastructure adapted for async I/O. Components: - AsyncUnreader: Async socket reading with pushback buffer -- AsyncRequest: Async HTTP request parser - ASGIProtocol: asyncio.Protocol implementation for HTTP handling - WebSocketProtocol: WebSocket protocol handler (RFC 6455) - LifespanManager: ASGI lifespan protocol support @@ -20,7 +19,6 @@ Usage: """ from gunicorn.asgi.unreader import AsyncUnreader -from gunicorn.asgi.message import AsyncRequest from gunicorn.asgi.lifespan import LifespanManager -__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager'] +__all__ = ['AsyncUnreader', 'LifespanManager'] diff --git a/gunicorn/asgi/message.py b/gunicorn/asgi/message.py deleted file mode 100644 index 45afb851..00000000 --- a/gunicorn/asgi/message.py +++ /dev/null @@ -1,736 +0,0 @@ -# -# This file is part of gunicorn released under the MIT license. -# See the NOTICE for more information. - -""" -Async version of gunicorn/http/message.py for ASGI workers. - -Reuses the parsing logic from the sync version, adapted for async I/O. -""" - -import ipaddress -import re -import socket -import struct - -from gunicorn.http.errors import ( - ExpectationFailed, - InvalidHeader, InvalidHeaderName, NoMoreData, - InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, - LimitRequestLine, LimitRequestHeaders, - UnsupportedTransferCoding, ObsoleteFolding, - InvalidProxyLine, InvalidProxyHeader, ForbiddenProxyRequest, - InvalidSchemeHeaders, -) -from gunicorn.http.message import ( - PP_V2_SIGNATURE, PPCommand, PPFamily, PPProtocol -) -from gunicorn.util import bytes_to_str, split_request_uri - -MAX_REQUEST_LINE = 8190 -MAX_HEADERS = 32768 -DEFAULT_MAX_HEADERFIELD_SIZE = 8190 - -# Reuse regex patterns from sync version -RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~" -TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS))) -METHOD_BADCHAR_RE = re.compile("[a-z#]") -VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)") -RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]") - - -def _ip_in_allow_list(ip_str, allow_list, networks): - """Check if IP address is in the allow list. - - Args: - ip_str: The IP address string to check - allow_list: The original allow list (strings, may contain "*") - networks: Pre-computed ipaddress.ip_network objects from config - """ - if '*' in allow_list: - return True - try: - ip = ipaddress.ip_address(ip_str) - except ValueError: - return False - for network in networks: - if ip in network: - return True - return False - - -class AsyncRequest: - """Async HTTP request parser. - - Parses HTTP/1.x requests using async I/O, reusing gunicorn's - parsing logic where possible. - """ - - def __init__(self, cfg, unreader, peer_addr, req_number=1): - self.cfg = cfg - self.unreader = unreader - self.peer_addr = peer_addr - self.remote_addr = peer_addr - self.req_number = req_number - - self.version = None - self.method = None - self.uri = None - self.path = None - self.query = None - self.fragment = None - self.headers = [] - self.trailers = [] - self.scheme = "https" if cfg.is_ssl else "http" - self.must_close = False - self._expected_100_continue = False - - self.proxy_protocol_info = None - - # Request line limit - self.limit_request_line = cfg.limit_request_line - if (self.limit_request_line < 0 - or self.limit_request_line >= MAX_REQUEST_LINE): - self.limit_request_line = MAX_REQUEST_LINE - - # Headers limits - self.limit_request_fields = cfg.limit_request_fields - if (self.limit_request_fields <= 0 - or self.limit_request_fields > MAX_HEADERS): - self.limit_request_fields = MAX_HEADERS - - self.limit_request_field_size = cfg.limit_request_field_size - if self.limit_request_field_size < 0: - self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE - - # Max header buffer size - max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE - self.max_buffer_headers = self.limit_request_fields * \ - (max_header_field_size + 2) + 4 - - # Body-related state - self.content_length = None - self.chunked = False - self._body_reader = None - self._body_remaining = 0 - - @classmethod - async def parse(cls, cfg, unreader, peer_addr, req_number=1): - """Parse an HTTP request from the stream. - - Args: - cfg: gunicorn config object - unreader: AsyncUnreader instance - peer_addr: client address tuple - req_number: request number on this connection (for keepalive) - - Returns: - AsyncRequest: Parsed request object - - Raises: - NoMoreData: If no data available - Various parsing errors for malformed requests - """ - req = cls(cfg, unreader, peer_addr, req_number) - await req._parse() - return req - - async def _parse(self): - """Parse the request from the unreader.""" - buf = bytearray() - await self._read_into(buf) - - # Handle proxy protocol if enabled and this is the first request - mode = self.cfg.proxy_protocol - if mode != "off" and self.req_number == 1: - buf = await self._handle_proxy_protocol(buf, mode) - - # Get request line - line, buf = await self._read_line(buf, self.limit_request_line) - - self._parse_request_line(line) - - # Headers - use bytearray.find() directly to avoid bytes() conversions - while True: - idx = buf.find(b"\r\n\r\n") - done = buf[:2] == b"\r\n" - - if idx < 0 and not done: - await self._read_into(buf) - if len(buf) > self.max_buffer_headers: - raise LimitRequestHeaders("max buffer headers") - else: - break - - if done: - self.unreader.unread(bytes(buf[2:])) - else: - self.headers = self._parse_headers(bytes(buf[:idx]), from_trailer=False) - self.unreader.unread(bytes(buf[idx + 4:])) - - self._set_body_reader() - - async def _read_into(self, buf): - """Read data from unreader and append to bytearray buffer.""" - data = await self.unreader.read() - if not data: - raise NoMoreData(bytes(buf)) - buf.extend(data) - - async def _read_line(self, buf, limit=0): - """Read a line from buffer, returning (line, remaining_buffer). - - Uses bytearray.find() directly to avoid repeated bytes() conversions. - """ - while True: - idx = buf.find(b"\r\n") - if idx >= 0: - if idx > limit > 0: - raise LimitRequestLine(idx, limit) - break - if len(buf) - 2 > limit > 0: - raise LimitRequestLine(len(buf), limit) - await self._read_into(buf) - - line = bytes(buf[:idx]) - remaining = bytearray(buf[idx + 2:]) - return (line, remaining) - - async def _handle_proxy_protocol(self, buf, mode): - """Handle PROXY protocol detection and parsing. - - Returns the buffer with proxy protocol data consumed. - """ - # Ensure we have enough data to detect v2 signature (12 bytes) - while len(buf) < 12: - await self._read_into(buf) - - # Check for v2 signature first - if mode in ("v2", "auto") and buf[:12] == PP_V2_SIGNATURE: - self._proxy_protocol_access_check() - return await self._parse_proxy_protocol_v2(buf) - - # Check for v1 prefix - if mode in ("v1", "auto") and buf[:6] == b"PROXY ": - self._proxy_protocol_access_check() - return await self._parse_proxy_protocol_v1(buf) - - # Not proxy protocol - return buffer unchanged - return buf - - def _proxy_protocol_access_check(self): - """Check if proxy protocol is allowed from this peer.""" - if (isinstance(self.peer_addr, tuple) and - not _ip_in_allow_list(self.peer_addr[0], self.cfg.proxy_allow_ips, - self.cfg.proxy_allow_networks())): - raise ForbiddenProxyRequest(self.peer_addr[0]) - - async def _parse_proxy_protocol_v1(self, buf): - """Parse PROXY protocol v1 (text format). - - Returns buffer with v1 header consumed. - """ - # Read until we find \r\n - data = bytes(buf) - while b"\r\n" not in data: - await self._read_into(buf) - data = bytes(buf) - - idx = data.find(b"\r\n") - line = bytes_to_str(data[:idx]) - remaining = bytearray(data[idx + 2:]) - - bits = line.split(" ") - - if len(bits) != 6: - raise InvalidProxyLine(line) - - proto = bits[1] - s_addr = bits[2] - d_addr = bits[3] - - if proto not in ["TCP4", "TCP6"]: - raise InvalidProxyLine("protocol '%s' not supported" % proto) - - if proto == "TCP4": - try: - socket.inet_pton(socket.AF_INET, s_addr) - socket.inet_pton(socket.AF_INET, d_addr) - except OSError: - raise InvalidProxyLine(line) - elif proto == "TCP6": - try: - socket.inet_pton(socket.AF_INET6, s_addr) - socket.inet_pton(socket.AF_INET6, d_addr) - except OSError: - raise InvalidProxyLine(line) - - try: - s_port = int(bits[4]) - d_port = int(bits[5]) - except ValueError: - raise InvalidProxyLine("invalid port %s" % line) - - if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)): - raise InvalidProxyLine("invalid port %s" % line) - - self.proxy_protocol_info = { - "proxy_protocol": proto, - "client_addr": s_addr, - "client_port": s_port, - "proxy_addr": d_addr, - "proxy_port": d_port - } - - return remaining - - async def _parse_proxy_protocol_v2(self, buf): - """Parse PROXY protocol v2 (binary format). - - Returns buffer with v2 header consumed. - """ - # We need at least 16 bytes for the header (12 signature + 4 header) - while len(buf) < 16: - await self._read_into(buf) - - # Parse header fields (after 12-byte signature) - ver_cmd = buf[12] - fam_proto = buf[13] - length = struct.unpack(">H", bytes(buf[14:16]))[0] - - # Validate version (high nibble must be 0x2) - version = (ver_cmd & 0xF0) >> 4 - if version != 2: - raise InvalidProxyHeader("unsupported version %d" % version) - - # Extract command (low nibble) - command = ver_cmd & 0x0F - if command not in (PPCommand.LOCAL, PPCommand.PROXY): - raise InvalidProxyHeader("unsupported command %d" % command) - - # Ensure we have the complete header - total_header_size = 16 + length - while len(buf) < total_header_size: - await self._read_into(buf) - - # For LOCAL command, no address info is provided - if command == PPCommand.LOCAL: - self.proxy_protocol_info = { - "proxy_protocol": "LOCAL", - "client_addr": None, - "client_port": None, - "proxy_addr": None, - "proxy_port": None - } - return bytearray(buf[total_header_size:]) - - # Extract address family and protocol - family = (fam_proto & 0xF0) >> 4 - protocol = fam_proto & 0x0F - - # We only support TCP (STREAM) - if protocol != PPProtocol.STREAM: - raise InvalidProxyHeader("only TCP protocol is supported") - - addr_data = bytes(buf[16:16 + length]) - - if family == PPFamily.INET: # IPv4 - if length < 12: # 4+4+2+2 - raise InvalidProxyHeader("insufficient address data for IPv4") - s_addr = socket.inet_ntop(socket.AF_INET, addr_data[0:4]) - d_addr = socket.inet_ntop(socket.AF_INET, addr_data[4:8]) - s_port = struct.unpack(">H", addr_data[8:10])[0] - d_port = struct.unpack(">H", addr_data[10:12])[0] - proto = "TCP4" - - elif family == PPFamily.INET6: # IPv6 - if length < 36: # 16+16+2+2 - raise InvalidProxyHeader("insufficient address data for IPv6") - s_addr = socket.inet_ntop(socket.AF_INET6, addr_data[0:16]) - d_addr = socket.inet_ntop(socket.AF_INET6, addr_data[16:32]) - s_port = struct.unpack(">H", addr_data[32:34])[0] - d_port = struct.unpack(">H", addr_data[34:36])[0] - proto = "TCP6" - - elif family == PPFamily.UNSPEC: - # No address info provided with PROXY command - self.proxy_protocol_info = { - "proxy_protocol": "UNSPEC", - "client_addr": None, - "client_port": None, - "proxy_addr": None, - "proxy_port": None - } - return bytearray(buf[total_header_size:]) - - else: - raise InvalidProxyHeader("unsupported address family %d" % family) - - # Set data - self.proxy_protocol_info = { - "proxy_protocol": proto, - "client_addr": s_addr, - "client_port": s_port, - "proxy_addr": d_addr, - "proxy_port": d_port - } - - return bytearray(buf[total_header_size:]) - - def _parse_request_line(self, line_bytes): - """Parse the HTTP request line.""" - bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] - if len(bits) != 3: - raise InvalidRequestLine(bytes_to_str(line_bytes)) - - # Method - self.method = bits[0] - - if not self.cfg.permit_unconventional_http_method: - if METHOD_BADCHAR_RE.search(self.method): - raise InvalidRequestMethod(self.method) - if not 3 <= len(bits[0]) <= 20: - raise InvalidRequestMethod(self.method) - if not TOKEN_RE.fullmatch(self.method): - raise InvalidRequestMethod(self.method) - if self.cfg.casefold_http_method: - self.method = self.method.upper() - - # URI - self.uri = bits[1] - - if len(self.uri) == 0: - raise InvalidRequestLine(bytes_to_str(line_bytes)) - - try: - parts = split_request_uri(self.uri) - except ValueError: - raise InvalidRequestLine(bytes_to_str(line_bytes)) - self.path = parts.path or "" - self.query = parts.query or "" - self.fragment = parts.fragment or "" - - # Version - match = VERSION_RE.fullmatch(bits[2]) - if match is None: - raise InvalidHTTPVersion(bits[2]) - self.version = (int(match.group(1)), int(match.group(2))) - if not (1, 0) <= self.version < (2, 0): - if not self.cfg.permit_unconventional_http_version: - raise InvalidHTTPVersion(self.version) - - def _parse_headers(self, data, from_trailer=False): - """Parse HTTP headers from raw data. - - Uses index-based iteration instead of list.pop(0) for O(1) access. - """ - cfg = self.cfg - headers = [] - - lines = [bytes_to_str(line) for line in data.split(b"\r\n")] - num_lines = len(lines) - i = 0 - - # Handle scheme headers - scheme_header = False - secure_scheme_headers = {} - forwarder_headers = [] - if from_trailer: - pass - elif (not isinstance(self.peer_addr, tuple) - or _ip_in_allow_list(self.peer_addr[0], cfg.forwarded_allow_ips, - cfg.forwarded_allow_networks())): - secure_scheme_headers = cfg.secure_scheme_headers - forwarder_headers = cfg.forwarder_headers - - while i < num_lines: - if len(headers) >= self.limit_request_fields: - raise LimitRequestHeaders("limit request headers fields") - - curr = lines[i] - i += 1 - header_length = len(curr) + len("\r\n") - if curr.find(":") <= 0: - raise InvalidHeader(curr) - name, value = curr.split(":", 1) - if self.cfg.strip_header_spaces: - name = name.rstrip(" \t") - if not TOKEN_RE.fullmatch(name): - raise InvalidHeaderName(name) - - name = name.upper() - value = [value.strip(" \t")] - - # Consume value continuation lines using index-based iteration - while i < num_lines and lines[i].startswith((" ", "\t")): - if not self.cfg.permit_obsolete_folding: - raise ObsoleteFolding(name) - curr = lines[i] - i += 1 - header_length += len(curr) + len("\r\n") - if header_length > self.limit_request_field_size > 0: - raise LimitRequestHeaders("limit request headers fields size") - value.append(curr.strip("\t ")) - value = " ".join(value) - - if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value): - raise InvalidHeader(name) - - if header_length > self.limit_request_field_size > 0: - raise LimitRequestHeaders("limit request headers fields size") - - if not from_trailer and name == "EXPECT": - # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1 - # "The Expect field value is case-insensitive." - if value.lower() == "100-continue": - if self.version < (1, 1): - # https://datatracker.ietf.org/doc/html/rfc9110#section-10.1.1-12 - # "A server that receives a 100-continue expectation - # in an HTTP/1.0 request MUST ignore that expectation." - pass - else: - self._expected_100_continue = True - # N.B. understood but ignored expect header does not return 417 - else: - raise ExpectationFailed(value) - - if name in secure_scheme_headers: - secure = value == secure_scheme_headers[name] - scheme = "https" if secure else "http" - if scheme_header: - if scheme != self.scheme: - raise InvalidSchemeHeaders() - else: - scheme_header = True - self.scheme = scheme - - if "_" in name: - if name in forwarder_headers or "*" in forwarder_headers: - pass - elif self.cfg.header_map == "dangerous": - pass - elif self.cfg.header_map == "drop": - continue - else: - raise InvalidHeaderName(name) - - headers.append((name, value)) - - return headers - - def _set_body_reader(self): - """Determine how to read the request body.""" - chunked = False - content_length = None - - for (name, value) in self.headers: - if name == "CONTENT-LENGTH": - if content_length is not None: - raise InvalidHeader("CONTENT-LENGTH", req=self) - content_length = value - elif name == "TRANSFER-ENCODING": - vals = [v.strip() for v in value.split(',')] - for val in vals: - if val.lower() == "chunked": - if chunked: - raise InvalidHeader("TRANSFER-ENCODING", req=self) - chunked = True - elif val.lower() == "identity": - if chunked: - raise InvalidHeader("TRANSFER-ENCODING", req=self) - elif val.lower() in ('compress', 'deflate', 'gzip'): - if chunked: - raise InvalidHeader("TRANSFER-ENCODING", req=self) - self.force_close() - else: - raise UnsupportedTransferCoding(value) - - if chunked: - if self.version < (1, 1): - raise InvalidHeader("TRANSFER-ENCODING", req=self) - if content_length is not None: - raise InvalidHeader("CONTENT-LENGTH", req=self) - self.chunked = True - self.content_length = None - self._body_remaining = -1 - elif content_length is not None: - try: - if str(content_length).isnumeric(): - content_length = int(content_length) - else: - raise InvalidHeader("CONTENT-LENGTH", req=self) - except ValueError: - raise InvalidHeader("CONTENT-LENGTH", req=self) - - if content_length < 0: - raise InvalidHeader("CONTENT-LENGTH", req=self) - - self.content_length = content_length - self._body_remaining = content_length - else: - # No body for requests without Content-Length or Transfer-Encoding - self.content_length = 0 - self._body_remaining = 0 - - def force_close(self): - """Mark connection for closing after this request.""" - self.must_close = True - - def should_close(self): - """Check if connection should be closed after this request.""" - if self.must_close: - return True - for (h, v) in self.headers: - if h == "CONNECTION": - v = v.lower().strip(" \t") - if v == "close": - return True - elif v == "keep-alive": - return False - break - return self.version <= (1, 0) - - def get_header(self, name): - """Get a header value by name (case-insensitive).""" - name = name.upper() - for (h, v) in self.headers: - if h == name: - return v - return None - - async def read_body(self, size=8192): - """Read a chunk of the request body. - - Args: - size: Maximum bytes to read - - Returns: - bytes: Body data, empty bytes when body is exhausted - """ - if self._body_remaining == 0: - return b"" - - if self.chunked: - return await self._read_chunked_body(size) - else: - return await self._read_length_body(size) - - async def _read_length_body(self, size): - """Read from a length-delimited body.""" - if self._body_remaining <= 0: - return b"" - - to_read = min(size, self._body_remaining) - data = await self.unreader.read(to_read) - if data: - self._body_remaining -= len(data) - return data - - async def _read_chunked_body(self, size): - """Read from a chunked body.""" - if self._body_reader is None: - self._body_reader = self._chunked_body_reader() - - try: - return await anext(self._body_reader) - except StopAsyncIteration: - self._body_remaining = 0 - return b"" - - async def _chunked_body_reader(self): - """Async generator for reading chunked body.""" - while True: - # Read chunk size line - size_line = await self._read_chunk_size_line() - # Parse chunk size (handle extensions) - chunk_size, *_ = size_line.split(b";", 1) - if _: - chunk_size = chunk_size.rstrip(b" \t") - - if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size): - raise InvalidHeader("Invalid chunk size") - if len(chunk_size) == 0: - raise InvalidHeader("Invalid chunk size") - - chunk_size = int(chunk_size, 16) - - if chunk_size == 0: - # Final chunk - skip trailers and final CRLF - await self._skip_trailers() - return - - # Read chunk data - remaining = chunk_size - while remaining > 0: - data = await self.unreader.read(min(remaining, 8192)) - if not data: - raise NoMoreData() - remaining -= len(data) - yield data - - # Skip chunk terminating CRLF - crlf = await self.unreader.read(2) - if crlf != b"\r\n": - # May have partial read, try to get the rest - while len(crlf) < 2: - more = await self.unreader.read(2 - len(crlf)) - if not more: - break - crlf += more - if crlf != b"\r\n": - raise InvalidHeader("Missing chunk terminator") - - async def _read_chunk_size_line(self): - """Read a chunk size line. - - Performance optimization: reads 64-byte chunks instead of 1 byte at a time, - then pushes back any excess data after finding the line terminator. - """ - buf = bytearray() - while True: - data = await self.unreader.read(64) - if not data: - raise NoMoreData() - buf.extend(data) - idx = buf.find(b"\r\n") - if idx >= 0: - # Push back any data after the line - if idx + 2 < len(buf): - self.unreader.unread(bytes(buf[idx + 2:])) - return bytes(buf[:idx]) - - async def _skip_trailers(self): - """Skip trailer headers after chunked body. - - Performance optimization: reads 64-byte chunks instead of 1 byte at a time, - then pushes back any excess data after finding the trailer terminator. - """ - buf = bytearray() - while True: - data = await self.unreader.read(64) - if not data: - return - buf.extend(data) - # Check for empty trailer (just CRLF) - if buf[:2] == b"\r\n": - # Push back remaining data - if len(buf) > 2: - self.unreader.unread(bytes(buf[2:])) - return - # Check for full trailer terminator - idx = buf.find(b"\r\n\r\n") - if idx >= 0: - # Push back data after the trailer - if idx + 4 < len(buf): - self.unreader.unread(bytes(buf[idx + 4:])) - return - - async def drain_body(self): - """Drain any unread body data. - - Should be called before reusing connection for keepalive. - """ - while True: - data = await self.read_body(8192) - if not data: - break diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index ac38d155..4784763f 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -934,7 +934,7 @@ class ASGIProtocol(asyncio.Protocol): def _build_http_scope(self, request, sockname, peername): """Build ASGI HTTP scope from parsed request.""" # Use pre-computed bytes headers if available (fast path) - # Fall back to conversion for legacy requests (AsyncRequest, HTTP/2) + # Fall back to conversion for HTTP/2 requests headers_bytes = getattr(request, 'headers_bytes', None) if isinstance(headers_bytes, list): headers = list(headers_bytes) # Copy to avoid mutation diff --git a/tests/test_asgi.py b/tests/test_asgi.py deleted file mode 100644 index 74e11ba2..00000000 --- a/tests/test_asgi.py +++ /dev/null @@ -1,304 +0,0 @@ -# -# This file is part of gunicorn released under the MIT license. -# See the NOTICE for more information. - -""" -Tests for ASGI worker components. -""" - -import asyncio -import ipaddress -import pytest - -from gunicorn.asgi.unreader import AsyncUnreader -from gunicorn.asgi.message import AsyncRequest - - -class MockStreamReader: - """Mock asyncio.StreamReader for testing.""" - - def __init__(self, data): - self.data = data - self.pos = 0 - - async def read(self, size=-1): - if self.pos >= len(self.data): - return b"" - if size < 0: - result = self.data[self.pos:] - self.pos = len(self.data) - else: - result = self.data[self.pos:self.pos + size] - self.pos += size - return result - - async def readexactly(self, n): - if self.pos + n > len(self.data): - raise asyncio.IncompleteReadError( - self.data[self.pos:], n - ) - result = self.data[self.pos:self.pos + n] - self.pos += n - return result - - -class MockConfig: - """Mock gunicorn config for testing.""" - - def __init__(self): - self.is_ssl = False - self.proxy_protocol = "off" - self.proxy_allow_ips = ["127.0.0.1"] - self.forwarded_allow_ips = ["127.0.0.1"] - self._proxy_allow_networks = None - self._forwarded_allow_networks = None - self.secure_scheme_headers = {} - self.forwarder_headers = [] - self.limit_request_line = 8190 - self.limit_request_fields = 100 - self.limit_request_field_size = 8190 - self.permit_unconventional_http_method = False - self.permit_unconventional_http_version = False - self.permit_obsolete_folding = False - self.casefold_http_method = False - self.strip_header_spaces = False - self.header_map = "refuse" - - def forwarded_allow_networks(self): - if self._forwarded_allow_networks is None: - self._forwarded_allow_networks = [ - ipaddress.ip_network(addr) - for addr in self.forwarded_allow_ips - if addr != "*" - ] - return self._forwarded_allow_networks - - def proxy_allow_networks(self): - if self._proxy_allow_networks is None: - self._proxy_allow_networks = [ - ipaddress.ip_network(addr) - for addr in self.proxy_allow_ips - if addr != "*" - ] - return self._proxy_allow_networks - - -# AsyncUnreader Tests - -@pytest.mark.asyncio -async def test_async_unreader_read_chunk(): - """Test basic chunk reading.""" - reader = MockStreamReader(b"hello world") - unreader = AsyncUnreader(reader) - data = await unreader.read() - assert data == b"hello world" - - -@pytest.mark.asyncio -async def test_async_unreader_read_size(): - """Test reading specific size.""" - reader = MockStreamReader(b"hello world") - unreader = AsyncUnreader(reader) - data = await unreader.read(5) - assert data == b"hello" - - -@pytest.mark.asyncio -async def test_async_unreader_unread(): - """Test unread functionality.""" - reader = MockStreamReader(b"hello world") - unreader = AsyncUnreader(reader) - - # Read all data - data = await unreader.read() - assert data == b"hello world" - - # Unread some data - unreader.unread(b"world") - - # Read again should get unread data - data = await unreader.read() - assert data == b"world" - - -@pytest.mark.asyncio -async def test_async_unreader_read_zero(): - """Test reading zero bytes.""" - reader = MockStreamReader(b"hello") - unreader = AsyncUnreader(reader) - data = await unreader.read(0) - assert data == b"" - - -@pytest.mark.asyncio -async def test_async_unreader_read_empty(): - """Test reading from empty stream.""" - reader = MockStreamReader(b"") - unreader = AsyncUnreader(reader) - data = await unreader.read() - assert data == b"" - - -# AsyncRequest Tests - -@pytest.mark.asyncio -async def test_async_request_simple_get(): - """Test parsing a simple GET request.""" - request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.method == "GET" - assert request.path == "/path" - assert request.version == (1, 1) - assert ("HOST", "localhost") in request.headers - - -@pytest.mark.asyncio -async def test_async_request_with_query(): - """Test parsing request with query string.""" - request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.method == "GET" - assert request.path == "/search" - assert request.query == "q=test&page=1" - - -@pytest.mark.asyncio -async def test_async_request_post_with_body(): - """Test parsing POST request with body.""" - request_data = ( - b"POST /submit HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Content-Length: 11\r\n" - b"\r\n" - b"hello=world" - ) - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.method == "POST" - assert request.path == "/submit" - assert request.content_length == 11 - - # Read body - body = await request.read_body(100) - assert body == b"hello=world" - - -@pytest.mark.asyncio -async def test_async_request_multiple_headers(): - """Test parsing request with multiple headers.""" - request_data = ( - b"GET / HTTP/1.1\r\n" - b"Host: localhost\r\n" - b"Accept: text/html\r\n" - b"Accept-Language: en-US\r\n" - b"Connection: keep-alive\r\n" - b"\r\n" - ) - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert len(request.headers) == 4 - assert request.get_header("HOST") == "localhost" - assert request.get_header("ACCEPT") == "text/html" - - -@pytest.mark.asyncio -async def test_async_request_should_close_http10(): - """Test connection close detection for HTTP/1.0.""" - request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.version == (1, 0) - assert request.should_close() is True - - -@pytest.mark.asyncio -async def test_async_request_should_close_connection_header(): - """Test connection close detection with Connection header.""" - request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.should_close() is True - - -@pytest.mark.asyncio -async def test_async_request_keepalive(): - """Test keepalive detection.""" - request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.should_close() is False - - -@pytest.mark.asyncio -async def test_async_request_no_body_for_get(): - """Test that GET requests have no body by default.""" - request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert request.content_length == 0 - body = await request.read_body() - assert body == b"" - - -# Error handling tests - -@pytest.mark.asyncio -async def test_async_request_invalid_method(): - """Test invalid HTTP method detection.""" - from gunicorn.http.errors import InvalidRequestMethod - - request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - with pytest.raises(InvalidRequestMethod): - await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - -@pytest.mark.asyncio -async def test_async_request_invalid_http_version(): - """Test invalid HTTP version detection.""" - from gunicorn.http.errors import InvalidHTTPVersion - - request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n" - reader = MockStreamReader(request_data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - with pytest.raises(InvalidHTTPVersion): - await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) diff --git a/tests/test_asgi_parser.py b/tests/test_asgi_parser.py deleted file mode 100644 index e20aad17..00000000 --- a/tests/test_asgi_parser.py +++ /dev/null @@ -1,324 +0,0 @@ -# -# This file is part of gunicorn released under the MIT license. -# See the NOTICE for more information. - -""" -Tests for ASGI HTTP parser optimizations. -""" - -import ipaddress -import pytest - -from gunicorn.asgi.unreader import AsyncUnreader -from gunicorn.asgi.message import AsyncRequest - - -class MockStreamReader: - """Mock asyncio.StreamReader for testing.""" - - def __init__(self, data): - self.data = data - self.pos = 0 - - async def read(self, size=-1): - if self.pos >= len(self.data): - return b"" - if size < 0: - result = self.data[self.pos:] - self.pos = len(self.data) - else: - result = self.data[self.pos:self.pos + size] - self.pos += size - return result - - -class MockConfig: - """Mock gunicorn config for testing.""" - - def __init__(self): - self.is_ssl = False - self.proxy_protocol = "off" - self.proxy_allow_ips = ["127.0.0.1"] - self.forwarded_allow_ips = ["127.0.0.1"] - self._proxy_allow_networks = None - self._forwarded_allow_networks = None - self.secure_scheme_headers = {} - self.forwarder_headers = [] - self.limit_request_line = 8190 - self.limit_request_fields = 100 - self.limit_request_field_size = 8190 - self.permit_unconventional_http_method = False - self.permit_unconventional_http_version = False - self.permit_obsolete_folding = False - self.casefold_http_method = False - self.strip_header_spaces = False - self.header_map = "refuse" - - def forwarded_allow_networks(self): - if self._forwarded_allow_networks is None: - self._forwarded_allow_networks = [ - ipaddress.ip_network(addr) - for addr in self.forwarded_allow_ips - if addr != "*" - ] - return self._forwarded_allow_networks - - def proxy_allow_networks(self): - if self._proxy_allow_networks is None: - self._proxy_allow_networks = [ - ipaddress.ip_network(addr) - for addr in self.proxy_allow_ips - if addr != "*" - ] - return self._proxy_allow_networks - - -# Optimized Chunk Reading Tests - -@pytest.mark.asyncio -async def test_chunk_size_line_reading(): - """Test optimized chunk size line reading.""" - # Simulate chunked body with chunk size line - data = b"a\r\nhello body\r\n0\r\n\r\n" - reader = MockStreamReader(data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000)) - # Access the private method for testing - line = await req._read_chunk_size_line() - assert line == b"a" - - -@pytest.mark.asyncio -async def test_skip_trailers_empty(): - """Test skipping empty trailers.""" - data = b"\r\n" - reader = MockStreamReader(data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000)) - # Should not raise - await req._skip_trailers() - - -@pytest.mark.asyncio -async def test_skip_trailers_with_headers(): - """Test skipping trailers with actual headers.""" - data = b"X-Checksum: abc123\r\n\r\n" - reader = MockStreamReader(data) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = AsyncRequest(cfg, unreader, ("127.0.0.1", 8000)) - # Should not raise - await req._skip_trailers() - - -# Buffer Reuse Tests - -@pytest.mark.asyncio -async def test_unreader_buffer_reuse(): - """Test that AsyncUnreader reuses buffers efficiently.""" - data = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n" - reader = MockStreamReader(data) - unreader = AsyncUnreader(reader) - - # Read in chunks - chunk1 = await unreader.read(10) - assert chunk1 == b"GET / HTTP" - - # Read more - chunk2 = await unreader.read(10) - assert chunk2 == b"/1.1\r\nHost" - - # Unread some data - unreader.unread(b"/1.1\r\nHost") - - # Read again - should get unreaded data - chunk3 = await unreader.read(10) - assert chunk3 == b"/1.1\r\nHost" - - -@pytest.mark.asyncio -async def test_unreader_unread_prepends(): - """Test that unread prepends data.""" - data = b"original" - reader = MockStreamReader(data) - unreader = AsyncUnreader(reader) - - # Read some data first - await unreader.read(4) # "orig" - - # Unread something different - unreader.unread(b"NEW") - - # Should read the new data first - result = await unreader.read(3) - assert result == b"NEW" - - -# Header Parsing Optimization Tests - -@pytest.mark.asyncio -async def test_header_parsing_index_iteration(): - """Test that header parsing uses index-based iteration.""" - raw_request = ( - b"GET / HTTP/1.1\r\n" - b"Host: example.com\r\n" - b"Content-Type: text/plain\r\n" - b"X-Custom: value\r\n" - b"\r\n" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert req.method == "GET" - assert req.path == "/" - assert len(req.headers) == 3 - assert ("HOST", "example.com") in req.headers - assert ("CONTENT-TYPE", "text/plain") in req.headers - assert ("X-CUSTOM", "value") in req.headers - - -@pytest.mark.asyncio -async def test_many_headers_performance(): - """Test parsing request with many headers.""" - headers = [] - for i in range(50): - headers.append(f"X-Header-{i}: value-{i}\r\n") - - raw_request = ( - b"GET / HTTP/1.1\r\n" - + "".join(headers).encode() - + b"\r\n" - ) - - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert len(req.headers) == 50 - - -# Bytearray Find Optimization Tests - -@pytest.mark.asyncio -async def test_bytearray_find_optimization(): - """Test that bytearray.find() is used instead of bytes().find().""" - raw_request = ( - b"GET /path?query=value HTTP/1.1\r\n" - b"Host: example.com\r\n" - b"Content-Length: 5\r\n" - b"\r\n" - b"hello" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert req.method == "GET" - assert req.path == "/path" - assert req.query == "query=value" - assert req.content_length == 5 - - -# Chunked Body Tests with Optimized Reading - -@pytest.mark.asyncio -async def test_chunked_body_optimized_reading(): - """Test reading chunked body with optimized chunk reading.""" - raw_request = ( - b"POST / HTTP/1.1\r\n" - b"Host: example.com\r\n" - b"Transfer-Encoding: chunked\r\n" - b"\r\n" - b"5\r\nhello\r\n" - b"6\r\n world\r\n" - b"0\r\n\r\n" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert req.chunked is True - assert req.content_length is None - - # Read body - body_parts = [] - while True: - chunk = await req.read_body(1024) - if not chunk: - break - body_parts.append(chunk) - - body = b"".join(body_parts) - assert body == b"hello world" - - -@pytest.mark.asyncio -async def test_chunked_body_with_extension(): - """Test reading chunked body with chunk extensions.""" - raw_request = ( - b"POST / HTTP/1.1\r\n" - b"Host: example.com\r\n" - b"Transfer-Encoding: chunked\r\n" - b"\r\n" - b"5;ext=value\r\nhello\r\n" - b"0\r\n\r\n" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - chunk = await req.read_body(1024) - assert chunk == b"hello" - - -# Edge Cases - -@pytest.mark.asyncio -async def test_empty_headers(): - """Test request with no headers.""" - raw_request = ( - b"GET / HTTP/1.1\r\n" - b"\r\n" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert req.method == "GET" - assert len(req.headers) == 0 - - -@pytest.mark.asyncio -async def test_large_header_value(): - """Test request with large header value.""" - large_value = "x" * 4000 # Within default limit - raw_request = ( - b"GET / HTTP/1.1\r\n" - + f"X-Large-Header: {large_value}\r\n".encode() - + b"\r\n" - ) - reader = MockStreamReader(raw_request) - unreader = AsyncUnreader(reader) - cfg = MockConfig() - - req = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) - - assert req.get_header("X-Large-Header") == large_value