asgi: Add native ASGI worker with HTTP and WebSocket support

Add a new ASGI worker type that provides native async support using
gunicorn's own HTTP parsing infrastructure adapted for asyncio.

Features:
- HTTP/1.1 with keepalive support
- WebSocket connections (RFC 6455)
- ASGI lifespan protocol for startup/shutdown hooks
- Optional uvloop support for improved performance
- Full proxy protocol support (inherited from gunicorn)

New configuration options:
- --asgi-loop: Event loop selection (auto/asyncio/uvloop)
- --asgi-lifespan: Lifespan protocol control (auto/on/off)
- --root-path: ASGI root path for reverse proxy setups

Usage: gunicorn -k asgi myapp:app
This commit is contained in:
Benoit Chesneau 2026-01-22 17:05:29 +01:00
parent ea98400820
commit ae1eea8108
15 changed files with 3334 additions and 0 deletions

View File

@ -0,0 +1,7 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI example applications for gunicorn.
"""

130
examples/asgi/basic_app.py Normal file
View File

@ -0,0 +1,130 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Basic ASGI application example.
Run with:
gunicorn -k asgi examples.asgi.basic_app:app
Test with:
curl http://127.0.0.1:8000/
curl http://127.0.0.1:8000/hello
curl -X POST http://127.0.0.1:8000/echo -d "test data"
"""
async def app(scope, receive, send):
"""Simple ASGI application demonstrating basic functionality."""
if scope["type"] == "lifespan":
await handle_lifespan(scope, receive, send)
elif scope["type"] == "http":
await handle_http(scope, receive, send)
else:
raise ValueError(f"Unknown scope type: {scope['type']}")
async def handle_lifespan(scope, receive, send):
"""Handle lifespan events (startup/shutdown)."""
while True:
message = await receive()
if message["type"] == "lifespan.startup":
print("ASGI application starting up...")
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
print("ASGI application shutting down...")
await send({"type": "lifespan.shutdown.complete"})
return
async def handle_http(scope, receive, send):
"""Handle HTTP requests."""
path = scope["path"]
method = scope["method"]
if path == "/" and method == "GET":
await send_response(send, 200, b"Welcome to gunicorn ASGI!\n")
elif path == "/hello" and method == "GET":
name = get_query_param(scope, "name", "World")
body = f"Hello, {name}!\n".encode()
await send_response(send, 200, body)
elif path == "/echo" and method == "POST":
body = await read_body(receive)
await send_response(send, 200, body, content_type=b"application/octet-stream")
elif path == "/headers":
headers_info = format_headers(scope["headers"])
await send_response(send, 200, headers_info.encode())
elif path == "/info":
info = format_request_info(scope)
await send_response(send, 200, info.encode(), content_type=b"application/json")
else:
await send_response(send, 404, b"Not Found\n")
async def send_response(send, status, body, content_type=b"text/plain"):
"""Send an HTTP response."""
await send({
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", content_type),
(b"content-length", str(len(body)).encode()),
],
})
await send({
"type": "http.response.body",
"body": body,
})
async def read_body(receive):
"""Read the full request body."""
body = b""
while True:
message = await receive()
body += message.get("body", b"")
if not message.get("more_body", False):
break
return body
def get_query_param(scope, name, default=None):
"""Get a query parameter value."""
query_string = scope.get("query_string", b"").decode()
for param in query_string.split("&"):
if "=" in param:
key, value = param.split("=", 1)
if key == name:
return value
return default
def format_headers(headers):
"""Format headers for display."""
lines = ["Request Headers:"]
for name, value in headers:
lines.append(f" {name.decode()}: {value.decode()}")
return "\n".join(lines) + "\n"
def format_request_info(scope):
"""Format request info as JSON."""
import json
info = {
"method": scope["method"],
"path": scope["path"],
"query_string": scope.get("query_string", b"").decode(),
"http_version": scope["http_version"],
"scheme": scope["scheme"],
"server": list(scope.get("server") or []),
"client": list(scope.get("client") or []),
"root_path": scope.get("root_path", ""),
}
return json.dumps(info, indent=2) + "\n"

View File

