From ae1eea8108b10e8aa3d5f937c7eb8076a0d2ef56 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 17:05:29 +0100 Subject: [PATCH 1/8] asgi: Add native ASGI worker with HTTP and WebSocket support Add a new ASGI worker type that provides native async support using gunicorn's own HTTP parsing infrastructure adapted for asyncio. Features: - HTTP/1.1 with keepalive support - WebSocket connections (RFC 6455) - ASGI lifespan protocol for startup/shutdown hooks - Optional uvloop support for improved performance - Full proxy protocol support (inherited from gunicorn) New configuration options: - --asgi-loop: Event loop selection (auto/asyncio/uvloop) - --asgi-lifespan: Lifespan protocol control (auto/on/off) - --root-path: ASGI root path for reverse proxy setups Usage: gunicorn -k asgi myapp:app --- examples/asgi/__init__.py | 7 + examples/asgi/basic_app.py | 130 +++++++ examples/asgi/websocket_app.py | 235 ++++++++++++ gunicorn/asgi/__init__.py | 26 ++ gunicorn/asgi/lifespan.py | 178 +++++++++ gunicorn/asgi/message.py | 562 ++++++++++++++++++++++++++++ gunicorn/asgi/protocol.py | 424 ++++++++++++++++++++++ gunicorn/asgi/unreader.py | 100 +++++ gunicorn/asgi/websocket.py | 369 +++++++++++++++++++ gunicorn/config.py | 91 +++++ gunicorn/workers/__init__.py | 1 + gunicorn/workers/gasgi.py | 282 +++++++++++++++ pyproject.toml | 1 + tests/test_asgi.py | 285 +++++++++++++++ tests/test_asgi_worker.py | 643 +++++++++++++++++++++++++++++++++ 15 files changed, 3334 insertions(+) create mode 100644 examples/asgi/__init__.py create mode 100644 examples/asgi/basic_app.py create mode 100644 examples/asgi/websocket_app.py create mode 100644 gunicorn/asgi/__init__.py create mode 100644 gunicorn/asgi/lifespan.py create mode 100644 gunicorn/asgi/message.py create mode 100644 gunicorn/asgi/protocol.py create mode 100644 gunicorn/asgi/unreader.py create mode 100644 gunicorn/asgi/websocket.py create mode 100644 gunicorn/workers/gasgi.py create mode 100644 tests/test_asgi.py create mode 100644 tests/test_asgi_worker.py diff --git a/examples/asgi/__init__.py b/examples/asgi/__init__.py new file mode 100644 index 00000000..1c9ecbeb --- /dev/null +++ b/examples/asgi/__init__.py @@ -0,0 +1,7 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI example applications for gunicorn. +""" diff --git a/examples/asgi/basic_app.py b/examples/asgi/basic_app.py new file mode 100644 index 00000000..73a160fe --- /dev/null +++ b/examples/asgi/basic_app.py @@ -0,0 +1,130 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Basic ASGI application example. + +Run with: + gunicorn -k asgi examples.asgi.basic_app:app + +Test with: + curl http://127.0.0.1:8000/ + curl http://127.0.0.1:8000/hello + curl -X POST http://127.0.0.1:8000/echo -d "test data" +""" + + +async def app(scope, receive, send): + """Simple ASGI application demonstrating basic functionality.""" + + if scope["type"] == "lifespan": + await handle_lifespan(scope, receive, send) + elif scope["type"] == "http": + await handle_http(scope, receive, send) + else: + raise ValueError(f"Unknown scope type: {scope['type']}") + + +async def handle_lifespan(scope, receive, send): + """Handle lifespan events (startup/shutdown).""" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + print("ASGI application starting up...") + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + print("ASGI application shutting down...") + await send({"type": "lifespan.shutdown.complete"}) + return + + +async def handle_http(scope, receive, send): + """Handle HTTP requests.""" + path = scope["path"] + method = scope["method"] + + if path == "/" and method == "GET": + await send_response(send, 200, b"Welcome to gunicorn ASGI!\n") + + elif path == "/hello" and method == "GET": + name = get_query_param(scope, "name", "World") + body = f"Hello, {name}!\n".encode() + await send_response(send, 200, body) + + elif path == "/echo" and method == "POST": + body = await read_body(receive) + await send_response(send, 200, body, content_type=b"application/octet-stream") + + elif path == "/headers": + headers_info = format_headers(scope["headers"]) + await send_response(send, 200, headers_info.encode()) + + elif path == "/info": + info = format_request_info(scope) + await send_response(send, 200, info.encode(), content_type=b"application/json") + + else: + await send_response(send, 404, b"Not Found\n") + + +async def send_response(send, status, body, content_type=b"text/plain"): + """Send an HTTP response.""" + await send({ + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", content_type), + (b"content-length", str(len(body)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": body, + }) + + +async def read_body(receive): + """Read the full request body.""" + body = b"" + while True: + message = await receive() + body += message.get("body", b"") + if not message.get("more_body", False): + break + return body + + +def get_query_param(scope, name, default=None): + """Get a query parameter value.""" + query_string = scope.get("query_string", b"").decode() + for param in query_string.split("&"): + if "=" in param: + key, value = param.split("=", 1) + if key == name: + return value + return default + + +def format_headers(headers): + """Format headers for display.""" + lines = ["Request Headers:"] + for name, value in headers: + lines.append(f" {name.decode()}: {value.decode()}") + return "\n".join(lines) + "\n" + + +def format_request_info(scope): + """Format request info as JSON.""" + import json + info = { + "method": scope["method"], + "path": scope["path"], + "query_string": scope.get("query_string", b"").decode(), + "http_version": scope["http_version"], + "scheme": scope["scheme"], + "server": list(scope.get("server") or []), + "client": list(scope.get("client") or []), + "root_path": scope.get("root_path", ""), + } + return json.dumps(info, indent=2) + "\n" diff --git a/examples/asgi/websocket_app.py b/examples/asgi/websocket_app.py new file mode 100644 index 00000000..8423c30e --- /dev/null +++ b/examples/asgi/websocket_app.py @@ -0,0 +1,235 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +WebSocket ASGI application example. + +Run with: + gunicorn -k asgi examples.asgi.websocket_app:app + +Test with: + # Using websocat (install with: cargo install websocat) + websocat ws://127.0.0.1:8000/ws + + # Or using Python websockets library + python -c " + import asyncio + import websockets + async def test(): + async with websockets.connect('ws://127.0.0.1:8000/ws') as ws: + await ws.send('Hello') + print(await ws.recv()) + asyncio.run(test()) + " +""" + + +async def app(scope, receive, send): + """ASGI application with WebSocket support.""" + + if scope["type"] == "lifespan": + await handle_lifespan(scope, receive, send) + elif scope["type"] == "http": + await handle_http(scope, receive, send) + elif scope["type"] == "websocket": + await handle_websocket(scope, receive, send) + else: + raise ValueError(f"Unknown scope type: {scope['type']}") + + +async def handle_lifespan(scope, receive, send): + """Handle lifespan events.""" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + + +async def handle_http(scope, receive, send): + """Handle HTTP requests - serve a simple HTML page for WebSocket testing.""" + path = scope["path"] + + if path == "/": + html = HTML_PAGE.encode() + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html"), + (b"content-length", str(len(html)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": html, + }) + else: + await send({ + "type": "http.response.start", + "status": 404, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Not Found", + }) + + +async def handle_websocket(scope, receive, send): + """Handle WebSocket connections.""" + path = scope["path"] + + if path == "/ws": + await echo_websocket(scope, receive, send) + elif path == "/ws/chat": + await chat_websocket(scope, receive, send) + else: + # Reject the connection + await send({"type": "websocket.close", "code": 4004}) + + +async def echo_websocket(scope, receive, send): + """Echo WebSocket - sends back whatever it receives.""" + # Wait for connection + message = await receive() + if message["type"] != "websocket.connect": + return + + # Accept the connection + await send({"type": "websocket.accept"}) + + # Echo loop + try: + while True: + message = await receive() + + if message["type"] == "websocket.disconnect": + break + + if message["type"] == "websocket.receive": + if "text" in message: + # Echo text back + await send({ + "type": "websocket.send", + "text": f"Echo: {message['text']}" + }) + elif "bytes" in message: + # Echo bytes back + await send({ + "type": "websocket.send", + "bytes": message["bytes"] + }) + except Exception as e: + print(f"WebSocket error: {e}") + finally: + try: + await send({"type": "websocket.close", "code": 1000}) + except Exception: + pass + + +async def chat_websocket(scope, receive, send): + """Chat WebSocket - simple broadcast example.""" + message = await receive() + if message["type"] != "websocket.connect": + return + + await send({ + "type": "websocket.accept", + "subprotocol": "chat" + }) + + await send({ + "type": "websocket.send", + "text": "Welcome to the chat! Send messages and they will be echoed back." + }) + + try: + while True: + message = await receive() + + if message["type"] == "websocket.disconnect": + break + + if message["type"] == "websocket.receive" and "text" in message: + text = message["text"] + await send({ + "type": "websocket.send", + "text": f"[You]: {text}" + }) + except Exception: + pass + + +HTML_PAGE = """ + + + WebSocket Test + + + +

WebSocket Test

