diff --git a/examples/asgi/__init__.py b/examples/asgi/__init__.py new file mode 100644 index 00000000..1c9ecbeb --- /dev/null +++ b/examples/asgi/__init__.py @@ -0,0 +1,7 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI example applications for gunicorn. +""" diff --git a/examples/asgi/basic_app.py b/examples/asgi/basic_app.py new file mode 100644 index 00000000..73a160fe --- /dev/null +++ b/examples/asgi/basic_app.py @@ -0,0 +1,130 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Basic ASGI application example. + +Run with: + gunicorn -k asgi examples.asgi.basic_app:app + +Test with: + curl http://127.0.0.1:8000/ + curl http://127.0.0.1:8000/hello + curl -X POST http://127.0.0.1:8000/echo -d "test data" +""" + + +async def app(scope, receive, send): + """Simple ASGI application demonstrating basic functionality.""" + + if scope["type"] == "lifespan": + await handle_lifespan(scope, receive, send) + elif scope["type"] == "http": + await handle_http(scope, receive, send) + else: + raise ValueError(f"Unknown scope type: {scope['type']}") + + +async def handle_lifespan(scope, receive, send): + """Handle lifespan events (startup/shutdown).""" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + print("ASGI application starting up...") + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + print("ASGI application shutting down...") + await send({"type": "lifespan.shutdown.complete"}) + return + + +async def handle_http(scope, receive, send): + """Handle HTTP requests.""" + path = scope["path"] + method = scope["method"] + + if path == "/" and method == "GET": + await send_response(send, 200, b"Welcome to gunicorn ASGI!\n") + + elif path == "/hello" and method == "GET": + name = get_query_param(scope, "name", "World") + body = f"Hello, {name}!\n".encode() + await send_response(send, 200, body) + + elif path == "/echo" and method == "POST": + body = await read_body(receive) + await send_response(send, 200, body, content_type=b"application/octet-stream") + + elif path == "/headers": + headers_info = format_headers(scope["headers"]) + await send_response(send, 200, headers_info.encode()) + + elif path == "/info": + info = format_request_info(scope) + await send_response(send, 200, info.encode(), content_type=b"application/json") + + else: + await send_response(send, 404, b"Not Found\n") + + +async def send_response(send, status, body, content_type=b"text/plain"): + """Send an HTTP response.""" + await send({ + "type": "http.response.start", + "status": status, + "headers": [ + (b"content-type", content_type), + (b"content-length", str(len(body)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": body, + }) + + +async def read_body(receive): + """Read the full request body.""" + body = b"" + while True: + message = await receive() + body += message.get("body", b"") + if not message.get("more_body", False): + break + return body + + +def get_query_param(scope, name, default=None): + """Get a query parameter value.""" + query_string = scope.get("query_string", b"").decode() + for param in query_string.split("&"): + if "=" in param: + key, value = param.split("=", 1) + if key == name: + return value + return default + + +def format_headers(headers): + """Format headers for display.""" + lines = ["Request Headers:"] + for name, value in headers: + lines.append(f" {name.decode()}: {value.decode()}") + return "\n".join(lines) + "\n" + + +def format_request_info(scope): + """Format request info as JSON.""" + import json + info = { + "method": scope["method"], + "path": scope["path"], + "query_string": scope.get("query_string", b"").decode(), + "http_version": scope["http_version"], + "scheme": scope["scheme"], + "server": list(scope.get("server") or []), + "client": list(scope.get("client") or []), + "root_path": scope.get("root_path", ""), + } + return json.dumps(info, indent=2) + "\n" diff --git a/examples/asgi/websocket_app.py b/examples/asgi/websocket_app.py new file mode 100644 index 00000000..8423c30e --- /dev/null +++ b/examples/asgi/websocket_app.py @@ -0,0 +1,235 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +WebSocket ASGI application example. + +Run with: + gunicorn -k asgi examples.asgi.websocket_app:app + +Test with: + # Using websocat (install with: cargo install websocat) + websocat ws://127.0.0.1:8000/ws + + # Or using Python websockets library + python -c " + import asyncio + import websockets + async def test(): + async with websockets.connect('ws://127.0.0.1:8000/ws') as ws: + await ws.send('Hello') + print(await ws.recv()) + asyncio.run(test()) + " +""" + + +async def app(scope, receive, send): + """ASGI application with WebSocket support.""" + + if scope["type"] == "lifespan": + await handle_lifespan(scope, receive, send) + elif scope["type"] == "http": + await handle_http(scope, receive, send) + elif scope["type"] == "websocket": + await handle_websocket(scope, receive, send) + else: + raise ValueError(f"Unknown scope type: {scope['type']}") + + +async def handle_lifespan(scope, receive, send): + """Handle lifespan events.""" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + + +async def handle_http(scope, receive, send): + """Handle HTTP requests - serve a simple HTML page for WebSocket testing.""" + path = scope["path"] + + if path == "/": + html = HTML_PAGE.encode() + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/html"), + (b"content-length", str(len(html)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": html, + }) + else: + await send({ + "type": "http.response.start", + "status": 404, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Not Found", + }) + + +async def handle_websocket(scope, receive, send): + """Handle WebSocket connections.""" + path = scope["path"] + + if path == "/ws": + await echo_websocket(scope, receive, send) + elif path == "/ws/chat": + await chat_websocket(scope, receive, send) + else: + # Reject the connection + await send({"type": "websocket.close", "code": 4004}) + + +async def echo_websocket(scope, receive, send): + """Echo WebSocket - sends back whatever it receives.""" + # Wait for connection + message = await receive() + if message["type"] != "websocket.connect": + return + + # Accept the connection + await send({"type": "websocket.accept"}) + + # Echo loop + try: + while True: + message = await receive() + + if message["type"] == "websocket.disconnect": + break + + if message["type"] == "websocket.receive": + if "text" in message: + # Echo text back + await send({ + "type": "websocket.send", + "text": f"Echo: {message['text']}" + }) + elif "bytes" in message: + # Echo bytes back + await send({ + "type": "websocket.send", + "bytes": message["bytes"] + }) + except Exception as e: + print(f"WebSocket error: {e}") + finally: + try: + await send({"type": "websocket.close", "code": 1000}) + except Exception: + pass + + +async def chat_websocket(scope, receive, send): + """Chat WebSocket - simple broadcast example.""" + message = await receive() + if message["type"] != "websocket.connect": + return + + await send({ + "type": "websocket.accept", + "subprotocol": "chat" + }) + + await send({ + "type": "websocket.send", + "text": "Welcome to the chat! Send messages and they will be echoed back." + }) + + try: + while True: + message = await receive() + + if message["type"] == "websocket.disconnect": + break + + if message["type"] == "websocket.receive" and "text" in message: + text = message["text"] + await send({ + "type": "websocket.send", + "text": f"[You]: {text}" + }) + except Exception: + pass + + +HTML_PAGE = """ + + + WebSocket Test + + + +

WebSocket Test

+
+ + + + + + + + +""" diff --git a/gunicorn/asgi/__init__.py b/gunicorn/asgi/__init__.py new file mode 100644 index 00000000..c2f13b2a --- /dev/null +++ b/gunicorn/asgi/__init__.py @@ -0,0 +1,26 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI support for gunicorn. + +This module provides native ASGI worker support, using gunicorn's own +HTTP parsing infrastructure adapted for async I/O. + +Components: +- AsyncUnreader: Async socket reading with pushback buffer +- AsyncRequest: Async HTTP request parser +- ASGIProtocol: asyncio.Protocol implementation for HTTP handling +- WebSocketProtocol: WebSocket protocol handler (RFC 6455) +- LifespanManager: ASGI lifespan protocol support + +Usage: + gunicorn -k asgi myapp:app +""" + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest +from gunicorn.asgi.lifespan import LifespanManager + +__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager'] diff --git a/gunicorn/asgi/lifespan.py b/gunicorn/asgi/lifespan.py new file mode 100644 index 00000000..9811cf56 --- /dev/null +++ b/gunicorn/asgi/lifespan.py @@ -0,0 +1,178 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI lifespan protocol manager. + +Manages startup and shutdown events for ASGI applications, +enabling frameworks like FastAPI to run initialization and +cleanup code. +""" + +import asyncio + + +class LifespanManager: + """Manages ASGI lifespan events (startup/shutdown). + + The lifespan protocol allows ASGI applications to run code at + startup and shutdown. This is essential for applications that + need to initialize database connections, caches, or other + resources. + + ASGI lifespan messages: + - Server sends: {"type": "lifespan.startup"} + - App responds: {"type": "lifespan.startup.complete"} or + {"type": "lifespan.startup.failed", "message": "..."} + - Server sends: {"type": "lifespan.shutdown"} + - App responds: {"type": "lifespan.shutdown.complete"} + """ + + def __init__(self, app, logger, state=None): + """Initialize the lifespan manager. + + Args: + app: ASGI application callable + logger: Logger instance + state: Shared state dict for the application + """ + self.app = app + self.logger = logger + self.state = state if state is not None else {} + + self._startup_complete = asyncio.Event() + self._shutdown_complete = asyncio.Event() + self._startup_failed = False + self._startup_error = None + self._shutdown_error = None + self._receive_queue = asyncio.Queue() + self._task = None + self._app_finished = False + + async def startup(self): + """Run lifespan startup and wait for completion. + + Raises: + RuntimeError: If startup fails or app doesn't support lifespan + """ + scope = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "state": self.state, + } + + # Send startup event + await self._receive_queue.put({"type": "lifespan.startup"}) + + # Run lifespan in background task + self._task = asyncio.create_task(self._run_lifespan(scope)) + + # Wait for startup with timeout + try: + await asyncio.wait_for( + self._startup_complete.wait(), + timeout=30.0 # Reasonable startup timeout + ) + except asyncio.TimeoutError: + if self._task: + self._task.cancel() + raise RuntimeError("Lifespan startup timed out") + + if self._startup_failed: + if self._task: + self._task.cancel() + msg = self._startup_error or "Unknown error" + raise RuntimeError(f"Lifespan startup failed: {msg}") + + self.logger.debug("ASGI lifespan startup complete") + + async def shutdown(self): + """Signal shutdown and wait for completion. + + This should be called during graceful shutdown. + """ + if self._app_finished: + self.logger.debug("ASGI lifespan already finished") + return + + # Send shutdown event + await self._receive_queue.put({"type": "lifespan.shutdown"}) + + # Wait for shutdown with timeout + try: + await asyncio.wait_for( + self._shutdown_complete.wait(), + timeout=30.0 # Reasonable shutdown timeout + ) + except asyncio.TimeoutError: + self.logger.warning("Lifespan shutdown timed out") + + if self._shutdown_error: + self.logger.error("Lifespan shutdown error: %s", self._shutdown_error) + + # Cancel the task if still running + if self._task and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + + self.logger.debug("ASGI lifespan shutdown complete") + + async def _run_lifespan(self, scope): + """Run the ASGI lifespan protocol.""" + try: + await self.app(scope, self._receive, self._send) + except asyncio.CancelledError: + raise + except Exception as e: + self.logger.debug("Lifespan application raised: %s", e) + # If startup hasn't completed, mark it as failed + if not self._startup_complete.is_set(): + self._startup_failed = True + self._startup_error = str(e) + self._startup_complete.set() + # If shutdown hasn't completed, mark error + elif not self._shutdown_complete.is_set(): + self._shutdown_error = str(e) + self._shutdown_complete.set() + finally: + self._app_finished = True + # Ensure events are set to unblock waiters + if not self._startup_complete.is_set(): + self._startup_failed = True + self._startup_error = "Application exited before startup complete" + self._startup_complete.set() + if not self._shutdown_complete.is_set(): + self._shutdown_complete.set() + + async def _receive(self): + """ASGI receive callable for lifespan.""" + return await self._receive_queue.get() + + async def _send(self, message): + """ASGI send callable for lifespan.""" + msg_type = message["type"] + + if msg_type == "lifespan.startup.complete": + self._startup_complete.set() + self.logger.debug("Received lifespan.startup.complete") + + elif msg_type == "lifespan.startup.failed": + self._startup_failed = True + self._startup_error = message.get("message", "") + self._startup_complete.set() + self.logger.debug("Received lifespan.startup.failed: %s", + self._startup_error) + + elif msg_type == "lifespan.shutdown.complete": + self._shutdown_complete.set() + self.logger.debug("Received lifespan.shutdown.complete") + + elif msg_type == "lifespan.shutdown.failed": + self._shutdown_error = message.get("message", "") + self._shutdown_complete.set() + self.logger.debug("Received lifespan.shutdown.failed: %s", + self._shutdown_error) diff --git a/gunicorn/asgi/message.py b/gunicorn/asgi/message.py new file mode 100644 index 00000000..d7d20c83 --- /dev/null +++ b/gunicorn/asgi/message.py @@ -0,0 +1,562 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Async version of gunicorn/http/message.py for ASGI workers. + +Reuses the parsing logic from the sync version, adapted for async I/O. +""" + +import io +import re +import socket + +from gunicorn.http.errors import ( + InvalidHeader, InvalidHeaderName, NoMoreData, + InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, + LimitRequestLine, LimitRequestHeaders, + UnsupportedTransferCoding, ObsoleteFolding, + InvalidProxyLine, ForbiddenProxyRequest, + InvalidSchemeHeaders, +) +from gunicorn.util import bytes_to_str, split_request_uri + +MAX_REQUEST_LINE = 8190 +MAX_HEADERS = 32768 +DEFAULT_MAX_HEADERFIELD_SIZE = 8190 + +# Reuse regex patterns from sync version +RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~" +TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS))) +METHOD_BADCHAR_RE = re.compile("[a-z#]") +VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)") +RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]") + + +class AsyncRequest: + """Async HTTP request parser. + + Parses HTTP/1.x requests using async I/O, reusing gunicorn's + parsing logic where possible. + """ + + def __init__(self, cfg, unreader, peer_addr, req_number=1): + self.cfg = cfg + self.unreader = unreader + self.peer_addr = peer_addr + self.remote_addr = peer_addr + self.req_number = req_number + + self.version = None + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = None + self.headers = [] + self.trailers = [] + self.scheme = "https" if cfg.is_ssl else "http" + self.must_close = False + + self.proxy_protocol_info = None + + # Request line limit + self.limit_request_line = cfg.limit_request_line + if (self.limit_request_line < 0 + or self.limit_request_line >= MAX_REQUEST_LINE): + self.limit_request_line = MAX_REQUEST_LINE + + # Headers limits + self.limit_request_fields = cfg.limit_request_fields + if (self.limit_request_fields <= 0 + or self.limit_request_fields > MAX_HEADERS): + self.limit_request_fields = MAX_HEADERS + + self.limit_request_field_size = cfg.limit_request_field_size + if self.limit_request_field_size < 0: + self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE + + # Max header buffer size + max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE + self.max_buffer_headers = self.limit_request_fields * \ + (max_header_field_size + 2) + 4 + + # Body-related state + self.content_length = None + self.chunked = False + self._body_reader = None + self._body_remaining = 0 + + @classmethod + async def parse(cls, cfg, unreader, peer_addr, req_number=1): + """Parse an HTTP request from the stream. + + Args: + cfg: gunicorn config object + unreader: AsyncUnreader instance + peer_addr: client address tuple + req_number: request number on this connection (for keepalive) + + Returns: + AsyncRequest: Parsed request object + + Raises: + NoMoreData: If no data available + Various parsing errors for malformed requests + """ + req = cls(cfg, unreader, peer_addr, req_number) + await req._parse() + return req + + async def _parse(self): + """Parse the request from the unreader.""" + buf = io.BytesIO() + await self._get_data(buf, stop=True) + + # Get request line + line, rbuf = await self._read_line(buf, self.limit_request_line) + + # Proxy protocol + if self._proxy_protocol(bytes_to_str(line)): + # Get next request line + buf = io.BytesIO() + buf.write(rbuf) + line, rbuf = await self._read_line(buf, self.limit_request_line) + + self._parse_request_line(line) + buf = io.BytesIO() + buf.write(rbuf) + + # Headers + data = buf.getvalue() + + while True: + idx = data.find(b"\r\n\r\n") + done = data[:2] == b"\r\n" + + if idx < 0 and not done: + await self._get_data(buf) + data = buf.getvalue() + if len(data) > self.max_buffer_headers: + raise LimitRequestHeaders("max buffer headers") + else: + break + + if done: + self.unreader.unread(data[2:]) + else: + self.headers = self._parse_headers(data[:idx], from_trailer=False) + self.unreader.unread(data[idx + 4:]) + + self._set_body_reader() + + async def _get_data(self, buf, stop=False): + """Read data from unreader into buffer.""" + data = await self.unreader.read() + if not data: + if stop: + raise StopIteration() + raise NoMoreData(buf.getvalue()) + buf.write(data) + + async def _read_line(self, buf, limit=0): + """Read a line from the buffer/stream.""" + data = buf.getvalue() + + while True: + idx = data.find(b"\r\n") + if idx >= 0: + if idx > limit > 0: + raise LimitRequestLine(idx, limit) + break + if len(data) - 2 > limit > 0: + raise LimitRequestLine(len(data), limit) + await self._get_data(buf) + data = buf.getvalue() + + return (data[:idx], data[idx + 2:]) + + def _proxy_protocol(self, line): + """Detect, check and parse proxy protocol.""" + if not self.cfg.proxy_protocol: + return False + + if self.req_number != 1: + return False + + if not line.startswith("PROXY"): + return False + + self._proxy_protocol_access_check() + self._parse_proxy_protocol(line) + + return True + + def _proxy_protocol_access_check(self): + """Check if proxy protocol is allowed from this peer.""" + if ("*" not in self.cfg.proxy_allow_ips and + isinstance(self.peer_addr, tuple) and + self.peer_addr[0] not in self.cfg.proxy_allow_ips): + raise ForbiddenProxyRequest(self.peer_addr[0]) + + def _parse_proxy_protocol(self, line): + """Parse proxy protocol header line.""" + bits = line.split(" ") + + if len(bits) != 6: + raise InvalidProxyLine(line) + + proto = bits[1] + s_addr = bits[2] + d_addr = bits[3] + + if proto not in ["TCP4", "TCP6"]: + raise InvalidProxyLine("protocol '%s' not supported" % proto) + + if proto == "TCP4": + try: + socket.inet_pton(socket.AF_INET, s_addr) + socket.inet_pton(socket.AF_INET, d_addr) + except OSError: + raise InvalidProxyLine(line) + elif proto == "TCP6": + try: + socket.inet_pton(socket.AF_INET6, s_addr) + socket.inet_pton(socket.AF_INET6, d_addr) + except OSError: + raise InvalidProxyLine(line) + + try: + s_port = int(bits[4]) + d_port = int(bits[5]) + except ValueError: + raise InvalidProxyLine("invalid port %s" % line) + + if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)): + raise InvalidProxyLine("invalid port %s" % line) + + self.proxy_protocol_info = { + "proxy_protocol": proto, + "client_addr": s_addr, + "client_port": s_port, + "proxy_addr": d_addr, + "proxy_port": d_port + } + + def _parse_request_line(self, line_bytes): + """Parse the HTTP request line.""" + bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)] + if len(bits) != 3: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + + # Method + self.method = bits[0] + + if not self.cfg.permit_unconventional_http_method: + if METHOD_BADCHAR_RE.search(self.method): + raise InvalidRequestMethod(self.method) + if not 3 <= len(bits[0]) <= 20: + raise InvalidRequestMethod(self.method) + if not TOKEN_RE.fullmatch(self.method): + raise InvalidRequestMethod(self.method) + if self.cfg.casefold_http_method: + self.method = self.method.upper() + + # URI + self.uri = bits[1] + + if len(self.uri) == 0: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + + try: + parts = split_request_uri(self.uri) + except ValueError: + raise InvalidRequestLine(bytes_to_str(line_bytes)) + self.path = parts.path or "" + self.query = parts.query or "" + self.fragment = parts.fragment or "" + + # Version + match = VERSION_RE.fullmatch(bits[2]) + if match is None: + raise InvalidHTTPVersion(bits[2]) + self.version = (int(match.group(1)), int(match.group(2))) + if not (1, 0) <= self.version < (2, 0): + if not self.cfg.permit_unconventional_http_version: + raise InvalidHTTPVersion(self.version) + + def _parse_headers(self, data, from_trailer=False): + """Parse HTTP headers from raw data.""" + cfg = self.cfg + headers = [] + + lines = [bytes_to_str(line) for line in data.split(b"\r\n")] + + # Handle scheme headers + scheme_header = False + secure_scheme_headers = {} + forwarder_headers = [] + if from_trailer: + pass + elif ('*' in cfg.forwarded_allow_ips or + not isinstance(self.peer_addr, tuple) + or self.peer_addr[0] in cfg.forwarded_allow_ips): + secure_scheme_headers = cfg.secure_scheme_headers + forwarder_headers = cfg.forwarder_headers + + while lines: + if len(headers) >= self.limit_request_fields: + raise LimitRequestHeaders("limit request headers fields") + + curr = lines.pop(0) + header_length = len(curr) + len("\r\n") + if curr.find(":") <= 0: + raise InvalidHeader(curr) + name, value = curr.split(":", 1) + if self.cfg.strip_header_spaces: + name = name.rstrip(" \t") + if not TOKEN_RE.fullmatch(name): + raise InvalidHeaderName(name) + + name = name.upper() + value = [value.strip(" \t")] + + # Consume value continuation lines + while lines and lines[0].startswith((" ", "\t")): + if not self.cfg.permit_obsolete_folding: + raise ObsoleteFolding(name) + curr = lines.pop(0) + header_length += len(curr) + len("\r\n") + if header_length > self.limit_request_field_size > 0: + raise LimitRequestHeaders("limit request headers fields size") + value.append(curr.strip("\t ")) + value = " ".join(value) + + if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value): + raise InvalidHeader(name) + + if header_length > self.limit_request_field_size > 0: + raise LimitRequestHeaders("limit request headers fields size") + + if name in secure_scheme_headers: + secure = value == secure_scheme_headers[name] + scheme = "https" if secure else "http" + if scheme_header: + if scheme != self.scheme: + raise InvalidSchemeHeaders() + else: + scheme_header = True + self.scheme = scheme + + if "_" in name: + if name in forwarder_headers or "*" in forwarder_headers: + pass + elif self.cfg.header_map == "dangerous": + pass + elif self.cfg.header_map == "drop": + continue + else: + raise InvalidHeaderName(name) + + headers.append((name, value)) + + return headers + + def _set_body_reader(self): + """Determine how to read the request body.""" + chunked = False + content_length = None + + for (name, value) in self.headers: + if name == "CONTENT-LENGTH": + if content_length is not None: + raise InvalidHeader("CONTENT-LENGTH", req=self) + content_length = value + elif name == "TRANSFER-ENCODING": + vals = [v.strip() for v in value.split(',')] + for val in vals: + if val.lower() == "chunked": + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + chunked = True + elif val.lower() == "identity": + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + elif val.lower() in ('compress', 'deflate', 'gzip'): + if chunked: + raise InvalidHeader("TRANSFER-ENCODING", req=self) + self.force_close() + else: + raise UnsupportedTransferCoding(value) + + if chunked: + if self.version < (1, 1): + raise InvalidHeader("TRANSFER-ENCODING", req=self) + if content_length is not None: + raise InvalidHeader("CONTENT-LENGTH", req=self) + self.chunked = True + self.content_length = None + self._body_remaining = -1 + elif content_length is not None: + try: + if str(content_length).isnumeric(): + content_length = int(content_length) + else: + raise InvalidHeader("CONTENT-LENGTH", req=self) + except ValueError: + raise InvalidHeader("CONTENT-LENGTH", req=self) + + if content_length < 0: + raise InvalidHeader("CONTENT-LENGTH", req=self) + + self.content_length = content_length + self._body_remaining = content_length + else: + # No body for requests without Content-Length or Transfer-Encoding + self.content_length = 0 + self._body_remaining = 0 + + def force_close(self): + """Mark connection for closing after this request.""" + self.must_close = True + + def should_close(self): + """Check if connection should be closed after this request.""" + if self.must_close: + return True + for (h, v) in self.headers: + if h == "CONNECTION": + v = v.lower().strip(" \t") + if v == "close": + return True + elif v == "keep-alive": + return False + break + return self.version <= (1, 0) + + def get_header(self, name): + """Get a header value by name (case-insensitive).""" + name = name.upper() + for (h, v) in self.headers: + if h == name: + return v + return None + + async def read_body(self, size=8192): + """Read a chunk of the request body. + + Args: + size: Maximum bytes to read + + Returns: + bytes: Body data, empty bytes when body is exhausted + """ + if self._body_remaining == 0: + return b"" + + if self.chunked: + return await self._read_chunked_body(size) + else: + return await self._read_length_body(size) + + async def _read_length_body(self, size): + """Read from a length-delimited body.""" + if self._body_remaining <= 0: + return b"" + + to_read = min(size, self._body_remaining) + data = await self.unreader.read(to_read) + if data: + self._body_remaining -= len(data) + return data + + async def _read_chunked_body(self, size): + """Read from a chunked body.""" + if self._body_reader is None: + self._body_reader = self._chunked_body_reader() + + try: + return await self._body_reader.__anext__() + except StopAsyncIteration: + self._body_remaining = 0 + return b"" + + async def _chunked_body_reader(self): + """Async generator for reading chunked body.""" + while True: + # Read chunk size line + size_line = await self._read_chunk_size_line() + # Parse chunk size (handle extensions) + chunk_size, *_ = size_line.split(b";", 1) + if _ : + chunk_size = chunk_size.rstrip(b" \t") + + if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size): + raise InvalidHeader("Invalid chunk size") + if len(chunk_size) == 0: + raise InvalidHeader("Invalid chunk size") + + chunk_size = int(chunk_size, 16) + + if chunk_size == 0: + # Final chunk - skip trailers and final CRLF + await self._skip_trailers() + return + + # Read chunk data + remaining = chunk_size + while remaining > 0: + data = await self.unreader.read(min(remaining, 8192)) + if not data: + raise NoMoreData() + remaining -= len(data) + yield data + + # Skip chunk terminating CRLF + crlf = await self.unreader.read(2) + if crlf != b"\r\n": + # May have partial read, try to get the rest + while len(crlf) < 2: + more = await self.unreader.read(2 - len(crlf)) + if not more: + break + crlf += more + if crlf != b"\r\n": + raise InvalidHeader("Missing chunk terminator") + + async def _read_chunk_size_line(self): + """Read a chunk size line.""" + buf = io.BytesIO() + while True: + data = await self.unreader.read(1) + if not data: + raise NoMoreData() + buf.write(data) + if buf.getvalue().endswith(b"\r\n"): + return buf.getvalue()[:-2] + + async def _skip_trailers(self): + """Skip trailer headers after chunked body.""" + buf = io.BytesIO() + while True: + data = await self.unreader.read(1) + if not data: + return + buf.write(data) + content = buf.getvalue() + if content.endswith(b"\r\n\r\n"): + # Could parse trailers here if needed + return + if content == b"\r\n": + return + + async def drain_body(self): + """Drain any unread body data. + + Should be called before reusing connection for keepalive. + """ + while True: + data = await self.read_body(8192) + if not data: + break diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py new file mode 100644 index 00000000..cededd68 --- /dev/null +++ b/gunicorn/asgi/protocol.py @@ -0,0 +1,424 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI protocol handler for gunicorn. + +Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch +to ASGI applications. +""" + +import asyncio +import base64 +import hashlib +import traceback +from datetime import datetime + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest +from gunicorn.http.errors import NoMoreData + + +class ASGIProtocol(asyncio.Protocol): + """HTTP/1.1 protocol handler for ASGI applications. + + Handles connection lifecycle, request parsing, and ASGI app invocation. + """ + + def __init__(self, worker): + self.worker = worker + self.cfg = worker.cfg + self.log = worker.log + self.app = worker.asgi + + self.transport = None + self.reader = None + self.writer = None + self._task = None + self.req_count = 0 + + # Connection state + self._closed = False + + def connection_made(self, transport): + """Called when a connection is established.""" + self.transport = transport + self.worker.nr_conns += 1 + + # Create stream reader/writer + self.reader = asyncio.StreamReader() + self.writer = transport + + # Start handling requests + self._task = self.worker.loop.create_task(self._handle_connection()) + + def data_received(self, data): + """Called when data is received on the connection.""" + if self.reader: + self.reader.feed_data(data) + + def connection_lost(self, exc): + """Called when the connection is lost or closed.""" + self._closed = True + self.worker.nr_conns -= 1 + if self.reader: + self.reader.feed_eof() + if self._task and not self._task.done(): + self._task.cancel() + + async def _handle_connection(self): + """Main request handling loop for this connection.""" + unreader = AsyncUnreader(self.reader) + + try: + peername = self.transport.get_extra_info('peername') + sockname = self.transport.get_extra_info('sockname') + + while not self._closed: + self.req_count += 1 + + try: + # Parse HTTP request + request = await AsyncRequest.parse( + self.cfg, + unreader, + peername, + self.req_count + ) + except StopIteration: + # No more data, close connection + break + except NoMoreData: + # Client disconnected + break + + # Check for WebSocket upgrade + if self._is_websocket_upgrade(request): + await self._handle_websocket(request, sockname, peername) + break # WebSocket takes over the connection + else: + # Handle HTTP request + keepalive = await self._handle_http_request( + request, sockname, peername + ) + + # Increment worker request count + self.worker.nr += 1 + + # Check max_requests + if self.worker.nr >= self.worker.max_requests: + self.log.info("Autorestarting worker after current request.") + self.worker.alive = False + keepalive = False + + if not keepalive or not self.worker.alive: + break + + # Check connection limits for keepalive + if not self.cfg.keepalive: + break + + # Drain any unread body before next request + await request.drain_body() + + except asyncio.CancelledError: + pass + except Exception as e: + self.log.exception("Error handling connection: %s", e) + finally: + self._close_transport() + + def _is_websocket_upgrade(self, request): + """Check if request is a WebSocket upgrade.""" + upgrade = None + connection = None + for name, value in request.headers: + if name == "UPGRADE": + upgrade = value.lower() + elif name == "CONNECTION": + connection = value.lower() + return upgrade == "websocket" and connection and "upgrade" in connection + + async def _handle_websocket(self, request, sockname, peername): + """Handle WebSocket upgrade request.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = self._build_websocket_scope(request, sockname, peername) + ws_protocol = WebSocketProtocol( + self.transport, self.reader, scope, self.app, self.log + ) + await ws_protocol.run() + + async def _handle_http_request(self, request, sockname, peername): + """Handle a single HTTP request.""" + scope = self._build_http_scope(request, sockname, peername) + response_started = False + response_complete = False + body_parts = [] + exc_to_raise = None + + # Receive queue for body + receive_queue = asyncio.Queue() + + # Pre-populate with initial body state + if request.content_length == 0 and not request.chunked: + await receive_queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + else: + # Start body reading task + asyncio.create_task(self._read_body_to_queue(request, receive_queue)) + + async def receive(): + return await receive_queue.get() + + async def send(message): + nonlocal response_started, response_complete, exc_to_raise + + msg_type = message["type"] + + if msg_type == "http.response.start": + if response_started: + exc_to_raise = RuntimeError("Response already started") + return + response_started = True + status = message["status"] + headers = message.get("headers", []) + await self._send_response_start(status, headers, request) + + elif msg_type == "http.response.body": + if not response_started: + exc_to_raise = RuntimeError("Response not started") + return + if response_complete: + exc_to_raise = RuntimeError("Response already complete") + return + + body = message.get("body", b"") + more_body = message.get("more_body", False) + + if body: + await self._send_body(body) + + if not more_body: + response_complete = True + + try: + request_start = datetime.now() + self.cfg.pre_request(self.worker, request) + + await self.app(scope, receive, send) + + if exc_to_raise: + raise exc_to_raise + + # Ensure response was sent + if not response_started: + await self._send_error_response(500, "Internal Server Error") + + except Exception as e: + self.log.exception("Error in ASGI application") + if not response_started: + await self._send_error_response(500, "Internal Server Error") + return False + finally: + try: + request_time = datetime.now() - request_start + self.cfg.post_request(self.worker, request, {}, None) + except Exception: + self.log.exception("Exception in post_request hook") + + # Determine keepalive + if request.should_close(): + return False + + return self.worker.alive and self.cfg.keepalive + + async def _read_body_to_queue(self, request, queue): + """Read request body and put chunks on the queue.""" + try: + while True: + chunk = await request.read_body(65536) + if chunk: + await queue.put({ + "type": "http.request", + "body": chunk, + "more_body": True, + }) + else: + await queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + break + except Exception as e: + self.log.debug("Error reading body: %s", e) + await queue.put({ + "type": "http.request", + "body": b"", + "more_body": False, + }) + + def _build_http_scope(self, request, sockname, peername): + """Build ASGI HTTP scope from parsed request.""" + # Build headers list as bytes tuples + headers = [] + for name, value in request.headers: + headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + + scope = { + "type": "http", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "http_version": f"{request.version[0]}.{request.version[1]}", + "method": request.method, + "scheme": request.scheme, + "path": request.path, + "raw_path": request.path.encode("latin-1") if request.path else b"", + "query_string": request.query.encode("latin-1") if request.query else b"", + "root_path": self.cfg.root_path or "", + "headers": headers, + "server": sockname if sockname else None, + "client": peername if peername else None, + } + + # Add state dict for lifespan sharing + if hasattr(self.worker, 'state'): + scope["state"] = self.worker.state + + return scope + + def _build_websocket_scope(self, request, sockname, peername): + """Build ASGI WebSocket scope from parsed request.""" + # Build headers list as bytes tuples + headers = [] + for name, value in request.headers: + headers.append((name.lower().encode("latin-1"), value.encode("latin-1"))) + + # Extract subprotocols from Sec-WebSocket-Protocol header + subprotocols = [] + for name, value in request.headers: + if name == "SEC-WEBSOCKET-PROTOCOL": + subprotocols = [s.strip() for s in value.split(",")] + break + + scope = { + "type": "websocket", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "http_version": f"{request.version[0]}.{request.version[1]}", + "scheme": "wss" if request.scheme == "https" else "ws", + "path": request.path, + "raw_path": request.path.encode("latin-1") if request.path else b"", + "query_string": request.query.encode("latin-1") if request.query else b"", + "root_path": self.cfg.root_path or "", + "headers": headers, + "server": sockname if sockname else None, + "client": peername if peername else None, + "subprotocols": subprotocols, + } + + # Add state dict for lifespan sharing + if hasattr(self.worker, 'state'): + scope["state"] = self.worker.state + + return scope + + async def _send_response_start(self, status, headers, request): + """Send HTTP response status and headers.""" + # Build status line + reason = self._get_reason_phrase(status) + status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n" + + # Build headers + header_lines = [] + has_content_length = False + has_transfer_encoding = False + has_connection = False + + for name, value in headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + header_lines.append(f"{name}: {value}\r\n") + name_lower = name.lower() + if name_lower == "content-length": + has_content_length = True + elif name_lower == "transfer-encoding": + has_transfer_encoding = True + elif name_lower == "connection": + has_connection = True + + # Add server header if not present + header_lines.append("Server: gunicorn/asgi\r\n") + + response = status_line + "".join(header_lines) + "\r\n" + self.transport.write(response.encode("latin-1")) + + async def _send_body(self, body): + """Send response body chunk.""" + if body: + self.transport.write(body) + + async def _send_error_response(self, status, message): + """Send an error response.""" + body = message.encode("utf-8") + response = ( + f"HTTP/1.1 {status} {message}\r\n" + f"Content-Type: text/plain\r\n" + f"Content-Length: {len(body)}\r\n" + f"Connection: close\r\n" + f"\r\n" + ) + self.transport.write(response.encode("latin-1")) + self.transport.write(body) + + def _get_reason_phrase(self, status): + """Get HTTP reason phrase for status code.""" + reasons = { + 100: "Continue", + 101: "Switching Protocols", + 200: "OK", + 201: "Created", + 202: "Accepted", + 204: "No Content", + 206: "Partial Content", + 301: "Moved Permanently", + 302: "Found", + 303: "See Other", + 304: "Not Modified", + 307: "Temporary Redirect", + 308: "Permanent Redirect", + 400: "Bad Request", + 401: "Unauthorized", + 403: "Forbidden", + 404: "Not Found", + 405: "Method Not Allowed", + 408: "Request Timeout", + 409: "Conflict", + 410: "Gone", + 411: "Length Required", + 413: "Payload Too Large", + 414: "URI Too Long", + 415: "Unsupported Media Type", + 422: "Unprocessable Entity", + 429: "Too Many Requests", + 500: "Internal Server Error", + 501: "Not Implemented", + 502: "Bad Gateway", + 503: "Service Unavailable", + 504: "Gateway Timeout", + } + return reasons.get(status, "Unknown") + + def _close_transport(self): + """Close the transport safely.""" + if self.transport and not self._closed: + try: + self.transport.close() + except Exception: + pass + self._closed = True diff --git a/gunicorn/asgi/unreader.py b/gunicorn/asgi/unreader.py new file mode 100644 index 00000000..c8d9aa82 --- /dev/null +++ b/gunicorn/asgi/unreader.py @@ -0,0 +1,100 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Async version of gunicorn/http/unreader.py for ASGI workers. + +Provides async reading with pushback buffer support. +""" + +import io + + +class AsyncUnreader: + """Async socket reader with pushback buffer support. + + This class wraps an asyncio StreamReader and provides the ability + to "unread" data back into a buffer for re-parsing. + """ + + def __init__(self, reader, max_chunk=8192): + """Initialize the async unreader. + + Args: + reader: asyncio.StreamReader instance + max_chunk: Maximum bytes to read at once + """ + self.reader = reader + self.buf = io.BytesIO() + self.max_chunk = max_chunk + + async def read(self, size=None): + """Read data from the stream, using buffered data first. + + Args: + size: Number of bytes to read. If None, returns all buffered + data or reads a single chunk. + + Returns: + bytes: Data read from buffer or stream + """ + if size is not None and not isinstance(size, int): + raise TypeError("size parameter must be an int or long.") + + if size is not None: + if size == 0: + return b"" + if size < 0: + size = None + + # Move to end to check buffer size + self.buf.seek(0, io.SEEK_END) + + # If no size specified, return buffered data or read chunk + if size is None and self.buf.tell(): + ret = self.buf.getvalue() + self.buf = io.BytesIO() + return ret + if size is None: + chunk = await self._read_chunk() + return chunk + + # Read until we have enough data + while self.buf.tell() < size: + chunk = await self._read_chunk() + if not chunk: + ret = self.buf.getvalue() + self.buf = io.BytesIO() + return ret + self.buf.write(chunk) + + data = self.buf.getvalue() + self.buf = io.BytesIO() + self.buf.write(data[size:]) + return data[:size] + + async def _read_chunk(self): + """Read a chunk of data from the underlying stream.""" + try: + return await self.reader.read(self.max_chunk) + except Exception: + return b"" + + def unread(self, data): + """Push data back into the buffer for re-reading. + + Args: + data: bytes to push back + """ + if data: + self.buf.seek(0, io.SEEK_END) + self.buf.write(data) + + def has_buffered_data(self): + """Check if there's data in the pushback buffer.""" + pos = self.buf.tell() + self.buf.seek(0, io.SEEK_END) + has_data = self.buf.tell() > 0 + self.buf.seek(pos) + return has_data diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py new file mode 100644 index 00000000..bcde84ee --- /dev/null +++ b/gunicorn/asgi/websocket.py @@ -0,0 +1,369 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +WebSocket protocol handler for ASGI. + +Implements RFC 6455 WebSocket protocol for ASGI applications. +""" + +import asyncio +import base64 +import hashlib +import struct +import os + + +# WebSocket frame opcodes +OPCODE_CONTINUATION = 0x0 +OPCODE_TEXT = 0x1 +OPCODE_BINARY = 0x2 +OPCODE_CLOSE = 0x8 +OPCODE_PING = 0x9 +OPCODE_PONG = 0xA + +# WebSocket close codes +CLOSE_NORMAL = 1000 +CLOSE_GOING_AWAY = 1001 +CLOSE_PROTOCOL_ERROR = 1002 +CLOSE_UNSUPPORTED = 1003 +CLOSE_NO_STATUS = 1005 +CLOSE_ABNORMAL = 1006 +CLOSE_INVALID_DATA = 1007 +CLOSE_POLICY_VIOLATION = 1008 +CLOSE_MESSAGE_TOO_BIG = 1009 +CLOSE_MANDATORY_EXT = 1010 +CLOSE_INTERNAL_ERROR = 1011 + +# WebSocket handshake GUID (RFC 6455) +WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +class WebSocketProtocol: + """WebSocket connection handler for ASGI applications.""" + + def __init__(self, transport, reader, scope, app, log): + """Initialize WebSocket protocol handler. + + Args: + transport: asyncio transport for writing + reader: asyncio StreamReader for reading + scope: ASGI WebSocket scope dict + app: ASGI application callable + log: Logger instance + """ + self.transport = transport + self.reader = reader + self.scope = scope + self.app = app + self.log = log + + self.accepted = False + self.closed = False + self.close_code = None + self.close_reason = "" + + # Message reassembly state + self._fragments = [] + self._fragment_opcode = None + + # Receive queue for incoming messages + self._receive_queue = asyncio.Queue() + + async def run(self): + """Run the WebSocket ASGI application.""" + # Send initial connect event + await self._receive_queue.put({"type": "websocket.connect"}) + + # Start frame reading task + read_task = asyncio.create_task(self._read_frames()) + + try: + await self.app(self.scope, self._receive, self._send) + except Exception as e: + self.log.exception("Error in WebSocket ASGI application") + finally: + read_task.cancel() + try: + await read_task + except asyncio.CancelledError: + pass + + # Send close frame if not already closed + if not self.closed and self.accepted: + await self._send_close(CLOSE_INTERNAL_ERROR, "Application error") + + async def _receive(self): + """ASGI receive callable.""" + return await self._receive_queue.get() + + async def _send(self, message): + """ASGI send callable.""" + msg_type = message["type"] + + if msg_type == "websocket.accept": + if self.accepted: + raise RuntimeError("WebSocket already accepted") + await self._send_accept(message) + self.accepted = True + + elif msg_type == "websocket.send": + if not self.accepted: + raise RuntimeError("WebSocket not accepted") + if self.closed: + raise RuntimeError("WebSocket closed") + + if "text" in message: + await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8")) + elif "bytes" in message: + await self._send_frame(OPCODE_BINARY, message["bytes"]) + + elif msg_type == "websocket.close": + code = message.get("code", CLOSE_NORMAL) + reason = message.get("reason", "") + await self._send_close(code, reason) + self.closed = True + + async def _send_accept(self, message): + """Send WebSocket handshake accept response.""" + # Get Sec-WebSocket-Key from headers + ws_key = None + for name, value in self.scope["headers"]: + if name == b"sec-websocket-key": + ws_key = value + break + + if not ws_key: + raise RuntimeError("Missing Sec-WebSocket-Key header") + + # Calculate accept key + accept_key = base64.b64encode( + hashlib.sha1(ws_key + WS_GUID).digest() + ).decode("ascii") + + # Build response headers + headers = [ + "HTTP/1.1 101 Switching Protocols\r\n", + "Upgrade: websocket\r\n", + "Connection: Upgrade\r\n", + f"Sec-WebSocket-Accept: {accept_key}\r\n", + ] + + # Add selected subprotocol if specified + subprotocol = message.get("subprotocol") + if subprotocol: + headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n") + + # Add any extra headers from message + extra_headers = message.get("headers", []) + for name, value in extra_headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + headers.append(f"{name}: {value}\r\n") + + headers.append("\r\n") + self.transport.write("".join(headers).encode("latin-1")) + + async def _read_frames(self): + """Read and process incoming WebSocket frames.""" + try: + while not self.closed: + frame = await self._read_frame() + if frame is None: + break + + opcode, payload = frame + + if opcode == OPCODE_CLOSE: + await self._handle_close(payload) + break + elif opcode == OPCODE_PING: + await self._send_frame(OPCODE_PONG, payload) + elif opcode == OPCODE_PONG: + # Ignore pongs + pass + elif opcode == OPCODE_TEXT: + await self._receive_queue.put({ + "type": "websocket.receive", + "text": payload.decode("utf-8"), + }) + elif opcode == OPCODE_BINARY: + await self._receive_queue.put({ + "type": "websocket.receive", + "bytes": payload, + }) + elif opcode == OPCODE_CONTINUATION: + # Handle fragmented messages + await self._handle_continuation(payload) + + except asyncio.CancelledError: + raise + except Exception as e: + self.log.debug("WebSocket read error: %s", e) + finally: + # Signal disconnect + if not self.closed: + self.closed = True + await self._receive_queue.put({ + "type": "websocket.disconnect", + "code": self.close_code or CLOSE_ABNORMAL, + }) + + async def _read_frame(self): + """Read a single WebSocket frame. + + Returns: + tuple: (opcode, payload) or None if connection closed + """ + # Read frame header (2 bytes minimum) + header = await self._read_exact(2) + if not header: + return None + + first_byte, second_byte = header[0], header[1] + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0x0F + + # RSV bits must be 0 (no extensions) + if rsv1 or rsv2 or rsv3: + await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set") + return None + + masked = (second_byte >> 7) & 1 + payload_len = second_byte & 0x7F + + # Client frames must be masked (RFC 6455) + if not masked: + await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked") + return None + + # Extended payload length + if payload_len == 126: + ext_len = await self._read_exact(2) + if not ext_len: + return None + payload_len = struct.unpack("!H", ext_len)[0] + elif payload_len == 127: + ext_len = await self._read_exact(8) + if not ext_len: + return None + payload_len = struct.unpack("!Q", ext_len)[0] + + # Read masking key + masking_key = await self._read_exact(4) + if not masking_key: + return None + + # Read payload + payload = await self._read_exact(payload_len) + if payload is None: + return None + + # Unmask payload + payload = self._unmask(payload, masking_key) + + # Handle fragmented messages + if opcode == OPCODE_CONTINUATION: + if self._fragment_opcode is None: + await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation") + return None + self._fragments.append(payload) + if fin: + # Reassemble complete message + full_payload = b"".join(self._fragments) + final_opcode = self._fragment_opcode + self._fragments = [] + self._fragment_opcode = None + return (final_opcode, full_payload) + return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more + elif opcode in (OPCODE_TEXT, OPCODE_BINARY): + if not fin: + # Start of fragmented message + self._fragment_opcode = opcode + self._fragments = [payload] + return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more + return (opcode, payload) + else: + # Control frames + return (opcode, payload) + + async def _read_exact(self, n): + """Read exactly n bytes from the reader.""" + try: + data = await self.reader.readexactly(n) + return data + except asyncio.IncompleteReadError: + return None + except Exception: + return None + + def _unmask(self, payload, masking_key): + """Unmask WebSocket payload data.""" + if not payload: + return payload + # XOR each byte with corresponding mask byte + return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload)) + + async def _handle_close(self, payload): + """Handle incoming close frame.""" + if len(payload) >= 2: + self.close_code = struct.unpack("!H", payload[:2])[0] + self.close_reason = payload[2:].decode("utf-8", errors="replace") + else: + self.close_code = CLOSE_NO_STATUS + self.close_reason = "" + + # Echo close frame back if we haven't already sent one + if not self.closed: + await self._send_close(self.close_code, self.close_reason) + + self.closed = True + + async def _handle_continuation(self, payload): + """Handle continuation frame (already processed in _read_frame).""" + # This is called for partial fragments, nothing to do + pass + + async def _send_frame(self, opcode, payload): + """Send a WebSocket frame. + + Server frames are not masked (RFC 6455). + """ + if isinstance(payload, str): + payload = payload.encode("utf-8") + + length = len(payload) + frame = bytearray() + + # First byte: FIN + opcode + frame.append(0x80 | opcode) + + # Second byte: length (no mask bit for server) + if length < 126: + frame.append(length) + elif length < 65536: + frame.append(126) + frame.extend(struct.pack("!H", length)) + else: + frame.append(127) + frame.extend(struct.pack("!Q", length)) + + # Payload + frame.extend(payload) + + self.transport.write(bytes(frame)) + + async def _send_close(self, code, reason=""): + """Send a close frame.""" + payload = struct.pack("!H", code) + if reason: + payload += reason.encode("utf-8")[:123] # Max 125 bytes total + await self._send_frame(OPCODE_CLOSE, payload) + self.closed = True diff --git a/gunicorn/config.py b/gunicorn/config.py index 29b30ad2..522dcae9 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2440,3 +2440,94 @@ class HeaderMap(Setting): .. versionadded:: 22.0.0 """ + + +def validate_asgi_loop(val): + if val is None: + return "auto" + if not isinstance(val, str): + raise TypeError("Invalid type for casting: %s" % val) + val = val.lower().strip() + if val not in ("auto", "asyncio", "uvloop"): + raise ValueError("Invalid ASGI loop: %s" % val) + return val + + +def validate_asgi_lifespan(val): + if val is None: + return "auto" + if not isinstance(val, str): + raise TypeError("Invalid type for casting: %s" % val) + val = val.lower().strip() + if val not in ("auto", "on", "off"): + raise ValueError("Invalid ASGI lifespan: %s" % val) + return val + + +class ASGILoop(Setting): + name = "asgi_loop" + section = "Worker Processes" + cli = ["--asgi-loop"] + meta = "STRING" + validator = validate_asgi_loop + default = "auto" + desc = """\ + Event loop implementation for ASGI workers. + + - auto: Use uvloop if available, otherwise asyncio + - asyncio: Use Python's built-in asyncio event loop + - uvloop: Use uvloop (must be installed separately) + + This setting only affects the ``asgi`` worker type. + + uvloop typically provides better performance but requires + installing the uvloop package. + + .. versionadded:: 24.0.0 + """ + + +class ASGILifespan(Setting): + name = "asgi_lifespan" + section = "Worker Processes" + cli = ["--asgi-lifespan"] + meta = "STRING" + validator = validate_asgi_lifespan + default = "auto" + desc = """\ + Control ASGI lifespan protocol handling. + + - auto: Detect if app supports lifespan, enable if so + - on: Always run lifespan protocol (fail if unsupported) + - off: Never run lifespan protocol + + The lifespan protocol allows ASGI applications to run code at + startup and shutdown. This is essential for frameworks like + FastAPI that need to initialize database connections, caches, + or other resources. + + This setting only affects the ``asgi`` worker type. + + .. versionadded:: 24.0.0 + """ + + +class RootPath(Setting): + name = "root_path" + section = "Server Mechanics" + cli = ["--root-path"] + meta = "STRING" + validator = validate_string + default = "" + desc = """\ + The root path for ASGI applications. + + This is used to set the ``root_path`` in the ASGI scope, which + allows applications to know their mount point when behind a + reverse proxy. + + For example, if your application is mounted at ``/api``, set + this to ``/api``. + + .. versionadded:: 24.0.0 + """ diff --git a/gunicorn/workers/__init__.py b/gunicorn/workers/__init__.py index 3da5f85e..3beb0d70 100644 --- a/gunicorn/workers/__init__.py +++ b/gunicorn/workers/__init__.py @@ -11,4 +11,5 @@ SUPPORTED_WORKERS = { "gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker", "tornado": "gunicorn.workers.gtornado.TornadoWorker", "gthread": "gunicorn.workers.gthread.ThreadWorker", + "asgi": "gunicorn.workers.gasgi.ASGIWorker", } diff --git a/gunicorn/workers/gasgi.py b/gunicorn/workers/gasgi.py new file mode 100644 index 00000000..b0d57cf0 --- /dev/null +++ b/gunicorn/workers/gasgi.py @@ -0,0 +1,282 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI worker for gunicorn. + +Provides native asyncio-based ASGI support using gunicorn's own +HTTP parsing infrastructure. +""" + +import asyncio +import os +import signal +import ssl +import sys + +from gunicorn.workers import base +from gunicorn.asgi.protocol import ASGIProtocol + + +class ASGIWorker(base.Worker): + """ASGI worker using asyncio event loop. + + Supports: + - HTTP/1.1 with keepalive + - WebSocket connections + - Lifespan protocol (startup/shutdown hooks) + - Optional uvloop for improved performance + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.worker_connections = self.cfg.worker_connections + self.loop = None + self.servers = [] + self.nr_conns = 0 + self.lifespan = None + self.state = {} # Shared state for lifespan + + @classmethod + def check_config(cls, cfg, log): + """Validate configuration for ASGI worker.""" + if cfg.threads > 1: + log.warning("ASGI worker does not use threads configuration. " + "Use worker_connections instead.") + + def init_process(self): + """Initialize the worker process.""" + # Setup event loop before calling super() + self._setup_event_loop() + super().init_process() + + def _setup_event_loop(self): + """Setup the asyncio event loop.""" + loop_type = getattr(self.cfg, 'asgi_loop', 'auto') + + if loop_type == "auto": + try: + import uvloop + loop_type = "uvloop" + except ImportError: + loop_type = "asyncio" + + if loop_type == "uvloop": + try: + import uvloop + self.loop = uvloop.new_event_loop() + self.log.debug("Using uvloop event loop") + except ImportError: + self.log.warning("uvloop not available, falling back to asyncio") + self.loop = asyncio.new_event_loop() + else: + self.loop = asyncio.new_event_loop() + self.log.debug("Using asyncio event loop") + + asyncio.set_event_loop(self.loop) + + def load_wsgi(self): + """Load the ASGI application.""" + try: + self.asgi = self.app.wsgi() + except SyntaxError as e: + if not self.cfg.reload: + raise + self.log.exception(e) + self.asgi = self._make_error_app(str(e)) + + def _make_error_app(self, error_msg): + """Create an error ASGI app for syntax errors during reload.""" + async def error_app(scope, receive, send): + if scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 500, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": f"Application error: {error_msg}".encode(), + }) + elif scope["type"] == "lifespan": + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + message = await receive() + if message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return error_app + + def init_signals(self): + """Initialize signal handlers for asyncio.""" + # Reset all signals first + for s in self.SIGNALS: + signal.signal(s, signal.SIG_DFL) + + # Set up signal handlers via the event loop + self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal) + self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal) + self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal) + self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal) + self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal) + self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal) + + def handle_quit_signal(self): + """Handle SIGQUIT - immediate shutdown.""" + self.alive = False + self.cfg.worker_int(self) + + def handle_exit_signal(self): + """Handle SIGTERM - graceful shutdown.""" + self.alive = False + + def handle_usr1_signal(self): + """Handle SIGUSR1 - reopen log files.""" + self.log.reopen_files() + + def handle_winch_signal(self): + """Handle SIGWINCH - ignored in worker.""" + self.log.debug("worker: SIGWINCH ignored.") + + def handle_abort_signal(self): + """Handle SIGABRT - abort.""" + self.alive = False + self.cfg.worker_abort(self) + sys.exit(1) + + def run(self): + """Main entry point for the worker.""" + try: + self.loop.run_until_complete(self._serve()) + except Exception as e: + self.log.exception("Worker exception: %s", e) + finally: + self._cleanup() + + async def _serve(self): + """Main async serving loop.""" + # Run lifespan startup + lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto') + if lifespan_mode != "off": + from gunicorn.asgi.lifespan import LifespanManager + self.lifespan = LifespanManager(self.asgi, self.log, self.state) + try: + await self.lifespan.startup() + except Exception as e: + if lifespan_mode == "on": + self.log.error("ASGI lifespan startup failed: %s", e) + return + else: + # auto mode - app doesn't support lifespan + self.log.debug("ASGI lifespan not supported by app: %s", e) + self.lifespan = None + + # Create servers for each listener socket + ssl_context = self._get_ssl_context() + + for sock in self.sockets: + try: + server = await self.loop.create_server( + lambda: ASGIProtocol(self), + sock=sock.sock, + ssl=ssl_context, + reuse_address=True, + start_serving=True, + ) + self.servers.append(server) + self.log.info("ASGI server listening on %s", sock) + except Exception as e: + self.log.error("Failed to create server on %s: %s", sock, e) + + if not self.servers: + self.log.error("No servers could be started") + return + + # Main loop with heartbeat + try: + while self.alive: + self.notify() + + # Check if parent is still alive + if self.ppid != os.getppid(): + self.log.info("Parent changed, shutting down: %s", self) + break + + # Check connection limit + # (Connections are managed by nr_conns in ASGIProtocol) + + await asyncio.sleep(1.0) + + except asyncio.CancelledError: + pass + + # Graceful shutdown + await self._shutdown() + + async def _shutdown(self): + """Perform graceful shutdown.""" + self.log.info("Worker shutting down...") + + # Stop accepting new connections + for server in self.servers: + server.close() + + # Wait for servers to close + for server in self.servers: + await server.wait_closed() + + # Wait for in-flight connections (with timeout) + graceful_timeout = self.cfg.graceful_timeout + if self.nr_conns > 0: + self.log.info("Waiting for %d connections to finish...", self.nr_conns) + deadline = self.loop.time() + graceful_timeout + while self.nr_conns > 0 and self.loop.time() < deadline: + await asyncio.sleep(0.1) + + if self.nr_conns > 0: + self.log.warning("Closing %d connections after timeout", self.nr_conns) + + # Run lifespan shutdown + if self.lifespan: + try: + await self.lifespan.shutdown() + except Exception as e: + self.log.error("ASGI lifespan shutdown error: %s", e) + + def _get_ssl_context(self): + """Get SSL context if configured.""" + if not self.cfg.is_ssl: + return None + + try: + from gunicorn import sock + return sock.ssl_context(self.cfg) + except Exception as e: + self.log.error("Failed to create SSL context: %s", e) + return None + + def _cleanup(self): + """Clean up resources on exit.""" + try: + # Cancel all pending tasks + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + + # Run loop until all tasks are cancelled + if pending: + self.loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + + self.loop.close() + except Exception as e: + self.log.debug("Cleanup error: %s", e) + + # Close sockets + for s in self.sockets: + try: + s.close() + except Exception: + pass diff --git a/pyproject.toml b/pyproject.toml index ce681f65..7803dc55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ testing = [ "coverage", "pytest", "pytest-cov", + "pytest-asyncio", ] [project.scripts] diff --git a/tests/test_asgi.py b/tests/test_asgi.py new file mode 100644 index 00000000..227f7ea2 --- /dev/null +++ b/tests/test_asgi.py @@ -0,0 +1,285 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for ASGI worker components. +""" + +import asyncio +import io +import pytest +from unittest import mock + +from gunicorn.asgi.unreader import AsyncUnreader +from gunicorn.asgi.message import AsyncRequest + + +class MockStreamReader: + """Mock asyncio.StreamReader for testing.""" + + def __init__(self, data): + self.data = data + self.pos = 0 + + async def read(self, size=-1): + if self.pos >= len(self.data): + return b"" + if size < 0: + result = self.data[self.pos:] + self.pos = len(self.data) + else: + result = self.data[self.pos:self.pos + size] + self.pos += size + return result + + async def readexactly(self, n): + if self.pos + n > len(self.data): + raise asyncio.IncompleteReadError( + self.data[self.pos:], n + ) + result = self.data[self.pos:self.pos + n] + self.pos += n + return result + + +class MockConfig: + """Mock gunicorn config for testing.""" + + def __init__(self): + self.is_ssl = False + self.proxy_protocol = False + self.proxy_allow_ips = ["127.0.0.1"] + self.forwarded_allow_ips = ["127.0.0.1"] + self.secure_scheme_headers = {} + self.forwarder_headers = [] + self.limit_request_line = 8190 + self.limit_request_fields = 100 + self.limit_request_field_size = 8190 + self.permit_unconventional_http_method = False + self.permit_unconventional_http_version = False + self.permit_obsolete_folding = False + self.casefold_http_method = False + self.strip_header_spaces = False + self.header_map = "refuse" + + +# AsyncUnreader Tests + +@pytest.mark.asyncio +async def test_async_unreader_read_chunk(): + """Test basic chunk reading.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + data = await unreader.read() + assert data == b"hello world" + + +@pytest.mark.asyncio +async def test_async_unreader_read_size(): + """Test reading specific size.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + data = await unreader.read(5) + assert data == b"hello" + + +@pytest.mark.asyncio +async def test_async_unreader_unread(): + """Test unread functionality.""" + reader = MockStreamReader(b"hello world") + unreader = AsyncUnreader(reader) + + # Read all data + data = await unreader.read() + assert data == b"hello world" + + # Unread some data + unreader.unread(b"world") + + # Read again should get unread data + data = await unreader.read() + assert data == b"world" + + +@pytest.mark.asyncio +async def test_async_unreader_read_zero(): + """Test reading zero bytes.""" + reader = MockStreamReader(b"hello") + unreader = AsyncUnreader(reader) + data = await unreader.read(0) + assert data == b"" + + +@pytest.mark.asyncio +async def test_async_unreader_read_empty(): + """Test reading from empty stream.""" + reader = MockStreamReader(b"") + unreader = AsyncUnreader(reader) + data = await unreader.read() + assert data == b"" + + +# AsyncRequest Tests + +@pytest.mark.asyncio +async def test_async_request_simple_get(): + """Test parsing a simple GET request.""" + request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "GET" + assert request.path == "/path" + assert request.version == (1, 1) + assert ("HOST", "localhost") in request.headers + + +@pytest.mark.asyncio +async def test_async_request_with_query(): + """Test parsing request with query string.""" + request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "GET" + assert request.path == "/search" + assert request.query == "q=test&page=1" + + +@pytest.mark.asyncio +async def test_async_request_post_with_body(): + """Test parsing POST request with body.""" + request_data = ( + b"POST /submit HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 11\r\n" + b"\r\n" + b"hello=world" + ) + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.method == "POST" + assert request.path == "/submit" + assert request.content_length == 11 + + # Read body + body = await request.read_body(100) + assert body == b"hello=world" + + +@pytest.mark.asyncio +async def test_async_request_multiple_headers(): + """Test parsing request with multiple headers.""" + request_data = ( + b"GET / HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Accept: text/html\r\n" + b"Accept-Language: en-US\r\n" + b"Connection: keep-alive\r\n" + b"\r\n" + ) + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert len(request.headers) == 4 + assert request.get_header("HOST") == "localhost" + assert request.get_header("ACCEPT") == "text/html" + + +@pytest.mark.asyncio +async def test_async_request_should_close_http10(): + """Test connection close detection for HTTP/1.0.""" + request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.version == (1, 0) + assert request.should_close() is True + + +@pytest.mark.asyncio +async def test_async_request_should_close_connection_header(): + """Test connection close detection with Connection header.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.should_close() is True + + +@pytest.mark.asyncio +async def test_async_request_keepalive(): + """Test keepalive detection.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.should_close() is False + + +@pytest.mark.asyncio +async def test_async_request_no_body_for_get(): + """Test that GET requests have no body by default.""" + request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + assert request.content_length == 0 + body = await request.read_body() + assert body == b"" + + +# Error handling tests + +@pytest.mark.asyncio +async def test_async_request_invalid_method(): + """Test invalid HTTP method detection.""" + from gunicorn.http.errors import InvalidRequestMethod + + request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + with pytest.raises(InvalidRequestMethod): + await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) + + +@pytest.mark.asyncio +async def test_async_request_invalid_http_version(): + """Test invalid HTTP version detection.""" + from gunicorn.http.errors import InvalidHTTPVersion + + request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n" + reader = MockStreamReader(request_data) + unreader = AsyncUnreader(reader) + cfg = MockConfig() + + with pytest.raises(InvalidHTTPVersion): + await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) diff --git a/tests/test_asgi_worker.py b/tests/test_asgi_worker.py new file mode 100644 index 00000000..9266af4d --- /dev/null +++ b/tests/test_asgi_worker.py @@ -0,0 +1,643 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for the ASGI worker. + +Includes unit tests for worker components and integration tests +that actually start the server and make HTTP requests. +""" + +import asyncio +import errno +import os +import signal +import socket +import sys +import time +import threading +from unittest import mock + +import pytest + +from gunicorn.config import Config +from gunicorn.workers import gasgi + + +# ============================================================================ +# Mock Classes +# ============================================================================ + +class FakeSocket: + """Mock socket for testing.""" + + def __init__(self, data=b''): + self.data = data + self.closed = False + self.blocking = True + self._fileno = id(self) % 65536 + + def fileno(self): + return self._fileno + + def setblocking(self, blocking): + self.blocking = blocking + + def recv(self, size): + if self.closed: + raise OSError(errno.EBADF, "Bad file descriptor") + result = self.data[:size] + self.data = self.data[size:] + return result + + def send(self, data): + if self.closed: + raise OSError(errno.EPIPE, "Broken pipe") + return len(data) + + def close(self): + self.closed = True + + def getsockname(self): + return ('127.0.0.1', 8000) + + def getpeername(self): + return ('127.0.0.1', 12345) + + +class FakeApp: + """Mock ASGI application for testing.""" + + def __init__(self): + self.calls = [] + + def wsgi(self): + return self.asgi_app + + async def asgi_app(self, scope, receive, send): + self.calls.append(scope) + if scope["type"] == "lifespan": + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await send({"type": "lifespan.shutdown.complete"}) + return + elif scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Hello from ASGI!", + }) + + +class FakeListener: + """Mock listener socket.""" + + def __init__(self): + self.sock = FakeSocket() + + def getsockname(self): + return ('127.0.0.1', 8000) + + def close(self): + self.sock.close() + + def __str__(self): + return "http://127.0.0.1:8000" + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def _has_uvloop(): + """Check if uvloop is available.""" + try: + import uvloop + return True + except ImportError: + return False + + +# ============================================================================ +# Unit Tests for ASGIWorker +# ============================================================================ + +class TestASGIWorkerInit: + """Tests for ASGIWorker initialization.""" + + def create_worker(self, **kwargs): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + + for key, value in kwargs.items(): + cfg.set(key, value) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + return worker + + def test_worker_init(self): + """Test worker initialization.""" + worker = self.create_worker() + + assert worker.worker_connections == 1000 + assert worker.nr_conns == 0 + assert worker.loop is None + assert worker.servers == [] + assert worker.state == {} + + def test_worker_connections_config(self): + """Test worker_connections configuration.""" + worker = self.create_worker(worker_connections=500) + assert worker.worker_connections == 500 + + +class TestASGIWorkerEventLoop: + """Tests for event loop setup.""" + + def create_worker(self, **kwargs): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + + for key, value in kwargs.items(): + cfg.set(key, value) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + return worker + + def test_setup_asyncio_loop(self): + """Test asyncio event loop setup.""" + worker = self.create_worker(asgi_loop='asyncio') + worker._setup_event_loop() + + assert worker.loop is not None + assert isinstance(worker.loop, asyncio.AbstractEventLoop) + worker.loop.close() + + def test_setup_auto_loop_falls_back_to_asyncio(self): + """Test that auto mode uses asyncio when uvloop unavailable.""" + worker = self.create_worker(asgi_loop='auto') + + # Mock uvloop import failure + with mock.patch.dict('sys.modules', {'uvloop': None}): + worker._setup_event_loop() + + assert worker.loop is not None + worker.loop.close() + + @pytest.mark.skipif( + not _has_uvloop(), + reason="uvloop not installed" + ) + def test_setup_uvloop(self): + """Test uvloop event loop setup.""" + worker = self.create_worker(asgi_loop='uvloop') + worker._setup_event_loop() + + import uvloop + assert isinstance(worker.loop, uvloop.Loop) + worker.loop.close() + + +class TestASGIWorkerSignals: + """Tests for signal handling.""" + + def create_worker(self): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('worker_connections', 1000) + cfg.set('graceful_timeout', 5) + + worker = gasgi.ASGIWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=FakeApp(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + worker._setup_event_loop() + return worker + + def test_handle_exit_sets_alive_false(self): + """Test that exit signal sets alive=False.""" + worker = self.create_worker() + worker.alive = True + + worker.handle_exit_signal() + + assert worker.alive is False + worker.loop.close() + + def test_handle_quit_sets_alive_false(self): + """Test that quit signal sets alive=False.""" + worker = self.create_worker() + worker.alive = True + + # Mock the worker_int callback on the worker's cfg settings + with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None): + worker.handle_quit_signal() + + assert worker.alive is False + worker.loop.close() + + +# ============================================================================ +# Tests for Lifespan Protocol +# ============================================================================ + +class TestLifespanManager: + """Tests for ASGI lifespan protocol.""" + + @pytest.mark.asyncio + async def test_lifespan_startup_complete(self): + """Test successful lifespan startup.""" + from gunicorn.asgi.lifespan import LifespanManager + + startup_called = False + shutdown_called = False + + async def app(scope, receive, send): + nonlocal startup_called, shutdown_called + assert scope["type"] == "lifespan" + while True: + message = await receive() + if message["type"] == "lifespan.startup": + startup_called = True + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + shutdown_called = True + await send({"type": "lifespan.shutdown.complete"}) + return + + manager = LifespanManager(app, mock.Mock()) + await manager.startup() + + assert startup_called + assert manager._startup_complete.is_set() + assert not manager._startup_failed + + await manager.shutdown() + assert shutdown_called + + @pytest.mark.asyncio + async def test_lifespan_startup_failed(self): + """Test lifespan startup failure.""" + from gunicorn.asgi.lifespan import LifespanManager + + async def app(scope, receive, send): + message = await receive() + if message["type"] == "lifespan.startup": + await send({ + "type": "lifespan.startup.failed", + "message": "Database connection failed" + }) + + manager = LifespanManager(app, mock.Mock()) + + with pytest.raises(RuntimeError, match="Database connection failed"): + await manager.startup() + + @pytest.mark.asyncio + async def test_lifespan_state_shared(self): + """Test that lifespan state is shared with app.""" + from gunicorn.asgi.lifespan import LifespanManager + + state = {} + + async def app(scope, receive, send): + assert "state" in scope + scope["state"]["db"] = "connected" + message = await receive() + await send({"type": "lifespan.startup.complete"}) + message = await receive() + await send({"type": "lifespan.shutdown.complete"}) + + manager = LifespanManager(app, mock.Mock(), state) + await manager.startup() + + assert state.get("db") == "connected" + + await manager.shutdown() + + +# ============================================================================ +# Tests for WebSocket Protocol +# ============================================================================ + +class TestWebSocketProtocol: + """Tests for WebSocket protocol handling.""" + + def test_websocket_guid(self): + """Test WebSocket GUID constant.""" + from gunicorn.asgi.websocket import WS_GUID + assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + def test_websocket_opcodes(self): + """Test WebSocket opcode constants.""" + from gunicorn.asgi import websocket + + assert websocket.OPCODE_TEXT == 0x1 + assert websocket.OPCODE_BINARY == 0x2 + assert websocket.OPCODE_CLOSE == 0x8 + assert websocket.OPCODE_PING == 0x9 + assert websocket.OPCODE_PONG == 0xA + + def test_websocket_accept_key_calculation(self): + """Test WebSocket accept key calculation per RFC 6455.""" + import base64 + import hashlib + from gunicorn.asgi.websocket import WS_GUID + + # Example from RFC 6455 + client_key = b"dGhlIHNhbXBsZSBub25jZQ==" + expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" + + accept_key = base64.b64encode( + hashlib.sha1(client_key + WS_GUID).digest() + ).decode("ascii") + + assert accept_key == expected_accept + + def test_websocket_frame_masking(self): + """Test WebSocket frame unmasking.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + # Create a minimal protocol instance + protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + + # Test unmasking (XOR operation) + masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked + + unmasked = protocol._unmask(masked_data, masking_key) + assert unmasked == b"Hello" + + def test_websocket_frame_masking_empty(self): + """Test WebSocket frame unmasking with empty payload.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + + masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + unmasked = protocol._unmask(b"", masking_key) + assert unmasked == b"" + + +# ============================================================================ +# Integration Tests +# ============================================================================ + +class TestASGIIntegration: + """Integration tests that start actual servers.""" + + @pytest.fixture + def free_port(self): + """Get a free port for testing.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('127.0.0.1', 0)) + return s.getsockname()[1] + + @pytest.mark.asyncio + async def test_http_request_response(self, free_port): + """Test basic HTTP request/response cycle.""" + # Simple ASGI app + async def app(scope, receive, send): + if scope["type"] == "http": + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"content-type", b"text/plain")], + }) + await send({ + "type": "http.response.body", + "body": b"Hello, World!", + }) + + # Start server + loop = asyncio.get_event_loop() + server = await loop.create_server( + lambda: _TestProtocol(app), + '127.0.0.1', + free_port, + ) + + try: + # Use asyncio to make HTTP request + reader, writer = await asyncio.open_connection('127.0.0.1', free_port) + request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n" + writer.write(request.encode()) + await writer.drain() + + # Read response + response = await reader.read(4096) + response_text = response.decode() + + assert "HTTP/1.1 200" in response_text + assert "Hello, World!" in response_text + + writer.close() + await writer.wait_closed() + finally: + server.close() + await server.wait_closed() + + +class _TestProtocol(asyncio.Protocol): + """Minimal protocol for integration testing.""" + + def __init__(self, app): + self.app = app + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + # Very simple HTTP parsing for testing + asyncio.create_task(self._handle(data)) + + async def _handle(self, data): + # Parse basic HTTP request + lines = data.decode().split('\r\n') + method, path, _ = lines[0].split(' ') + + scope = { + "type": "http", + "asgi": {"version": "3.0"}, + "http_version": "1.1", + "method": method, + "path": path, + "query_string": b"", + "headers": [], + "server": ("127.0.0.1", 8000), + "client": ("127.0.0.1", 12345), + } + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + async def send(message): + if message["type"] == "http.response.start": + status = message["status"] + headers = message.get("headers", []) + response = f"HTTP/1.1 {status} OK\r\n" + for name, value in headers: + if isinstance(name, bytes): + name = name.decode() + if isinstance(value, bytes): + value = value.decode() + response += f"{name}: {value}\r\n" + response += "\r\n" + self.transport.write(response.encode()) + elif message["type"] == "http.response.body": + body = message.get("body", b"") + self.transport.write(body) + if not message.get("more_body", False): + self.transport.close() + + await self.app(scope, receive, send) + + +# ============================================================================ +# ASGI Protocol Tests +# ============================================================================ + +class TestASGIProtocol: + """Tests for ASGIProtocol.""" + + def test_reason_phrases(self): + """Test HTTP reason phrase lookup.""" + from gunicorn.asgi.protocol import ASGIProtocol + + # Create minimal worker mock + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + protocol = ASGIProtocol(worker) + + assert protocol._get_reason_phrase(200) == "OK" + assert protocol._get_reason_phrase(404) == "Not Found" + assert protocol._get_reason_phrase(500) == "Internal Server Error" + assert protocol._get_reason_phrase(999) == "Unknown" + + def test_scope_building(self): + """Test HTTP scope building.""" + from gunicorn.asgi.protocol import ASGIProtocol + from gunicorn.asgi.message import AsyncRequest + + worker = mock.Mock() + worker.cfg = Config() + worker.cfg.set('root_path', '/api') + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + protocol = ASGIProtocol(worker) + + # Create mock request + request = mock.Mock() + request.method = "GET" + request.path = "/users" + request.query = "page=1" + request.version = (1, 1) + request.scheme = "http" + request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")] + + scope = protocol._build_http_scope( + request, + ("127.0.0.1", 8000), # sockname + ("127.0.0.1", 12345), # peername + ) + + assert scope["type"] == "http" + assert scope["method"] == "GET" + assert scope["path"] == "/users" + assert scope["query_string"] == b"page=1" + assert scope["root_path"] == "/api" + assert scope["http_version"] == "1.1" + + +# ============================================================================ +# Config Tests +# ============================================================================ + +class TestASGIConfig: + """Tests for ASGI configuration options.""" + + def test_asgi_loop_default(self): + """Test default asgi_loop value.""" + cfg = Config() + assert cfg.asgi_loop == "auto" + + def test_asgi_loop_validation(self): + """Test asgi_loop validation.""" + cfg = Config() + + cfg.set('asgi_loop', 'asyncio') + assert cfg.asgi_loop == 'asyncio' + + cfg.set('asgi_loop', 'uvloop') + assert cfg.asgi_loop == 'uvloop' + + with pytest.raises(ValueError): + cfg.set('asgi_loop', 'invalid') + + def test_asgi_lifespan_default(self): + """Test default asgi_lifespan value.""" + cfg = Config() + assert cfg.asgi_lifespan == "auto" + + def test_asgi_lifespan_validation(self): + """Test asgi_lifespan validation.""" + cfg = Config() + + cfg.set('asgi_lifespan', 'on') + assert cfg.asgi_lifespan == 'on' + + cfg.set('asgi_lifespan', 'off') + assert cfg.asgi_lifespan == 'off' + + with pytest.raises(ValueError): + cfg.set('asgi_lifespan', 'invalid') + + def test_root_path_default(self): + """Test default root_path value.""" + cfg = Config() + assert cfg.root_path == "" + + def test_root_path_setting(self): + """Test root_path configuration.""" + cfg = Config() + cfg.set('root_path', '/api/v1') + assert cfg.root_path == '/api/v1'