@ -0,0 +1,235 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
WebSocket ASGI application example.
Run with:
gunicorn -k asgi examples.asgi.websocket_app:app
Test with:
# Using websocat (install with: cargo install websocat)
websocat ws://127.0.0.1:8000/ws
# Or using Python websockets library
python -c "
import asyncio
import websockets
async def test():
async with websockets.connect('ws://127.0.0.1:8000/ws') as ws:
await ws.send('Hello')
print(await ws.recv())
asyncio.run(test())
"
"""
async def app(scope, receive, send):
"""ASGI application with WebSocket support."""
if scope["type"] == "lifespan":
await handle_lifespan(scope, receive, send)
elif scope["type"] == "http":
await handle_http(scope, receive, send)
elif scope["type"] == "websocket":
await handle_websocket(scope, receive, send)
else:
raise ValueError(f"Unknown scope type: {scope['type']}")
async def handle_lifespan(scope, receive, send):
"""Handle lifespan events."""
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
async def handle_http(scope, receive, send):
"""Handle HTTP requests - serve a simple HTML page for WebSocket testing."""
path = scope["path"]
if path == "/":
html = HTML_PAGE.encode()
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html"),
(b"content-length", str(len(html)).encode()),
],
})
await send({
"type": "http.response.body",
"body": html,
})
else:
await send({
"type": "http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Not Found",
})
async def handle_websocket(scope, receive, send):
"""Handle WebSocket connections."""
path = scope["path"]
if path == "/ws":
await echo_websocket(scope, receive, send)
elif path == "/ws/chat":
await chat_websocket(scope, receive, send)
else:
# Reject the connection
await send({"type": "websocket.close", "code": 4004})
async def echo_websocket(scope, receive, send):
"""Echo WebSocket - sends back whatever it receives."""
# Wait for connection
message = await receive()
if message["type"] != "websocket.connect":
return
# Accept the connection
await send({"type": "websocket.accept"})
# Echo loop
try:
while True:
message = await receive()
if message["type"] == "websocket.disconnect":
break
if message["type"] == "websocket.receive":
if "text" in message:
# Echo text back
await send({
"type": "websocket.send",
"text": f"Echo: {message['text']}"
})
elif "bytes" in message:
# Echo bytes back
await send({
"type": "websocket.send",
"bytes": message["bytes"]
})
except Exception as e:
print(f"WebSocket error: {e}")
finally:
try:
await send({"type": "websocket.close", "code": 1000})
except Exception:
pass
async def chat_websocket(scope, receive, send):
"""Chat WebSocket - simple broadcast example."""
message = await receive()
if message["type"] != "websocket.connect":
return
await send({
"type": "websocket.accept",
"subprotocol": "chat"
})
await send({
"type": "websocket.send",
"text": "Welcome to the chat! Send messages and they will be echoed back."
})
try:
while True:
message = await receive()
if message["type"] == "websocket.disconnect":
break
if message["type"] == "websocket.receive" and "text" in message:
text = message["text"]
await send({
"type": "websocket.send",
"text": f"[You]: {text}"
})
except Exception:
pass
HTML_PAGE = """<!DOCTYPE html>
<html>
<head>
<title>WebSocket Test</title>
<style>
body { font-family: sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
#messages { border: 1px solid #ccc; height: 300px; overflow-y: auto; padding: 10px; margin-bottom: 10px; }
#input { width: 80%; padding: 10px; }
button { padding: 10px 20px; }
.sent { color: blue; }
.received { color: green; }
.error { color: red; }
</style>
</head>
<body>
<h1>WebSocket Test</h1>
<div id="messages"></div>
<input type="text" id="input" placeholder="Type a message...">
<button onclick="sendMessage()">Send</button>
<button onclick="connectWS()">Connect</button>
<button onclick="disconnectWS()">Disconnect</button>
<script>
let ws = null;
const messages = document.getElementById('messages');
const input = document.getElementById('input');
function log(msg, className) {
const div = document.createElement('div');
div.className = className || '';
div.textContent = msg;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function connectWS() {
if (ws) {
log('Already connected', 'error');
return;
}
ws = new WebSocket('ws://' + window.location.host + '/ws');
ws.onopen = () => log('Connected!', 'received');
ws.onclose = () => { log('Disconnected', 'error'); ws = null; };
ws.onerror = (e) => log('Error: ' + e, 'error');
ws.onmessage = (e) => log(e.data, 'received');
}
function disconnectWS() {
if (ws) ws.close();
}
function sendMessage() {
if (!ws) { log('Not connected', 'error'); return; }
const msg = input.value;
if (!msg) return;
ws.send(msg);
log('Sent: ' + msg, 'sent');
input.value = '';
}
input.onkeypress = (e) => { if (e.key === 'Enter') sendMessage(); };
// Auto-connect
connectWS();
</script>
</body>
</html>
"""

26
gunicorn/asgi/__init__.py Normal file
View File

@ -0,0 +1,26 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI support for gunicorn.
This module provides native ASGI worker support, using gunicorn's own
HTTP parsing infrastructure adapted for async I/O.
Components:
- AsyncUnreader: Async socket reading with pushback buffer
- AsyncRequest: Async HTTP request parser
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
- LifespanManager: ASGI lifespan protocol support
Usage:
gunicorn -k asgi myapp:app
"""
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
from gunicorn.asgi.lifespan import LifespanManager
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']

178
gunicorn/asgi/lifespan.py Normal file
View File

@ -0,0 +1,178 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI lifespan protocol manager.
Manages startup and shutdown events for ASGI applications,
enabling frameworks like FastAPI to run initialization and
cleanup code.
"""
import asyncio
class LifespanManager:
"""Manages ASGI lifespan events (startup/shutdown).
The lifespan protocol allows ASGI applications to run code at
startup and shutdown. This is essential for applications that
need to initialize database connections, caches, or other
resources.
ASGI lifespan messages:
- Server sends: {"type": "lifespan.startup"}
- App responds: {"type": "lifespan.startup.complete"} or
{"type": "lifespan.startup.failed", "message": "..."}
- Server sends: {"type": "lifespan.shutdown"}
- App responds: {"type": "lifespan.shutdown.complete"}
"""
def __init__(self, app, logger, state=None):
"""Initialize the lifespan manager.
Args:
app: ASGI application callable
logger: Logger instance
state: Shared state dict for the application
"""
self.app = app
self.logger = logger
self.state = state if state is not None else {}
self._startup_complete = asyncio.Event()
self._shutdown_complete = asyncio.Event()
self._startup_failed = False
self._startup_error = None
self._shutdown_error = None
self._receive_queue = asyncio.Queue()
self._task = None
self._app_finished = False
async def startup(self):
"""Run lifespan startup and wait for completion.
Raises:
RuntimeError: If startup fails or app doesn't support lifespan
"""
scope = {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"state": self.state,
}
# Send startup event
await self._receive_queue.put({"type": "lifespan.startup"})
# Run lifespan in background task
self._task = asyncio.create_task(self._run_lifespan(scope))
# Wait for startup with timeout
try:
await asyncio.wait_for(
self._startup_complete.wait(),
timeout=30.0 # Reasonable startup timeout
)
except asyncio.TimeoutError:
if self._task:
self._task.cancel()
raise RuntimeError("Lifespan startup timed out")
if self._startup_failed:
if self._task:
self._task.cancel()
msg = self._startup_error or "Unknown error"
raise RuntimeError(f"Lifespan startup failed: {msg}")
self.logger.debug("ASGI lifespan startup complete")
async def shutdown(self):
"""Signal shutdown and wait for completion.
This should be called during graceful shutdown.
"""
if self._app_finished:
self.logger.debug("ASGI lifespan already finished")
return
# Send shutdown event
await self._receive_queue.put({"type": "lifespan.shutdown"})
# Wait for shutdown with timeout
try:
await asyncio.wait_for(
self._shutdown_complete.wait(),
timeout=30.0 # Reasonable shutdown timeout
)
except asyncio.TimeoutError:
self.logger.warning("Lifespan shutdown timed out")
if self._shutdown_error:
self.logger.error("Lifespan shutdown error: %s", self._shutdown_error)
# Cancel the task if still running
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self.logger.debug("ASGI lifespan shutdown complete")
async def _run_lifespan(self, scope):
"""Run the ASGI lifespan protocol."""
try:
await self.app(scope, self._receive, self._send)
except asyncio.CancelledError:
raise
except Exception as e:
self.logger.debug("Lifespan application raised: %s", e)
# If startup hasn't completed, mark it as failed
if not self._startup_complete.is_set():
self._startup_failed = True
self._startup_error = str(e)
self._startup_complete.set()
# If shutdown hasn't completed, mark error
elif not self._shutdown_complete.is_set():
self._shutdown_error = str(e)
self._shutdown_complete.set()
finally:
self._app_finished = True
# Ensure events are set to unblock waiters
if not self._startup_complete.is_set():
self._startup_failed = True
self._startup_error = "Application exited before startup complete"
self._startup_complete.set()
if not self._shutdown_complete.is_set():
self._shutdown_complete.set()
async def _receive(self):
"""ASGI receive callable for lifespan."""
return await self._receive_queue.get()
async def _send(self, message):
"""ASGI send callable for lifespan."""
msg_type = message["type"]
if msg_type == "lifespan.startup.complete":
self._startup_complete.set()
self.logger.debug("Received lifespan.startup.complete")
elif msg_type == "lifespan.startup.failed":
self._startup_failed = True
self._startup_error = message.get("message", "")
self._startup_complete.set()
self.logger.debug("Received lifespan.startup.failed: %s",
self._startup_error)
elif msg_type == "lifespan.shutdown.complete":
self._shutdown_complete.set()
self.logger.debug("Received lifespan.shutdown.complete")
elif msg_type == "lifespan.shutdown.failed":
self._shutdown_error = message.get("message", "")
self._shutdown_complete.set()
self.logger.debug("Received lifespan.shutdown.failed: %s",
self._shutdown_error)

562
gunicorn/asgi/message.py Normal file
View File

@ -0,0 +1,562 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Async version of gunicorn/http/message.py for ASGI workers.
Reuses the parsing logic from the sync version, adapted for async I/O.
"""
import io
import re
import socket
from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData,
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
LimitRequestLine, LimitRequestHeaders,
UnsupportedTransferCoding, ObsoleteFolding,
InvalidProxyLine, ForbiddenProxyRequest,
InvalidSchemeHeaders,
)
from gunicorn.util import bytes_to_str, split_request_uri
MAX_REQUEST_LINE = 8190
MAX_HEADERS = 32768
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
# Reuse regex patterns from sync version
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
METHOD_BADCHAR_RE = re.compile("[a-z#]")
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
class AsyncRequest:
"""Async HTTP request parser.
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
parsing logic where possible.
"""
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.remote_addr = peer_addr
self.req_number = req_number
self.version = None
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = None
self.headers = []
self.trailers = []
self.scheme = "https" if cfg.is_ssl else "http"
self.must_close = False
self.proxy_protocol_info = None
# Request line limit
self.limit_request_line = cfg.limit_request_line
if (self.limit_request_line < 0
or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE
# Headers limits
self.limit_request_fields = cfg.limit_request_fields
if (self.limit_request_fields <= 0
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = cfg.limit_request_field_size
if self.limit_request_field_size < 0:
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
# Max header buffer size
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
self.max_buffer_headers = self.limit_request_fields * \
(max_header_field_size + 2) + 4
# Body-related state
self.content_length = None
self.chunked = False
self._body_reader = None
self._body_remaining = 0
@classmethod
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
"""Parse an HTTP request from the stream.
Args:
cfg: gunicorn config object
unreader: AsyncUnreader instance
peer_addr: client address tuple
req_number: request number on this connection (for keepalive)
Returns:
AsyncRequest: Parsed request object
Raises:
NoMoreData: If no data available
Various parsing errors for malformed requests
"""
req = cls(cfg, unreader, peer_addr, req_number)
await req._parse()
return req
async def _parse(self):
"""Parse the request from the unreader."""
buf = io.BytesIO()
await self._get_data(buf, stop=True)
# Get request line
line, rbuf = await self._read_line(buf, self.limit_request_line)
# Proxy protocol
if self._proxy_protocol(bytes_to_str(line)):
# Get next request line
buf = io.BytesIO()
buf.write(rbuf)
line, rbuf = await self._read_line(buf, self.limit_request_line)
self._parse_request_line(line)
buf = io.BytesIO()
buf.write(rbuf)
# Headers
data = buf.getvalue()
while True:
idx = data.find(b"\r\n\r\n")
done = data[:2] == b"\r\n"
if idx < 0 and not done:
await self._get_data(buf)
data = buf.getvalue()
if len(data) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers")
else:
break
if done:
self.unreader.unread(data[2:])
else:
self.headers = self._parse_headers(data[:idx], from_trailer=False)
self.unreader.unread(data[idx + 4:])
self._set_body_reader()
async def _get_data(self, buf, stop=False):
"""Read data from unreader into buffer."""
data = await self.unreader.read()
if not data:
if stop:
raise StopIteration()
raise NoMoreData(buf.getvalue())
buf.write(data)
async def _read_line(self, buf, limit=0):
"""Read a line from the buffer/stream."""
data = buf.getvalue()
while True:
idx = data.find(b"\r\n")
if idx >= 0:
if idx > limit > 0:
raise LimitRequestLine(idx, limit)
break
if len(data) - 2 > limit > 0:
raise LimitRequestLine(len(data), limit)
await self._get_data(buf)
data = buf.getvalue()
return (data[:idx], data[idx + 2:])
def _proxy_protocol(self, line):
"""Detect, check and parse proxy protocol."""
if not self.cfg.proxy_protocol:
return False
if self.req_number != 1:
return False
if not line.startswith("PROXY"):
return False
self._proxy_protocol_access_check()
self._parse_proxy_protocol(line)
return True
def _proxy_protocol_access_check(self):
"""Check if proxy protocol is allowed from this peer."""
if ("*" not in self.cfg.proxy_allow_ips and
isinstance(self.peer_addr, tuple) and
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(self.peer_addr[0])
def _parse_proxy_protocol(self, line):
"""Parse proxy protocol header line."""
bits = line.split(" ")
if len(bits) != 6:
raise InvalidProxyLine(line)
proto = bits[1]
s_addr = bits[2]
d_addr = bits[3]
if proto not in ["TCP4", "TCP6"]:
raise InvalidProxyLine("protocol '%s' not supported" % proto)
if proto == "TCP4":
try:
socket.inet_pton(socket.AF_INET, s_addr)
socket.inet_pton(socket.AF_INET, d_addr)
except OSError:
raise InvalidProxyLine(line)
elif proto == "TCP6":
try:
socket.inet_pton(socket.AF_INET6, s_addr)
socket.inet_pton(socket.AF_INET6, d_addr)
except OSError:
raise InvalidProxyLine(line)
try:
s_port = int(bits[4])
d_port = int(bits[5])
except ValueError:
raise InvalidProxyLine("invalid port %s" % line)
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
raise InvalidProxyLine("invalid port %s" % line)
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
def _parse_request_line(self, line_bytes):
"""Parse the HTTP request line."""
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
if len(bits) != 3:
raise InvalidRequestLine(bytes_to_str(line_bytes))
# Method
self.method = bits[0]
if not self.cfg.permit_unconventional_http_method:
if METHOD_BADCHAR_RE.search(self.method):
raise InvalidRequestMethod(self.method)
if not 3 <= len(bits[0]) <= 20:
raise InvalidRequestMethod(self.method)
if not TOKEN_RE.fullmatch(self.method):
raise InvalidRequestMethod(self.method)
if self.cfg.casefold_http_method:
self.method = self.method.upper()
# URI
self.uri = bits[1]
if len(self.uri) == 0:
raise InvalidRequestLine(bytes_to_str(line_bytes))
try:
parts = split_request_uri(self.uri)
except ValueError:
raise InvalidRequestLine(bytes_to_str(line_bytes))
self.path = parts.path or ""
self.query = parts.query or ""
self.fragment = parts.fragment or ""
# Version
match = VERSION_RE.fullmatch(bits[2])
if match is None:
raise InvalidHTTPVersion(bits[2])
self.version = (int(match.group(1)), int(match.group(2)))
if not (1, 0) <= self.version < (2, 0):
if not self.cfg.permit_unconventional_http_version:
raise InvalidHTTPVersion(self.version)
def _parse_headers(self, data, from_trailer=False):
"""Parse HTTP headers from raw data."""
cfg = self.cfg
headers = []
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
# Handle scheme headers
scheme_header = False
secure_scheme_headers = {}
forwarder_headers = []
if from_trailer:
pass
elif ('*' in cfg.forwarded_allow_ips or
not isinstance(self.peer_addr, tuple)
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers
while lines:
if len(headers) >= self.limit_request_fields:
raise LimitRequestHeaders("limit request headers fields")
curr = lines.pop(0)
header_length = len(curr) + len("\r\n")
if curr.find(":") <= 0:
raise InvalidHeader(curr)
name, value = curr.split(":", 1)
if self.cfg.strip_header_spaces:
name = name.rstrip(" \t")
if not TOKEN_RE.fullmatch(name):
raise InvalidHeaderName(name)
name = name.upper()
value = [value.strip(" \t")]
# Consume value continuation lines
while lines and lines[0].startswith((" ", "\t")):
if not self.cfg.permit_obsolete_folding:
raise ObsoleteFolding(name)
curr = lines.pop(0)
header_length += len(curr) + len("\r\n")
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
value.append(curr.strip("\t "))
value = " ".join(value)
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
raise InvalidHeader(name)
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_header:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_header = True
self.scheme = scheme
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
pass
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue
else:
raise InvalidHeaderName(name)
headers.append((name, value))
return headers
def _set_body_reader(self):
"""Determine how to read the request body."""
chunked = False
content_length = None
for (name, value) in self.headers:
if name == "CONTENT-LENGTH":
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
content_length = value
elif name == "TRANSFER-ENCODING":
vals = [v.strip() for v in value.split(',')]
for val in vals:
if val.lower() == "chunked":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
chunked = True
elif val.lower() == "identity":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
elif val.lower() in ('compress', 'deflate', 'gzip'):
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
self.force_close()
else:
raise UnsupportedTransferCoding(value)
if chunked:
if self.version < (1, 1):
raise InvalidHeader("TRANSFER-ENCODING", req=self)
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.chunked = True
self.content_length = None
self._body_remaining = -1
elif content_length is not None:
try:
if str(content_length).isnumeric():
content_length = int(content_length)
else:
raise InvalidHeader("CONTENT-LENGTH", req=self)
except ValueError:
raise InvalidHeader("CONTENT-LENGTH", req=self)
if content_length < 0:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.content_length = content_length
self._body_remaining = content_length
else:
# No body for requests without Content-Length or Transfer-Encoding
self.content_length = 0
self._body_remaining = 0
def force_close(self):
"""Mark connection for closing after this request."""
self.must_close = True
def should_close(self):
"""Check if connection should be closed after this request."""
if self.must_close:
return True
for (h, v) in self.headers:
if h == "CONNECTION":
v = v.lower().strip(" \t")
if v == "close":
return True
elif v == "keep-alive":
return False
break
return self.version <= (1, 0)
def get_header(self, name):
"""Get a header value by name (case-insensitive)."""
name = name.upper()
for (h, v) in self.headers:
if h == name:
return v
return None
async def read_body(self, size=8192):
"""Read a chunk of the request body.
Args:
size: Maximum bytes to read
Returns:
bytes: Body data, empty bytes when body is exhausted
"""
if self._body_remaining == 0:
return b""
if self.chunked:
return await self._read_chunked_body(size)
else:
return await self._read_length_body(size)
async def _read_length_body(self, size):
"""Read from a length-delimited body."""
if self._body_remaining <= 0:
return b""
to_read = min(size, self._body_remaining)
data = await self.unreader.read(to_read)
if data:
self._body_remaining -= len(data)
return data
async def _read_chunked_body(self, size):
"""Read from a chunked body."""
if self._body_reader is None:
self._body_reader = self._chunked_body_reader()
try:
return await self._body_reader.__anext__()
except StopAsyncIteration:
self._body_remaining = 0
return b""
async def _chunked_body_reader(self):
"""Async generator for reading chunked body."""
while True:
# Read chunk size line
size_line = await self._read_chunk_size_line()
# Parse chunk size (handle extensions)
chunk_size, *_ = size_line.split(b";", 1)
if _ :
chunk_size = chunk_size.rstrip(b" \t")
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
raise InvalidHeader("Invalid chunk size")
if len(chunk_size) == 0:
raise InvalidHeader("Invalid chunk size")
chunk_size = int(chunk_size, 16)
if chunk_size == 0:
# Final chunk - skip trailers and final CRLF
await self._skip_trailers()
return
# Read chunk data
remaining = chunk_size
while remaining > 0:
data = await self.unreader.read(min(remaining, 8192))
if not data:
raise NoMoreData()
remaining -= len(data)
yield data
# Skip chunk terminating CRLF
crlf = await self.unreader.read(2)
if crlf != b"\r\n":
# May have partial read, try to get the rest
while len(crlf) < 2:
more = await self.unreader.read(2 - len(crlf))
if not more:
break
crlf += more
if crlf != b"\r\n":
raise InvalidHeader("Missing chunk terminator")
async def _read_chunk_size_line(self):
"""Read a chunk size line."""
buf = io.BytesIO()
while True:
data = await self.unreader.read(1)
if not data:
raise NoMoreData()
buf.write(data)
if buf.getvalue().endswith(b"\r\n"):
return buf.getvalue()[:-2]
async def _skip_trailers(self):
"""Skip trailer headers after chunked body."""
buf = io.BytesIO()
while True:
data = await self.unreader.read(1)
if not data:
return
buf.write(data)
content = buf.getvalue()
if content.endswith(b"\r\n\r\n"):
# Could parse trailers here if needed
return
if content == b"\r\n":
return
async def drain_body(self):
"""Drain any unread body data.
Should be called before reusing connection for keepalive.
"""
while True:
data = await self.read_body(8192)
if not data:
break

424
gunicorn/asgi/protocol.py Normal file
View File

@ -0,0 +1,424 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI protocol handler for gunicorn.
Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch
to ASGI applications.
"""
import asyncio
import base64
import hashlib
import traceback
from datetime import datetime
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
from gunicorn.http.errors import NoMoreData
class ASGIProtocol(asyncio.Protocol):
"""HTTP/1.1 protocol handler for ASGI applications.
Handles connection lifecycle, request parsing, and ASGI app invocation.
"""
def __init__(self, worker):
self.worker = worker
self.cfg = worker.cfg
self.log = worker.log
self.app = worker.asgi
self.transport = None
self.reader = None
self.writer = None
self._task = None
self.req_count = 0
# Connection state
self._closed = False
def connection_made(self, transport):
"""Called when a connection is established."""
self.transport = transport
self.worker.nr_conns += 1
# Create stream reader/writer
self.reader = asyncio.StreamReader()
self.writer = transport
# Start handling requests
self._task = self.worker.loop.create_task(self._handle_connection())
def data_received(self, data):
"""Called when data is received on the connection."""
if self.reader:
self.reader.feed_data(data)
def connection_lost(self, exc):
"""Called when the connection is lost or closed."""
self._closed = True
self.worker.nr_conns -= 1
if self.reader:
self.reader.feed_eof()
if self._task and not self._task.done():
self._task.cancel()
async def _handle_connection(self):
"""Main request handling loop for this connection."""
unreader = AsyncUnreader(self.reader)
try:
peername = self.transport.get_extra_info('peername')
sockname = self.transport.get_extra_info('sockname')
while not self._closed:
self.req_count += 1
try:
# Parse HTTP request
request = await AsyncRequest.parse(
self.cfg,
unreader,
peername,
self.req_count
)
except StopIteration:
# No more data, close connection
break
except NoMoreData:
# Client disconnected
break
# Check for WebSocket upgrade
if self._is_websocket_upgrade(request):
await self._handle_websocket(request, sockname, peername)
break # WebSocket takes over the connection
else:
# Handle HTTP request
keepalive = await self._handle_http_request(
request, sockname, peername
)
# Increment worker request count
self.worker.nr += 1
# Check max_requests
if self.worker.nr >= self.worker.max_requests:
self.log.info("Autorestarting worker after current request.")
self.worker.alive = False
keepalive = False
if not keepalive or not self.worker.alive:
break
# Check connection limits for keepalive
if not self.cfg.keepalive:
break
# Drain any unread body before next request
await request.drain_body()
except asyncio.CancelledError:
pass
except Exception as e:
self.log.exception("Error handling connection: %s", e)
finally:
self._close_transport()
def _is_websocket_upgrade(self, request):
"""Check if request is a WebSocket upgrade."""
upgrade = None
connection = None
for name, value in request.headers:
if name == "UPGRADE":
upgrade = value.lower()
elif name == "CONNECTION":
connection = value.lower()
return upgrade == "websocket" and connection and "upgrade" in connection
async def _handle_websocket(self, request, sockname, peername):
"""Handle WebSocket upgrade request."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = self._build_websocket_scope(request, sockname, peername)
ws_protocol = WebSocketProtocol(
self.transport, self.reader, scope, self.app, self.log
)
await ws_protocol.run()
async def _handle_http_request(self, request, sockname, peername):
"""Handle a single HTTP request."""
scope = self._build_http_scope(request, sockname, peername)
response_started = False
response_complete = False
body_parts = []
exc_to_raise = None
# Receive queue for body
receive_queue = asyncio.Queue()
# Pre-populate with initial body state
if request.content_length == 0 and not request.chunked:
await receive_queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
else:
# Start body reading task
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
async def receive():
return await receive_queue.get()
async def send(message):
nonlocal response_started, response_complete, exc_to_raise
msg_type = message["type"]
if msg_type == "http.response.start":
if response_started:
exc_to_raise = RuntimeError("Response already started")
return
response_started = True
status = message["status"]
headers = message.get("headers", [])
await self._send_response_start(status, headers, request)
elif msg_type == "http.response.body":
if not response_started:
exc_to_raise = RuntimeError("Response not started")
return
if response_complete:
exc_to_raise = RuntimeError("Response already complete")
return
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body:
await self._send_body(body)
if not more_body:
response_complete = True
try:
request_start = datetime.now()
self.cfg.pre_request(self.worker, request)
await self.app(scope, receive, send)
if exc_to_raise:
raise exc_to_raise
# Ensure response was sent
if not response_started:
await self._send_error_response(500, "Internal Server Error")
except Exception as e:
self.log.exception("Error in ASGI application")
if not response_started:
await self._send_error_response(500, "Internal Server Error")
return False
finally:
try:
request_time = datetime.now() - request_start
self.cfg.post_request(self.worker, request, {}, None)
except Exception:
self.log.exception("Exception in post_request hook")
# Determine keepalive
if request.should_close():
return False
return self.worker.alive and self.cfg.keepalive
async def _read_body_to_queue(self, request, queue):
"""Read request body and put chunks on the queue."""
try:
while True:
chunk = await request.read_body(65536)
if chunk:
await queue.put({
"type": "http.request",
"body": chunk,
"more_body": True,
})
else:
await queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
break
except Exception as e:
self.log.debug("Error reading body: %s", e)
await queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
def _build_http_scope(self, request, sockname, peername):
"""Build ASGI HTTP scope from parsed request."""
# Build headers list as bytes tuples
headers = []
for name, value in request.headers:
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
scope = {
"type": "http",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": f"{request.version[0]}.{request.version[1]}",
"method": request.method,
"scheme": request.scheme,
"path": request.path,
"raw_path": request.path.encode("latin-1") if request.path else b"",
"query_string": request.query.encode("latin-1") if request.query else b"",
"root_path": self.cfg.root_path or "",
"headers": headers,
"server": sockname if sockname else None,
"client": peername if peername else None,
}
# Add state dict for lifespan sharing
if hasattr(self.worker, 'state'):
scope["state"] = self.worker.state
return scope
def _build_websocket_scope(self, request, sockname, peername):
"""Build ASGI WebSocket scope from parsed request."""
# Build headers list as bytes tuples
headers = []
for name, value in request.headers:
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
# Extract subprotocols from Sec-WebSocket-Protocol header
subprotocols = []
for name, value in request.headers:
if name == "SEC-WEBSOCKET-PROTOCOL":
subprotocols = [s.strip() for s in value.split(",")]
break
scope = {
"type": "websocket",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": f"{request.version[0]}.{request.version[1]}",
"scheme": "wss" if request.scheme == "https" else "ws",
"path": request.path,
"raw_path": request.path.encode("latin-1") if request.path else b"",
"query_string": request.query.encode("latin-1") if request.query else b"",
"root_path": self.cfg.root_path or "",
"headers": headers,
"server": sockname if sockname else None,
"client": peername if peername else None,
"subprotocols": subprotocols,
}
# Add state dict for lifespan sharing
if hasattr(self.worker, 'state'):
scope["state"] = self.worker.state
return scope
async def _send_response_start(self, status, headers, request):
"""Send HTTP response status and headers."""
# Build status line
reason = self._get_reason_phrase(status)
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
# Build headers
header_lines = []
has_content_length = False
has_transfer_encoding = False
has_connection = False
for name, value in headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
header_lines.append(f"{name}: {value}\r\n")
name_lower = name.lower()
if name_lower == "content-length":
has_content_length = True
elif name_lower == "transfer-encoding":
has_transfer_encoding = True
elif name_lower == "connection":
has_connection = True
# Add server header if not present
header_lines.append("Server: gunicorn/asgi\r\n")
response = status_line + "".join(header_lines) + "\r\n"
self.transport.write(response.encode("latin-1"))
async def _send_body(self, body):
"""Send response body chunk."""
if body:
self.transport.write(body)
async def _send_error_response(self, status, message):
"""Send an error response."""
body = message.encode("utf-8")
response = (
f"HTTP/1.1 {status} {message}\r\n"
f"Content-Type: text/plain\r\n"
f"Content-Length: {len(body)}\r\n"
f"Connection: close\r\n"
f"\r\n"
)
self.transport.write(response.encode("latin-1"))
self.transport.write(body)
def _get_reason_phrase(self, status):
"""Get HTTP reason phrase for status code."""
reasons = {
100: "Continue",
101: "Switching Protocols",
200: "OK",
201: "Created",
202: "Accepted",
204: "No Content",
206: "Partial Content",
301: "Moved Permanently",
302: "Found",
303: "See Other",
304: "Not Modified",
307: "Temporary Redirect",
308: "Permanent Redirect",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Not Found",
405: "Method Not Allowed",
408: "Request Timeout",
409: "Conflict",
410: "Gone",
411: "Length Required",
413: "Payload Too Large",
414: "URI Too Long",
415: "Unsupported Media Type",
422: "Unprocessable Entity",
429: "Too Many Requests",
500: "Internal Server Error",
501: "Not Implemented",
502: "Bad Gateway",
503: "Service Unavailable",
504: "Gateway Timeout",
}
return reasons.get(status, "Unknown")
def _close_transport(self):
"""Close the transport safely."""
if self.transport and not self._closed:
try:
self.transport.close()
except Exception:
pass
self._closed = True

100
gunicorn/asgi/unreader.py Normal file
View File

@ -0,0 +1,100 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Async version of gunicorn/http/unreader.py for ASGI workers.
Provides async reading with pushback buffer support.
"""
import io
class AsyncUnreader:
"""Async socket reader with pushback buffer support.
This class wraps an asyncio StreamReader and provides the ability
to "unread" data back into a buffer for re-parsing.
"""
def __init__(self, reader, max_chunk=8192):
"""Initialize the async unreader.
Args:
reader: asyncio.StreamReader instance
max_chunk: Maximum bytes to read at once
"""
self.reader = reader
self.buf = io.BytesIO()
self.max_chunk = max_chunk
async def read(self, size=None):
"""Read data from the stream, using buffered data first.
Args:
size: Number of bytes to read. If None, returns all buffered
data or reads a single chunk.
Returns:
bytes: Data read from buffer or stream
"""
if size is not None and not isinstance(size, int):
raise TypeError("size parameter must be an int or long.")
if size is not None:
if size == 0:
return b""
if size < 0:
size = None
# Move to end to check buffer size
self.buf.seek(0, io.SEEK_END)
# If no size specified, return buffered data or read chunk
if size is None and self.buf.tell():
ret = self.buf.getvalue()
self.buf = io.BytesIO()
return ret
if size is None:
chunk = await self._read_chunk()
return chunk
# Read until we have enough data
while self.buf.tell() < size:
chunk = await self._read_chunk()
if not chunk:
ret = self.buf.getvalue()
self.buf = io.BytesIO()
return ret
self.buf.write(chunk)
data = self.buf.getvalue()
self.buf = io.BytesIO()
self.buf.write(data[size:])
return data[:size]
async def _read_chunk(self):
"""Read a chunk of data from the underlying stream."""
try:
return await self.reader.read(self.max_chunk)
except Exception:
return b""
def unread(self, data):
"""Push data back into the buffer for re-reading.
Args:
data: bytes to push back
"""
if data:
self.buf.seek(0, io.SEEK_END)
self.buf.write(data)
def has_buffered_data(self):
"""Check if there's data in the pushback buffer."""
pos = self.buf.tell()
self.buf.seek(0, io.SEEK_END)
has_data = self.buf.tell() > 0
self.buf.seek(pos)
return has_data

369
gunicorn/asgi/websocket.py Normal file
View File

@ -0,0 +1,369 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
WebSocket protocol handler for ASGI.
Implements RFC 6455 WebSocket protocol for ASGI applications.
"""
import asyncio
import base64
import hashlib
import struct
import os
# WebSocket frame opcodes
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
# WebSocket close codes
CLOSE_NORMAL = 1000
CLOSE_GOING_AWAY = 1001
CLOSE_PROTOCOL_ERROR = 1002
CLOSE_UNSUPPORTED = 1003
CLOSE_NO_STATUS = 1005
CLOSE_ABNORMAL = 1006
CLOSE_INVALID_DATA = 1007
CLOSE_POLICY_VIOLATION = 1008
CLOSE_MESSAGE_TOO_BIG = 1009
CLOSE_MANDATORY_EXT = 1010
CLOSE_INTERNAL_ERROR = 1011
# WebSocket handshake GUID (RFC 6455)
WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
class WebSocketProtocol:
"""WebSocket connection handler for ASGI applications."""
def __init__(self, transport, reader, scope, app, log):
"""Initialize WebSocket protocol handler.
Args:
transport: asyncio transport for writing
reader: asyncio StreamReader for reading
scope: ASGI WebSocket scope dict
app: ASGI application callable
log: Logger instance
"""
self.transport = transport
self.reader = reader
self.scope = scope
self.app = app
self.log = log
self.accepted = False
self.closed = False
self.close_code = None
self.close_reason = ""
# Message reassembly state
self._fragments = []
self._fragment_opcode = None
# Receive queue for incoming messages
self._receive_queue = asyncio.Queue()
async def run(self):
"""Run the WebSocket ASGI application."""
# Send initial connect event
await self._receive_queue.put({"type": "websocket.connect"})
# Start frame reading task
read_task = asyncio.create_task(self._read_frames())
try:
await self.app(self.scope, self._receive, self._send)
except Exception as e:
self.log.exception("Error in WebSocket ASGI application")
finally:
read_task.cancel()
try:
await read_task
except asyncio.CancelledError:
pass
# Send close frame if not already closed
if not self.closed and self.accepted:
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
async def _receive(self):
"""ASGI receive callable."""
return await self._receive_queue.get()
async def _send(self, message):
"""ASGI send callable."""
msg_type = message["type"]
if msg_type == "websocket.accept":
if self.accepted:
raise RuntimeError("WebSocket already accepted")
await self._send_accept(message)
self.accepted = True
elif msg_type == "websocket.send":
if not self.accepted:
raise RuntimeError("WebSocket not accepted")
if self.closed:
raise RuntimeError("WebSocket closed")
if "text" in message:
await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8"))
elif "bytes" in message:
await self._send_frame(OPCODE_BINARY, message["bytes"])
elif msg_type == "websocket.close":
code = message.get("code", CLOSE_NORMAL)
reason = message.get("reason", "")
await self._send_close(code, reason)
self.closed = True
async def _send_accept(self, message):
"""Send WebSocket handshake accept response."""
# Get Sec-WebSocket-Key from headers
ws_key = None
for name, value in self.scope["headers"]:
if name == b"sec-websocket-key":
ws_key = value
break
if not ws_key:
raise RuntimeError("Missing Sec-WebSocket-Key header")
# Calculate accept key
accept_key = base64.b64encode(
hashlib.sha1(ws_key + WS_GUID).digest()
).decode("ascii")
# Build response headers
headers = [
"HTTP/1.1 101 Switching Protocols\r\n",
"Upgrade: websocket\r\n",
"Connection: Upgrade\r\n",
f"Sec-WebSocket-Accept: {accept_key}\r\n",
]
# Add selected subprotocol if specified
subprotocol = message.get("subprotocol")
if subprotocol:
headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n")
# Add any extra headers from message
extra_headers = message.get("headers", [])
for name, value in extra_headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
headers.append(f"{name}: {value}\r\n")
headers.append("\r\n")
self.transport.write("".join(headers).encode("latin-1"))
async def _read_frames(self):
"""Read and process incoming WebSocket frames."""
try:
while not self.closed:
frame = await self._read_frame()
if frame is None:
break
opcode, payload = frame
if opcode == OPCODE_CLOSE:
await self._handle_close(payload)
break
elif opcode == OPCODE_PING:
await self._send_frame(OPCODE_PONG, payload)
elif opcode == OPCODE_PONG:
# Ignore pongs
pass
elif opcode == OPCODE_TEXT:
await self._receive_queue.put({
"type": "websocket.receive",
"text": payload.decode("utf-8"),
})
elif opcode == OPCODE_BINARY:
await self._receive_queue.put({
"type": "websocket.receive",
"bytes": payload,
})
elif opcode == OPCODE_CONTINUATION:
# Handle fragmented messages
await self._handle_continuation(payload)
except asyncio.CancelledError:
raise
except Exception as e:
self.log.debug("WebSocket read error: %s", e)
finally:
# Signal disconnect
if not self.closed:
self.closed = True
await self._receive_queue.put({
"type": "websocket.disconnect",
"code": self.close_code or CLOSE_ABNORMAL,
})
async def _read_frame(self):
"""Read a single WebSocket frame.
Returns:
tuple: (opcode, payload) or None if connection closed
"""
# Read frame header (2 bytes minimum)
header = await self._read_exact(2)
if not header:
return None
first_byte, second_byte = header[0], header[1]
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0x0F
# RSV bits must be 0 (no extensions)
if rsv1 or rsv2 or rsv3:
await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set")
return None
masked = (second_byte >> 7) & 1
payload_len = second_byte & 0x7F
# Client frames must be masked (RFC 6455)
if not masked:
await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked")
return None
# Extended payload length
if payload_len == 126:
ext_len = await self._read_exact(2)
if not ext_len:
return None
payload_len = struct.unpack("!H", ext_len)[0]
elif payload_len == 127:
ext_len = await self._read_exact(8)
if not ext_len:
return None
payload_len = struct.unpack("!Q", ext_len)[0]
# Read masking key
masking_key = await self._read_exact(4)
if not masking_key:
return None
# Read payload
payload = await self._read_exact(payload_len)
if payload is None:
return None
# Unmask payload
payload = self._unmask(payload, masking_key)
# Handle fragmented messages
if opcode == OPCODE_CONTINUATION:
if self._fragment_opcode is None:
await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation")
return None
self._fragments.append(payload)
if fin:
# Reassemble complete message
full_payload = b"".join(self._fragments)
final_opcode = self._fragment_opcode
self._fragments = []
self._fragment_opcode = None
return (final_opcode, full_payload)
return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
if not fin:
# Start of fragmented message
self._fragment_opcode = opcode
self._fragments = [payload]
return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more
return (opcode, payload)
else:
# Control frames
return (opcode, payload)
async def _read_exact(self, n):
"""Read exactly n bytes from the reader."""
try:
data = await self.reader.readexactly(n)
return data
except asyncio.IncompleteReadError:
return None
except Exception:
return None
def _unmask(self, payload, masking_key):
"""Unmask WebSocket payload data."""
if not payload:
return payload
# XOR each byte with corresponding mask byte
return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload))
async def _handle_close(self, payload):
"""Handle incoming close frame."""
if len(payload) >= 2:
self.close_code = struct.unpack("!H", payload[:2])[0]
self.close_reason = payload[2:].decode("utf-8", errors="replace")
else:
self.close_code = CLOSE_NO_STATUS
self.close_reason = ""
# Echo close frame back if we haven't already sent one
if not self.closed:
await self._send_close(self.close_code, self.close_reason)
self.closed = True
async def _handle_continuation(self, payload):
"""Handle continuation frame (already processed in _read_frame)."""
# This is called for partial fragments, nothing to do
pass
async def _send_frame(self, opcode, payload):
"""Send a WebSocket frame.
Server frames are not masked (RFC 6455).
"""
if isinstance(payload, str):
payload = payload.encode("utf-8")
length = len(payload)
frame = bytearray()
# First byte: FIN + opcode
frame.append(0x80 | opcode)
# Second byte: length (no mask bit for server)
if length < 126:
frame.append(length)
elif length < 65536:
frame.append(126)
frame.extend(struct.pack("!H", length))
else:
frame.append(127)
frame.extend(struct.pack("!Q", length))
# Payload
frame.extend(payload)
self.transport.write(bytes(frame))
async def _send_close(self, code, reason=""):
"""Send a close frame."""
payload = struct.pack("!H", code)
if reason:
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
await self._send_frame(OPCODE_CLOSE, payload)
self.closed = True

View File

@ -2440,3 +2440,94 @@ class HeaderMap(Setting):
.. versionadded:: 22.0.0
"""
def validate_asgi_loop(val):
if val is None:
return "auto"
if not isinstance(val, str):
raise TypeError("Invalid type for casting: %s" % val)
val = val.lower().strip()
if val not in ("auto", "asyncio", "uvloop"):
raise ValueError("Invalid ASGI loop: %s" % val)
return val
def validate_asgi_lifespan(val):
if val is None:
return "auto"
if not isinstance(val, str):
raise TypeError("Invalid type for casting: %s" % val)
val = val.lower().strip()
if val not in ("auto", "on", "off"):
raise ValueError("Invalid ASGI lifespan: %s" % val)
return val
class ASGILoop(Setting):
name = "asgi_loop"
section = "Worker Processes"
cli = ["--asgi-loop"]
meta = "STRING"
validator = validate_asgi_loop
default = "auto"
desc = """\
Event loop implementation for ASGI workers.
- auto: Use uvloop if available, otherwise asyncio
- asyncio: Use Python's built-in asyncio event loop
- uvloop: Use uvloop (must be installed separately)
This setting only affects the ``asgi`` worker type.
uvloop typically provides better performance but requires
installing the uvloop package.
.. versionadded:: 24.0.0
"""
class ASGILifespan(Setting):
name = "asgi_lifespan"
section = "Worker Processes"
cli = ["--asgi-lifespan"]
meta = "STRING"
validator = validate_asgi_lifespan
default = "auto"
desc = """\
Control ASGI lifespan protocol handling.
- auto: Detect if app supports lifespan, enable if so
- on: Always run lifespan protocol (fail if unsupported)
- off: Never run lifespan protocol
The lifespan protocol allows ASGI applications to run code at
startup and shutdown. This is essential for frameworks like
FastAPI that need to initialize database connections, caches,
or other resources.
This setting only affects the ``asgi`` worker type.
.. versionadded:: 24.0.0
"""
class RootPath(Setting):
name = "root_path"
section = "Server Mechanics"
cli = ["--root-path"]
meta = "STRING"
validator = validate_string
default = ""
desc = """\
The root path for ASGI applications.
This is used to set the ``root_path`` in the ASGI scope, which
allows applications to know their mount point when behind a
reverse proxy.
For example, if your application is mounted at ``/api``, set
this to ``/api``.
.. versionadded:: 24.0.0
"""

View File

@ -11,4 +11,5 @@ SUPPORTED_WORKERS = {
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
"gthread": "gunicorn.workers.gthread.ThreadWorker",
"asgi": "gunicorn.workers.gasgi.ASGIWorker",
}

282
gunicorn/workers/gasgi.py Normal file
View File

@ -0,0 +1,282 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI worker for gunicorn.
Provides native asyncio-based ASGI support using gunicorn's own
HTTP parsing infrastructure.
"""
import asyncio
import os
import signal
import ssl
import sys
from gunicorn.workers import base
from gunicorn.asgi.protocol import ASGIProtocol
class ASGIWorker(base.Worker):
"""ASGI worker using asyncio event loop.
Supports:
- HTTP/1.1 with keepalive
- WebSocket connections
- Lifespan protocol (startup/shutdown hooks)
- Optional uvloop for improved performance
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.worker_connections = self.cfg.worker_connections
self.loop = None
self.servers = []
self.nr_conns = 0
self.lifespan = None
self.state = {} # Shared state for lifespan
@classmethod
def check_config(cls, cfg, log):
"""Validate configuration for ASGI worker."""
if cfg.threads > 1:
log.warning("ASGI worker does not use threads configuration. "
"Use worker_connections instead.")
def init_process(self):
"""Initialize the worker process."""
# Setup event loop before calling super()
self._setup_event_loop()
super().init_process()
def _setup_event_loop(self):
"""Setup the asyncio event loop."""
loop_type = getattr(self.cfg, 'asgi_loop', 'auto')
if loop_type == "auto":
try:
import uvloop
loop_type = "uvloop"
except ImportError:
loop_type = "asyncio"
if loop_type == "uvloop":
try:
import uvloop
self.loop = uvloop.new_event_loop()
self.log.debug("Using uvloop event loop")
except ImportError:
self.log.warning("uvloop not available, falling back to asyncio")
self.loop = asyncio.new_event_loop()
else:
self.loop = asyncio.new_event_loop()
self.log.debug("Using asyncio event loop")
asyncio.set_event_loop(self.loop)
def load_wsgi(self):
"""Load the ASGI application."""
try:
self.asgi = self.app.wsgi()
except SyntaxError as e:
if not self.cfg.reload:
raise
self.log.exception(e)
self.asgi = self._make_error_app(str(e))
def _make_error_app(self, error_msg):
"""Create an error ASGI app for syntax errors during reload."""
async def error_app(scope, receive, send):
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 500,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": f"Application error: {error_msg}".encode(),
})
elif scope["type"] == "lifespan":
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
message = await receive()
if message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return error_app
def init_signals(self):
"""Initialize signal handlers for asyncio."""
# Reset all signals first
for s in self.SIGNALS:
signal.signal(s, signal.SIG_DFL)
# Set up signal handlers via the event loop
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal)
def handle_quit_signal(self):
"""Handle SIGQUIT - immediate shutdown."""
self.alive = False
self.cfg.worker_int(self)
def handle_exit_signal(self):
"""Handle SIGTERM - graceful shutdown."""
self.alive = False
def handle_usr1_signal(self):
"""Handle SIGUSR1 - reopen log files."""
self.log.reopen_files()
def handle_winch_signal(self):
"""Handle SIGWINCH - ignored in worker."""
self.log.debug("worker: SIGWINCH ignored.")
def handle_abort_signal(self):
"""Handle SIGABRT - abort."""
self.alive = False
self.cfg.worker_abort(self)
sys.exit(1)
def run(self):
"""Main entry point for the worker."""
try:
self.loop.run_until_complete(self._serve())
except Exception as e:
self.log.exception("Worker exception: %s", e)
finally:
self._cleanup()
async def _serve(self):
"""Main async serving loop."""
# Run lifespan startup
lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto')
if lifespan_mode != "off":
from gunicorn.asgi.lifespan import LifespanManager
self.lifespan = LifespanManager(self.asgi, self.log, self.state)
try:
await self.lifespan.startup()
except Exception as e:
if lifespan_mode == "on":
self.log.error("ASGI lifespan startup failed: %s", e)
return
else:
# auto mode - app doesn't support lifespan
self.log.debug("ASGI lifespan not supported by app: %s", e)
self.lifespan = None
# Create servers for each listener socket
ssl_context = self._get_ssl_context()
for sock in self.sockets:
try:
server = await self.loop.create_server(
lambda: ASGIProtocol(self),
sock=sock.sock,
ssl=ssl_context,
reuse_address=True,
start_serving=True,
)
self.servers.append(server)
self.log.info("ASGI server listening on %s", sock)
except Exception as e:
self.log.error("Failed to create server on %s: %s", sock, e)
if not self.servers:
self.log.error("No servers could be started")
return
# Main loop with heartbeat
try:
while self.alive:
self.notify()
# Check if parent is still alive
if self.ppid != os.getppid():
self.log.info("Parent changed, shutting down: %s", self)
break
# Check connection limit
# (Connections are managed by nr_conns in ASGIProtocol)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
pass
# Graceful shutdown
await self._shutdown()
async def _shutdown(self):
"""Perform graceful shutdown."""
self.log.info("Worker shutting down...")
# Stop accepting new connections
for server in self.servers:
server.close()
# Wait for servers to close
for server in self.servers:
await server.wait_closed()
# Wait for in-flight connections (with timeout)
graceful_timeout = self.cfg.graceful_timeout
if self.nr_conns > 0:
self.log.info("Waiting for %d connections to finish...", self.nr_conns)
deadline = self.loop.time() + graceful_timeout
while self.nr_conns > 0 and self.loop.time() < deadline:
await asyncio.sleep(0.1)
if self.nr_conns > 0:
self.log.warning("Closing %d connections after timeout", self.nr_conns)
# Run lifespan shutdown
if self.lifespan:
try:
await self.lifespan.shutdown()
except Exception as e:
self.log.error("ASGI lifespan shutdown error: %s", e)
def _get_ssl_context(self):
"""Get SSL context if configured."""
if not self.cfg.is_ssl:
return None
try:
from gunicorn import sock
return sock.ssl_context(self.cfg)
except Exception as e:
self.log.error("Failed to create SSL context: %s", e)
return None
def _cleanup(self):
"""Clean up resources on exit."""
try:
# Cancel all pending tasks
pending = asyncio.all_tasks(self.loop)
for task in pending:
task.cancel()
# Run loop until all tasks are cancelled
if pending:
self.loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
self.loop.close()
except Exception as e:
self.log.debug("Cleanup error: %s", e)
# Close sockets
for s in self.sockets:
try:
s.close()
except Exception:
pass

View File

@ -58,6 +58,7 @@ testing = [
"coverage",
"pytest",
"pytest-cov",
"pytest-asyncio",
]
[project.scripts]

285
tests/test_asgi.py Normal file
View File

@ -0,0 +1,285 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for ASGI worker components.
"""
import asyncio
import io
import pytest
from unittest import mock
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
class MockStreamReader:
"""Mock asyncio.StreamReader for testing."""
def __init__(self, data):
self.data = data
self.pos = 0
async def read(self, size=-1):
if self.pos >= len(self.data):
return b""
if size < 0:
result = self.data[self.pos:]
self.pos = len(self.data)
else:
result = self.data[self.pos:self.pos + size]
self.pos += size
return result
async def readexactly(self, n):
if self.pos + n > len(self.data):
raise asyncio.IncompleteReadError(
self.data[self.pos:], n
)
result = self.data[self.pos:self.pos + n]
self.pos += n
return result
class MockConfig:
"""Mock gunicorn config for testing."""
def __init__(self):
self.is_ssl = False
self.proxy_protocol = False
self.proxy_allow_ips = ["127.0.0.1"]
self.forwarded_allow_ips = ["127.0.0.1"]
self.secure_scheme_headers = {}
self.forwarder_headers = []
self.limit_request_line = 8190
self.limit_request_fields = 100
self.limit_request_field_size = 8190
self.permit_unconventional_http_method = False
self.permit_unconventional_http_version = False
self.permit_obsolete_folding = False
self.casefold_http_method = False
self.strip_header_spaces = False
self.header_map = "refuse"
# AsyncUnreader Tests
@pytest.mark.asyncio
async def test_async_unreader_read_chunk():
"""Test basic chunk reading."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b"hello world"
@pytest.mark.asyncio
async def test_async_unreader_read_size():
"""Test reading specific size."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read(5)
assert data == b"hello"
@pytest.mark.asyncio
async def test_async_unreader_unread():
"""Test unread functionality."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
# Read all data
data = await unreader.read()
assert data == b"hello world"
# Unread some data
unreader.unread(b"world")
# Read again should get unread data
data = await unreader.read()
assert data == b"world"
@pytest.mark.asyncio
async def test_async_unreader_read_zero():
"""Test reading zero bytes."""
reader = MockStreamReader(b"hello")
unreader = AsyncUnreader(reader)
data = await unreader.read(0)
assert data == b""
@pytest.mark.asyncio
async def test_async_unreader_read_empty():
"""Test reading from empty stream."""
reader = MockStreamReader(b"")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b""
# AsyncRequest Tests
@pytest.mark.asyncio
async def test_async_request_simple_get():
"""Test parsing a simple GET request."""
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/path"
assert request.version == (1, 1)
assert ("HOST", "localhost") in request.headers
@pytest.mark.asyncio
async def test_async_request_with_query():
"""Test parsing request with query string."""
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/search"
assert request.query == "q=test&page=1"
@pytest.mark.asyncio
async def test_async_request_post_with_body():
"""Test parsing POST request with body."""
request_data = (
b"POST /submit HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 11\r\n"
b"\r\n"
b"hello=world"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "POST"
assert request.path == "/submit"
assert request.content_length == 11
# Read body
body = await request.read_body(100)
assert body == b"hello=world"
@pytest.mark.asyncio
async def test_async_request_multiple_headers():
"""Test parsing request with multiple headers."""
request_data = (
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Accept: text/html\r\n"
b"Accept-Language: en-US\r\n"
b"Connection: keep-alive\r\n"
b"\r\n"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert len(request.headers) == 4
assert request.get_header("HOST") == "localhost"
assert request.get_header("ACCEPT") == "text/html"
@pytest.mark.asyncio
async def test_async_request_should_close_http10():
"""Test connection close detection for HTTP/1.0."""
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.version == (1, 0)
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_should_close_connection_header():
"""Test connection close detection with Connection header."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_keepalive():
"""Test keepalive detection."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is False
@pytest.mark.asyncio
async def test_async_request_no_body_for_get():
"""Test that GET requests have no body by default."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.content_length == 0
body = await request.read_body()
assert body == b""
# Error handling tests
@pytest.mark.asyncio
async def test_async_request_invalid_method():
"""Test invalid HTTP method detection."""
from gunicorn.http.errors import InvalidRequestMethod
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidRequestMethod):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
@pytest.mark.asyncio
async def test_async_request_invalid_http_version():
"""Test invalid HTTP version detection."""
from gunicorn.http.errors import InvalidHTTPVersion
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidHTTPVersion):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))

643
tests/test_asgi_worker.py Normal file
View File

@ -0,0 +1,643 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for the ASGI worker.
Includes unit tests for worker components and integration tests
that actually start the server and make HTTP requests.
"""
import asyncio
import errno
import os
import signal
import socket
import sys
import time
import threading
from unittest import mock
import pytest
from gunicorn.config import Config
from gunicorn.workers import gasgi
# ============================================================================
# Mock Classes
# ============================================================================
class FakeSocket:
"""Mock socket for testing."""
def __init__(self, data=b''):
self.data = data
self.closed = False
self.blocking = True
self._fileno = id(self) % 65536
def fileno(self):
return self._fileno
def setblocking(self, blocking):
self.blocking = blocking
def recv(self, size):
if self.closed:
raise OSError(errno.EBADF, "Bad file descriptor")
result = self.data[:size]
self.data = self.data[size:]
return result
def send(self, data):
if self.closed:
raise OSError(errno.EPIPE, "Broken pipe")
return len(data)
def close(self):
self.closed = True
def getsockname(self):
return ('127.0.0.1', 8000)
def getpeername(self):
return ('127.0.0.1', 12345)
class FakeApp:
"""Mock ASGI application for testing."""
def __init__(self):
self.calls = []
def wsgi(self):
return self.asgi_app
async def asgi_app(self, scope, receive, send):
self.calls.append(scope)
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
elif scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello from ASGI!",
})
class FakeListener:
"""Mock listener socket."""
def __init__(self):
self.sock = FakeSocket()
def getsockname(self):
return ('127.0.0.1', 8000)
def close(self):
self.sock.close()
def __str__(self):
return "http://127.0.0.1:8000"
# ============================================================================
# Helper Functions
# ============================================================================
def _has_uvloop():
"""Check if uvloop is available."""
try:
import uvloop
return True
except ImportError:
return False
# ============================================================================
# Unit Tests for ASGIWorker
# ============================================================================
class TestASGIWorkerInit:
"""Tests for ASGIWorker initialization."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_worker_init(self):
"""Test worker initialization."""
worker = self.create_worker()
assert worker.worker_connections == 1000
assert worker.nr_conns == 0
assert worker.loop is None
assert worker.servers == []
assert worker.state == {}
def test_worker_connections_config(self):
"""Test worker_connections configuration."""
worker = self.create_worker(worker_connections=500)
assert worker.worker_connections == 500
class TestASGIWorkerEventLoop:
"""Tests for event loop setup."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_setup_asyncio_loop(self):
"""Test asyncio event loop setup."""
worker = self.create_worker(asgi_loop='asyncio')
worker._setup_event_loop()
assert worker.loop is not None
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
worker.loop.close()
def test_setup_auto_loop_falls_back_to_asyncio(self):
"""Test that auto mode uses asyncio when uvloop unavailable."""
worker = self.create_worker(asgi_loop='auto')
# Mock uvloop import failure
with mock.patch.dict('sys.modules', {'uvloop': None}):
worker._setup_event_loop()
assert worker.loop is not None
worker.loop.close()
@pytest.mark.skipif(
not _has_uvloop(),
reason="uvloop not installed"
)
def test_setup_uvloop(self):
"""Test uvloop event loop setup."""
worker = self.create_worker(asgi_loop='uvloop')
worker._setup_event_loop()
import uvloop
assert isinstance(worker.loop, uvloop.Loop)
worker.loop.close()
class TestASGIWorkerSignals:
"""Tests for signal handling."""
def create_worker(self):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
cfg.set('graceful_timeout', 5)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
worker._setup_event_loop()
return worker
def test_handle_exit_sets_alive_false(self):
"""Test that exit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
worker.handle_exit_signal()
assert worker.alive is False
worker.loop.close()
def test_handle_quit_sets_alive_false(self):
"""Test that quit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
# Mock the worker_int callback on the worker's cfg settings
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
worker.handle_quit_signal()
assert worker.alive is False
worker.loop.close()
# ============================================================================
# Tests for Lifespan Protocol
# ============================================================================
class TestLifespanManager:
"""Tests for ASGI lifespan protocol."""
@pytest.mark.asyncio
async def test_lifespan_startup_complete(self):
"""Test successful lifespan startup."""
from gunicorn.asgi.lifespan import LifespanManager
startup_called = False
shutdown_called = False
async def app(scope, receive, send):
nonlocal startup_called, shutdown_called
assert scope["type"] == "lifespan"
while True:
message = await receive()
if message["type"] == "lifespan.startup":
startup_called = True
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
shutdown_called = True
await send({"type": "lifespan.shutdown.complete"})
return
manager = LifespanManager(app, mock.Mock())
await manager.startup()
assert startup_called
assert manager._startup_complete.is_set()
assert not manager._startup_failed
await manager.shutdown()
assert shutdown_called
@pytest.mark.asyncio
async def test_lifespan_startup_failed(self):
"""Test lifespan startup failure."""
from gunicorn.asgi.lifespan import LifespanManager
async def app(scope, receive, send):
message = await receive()
if message["type"] == "lifespan.startup":
await send({
"type": "lifespan.startup.failed",
"message": "Database connection failed"
})
manager = LifespanManager(app, mock.Mock())
with pytest.raises(RuntimeError, match="Database connection failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_lifespan_state_shared(self):
"""Test that lifespan state is shared with app."""
from gunicorn.asgi.lifespan import LifespanManager
state = {}
async def app(scope, receive, send):
assert "state" in scope
scope["state"]["db"] = "connected"
message = await receive()
await send({"type": "lifespan.startup.complete"})
message = await receive()
await send({"type": "lifespan.shutdown.complete"})
manager = LifespanManager(app, mock.Mock(), state)
await manager.startup()
assert state.get("db") == "connected"
await manager.shutdown()
# ============================================================================
# Tests for WebSocket Protocol
# ============================================================================
class TestWebSocketProtocol:
"""Tests for WebSocket protocol handling."""
def test_websocket_guid(self):
"""Test WebSocket GUID constant."""
from gunicorn.asgi.websocket import WS_GUID
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def test_websocket_opcodes(self):
"""Test WebSocket opcode constants."""
from gunicorn.asgi import websocket
assert websocket.OPCODE_TEXT == 0x1
assert websocket.OPCODE_BINARY == 0x2
assert websocket.OPCODE_CLOSE == 0x8
assert websocket.OPCODE_PING == 0x9
assert websocket.OPCODE_PONG == 0xA
def test_websocket_accept_key_calculation(self):
"""Test WebSocket accept key calculation per RFC 6455."""
import base64
import hashlib
from gunicorn.asgi.websocket import WS_GUID
# Example from RFC 6455
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
accept_key = base64.b64encode(
hashlib.sha1(client_key + WS_GUID).digest()
).decode("ascii")
assert accept_key == expected_accept
def test_websocket_frame_masking(self):
"""Test WebSocket frame unmasking."""
from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
# Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
unmasked = protocol._unmask(masked_data, masking_key)
assert unmasked == b"Hello"
def test_websocket_frame_masking_empty(self):
"""Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key)
assert unmasked == b""
# ============================================================================
# Integration Tests
# ============================================================================
class TestASGIIntegration:
"""Integration tests that start actual servers."""
@pytest.fixture
def free_port(self):
"""Get a free port for testing."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
@pytest.mark.asyncio
async def test_http_request_response(self, free_port):
"""Test basic HTTP request/response cycle."""
# Simple ASGI app
async def app(scope, receive, send):
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello, World!",
})
# Start server
loop = asyncio.get_event_loop()
server = await loop.create_server(
lambda: _TestProtocol(app),
'127.0.0.1',
free_port,
)
try:
# Use asyncio to make HTTP request
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
writer.write(request.encode())
await writer.drain()
# Read response
response = await reader.read(4096)
response_text = response.decode()
assert "HTTP/1.1 200" in response_text
assert "Hello, World!" in response_text
writer.close()
await writer.wait_closed()
finally:
server.close()
await server.wait_closed()
class _TestProtocol(asyncio.Protocol):
"""Minimal protocol for integration testing."""
def __init__(self, app):
self.app = app
self.transport = None
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
# Very simple HTTP parsing for testing
asyncio.create_task(self._handle(data))
async def _handle(self, data):
# Parse basic HTTP request
lines = data.decode().split('\r\n')
method, path, _ = lines[0].split(' ')
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method,
"path": path,
"query_string": b"",
"headers": [],
"server": ("127.0.0.1", 8000),
"client": ("127.0.0.1", 12345),
}
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message):
if message["type"] == "http.response.start":
status = message["status"]
headers = message.get("headers", [])
response = f"HTTP/1.1 {status} OK\r\n"
for name, value in headers:
if isinstance(name, bytes):
name = name.decode()
if isinstance(value, bytes):
value = value.decode()
response += f"{name}: {value}\r\n"
response += "\r\n"
self.transport.write(response.encode())
elif message["type"] == "http.response.body":
body = message.get("body", b"")
self.transport.write(body)
if not message.get("more_body", False):
self.transport.close()
await self.app(scope, receive, send)
# ============================================================================
# ASGI Protocol Tests
# ============================================================================
class TestASGIProtocol:
"""Tests for ASGIProtocol."""
def test_reason_phrases(self):
"""Test HTTP reason phrase lookup."""
from gunicorn.asgi.protocol import ASGIProtocol
# Create minimal worker mock
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
assert protocol._get_reason_phrase(200) == "OK"
assert protocol._get_reason_phrase(404) == "Not Found"
assert protocol._get_reason_phrase(500) == "Internal Server Error"
assert protocol._get_reason_phrase(999) == "Unknown"
def test_scope_building(self):
"""Test HTTP scope building."""
from gunicorn.asgi.protocol import ASGIProtocol
from gunicorn.asgi.message import AsyncRequest
worker = mock.Mock()
worker.cfg = Config()
worker.cfg.set('root_path', '/api')
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
# Create mock request
request = mock.Mock()
request.method = "GET"
request.path = "/users"
request.query = "page=1"
request.version = (1, 1)
request.scheme = "http"
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000), # sockname
("127.0.0.1", 12345), # peername
)
assert scope["type"] == "http"
assert scope["method"] == "GET"
assert scope["path"] == "/users"
assert scope["query_string"] == b"page=1"
assert scope["root_path"] == "/api"
assert scope["http_version"] == "1.1"
# ============================================================================
# Config Tests
# ============================================================================
class TestASGIConfig:
"""Tests for ASGI configuration options."""
def test_asgi_loop_default(self):
"""Test default asgi_loop value."""
cfg = Config()
assert cfg.asgi_loop == "auto"
def test_asgi_loop_validation(self):
"""Test asgi_loop validation."""
cfg = Config()
cfg.set('asgi_loop', 'asyncio')
assert cfg.asgi_loop == 'asyncio'
cfg.set('asgi_loop', 'uvloop')
assert cfg.asgi_loop == 'uvloop'
with pytest.raises(ValueError):
cfg.set('asgi_loop', 'invalid')
def test_asgi_lifespan_default(self):
"""Test default asgi_lifespan value."""
cfg = Config()
assert cfg.asgi_lifespan == "auto"
def test_asgi_lifespan_validation(self):
"""Test asgi_lifespan validation."""
cfg = Config()
cfg.set('asgi_lifespan', 'on')
assert cfg.asgi_lifespan == 'on'
cfg.set('asgi_lifespan', 'off')
assert cfg.asgi_lifespan == 'off'
with pytest.raises(ValueError):
cfg.set('asgi_lifespan', 'invalid')
def test_root_path_default(self):
"""Test default root_path value."""
cfg = Config()
assert cfg.root_path == ""
def test_root_path_setting(self):
"""Test root_path configuration."""
cfg = Config()
cfg.set('root_path', '/api/v1')
assert cfg.root_path == '/api/v1'