+
+ + + + + + + + +""" diff --git a/gunicorn/asgi/__init__.py b/gunicorn/asgi/__init__.py new file mode 100644 index 00000000..c2f13b2a --- /dev/null +++ b/gunicorn/asgi/__init__.py @@ -0,0 +1,26 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI support for gunicorn. + +This module provides native ASGI worker support, using gunicorn's own +HTTP parsing infrastructure adapted for async I/O. + +Components: +- AsyncUnreader: Async socket reading with pushback buffer +- AsyncRequest: Async HTTP request parser +- ASGIProtocol: asyncio.Protocol implementation for HTTP handling +- WebSocketProtocol: WebSocket protocol handler (RFC 6455) +- LifespanManager: ASGI lifespan protocol support + +Usage: + gunicorn -k asgi myapp:app +""" + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest +from gunicorn.asgi.lifespan import LifespanManager + +__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager'] diff --git a/gunicorn/asgi/lifespan.py b/gunicorn/asgi/lifespan.py new file mode 100644 index 00000000..9811cf56 --- /dev/null +++ b/gunicorn/asgi/lifespan.py @@ -0,0 +1,178 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI lifespan protocol manager. + +Manages startup and shutdown events for ASGI applications, +enabling frameworks like FastAPI to run initialization and +cleanup code. +""" + +import asyncio + + +class LifespanManager: + """Manages ASGI lifespan events (startup/shutdown). + + The lifespan protocol allows ASGI applications to run code at + startup and shutdown. This is essential for applications that + need to initialize database connections, caches, or other + resources. + + ASGI lifespan messages: + - Server sends: {"type": "lifespan.startup"} + - App responds: {"type": "lifespan.startup.complete"} or + {"type": "lifespan.startup.failed", "message": "..."} + - Server sends: {"type": "lifespan.shutdown"} + - App responds: {"type": "lifespan.shutdown.complete"} + """ + + def __init__(self, app, logger, state=None): + """Initialize the lifespan manager. + + Args: + app: ASGI application callable + logger: Logger instance + state: Shared state dict for the application + """ + self.app = app + self.logger = logger + self.state = state if state is not None else {} + + self._startup_complete = asyncio.Event() + self._shutdown_complete = asyncio.Event() + self._startup_failed = False + self._startup_error = None + self._shutdown_error = None + self._receive_queue = asyncio.Queue() + self._task = None + self._app_finished = False + + async def startup(self): + """Run lifespan startup and wait for completion. + + Raises: + RuntimeError: If startup fails or app doesn't support lifespan + """ + scope = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "state": self.state, + } + + # Send startup event + await self._receive_queue.put({"type": "lifespan.startup"}) + + # Run lifespan in background task + self._task = asyncio.create_task(self._run_lifespan(scope)) + + # Wait for startup with timeout + try: + await asyncio.wait_for( + self._startup_complete.wait(), + timeout=30.0 # Reasonable startup timeout + ) + except asyncio.TimeoutError: + if self._task: + self._task.cancel() + raise RuntimeError("Lifespan startup timed out") + + if self._startup_failed: + if self._task: + self._task.cancel() + msg = self._startup_error or "Unknown error" + raise RuntimeError(f"Lifespan startup failed: {msg}") + + self.logger.debug("ASGI lifespan startup complete") + + async def shutdown(self): + """Signal shutdown and wait for completion. + + This should be called during graceful shutdown. + """ + if self._app_finished: + self.logger.debug("ASGI lifespan already finished") + return + + # Send shutdown event + await self._receive_queue.put({"type": "lifespan.shutdown"}) + + # Wait for shutdown with timeout + try: + await asyncio.wait_for( + self._shutdown_complete.wait(), + timeout=30.0 # Reasonable shutdown timeout + ) + except asyncio.TimeoutError: + self.logger.warning("Lifespan shutdown timed out") + + if self._shutdown_error: + self.logger.error("Lifespan shutdown error: %s", self._shutdown_error) + + # Cancel the task if still running + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + self.logger.debug("ASGI lifespan shutdown complete") + + async def _run_lifespan(self, scope): + """Run the ASGI lifespan protocol.""" + try: + await self.app(scope, self._receive, self._send) + except asyncio.CancelledError: + raise + except Exception as e: + self.logger.debug("Lifespan application raised: %s", e) + # If startup hasn't completed, mark it as failed + if not self._startup_complete.is_set(): + self._startup_failed = True + self._startup_error = str(e) + self._startup_complete.set() + # If shutdown hasn't completed, mark error + elif not self._shutdown_complete.is_set(): + self._shutdown_error = str(e) + self._shutdown_complete.set() + finally: + self._app_finished = True + # Ensure events are set to unblock waiters + if not self._startup_complete.is_set(): + self._startup_failed = True + self._startup_error = "Application exited before startup complete" + self._startup_complete.set() + if not self._shutdown_complete.is_set(): + self._shutdown_complete.set() + + async def _receive(self): + """ASGI receive callable for lifespan.""" + return await self._receive_queue.get() + + async def _send(self, message): + """ASGI send callable for lifespan.""" + msg_type = message["type"] + + if msg_type == "lifespan.startup.complete": + self._startup_complete.set() + self.logger.debug("Received lifespan.startup.complete") + + elif msg_type == "lifespan.startup.failed": + self._startup_failed = True + self._startup_error = message.get("message", "") + self._startup_complete.set() + self.logger.debug("Received lifespan.startup.failed: %s", + self._startup_error) + + elif msg_type == "lifespan.shutdown.complete": + self._shutdown_complete.set() + self.logger.debug("Received lifespan.shutdown.complete") + + elif msg_type == "lifespan.shutdown.failed": + self._shutdown_error = message.get("message", "") + self._shutdown_complete.set() + self.logger.debug("Received lifespan.shutdown.failed: %s", + self._shutdown_error) diff --git a/gunicorn/asgi/message.py b/gunicorn/asgi/message.py new file mode 100644 index 00000000..d7d20c83 --- /dev/null +++ b/gunicorn/asgi/message.py @@ -0,0 +1,562 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Async version of gunicorn/http/message.py for ASGI workers. + +Reuses the parsing logic from the sync version, adapted for async I/O. +""" + +import io +import re +import socket + +from gunicorn.http.errors import ( + InvalidHeader, InvalidHeaderName, NoMoreData, + InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, + LimitRequestLine, LimitRequestHeaders, + UnsupportedTransferCoding, ObsoleteFolding, + InvalidProxyLine, ForbiddenProxyRequest, + InvalidSchemeHeaders, +) +from gunicorn.util import bytes_to_str, split_request_uri + +MAX_REQUEST_LINE = 8190 +MAX_HEADERS = 32768 +DEFAULT_MAX_HEADERFIELD_SIZE = 8190 + +# Reuse regex patterns from sync version +RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~" +TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS))) +METHOD_BADCHAR_RE = re.compile("[a-z#]") +VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)") +RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]") + + +class AsyncRequest: + """Async HTTP request parser. + + Parses HTTP/1.x requests using async I/O, reusing gunicorn's + parsing logic where possible. + """ + + def __init__(self, cfg, unreader, peer_addr, req_number=1): + self.cfg = cfg + self.unreader = unreader + self.peer_addr = peer_addr + self.remote_addr = peer_addr + self.req_number = req_number + + self.version = None + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = None + self.headers = [] + self.trailers = [] + self.scheme = "https" if cfg.is_ssl else "http" + self.must_close = False + + self.proxy_protocol_info = None + + # Request line limit + self.limit_request_line = cfg.limit_request_line + if (self.limit_request_line < 0 + or self.limit_request_line >= MAX_REQUEST_LINE): + self.limit_request_line = MAX_REQUEST_LINE + + # Headers limits + self.limit_request_fields = cfg.limit_request_fields + if (self.limit_request_fields <= 0 + or self.limit_request_fields > MAX_HEADERS): + self.limit_request_fields = MAX_HEADERS + + self.limit_request_field_size = cfg.limit_request_field_size + if self.limit_request_field_size < 0: + self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE + + # Max header buffer size + max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE + self.max_buffer_headers = self.limit_request_fields * \ + (max_header_field_size + 2) + 4 + + # Body-related state + self.content_length = None + self.chunked = False + self._body_reader = None + self._body_remaining = 0 + + @classmethod + async def parse(cls, cfg, unreader, peer_addr, req_number=1): + """Parse an HTTP request from the stream. + + Args: + cfg: gunicorn config object + unreader: AsyncUnreader instance + peer_addr: client address tuple + req_number: request number on this connection (for keepalive) + + Returns: + AsyncRequest: Parsed request object + + Raises: + NoMoreData: If no data available + Various parsing errors for malformed requests + """ + req = cls(cfg, unreader, peer_addr, req_number) + await req._parse() + return req + + async def _parse(self): + """Parse the request from the unreader.""" + buf = io.BytesIO() + await self._get_data(buf, stop=True) + + # Get request line + line, rbuf = await self._read_line(buf, self.limit_request_line) + + # Proxy protocol + if self._proxy_protocol(bytes_to_str(line)): + # Get next request line + buf = io.BytesIO() + buf.write(rbuf) + line, rbuf = await self._read_line(buf, self.limit_request_line) + + self._parse_request_line(line) + buf = io.BytesIO() + buf.write(rbuf) + + # Headers + data = buf.getvalue() + + while True: + idx = data.find(b"\r\n\r\n") + done = data[:2] == b"\r\n" + + if idx < 0 and not done: + await self._get_data(buf) + data = buf.getvalue() + if len(data) > self.max_buffer_headers: + raise LimitRequestHeaders("max buffer headers") + else: + break + + if done: + self.unreader.unread(data[2:]) + else: + self.headers = self._parse_headers(data[:idx], from_trailer=False) + self.unreader.unread(data[idx + 4:]) + + self._set_body_reader() + + async def _get_data(self, buf, stop=False): + """Read data from unreader into buffer.""" + data = await self.unreader.read() + if not data: + if stop: + raise StopIteration() + raise NoMoreData(buf.getvalue()) + buf.write(data) + + async def _read_line(self, buf, limit=0): + """Read a line from the buffer/stream.""" + data = buf.getvalue() + + while True: + idx = data.find(b"\r\n") + if idx >= 0: + if idx > limit > 0: + raise LimitRequestLine(idx, limit) + break + if len(data) - 2 > limit > 0: + raise LimitRequestLine(len(data), limit) + await self._get_data(buf) + data = buf.getvalue() + + return (data[:idx], data[idx + 2:]) + + def _proxy_protocol(self, line): + """Detect, check and parse proxy protocol.""" + if not self.cfg.proxy_protocol: + return False + + if self.req_number != 1: + return False + + if not line.startswith("PROXY"): + return False + + self._proxy_protocol_access_check() + self._parse_proxy_protocol(line) + + return True + + def _proxy_protocol_access_check(self): + """Check if proxy protocol is allowed from this peer.""" + if ("*" not in self.cfg.proxy_allow_ips and + isinstance(self.peer_addr, tuple) and + self.peer_addr[0] not in self.cfg.proxy_allow_ips): + raise ForbiddenProxyRequest(self.peer_addr[0]) + + def _parse_proxy_protocol(self, line): + """Parse proxy protocol header line.""" + bits = line.split(" ") + + if len(bits) != 6: + raise InvalidProxyLine(line) + + proto = bits[1] + s_addr = bits[2] + d_addr = bits[3] + + if proto not in ["TCP4", "TCP6"]: + raise InvalidProxyLine("protocol '%s' not supported" % proto) + + if proto == "TCP4": + try: + socket.inet_pton(socket.AF_INET, s_addr) + socket.inet_pton(socket.AF_INET, d_addr) + except OSError: + raise InvalidProxyLine(line) + elif proto == "TCP6": + try: + socket.inet_pton(socket.AF_INET6, s_addr) + socket.inet_pton(socket.AF_INET6, d_addr) + except OSError: + raise InvalidProxyLine(line) + + try: + s_port = int(bits[4]) + d_port = int(bits[5]) + except ValueError: + raise InvalidProxyLine("invalid port %s" % line) + + if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)): + raise InvalidProxyLine("invalid port %s" % line) + + self.proxy_protocol_info = { + "proxy_protocol": proto, + "client_addr": s_addr, + "client_port": s_port, + "proxy_addr": d_addr, + "proxy_port": d_port + } + + def _parse_request_line(self, line_bytes): + """Parse the HTTP request line.""" + bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] + if len(bits) != 3: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + + # Method + self.method = bits[0] + + if not self.cfg.permit_unconventional_http_method: + if METHOD_BADCHAR_RE.search(self.method): + raise InvalidRequestMethod(self.method) + if not 3 <= len(bits[0]) <= 20: + raise InvalidRequestMethod(self.method) + if not TOKEN_RE.fullmatch(self.method): + raise InvalidRequestMethod(self.method) + if self.cfg.casefold_http_method: + self.method = self.method.upper() + + # URI + self.uri = bits[1] + + if len(self.uri) == 0: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + + try: + parts = split_request_uri(self.uri) + except ValueError: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + self.path = parts.path or "" + self.query = parts.query or "" + self.fragment = parts.fragment or "" + + # Version + match = VERSION_RE.fullmatch(bits[2]) + if match is None: + raise InvalidHTTPVersion(bits[2]) + self.version = (int(match.group(1)), int(match.group(2))) + if not (1, 0) <= self.version < (2, 0): + if not self.cfg.permit_unconventional_http_version: + raise InvalidHTTPVersion(self.version) + + def _parse_headers(self, data, from_trailer=False): + """Parse HTTP headers from raw data.""" + cfg = self.cfg + headers = [] + + lines = [bytes_to_str(line) for line in data.split(b"\r\n")] + + # Handle scheme headers + scheme_header = False + secure_scheme_headers = {} + forwarder_headers = [] + if from_trailer: + pass + elif ('*' in cfg.forwarded_allow_ips or + not isinstance(self.peer_addr, tuple) + or self.peer_addr[0] in cfg.forwarded_allow_ips): + secure_scheme_headers = cfg.secure_scheme_headers + forwarder_headers = cfg.forwarder_headers + + while lines: + if len(headers) >= self.limit_request_fields: + raise LimitRequestHeaders("limit request headers fields") + + curr = lines.pop(0) + header_length = len(curr) + len("\r\n") + if curr.find(":") <= 0: + raise InvalidHeader(curr) + name, value = curr.split(":", 1) + if self.cfg.strip_header_spaces: + name = name.rstrip(" \t") + if not TOKEN_RE.fullmatch(name): + raise InvalidHeaderName(name) + + name = name.upper() + value = [value.strip(" \t")] + + # Consume value continuation lines + while lines and lines[0].startswith((" ", "\t")): + if not self.cfg.permit_obsolete_folding: + raise ObsoleteFolding(name) + curr = lines.pop(0) + header_length += len(curr) + len("\r\n") + if header_length > self.limit_request_field_size > 0: + raise LimitRequestHeaders("limit request headers fields size") + value.append(curr.strip("\t ")) + value = " ".join(value) + + if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value): + raise InvalidHeader(name) + + if header_length > self.limit_request_field_size > 0: + raise LimitRequestHeaders("limit request headers fields size") + + if name in secure_scheme_headers: + secure = value == secure_scheme_headers[name] + scheme = "https" if secure else "http" + if scheme_header: + if scheme != self.scheme: + raise InvalidSchemeHeaders() + else: + scheme_header = True + self.scheme = scheme + + if "_" in name: + if name in forwarder_headers or "*" in forwarder_headers: + pass + elif self.cfg.header_map == "dangerous": + pass + elif self.cfg.header_map == "drop": + continue + else: + raise InvalidHeaderName(name) + + headers.append((name, value)) + + return headers + + def _set_body_reader(self): + """Determine how to read the request body.""" + chunked = False + content_length = None + + for (name, value) in self.headers: + if name == "CONTENT-LENGTH": + if content_length is not None: + raise InvalidHeader("CONTENT-LENGTH", req=self) + content_length = value + elif name == "TRANSFER-ENCODING": + vals = [v.strip() for v in value.split(',')] + for val in vals: + if val.lower() == "chunked": + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + chunked = True + elif val.lower() == "identity": + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + elif val.lower() in ('compress', 'deflate', 'gzip'): + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + self.force_close() + else: + raise UnsupportedTransferCoding(value) + + if chunked: + if self.version < (1, 1): + raise InvalidHeader("TRANSFER-ENCODING", req=self) + if content_length is not None: + raise InvalidHeader("CONTENT-LENGTH", req=self) + self.chunked = True + self.content_length = None + self._body_remaining = -1 + elif content_length is not None: + try: + if str(content_length).isnumeric(): + content_length = int(content_length) + else: + raise InvalidHeader("CONTENT-LENGTH", req=self) + except ValueError: + raise InvalidHeader("CONTENT-LENGTH", req=self) + + if content_length < 0: + raise InvalidHeader("CONTENT-LENGTH", req=self) + + self.content_length = content_length + self._body_remaining = content_length + else: + # No body for requests without Content-Length or Transfer-Encoding + self.content_length = 0 + self._body_remaining = 0 + + def force_close(self): + """Mark connection for closing after this request.""" + self.must_close = True + + def should_close(self): + """Check if connection should be closed after this request.""" + if self.must_close: + return True + for (h, v) in self.headers: + if h == "CONNECTION": + v = v.lower().strip(" \t") + if v == "close": + return True + elif v == "keep-alive": + return False + break + return self.version <= (1, 0) + + def get_header(self, name): + """Get a header value by name (case-insensitive).""" + name = name.upper() + for (h, v) in self.headers: + if h == name: + return v + return None + + async def read_body(self, size=8192): + """Read a chunk of the request body. + + Args: + size: Maximum bytes to read + + Returns: + bytes: Body data, empty bytes when body is exhausted + """ + if self._body_remaining == 0: + return b"" + + if self.chunked: + return await self._read_chunked_body(size) + else: + return await self._read_length_body(size) + + async def _read_length_body(self, size): + """Read from a length-delimited body.""" + if self._body_remaining <= 0: + return b"" + + to_read = min(size, self._body_remaining) + data = await self.unreader.read(to_read) + if data: + self._body_remaining -= len(data) + return data + + async def _read_chunked_body(self, size): + """Read from a chunked body.""" + if self._body_reader is None: + self._body_reader = self._chunked_body_reader() + + try: + return await self._body_reader.__anext__() + except StopAsyncIteration: + self._body_remaining = 0 + return b"" + + async def _chunked_body_reader(self): + """Async generator for reading chunked body.""" + while True: + # Read chunk size line + size_line = await self._read_chunk_size_line() + # Parse chunk size (handle extensions) + chunk_size, *_ = size_line.split(b";", 1) + if _ : + chunk_size = chunk_size.rstrip(b" \t") + + if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size): + raise InvalidHeader("Invalid chunk size") + if len(chunk_size) == 0: + raise InvalidHeader("Invalid chunk size") + + chunk_size = int(chunk_size, 16) + + if chunk_size == 0: + # Final chunk - skip trailers and final CRLF + await self._skip_trailers() + return + + # Read chunk data + remaining = chunk_size + while remaining > 0: + data = await self.unreader.read(min(remaining, 8192)) + if not data: + raise NoMoreData() + remaining -= len(data) + yield data + + # Skip chunk terminating CRLF + crlf = await self.unreader.read(2) + if crlf != b"\r\n": + # May have partial read, try to get the rest + while len(crlf) < 2: + more = await self.unreader.read(2 - len(crlf)) + if not more: + break + crlf += more + if crlf != b"\r\n": + raise InvalidHeader("Missing chunk terminator") + + async def _read_chunk_size_line(self): + """Read a chunk size line.""" + buf = io.BytesIO() + while True: + data = await self.unreader.read(1) + if not data: + raise NoMoreData() + buf.write(data) + if buf.getvalue().endswith(b"\r\n"): + return buf.getvalue()[:-2] + + async def _skip_trailers(self): + """Skip trailer headers after chunked body.""" + buf = io.BytesIO() + while True: + data = await self.unreader.read(1) + if not data: + return + buf.write(data) + content = buf.getvalue() + if content.endswith(b"\r\n\r\n"): + # Could parse trailers here if needed + return + if content == b"\r\n": + return + + async def drain_body(self): + """Drain any unread body data. + + Should be called before reusing connection for keepalive. + """ + while True: + data = await self.read_body(8192) + if not data: + break diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py new file mode 100644 index 00000000..cededd68 --- /dev/null +++ b/gunicorn/asgi/protocol.py @@ -0,0 +1,424 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI protocol handler for gunicorn. + +Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch +to ASGI applications. +""" + +import asyncio +import base64 +import hashlib +import traceback +from datetime import datetime + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest +from gunicorn.http.errors import NoMoreData + + +class ASGIProtocol(asyncio.Protocol): + """HTTP/1.1 protocol handler for ASGI applications. + + Handles connection lifecycle, request parsing, and ASGI app invocation. + """ + + def __init__(self, worker): + self.worker = worker + self.cfg = worker.cfg + self.log = worker.log + self.app = worker.asgi + + self.transport = None + self.reader = None + self.writer = None + self._task = None + self.req_count = 0 + + # Connection state + self._closed = False + + def connection_made(self, transport): + """Called when a connection is established.""" + self.transport = transport + self.worker.nr_conns += 1 + + # Create stream reader/writer + self.reader = asyncio.StreamReader() + self.writer = transport + + # Start handling requests + self._task = self.worker.loop.create_task(self._handle_connection()) + + def data_received(self, data): + """Called when data is received on the connection.""" + if self.reader: + self.reader.feed_data(data) + + def connection_lost(self, exc): + """Called when the connection is lost or closed.""" + self._closed = True + self.worker.nr_conns -= 1 + if self.reader: + self.reader.feed_eof() + if self._task and not self._task.done(): + self._task.cancel() + + async def _handle_connection(self): + """Main request handling loop for this connection.""" + unreader = AsyncUnreader(self.reader) + + try: + peername = self.transport.get_extra_info('peername') + sockname = self.transport.get_extra_info('sockname') + + while not self._closed: + self.req_count += 1 + + try: + # Parse HTTP request + request = await AsyncRequest.parse( + self.cfg, + unreader, + peername, + self.req_count + ) + except StopIteration: + # No more data, close connection + break + except NoMoreData: + # Client disconnected + break + + # Check for WebSocket upgrade + if self._is_websocket_upgrade(request): + await self._handle_websocket(request, sockname, peername) + break # WebSocket takes over the connection + else: + # Handle HTTP request + keepalive = await self._handle_http_request( + request, sockname, peername + ) + + # Increment worker request count + self.worker.nr += 1 + + # Check max_requests + if self.worker.nr >= self.worker.max_requests: + self.log.info("Autorestarting worker after current request.") + self.worker.alive = False + keepalive = False + + if not keepalive or not self.worker.alive: + break + + # Check connection limits for keepalive + if not self.cfg.keepalive: + break + + # Drain any unread body before next request + await request.drain_body() + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.exception("Error handling connection: %s", e) + finally: + self._close_transport() + + def _is_websocket_upgrade(self, request): + """Check if request is a WebSocket upgrade.""" + upgrade = None + connection = None + for name, value in request.headers: + if name == "UPGRADE": + upgrade = value.lower() + elif name == "CONNECTION": + connection = value.lower() + return upgrade == "websocket" and connection and "upgrade" in connection + + async def _handle_websocket(self, request, sockname, peername): + """Handle WebSocket upgrade request.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = self._build_websocket_scope(request, sockname, peername) + ws_protocol = WebSocketProtocol( + self.transport, self.reader, scope, self.app, self.log + ) + await ws_protocol.run() + + async def _handle_http_request(self, request, sockname, peername): + """Handle a single HTTP request.""" + scope = self._build_http_scope(request, sockname, peername) + response_started = False + response_complete = False + body_parts = [] + exc_to_raise = None + + # Receive queue for body + receive_queue = asyncio.Queue() + + # Pre-populate with initial body state + if request.content_length == 0 and not request.chunked: + await receive_queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + else: + # Start body reading task + asyncio.create_task(self._read_body_to_queue(request, receive_queue)) + + async def receive(): + return await receive_queue.get() + + async def send(message): + nonlocal response_started, response_complete, exc_to_raise + + msg_type = message["type"] + + if msg_type == "http.response.start": + if response_started: + exc_to_raise = RuntimeError("Response already started") + return + response_started = True + status = message["status"] + headers = message.get("headers", []) + await self._send_response_start(status, headers, request) + + elif msg_type == "http.response.body": + if not response_started: + exc_to_raise = RuntimeError("Response not started") + return + if response_complete: + exc_to_raise = RuntimeError("Response already complete") + return + + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body: + await self._send_body(body) + + if not more_body: + response_complete = True + + try: + request_start = datetime.now() + self.cfg.pre_request(self.worker, request) + + await self.app(scope, receive, send) + + if exc_to_raise: + raise exc_to_raise + + # Ensure response was sent + if not response_started: + await self._send_error_response(500, "Internal Server Error") + + except Exception as e: + self.log.exception("Error in ASGI application") + if not response_started: + await self._send_error_response(500, "Internal Server Error") + return False + finally: + try: + request_time = datetime.now() - request_start + self.cfg.post_request(self.worker, request, {}, None) + except Exception: + self.log.exception("Exception in post_request hook") + + # Determine keepalive + if request.should_close(): + return False + + return self.worker.alive and self.cfg.keepalive + + async def _read_body_to_queue(self, request, queue): + """Read request body and put chunks on the queue.""" + try: + while True: + chunk = await request.read_body(65536) + if chunk: + await queue.put({ + "type": "http.request", + "body": chunk, + "more_body": True, + }) + else: + await queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + break + except Exception as e: + self.log.debug("Error reading body: %s", e) + await queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + + def _build_http_scope(self, request, sockname, peername): + """Build ASGI HTTP scope from parsed request.""" + # Build headers list as bytes tuples + headers = [] + for name, value in request.headers: + headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + + scope = { + "type": "http", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "http_version": f"{request.version[0]}.{request.version[1]}", + "method": request.method, + "scheme": request.scheme, + "path": request.path, + "raw_path": request.path.encode("latin-1") if request.path else b"", + "query_string": request.query.encode("latin-1") if request.query else b"", + "root_path": self.cfg.root_path or "", + "headers": headers, + "server": sockname if sockname else None, + "client": peername if peername else None, + } + + # Add state dict for lifespan sharing + if hasattr(self.worker, 'state'): + scope["state"] = self.worker.state + + return scope + + def _build_websocket_scope(self, request, sockname, peername): + """Build ASGI WebSocket scope from parsed request.""" + # Build headers list as bytes tuples + headers = [] + for name, value in request.headers: + headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + + # Extract subprotocols from Sec-WebSocket-Protocol header + subprotocols = [] + for name, value in request.headers: + if name == "SEC-WEBSOCKET-PROTOCOL": + subprotocols = [s.strip() for s in value.split(",")] + break + + scope = { + "type": "websocket", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "http_version": f"{request.version[0]}.{request.version[1]}", + "scheme": "wss" if request.scheme == "https" else "ws", + "path": request.path, + "raw_path": request.path.encode("latin-1") if request.path else b"", + "query_string": request.query.encode("latin-1") if request.query else b"", + "root_path": self.cfg.root_path or "", + "headers": headers, + "server": sockname if sockname else None, + "client": peername if peername else None, + "subprotocols": subprotocols, + } + + # Add state dict for lifespan sharing + if hasattr(self.worker, 'state'): + scope["state"] = self.worker.state + + return scope + + async def _send_response_start(self, status, headers, request): + """Send HTTP response status and headers.""" + # Build status line + reason = self._get_reason_phrase(status) + status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n" + + # Build headers + header_lines = [] + has_content_length = False + has_transfer_encoding = False + has_connection = False + + for name, value in headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + header_lines.append(f"{name}: {value}\r\n") + name_lower = name.lower() + if name_lower == "content-length": + has_content_length = True + elif name_lower == "transfer-encoding": + has_transfer_encoding = True + elif name_lower == "connection": + has_connection = True + + # Add server header if not present + header_lines.append("Server: gunicorn/asgi\r\n") + + response = status_line + "".join(header_lines) + "\r\n" + self.transport.write(response.encode("latin-1")) + + async def _send_body(self, body): + """Send response body chunk.""" + if body: + self.transport.write(body) + + async def _send_error_response(self, status, message): + """Send an error response.""" + body = message.encode("utf-8") + response = ( + f"HTTP/1.1 {status} {message}\r\n" + f"Content-Type: text/plain\r\n" + f"Content-Length: {len(body)}\r\n" + f"Connection: close\r\n" + f"\r\n" + ) + self.transport.write(response.encode("latin-1")) + self.transport.write(body) + + def _get_reason_phrase(self, status): + """Get HTTP reason phrase for status code.""" + reasons = { + 100: "Continue", + 101: "Switching Protocols", + 200: "OK", + 201: "Created", + 202: "Accepted", + 204: "No Content", + 206: "Partial Content", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 413: "Payload Too Large", + 414: "URI Too Long", + 415: "Unsupported Media Type", + 422: "Unprocessable Entity", + 429: "Too Many Requests", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + } + return reasons.get(status, "Unknown") + + def _close_transport(self): + """Close the transport safely.""" + if self.transport and not self._closed: + try: + self.transport.close() + except Exception: + pass + self._closed = True diff --git a/gunicorn/asgi/unreader.py b/gunicorn/asgi/unreader.py new file mode 100644 index 00000000..c8d9aa82 --- /dev/null +++ b/gunicorn/asgi/unreader.py @@ -0,0 +1,100 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Async version of gunicorn/http/unreader.py for ASGI workers. + +Provides async reading with pushback buffer support. +""" + +import io + + +class AsyncUnreader: + """Async socket reader with pushback buffer support. + + This class wraps an asyncio StreamReader and provides the ability + to "unread" data back into a buffer for re-parsing. + """ + + def __init__(self, reader, max_chunk=8192): + """Initialize the async unreader. + + Args: + reader: asyncio.StreamReader instance + max_chunk: Maximum bytes to read at once + """ + self.reader = reader + self.buf = io.BytesIO() + self.max_chunk = max_chunk + + async def read(self, size=None): + """Read data from the stream, using buffered data first. + + Args: + size: Number of bytes to read. If None, returns all buffered + data or reads a single chunk. + + Returns: + bytes: Data read from buffer or stream + """ + if size is not None and not isinstance(size, int): + raise TypeError("size parameter must be an int or long.") + + if size is not None: + if size == 0: + return b"" + if size < 0: + size = None + + # Move to end to check buffer size + self.buf.seek(0, io.SEEK_END) + + # If no size specified, return buffered data or read chunk + if size is None and self.buf.tell(): + ret = self.buf.getvalue() + self.buf = io.BytesIO() + return ret + if size is None: + chunk = await self._read_chunk() + return chunk + + # Read until we have enough data + while self.buf.tell() < size: + chunk = await self._read_chunk() + if not chunk: + ret = self.buf.getvalue() + self.buf = io.BytesIO() + return ret + self.buf.write(chunk) + + data = self.buf.getvalue() + self.buf = io.BytesIO() + self.buf.write(data[size:]) + return data[:size] + + async def _read_chunk(self): + """Read a chunk of data from the underlying stream.""" + try: + return await self.reader.read(self.max_chunk) + except Exception: + return b"" + + def unread(self, data): + """Push data back into the buffer for re-reading. + + Args: + data: bytes to push back + """ + if data: + self.buf.seek(0, io.SEEK_END) + self.buf.write(data) + + def has_buffered_data(self): + """Check if there's data in the pushback buffer.""" + pos = self.buf.tell() + self.buf.seek(0, io.SEEK_END) + has_data = self.buf.tell() > 0 + self.buf.seek(pos) + return has_data diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py new file mode 100644 index 00000000..bcde84ee --- /dev/null +++ b/gunicorn/asgi/websocket.py @@ -0,0 +1,369 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +WebSocket protocol handler for ASGI. + +Implements RFC 6455 WebSocket protocol for ASGI applications. +""" + +import asyncio +import base64 +import hashlib +import struct +import os + + +# WebSocket frame opcodes +OPCODE_CONTINUATION = 0x0 +OPCODE_TEXT = 0x1 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA + +# WebSocket close codes +CLOSE_NORMAL = 1000 +CLOSE_GOING_AWAY = 1001 +CLOSE_PROTOCOL_ERROR = 1002 +CLOSE_UNSUPPORTED = 1003 +CLOSE_NO_STATUS = 1005 +CLOSE_ABNORMAL = 1006 +CLOSE_INVALID_DATA = 1007 +CLOSE_POLICY_VIOLATION = 1008 +CLOSE_MESSAGE_TOO_BIG = 1009 +CLOSE_MANDATORY_EXT = 1010 +CLOSE_INTERNAL_ERROR = 1011 + +# WebSocket handshake GUID (RFC 6455) +WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class WebSocketProtocol: + """WebSocket connection handler for ASGI applications.""" + + def __init__(self, transport, reader, scope, app, log): + """Initialize WebSocket protocol handler. + + Args: + transport: asyncio transport for writing + reader: asyncio StreamReader for reading + scope: ASGI WebSocket scope dict + app: ASGI application callable + log: Logger instance + """ + self.transport = transport + self.reader = reader + self.scope = scope + self.app = app + self.log = log + + self.accepted = False + self.closed = False + self.close_code = None + self.close_reason = "" + + # Message reassembly state + self._fragments = [] + self._fragment_opcode = None + + # Receive queue for incoming messages + self._receive_queue = asyncio.Queue() + + async def run(self): + """Run the WebSocket ASGI application.""" + # Send initial connect event + await self._receive_queue.put({"type": "websocket.connect"}) + + # Start frame reading task + read_task = asyncio.create_task(self._read_frames()) + + try: + await self.app(self.scope, self._receive, self._send) + except Exception as e: + self.log.exception("Error in WebSocket ASGI application") + finally: + read_task.cancel() + try: + await read_task + except asyncio.CancelledError: + pass + + # Send close frame if not already closed + if not self.closed and self.accepted: + await self._send_close(CLOSE_INTERNAL_ERROR, "Application error") + + async def _receive(self): + """ASGI receive callable.""" + return await self._receive_queue.get() + + async def _send(self, message): + """ASGI send callable.""" + msg_type = message["type"] + + if msg_type == "websocket.accept": + if self.accepted: + raise RuntimeError("WebSocket already accepted") + await self._send_accept(message) + self.accepted = True + + elif msg_type == "websocket.send": + if not self.accepted: + raise RuntimeError("WebSocket not accepted") + if self.closed: + raise RuntimeError("WebSocket closed") + + if "text" in message: + await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8")) + elif "bytes" in message: + await self._send_frame(OPCODE_BINARY, message["bytes"]) + + elif msg_type == "websocket.close": + code = message.get("code", CLOSE_NORMAL) + reason = message.get("reason", "") + await self._send_close(code, reason) + self.closed = True + + async def _send_accept(self, message): + """Send WebSocket handshake accept response.""" + # Get Sec-WebSocket-Key from headers + ws_key = None + for name, value in self.scope["headers"]: + if name == b"sec-websocket-key": + ws_key = value + break + + if not ws_key: + raise RuntimeError("Missing Sec-WebSocket-Key header") + + # Calculate accept key + accept_key = base64.b64encode( + hashlib.sha1(ws_key + WS_GUID).digest() + ).decode("ascii") + + # Build response headers + headers = [ + "HTTP/1.1 101 Switching Protocols\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + f"Sec-WebSocket-Accept: {accept_key}\r\n", + ] + + # Add selected subprotocol if specified + subprotocol = message.get("subprotocol") + if subprotocol: + headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n") + + # Add any extra headers from message + extra_headers = message.get("headers", []) + for name, value in extra_headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + headers.append(f"{name}: {value}\r\n") + + headers.append("\r\n") + self.transport.write("".join(headers).encode("latin-1")) + + async def _read_frames(self): + """Read and process incoming WebSocket frames.""" + try: + while not self.closed: + frame = await self._read_frame() + if frame is None: + break + + opcode, payload = frame + + if opcode == OPCODE_CLOSE: + await self._handle_close(payload) + break + elif opcode == OPCODE_PING: + await self._send_frame(OPCODE_PONG, payload) + elif opcode == OPCODE_PONG: + # Ignore pongs + pass + elif opcode == OPCODE_TEXT: + await self._receive_queue.put({ + "type": "websocket.receive", + "text": payload.decode("utf-8"), + }) + elif opcode == OPCODE_BINARY: + await self._receive_queue.put({ + "type": "websocket.receive", + "bytes": payload, + }) + elif opcode == OPCODE_CONTINUATION: + # Handle fragmented messages + await self._handle_continuation(payload) + + except asyncio.CancelledError: + raise + except Exception as e: + self.log.debug("WebSocket read error: %s", e) + finally: + # Signal disconnect + if not self.closed: + self.closed = True + await self._receive_queue.put({ + "type": "websocket.disconnect", + "code": self.close_code or CLOSE_ABNORMAL, + }) + + async def _read_frame(self): + """Read a single WebSocket frame. + + Returns: + tuple: (opcode, payload) or None if connection closed + """ + # Read frame header (2 bytes minimum) + header = await self._read_exact(2) + if not header: + return None + + first_byte, second_byte = header[0], header[1] + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0x0F + + # RSV bits must be 0 (no extensions) + if rsv1 or rsv2 or rsv3: + await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set") + return None + + masked = (second_byte >> 7) & 1 + payload_len = second_byte & 0x7F + + # Client frames must be masked (RFC 6455) + if not masked: + await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked") + return None + + # Extended payload length + if payload_len == 126: + ext_len = await self._read_exact(2) + if not ext_len: + return None + payload_len = struct.unpack("!H", ext_len)[0] + elif payload_len == 127: + ext_len = await self._read_exact(8) + if not ext_len: + return None + payload_len = struct.unpack("!Q", ext_len)[0] + + # Read masking key + masking_key = await self._read_exact(4) + if not masking_key: + return None + + # Read payload + payload = await self._read_exact(payload_len) + if payload is None: + return None + + # Unmask payload + payload = self._unmask(payload, masking_key) + + # Handle fragmented messages + if opcode == OPCODE_CONTINUATION: + if self._fragment_opcode is None: + await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation") + return None + self._fragments.append(payload) + if fin: + # Reassemble complete message + full_payload = b"".join(self._fragments) + final_opcode = self._fragment_opcode + self._fragments = [] + self._fragment_opcode = None + return (final_opcode, full_payload) + return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more + elif opcode in (OPCODE_TEXT, OPCODE_BINARY): + if not fin: + # Start of fragmented message + self._fragment_opcode = opcode + self._fragments = [payload] + return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more + return (opcode, payload) + else: + # Control frames + return (opcode, payload) + + async def _read_exact(self, n): + """Read exactly n bytes from the reader.""" + try: + data = await self.reader.readexactly(n) + return data + except asyncio.IncompleteReadError: + return None + except Exception: + return None + + def _unmask(self, payload, masking_key): + """Unmask WebSocket payload data.""" + if not payload: + return payload + # XOR each byte with corresponding mask byte + return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload)) + + async def _handle_close(self, payload): + """Handle incoming close frame.""" + if len(payload) >= 2: + self.close_code = struct.unpack("!H", payload[:2])[0] + self.close_reason = payload[2:].decode("utf-8", errors="replace") + else: + self.close_code = CLOSE_NO_STATUS + self.close_reason = "" + + # Echo close frame back if we haven't already sent one + if not self.closed: + await self._send_close(self.close_code, self.close_reason) + + self.closed = True + + async def _handle_continuation(self, payload): + """Handle continuation frame (already processed in _read_frame).""" + # This is called for partial fragments, nothing to do + pass + + async def _send_frame(self, opcode, payload): + """Send a WebSocket frame. + + Server frames are not masked (RFC 6455). + """ + if isinstance(payload, str): + payload = payload.encode("utf-8") + + length = len(payload) + frame = bytearray() + + # First byte: FIN + opcode + frame.append(0x80 | opcode) + + # Second byte: length (no mask bit for server) + if length < 126: + frame.append(length) + elif length < 65536: + frame.append(126) + frame.extend(struct.pack("!H", length)) + else: + frame.append(127) + frame.extend(struct.pack("!Q", length)) + + # Payload + frame.extend(payload) + + self.transport.write(bytes(frame)) + + async def _send_close(self, code, reason=""): + """Send a close frame.""" + payload = struct.pack("!H", code) + if reason: + payload += reason.encode("utf-8")[:123] # Max 125 bytes total + await self._send_frame(OPCODE_CLOSE, payload) + self.closed = True diff --git a/gunicorn/config.py b/gunicorn/config.py index 29b30ad2..522dcae9 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2440,3 +2440,94 @@ class HeaderMap(Setting): .. versionadded:: 22.0.0 """ + + +def validate_asgi_loop(val): + if val is None: + return "auto" + if not isinstance(val, str): + raise TypeError("Invalid type for casting: %s" % val) + val = val.lower().strip() + if val not in ("auto", "asyncio", "uvloop"): + raise ValueError("Invalid ASGI loop: %s" % val) + return val + + +def validate_asgi_lifespan(val): + if val is None: + return "auto" + if not isinstance(val, str): + raise TypeError("Invalid type for casting: %s" % val) + val = val.lower().strip() + if val not in ("auto", "on", "off"): + raise ValueError("Invalid ASGI lifespan: %s" % val) + return val + + +class ASGILoop(Setting): + name = "asgi_loop" + section = "Worker Processes" + cli = ["--asgi-loop"] + meta = "STRING" + validator = validate_asgi_loop + default = "auto" + desc = """\ + Event loop implementation for ASGI workers. + + - auto: Use uvloop if available, otherwise asyncio + - asyncio: Use Python's built-in asyncio event loop + - uvloop: Use uvloop (must be installed separately) + + This setting only affects the ``asgi`` worker type. + + uvloop typically provides better performance but requires + installing the uvloop package. + + .. versionadded:: 24.0.0 + """ + + +class ASGILifespan(Setting): + name = "asgi_lifespan" + section = "Worker Processes" + cli = ["--asgi-lifespan"] + meta = "STRING" + validator = validate_asgi_lifespan + default = "auto" + desc = """\ + Control ASGI lifespan protocol handling. + + - auto: Detect if app supports lifespan, enable if so + - on: Always run lifespan protocol (fail if unsupported) + - off: Never run lifespan protocol + + The lifespan protocol allows ASGI applications to run code at + startup and shutdown. This is essential for frameworks like + FastAPI that need to initialize database connections, caches, + or other resources. + + This setting only affects the ``asgi`` worker type. + + .. versionadded:: 24.0.0 + """ + + +class RootPath(Setting): + name = "root_path" + section = "Server Mechanics" + cli = ["--root-path"] + meta = "STRING" + validator = validate_string + default = "" + desc = """\ + The root path for ASGI applications. + + This is used to set the ``root_path`` in the ASGI scope, which + allows applications to know their mount point when behind a + reverse proxy. + + For example, if your application is mounted at ``/api``, set + this to ``/api``. + + .. versionadded:: 24.0.0 + """ diff --git a/gunicorn/workers/__init__.py b/gunicorn/workers/__init__.py index 3da5f85e..3beb0d70 100644 --- a/gunicorn/workers/__init__.py +++ b/gunicorn/workers/__init__.py @@ -11,4 +11,5 @@ SUPPORTED_WORKERS = { "gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker", "tornado": "gunicorn.workers.gtornado.TornadoWorker", "gthread": "gunicorn.workers.gthread.ThreadWorker", + "asgi": "gunicorn.workers.gasgi.ASGIWorker", } diff --git a/gunicorn/workers/gasgi.py b/gunicorn/workers/gasgi.py new file mode 100644 index 00000000..b0d57cf0 --- /dev/null +++ b/gunicorn/workers/gasgi.py @@ -0,0 +1,282 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI worker for gunicorn. + +Provides native asyncio-based ASGI support using gunicorn's own +HTTP parsing infrastructure. +""" + +import asyncio +import os +import signal +import ssl +import sys + +from gunicorn.workers import base +from gunicorn.asgi.protocol import ASGIProtocol + + +class ASGIWorker(base.Worker): + """ASGI worker using asyncio event loop. + + Supports: + - HTTP/1.1 with keepalive + - WebSocket connections + - Lifespan protocol (startup/shutdown hooks) + - Optional uvloop for improved performance + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.worker_connections = self.cfg.worker_connections + self.loop = None + self.servers = [] + self.nr_conns = 0 + self.lifespan = None + self.state = {} # Shared state for lifespan + + @classmethod + def check_config(cls, cfg, log): + """Validate configuration for ASGI worker.""" + if cfg.threads > 1: + log.warning("ASGI worker does not use threads configuration. " + "Use worker_connections instead.") + + def init_process(self): + """Initialize the worker process.""" + # Setup event loop before calling super() + self._setup_event_loop() + super().init_process() + + def _setup_event_loop(self): + """Setup the asyncio event loop.""" + loop_type = getattr(self.cfg, 'asgi_loop', 'auto') + + if loop_type == "auto": + try: + import uvloop + loop_type = "uvloop" + except ImportError: + loop_type = "asyncio" + + if loop_type == "uvloop": + try: + import uvloop + self.loop = uvloop.new_event_loop() + self.log.debug("Using uvloop event loop") + except ImportError: + self.log.warning("uvloop not available, falling back to asyncio") + self.loop = asyncio.new_event_loop() + else: + self.loop = asyncio.new_event_loop() + self.log.debug("Using asyncio event loop") + + asyncio.set_event_loop(self.loop) + + def load_wsgi(self): + """Load the ASGI application.""" + try: + self.asgi = self.app.wsgi() + except SyntaxError as e: + if not self.cfg.reload: + raise + self.log.exception(e) + self.asgi = self._make_error_app(str(e)) + + def _make_error_app(self, error_msg): + """Create an error ASGI app for syntax errors during reload.""" + async def error_app(scope, receive, send): + if scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 500, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": f"Application error: {error_msg}".encode(), + }) + elif scope["type"] == "lifespan": + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + message = await receive() + if message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return error_app + + def init_signals(self): + """Initialize signal handlers for asyncio.""" + # Reset all signals first + for s in self.SIGNALS: + signal.signal(s, signal.SIG_DFL) + + # Set up signal handlers via the event loop + self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal) + self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal) + self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal) + self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal) + self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal) + + def handle_quit_signal(self): + """Handle SIGQUIT - immediate shutdown.""" + self.alive = False + self.cfg.worker_int(self) + + def handle_exit_signal(self): + """Handle SIGTERM - graceful shutdown.""" + self.alive = False + + def handle_usr1_signal(self): + """Handle SIGUSR1 - reopen log files.""" + self.log.reopen_files() + + def handle_winch_signal(self): + """Handle SIGWINCH - ignored in worker.""" + self.log.debug("worker: SIGWINCH ignored.") + + def handle_abort_signal(self): + """Handle SIGABRT - abort.""" + self.alive = False + self.cfg.worker_abort(self) + sys.exit(1) + + def run(self): + """Main entry point for the worker.""" + try: + self.loop.run_until_complete(self._serve()) + except Exception as e: + self.log.exception("Worker exception: %s", e) + finally: + self._cleanup() + + async def _serve(self): + """Main async serving loop.""" + # Run lifespan startup + lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto') + if lifespan_mode != "off": + from gunicorn.asgi.lifespan import LifespanManager + self.lifespan = LifespanManager(self.asgi, self.log, self.state) + try: + await self.lifespan.startup() + except Exception as e: + if lifespan_mode == "on": + self.log.error("ASGI lifespan startup failed: %s", e) + return + else: + # auto mode - app doesn't support lifespan + self.log.debug("ASGI lifespan not supported by app: %s", e) + self.lifespan = None + + # Create servers for each listener socket + ssl_context = self._get_ssl_context() + + for sock in self.sockets: + try: + server = await self.loop.create_server( + lambda: ASGIProtocol(self), + sock=sock.sock, + ssl=ssl_context, + reuse_address=True, + start_serving=True, + ) + self.servers.append(server) + self.log.info("ASGI server listening on %s", sock) + except Exception as e: + self.log.error("Failed to create server on %s: %s", sock, e) + + if not self.servers: + self.log.error("No servers could be started") + return + + # Main loop with heartbeat + try: + while self.alive: + self.notify() + + # Check if parent is still alive + if self.ppid != os.getppid(): + self.log.info("Parent changed, shutting down: %s", self) + break + + # Check connection limit + # (Connections are managed by nr_conns in ASGIProtocol) + + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + + # Graceful shutdown + await self._shutdown() + + async def _shutdown(self): + """Perform graceful shutdown.""" + self.log.info("Worker shutting down...") + + # Stop accepting new connections + for server in self.servers: + server.close() + + # Wait for servers to close + for server in self.servers: + await server.wait_closed() + + # Wait for in-flight connections (with timeout) + graceful_timeout = self.cfg.graceful_timeout + if self.nr_conns > 0: + self.log.info("Waiting for %d connections to finish...", self.nr_conns) + deadline = self.loop.time() + graceful_timeout + while self.nr_conns > 0 and self.loop.time() < deadline: + await asyncio.sleep(0.1) + + if self.nr_conns > 0: + self.log.warning("Closing %d connections after timeout", self.nr_conns) + + # Run lifespan shutdown + if self.lifespan: + try: + await self.lifespan.shutdown() + except Exception as e: + self.log.error("ASGI lifespan shutdown error: %s", e) + + def _get_ssl_context(self): + """Get SSL context if configured.""" + if not self.cfg.is_ssl: + return None + + try: + from gunicorn import sock + return sock.ssl_context(self.cfg) + except Exception as e: + self.log.error("Failed to create SSL context: %s", e) + return None + + def _cleanup(self): + """Clean up resources on exit.""" + try: + # Cancel all pending tasks + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + + # Run loop until all tasks are cancelled + if pending: + self.loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + + self.loop.close() + except Exception as e: + self.log.debug("Cleanup error: %s", e) + + # Close sockets + for s in self.sockets: + try: + s.close() + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml index ce681f65..7803dc55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ testing = [ "coverage", "pytest", "pytest-cov", + "pytest-asyncio", ] [project.scripts] diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 00000000..227f7ea2 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,285 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for ASGI worker components. +""" + +import asyncio +import io +import pytest +from unittest import mock + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest + + +class MockStreamReader: + """Mock asyncio.StreamReader for testing.""" + + def __init__(self, data): + self.data = data + self.pos = 0 + + async def read(self, size=-1): + if self.pos >= len(self.data): + return b"" + if size < 0: + result = self.data[self.pos:] + self.pos = len(self.data) + else: + result = self.data[self.pos:self.pos + size] + self.pos += size + return result + + async def readexactly(self, n): + if self.pos + n > len(self.data): + raise asyncio.IncompleteReadError( + self.data[self.pos:], n + ) + result = self.data[self.pos:self.pos + n] + self.pos += n + return result + + +class MockConfig: + """Mock gunicorn config for testing.""" + + def __init__(self): + self.is_ssl = False + self.proxy_protocol = False + self.proxy_allow_ips = ["127.0.0.1"] + self.forwarded_allow_ips = ["127.0.0.1"] + self.secure_scheme_headers = {} + self.forwarder_headers = [] + self.limit_request_line = 8190 + self.limit_request_fields = 100 + self.limit_request_field_size = 8190 + self.permit_unconventional_http_method = False + self.permit_unconventional_http_version = False + self.permit_obsolete_folding = False + self.casefold_http_method = False + self.strip_header_spaces = False + self.header_map = "refuse" + + +# AsyncUnreader Tests + +@pytest.mark.asyncio +async def test_async_unreader_read_chunk(): + """Test basic chunk reading.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + data = await unreader.read() + assert data == b"hello world" + + +@pytest.mark.asyncio +async def test_async_unreader_read_size(): + """Test reading specific size.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + data = await unreader.read(5) + assert data == b"hello" + + +@pytest.mark.asyncio +async def test_async_unreader_unread(): + """Test unread functionality.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + + # Read all data + data = await unreader.read() + assert data == b"hello world" + + # Unread some data + unreader.unread(b"world") + + # Read again should get unread data + data = await unreader.read() + assert data == b"world" + + +@pytest.mark.asyncio +async def test_async_unreader_read_zero(): + """Test reading zero bytes.""" + reader = MockStreamReader(b"hello") + unreader = AsyncUnreader(reader) + data = await unreader.read(0) + assert data == b"" + + +@pytest.mark.asyncio +async def test_async_unreader_read_empty(): + """Test reading from empty stream.""" + reader = MockStreamReader(b"") + unreader = AsyncUnreader(reader) + data = await unreader.read() + assert data == b"" + + +# AsyncRequest Tests + +@pytest.mark.asyncio +async def test_async_request_simple_get(): + """Test parsing a simple GET request.""" + request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "GET" + assert request.path == "/path" + assert request.version == (1, 1) + assert ("HOST", "localhost") in request.headers + + +@pytest.mark.asyncio +async def test_async_request_with_query(): + """Test parsing request with query string.""" + request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "GET" + assert request.path == "/search" + assert request.query == "q=test&page=1" + + +@pytest.mark.asyncio +async def test_async_request_post_with_body(): + """Test parsing POST request with body.""" + request_data = ( + b"POST /submit HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 11\r\n" + b"\r\n" + b"hello=world" + ) + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "POST" + assert request.path == "/submit" + assert request.content_length == 11 + + # Read body + body = await request.read_body(100) + assert body == b"hello=world" + + +@pytest.mark.asyncio +async def test_async_request_multiple_headers(): + """Test parsing request with multiple headers.""" + request_data = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept: text/html\r\n" + b"Accept-Language: en-US\r\n" + b"Connection: keep-alive\r\n" + b"\r\n" + ) + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert len(request.headers) == 4 + assert request.get_header("HOST") == "localhost" + assert request.get_header("ACCEPT") == "text/html" + + +@pytest.mark.asyncio +async def test_async_request_should_close_http10(): + """Test connection close detection for HTTP/1.0.""" + request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.version == (1, 0) + assert request.should_close() is True + + +@pytest.mark.asyncio +async def test_async_request_should_close_connection_header(): + """Test connection close detection with Connection header.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.should_close() is True + + +@pytest.mark.asyncio +async def test_async_request_keepalive(): + """Test keepalive detection.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.should_close() is False + + +@pytest.mark.asyncio +async def test_async_request_no_body_for_get(): + """Test that GET requests have no body by default.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.content_length == 0 + body = await request.read_body() + assert body == b"" + + +# Error handling tests + +@pytest.mark.asyncio +async def test_async_request_invalid_method(): + """Test invalid HTTP method detection.""" + from gunicorn.http.errors import InvalidRequestMethod + + request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + with pytest.raises(InvalidRequestMethod): + await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + +@pytest.mark.asyncio +async def test_async_request_invalid_http_version(): + """Test invalid HTTP version detection.""" + from gunicorn.http.errors import InvalidHTTPVersion + + request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + with pytest.raises(InvalidHTTPVersion): + await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) diff --git a/tests/test_asgi_worker.py b/tests/test_asgi_worker.py new file mode 100644 index 00000000..9266af4d --- /dev/null +++ b/tests/test_asgi_worker.py @@ -0,0 +1,643 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for the ASGI worker. + +Includes unit tests for worker components and integration tests +that actually start the server and make HTTP requests. +""" + +import asyncio +import errno +import os +import signal +import socket +import sys +import time +import threading +from unittest import mock + +import pytest + +from gunicorn.config import Config +from gunicorn.workers import gasgi + + +# ============================================================================ +# Mock Classes +# ============================================================================ + +class FakeSocket: + """Mock socket for testing.""" + + def __init__(self, data=b''): + self.data = data + self.closed = False + self.blocking = True + self._fileno = id(self) % 65536 + + def fileno(self): + return self._fileno + + def setblocking(self, blocking): + self.blocking = blocking + + def recv(self, size): + if self.closed: + raise OSError(errno.EBADF, "Bad file descriptor") + result = self.data[:size] + self.data = self.data[size:] + return result + + def send(self, data): + if self.closed: + raise OSError(errno.EPIPE, "Broken pipe") + return len(data) + + def close(self): + self.closed = True + + def getsockname(self): + return ('127.0.0.1', 8000) + + def getpeername(self): + return ('127.0.0.1', 12345) + + +class FakeApp: + """Mock ASGI application for testing.""" + + def __init__(self): + self.calls = [] + + def wsgi(self): + return self.asgi_app + + async def asgi_app(self, scope, receive, send): + self.calls.append(scope) + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + elif scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Hello from ASGI!", + }) + + +class FakeListener: + """Mock listener socket.""" + + def __init__(self): + self.sock = FakeSocket() + + def getsockname(self): + return ('127.0.0.1', 8000) + + def close(self): + self.sock.close() + + def __str__(self): + return "http://127.0.0.1:8000" + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def _has_uvloop(): + """Check if uvloop is available.""" + try: + import uvloop + return True + except ImportError: + return False + + +# ============================================================================ +# Unit Tests for ASGIWorker +# ============================================================================ + +class TestASGIWorkerInit: + """Tests for ASGIWorker initialization.""" + + def create_worker(self, **kwargs): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + + for key, value in kwargs.items(): + cfg.set(key, value) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + return worker + + def test_worker_init(self): + """Test worker initialization.""" + worker = self.create_worker() + + assert worker.worker_connections == 1000 + assert worker.nr_conns == 0 + assert worker.loop is None + assert worker.servers == [] + assert worker.state == {} + + def test_worker_connections_config(self): + """Test worker_connections configuration.""" + worker = self.create_worker(worker_connections=500) + assert worker.worker_connections == 500 + + +class TestASGIWorkerEventLoop: + """Tests for event loop setup.""" + + def create_worker(self, **kwargs): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + + for key, value in kwargs.items(): + cfg.set(key, value) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + return worker + + def test_setup_asyncio_loop(self): + """Test asyncio event loop setup.""" + worker = self.create_worker(asgi_loop='asyncio') + worker._setup_event_loop() + + assert worker.loop is not None + assert isinstance(worker.loop, asyncio.AbstractEventLoop) + worker.loop.close() + + def test_setup_auto_loop_falls_back_to_asyncio(self): + """Test that auto mode uses asyncio when uvloop unavailable.""" + worker = self.create_worker(asgi_loop='auto') + + # Mock uvloop import failure + with mock.patch.dict('sys.modules', {'uvloop': None}): + worker._setup_event_loop() + + assert worker.loop is not None + worker.loop.close() + + @pytest.mark.skipif( + not _has_uvloop(), + reason="uvloop not installed" + ) + def test_setup_uvloop(self): + """Test uvloop event loop setup.""" + worker = self.create_worker(asgi_loop='uvloop') + worker._setup_event_loop() + + import uvloop + assert isinstance(worker.loop, uvloop.Loop) + worker.loop.close() + + +class TestASGIWorkerSignals: + """Tests for signal handling.""" + + def create_worker(self): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + cfg.set('graceful_timeout', 5) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + worker._setup_event_loop() + return worker + + def test_handle_exit_sets_alive_false(self): + """Test that exit signal sets alive=False.""" + worker = self.create_worker() + worker.alive = True + + worker.handle_exit_signal() + + assert worker.alive is False + worker.loop.close() + + def test_handle_quit_sets_alive_false(self): + """Test that quit signal sets alive=False.""" + worker = self.create_worker() + worker.alive = True + + # Mock the worker_int callback on the worker's cfg settings + with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None): + worker.handle_quit_signal() + + assert worker.alive is False + worker.loop.close() + + +# ============================================================================ +# Tests for Lifespan Protocol +# ============================================================================ + +class TestLifespanManager: + """Tests for ASGI lifespan protocol.""" + + @pytest.mark.asyncio + async def test_lifespan_startup_complete(self): + """Test successful lifespan startup.""" + from gunicorn.asgi.lifespan import LifespanManager + + startup_called = False + shutdown_called = False + + async def app(scope, receive, send): + nonlocal startup_called, shutdown_called + assert scope["type"] == "lifespan" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + startup_called = True + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + shutdown_called = True + await send({"type": "lifespan.shutdown.complete"}) + return + + manager = LifespanManager(app, mock.Mock()) + await manager.startup() + + assert startup_called + assert manager._startup_complete.is_set() + assert not manager._startup_failed + + await manager.shutdown() + assert shutdown_called + + @pytest.mark.asyncio + async def test_lifespan_startup_failed(self): + """Test lifespan startup failure.""" + from gunicorn.asgi.lifespan import LifespanManager + + async def app(scope, receive, send): + message = await receive() + if message["type"] == "lifespan.startup": + await send({ + "type": "lifespan.startup.failed", + "message": "Database connection failed" + }) + + manager = LifespanManager(app, mock.Mock()) + + with pytest.raises(RuntimeError, match="Database connection failed"): + await manager.startup() + + @pytest.mark.asyncio + async def test_lifespan_state_shared(self): + """Test that lifespan state is shared with app.""" + from gunicorn.asgi.lifespan import LifespanManager + + state = {} + + async def app(scope, receive, send): + assert "state" in scope + scope["state"]["db"] = "connected" + message = await receive() + await send({"type": "lifespan.startup.complete"}) + message = await receive() + await send({"type": "lifespan.shutdown.complete"}) + + manager = LifespanManager(app, mock.Mock(), state) + await manager.startup() + + assert state.get("db") == "connected" + + await manager.shutdown() + + +# ============================================================================ +# Tests for WebSocket Protocol +# ============================================================================ + +class TestWebSocketProtocol: + """Tests for WebSocket protocol handling.""" + + def test_websocket_guid(self): + """Test WebSocket GUID constant.""" + from gunicorn.asgi.websocket import WS_GUID + assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + def test_websocket_opcodes(self): + """Test WebSocket opcode constants.""" + from gunicorn.asgi import websocket + + assert websocket.OPCODE_TEXT == 0x1 + assert websocket.OPCODE_BINARY == 0x2 + assert websocket.OPCODE_CLOSE == 0x8 + assert websocket.OPCODE_PING == 0x9 + assert websocket.OPCODE_PONG == 0xA + + def test_websocket_accept_key_calculation(self): + """Test WebSocket accept key calculation per RFC 6455.""" + import base64 + import hashlib + from gunicorn.asgi.websocket import WS_GUID + + # Example from RFC 6455 + client_key = b"dGhlIHNhbXBsZSBub25jZQ==" + expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + + accept_key = base64.b64encode( + hashlib.sha1(client_key + WS_GUID).digest() + ).decode("ascii") + + assert accept_key == expected_accept + + def test_websocket_frame_masking(self): + """Test WebSocket frame unmasking.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + # Create a minimal protocol instance + protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + + # Test unmasking (XOR operation) + masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked + + unmasked = protocol._unmask(masked_data, masking_key) + assert unmasked == b"Hello" + + def test_websocket_frame_masking_empty(self): + """Test WebSocket frame unmasking with empty payload.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + + masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + unmasked = protocol._unmask(b"", masking_key) + assert unmasked == b"" + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestASGIIntegration: + """Integration tests that start actual servers.""" + + @pytest.fixture + def free_port(self): + """Get a free port for testing.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + + @pytest.mark.asyncio + async def test_http_request_response(self, free_port): + """Test basic HTTP request/response cycle.""" + # Simple ASGI app + async def app(scope, receive, send): + if scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Hello, World!", + }) + + # Start server + loop = asyncio.get_event_loop() + server = await loop.create_server( + lambda: _TestProtocol(app), + '127.0.0.1', + free_port, + ) + + try: + # Use asyncio to make HTTP request + reader, writer = await asyncio.open_connection('127.0.0.1', free_port) + request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n" + writer.write(request.encode()) + await writer.drain() + + # Read response + response = await reader.read(4096) + response_text = response.decode() + + assert "HTTP/1.1 200" in response_text + assert "Hello, World!" in response_text + + writer.close() + await writer.wait_closed() + finally: + server.close() + await server.wait_closed() + + +class _TestProtocol(asyncio.Protocol): + """Minimal protocol for integration testing.""" + + def __init__(self, app): + self.app = app + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + # Very simple HTTP parsing for testing + asyncio.create_task(self._handle(data)) + + async def _handle(self, data): + # Parse basic HTTP request + lines = data.decode().split('\r\n') + method, path, _ = lines[0].split(' ') + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "path": path, + "query_string": b"", + "headers": [], + "server": ("127.0.0.1", 8000), + "client": ("127.0.0.1", 12345), + } + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + if message["type"] == "http.response.start": + status = message["status"] + headers = message.get("headers", []) + response = f"HTTP/1.1 {status} OK\r\n" + for name, value in headers: + if isinstance(name, bytes): + name = name.decode() + if isinstance(value, bytes): + value = value.decode() + response += f"{name}: {value}\r\n" + response += "\r\n" + self.transport.write(response.encode()) + elif message["type"] == "http.response.body": + body = message.get("body", b"") + self.transport.write(body) + if not message.get("more_body", False): + self.transport.close() + + await self.app(scope, receive, send) + + +# ============================================================================ +# ASGI Protocol Tests +# ============================================================================ + +class TestASGIProtocol: + """Tests for ASGIProtocol.""" + + def test_reason_phrases(self): + """Test HTTP reason phrase lookup.""" + from gunicorn.asgi.protocol import ASGIProtocol + + # Create minimal worker mock + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + protocol = ASGIProtocol(worker) + + assert protocol._get_reason_phrase(200) == "OK" + assert protocol._get_reason_phrase(404) == "Not Found" + assert protocol._get_reason_phrase(500) == "Internal Server Error" + assert protocol._get_reason_phrase(999) == "Unknown" + + def test_scope_building(self): + """Test HTTP scope building.""" + from gunicorn.asgi.protocol import ASGIProtocol + from gunicorn.asgi.message import AsyncRequest + + worker = mock.Mock() + worker.cfg = Config() + worker.cfg.set('root_path', '/api') + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + protocol = ASGIProtocol(worker) + + # Create mock request + request = mock.Mock() + request.method = "GET" + request.path = "/users" + request.query = "page=1" + request.version = (1, 1) + request.scheme = "http" + request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")] + + scope = protocol._build_http_scope( + request, + ("127.0.0.1", 8000), # sockname + ("127.0.0.1", 12345), # peername + ) + + assert scope["type"] == "http" + assert scope["method"] == "GET" + assert scope["path"] == "/users" + assert scope["query_string"] == b"page=1" + assert scope["root_path"] == "/api" + assert scope["http_version"] == "1.1" + + +# ============================================================================ +# Config Tests +# ============================================================================ + +class TestASGIConfig: + """Tests for ASGI configuration options.""" + + def test_asgi_loop_default(self): + """Test default asgi_loop value.""" + cfg = Config() + assert cfg.asgi_loop == "auto" + + def test_asgi_loop_validation(self): + """Test asgi_loop validation.""" + cfg = Config() + + cfg.set('asgi_loop', 'asyncio') + assert cfg.asgi_loop == 'asyncio' + + cfg.set('asgi_loop', 'uvloop') + assert cfg.asgi_loop == 'uvloop' + + with pytest.raises(ValueError): + cfg.set('asgi_loop', 'invalid') + + def test_asgi_lifespan_default(self): + """Test default asgi_lifespan value.""" + cfg = Config() + assert cfg.asgi_lifespan == "auto" + + def test_asgi_lifespan_validation(self): + """Test asgi_lifespan validation.""" + cfg = Config() + + cfg.set('asgi_lifespan', 'on') + assert cfg.asgi_lifespan == 'on' + + cfg.set('asgi_lifespan', 'off') + assert cfg.asgi_lifespan == 'off' + + with pytest.raises(ValueError): + cfg.set('asgi_lifespan', 'invalid') + + def test_root_path_default(self): + """Test default root_path value.""" + cfg = Config() + assert cfg.root_path == "" + + def test_root_path_setting(self): + """Test root_path configuration.""" + cfg = Config() + cfg.set('root_path', '/api/v1') + assert cfg.root_path == '/api/v1' From 11c6a97c47f2a4a16cf683ae0f224b3880fca6e4 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 18:03:14 +0100 Subject: [PATCH 2/8] asgi: Fix pylint and pycodestyle warnings - Remove unused imports (ssl, os, base64, hashlib, traceback) - Remove unused variables (body_parts, has_content_length, etc.) - Fix no-else-break patterns in protocol.py and websocket.py - Replace __anext__() with anext() builtin - Remove unnecessary pass statements - Add proper access logging to ASGI protocol handler - Add ASGIResponseInfo class and _build_environ method for logging - Disable too-many-return-statements for _read_frame method - Fix raising-bad-type error (use 'is not None' check) - Fix whitespace before colon in message.py --- gunicorn/asgi/message.py | 4 +- gunicorn/asgi/protocol.py | 114 ++++++++++++++++++++++++------------- gunicorn/asgi/websocket.py | 13 ++--- gunicorn/workers/gasgi.py | 1 - 4 files changed, 83 insertions(+), 49 deletions(-) diff --git a/gunicorn/asgi/message.py b/gunicorn/asgi/message.py index d7d20c83..a2d8e825 100644 --- a/gunicorn/asgi/message.py +++ b/gunicorn/asgi/message.py @@ -477,7 +477,7 @@ class AsyncRequest: self._body_reader = self._chunked_body_reader() try: - return await self._body_reader.__anext__() + return await anext(self._body_reader) except StopAsyncIteration: self._body_remaining = 0 return b"" @@ -489,7 +489,7 @@ class AsyncRequest: size_line = await self._read_chunk_size_line() # Parse chunk size (handle extensions) chunk_size, *_ = size_line.split(b";", 1) - if _ : + if _: chunk_size = chunk_size.rstrip(b" \t") if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size): diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index cededd68..0eb1d045 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -10,9 +10,6 @@ to ASGI applications. """ import asyncio -import base64 -import hashlib -import traceback from datetime import datetime from gunicorn.asgi.unreader import AsyncUnreader @@ -20,6 +17,22 @@ from gunicorn.asgi.message import AsyncRequest from gunicorn.http.errors import NoMoreData +class ASGIResponseInfo: + """Simple container for ASGI response info for access logging.""" + + def __init__(self, status, headers, sent): + self.status = status + self.sent = sent + # Convert headers to list of string tuples for logging + self.headers = [] + for name, value in headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + self.headers.append((name, value)) + + class ASGIProtocol(asyncio.Protocol): """HTTP/1.1 protocol handler for ASGI applications. @@ -97,30 +110,30 @@ class ASGIProtocol(asyncio.Protocol): if self._is_websocket_upgrade(request): await self._handle_websocket(request, sockname, peername) break # WebSocket takes over the connection - else: - # Handle HTTP request - keepalive = await self._handle_http_request( - request, sockname, peername - ) - # Increment worker request count - self.worker.nr += 1 + # Handle HTTP request + keepalive = await self._handle_http_request( + request, sockname, peername + ) - # Check max_requests - if self.worker.nr >= self.worker.max_requests: - self.log.info("Autorestarting worker after current request.") - self.worker.alive = False - keepalive = False + # Increment worker request count + self.worker.nr += 1 - if not keepalive or not self.worker.alive: - break + # Check max_requests + if self.worker.nr >= self.worker.max_requests: + self.log.info("Autorestarting worker after current request.") + self.worker.alive = False + keepalive = False - # Check connection limits for keepalive - if not self.cfg.keepalive: - break + if not keepalive or not self.worker.alive: + break - # Drain any unread body before next request - await request.drain_body() + # Check connection limits for keepalive + if not self.cfg.keepalive: + break + + # Drain any unread body before next request + await request.drain_body() except asyncio.CancelledError: pass @@ -155,9 +168,13 @@ class ASGIProtocol(asyncio.Protocol): scope = self._build_http_scope(request, sockname, peername) response_started = False response_complete = False - body_parts = [] exc_to_raise = None + # Response tracking for access logging + response_status = 500 + response_headers = [] + response_sent = 0 + # Receive queue for body receive_queue = asyncio.Queue() @@ -177,6 +194,7 @@ class ASGIProtocol(asyncio.Protocol): async def send(message): nonlocal response_started, response_complete, exc_to_raise + nonlocal response_status, response_headers, response_sent msg_type = message["type"] @@ -185,9 +203,9 @@ class ASGIProtocol(asyncio.Protocol): exc_to_raise = RuntimeError("Response already started") return response_started = True - status = message["status"] - headers = message.get("headers", []) - await self._send_response_start(status, headers, request) + response_status = message["status"] + response_headers = message.get("headers", []) + await self._send_response_start(response_status, response_headers, request) elif msg_type == "http.response.body": if not response_started: @@ -202,32 +220,42 @@ class ASGIProtocol(asyncio.Protocol): if body: await self._send_body(body) + response_sent += len(body) if not more_body: response_complete = True + # Build environ for logging + environ = self._build_environ(request, sockname, peername) + resp = None + try: request_start = datetime.now() self.cfg.pre_request(self.worker, request) await self.app(scope, receive, send) - if exc_to_raise: + if exc_to_raise is not None: raise exc_to_raise # Ensure response was sent if not response_started: await self._send_error_response(500, "Internal Server Error") + response_status = 500 - except Exception as e: + except Exception: self.log.exception("Error in ASGI application") if not response_started: await self._send_error_response(500, "Internal Server Error") + response_status = 500 return False finally: try: request_time = datetime.now() - request_start - self.cfg.post_request(self.worker, request, {}, None) + # Create response info for logging + resp = ASGIResponseInfo(response_status, response_headers, response_sent) + self.log.access(resp, request, environ, request_time) + self.cfg.post_request(self.worker, request, environ, resp) except Exception: self.log.exception("Exception in post_request hook") @@ -291,6 +319,24 @@ class ASGIProtocol(asyncio.Protocol): return scope + def _build_environ(self, request, sockname, peername): + """Build minimal WSGI-like environ dict for access logging.""" + environ = { + "REQUEST_METHOD": request.method, + "RAW_URI": request.uri, + "PATH_INFO": request.path, + "QUERY_STRING": request.query or "", + "SERVER_PROTOCOL": f"HTTP/{request.version[0]}.{request.version[1]}", + "REMOTE_ADDR": peername[0] if peername else "-", + } + + # Add HTTP headers as environ vars + for name, value in request.headers: + key = "HTTP_" + name.replace("-", "_") + environ[key] = value + + return environ + def _build_websocket_scope(self, request, sockname, peername): """Build ASGI WebSocket scope from parsed request.""" # Build headers list as bytes tuples @@ -334,9 +380,6 @@ class ASGIProtocol(asyncio.Protocol): # Build headers header_lines = [] - has_content_length = False - has_transfer_encoding = False - has_connection = False for name, value in headers: if isinstance(name, bytes): @@ -344,13 +387,6 @@ class ASGIProtocol(asyncio.Protocol): if isinstance(value, bytes): value = value.decode("latin-1") header_lines.append(f"{name}: {value}\r\n") - name_lower = name.lower() - if name_lower == "content-length": - has_content_length = True - elif name_lower == "transfer-encoding": - has_transfer_encoding = True - elif name_lower == "connection": - has_connection = True # Add server header if not present header_lines.append("Server: gunicorn/asgi\r\n") diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py index bcde84ee..737268b6 100644 --- a/gunicorn/asgi/websocket.py +++ b/gunicorn/asgi/websocket.py @@ -12,7 +12,6 @@ import asyncio import base64 import hashlib import struct -import os # WebSocket frame opcodes @@ -81,7 +80,7 @@ class WebSocketProtocol: try: await self.app(self.scope, self._receive, self._send) - except Exception as e: + except Exception: self.log.exception("Error in WebSocket ASGI application") finally: read_task.cancel() @@ -180,7 +179,8 @@ class WebSocketProtocol: if opcode == OPCODE_CLOSE: await self._handle_close(payload) break - elif opcode == OPCODE_PING: + + if opcode == OPCODE_PING: await self._send_frame(OPCODE_PONG, payload) elif opcode == OPCODE_PONG: # Ignore pongs @@ -212,7 +212,7 @@ class WebSocketProtocol: "code": self.close_code or CLOSE_ABNORMAL, }) - async def _read_frame(self): + async def _read_frame(self): # pylint: disable=too-many-return-statements """Read a single WebSocket frame. Returns: @@ -326,10 +326,9 @@ class WebSocketProtocol: self.closed = True - async def _handle_continuation(self, payload): + async def _handle_continuation(self, payload): # pylint: disable=unused-argument """Handle continuation frame (already processed in _read_frame).""" - # This is called for partial fragments, nothing to do - pass + # This is called for partial fragments, nothing to do here async def _send_frame(self, opcode, payload): """Send a WebSocket frame. diff --git a/gunicorn/workers/gasgi.py b/gunicorn/workers/gasgi.py index b0d57cf0..118d11de 100644 --- a/gunicorn/workers/gasgi.py +++ b/gunicorn/workers/gasgi.py @@ -12,7 +12,6 @@ HTTP parsing infrastructure. import asyncio import os import signal -import ssl import sys from gunicorn.workers import base From 903a1fdf3cf779c8561d53ce9c9cc9f6017b4eac Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 18:32:01 +0100 Subject: [PATCH 3/8] tests: Add pytest-asyncio for ASGI worker tests The ASGI worker tests use @pytest.mark.asyncio decorator which requires the pytest-asyncio plugin to be installed. --- requirements_test.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements_test.txt b/requirements_test.txt index b618d1a7..efa91f20 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -3,3 +3,4 @@ eventlet coverage pytest>=7.2.0 pytest-cov +pytest-asyncio From ac7296ec49cf32556c19963d20bc7ba83d51acc3 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 18:04:23 +0100 Subject: [PATCH 4/8] uwsgi: Add native uWSGI binary protocol support Add support for the uWSGI binary protocol, enabling gunicorn to work with nginx's uwsgi_pass directive. New module gunicorn/uwsgi/ with: - UWSGIRequest: Parses 4-byte binary header and key-value vars block - UWSGIParser: Protocol parser following existing Parser pattern - Error classes: InvalidUWSGIHeader, UnsupportedModifier, ForbiddenUWSGIRequest New configuration options: - --protocol: Select 'http' (default) or 'uwsgi' protocol - --uwsgi-allow-from: IP allowlist for uWSGI requests (default: localhost) Worker integration via get_parser() factory in gunicorn/http/__init__.py, updates to sync, gthread, and base_async workers. Example nginx config: upstream gunicorn { server 127.0.0.1:8000; } location / { uwsgi_pass gunicorn; include uwsgi_params; } --- gunicorn/config.py | 47 ++++ gunicorn/http/__init__.py | 21 +- gunicorn/uwsgi/__init__.py | 21 ++ gunicorn/uwsgi/errors.py | 46 ++++ gunicorn/uwsgi/message.py | 232 ++++++++++++++++++ gunicorn/uwsgi/parser.py | 12 + gunicorn/workers/base_async.py | 2 +- gunicorn/workers/gthread.py | 2 +- gunicorn/workers/sync.py | 2 +- tests/test_uwsgi.py | 435 +++++++++++++++++++++++++++++++++ 10 files changed, 816 insertions(+), 4 deletions(-) create mode 100644 gunicorn/uwsgi/__init__.py create mode 100644 gunicorn/uwsgi/errors.py create mode 100644 gunicorn/uwsgi/message.py create mode 100644 gunicorn/uwsgi/parser.py create mode 100644 tests/test_uwsgi.py diff --git a/gunicorn/config.py b/gunicorn/config.py index 522dcae9..1c36f987 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2096,6 +2096,53 @@ class ProxyAllowFrom(Setting): """ +class Protocol(Setting): + name = "protocol" + section = "Server Mechanics" + cli = ["--protocol"] + meta = "STRING" + validator = validate_string + default = "http" + desc = """\ + The protocol for incoming connections. + + * ``http`` - Standard HTTP/1.x (default) + * ``uwsgi`` - uWSGI binary protocol (for nginx uwsgi_pass) + + When using the uWSGI protocol, Gunicorn can receive requests from + nginx using the uwsgi_pass directive:: + + upstream gunicorn { + server 127.0.0.1:8000; + } + location / { + uwsgi_pass gunicorn; + include uwsgi_params; + } + """ + + +class UWSGIAllowFrom(Setting): + name = "uwsgi_allow_ips" + section = "Server Mechanics" + cli = ["--uwsgi-allow-from"] + validator = validate_string_to_addr_list + default = "127.0.0.1,::1" + desc = """\ + IPs allowed to send uWSGI protocol requests (comma separated). + + Set to ``*`` to allow all IPs. This is useful for setups where you + don't know in advance the IP address of front-end, but instead have + ensured via other means that only your authorized front-ends can + access Gunicorn. + + .. note:: + + This option does not affect UNIX socket connections. Connections not associated with + an IP address are treated as allowed, unconditionally. + """ + + class KeyFile(Setting): name = "keyfile" section = "SSL" diff --git a/gunicorn/http/__init__.py b/gunicorn/http/__init__.py index 11473bb0..1d35b7c7 100644 --- a/gunicorn/http/__init__.py +++ b/gunicorn/http/__init__.py @@ -5,4 +5,23 @@ from gunicorn.http.message import Message, Request from gunicorn.http.parser import RequestParser -__all__ = ['Message', 'Request', 'RequestParser'] + +def get_parser(cfg, source, source_addr): + """Get appropriate parser based on protocol config. + + Args: + cfg: Gunicorn config object + source: Socket or iterable source + source_addr: Source address tuple or None + + Returns: + Parser instance (RequestParser or UWSGIParser) + """ + protocol = getattr(cfg, 'protocol', 'http') + if protocol == 'uwsgi': + from gunicorn.uwsgi.parser import UWSGIParser + return UWSGIParser(cfg, source, source_addr) + return RequestParser(cfg, source, source_addr) + + +__all__ = ['Message', 'Request', 'RequestParser', 'get_parser'] diff --git a/gunicorn/uwsgi/__init__.py b/gunicorn/uwsgi/__init__.py new file mode 100644 index 00000000..cdf4f60c --- /dev/null +++ b/gunicorn/uwsgi/__init__.py @@ -0,0 +1,21 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.uwsgi.message import UWSGIRequest +from gunicorn.uwsgi.parser import UWSGIParser +from gunicorn.uwsgi.errors import ( + UWSGIParseException, + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) + +__all__ = [ + 'UWSGIRequest', + 'UWSGIParser', + 'UWSGIParseException', + 'InvalidUWSGIHeader', + 'UnsupportedModifier', + 'ForbiddenUWSGIRequest', +] diff --git a/gunicorn/uwsgi/errors.py b/gunicorn/uwsgi/errors.py new file mode 100644 index 00000000..cdbaee21 --- /dev/null +++ b/gunicorn/uwsgi/errors.py @@ -0,0 +1,46 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +# We don't need to call super() in __init__ methods of our +# BaseException and Exception classes because we also define +# our own __str__ methods so there is no need to pass 'message' +# to the base class to get a meaningful output from 'str(exc)'. +# pylint: disable=super-init-not-called + + +class UWSGIParseException(Exception): + """Base exception for uWSGI protocol parsing errors.""" + + +class InvalidUWSGIHeader(UWSGIParseException): + """Raised when the uWSGI header is malformed.""" + + def __init__(self, msg=""): + self.msg = msg + self.code = 400 + + def __str__(self): + return "Invalid uWSGI header: %s" % self.msg + + +class UnsupportedModifier(UWSGIParseException): + """Raised when modifier1 is not 0 (WSGI request).""" + + def __init__(self, modifier): + self.modifier = modifier + self.code = 501 + + def __str__(self): + return "Unsupported uWSGI modifier1: %d" % self.modifier + + +class ForbiddenUWSGIRequest(UWSGIParseException): + """Raised when source IP is not in the allow list.""" + + def __init__(self, host): + self.host = host + self.code = 403 + + def __str__(self): + return "uWSGI request from %r not allowed" % self.host diff --git a/gunicorn/uwsgi/message.py b/gunicorn/uwsgi/message.py new file mode 100644 index 00000000..a63172eb --- /dev/null +++ b/gunicorn/uwsgi/message.py @@ -0,0 +1,232 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +import io + +from gunicorn.http.body import LengthReader, Body +from gunicorn.uwsgi.errors import ( + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) + + +# Maximum number of variables to prevent DoS +MAX_UWSGI_VARS = 1000 + + +class UWSGIRequest: + """uWSGI protocol request parser. + + The uWSGI protocol uses a 4-byte binary header: + - Byte 0: modifier1 (packet type, 0 = WSGI request) + - Bytes 1-2: datasize (16-bit little-endian, size of vars block) + - Byte 3: modifier2 (additional flags, typically 0) + + After the header: + 1. Vars block (datasize bytes): Key-value pairs containing WSGI environ + - Each pair: 2-byte key_size (LE) + key + 2-byte val_size (LE) + value + 2. Request body (determined by CONTENT_LENGTH in vars) + """ + + def __init__(self, cfg, unreader, peer_addr, req_number=1): + self.cfg = cfg + self.unreader = unreader + self.peer_addr = peer_addr + self.remote_addr = peer_addr + self.req_number = req_number + + # Request attributes (compatible with HTTP Request interface) + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = "" + self.version = (1, 1) # uWSGI is HTTP/1.1 compatible + self.headers = [] + self.trailers = [] + self.body = None + self.scheme = "https" if cfg.is_ssl else "http" + self.must_close = False + + # uWSGI specific + self.uwsgi_vars = {} + self.modifier1 = 0 + self.modifier2 = 0 + + # Proxy protocol compatibility + self.proxy_protocol_info = None + + # Check if the source IP is allowed + self._check_allowed_ip() + + # Parse the request + unused = self.parse(self.unreader) + self.unreader.unread(unused) + self.set_body_reader() + + def _check_allowed_ip(self): + """Verify source IP is in the allowed list.""" + allow_ips = getattr(self.cfg, 'uwsgi_allow_ips', ['127.0.0.1', '::1']) + + # UNIX sockets don't have IP addresses + if not isinstance(self.peer_addr, tuple): + return + + # Wildcard allows all + if '*' in allow_ips: + return + + if self.peer_addr[0] not in allow_ips: + raise ForbiddenUWSGIRequest(self.peer_addr[0]) + + def force_close(self): + """Force the connection to close after this request.""" + self.must_close = True + + def parse(self, unreader): + """Parse uWSGI packet header and vars block.""" + # Read the 4-byte header + header = self._read_exact(unreader, 4) + if len(header) < 4: + raise InvalidUWSGIHeader("incomplete header") + + self.modifier1 = header[0] + datasize = int.from_bytes(header[1:3], 'little') + self.modifier2 = header[3] + + # Only modifier1=0 (WSGI request) is supported + if self.modifier1 != 0: + raise UnsupportedModifier(self.modifier1) + + # Read the vars block + if datasize > 0: + vars_data = self._read_exact(unreader, datasize) + if len(vars_data) < datasize: + raise InvalidUWSGIHeader("incomplete vars block") + self._parse_vars(vars_data) + + # Extract HTTP request info from vars + self._extract_request_info() + + return b"" + + def _read_exact(self, unreader, size): + """Read exactly size bytes from the unreader.""" + buf = io.BytesIO() + remaining = size + + while remaining > 0: + data = unreader.read() + if not data: + break + buf.write(data) + remaining = size - buf.tell() + + result = buf.getvalue() + # Put back any extra bytes + if len(result) > size: + unreader.unread(result[size:]) + result = result[:size] + + return result + + def _parse_vars(self, data): + """Parse uWSGI vars block into key-value pairs. + + Format: key_size (2 bytes LE) + key + val_size (2 bytes LE) + value + """ + pos = 0 + var_count = 0 + + while pos < len(data): + if var_count >= MAX_UWSGI_VARS: + raise InvalidUWSGIHeader("too many variables") + + # Key size (2 bytes, little-endian) + if pos + 2 > len(data): + raise InvalidUWSGIHeader("truncated key size") + key_size = int.from_bytes(data[pos:pos + 2], 'little') + pos += 2 + + # Key + if pos + key_size > len(data): + raise InvalidUWSGIHeader("truncated key") + key = data[pos:pos + key_size].decode('latin-1') + pos += key_size + + # Value size (2 bytes, little-endian) + if pos + 2 > len(data): + raise InvalidUWSGIHeader("truncated value size") + val_size = int.from_bytes(data[pos:pos + 2], 'little') + pos += 2 + + # Value + if pos + val_size > len(data): + raise InvalidUWSGIHeader("truncated value") + value = data[pos:pos + val_size].decode('latin-1') + pos += val_size + + self.uwsgi_vars[key] = value + var_count += 1 + + def _extract_request_info(self): + """Extract HTTP request info from uWSGI vars.""" + # Method + self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET') + + # URI and path + self.path = self.uwsgi_vars.get('PATH_INFO', '/') + self.query = self.uwsgi_vars.get('QUERY_STRING', '') + + # Build URI + if self.query: + self.uri = "%s?%s" % (self.path, self.query) + else: + self.uri = self.path + + # Scheme + if self.uwsgi_vars.get('HTTPS', '').lower() in ('on', '1', 'true'): + self.scheme = 'https' + elif 'wsgi.url_scheme' in self.uwsgi_vars: + self.scheme = self.uwsgi_vars['wsgi.url_scheme'] + + # Extract HTTP headers (HTTP_* vars) + for key, value in self.uwsgi_vars.items(): + if key.startswith('HTTP_'): + # Convert HTTP_HEADER_NAME to HEADER-NAME + header_name = key[5:].replace('_', '-') + self.headers.append((header_name, value)) + elif key == 'CONTENT_TYPE': + self.headers.append(('CONTENT-TYPE', value)) + elif key == 'CONTENT_LENGTH': + self.headers.append(('CONTENT-LENGTH', value)) + + def set_body_reader(self): + """Set up the body reader based on CONTENT_LENGTH.""" + content_length = 0 + + # Get content length from vars + if 'CONTENT_LENGTH' in self.uwsgi_vars: + try: + content_length = max(int(self.uwsgi_vars['CONTENT_LENGTH']), 0) + except ValueError: + content_length = 0 + + self.body = Body(LengthReader(self.unreader, content_length)) + + def should_close(self): + """Determine if the connection should be closed after this request.""" + if self.must_close: + return True + + # Check HTTP_CONNECTION header + connection = self.uwsgi_vars.get('HTTP_CONNECTION', '').lower() + if connection == 'close': + return True + elif connection == 'keep-alive': + return False + + # Default to keep-alive for HTTP/1.1 + return False diff --git a/gunicorn/uwsgi/parser.py b/gunicorn/uwsgi/parser.py new file mode 100644 index 00000000..fede8c56 --- /dev/null +++ b/gunicorn/uwsgi/parser.py @@ -0,0 +1,12 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.http.parser import Parser +from gunicorn.uwsgi.message import UWSGIRequest + + +class UWSGIParser(Parser): + """Parser for uWSGI protocol requests.""" + + mesg_class = UWSGIRequest diff --git a/gunicorn/workers/base_async.py b/gunicorn/workers/base_async.py index 9466d6aa..22ea09ab 100644 --- a/gunicorn/workers/base_async.py +++ b/gunicorn/workers/base_async.py @@ -32,7 +32,7 @@ class AsyncWorker(base.Worker): def handle(self, listener, client, addr): req = None try: - parser = http.RequestParser(self.cfg, client, addr) + parser = http.get_parser(self.cfg, client, addr) try: listener_name = listener.getsockname() if not self.cfg.keepalive: diff --git a/gunicorn/workers/gthread.py b/gunicorn/workers/gthread.py index 47270725..7cab9920 100644 --- a/gunicorn/workers/gthread.py +++ b/gunicorn/workers/gthread.py @@ -58,7 +58,7 @@ class TConn: self.sock = sock.ssl_wrap_socket(self.sock, self.cfg) # initialize the parser - self.parser = http.RequestParser(self.cfg, self.sock, self.client) + self.parser = http.get_parser(self.cfg, self.sock, self.client) def set_timeout(self): # Use monotonic clock for reliability (time.time() can jump due to NTP) diff --git a/gunicorn/workers/sync.py b/gunicorn/workers/sync.py index 4c029f91..99dbdaac 100644 --- a/gunicorn/workers/sync.py +++ b/gunicorn/workers/sync.py @@ -129,7 +129,7 @@ class SyncWorker(base.Worker): try: if self.cfg.is_ssl: client = sock.ssl_wrap_socket(client, self.cfg) - parser = http.RequestParser(self.cfg, client, addr) + parser = http.get_parser(self.cfg, client, addr) req = next(parser) self.handle_request(listener, req, client, addr) except http.errors.NoMoreData as e: diff --git a/tests/test_uwsgi.py b/tests/test_uwsgi.py new file mode 100644 index 00000000..26ff09f5 --- /dev/null +++ b/tests/test_uwsgi.py @@ -0,0 +1,435 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +import io +import pytest +from unittest import mock + +from gunicorn.uwsgi import ( + UWSGIRequest, + UWSGIParser, + UWSGIParseException, + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) +from gunicorn.http.unreader import IterUnreader + + +def make_uwsgi_packet(vars_dict, modifier1=0, modifier2=0): + """Create uWSGI packet for testing. + + Args: + vars_dict: Dict of WSGI environ variables + modifier1: Packet type (0 = WSGI request) + modifier2: Additional flags + + Returns: + bytes: Complete uWSGI packet + """ + vars_data = b'' + for key, value in vars_dict.items(): + k = key.encode('latin-1') + v = value.encode('latin-1') + vars_data += len(k).to_bytes(2, 'little') + k + vars_data += len(v).to_bytes(2, 'little') + v + + header = bytes([modifier1]) + len(vars_data).to_bytes(2, 'little') + bytes([modifier2]) + return header + vars_data + + +def make_uwsgi_packet_with_body(vars_dict, body=b'', modifier1=0, modifier2=0): + """Create uWSGI packet with body for testing.""" + if body: + vars_dict = dict(vars_dict) + vars_dict['CONTENT_LENGTH'] = str(len(body)) + return make_uwsgi_packet(vars_dict, modifier1, modifier2) + body + + +class MockConfig: + """Mock config object for testing.""" + + def __init__(self, is_ssl=False, uwsgi_allow_ips=None): + self.is_ssl = is_ssl + self.uwsgi_allow_ips = uwsgi_allow_ips or ['127.0.0.1', '::1'] + + +class TestUWSGIPacketConstruction: + """Test the packet construction helper.""" + + def test_empty_vars(self): + packet = make_uwsgi_packet({}) + assert packet == b'\x00\x00\x00\x00' # modifier1=0, size=0, modifier2=0 + + def test_single_var(self): + packet = make_uwsgi_packet({'KEY': 'val'}) + # Header: modifier1(0) + size(10 in LE) + modifier2(0) + # Var: key_size(3 in LE) + 'KEY' + val_size(3 in LE) + 'val' + # Size = 2 + 3 + 2 + 3 = 10 bytes + expected_header = b'\x00\x0a\x00\x00' + expected_var = b'\x03\x00KEY\x03\x00val' + assert packet == expected_header + expected_var + + def test_multiple_vars(self): + packet = make_uwsgi_packet({'A': '1', 'B': '2'}) + assert len(packet) == 4 + (2 + 1 + 2 + 1) * 2 # header + 2 vars + + +class TestUWSGIRequest: + """Test UWSGIRequest parsing.""" + + def test_parse_simple_request(self): + """Test parsing a simple GET request.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/test', + 'QUERY_STRING': 'foo=bar', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'GET' + assert req.path == '/test' + assert req.query == 'foo=bar' + assert req.uri == '/test?foo=bar' + + def test_parse_post_request_with_body(self): + """Test parsing a POST request with body.""" + body = b'name=test&value=123' + packet = make_uwsgi_packet_with_body({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/submit', + 'CONTENT_TYPE': 'application/x-www-form-urlencoded', + }, body) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'POST' + assert req.path == '/submit' + assert req.body.read() == body + + def test_parse_headers(self): + """Test that HTTP_* vars become headers.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_HOST': 'example.com', + 'HTTP_USER_AGENT': 'TestClient/1.0', + 'HTTP_ACCEPT': 'text/html', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + headers_dict = dict(req.headers) + assert headers_dict['HOST'] == 'example.com' + assert headers_dict['USER-AGENT'] == 'TestClient/1.0' + assert headers_dict['ACCEPT'] == 'text/html' + + def test_parse_content_type_header(self): + """Test that CONTENT_TYPE becomes a header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_TYPE': 'application/json', + 'CONTENT_LENGTH': '0', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + headers_dict = dict(req.headers) + assert headers_dict['CONTENT-TYPE'] == 'application/json' + assert headers_dict['CONTENT-LENGTH'] == '0' + + def test_https_scheme(self): + """Test scheme detection from HTTPS variable.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTPS': 'on', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.scheme == 'https' + + def test_wsgi_url_scheme(self): + """Test scheme from wsgi.url_scheme variable.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'wsgi.url_scheme': 'https', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.scheme == 'https' + + def test_default_values(self): + """Test default values when vars are missing.""" + packet = make_uwsgi_packet({}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'GET' + assert req.path == '/' + assert req.query == '' + assert req.uri == '/' + + def test_uwsgi_vars_preserved(self): + """Test that all vars are preserved in uwsgi_vars.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'SERVER_NAME': 'localhost', + 'SERVER_PORT': '8000', + 'CUSTOM_VAR': 'custom_value', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.uwsgi_vars['SERVER_NAME'] == 'localhost' + assert req.uwsgi_vars['SERVER_PORT'] == '8000' + assert req.uwsgi_vars['CUSTOM_VAR'] == 'custom_value' + + +class TestUWSGIRequestErrors: + """Test UWSGIRequest error handling.""" + + def test_incomplete_header(self): + """Test error on incomplete header.""" + unreader = IterUnreader([b'\x00\x00']) # Only 2 bytes + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'incomplete header' in str(exc_info.value) + + def test_incomplete_vars_block(self): + """Test error on truncated vars block.""" + # Header says 100 bytes of vars, but we only provide 10 + header = b'\x00\x64\x00\x00' # modifier1=0, size=100, modifier2=0 + unreader = IterUnreader([header + b'1234567890']) + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'incomplete vars block' in str(exc_info.value) + + def test_unsupported_modifier(self): + """Test error on non-zero modifier1.""" + packet = bytes([1]) + b'\x00\x00\x00' # modifier1=1 + unreader = IterUnreader([packet]) + cfg = MockConfig() + + with pytest.raises(UnsupportedModifier) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert exc_info.value.modifier == 1 + assert exc_info.value.code == 501 + + def test_truncated_key_size(self): + """Test error on truncated key size.""" + header = b'\x00\x01\x00\x00' # size=1, but need at least 2 bytes for key_size + unreader = IterUnreader([header + b'X']) + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'truncated' in str(exc_info.value) + + def test_forbidden_ip(self): + """Test error when source IP not in allow list.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['192.168.1.1']) + + with pytest.raises(ForbiddenUWSGIRequest) as exc_info: + UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345)) + assert exc_info.value.code == 403 + assert '10.0.0.1' in str(exc_info.value) + + def test_allowed_ip_wildcard(self): + """Test that wildcard allows any IP.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['*']) + + # Should not raise + req = UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345)) + assert req.method == 'GET' + + def test_unix_socket_always_allowed(self): + """Test that UNIX socket connections are always allowed.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['127.0.0.1']) + + # UNIX socket has non-tuple peer_addr + req = UWSGIRequest(cfg, unreader, None) + assert req.method == 'GET' + + +class TestUWSGIRequestConnection: + """Test connection handling.""" + + def test_should_close_default(self): + """Test default keep-alive behavior.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is False + + def test_should_close_connection_close(self): + """Test Connection: close header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_CONNECTION': 'close', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is True + + def test_should_close_connection_keepalive(self): + """Test Connection: keep-alive header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_CONNECTION': 'keep-alive', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is False + + def test_force_close(self): + """Test force_close method.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + req.force_close() + + assert req.should_close() is True + + +class TestUWSGIParser: + """Test UWSGIParser.""" + + def test_parser_iteration(self): + """Test iterating over parser for multiple requests.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/test', + 'HTTP_CONNECTION': 'close', # Single request + }) + cfg = MockConfig() + + # Parser expects an iterable source, not an unreader + parser = UWSGIParser(cfg, [packet], ('127.0.0.1', 12345)) + req = next(parser) + + assert req.method == 'GET' + assert req.path == '/test' + + def test_parser_mesg_class(self): + """Test that parser uses UWSGIRequest.""" + assert UWSGIParser.mesg_class is UWSGIRequest + + +class TestExceptionStrings: + """Test exception string representations.""" + + def test_invalid_uwsgi_header_str(self): + exc = InvalidUWSGIHeader("test message") + assert str(exc) == "Invalid uWSGI header: test message" + assert exc.code == 400 + + def test_unsupported_modifier_str(self): + exc = UnsupportedModifier(5) + assert str(exc) == "Unsupported uWSGI modifier1: 5" + assert exc.code == 501 + + def test_forbidden_uwsgi_request_str(self): + exc = ForbiddenUWSGIRequest("10.0.0.1") + assert str(exc) == "uWSGI request from '10.0.0.1' not allowed" + assert exc.code == 403 + + +class TestUWSGIBody: + """Test body reading.""" + + def test_read_body_in_chunks(self): + """Test reading body in multiple chunks.""" + body = b'A' * 1000 + packet = make_uwsgi_packet_with_body({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + }, body) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + result = b'' + chunk = req.body.read(100) + while chunk: + result += chunk + chunk = req.body.read(100) + + assert result == body + + def test_invalid_content_length(self): + """Test handling of invalid CONTENT_LENGTH.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_LENGTH': 'invalid', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + # Invalid content length should default to 0 + assert req.body.read() == b'' + + def test_negative_content_length(self): + """Test handling of negative CONTENT_LENGTH.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_LENGTH': '-5', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + # Negative content length should default to 0 + assert req.body.read() == b'' From ecc471f3b4e1732decd386ee744a21e2553d354a Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 19:06:30 +0100 Subject: [PATCH 5/8] tests: Add Docker integration tests for uWSGI protocol with nginx Add comprehensive integration tests verifying gunicorn's uWSGI binary protocol works correctly with nginx's uwsgi_pass directive. Test categories: - Basic GET/POST requests with query strings and large bodies - Header preservation (custom headers, Host, Content-Type) - HTTP keep-alive connections - Error responses (400-503 status codes) - WSGI environ variables - Large response streaming (1MB) - Concurrent request handling - Edge cases (binary data, unicode, long headers) Architecture: pytest -> nginx:8080 -> uwsgi_pass -> gunicorn:8000 Also adds GitHub Actions workflow that runs on changes to uwsgi module or docker test files. --- .github/workflows/docker-integration.yml | 45 +++ tests/docker/uwsgi/Dockerfile.gunicorn | 16 + tests/docker/uwsgi/Dockerfile.nginx | 12 + tests/docker/uwsgi/README.md | 154 +++++++++ tests/docker/uwsgi/app.py | 222 +++++++++++++ tests/docker/uwsgi/conftest.py | 121 +++++++ tests/docker/uwsgi/docker-compose.yml | 29 ++ tests/docker/uwsgi/nginx.conf | 46 +++ tests/docker/uwsgi/test_uwsgi_integration.py | 312 +++++++++++++++++++ tests/docker/uwsgi/uwsgi_params | 16 + 10 files changed, 973 insertions(+) create mode 100644 .github/workflows/docker-integration.yml create mode 100644 tests/docker/uwsgi/Dockerfile.gunicorn create mode 100644 tests/docker/uwsgi/Dockerfile.nginx create mode 100644 tests/docker/uwsgi/README.md create mode 100644 tests/docker/uwsgi/app.py create mode 100644 tests/docker/uwsgi/conftest.py create mode 100644 tests/docker/uwsgi/docker-compose.yml create mode 100644 tests/docker/uwsgi/nginx.conf create mode 100644 tests/docker/uwsgi/test_uwsgi_integration.py create mode 100644 tests/docker/uwsgi/uwsgi_params diff --git a/.github/workflows/docker-integration.yml b/.github/workflows/docker-integration.yml new file mode 100644 index 00000000..c63c7bff --- /dev/null +++ b/.github/workflows/docker-integration.yml @@ -0,0 +1,45 @@ +name: Docker Integration Tests + +on: + push: + branches: [master] + paths: + - 'gunicorn/uwsgi/**' + - 'tests/docker/uwsgi/**' + - '.github/workflows/docker-integration.yml' + pull_request: + paths: + - 'gunicorn/uwsgi/**' + - 'tests/docker/uwsgi/**' + - '.github/workflows/docker-integration.yml' + +permissions: + contents: read + +env: + FORCE_COLOR: 1 + +jobs: + uwsgi-nginx: + name: uWSGI Protocol with nginx + runs-on: ubuntu-latest + timeout-minutes: 15 + + steps: + - uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: requirements_test.txt + + - name: Install test dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pytest requests + + - name: Run uWSGI integration tests + run: | + pytest tests/docker/uwsgi/ -v --tb=short diff --git a/tests/docker/uwsgi/Dockerfile.gunicorn b/tests/docker/uwsgi/Dockerfile.gunicorn new file mode 100644 index 00000000..2fd73a74 --- /dev/null +++ b/tests/docker/uwsgi/Dockerfile.gunicorn @@ -0,0 +1,16 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Copy gunicorn source +COPY . /app/gunicorn-src/ + +# Install gunicorn from source +RUN pip install --no-cache-dir /app/gunicorn-src/ + +# Copy test application +COPY tests/docker/uwsgi/app.py /app/ + +EXPOSE 8000 + +CMD ["gunicorn", "--protocol", "uwsgi", "--uwsgi-allow-from", "*", "--bind", "0.0.0.0:8000", "--workers", "2", "--log-level", "debug", "app:application"] diff --git a/tests/docker/uwsgi/Dockerfile.nginx b/tests/docker/uwsgi/Dockerfile.nginx new file mode 100644 index 00000000..e934a0f7 --- /dev/null +++ b/tests/docker/uwsgi/Dockerfile.nginx @@ -0,0 +1,12 @@ +FROM nginx:alpine + +# Remove default config +RUN rm /etc/nginx/conf.d/default.conf + +# Copy custom config +COPY nginx.conf /etc/nginx/nginx.conf +COPY uwsgi_params /etc/nginx/uwsgi_params + +EXPOSE 8080 + +CMD ["nginx", "-g", "daemon off;"] diff --git a/tests/docker/uwsgi/README.md b/tests/docker/uwsgi/README.md new file mode 100644 index 00000000..d8c78f19 --- /dev/null +++ b/tests/docker/uwsgi/README.md @@ -0,0 +1,154 @@ +# uWSGI Protocol Docker Integration Tests + +This directory contains Docker-based integration tests that verify gunicorn's +uWSGI binary protocol implementation works correctly with nginx's `uwsgi_pass` +directive. + +## Architecture + +``` +[pytest] --HTTP--> [nginx:8080] --uwsgi_pass--> [gunicorn:8000] +``` + +The tests make HTTP requests to nginx, which proxies them to gunicorn using the +uWSGI binary protocol. This validates the complete request/response cycle through +the protocol. + +## Prerequisites + +- Docker +- Docker Compose (v2) +- Python 3.8+ +- pytest +- requests + +## Running Tests + +### From repository root: + +```bash +# Run all uWSGI integration tests +pytest tests/docker/uwsgi/ -v + +# Run specific test class +pytest tests/docker/uwsgi/ -v -k TestBasicRequests + +# Skip Docker tests (for CI environments without Docker) +pytest tests/ -v -m "not docker" +``` + +### Manual testing: + +```bash +cd tests/docker/uwsgi + +# Start services +docker compose up -d + +# Wait for services to be healthy +docker compose ps + +# Test endpoints +curl http://localhost:8080/ +curl -X POST -d "test body" http://localhost:8080/echo +curl http://localhost:8080/headers +curl "http://localhost:8080/query?foo=bar" +curl http://localhost:8080/environ +curl http://localhost:8080/error/404 +curl http://localhost:8080/large > /dev/null # 1MB response + +# View logs +docker compose logs gunicorn +docker compose logs nginx + +# Stop services +docker compose down -v +``` + +## Test Categories + +| Category | Description | +|----------|-------------| +| `TestBasicRequests` | GET, POST, query strings, large bodies | +| `TestHeaderPreservation` | Custom headers, Host, Content-Type, User-Agent | +| `TestKeepAlive` | Multiple requests per connection | +| `TestErrorResponses` | HTTP error codes (400, 404, 500, etc.) | +| `TestEnvironVariables` | WSGI environ: REQUEST_METHOD, PATH_INFO, etc. | +| `TestLargeResponses` | 1MB response body streaming | +| `TestConcurrency` | Parallel request handling | +| `TestSpecialCases` | Edge cases: binary data, unicode, long headers | + +## Files + +| File | Purpose | +|------|---------| +| `docker-compose.yml` | Orchestrates nginx + gunicorn containers | +| `Dockerfile.gunicorn` | Builds gunicorn image with test app | +| `Dockerfile.nginx` | Builds nginx with uwsgi config | +| `nginx.conf` | nginx configuration using `uwsgi_pass` | +| `uwsgi_params` | Standard uwsgi parameter mappings | +| `app.py` | Test WSGI application with multiple endpoints | +| `conftest.py` | pytest fixtures for Docker lifecycle | +| `test_uwsgi_integration.py` | Test cases | + +## Test App Endpoints + +| Endpoint | Method | Description | +|----------|--------|-------------| +| `/` | GET | Basic hello response | +| `/echo` | POST | Echo request body | +| `/headers` | GET/POST | Return received headers as JSON | +| `/environ` | GET/POST | Return WSGI environ as JSON | +| `/query` | GET | Return query params as JSON | +| `/json` | POST | Parse and echo JSON body | +| `/error/{code}` | GET | Return specified HTTP error | +| `/large` | GET | Return 1MB response | + +## Gunicorn Configuration + +The gunicorn container runs with: + +```bash +gunicorn \ + --protocol uwsgi \ + --uwsgi-allow-from "*" \ + --bind 0.0.0.0:8000 \ + --workers 2 \ + --log-level debug \ + app:application +``` + +Key settings: +- `--protocol uwsgi`: Enable uWSGI binary protocol +- `--uwsgi-allow-from "*"`: Accept connections from Docker network IPs + +## Troubleshooting + +### Services won't start + +Check Docker logs: +```bash +docker compose logs +``` + +### Connection refused + +Wait for health checks: +```bash +docker compose ps # Check health status +``` + +### Tests timing out + +Increase `STARTUP_TIMEOUT` in `conftest.py` or check if ports are in use: +```bash +lsof -i :8080 +lsof -i :8000 +``` + +### Rebuild after code changes + +```bash +docker compose build --no-cache +docker compose up -d +``` diff --git a/tests/docker/uwsgi/app.py b/tests/docker/uwsgi/app.py new file mode 100644 index 00000000..6eb681cf --- /dev/null +++ b/tests/docker/uwsgi/app.py @@ -0,0 +1,222 @@ +""" +Test WSGI application for uWSGI protocol integration tests. + +This application provides various endpoints to test different aspects +of the uWSGI binary protocol when proxied through nginx. +""" + +import json + + +def application(environ, start_response): + """Main WSGI application entry point.""" + path = environ.get('PATH_INFO', '/') + method = environ.get('REQUEST_METHOD', 'GET') + + # Route to appropriate handler + if path == '/': + return handle_root(environ, start_response) + elif path == '/echo': + return handle_echo(environ, start_response) + elif path == '/headers': + return handle_headers(environ, start_response) + elif path == '/environ': + return handle_environ(environ, start_response) + elif path.startswith('/error/'): + return handle_error(environ, start_response, path) + elif path == '/large': + return handle_large(environ, start_response) + elif path == '/json': + return handle_json(environ, start_response) + elif path == '/query': + return handle_query(environ, start_response) + else: + return handle_not_found(environ, start_response) + + +def handle_root(environ, start_response): + """Basic root endpoint.""" + status = '200 OK' + headers = [('Content-Type', 'text/plain')] + start_response(status, headers) + return [b'Hello from gunicorn uWSGI!\n'] + + +def handle_echo(environ, start_response): + """Echo back the request body.""" + try: + content_length = int(environ.get('CONTENT_LENGTH', 0)) + except (ValueError, TypeError): + content_length = 0 + + body = b'' + if content_length > 0: + body = environ['wsgi.input'].read(content_length) + + status = '200 OK' + headers = [ + ('Content-Type', 'application/octet-stream'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_headers(environ, start_response): + """Return received HTTP headers as JSON.""" + headers_dict = {} + for key, value in environ.items(): + if key.startswith('HTTP_'): + # Convert HTTP_X_CUSTOM_HEADER to X-Custom-Header + header_name = key[5:].replace('_', '-').title() + headers_dict[header_name] = value + + # Also include some special headers + if 'CONTENT_TYPE' in environ: + headers_dict['Content-Type'] = environ['CONTENT_TYPE'] + if 'CONTENT_LENGTH' in environ: + headers_dict['Content-Length'] = environ['CONTENT_LENGTH'] + + body = json.dumps(headers_dict, indent=2).encode('utf-8') + status = '200 OK' + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_environ(environ, start_response): + """Return WSGI environ variables as JSON.""" + # Filter to serializable values + safe_environ = {} + skip_keys = {'wsgi.input', 'wsgi.errors', 'wsgi.file_wrapper'} + + for key, value in environ.items(): + if key in skip_keys: + continue + try: + # Test if value is JSON serializable + json.dumps(value) + safe_environ[key] = value + except (TypeError, ValueError): + safe_environ[key] = str(value) + + body = json.dumps(safe_environ, indent=2).encode('utf-8') + status = '200 OK' + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_error(environ, start_response, path): + """Return specified HTTP error code.""" + try: + code = int(path.split('/')[-1]) + except ValueError: + code = 500 + + status_messages = { + 400: 'Bad Request', + 401: 'Unauthorized', + 403: 'Forbidden', + 404: 'Not Found', + 500: 'Internal Server Error', + 502: 'Bad Gateway', + 503: 'Service Unavailable', + } + + message = status_messages.get(code, 'Error') + status = f'{code} {message}' + body = json.dumps({'error': message, 'code': code}).encode('utf-8') + + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_large(environ, start_response): + """Return a 1MB response body for testing large responses.""" + # Generate 1MB of data (1024 * 1024 bytes) + chunk_size = 1024 + num_chunks = 1024 + chunk = b'X' * chunk_size + + status = '200 OK' + headers = [ + ('Content-Type', 'application/octet-stream'), + ('Content-Length', str(chunk_size * num_chunks)) + ] + start_response(status, headers) + + # Return as generator for streaming + def generate(): + for _ in range(num_chunks): + yield chunk + + return generate() + + +def handle_json(environ, start_response): + """Handle JSON POST requests.""" + try: + content_length = int(environ.get('CONTENT_LENGTH', 0)) + except (ValueError, TypeError): + content_length = 0 + + if content_length > 0: + body = environ['wsgi.input'].read(content_length) + try: + data = json.loads(body.decode('utf-8')) + response = {'received': data, 'status': 'ok'} + except json.JSONDecodeError: + response = {'error': 'Invalid JSON', 'status': 'error'} + else: + response = {'error': 'No body', 'status': 'error'} + + body = json.dumps(response).encode('utf-8') + status = '200 OK' + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_query(environ, start_response): + """Return query string parameters as JSON.""" + from urllib.parse import parse_qs + query_string = environ.get('QUERY_STRING', '') + params = parse_qs(query_string) + + # Convert lists to single values where appropriate + simple_params = {k: v[0] if len(v) == 1 else v for k, v in params.items()} + + body = json.dumps(simple_params).encode('utf-8') + status = '200 OK' + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] + + +def handle_not_found(environ, start_response): + """Handle 404 for unknown paths.""" + body = json.dumps({'error': 'Not Found', 'path': environ.get('PATH_INFO')}).encode('utf-8') + status = '404 Not Found' + headers = [ + ('Content-Type', 'application/json'), + ('Content-Length', str(len(body))) + ] + start_response(status, headers) + return [body] diff --git a/tests/docker/uwsgi/conftest.py b/tests/docker/uwsgi/conftest.py new file mode 100644 index 00000000..a31e0de3 --- /dev/null +++ b/tests/docker/uwsgi/conftest.py @@ -0,0 +1,121 @@ +""" +pytest fixtures for uWSGI Docker integration tests. +""" + +import os +import subprocess +import time + +import pytest +import requests + + +COMPOSE_FILE = os.path.join(os.path.dirname(__file__), 'docker-compose.yml') +NGINX_URL = 'http://127.0.0.1:8080' +STARTUP_TIMEOUT = 60 # seconds + + +def is_docker_available(): + """Check if Docker is available.""" + try: + result = subprocess.run( + ['docker', 'info'], + capture_output=True, + timeout=10 + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError): + return False + + +def is_compose_available(): + """Check if docker compose is available.""" + try: + result = subprocess.run( + ['docker', 'compose', 'version'], + capture_output=True, + timeout=10 + ) + return result.returncode == 0 + except (subprocess.TimeoutExpired, FileNotFoundError): + return False + + +docker_available = pytest.mark.skipif( + not is_docker_available() or not is_compose_available(), + reason="Docker or docker compose not available" +) + + +@pytest.fixture(scope='session') +def docker_services(): + """ + Start Docker Compose services for the test session. + + This fixture builds and starts the gunicorn and nginx containers, + waits for them to be healthy, and tears them down after all tests. + """ + if not is_docker_available() or not is_compose_available(): + pytest.skip("Docker or docker compose not available") + + # Build and start services + subprocess.run( + ['docker', 'compose', '-f', COMPOSE_FILE, 'build'], + check=True, + capture_output=True + ) + + subprocess.run( + ['docker', 'compose', '-f', COMPOSE_FILE, 'up', '-d'], + check=True, + capture_output=True + ) + + # Wait for services to be healthy + start_time = time.time() + while time.time() - start_time < STARTUP_TIMEOUT: + try: + response = requests.get(f'{NGINX_URL}/', timeout=2) + if response.status_code == 200: + break + except requests.RequestException: + pass + time.sleep(1) + else: + # Get logs for debugging + logs = subprocess.run( + ['docker', 'compose', '-f', COMPOSE_FILE, 'logs'], + capture_output=True, + text=True + ) + subprocess.run( + ['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'], + capture_output=True + ) + pytest.fail( + f"Services did not become healthy within {STARTUP_TIMEOUT}s.\n" + f"Logs:\n{logs.stdout}\n{logs.stderr}" + ) + + yield + + # Teardown + subprocess.run( + ['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'], + capture_output=True + ) + + +@pytest.fixture +def nginx_url(docker_services): + """Return the nginx base URL.""" + return NGINX_URL + + +@pytest.fixture +def session(docker_services): + """Return a requests Session with keep-alive enabled.""" + with requests.Session() as s: + # Enable keep-alive + s.headers['Connection'] = 'keep-alive' + yield s diff --git a/tests/docker/uwsgi/docker-compose.yml b/tests/docker/uwsgi/docker-compose.yml new file mode 100644 index 00000000..71c30355 --- /dev/null +++ b/tests/docker/uwsgi/docker-compose.yml @@ -0,0 +1,29 @@ +services: + gunicorn: + build: + context: ../../.. + dockerfile: tests/docker/uwsgi/Dockerfile.gunicorn + expose: + - "8000" + healthcheck: + test: ["CMD", "python", "-c", "import socket; s=socket.socket(); s.connect(('localhost', 8000)); s.close()"] + interval: 2s + timeout: 5s + retries: 10 + start_period: 5s + + nginx: + build: + context: . + dockerfile: Dockerfile.nginx + ports: + - "8080:8080" + depends_on: + gunicorn: + condition: service_healthy + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8080/"] + interval: 2s + timeout: 5s + retries: 10 + start_period: 5s diff --git a/tests/docker/uwsgi/nginx.conf b/tests/docker/uwsgi/nginx.conf new file mode 100644 index 00000000..052f4f81 --- /dev/null +++ b/tests/docker/uwsgi/nginx.conf @@ -0,0 +1,46 @@ +worker_processes 1; + +events { + worker_connections 1024; +} + +http { + include /etc/nginx/mime.types; + default_type application/octet-stream; + + log_format main '$remote_addr - $remote_user [$time_local] "$request" ' + '$status $body_bytes_sent "$http_referer" ' + '"$http_user_agent"'; + + access_log /var/log/nginx/access.log main; + error_log /var/log/nginx/error.log debug; + + sendfile on; + keepalive_timeout 65; + + upstream gunicorn { + server gunicorn:8000; + } + + server { + listen 8080; + server_name localhost; + + # Increase buffer sizes for large headers + uwsgi_buffer_size 32k; + uwsgi_buffers 8 32k; + uwsgi_busy_buffers_size 64k; + + # Read timeout for large responses + uwsgi_read_timeout 300s; + + location / { + uwsgi_pass gunicorn; + include uwsgi_params; + + # Pass additional headers + uwsgi_param HTTP_X_FORWARDED_FOR $proxy_add_x_forwarded_for; + uwsgi_param HTTP_X_REAL_IP $remote_addr; + } + } +} diff --git a/tests/docker/uwsgi/test_uwsgi_integration.py b/tests/docker/uwsgi/test_uwsgi_integration.py new file mode 100644 index 00000000..eea9a9e5 --- /dev/null +++ b/tests/docker/uwsgi/test_uwsgi_integration.py @@ -0,0 +1,312 @@ +""" +Integration tests for gunicorn's uWSGI binary protocol with nginx. + +These tests verify that gunicorn correctly implements the uWSGI binary +protocol by running actual requests through nginx's uwsgi_pass directive. +""" + +import concurrent.futures +import json + +import pytest +import requests + +from conftest import docker_available + + +@docker_available +class TestBasicRequests: + """Test basic HTTP request handling through uWSGI protocol.""" + + def test_get_root(self, nginx_url): + """Test basic GET request to root endpoint.""" + response = requests.get(f'{nginx_url}/') + assert response.status_code == 200 + assert b'Hello from gunicorn uWSGI!' in response.content + + def test_get_with_query_string(self, nginx_url): + """Test GET request with query string parameters.""" + response = requests.get(f'{nginx_url}/query?foo=bar&baz=qux') + assert response.status_code == 200 + data = response.json() + assert data['foo'] == 'bar' + assert data['baz'] == 'qux' + + def test_post_echo(self, nginx_url): + """Test POST request with body echo.""" + test_body = b'This is a test body content' + response = requests.post(f'{nginx_url}/echo', data=test_body) + assert response.status_code == 200 + assert response.content == test_body + + def test_post_json(self, nginx_url): + """Test POST request with JSON body.""" + test_data = {'key': 'value', 'number': 42, 'nested': {'a': 1}} + response = requests.post( + f'{nginx_url}/json', + json=test_data, + headers={'Content-Type': 'application/json'} + ) + assert response.status_code == 200 + data = response.json() + assert data['status'] == 'ok' + assert data['received'] == test_data + + def test_post_large_body(self, nginx_url): + """Test POST with large request body (100KB).""" + large_body = b'X' * (100 * 1024) + response = requests.post(f'{nginx_url}/echo', data=large_body) + assert response.status_code == 200 + assert len(response.content) == len(large_body) + assert response.content == large_body + + +@docker_available +class TestHeaderPreservation: + """Test that headers are correctly passed through uWSGI protocol.""" + + def test_custom_headers(self, nginx_url): + """Test custom headers are passed to the application.""" + custom_headers = { + 'X-Custom-Header': 'custom-value', + 'X-Another-Header': 'another-value' + } + response = requests.get(f'{nginx_url}/headers', headers=custom_headers) + assert response.status_code == 200 + data = response.json() + assert data.get('X-Custom-Header') == 'custom-value' + assert data.get('X-Another-Header') == 'another-value' + + def test_host_header(self, nginx_url): + """Test Host header is passed correctly.""" + response = requests.get( + f'{nginx_url}/headers', + headers={'Host': 'test.example.com'} + ) + assert response.status_code == 200 + data = response.json() + assert data.get('Host') == 'test.example.com' + + def test_content_type_header(self, nginx_url): + """Test Content-Type header is passed correctly.""" + response = requests.post( + f'{nginx_url}/headers', + data='test', + headers={'Content-Type': 'application/x-custom-type'} + ) + assert response.status_code == 200 + data = response.json() + assert data.get('Content-Type') == 'application/x-custom-type' + + def test_user_agent_header(self, nginx_url): + """Test User-Agent header is passed correctly.""" + response = requests.get( + f'{nginx_url}/headers', + headers={'User-Agent': 'TestAgent/1.0'} + ) + assert response.status_code == 200 + data = response.json() + assert data.get('User-Agent') == 'TestAgent/1.0' + + +@docker_available +class TestKeepAlive: + """Test HTTP keep-alive with multiple requests per connection.""" + + def test_multiple_requests_same_session(self, session, nginx_url): + """Test multiple requests using same session/connection.""" + for i in range(5): + response = session.get(f'{nginx_url}/') + assert response.status_code == 200 + + def test_mixed_requests_same_session(self, session, nginx_url): + """Test mixed GET and POST requests using same session.""" + # GET request + response = session.get(f'{nginx_url}/') + assert response.status_code == 200 + + # POST request + response = session.post(f'{nginx_url}/echo', data=b'test') + assert response.status_code == 200 + assert response.content == b'test' + + # Another GET + response = session.get(f'{nginx_url}/headers') + assert response.status_code == 200 + + # JSON POST + response = session.post(f'{nginx_url}/json', json={'test': 1}) + assert response.status_code == 200 + + +@docker_available +class TestErrorResponses: + """Test HTTP error responses through uWSGI protocol.""" + + @pytest.mark.parametrize('code', [400, 401, 403, 404, 500, 502, 503]) + def test_error_codes(self, nginx_url, code): + """Test various HTTP error codes are returned correctly.""" + response = requests.get(f'{nginx_url}/error/{code}') + assert response.status_code == code + data = response.json() + assert data['code'] == code + + def test_not_found(self, nginx_url): + """Test 404 for non-existent path.""" + response = requests.get(f'{nginx_url}/nonexistent/path') + assert response.status_code == 404 + data = response.json() + assert data['error'] == 'Not Found' + assert data['path'] == '/nonexistent/path' + + +@docker_available +class TestEnvironVariables: + """Test WSGI environ variables are correctly set.""" + + def test_request_method(self, nginx_url): + """Test REQUEST_METHOD is set correctly.""" + response = requests.get(f'{nginx_url}/environ') + assert response.status_code == 200 + data = response.json() + assert data.get('REQUEST_METHOD') == 'GET' + + response = requests.post(f'{nginx_url}/environ', data='') + data = response.json() + assert data.get('REQUEST_METHOD') == 'POST' + + def test_path_info(self, nginx_url): + """Test PATH_INFO is set correctly.""" + response = requests.get(f'{nginx_url}/environ') + assert response.status_code == 200 + data = response.json() + assert data.get('PATH_INFO') == '/environ' + + def test_query_string(self, nginx_url): + """Test QUERY_STRING is set correctly.""" + response = requests.get(f'{nginx_url}/environ?foo=bar&test=123') + assert response.status_code == 200 + data = response.json() + assert data.get('QUERY_STRING') == 'foo=bar&test=123' + + def test_server_protocol(self, nginx_url): + """Test SERVER_PROTOCOL is set.""" + response = requests.get(f'{nginx_url}/environ') + assert response.status_code == 200 + data = response.json() + assert 'SERVER_PROTOCOL' in data + assert data['SERVER_PROTOCOL'].startswith('HTTP/') + + def test_content_length(self, nginx_url): + """Test CONTENT_LENGTH is set for POST requests.""" + body = 'test body content' + response = requests.post(f'{nginx_url}/environ', data=body) + assert response.status_code == 200 + data = response.json() + assert data.get('CONTENT_LENGTH') == str(len(body)) + + +@docker_available +class TestLargeResponses: + """Test large response handling through uWSGI protocol.""" + + def test_1mb_response(self, nginx_url): + """Test 1MB response body is received correctly.""" + response = requests.get(f'{nginx_url}/large') + assert response.status_code == 200 + assert len(response.content) == 1024 * 1024 + # Verify content is all 'X' characters + assert response.content == b'X' * (1024 * 1024) + + def test_large_response_content_length(self, nginx_url): + """Test Content-Length header for large response.""" + response = requests.get(f'{nginx_url}/large') + assert response.status_code == 200 + assert response.headers.get('Content-Length') == str(1024 * 1024) + + +@docker_available +class TestConcurrency: + """Test concurrent request handling.""" + + def test_parallel_requests(self, nginx_url): + """Test handling multiple parallel requests.""" + num_requests = 20 + + def make_request(i): + response = requests.get(f'{nginx_url}/query?id={i}') + return response.status_code, response.json().get('id') + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(make_request, i) for i in range(num_requests)] + results = [f.result() for f in concurrent.futures.as_completed(futures)] + + # All requests should succeed + assert all(status == 200 for status, _ in results) + # All IDs should be present + ids = set(id_val for _, id_val in results) + assert ids == set(str(i) for i in range(num_requests)) + + def test_parallel_mixed_requests(self, nginx_url): + """Test parallel GET and POST requests.""" + def get_request(): + return requests.get(f'{nginx_url}/').status_code + + def post_request(data): + response = requests.post(f'{nginx_url}/echo', data=data) + return response.status_code, response.content + + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: + get_futures = [executor.submit(get_request) for _ in range(10)] + post_futures = [ + executor.submit(post_request, f'data-{i}'.encode()) + for i in range(10) + ] + + get_results = [f.result() for f in get_futures] + post_results = [f.result() for f in post_futures] + + assert all(status == 200 for status in get_results) + assert all(status == 200 for status, _ in post_results) + + +@docker_available +class TestSpecialCases: + """Test edge cases and special scenarios.""" + + def test_empty_body_post(self, nginx_url): + """Test POST with empty body.""" + response = requests.post(f'{nginx_url}/echo', data=b'') + assert response.status_code == 200 + assert response.content == b'' + + def test_binary_body(self, nginx_url): + """Test POST with binary body containing null bytes.""" + binary_data = bytes(range(256)) + response = requests.post(f'{nginx_url}/echo', data=binary_data) + assert response.status_code == 200 + assert response.content == binary_data + + def test_unicode_in_query_string(self, nginx_url): + """Test unicode characters in query string.""" + response = requests.get(f'{nginx_url}/query', params={'name': 'test'}) + assert response.status_code == 200 + data = response.json() + assert data.get('name') == 'test' + + def test_special_characters_in_path(self, nginx_url): + """Test handling of special path that triggers 404.""" + # This should return 404 since the path doesn't exist + response = requests.get(f'{nginx_url}/path/with/slashes') + assert response.status_code == 404 + + def test_long_header_value(self, nginx_url): + """Test handling of long header values.""" + long_value = 'X' * 4096 # 4KB header value + response = requests.get( + f'{nginx_url}/headers', + headers={'X-Long-Header': long_value} + ) + assert response.status_code == 200 + data = response.json() + assert data.get('X-Long-Header') == long_value diff --git a/tests/docker/uwsgi/uwsgi_params b/tests/docker/uwsgi/uwsgi_params new file mode 100644 index 00000000..5abf809b --- /dev/null +++ b/tests/docker/uwsgi/uwsgi_params @@ -0,0 +1,16 @@ +uwsgi_param QUERY_STRING $query_string; +uwsgi_param REQUEST_METHOD $request_method; +uwsgi_param CONTENT_TYPE $content_type; +uwsgi_param CONTENT_LENGTH $content_length; + +uwsgi_param REQUEST_URI $request_uri; +uwsgi_param PATH_INFO $document_uri; +uwsgi_param DOCUMENT_ROOT $document_root; +uwsgi_param SERVER_PROTOCOL $server_protocol; +uwsgi_param REQUEST_SCHEME $scheme; +uwsgi_param HTTPS $https if_not_empty; + +uwsgi_param REMOTE_ADDR $remote_addr; +uwsgi_param REMOTE_PORT $remote_port; +uwsgi_param SERVER_PORT $server_port; +uwsgi_param SERVER_NAME $server_name; From 99ffa0cc6b1bb199c27cf42d5ececf3487274e41 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 19:24:30 +0100 Subject: [PATCH 6/8] tests: Exclude docker tests from regular pytest runs - Add tests/docker to norecursedirs in pyproject.toml to prevent docker tests from running during regular test suite (they require docker and the requests library) - Add -p no:cov to docker integration workflow to disable coverage plugin since pytest-cov is not installed in that environment --- .github/workflows/docker-integration.yml | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker-integration.yml b/.github/workflows/docker-integration.yml index c63c7bff..2e918780 100644 --- a/.github/workflows/docker-integration.yml +++ b/.github/workflows/docker-integration.yml @@ -42,4 +42,4 @@ jobs: - name: Run uWSGI integration tests run: | - pytest tests/docker/uwsgi/ -v --tb=short + pytest tests/docker/uwsgi/ -v --tb=short -p no:cov diff --git a/pyproject.toml b/pyproject.toml index 7803dc55..3fecbd30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,7 @@ main = "gunicorn.app.pasterapp:serve" [tool.pytest.ini_options] # # can override these: python -m pytest --override-ini="addopts=" -norecursedirs = ["examples", "lib", "local", "src"] +norecursedirs = ["examples", "lib", "local", "src", "tests/docker"] testpaths = ["tests/"] addopts = "--assert=plain --cov=gunicorn --cov-report=xml" From 1521266e2fda309f590dac365b47521615279b0e Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 19:28:11 +0100 Subject: [PATCH 7/8] asgi/uwsgi: Address PR review feedback - asgi: Check HTTP method is GET for WebSocket upgrade per RFC 6455 Section 4.1. Previously HEAD and other methods with upgrade headers could trigger WebSocket handling. - uwsgi: Add detailed docstring explaining header mapping from CGI-style environment variables to HTTP headers, including the lossy nature of underscore-to-hyphen conversion. --- gunicorn/asgi/protocol.py | 12 +++++++++++- gunicorn/uwsgi/message.py | 27 +++++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index 0eb1d045..01569ce4 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -143,7 +143,17 @@ class ASGIProtocol(asyncio.Protocol): self._close_transport() def _is_websocket_upgrade(self, request): - """Check if request is a WebSocket upgrade.""" + """Check if request is a WebSocket upgrade. + + Per RFC 6455 Section 4.1, the opening handshake requires: + - HTTP method MUST be GET + - Upgrade header MUST be "websocket" (case-insensitive) + - Connection header MUST contain "Upgrade" + """ + # RFC 6455: The method of the request MUST be GET + if request.method != "GET": + return False + upgrade = None connection = None for name, value in request.headers: diff --git a/gunicorn/uwsgi/message.py b/gunicorn/uwsgi/message.py index a63172eb..db69b5e2 100644 --- a/gunicorn/uwsgi/message.py +++ b/gunicorn/uwsgi/message.py @@ -172,7 +172,29 @@ class UWSGIRequest: var_count += 1 def _extract_request_info(self): - """Extract HTTP request info from uWSGI vars.""" + """Extract HTTP request info from uWSGI vars. + + Header Mapping (CGI/WSGI to HTTP): + + The uWSGI protocol passes HTTP headers using CGI-style environment + variable naming. This method converts them back to HTTP header format: + + - HTTP_* vars: Strip 'HTTP_' prefix, replace '_' with '-' + Example: HTTP_X_FORWARDED_FOR -> X-FORWARDED-FOR + Example: HTTP_ACCEPT_ENCODING -> ACCEPT-ENCODING + + - CONTENT_TYPE: Mapped directly to CONTENT-TYPE header + (CGI spec excludes HTTP_ prefix for this header) + + - CONTENT_LENGTH: Mapped directly to CONTENT-LENGTH header + (CGI spec excludes HTTP_ prefix for this header) + + Note: The underscore-to-hyphen conversion is lossy. Headers that + originally contained underscores (e.g., X_Custom_Header) cannot be + distinguished from hyphenated headers (X-Custom-Header) after + passing through nginx/uWSGI. This is a CGI/WSGI specification + limitation, not specific to this implementation. + """ # Method self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET') @@ -192,7 +214,8 @@ class UWSGIRequest: elif 'wsgi.url_scheme' in self.uwsgi_vars: self.scheme = self.uwsgi_vars['wsgi.url_scheme'] - # Extract HTTP headers (HTTP_* vars) + # Extract HTTP headers from CGI-style vars + # See docstring above for mapping details for key, value in self.uwsgi_vars.items(): if key.startswith('HTTP_'): # Convert HTTP_HEADER_NAME to HEADER-NAME From 81b653457c72b506ab85b66ef92c78f101f59dbf Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 19:33:57 +0100 Subject: [PATCH 8/8] ci: Fix test dependencies for Docker and FreeBSD workflows - Docker integration: Install pytest-cov to support coverage addopts - FreeBSD: Install pytest-asyncio for ASGI async test support --- .github/workflows/docker-integration.yml | 4 ++-- .github/workflows/freebsd.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/docker-integration.yml b/.github/workflows/docker-integration.yml index 2e918780..1333b103 100644 --- a/.github/workflows/docker-integration.yml +++ b/.github/workflows/docker-integration.yml @@ -38,8 +38,8 @@ jobs: - name: Install test dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest requests + python -m pip install pytest pytest-cov requests - name: Run uWSGI integration tests run: | - pytest tests/docker/uwsgi/ -v --tb=short -p no:cov + pytest tests/docker/uwsgi/ -v --tb=short diff --git a/.github/workflows/freebsd.yml b/.github/workflows/freebsd.yml index 060f40ca..120cc909 100644 --- a/.github/workflows/freebsd.yml +++ b/.github/workflows/freebsd.yml @@ -40,7 +40,7 @@ jobs: python${{ matrix.python-version }} -m venv venv . venv/bin/activate pip install --upgrade pip - pip install pytest pytest-cov coverage + pip install pytest pytest-cov pytest-asyncio coverage pip install -e . pytest --cov=gunicorn -v tests/ \ --ignore=tests/workers/test_ggevent.py \