mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-02 10:41:30 +08:00
asgi: Add native ASGI worker with HTTP and WebSocket support
Add a new ASGI worker type that provides native async support using gunicorn's own HTTP parsing infrastructure adapted for asyncio. Features: - HTTP/1.1 with keepalive support - WebSocket connections (RFC 6455) - ASGI lifespan protocol for startup/shutdown hooks - Optional uvloop support for improved performance - Full proxy protocol support (inherited from gunicorn) New configuration options: - --asgi-loop: Event loop selection (auto/asyncio/uvloop) - --asgi-lifespan: Lifespan protocol control (auto/on/off) - --root-path: ASGI root path for reverse proxy setups Usage: gunicorn -k asgi myapp:app
This commit is contained in:
parent
ea98400820
commit
ae1eea8108
7
examples/asgi/__init__.py
Normal file
7
examples/asgi/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ASGI example applications for gunicorn.
|
||||||
|
"""
|
||||||
130
examples/asgi/basic_app.py
Normal file
130
examples/asgi/basic_app.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Basic ASGI application example.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
gunicorn -k asgi examples.asgi.basic_app:app
|
||||||
|
|
||||||
|
Test with:
|
||||||
|
curl http://127.0.0.1:8000/
|
||||||
|
curl http://127.0.0.1:8000/hello
|
||||||
|
curl -X POST http://127.0.0.1:8000/echo -d "test data"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
"""Simple ASGI application demonstrating basic functionality."""
|
||||||
|
|
||||||
|
if scope["type"] == "lifespan":
|
||||||
|
await handle_lifespan(scope, receive, send)
|
||||||
|
elif scope["type"] == "http":
|
||||||
|
await handle_http(scope, receive, send)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_lifespan(scope, receive, send):
|
||||||
|
"""Handle lifespan events (startup/shutdown)."""
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
print("ASGI application starting up...")
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
elif message["type"] == "lifespan.shutdown":
|
||||||
|
print("ASGI application shutting down...")
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_http(scope, receive, send):
|
||||||
|
"""Handle HTTP requests."""
|
||||||
|
path = scope["path"]
|
||||||
|
method = scope["method"]
|
||||||
|
|
||||||
|
if path == "/" and method == "GET":
|
||||||
|
await send_response(send, 200, b"Welcome to gunicorn ASGI!\n")
|
||||||
|
|
||||||
|
elif path == "/hello" and method == "GET":
|
||||||
|
name = get_query_param(scope, "name", "World")
|
||||||
|
body = f"Hello, {name}!\n".encode()
|
||||||
|
await send_response(send, 200, body)
|
||||||
|
|
||||||
|
elif path == "/echo" and method == "POST":
|
||||||
|
body = await read_body(receive)
|
||||||
|
await send_response(send, 200, body, content_type=b"application/octet-stream")
|
||||||
|
|
||||||
|
elif path == "/headers":
|
||||||
|
headers_info = format_headers(scope["headers"])
|
||||||
|
await send_response(send, 200, headers_info.encode())
|
||||||
|
|
||||||
|
elif path == "/info":
|
||||||
|
info = format_request_info(scope)
|
||||||
|
await send_response(send, 200, info.encode(), content_type=b"application/json")
|
||||||
|
|
||||||
|
else:
|
||||||
|
await send_response(send, 404, b"Not Found\n")
|
||||||
|
|
||||||
|
|
||||||
|
async def send_response(send, status, body, content_type=b"text/plain"):
|
||||||
|
"""Send an HTTP response."""
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": status,
|
||||||
|
"headers": [
|
||||||
|
(b"content-type", content_type),
|
||||||
|
(b"content-length", str(len(body)).encode()),
|
||||||
|
],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": body,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def read_body(receive):
|
||||||
|
"""Read the full request body."""
|
||||||
|
body = b""
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
body += message.get("body", b"")
|
||||||
|
if not message.get("more_body", False):
|
||||||
|
break
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_param(scope, name, default=None):
|
||||||
|
"""Get a query parameter value."""
|
||||||
|
query_string = scope.get("query_string", b"").decode()
|
||||||
|
for param in query_string.split("&"):
|
||||||
|
if "=" in param:
|
||||||
|
key, value = param.split("=", 1)
|
||||||
|
if key == name:
|
||||||
|
return value
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def format_headers(headers):
|
||||||
|
"""Format headers for display."""
|
||||||
|
lines = ["Request Headers:"]
|
||||||
|
for name, value in headers:
|
||||||
|
lines.append(f" {name.decode()}: {value.decode()}")
|
||||||
|
return "\n".join(lines) + "\n"
|
||||||
|
|
||||||
|
|
||||||
|
def format_request_info(scope):
|
||||||
|
"""Format request info as JSON."""
|
||||||
|
import json
|
||||||
|
info = {
|
||||||
|
"method": scope["method"],
|
||||||
|
"path": scope["path"],
|
||||||
|
"query_string": scope.get("query_string", b"").decode(),
|
||||||
|
"http_version": scope["http_version"],
|
||||||
|
"scheme": scope["scheme"],
|
||||||
|
"server": list(scope.get("server") or []),
|
||||||
|
"client": list(scope.get("client") or []),
|
||||||
|
"root_path": scope.get("root_path", ""),
|
||||||
|
}
|
||||||
|
return json.dumps(info, indent=2) + "\n"
|
||||||
235
examples/asgi/websocket_app.py
Normal file
235
examples/asgi/websocket_app.py
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
WebSocket ASGI application example.
|
||||||
|
|
||||||
|
Run with:
|
||||||
|
gunicorn -k asgi examples.asgi.websocket_app:app
|
||||||
|
|
||||||
|
Test with:
|
||||||
|
# Using websocat (install with: cargo install websocat)
|
||||||
|
websocat ws://127.0.0.1:8000/ws
|
||||||
|
|
||||||
|
# Or using Python websockets library
|
||||||
|
python -c "
|
||||||
|
import asyncio
|
||||||
|
import websockets
|
||||||
|
async def test():
|
||||||
|
async with websockets.connect('ws://127.0.0.1:8000/ws') as ws:
|
||||||
|
await ws.send('Hello')
|
||||||
|
print(await ws.recv())
|
||||||
|
asyncio.run(test())
|
||||||
|
"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
"""ASGI application with WebSocket support."""
|
||||||
|
|
||||||
|
if scope["type"] == "lifespan":
|
||||||
|
await handle_lifespan(scope, receive, send)
|
||||||
|
elif scope["type"] == "http":
|
||||||
|
await handle_http(scope, receive, send)
|
||||||
|
elif scope["type"] == "websocket":
|
||||||
|
await handle_websocket(scope, receive, send)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_lifespan(scope, receive, send):
|
||||||
|
"""Handle lifespan events."""
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
elif message["type"] == "lifespan.shutdown":
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_http(scope, receive, send):
|
||||||
|
"""Handle HTTP requests - serve a simple HTML page for WebSocket testing."""
|
||||||
|
path = scope["path"]
|
||||||
|
|
||||||
|
if path == "/":
|
||||||
|
html = HTML_PAGE.encode()
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 200,
|
||||||
|
"headers": [
|
||||||
|
(b"content-type", b"text/html"),
|
||||||
|
(b"content-length", str(len(html)).encode()),
|
||||||
|
],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": html,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 404,
|
||||||
|
"headers": [(b"content-type", b"text/plain")],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": b"Not Found",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_websocket(scope, receive, send):
|
||||||
|
"""Handle WebSocket connections."""
|
||||||
|
path = scope["path"]
|
||||||
|
|
||||||
|
if path == "/ws":
|
||||||
|
await echo_websocket(scope, receive, send)
|
||||||
|
elif path == "/ws/chat":
|
||||||
|
await chat_websocket(scope, receive, send)
|
||||||
|
else:
|
||||||
|
# Reject the connection
|
||||||
|
await send({"type": "websocket.close", "code": 4004})
|
||||||
|
|
||||||
|
|
||||||
|
async def echo_websocket(scope, receive, send):
|
||||||
|
"""Echo WebSocket - sends back whatever it receives."""
|
||||||
|
# Wait for connection
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] != "websocket.connect":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Accept the connection
|
||||||
|
await send({"type": "websocket.accept"})
|
||||||
|
|
||||||
|
# Echo loop
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
|
||||||
|
if message["type"] == "websocket.disconnect":
|
||||||
|
break
|
||||||
|
|
||||||
|
if message["type"] == "websocket.receive":
|
||||||
|
if "text" in message:
|
||||||
|
# Echo text back
|
||||||
|
await send({
|
||||||
|
"type": "websocket.send",
|
||||||
|
"text": f"Echo: {message['text']}"
|
||||||
|
})
|
||||||
|
elif "bytes" in message:
|
||||||
|
# Echo bytes back
|
||||||
|
await send({
|
||||||
|
"type": "websocket.send",
|
||||||
|
"bytes": message["bytes"]
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
print(f"WebSocket error: {e}")
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await send({"type": "websocket.close", "code": 1000})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_websocket(scope, receive, send):
|
||||||
|
"""Chat WebSocket - simple broadcast example."""
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] != "websocket.connect":
|
||||||
|
return
|
||||||
|
|
||||||
|
await send({
|
||||||
|
"type": "websocket.accept",
|
||||||
|
"subprotocol": "chat"
|
||||||
|
})
|
||||||
|
|
||||||
|
await send({
|
||||||
|
"type": "websocket.send",
|
||||||
|
"text": "Welcome to the chat! Send messages and they will be echoed back."
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
|
||||||
|
if message["type"] == "websocket.disconnect":
|
||||||
|
break
|
||||||
|
|
||||||
|
if message["type"] == "websocket.receive" and "text" in message:
|
||||||
|
text = message["text"]
|
||||||
|
await send({
|
||||||
|
"type": "websocket.send",
|
||||||
|
"text": f"[You]: {text}"
|
||||||
|
})
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
HTML_PAGE = """<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>WebSocket Test</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
||||||
|
#messages { border: 1px solid #ccc; height: 300px; overflow-y: auto; padding: 10px; margin-bottom: 10px; }
|
||||||
|
#input { width: 80%; padding: 10px; }
|
||||||
|
button { padding: 10px 20px; }
|
||||||
|
.sent { color: blue; }
|
||||||
|
.received { color: green; }
|
||||||
|
.error { color: red; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<h1>WebSocket Test</h1>
|
||||||
|
<div id="messages"></div>
|
||||||
|
<input type="text" id="input" placeholder="Type a message...">
|
||||||
|
<button onclick="sendMessage()">Send</button>
|
||||||
|
<button onclick="connectWS()">Connect</button>
|
||||||
|
<button onclick="disconnectWS()">Disconnect</button>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let ws = null;
|
||||||
|
const messages = document.getElementById('messages');
|
||||||
|
const input = document.getElementById('input');
|
||||||
|
|
||||||
|
function log(msg, className) {
|
||||||
|
const div = document.createElement('div');
|
||||||
|
div.className = className || '';
|
||||||
|
div.textContent = msg;
|
||||||
|
messages.appendChild(div);
|
||||||
|
messages.scrollTop = messages.scrollHeight;
|
||||||
|
}
|
||||||
|
|
||||||
|
function connectWS() {
|
||||||
|
if (ws) {
|
||||||
|
log('Already connected', 'error');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ws = new WebSocket('ws://' + window.location.host + '/ws');
|
||||||
|
ws.onopen = () => log('Connected!', 'received');
|
||||||
|
ws.onclose = () => { log('Disconnected', 'error'); ws = null; };
|
||||||
|
ws.onerror = (e) => log('Error: ' + e, 'error');
|
||||||
|
ws.onmessage = (e) => log(e.data, 'received');
|
||||||
|
}
|
||||||
|
|
||||||
|
function disconnectWS() {
|
||||||
|
if (ws) ws.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
function sendMessage() {
|
||||||
|
if (!ws) { log('Not connected', 'error'); return; }
|
||||||
|
const msg = input.value;
|
||||||
|
if (!msg) return;
|
||||||
|
ws.send(msg);
|
||||||
|
log('Sent: ' + msg, 'sent');
|
||||||
|
input.value = '';
|
||||||
|
}
|
||||||
|
|
||||||
|
input.onkeypress = (e) => { if (e.key === 'Enter') sendMessage(); };
|
||||||
|
|
||||||
|
// Auto-connect
|
||||||
|
connectWS();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
26
gunicorn/asgi/__init__.py
Normal file
26
gunicorn/asgi/__init__.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ASGI support for gunicorn.
|
||||||
|
|
||||||
|
This module provides native ASGI worker support, using gunicorn's own
|
||||||
|
HTTP parsing infrastructure adapted for async I/O.
|
||||||
|
|
||||||
|
Components:
|
||||||
|
- AsyncUnreader: Async socket reading with pushback buffer
|
||||||
|
- AsyncRequest: Async HTTP request parser
|
||||||
|
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
|
||||||
|
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
|
||||||
|
- LifespanManager: ASGI lifespan protocol support
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
gunicorn -k asgi myapp:app
|
||||||
|
"""
|
||||||
|
|
||||||
|
from gunicorn.asgi.unreader import AsyncUnreader
|
||||||
|
from gunicorn.asgi.message import AsyncRequest
|
||||||
|
from gunicorn.asgi.lifespan import LifespanManager
|
||||||
|
|
||||||
|
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']
|
||||||
178
gunicorn/asgi/lifespan.py
Normal file
178
gunicorn/asgi/lifespan.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ASGI lifespan protocol manager.
|
||||||
|
|
||||||
|
Manages startup and shutdown events for ASGI applications,
|
||||||
|
enabling frameworks like FastAPI to run initialization and
|
||||||
|
cleanup code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
class LifespanManager:
|
||||||
|
"""Manages ASGI lifespan events (startup/shutdown).
|
||||||
|
|
||||||
|
The lifespan protocol allows ASGI applications to run code at
|
||||||
|
startup and shutdown. This is essential for applications that
|
||||||
|
need to initialize database connections, caches, or other
|
||||||
|
resources.
|
||||||
|
|
||||||
|
ASGI lifespan messages:
|
||||||
|
- Server sends: {"type": "lifespan.startup"}
|
||||||
|
- App responds: {"type": "lifespan.startup.complete"} or
|
||||||
|
{"type": "lifespan.startup.failed", "message": "..."}
|
||||||
|
- Server sends: {"type": "lifespan.shutdown"}
|
||||||
|
- App responds: {"type": "lifespan.shutdown.complete"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, logger, state=None):
|
||||||
|
"""Initialize the lifespan manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: ASGI application callable
|
||||||
|
logger: Logger instance
|
||||||
|
state: Shared state dict for the application
|
||||||
|
"""
|
||||||
|
self.app = app
|
||||||
|
self.logger = logger
|
||||||
|
self.state = state if state is not None else {}
|
||||||
|
|
||||||
|
self._startup_complete = asyncio.Event()
|
||||||
|
self._shutdown_complete = asyncio.Event()
|
||||||
|
self._startup_failed = False
|
||||||
|
self._startup_error = None
|
||||||
|
self._shutdown_error = None
|
||||||
|
self._receive_queue = asyncio.Queue()
|
||||||
|
self._task = None
|
||||||
|
self._app_finished = False
|
||||||
|
|
||||||
|
async def startup(self):
|
||||||
|
"""Run lifespan startup and wait for completion.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: If startup fails or app doesn't support lifespan
|
||||||
|
"""
|
||||||
|
scope = {
|
||||||
|
"type": "lifespan",
|
||||||
|
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||||
|
"state": self.state,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Send startup event
|
||||||
|
await self._receive_queue.put({"type": "lifespan.startup"})
|
||||||
|
|
||||||
|
# Run lifespan in background task
|
||||||
|
self._task = asyncio.create_task(self._run_lifespan(scope))
|
||||||
|
|
||||||
|
# Wait for startup with timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._startup_complete.wait(),
|
||||||
|
timeout=30.0 # Reasonable startup timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
raise RuntimeError("Lifespan startup timed out")
|
||||||
|
|
||||||
|
if self._startup_failed:
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
msg = self._startup_error or "Unknown error"
|
||||||
|
raise RuntimeError(f"Lifespan startup failed: {msg}")
|
||||||
|
|
||||||
|
self.logger.debug("ASGI lifespan startup complete")
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""Signal shutdown and wait for completion.
|
||||||
|
|
||||||
|
This should be called during graceful shutdown.
|
||||||
|
"""
|
||||||
|
if self._app_finished:
|
||||||
|
self.logger.debug("ASGI lifespan already finished")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Send shutdown event
|
||||||
|
await self._receive_queue.put({"type": "lifespan.shutdown"})
|
||||||
|
|
||||||
|
# Wait for shutdown with timeout
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._shutdown_complete.wait(),
|
||||||
|
timeout=30.0 # Reasonable shutdown timeout
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.logger.warning("Lifespan shutdown timed out")
|
||||||
|
|
||||||
|
if self._shutdown_error:
|
||||||
|
self.logger.error("Lifespan shutdown error: %s", self._shutdown_error)
|
||||||
|
|
||||||
|
# Cancel the task if still running
|
||||||
|
if self._task and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
try:
|
||||||
|
await self._task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.logger.debug("ASGI lifespan shutdown complete")
|
||||||
|
|
||||||
|
async def _run_lifespan(self, scope):
|
||||||
|
"""Run the ASGI lifespan protocol."""
|
||||||
|
try:
|
||||||
|
await self.app(scope, self._receive, self._send)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.debug("Lifespan application raised: %s", e)
|
||||||
|
# If startup hasn't completed, mark it as failed
|
||||||
|
if not self._startup_complete.is_set():
|
||||||
|
self._startup_failed = True
|
||||||
|
self._startup_error = str(e)
|
||||||
|
self._startup_complete.set()
|
||||||
|
# If shutdown hasn't completed, mark error
|
||||||
|
elif not self._shutdown_complete.is_set():
|
||||||
|
self._shutdown_error = str(e)
|
||||||
|
self._shutdown_complete.set()
|
||||||
|
finally:
|
||||||
|
self._app_finished = True
|
||||||
|
# Ensure events are set to unblock waiters
|
||||||
|
if not self._startup_complete.is_set():
|
||||||
|
self._startup_failed = True
|
||||||
|
self._startup_error = "Application exited before startup complete"
|
||||||
|
self._startup_complete.set()
|
||||||
|
if not self._shutdown_complete.is_set():
|
||||||
|
self._shutdown_complete.set()
|
||||||
|
|
||||||
|
async def _receive(self):
|
||||||
|
"""ASGI receive callable for lifespan."""
|
||||||
|
return await self._receive_queue.get()
|
||||||
|
|
||||||
|
async def _send(self, message):
|
||||||
|
"""ASGI send callable for lifespan."""
|
||||||
|
msg_type = message["type"]
|
||||||
|
|
||||||
|
if msg_type == "lifespan.startup.complete":
|
||||||
|
self._startup_complete.set()
|
||||||
|
self.logger.debug("Received lifespan.startup.complete")
|
||||||
|
|
||||||
|
elif msg_type == "lifespan.startup.failed":
|
||||||
|
self._startup_failed = True
|
||||||
|
self._startup_error = message.get("message", "")
|
||||||
|
self._startup_complete.set()
|
||||||
|
self.logger.debug("Received lifespan.startup.failed: %s",
|
||||||
|
self._startup_error)
|
||||||
|
|
||||||
|
elif msg_type == "lifespan.shutdown.complete":
|
||||||
|
self._shutdown_complete.set()
|
||||||
|
self.logger.debug("Received lifespan.shutdown.complete")
|
||||||
|
|
||||||
|
elif msg_type == "lifespan.shutdown.failed":
|
||||||
|
self._shutdown_error = message.get("message", "")
|
||||||
|
self._shutdown_complete.set()
|
||||||
|
self.logger.debug("Received lifespan.shutdown.failed: %s",
|
||||||
|
self._shutdown_error)
|
||||||
562
gunicorn/asgi/message.py
Normal file
562
gunicorn/asgi/message.py
Normal file
@ -0,0 +1,562 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Async version of gunicorn/http/message.py for ASGI workers.
|
||||||
|
|
||||||
|
Reuses the parsing logic from the sync version, adapted for async I/O.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import re
|
||||||
|
import socket
|
||||||
|
|
||||||
|
from gunicorn.http.errors import (
|
||||||
|
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||||
|
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
|
||||||
|
LimitRequestLine, LimitRequestHeaders,
|
||||||
|
UnsupportedTransferCoding, ObsoleteFolding,
|
||||||
|
InvalidProxyLine, ForbiddenProxyRequest,
|
||||||
|
InvalidSchemeHeaders,
|
||||||
|
)
|
||||||
|
from gunicorn.util import bytes_to_str, split_request_uri
|
||||||
|
|
||||||
|
MAX_REQUEST_LINE = 8190
|
||||||
|
MAX_HEADERS = 32768
|
||||||
|
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
|
||||||
|
|
||||||
|
# Reuse regex patterns from sync version
|
||||||
|
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
|
||||||
|
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
|
||||||
|
METHOD_BADCHAR_RE = re.compile("[a-z#]")
|
||||||
|
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
|
||||||
|
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncRequest:
|
||||||
|
"""Async HTTP request parser.
|
||||||
|
|
||||||
|
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
|
||||||
|
parsing logic where possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||||
|
self.cfg = cfg
|
||||||
|
self.unreader = unreader
|
||||||
|
self.peer_addr = peer_addr
|
||||||
|
self.remote_addr = peer_addr
|
||||||
|
self.req_number = req_number
|
||||||
|
|
||||||
|
self.version = None
|
||||||
|
self.method = None
|
||||||
|
self.uri = None
|
||||||
|
self.path = None
|
||||||
|
self.query = None
|
||||||
|
self.fragment = None
|
||||||
|
self.headers = []
|
||||||
|
self.trailers = []
|
||||||
|
self.scheme = "https" if cfg.is_ssl else "http"
|
||||||
|
self.must_close = False
|
||||||
|
|
||||||
|
self.proxy_protocol_info = None
|
||||||
|
|
||||||
|
# Request line limit
|
||||||
|
self.limit_request_line = cfg.limit_request_line
|
||||||
|
if (self.limit_request_line < 0
|
||||||
|
or self.limit_request_line >= MAX_REQUEST_LINE):
|
||||||
|
self.limit_request_line = MAX_REQUEST_LINE
|
||||||
|
|
||||||
|
# Headers limits
|
||||||
|
self.limit_request_fields = cfg.limit_request_fields
|
||||||
|
if (self.limit_request_fields <= 0
|
||||||
|
or self.limit_request_fields > MAX_HEADERS):
|
||||||
|
self.limit_request_fields = MAX_HEADERS
|
||||||
|
|
||||||
|
self.limit_request_field_size = cfg.limit_request_field_size
|
||||||
|
if self.limit_request_field_size < 0:
|
||||||
|
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
|
||||||
|
|
||||||
|
# Max header buffer size
|
||||||
|
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
|
||||||
|
self.max_buffer_headers = self.limit_request_fields * \
|
||||||
|
(max_header_field_size + 2) + 4
|
||||||
|
|
||||||
|
# Body-related state
|
||||||
|
self.content_length = None
|
||||||
|
self.chunked = False
|
||||||
|
self._body_reader = None
|
||||||
|
self._body_remaining = 0
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
|
||||||
|
"""Parse an HTTP request from the stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: gunicorn config object
|
||||||
|
unreader: AsyncUnreader instance
|
||||||
|
peer_addr: client address tuple
|
||||||
|
req_number: request number on this connection (for keepalive)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AsyncRequest: Parsed request object
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NoMoreData: If no data available
|
||||||
|
Various parsing errors for malformed requests
|
||||||
|
"""
|
||||||
|
req = cls(cfg, unreader, peer_addr, req_number)
|
||||||
|
await req._parse()
|
||||||
|
return req
|
||||||
|
|
||||||
|
async def _parse(self):
|
||||||
|
"""Parse the request from the unreader."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
await self._get_data(buf, stop=True)
|
||||||
|
|
||||||
|
# Get request line
|
||||||
|
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||||
|
|
||||||
|
# Proxy protocol
|
||||||
|
if self._proxy_protocol(bytes_to_str(line)):
|
||||||
|
# Get next request line
|
||||||
|
buf = io.BytesIO()
|
||||||
|
buf.write(rbuf)
|
||||||
|
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||||
|
|
||||||
|
self._parse_request_line(line)
|
||||||
|
buf = io.BytesIO()
|
||||||
|
buf.write(rbuf)
|
||||||
|
|
||||||
|
# Headers
|
||||||
|
data = buf.getvalue()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
idx = data.find(b"\r\n\r\n")
|
||||||
|
done = data[:2] == b"\r\n"
|
||||||
|
|
||||||
|
if idx < 0 and not done:
|
||||||
|
await self._get_data(buf)
|
||||||
|
data = buf.getvalue()
|
||||||
|
if len(data) > self.max_buffer_headers:
|
||||||
|
raise LimitRequestHeaders("max buffer headers")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if done:
|
||||||
|
self.unreader.unread(data[2:])
|
||||||
|
else:
|
||||||
|
self.headers = self._parse_headers(data[:idx], from_trailer=False)
|
||||||
|
self.unreader.unread(data[idx + 4:])
|
||||||
|
|
||||||
|
self._set_body_reader()
|
||||||
|
|
||||||
|
async def _get_data(self, buf, stop=False):
|
||||||
|
"""Read data from unreader into buffer."""
|
||||||
|
data = await self.unreader.read()
|
||||||
|
if not data:
|
||||||
|
if stop:
|
||||||
|
raise StopIteration()
|
||||||
|
raise NoMoreData(buf.getvalue())
|
||||||
|
buf.write(data)
|
||||||
|
|
||||||
|
async def _read_line(self, buf, limit=0):
|
||||||
|
"""Read a line from the buffer/stream."""
|
||||||
|
data = buf.getvalue()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
idx = data.find(b"\r\n")
|
||||||
|
if idx >= 0:
|
||||||
|
if idx > limit > 0:
|
||||||
|
raise LimitRequestLine(idx, limit)
|
||||||
|
break
|
||||||
|
if len(data) - 2 > limit > 0:
|
||||||
|
raise LimitRequestLine(len(data), limit)
|
||||||
|
await self._get_data(buf)
|
||||||
|
data = buf.getvalue()
|
||||||
|
|
||||||
|
return (data[:idx], data[idx + 2:])
|
||||||
|
|
||||||
|
def _proxy_protocol(self, line):
|
||||||
|
"""Detect, check and parse proxy protocol."""
|
||||||
|
if not self.cfg.proxy_protocol:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.req_number != 1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not line.startswith("PROXY"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._proxy_protocol_access_check()
|
||||||
|
self._parse_proxy_protocol(line)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _proxy_protocol_access_check(self):
|
||||||
|
"""Check if proxy protocol is allowed from this peer."""
|
||||||
|
if ("*" not in self.cfg.proxy_allow_ips and
|
||||||
|
isinstance(self.peer_addr, tuple) and
|
||||||
|
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
|
||||||
|
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||||
|
|
||||||
|
def _parse_proxy_protocol(self, line):
|
||||||
|
"""Parse proxy protocol header line."""
|
||||||
|
bits = line.split(" ")
|
||||||
|
|
||||||
|
if len(bits) != 6:
|
||||||
|
raise InvalidProxyLine(line)
|
||||||
|
|
||||||
|
proto = bits[1]
|
||||||
|
s_addr = bits[2]
|
||||||
|
d_addr = bits[3]
|
||||||
|
|
||||||
|
if proto not in ["TCP4", "TCP6"]:
|
||||||
|
raise InvalidProxyLine("protocol '%s' not supported" % proto)
|
||||||
|
|
||||||
|
if proto == "TCP4":
|
||||||
|
try:
|
||||||
|
socket.inet_pton(socket.AF_INET, s_addr)
|
||||||
|
socket.inet_pton(socket.AF_INET, d_addr)
|
||||||
|
except OSError:
|
||||||
|
raise InvalidProxyLine(line)
|
||||||
|
elif proto == "TCP6":
|
||||||
|
try:
|
||||||
|
socket.inet_pton(socket.AF_INET6, s_addr)
|
||||||
|
socket.inet_pton(socket.AF_INET6, d_addr)
|
||||||
|
except OSError:
|
||||||
|
raise InvalidProxyLine(line)
|
||||||
|
|
||||||
|
try:
|
||||||
|
s_port = int(bits[4])
|
||||||
|
d_port = int(bits[5])
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidProxyLine("invalid port %s" % line)
|
||||||
|
|
||||||
|
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
|
||||||
|
raise InvalidProxyLine("invalid port %s" % line)
|
||||||
|
|
||||||
|
self.proxy_protocol_info = {
|
||||||
|
"proxy_protocol": proto,
|
||||||
|
"client_addr": s_addr,
|
||||||
|
"client_port": s_port,
|
||||||
|
"proxy_addr": d_addr,
|
||||||
|
"proxy_port": d_port
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_request_line(self, line_bytes):
|
||||||
|
"""Parse the HTTP request line."""
|
||||||
|
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
|
||||||
|
if len(bits) != 3:
|
||||||
|
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||||
|
|
||||||
|
# Method
|
||||||
|
self.method = bits[0]
|
||||||
|
|
||||||
|
if not self.cfg.permit_unconventional_http_method:
|
||||||
|
if METHOD_BADCHAR_RE.search(self.method):
|
||||||
|
raise InvalidRequestMethod(self.method)
|
||||||
|
if not 3 <= len(bits[0]) <= 20:
|
||||||
|
raise InvalidRequestMethod(self.method)
|
||||||
|
if not TOKEN_RE.fullmatch(self.method):
|
||||||
|
raise InvalidRequestMethod(self.method)
|
||||||
|
if self.cfg.casefold_http_method:
|
||||||
|
self.method = self.method.upper()
|
||||||
|
|
||||||
|
# URI
|
||||||
|
self.uri = bits[1]
|
||||||
|
|
||||||
|
if len(self.uri) == 0:
|
||||||
|
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||||
|
|
||||||
|
try:
|
||||||
|
parts = split_request_uri(self.uri)
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||||
|
self.path = parts.path or ""
|
||||||
|
self.query = parts.query or ""
|
||||||
|
self.fragment = parts.fragment or ""
|
||||||
|
|
||||||
|
# Version
|
||||||
|
match = VERSION_RE.fullmatch(bits[2])
|
||||||
|
if match is None:
|
||||||
|
raise InvalidHTTPVersion(bits[2])
|
||||||
|
self.version = (int(match.group(1)), int(match.group(2)))
|
||||||
|
if not (1, 0) <= self.version < (2, 0):
|
||||||
|
if not self.cfg.permit_unconventional_http_version:
|
||||||
|
raise InvalidHTTPVersion(self.version)
|
||||||
|
|
||||||
|
def _parse_headers(self, data, from_trailer=False):
|
||||||
|
"""Parse HTTP headers from raw data."""
|
||||||
|
cfg = self.cfg
|
||||||
|
headers = []
|
||||||
|
|
||||||
|
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
|
||||||
|
|
||||||
|
# Handle scheme headers
|
||||||
|
scheme_header = False
|
||||||
|
secure_scheme_headers = {}
|
||||||
|
forwarder_headers = []
|
||||||
|
if from_trailer:
|
||||||
|
pass
|
||||||
|
elif ('*' in cfg.forwarded_allow_ips or
|
||||||
|
not isinstance(self.peer_addr, tuple)
|
||||||
|
or self.peer_addr[0] in cfg.forwarded_allow_ips):
|
||||||
|
secure_scheme_headers = cfg.secure_scheme_headers
|
||||||
|
forwarder_headers = cfg.forwarder_headers
|
||||||
|
|
||||||
|
while lines:
|
||||||
|
if len(headers) >= self.limit_request_fields:
|
||||||
|
raise LimitRequestHeaders("limit request headers fields")
|
||||||
|
|
||||||
|
curr = lines.pop(0)
|
||||||
|
header_length = len(curr) + len("\r\n")
|
||||||
|
if curr.find(":") <= 0:
|
||||||
|
raise InvalidHeader(curr)
|
||||||
|
name, value = curr.split(":", 1)
|
||||||
|
if self.cfg.strip_header_spaces:
|
||||||
|
name = name.rstrip(" \t")
|
||||||
|
if not TOKEN_RE.fullmatch(name):
|
||||||
|
raise InvalidHeaderName(name)
|
||||||
|
|
||||||
|
name = name.upper()
|
||||||
|
value = [value.strip(" \t")]
|
||||||
|
|
||||||
|
# Consume value continuation lines
|
||||||
|
while lines and lines[0].startswith((" ", "\t")):
|
||||||
|
if not self.cfg.permit_obsolete_folding:
|
||||||
|
raise ObsoleteFolding(name)
|
||||||
|
curr = lines.pop(0)
|
||||||
|
header_length += len(curr) + len("\r\n")
|
||||||
|
if header_length > self.limit_request_field_size > 0:
|
||||||
|
raise LimitRequestHeaders("limit request headers fields size")
|
||||||
|
value.append(curr.strip("\t "))
|
||||||
|
value = " ".join(value)
|
||||||
|
|
||||||
|
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
|
||||||
|
raise InvalidHeader(name)
|
||||||
|
|
||||||
|
if header_length > self.limit_request_field_size > 0:
|
||||||
|
raise LimitRequestHeaders("limit request headers fields size")
|
||||||
|
|
||||||
|
if name in secure_scheme_headers:
|
||||||
|
secure = value == secure_scheme_headers[name]
|
||||||
|
scheme = "https" if secure else "http"
|
||||||
|
if scheme_header:
|
||||||
|
if scheme != self.scheme:
|
||||||
|
raise InvalidSchemeHeaders()
|
||||||
|
else:
|
||||||
|
scheme_header = True
|
||||||
|
self.scheme = scheme
|
||||||
|
|
||||||
|
if "_" in name:
|
||||||
|
if name in forwarder_headers or "*" in forwarder_headers:
|
||||||
|
pass
|
||||||
|
elif self.cfg.header_map == "dangerous":
|
||||||
|
pass
|
||||||
|
elif self.cfg.header_map == "drop":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise InvalidHeaderName(name)
|
||||||
|
|
||||||
|
headers.append((name, value))
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _set_body_reader(self):
|
||||||
|
"""Determine how to read the request body."""
|
||||||
|
chunked = False
|
||||||
|
content_length = None
|
||||||
|
|
||||||
|
for (name, value) in self.headers:
|
||||||
|
if name == "CONTENT-LENGTH":
|
||||||
|
if content_length is not None:
|
||||||
|
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||||
|
content_length = value
|
||||||
|
elif name == "TRANSFER-ENCODING":
|
||||||
|
vals = [v.strip() for v in value.split(',')]
|
||||||
|
for val in vals:
|
||||||
|
if val.lower() == "chunked":
|
||||||
|
if chunked:
|
||||||
|
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||||
|
chunked = True
|
||||||
|
elif val.lower() == "identity":
|
||||||
|
if chunked:
|
||||||
|
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||||
|
elif val.lower() in ('compress', 'deflate', 'gzip'):
|
||||||
|
if chunked:
|
||||||
|
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||||
|
self.force_close()
|
||||||
|
else:
|
||||||
|
raise UnsupportedTransferCoding(value)
|
||||||
|
|
||||||
|
if chunked:
|
||||||
|
if self.version < (1, 1):
|
||||||
|
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||||
|
if content_length is not None:
|
||||||
|
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||||
|
self.chunked = True
|
||||||
|
self.content_length = None
|
||||||
|
self._body_remaining = -1
|
||||||
|
elif content_length is not None:
|
||||||
|
try:
|
||||||
|
if str(content_length).isnumeric():
|
||||||
|
content_length = int(content_length)
|
||||||
|
else:
|
||||||
|
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||||
|
|
||||||
|
if content_length < 0:
|
||||||
|
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||||
|
|
||||||
|
self.content_length = content_length
|
||||||
|
self._body_remaining = content_length
|
||||||
|
else:
|
||||||
|
# No body for requests without Content-Length or Transfer-Encoding
|
||||||
|
self.content_length = 0
|
||||||
|
self._body_remaining = 0
|
||||||
|
|
||||||
|
def force_close(self):
|
||||||
|
"""Mark connection for closing after this request."""
|
||||||
|
self.must_close = True
|
||||||
|
|
||||||
|
def should_close(self):
|
||||||
|
"""Check if connection should be closed after this request."""
|
||||||
|
if self.must_close:
|
||||||
|
return True
|
||||||
|
for (h, v) in self.headers:
|
||||||
|
if h == "CONNECTION":
|
||||||
|
v = v.lower().strip(" \t")
|
||||||
|
if v == "close":
|
||||||
|
return True
|
||||||
|
elif v == "keep-alive":
|
||||||
|
return False
|
||||||
|
break
|
||||||
|
return self.version <= (1, 0)
|
||||||
|
|
||||||
|
def get_header(self, name):
|
||||||
|
"""Get a header value by name (case-insensitive)."""
|
||||||
|
name = name.upper()
|
||||||
|
for (h, v) in self.headers:
|
||||||
|
if h == name:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def read_body(self, size=8192):
|
||||||
|
"""Read a chunk of the request body.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: Maximum bytes to read
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Body data, empty bytes when body is exhausted
|
||||||
|
"""
|
||||||
|
if self._body_remaining == 0:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
if self.chunked:
|
||||||
|
return await self._read_chunked_body(size)
|
||||||
|
else:
|
||||||
|
return await self._read_length_body(size)
|
||||||
|
|
||||||
|
async def _read_length_body(self, size):
|
||||||
|
"""Read from a length-delimited body."""
|
||||||
|
if self._body_remaining <= 0:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
to_read = min(size, self._body_remaining)
|
||||||
|
data = await self.unreader.read(to_read)
|
||||||
|
if data:
|
||||||
|
self._body_remaining -= len(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def _read_chunked_body(self, size):
|
||||||
|
"""Read from a chunked body."""
|
||||||
|
if self._body_reader is None:
|
||||||
|
self._body_reader = self._chunked_body_reader()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._body_reader.__anext__()
|
||||||
|
except StopAsyncIteration:
|
||||||
|
self._body_remaining = 0
|
||||||
|
return b""
|
||||||
|
|
||||||
|
async def _chunked_body_reader(self):
|
||||||
|
"""Async generator for reading chunked body."""
|
||||||
|
while True:
|
||||||
|
# Read chunk size line
|
||||||
|
size_line = await self._read_chunk_size_line()
|
||||||
|
# Parse chunk size (handle extensions)
|
||||||
|
chunk_size, *_ = size_line.split(b";", 1)
|
||||||
|
if _ :
|
||||||
|
chunk_size = chunk_size.rstrip(b" \t")
|
||||||
|
|
||||||
|
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
|
||||||
|
raise InvalidHeader("Invalid chunk size")
|
||||||
|
if len(chunk_size) == 0:
|
||||||
|
raise InvalidHeader("Invalid chunk size")
|
||||||
|
|
||||||
|
chunk_size = int(chunk_size, 16)
|
||||||
|
|
||||||
|
if chunk_size == 0:
|
||||||
|
# Final chunk - skip trailers and final CRLF
|
||||||
|
await self._skip_trailers()
|
||||||
|
return
|
||||||
|
|
||||||
|
# Read chunk data
|
||||||
|
remaining = chunk_size
|
||||||
|
while remaining > 0:
|
||||||
|
data = await self.unreader.read(min(remaining, 8192))
|
||||||
|
if not data:
|
||||||
|
raise NoMoreData()
|
||||||
|
remaining -= len(data)
|
||||||
|
yield data
|
||||||
|
|
||||||
|
# Skip chunk terminating CRLF
|
||||||
|
crlf = await self.unreader.read(2)
|
||||||
|
if crlf != b"\r\n":
|
||||||
|
# May have partial read, try to get the rest
|
||||||
|
while len(crlf) < 2:
|
||||||
|
more = await self.unreader.read(2 - len(crlf))
|
||||||
|
if not more:
|
||||||
|
break
|
||||||
|
crlf += more
|
||||||
|
if crlf != b"\r\n":
|
||||||
|
raise InvalidHeader("Missing chunk terminator")
|
||||||
|
|
||||||
|
async def _read_chunk_size_line(self):
|
||||||
|
"""Read a chunk size line."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
while True:
|
||||||
|
data = await self.unreader.read(1)
|
||||||
|
if not data:
|
||||||
|
raise NoMoreData()
|
||||||
|
buf.write(data)
|
||||||
|
if buf.getvalue().endswith(b"\r\n"):
|
||||||
|
return buf.getvalue()[:-2]
|
||||||
|
|
||||||
|
async def _skip_trailers(self):
|
||||||
|
"""Skip trailer headers after chunked body."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
while True:
|
||||||
|
data = await self.unreader.read(1)
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
buf.write(data)
|
||||||
|
content = buf.getvalue()
|
||||||
|
if content.endswith(b"\r\n\r\n"):
|
||||||
|
# Could parse trailers here if needed
|
||||||
|
return
|
||||||
|
if content == b"\r\n":
|
||||||
|
return
|
||||||
|
|
||||||
|
async def drain_body(self):
|
||||||
|
"""Drain any unread body data.
|
||||||
|
|
||||||
|
Should be called before reusing connection for keepalive.
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
data = await self.read_body(8192)
|
||||||
|
if not data:
|
||||||
|
break
|
||||||
424
gunicorn/asgi/protocol.py
Normal file
424
gunicorn/asgi/protocol.py
Normal file
@ -0,0 +1,424 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ASGI protocol handler for gunicorn.
|
||||||
|
|
||||||
|
Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch
|
||||||
|
to ASGI applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from gunicorn.asgi.unreader import AsyncUnreader
|
||||||
|
from gunicorn.asgi.message import AsyncRequest
|
||||||
|
from gunicorn.http.errors import NoMoreData
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIProtocol(asyncio.Protocol):
|
||||||
|
"""HTTP/1.1 protocol handler for ASGI applications.
|
||||||
|
|
||||||
|
Handles connection lifecycle, request parsing, and ASGI app invocation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, worker):
|
||||||
|
self.worker = worker
|
||||||
|
self.cfg = worker.cfg
|
||||||
|
self.log = worker.log
|
||||||
|
self.app = worker.asgi
|
||||||
|
|
||||||
|
self.transport = None
|
||||||
|
self.reader = None
|
||||||
|
self.writer = None
|
||||||
|
self._task = None
|
||||||
|
self.req_count = 0
|
||||||
|
|
||||||
|
# Connection state
|
||||||
|
self._closed = False
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
"""Called when a connection is established."""
|
||||||
|
self.transport = transport
|
||||||
|
self.worker.nr_conns += 1
|
||||||
|
|
||||||
|
# Create stream reader/writer
|
||||||
|
self.reader = asyncio.StreamReader()
|
||||||
|
self.writer = transport
|
||||||
|
|
||||||
|
# Start handling requests
|
||||||
|
self._task = self.worker.loop.create_task(self._handle_connection())
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
"""Called when data is received on the connection."""
|
||||||
|
if self.reader:
|
||||||
|
self.reader.feed_data(data)
|
||||||
|
|
||||||
|
def connection_lost(self, exc):
|
||||||
|
"""Called when the connection is lost or closed."""
|
||||||
|
self._closed = True
|
||||||
|
self.worker.nr_conns -= 1
|
||||||
|
if self.reader:
|
||||||
|
self.reader.feed_eof()
|
||||||
|
if self._task and not self._task.done():
|
||||||
|
self._task.cancel()
|
||||||
|
|
||||||
|
async def _handle_connection(self):
|
||||||
|
"""Main request handling loop for this connection."""
|
||||||
|
unreader = AsyncUnreader(self.reader)
|
||||||
|
|
||||||
|
try:
|
||||||
|
peername = self.transport.get_extra_info('peername')
|
||||||
|
sockname = self.transport.get_extra_info('sockname')
|
||||||
|
|
||||||
|
while not self._closed:
|
||||||
|
self.req_count += 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse HTTP request
|
||||||
|
request = await AsyncRequest.parse(
|
||||||
|
self.cfg,
|
||||||
|
unreader,
|
||||||
|
peername,
|
||||||
|
self.req_count
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
# No more data, close connection
|
||||||
|
break
|
||||||
|
except NoMoreData:
|
||||||
|
# Client disconnected
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check for WebSocket upgrade
|
||||||
|
if self._is_websocket_upgrade(request):
|
||||||
|
await self._handle_websocket(request, sockname, peername)
|
||||||
|
break # WebSocket takes over the connection
|
||||||
|
else:
|
||||||
|
# Handle HTTP request
|
||||||
|
keepalive = await self._handle_http_request(
|
||||||
|
request, sockname, peername
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increment worker request count
|
||||||
|
self.worker.nr += 1
|
||||||
|
|
||||||
|
# Check max_requests
|
||||||
|
if self.worker.nr >= self.worker.max_requests:
|
||||||
|
self.log.info("Autorestarting worker after current request.")
|
||||||
|
self.worker.alive = False
|
||||||
|
keepalive = False
|
||||||
|
|
||||||
|
if not keepalive or not self.worker.alive:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check connection limits for keepalive
|
||||||
|
if not self.cfg.keepalive:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Drain any unread body before next request
|
||||||
|
await request.drain_body()
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception("Error handling connection: %s", e)
|
||||||
|
finally:
|
||||||
|
self._close_transport()
|
||||||
|
|
||||||
|
def _is_websocket_upgrade(self, request):
|
||||||
|
"""Check if request is a WebSocket upgrade."""
|
||||||
|
upgrade = None
|
||||||
|
connection = None
|
||||||
|
for name, value in request.headers:
|
||||||
|
if name == "UPGRADE":
|
||||||
|
upgrade = value.lower()
|
||||||
|
elif name == "CONNECTION":
|
||||||
|
connection = value.lower()
|
||||||
|
return upgrade == "websocket" and connection and "upgrade" in connection
|
||||||
|
|
||||||
|
async def _handle_websocket(self, request, sockname, peername):
|
||||||
|
"""Handle WebSocket upgrade request."""
|
||||||
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||||
|
|
||||||
|
scope = self._build_websocket_scope(request, sockname, peername)
|
||||||
|
ws_protocol = WebSocketProtocol(
|
||||||
|
self.transport, self.reader, scope, self.app, self.log
|
||||||
|
)
|
||||||
|
await ws_protocol.run()
|
||||||
|
|
||||||
|
async def _handle_http_request(self, request, sockname, peername):
|
||||||
|
"""Handle a single HTTP request."""
|
||||||
|
scope = self._build_http_scope(request, sockname, peername)
|
||||||
|
response_started = False
|
||||||
|
response_complete = False
|
||||||
|
body_parts = []
|
||||||
|
exc_to_raise = None
|
||||||
|
|
||||||
|
# Receive queue for body
|
||||||
|
receive_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
# Pre-populate with initial body state
|
||||||
|
if request.content_length == 0 and not request.chunked:
|
||||||
|
await receive_queue.put({
|
||||||
|
"type": "http.request",
|
||||||
|
"body": b"",
|
||||||
|
"more_body": False,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# Start body reading task
|
||||||
|
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
|
||||||
|
|
||||||
|
async def receive():
|
||||||
|
return await receive_queue.get()
|
||||||
|
|
||||||
|
async def send(message):
|
||||||
|
nonlocal response_started, response_complete, exc_to_raise
|
||||||
|
|
||||||
|
msg_type = message["type"]
|
||||||
|
|
||||||
|
if msg_type == "http.response.start":
|
||||||
|
if response_started:
|
||||||
|
exc_to_raise = RuntimeError("Response already started")
|
||||||
|
return
|
||||||
|
response_started = True
|
||||||
|
status = message["status"]
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
await self._send_response_start(status, headers, request)
|
||||||
|
|
||||||
|
elif msg_type == "http.response.body":
|
||||||
|
if not response_started:
|
||||||
|
exc_to_raise = RuntimeError("Response not started")
|
||||||
|
return
|
||||||
|
if response_complete:
|
||||||
|
exc_to_raise = RuntimeError("Response already complete")
|
||||||
|
return
|
||||||
|
|
||||||
|
body = message.get("body", b"")
|
||||||
|
more_body = message.get("more_body", False)
|
||||||
|
|
||||||
|
if body:
|
||||||
|
await self._send_body(body)
|
||||||
|
|
||||||
|
if not more_body:
|
||||||
|
response_complete = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
request_start = datetime.now()
|
||||||
|
self.cfg.pre_request(self.worker, request)
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
if exc_to_raise:
|
||||||
|
raise exc_to_raise
|
||||||
|
|
||||||
|
# Ensure response was sent
|
||||||
|
if not response_started:
|
||||||
|
await self._send_error_response(500, "Internal Server Error")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception("Error in ASGI application")
|
||||||
|
if not response_started:
|
||||||
|
await self._send_error_response(500, "Internal Server Error")
|
||||||
|
return False
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
request_time = datetime.now() - request_start
|
||||||
|
self.cfg.post_request(self.worker, request, {}, None)
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Exception in post_request hook")
|
||||||
|
|
||||||
|
# Determine keepalive
|
||||||
|
if request.should_close():
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.worker.alive and self.cfg.keepalive
|
||||||
|
|
||||||
|
async def _read_body_to_queue(self, request, queue):
|
||||||
|
"""Read request body and put chunks on the queue."""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
chunk = await request.read_body(65536)
|
||||||
|
if chunk:
|
||||||
|
await queue.put({
|
||||||
|
"type": "http.request",
|
||||||
|
"body": chunk,
|
||||||
|
"more_body": True,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
await queue.put({
|
||||||
|
"type": "http.request",
|
||||||
|
"body": b"",
|
||||||
|
"more_body": False,
|
||||||
|
})
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
self.log.debug("Error reading body: %s", e)
|
||||||
|
await queue.put({
|
||||||
|
"type": "http.request",
|
||||||
|
"body": b"",
|
||||||
|
"more_body": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
def _build_http_scope(self, request, sockname, peername):
|
||||||
|
"""Build ASGI HTTP scope from parsed request."""
|
||||||
|
# Build headers list as bytes tuples
|
||||||
|
headers = []
|
||||||
|
for name, value in request.headers:
|
||||||
|
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||||
|
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||||
|
"method": request.method,
|
||||||
|
"scheme": request.scheme,
|
||||||
|
"path": request.path,
|
||||||
|
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||||
|
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||||
|
"root_path": self.cfg.root_path or "",
|
||||||
|
"headers": headers,
|
||||||
|
"server": sockname if sockname else None,
|
||||||
|
"client": peername if peername else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add state dict for lifespan sharing
|
||||||
|
if hasattr(self.worker, 'state'):
|
||||||
|
scope["state"] = self.worker.state
|
||||||
|
|
||||||
|
return scope
|
||||||
|
|
||||||
|
def _build_websocket_scope(self, request, sockname, peername):
|
||||||
|
"""Build ASGI WebSocket scope from parsed request."""
|
||||||
|
# Build headers list as bytes tuples
|
||||||
|
headers = []
|
||||||
|
for name, value in request.headers:
|
||||||
|
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||||
|
|
||||||
|
# Extract subprotocols from Sec-WebSocket-Protocol header
|
||||||
|
subprotocols = []
|
||||||
|
for name, value in request.headers:
|
||||||
|
if name == "SEC-WEBSOCKET-PROTOCOL":
|
||||||
|
subprotocols = [s.strip() for s in value.split(",")]
|
||||||
|
break
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "websocket",
|
||||||
|
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||||
|
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||||
|
"scheme": "wss" if request.scheme == "https" else "ws",
|
||||||
|
"path": request.path,
|
||||||
|
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||||
|
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||||
|
"root_path": self.cfg.root_path or "",
|
||||||
|
"headers": headers,
|
||||||
|
"server": sockname if sockname else None,
|
||||||
|
"client": peername if peername else None,
|
||||||
|
"subprotocols": subprotocols,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add state dict for lifespan sharing
|
||||||
|
if hasattr(self.worker, 'state'):
|
||||||
|
scope["state"] = self.worker.state
|
||||||
|
|
||||||
|
return scope
|
||||||
|
|
||||||
|
async def _send_response_start(self, status, headers, request):
|
||||||
|
"""Send HTTP response status and headers."""
|
||||||
|
# Build status line
|
||||||
|
reason = self._get_reason_phrase(status)
|
||||||
|
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
|
||||||
|
|
||||||
|
# Build headers
|
||||||
|
header_lines = []
|
||||||
|
has_content_length = False
|
||||||
|
has_transfer_encoding = False
|
||||||
|
has_connection = False
|
||||||
|
|
||||||
|
for name, value in headers:
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = name.decode("latin-1")
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode("latin-1")
|
||||||
|
header_lines.append(f"{name}: {value}\r\n")
|
||||||
|
name_lower = name.lower()
|
||||||
|
if name_lower == "content-length":
|
||||||
|
has_content_length = True
|
||||||
|
elif name_lower == "transfer-encoding":
|
||||||
|
has_transfer_encoding = True
|
||||||
|
elif name_lower == "connection":
|
||||||
|
has_connection = True
|
||||||
|
|
||||||
|
# Add server header if not present
|
||||||
|
header_lines.append("Server: gunicorn/asgi\r\n")
|
||||||
|
|
||||||
|
response = status_line + "".join(header_lines) + "\r\n"
|
||||||
|
self.transport.write(response.encode("latin-1"))
|
||||||
|
|
||||||
|
async def _send_body(self, body):
|
||||||
|
"""Send response body chunk."""
|
||||||
|
if body:
|
||||||
|
self.transport.write(body)
|
||||||
|
|
||||||
|
async def _send_error_response(self, status, message):
|
||||||
|
"""Send an error response."""
|
||||||
|
body = message.encode("utf-8")
|
||||||
|
response = (
|
||||||
|
f"HTTP/1.1 {status} {message}\r\n"
|
||||||
|
f"Content-Type: text/plain\r\n"
|
||||||
|
f"Content-Length: {len(body)}\r\n"
|
||||||
|
f"Connection: close\r\n"
|
||||||
|
f"\r\n"
|
||||||
|
)
|
||||||
|
self.transport.write(response.encode("latin-1"))
|
||||||
|
self.transport.write(body)
|
||||||
|
|
||||||
|
def _get_reason_phrase(self, status):
|
||||||
|
"""Get HTTP reason phrase for status code."""
|
||||||
|
reasons = {
|
||||||
|
100: "Continue",
|
||||||
|
101: "Switching Protocols",
|
||||||
|
200: "OK",
|
||||||
|
201: "Created",
|
||||||
|
202: "Accepted",
|
||||||
|
204: "No Content",
|
||||||
|
206: "Partial Content",
|
||||||
|
301: "Moved Permanently",
|
||||||
|
302: "Found",
|
||||||
|
303: "See Other",
|
||||||
|
304: "Not Modified",
|
||||||
|
307: "Temporary Redirect",
|
||||||
|
308: "Permanent Redirect",
|
||||||
|
400: "Bad Request",
|
||||||
|
401: "Unauthorized",
|
||||||
|
403: "Forbidden",
|
||||||
|
404: "Not Found",
|
||||||
|
405: "Method Not Allowed",
|
||||||
|
408: "Request Timeout",
|
||||||
|
409: "Conflict",
|
||||||
|
410: "Gone",
|
||||||
|
411: "Length Required",
|
||||||
|
413: "Payload Too Large",
|
||||||
|
414: "URI Too Long",
|
||||||
|
415: "Unsupported Media Type",
|
||||||
|
422: "Unprocessable Entity",
|
||||||
|
429: "Too Many Requests",
|
||||||
|
500: "Internal Server Error",
|
||||||
|
501: "Not Implemented",
|
||||||
|
502: "Bad Gateway",
|
||||||
|
503: "Service Unavailable",
|
||||||
|
504: "Gateway Timeout",
|
||||||
|
}
|
||||||
|
return reasons.get(status, "Unknown")
|
||||||
|
|
||||||
|
def _close_transport(self):
|
||||||
|
"""Close the transport safely."""
|
||||||
|
if self.transport and not self._closed:
|
||||||
|
try:
|
||||||
|
self.transport.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._closed = True
|
||||||
100
gunicorn/asgi/unreader.py
Normal file
100
gunicorn/asgi/unreader.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Async version of gunicorn/http/unreader.py for ASGI workers.
|
||||||
|
|
||||||
|
Provides async reading with pushback buffer support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncUnreader:
|
||||||
|
"""Async socket reader with pushback buffer support.
|
||||||
|
|
||||||
|
This class wraps an asyncio StreamReader and provides the ability
|
||||||
|
to "unread" data back into a buffer for re-parsing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reader, max_chunk=8192):
|
||||||
|
"""Initialize the async unreader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reader: asyncio.StreamReader instance
|
||||||
|
max_chunk: Maximum bytes to read at once
|
||||||
|
"""
|
||||||
|
self.reader = reader
|
||||||
|
self.buf = io.BytesIO()
|
||||||
|
self.max_chunk = max_chunk
|
||||||
|
|
||||||
|
async def read(self, size=None):
|
||||||
|
"""Read data from the stream, using buffered data first.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: Number of bytes to read. If None, returns all buffered
|
||||||
|
data or reads a single chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: Data read from buffer or stream
|
||||||
|
"""
|
||||||
|
if size is not None and not isinstance(size, int):
|
||||||
|
raise TypeError("size parameter must be an int or long.")
|
||||||
|
|
||||||
|
if size is not None:
|
||||||
|
if size == 0:
|
||||||
|
return b""
|
||||||
|
if size < 0:
|
||||||
|
size = None
|
||||||
|
|
||||||
|
# Move to end to check buffer size
|
||||||
|
self.buf.seek(0, io.SEEK_END)
|
||||||
|
|
||||||
|
# If no size specified, return buffered data or read chunk
|
||||||
|
if size is None and self.buf.tell():
|
||||||
|
ret = self.buf.getvalue()
|
||||||
|
self.buf = io.BytesIO()
|
||||||
|
return ret
|
||||||
|
if size is None:
|
||||||
|
chunk = await self._read_chunk()
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
# Read until we have enough data
|
||||||
|
while self.buf.tell() < size:
|
||||||
|
chunk = await self._read_chunk()
|
||||||
|
if not chunk:
|
||||||
|
ret = self.buf.getvalue()
|
||||||
|
self.buf = io.BytesIO()
|
||||||
|
return ret
|
||||||
|
self.buf.write(chunk)
|
||||||
|
|
||||||
|
data = self.buf.getvalue()
|
||||||
|
self.buf = io.BytesIO()
|
||||||
|
self.buf.write(data[size:])
|
||||||
|
return data[:size]
|
||||||
|
|
||||||
|
async def _read_chunk(self):
|
||||||
|
"""Read a chunk of data from the underlying stream."""
|
||||||
|
try:
|
||||||
|
return await self.reader.read(self.max_chunk)
|
||||||
|
except Exception:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
def unread(self, data):
|
||||||
|
"""Push data back into the buffer for re-reading.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: bytes to push back
|
||||||
|
"""
|
||||||
|
if data:
|
||||||
|
self.buf.seek(0, io.SEEK_END)
|
||||||
|
self.buf.write(data)
|
||||||
|
|
||||||
|
def has_buffered_data(self):
|
||||||
|
"""Check if there's data in the pushback buffer."""
|
||||||
|
pos = self.buf.tell()
|
||||||
|
self.buf.seek(0, io.SEEK_END)
|
||||||
|
has_data = self.buf.tell() > 0
|
||||||
|
self.buf.seek(pos)
|
||||||
|
return has_data
|
||||||
369
gunicorn/asgi/websocket.py
Normal file
369
gunicorn/asgi/websocket.py
Normal file
@ -0,0 +1,369 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
WebSocket protocol handler for ASGI.
|
||||||
|
|
||||||
|
Implements RFC 6455 WebSocket protocol for ASGI applications.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import struct
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# WebSocket frame opcodes
|
||||||
|
OPCODE_CONTINUATION = 0x0
|
||||||
|
OPCODE_TEXT = 0x1
|
||||||
|
OPCODE_BINARY = 0x2
|
||||||
|
OPCODE_CLOSE = 0x8
|
||||||
|
OPCODE_PING = 0x9
|
||||||
|
OPCODE_PONG = 0xA
|
||||||
|
|
||||||
|
# WebSocket close codes
|
||||||
|
CLOSE_NORMAL = 1000
|
||||||
|
CLOSE_GOING_AWAY = 1001
|
||||||
|
CLOSE_PROTOCOL_ERROR = 1002
|
||||||
|
CLOSE_UNSUPPORTED = 1003
|
||||||
|
CLOSE_NO_STATUS = 1005
|
||||||
|
CLOSE_ABNORMAL = 1006
|
||||||
|
CLOSE_INVALID_DATA = 1007
|
||||||
|
CLOSE_POLICY_VIOLATION = 1008
|
||||||
|
CLOSE_MESSAGE_TOO_BIG = 1009
|
||||||
|
CLOSE_MANDATORY_EXT = 1010
|
||||||
|
CLOSE_INTERNAL_ERROR = 1011
|
||||||
|
|
||||||
|
# WebSocket handshake GUID (RFC 6455)
|
||||||
|
WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketProtocol:
|
||||||
|
"""WebSocket connection handler for ASGI applications."""
|
||||||
|
|
||||||
|
def __init__(self, transport, reader, scope, app, log):
|
||||||
|
"""Initialize WebSocket protocol handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
transport: asyncio transport for writing
|
||||||
|
reader: asyncio StreamReader for reading
|
||||||
|
scope: ASGI WebSocket scope dict
|
||||||
|
app: ASGI application callable
|
||||||
|
log: Logger instance
|
||||||
|
"""
|
||||||
|
self.transport = transport
|
||||||
|
self.reader = reader
|
||||||
|
self.scope = scope
|
||||||
|
self.app = app
|
||||||
|
self.log = log
|
||||||
|
|
||||||
|
self.accepted = False
|
||||||
|
self.closed = False
|
||||||
|
self.close_code = None
|
||||||
|
self.close_reason = ""
|
||||||
|
|
||||||
|
# Message reassembly state
|
||||||
|
self._fragments = []
|
||||||
|
self._fragment_opcode = None
|
||||||
|
|
||||||
|
# Receive queue for incoming messages
|
||||||
|
self._receive_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""Run the WebSocket ASGI application."""
|
||||||
|
# Send initial connect event
|
||||||
|
await self._receive_queue.put({"type": "websocket.connect"})
|
||||||
|
|
||||||
|
# Start frame reading task
|
||||||
|
read_task = asyncio.create_task(self._read_frames())
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.app(self.scope, self._receive, self._send)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception("Error in WebSocket ASGI application")
|
||||||
|
finally:
|
||||||
|
read_task.cancel()
|
||||||
|
try:
|
||||||
|
await read_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Send close frame if not already closed
|
||||||
|
if not self.closed and self.accepted:
|
||||||
|
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
|
||||||
|
|
||||||
|
async def _receive(self):
|
||||||
|
"""ASGI receive callable."""
|
||||||
|
return await self._receive_queue.get()
|
||||||
|
|
||||||
|
async def _send(self, message):
|
||||||
|
"""ASGI send callable."""
|
||||||
|
msg_type = message["type"]
|
||||||
|
|
||||||
|
if msg_type == "websocket.accept":
|
||||||
|
if self.accepted:
|
||||||
|
raise RuntimeError("WebSocket already accepted")
|
||||||
|
await self._send_accept(message)
|
||||||
|
self.accepted = True
|
||||||
|
|
||||||
|
elif msg_type == "websocket.send":
|
||||||
|
if not self.accepted:
|
||||||
|
raise RuntimeError("WebSocket not accepted")
|
||||||
|
if self.closed:
|
||||||
|
raise RuntimeError("WebSocket closed")
|
||||||
|
|
||||||
|
if "text" in message:
|
||||||
|
await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8"))
|
||||||
|
elif "bytes" in message:
|
||||||
|
await self._send_frame(OPCODE_BINARY, message["bytes"])
|
||||||
|
|
||||||
|
elif msg_type == "websocket.close":
|
||||||
|
code = message.get("code", CLOSE_NORMAL)
|
||||||
|
reason = message.get("reason", "")
|
||||||
|
await self._send_close(code, reason)
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
async def _send_accept(self, message):
|
||||||
|
"""Send WebSocket handshake accept response."""
|
||||||
|
# Get Sec-WebSocket-Key from headers
|
||||||
|
ws_key = None
|
||||||
|
for name, value in self.scope["headers"]:
|
||||||
|
if name == b"sec-websocket-key":
|
||||||
|
ws_key = value
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ws_key:
|
||||||
|
raise RuntimeError("Missing Sec-WebSocket-Key header")
|
||||||
|
|
||||||
|
# Calculate accept key
|
||||||
|
accept_key = base64.b64encode(
|
||||||
|
hashlib.sha1(ws_key + WS_GUID).digest()
|
||||||
|
).decode("ascii")
|
||||||
|
|
||||||
|
# Build response headers
|
||||||
|
headers = [
|
||||||
|
"HTTP/1.1 101 Switching Protocols\r\n",
|
||||||
|
"Upgrade: websocket\r\n",
|
||||||
|
"Connection: Upgrade\r\n",
|
||||||
|
f"Sec-WebSocket-Accept: {accept_key}\r\n",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add selected subprotocol if specified
|
||||||
|
subprotocol = message.get("subprotocol")
|
||||||
|
if subprotocol:
|
||||||
|
headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n")
|
||||||
|
|
||||||
|
# Add any extra headers from message
|
||||||
|
extra_headers = message.get("headers", [])
|
||||||
|
for name, value in extra_headers:
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = name.decode("latin-1")
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode("latin-1")
|
||||||
|
headers.append(f"{name}: {value}\r\n")
|
||||||
|
|
||||||
|
headers.append("\r\n")
|
||||||
|
self.transport.write("".join(headers).encode("latin-1"))
|
||||||
|
|
||||||
|
async def _read_frames(self):
|
||||||
|
"""Read and process incoming WebSocket frames."""
|
||||||
|
try:
|
||||||
|
while not self.closed:
|
||||||
|
frame = await self._read_frame()
|
||||||
|
if frame is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
opcode, payload = frame
|
||||||
|
|
||||||
|
if opcode == OPCODE_CLOSE:
|
||||||
|
await self._handle_close(payload)
|
||||||
|
break
|
||||||
|
elif opcode == OPCODE_PING:
|
||||||
|
await self._send_frame(OPCODE_PONG, payload)
|
||||||
|
elif opcode == OPCODE_PONG:
|
||||||
|
# Ignore pongs
|
||||||
|
pass
|
||||||
|
elif opcode == OPCODE_TEXT:
|
||||||
|
await self._receive_queue.put({
|
||||||
|
"type": "websocket.receive",
|
||||||
|
"text": payload.decode("utf-8"),
|
||||||
|
})
|
||||||
|
elif opcode == OPCODE_BINARY:
|
||||||
|
await self._receive_queue.put({
|
||||||
|
"type": "websocket.receive",
|
||||||
|
"bytes": payload,
|
||||||
|
})
|
||||||
|
elif opcode == OPCODE_CONTINUATION:
|
||||||
|
# Handle fragmented messages
|
||||||
|
await self._handle_continuation(payload)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
self.log.debug("WebSocket read error: %s", e)
|
||||||
|
finally:
|
||||||
|
# Signal disconnect
|
||||||
|
if not self.closed:
|
||||||
|
self.closed = True
|
||||||
|
await self._receive_queue.put({
|
||||||
|
"type": "websocket.disconnect",
|
||||||
|
"code": self.close_code or CLOSE_ABNORMAL,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def _read_frame(self):
|
||||||
|
"""Read a single WebSocket frame.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (opcode, payload) or None if connection closed
|
||||||
|
"""
|
||||||
|
# Read frame header (2 bytes minimum)
|
||||||
|
header = await self._read_exact(2)
|
||||||
|
if not header:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_byte, second_byte = header[0], header[1]
|
||||||
|
|
||||||
|
fin = (first_byte >> 7) & 1
|
||||||
|
rsv1 = (first_byte >> 6) & 1
|
||||||
|
rsv2 = (first_byte >> 5) & 1
|
||||||
|
rsv3 = (first_byte >> 4) & 1
|
||||||
|
opcode = first_byte & 0x0F
|
||||||
|
|
||||||
|
# RSV bits must be 0 (no extensions)
|
||||||
|
if rsv1 or rsv2 or rsv3:
|
||||||
|
await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set")
|
||||||
|
return None
|
||||||
|
|
||||||
|
masked = (second_byte >> 7) & 1
|
||||||
|
payload_len = second_byte & 0x7F
|
||||||
|
|
||||||
|
# Client frames must be masked (RFC 6455)
|
||||||
|
if not masked:
|
||||||
|
await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extended payload length
|
||||||
|
if payload_len == 126:
|
||||||
|
ext_len = await self._read_exact(2)
|
||||||
|
if not ext_len:
|
||||||
|
return None
|
||||||
|
payload_len = struct.unpack("!H", ext_len)[0]
|
||||||
|
elif payload_len == 127:
|
||||||
|
ext_len = await self._read_exact(8)
|
||||||
|
if not ext_len:
|
||||||
|
return None
|
||||||
|
payload_len = struct.unpack("!Q", ext_len)[0]
|
||||||
|
|
||||||
|
# Read masking key
|
||||||
|
masking_key = await self._read_exact(4)
|
||||||
|
if not masking_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Read payload
|
||||||
|
payload = await self._read_exact(payload_len)
|
||||||
|
if payload is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Unmask payload
|
||||||
|
payload = self._unmask(payload, masking_key)
|
||||||
|
|
||||||
|
# Handle fragmented messages
|
||||||
|
if opcode == OPCODE_CONTINUATION:
|
||||||
|
if self._fragment_opcode is None:
|
||||||
|
await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation")
|
||||||
|
return None
|
||||||
|
self._fragments.append(payload)
|
||||||
|
if fin:
|
||||||
|
# Reassemble complete message
|
||||||
|
full_payload = b"".join(self._fragments)
|
||||||
|
final_opcode = self._fragment_opcode
|
||||||
|
self._fragments = []
|
||||||
|
self._fragment_opcode = None
|
||||||
|
return (final_opcode, full_payload)
|
||||||
|
return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more
|
||||||
|
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||||
|
if not fin:
|
||||||
|
# Start of fragmented message
|
||||||
|
self._fragment_opcode = opcode
|
||||||
|
self._fragments = [payload]
|
||||||
|
return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more
|
||||||
|
return (opcode, payload)
|
||||||
|
else:
|
||||||
|
# Control frames
|
||||||
|
return (opcode, payload)
|
||||||
|
|
||||||
|
async def _read_exact(self, n):
|
||||||
|
"""Read exactly n bytes from the reader."""
|
||||||
|
try:
|
||||||
|
data = await self.reader.readexactly(n)
|
||||||
|
return data
|
||||||
|
except asyncio.IncompleteReadError:
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _unmask(self, payload, masking_key):
|
||||||
|
"""Unmask WebSocket payload data."""
|
||||||
|
if not payload:
|
||||||
|
return payload
|
||||||
|
# XOR each byte with corresponding mask byte
|
||||||
|
return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload))
|
||||||
|
|
||||||
|
async def _handle_close(self, payload):
|
||||||
|
"""Handle incoming close frame."""
|
||||||
|
if len(payload) >= 2:
|
||||||
|
self.close_code = struct.unpack("!H", payload[:2])[0]
|
||||||
|
self.close_reason = payload[2:].decode("utf-8", errors="replace")
|
||||||
|
else:
|
||||||
|
self.close_code = CLOSE_NO_STATUS
|
||||||
|
self.close_reason = ""
|
||||||
|
|
||||||
|
# Echo close frame back if we haven't already sent one
|
||||||
|
if not self.closed:
|
||||||
|
await self._send_close(self.close_code, self.close_reason)
|
||||||
|
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
async def _handle_continuation(self, payload):
|
||||||
|
"""Handle continuation frame (already processed in _read_frame)."""
|
||||||
|
# This is called for partial fragments, nothing to do
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _send_frame(self, opcode, payload):
|
||||||
|
"""Send a WebSocket frame.
|
||||||
|
|
||||||
|
Server frames are not masked (RFC 6455).
|
||||||
|
"""
|
||||||
|
if isinstance(payload, str):
|
||||||
|
payload = payload.encode("utf-8")
|
||||||
|
|
||||||
|
length = len(payload)
|
||||||
|
frame = bytearray()
|
||||||
|
|
||||||
|
# First byte: FIN + opcode
|
||||||
|
frame.append(0x80 | opcode)
|
||||||
|
|
||||||
|
# Second byte: length (no mask bit for server)
|
||||||
|
if length < 126:
|
||||||
|
frame.append(length)
|
||||||
|
elif length < 65536:
|
||||||
|
frame.append(126)
|
||||||
|
frame.extend(struct.pack("!H", length))
|
||||||
|
else:
|
||||||
|
frame.append(127)
|
||||||
|
frame.extend(struct.pack("!Q", length))
|
||||||
|
|
||||||
|
# Payload
|
||||||
|
frame.extend(payload)
|
||||||
|
|
||||||
|
self.transport.write(bytes(frame))
|
||||||
|
|
||||||
|
async def _send_close(self, code, reason=""):
|
||||||
|
"""Send a close frame."""
|
||||||
|
payload = struct.pack("!H", code)
|
||||||
|
if reason:
|
||||||
|
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
|
||||||
|
await self._send_frame(OPCODE_CLOSE, payload)
|
||||||
|
self.closed = True
|
||||||
@ -2440,3 +2440,94 @@ class HeaderMap(Setting):
|
|||||||
|
|
||||||
.. versionadded:: 22.0.0
|
.. versionadded:: 22.0.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_asgi_loop(val):
|
||||||
|
if val is None:
|
||||||
|
return "auto"
|
||||||
|
if not isinstance(val, str):
|
||||||
|
raise TypeError("Invalid type for casting: %s" % val)
|
||||||
|
val = val.lower().strip()
|
||||||
|
if val not in ("auto", "asyncio", "uvloop"):
|
||||||
|
raise ValueError("Invalid ASGI loop: %s" % val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def validate_asgi_lifespan(val):
|
||||||
|
if val is None:
|
||||||
|
return "auto"
|
||||||
|
if not isinstance(val, str):
|
||||||
|
raise TypeError("Invalid type for casting: %s" % val)
|
||||||
|
val = val.lower().strip()
|
||||||
|
if val not in ("auto", "on", "off"):
|
||||||
|
raise ValueError("Invalid ASGI lifespan: %s" % val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
class ASGILoop(Setting):
|
||||||
|
name = "asgi_loop"
|
||||||
|
section = "Worker Processes"
|
||||||
|
cli = ["--asgi-loop"]
|
||||||
|
meta = "STRING"
|
||||||
|
validator = validate_asgi_loop
|
||||||
|
default = "auto"
|
||||||
|
desc = """\
|
||||||
|
Event loop implementation for ASGI workers.
|
||||||
|
|
||||||
|
- auto: Use uvloop if available, otherwise asyncio
|
||||||
|
- asyncio: Use Python's built-in asyncio event loop
|
||||||
|
- uvloop: Use uvloop (must be installed separately)
|
||||||
|
|
||||||
|
This setting only affects the ``asgi`` worker type.
|
||||||
|
|
||||||
|
uvloop typically provides better performance but requires
|
||||||
|
installing the uvloop package.
|
||||||
|
|
||||||
|
.. versionadded:: 24.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ASGILifespan(Setting):
|
||||||
|
name = "asgi_lifespan"
|
||||||
|
section = "Worker Processes"
|
||||||
|
cli = ["--asgi-lifespan"]
|
||||||
|
meta = "STRING"
|
||||||
|
validator = validate_asgi_lifespan
|
||||||
|
default = "auto"
|
||||||
|
desc = """\
|
||||||
|
Control ASGI lifespan protocol handling.
|
||||||
|
|
||||||
|
- auto: Detect if app supports lifespan, enable if so
|
||||||
|
- on: Always run lifespan protocol (fail if unsupported)
|
||||||
|
- off: Never run lifespan protocol
|
||||||
|
|
||||||
|
The lifespan protocol allows ASGI applications to run code at
|
||||||
|
startup and shutdown. This is essential for frameworks like
|
||||||
|
FastAPI that need to initialize database connections, caches,
|
||||||
|
or other resources.
|
||||||
|
|
||||||
|
This setting only affects the ``asgi`` worker type.
|
||||||
|
|
||||||
|
.. versionadded:: 24.0.0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class RootPath(Setting):
|
||||||
|
name = "root_path"
|
||||||
|
section = "Server Mechanics"
|
||||||
|
cli = ["--root-path"]
|
||||||
|
meta = "STRING"
|
||||||
|
validator = validate_string
|
||||||
|
default = ""
|
||||||
|
desc = """\
|
||||||
|
The root path for ASGI applications.
|
||||||
|
|
||||||
|
This is used to set the ``root_path`` in the ASGI scope, which
|
||||||
|
allows applications to know their mount point when behind a
|
||||||
|
reverse proxy.
|
||||||
|
|
||||||
|
For example, if your application is mounted at ``/api``, set
|
||||||
|
this to ``/api``.
|
||||||
|
|
||||||
|
.. versionadded:: 24.0.0
|
||||||
|
"""
|
||||||
|
|||||||
@ -11,4 +11,5 @@ SUPPORTED_WORKERS = {
|
|||||||
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
|
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
|
||||||
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
|
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
|
||||||
"gthread": "gunicorn.workers.gthread.ThreadWorker",
|
"gthread": "gunicorn.workers.gthread.ThreadWorker",
|
||||||
|
"asgi": "gunicorn.workers.gasgi.ASGIWorker",
|
||||||
}
|
}
|
||||||
|
|||||||
282
gunicorn/workers/gasgi.py
Normal file
282
gunicorn/workers/gasgi.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
ASGI worker for gunicorn.
|
||||||
|
|
||||||
|
Provides native asyncio-based ASGI support using gunicorn's own
|
||||||
|
HTTP parsing infrastructure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import ssl
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from gunicorn.workers import base
|
||||||
|
from gunicorn.asgi.protocol import ASGIProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ASGIWorker(base.Worker):
|
||||||
|
"""ASGI worker using asyncio event loop.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- HTTP/1.1 with keepalive
|
||||||
|
- WebSocket connections
|
||||||
|
- Lifespan protocol (startup/shutdown hooks)
|
||||||
|
- Optional uvloop for improved performance
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.worker_connections = self.cfg.worker_connections
|
||||||
|
self.loop = None
|
||||||
|
self.servers = []
|
||||||
|
self.nr_conns = 0
|
||||||
|
self.lifespan = None
|
||||||
|
self.state = {} # Shared state for lifespan
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_config(cls, cfg, log):
|
||||||
|
"""Validate configuration for ASGI worker."""
|
||||||
|
if cfg.threads > 1:
|
||||||
|
log.warning("ASGI worker does not use threads configuration. "
|
||||||
|
"Use worker_connections instead.")
|
||||||
|
|
||||||
|
def init_process(self):
|
||||||
|
"""Initialize the worker process."""
|
||||||
|
# Setup event loop before calling super()
|
||||||
|
self._setup_event_loop()
|
||||||
|
super().init_process()
|
||||||
|
|
||||||
|
def _setup_event_loop(self):
|
||||||
|
"""Setup the asyncio event loop."""
|
||||||
|
loop_type = getattr(self.cfg, 'asgi_loop', 'auto')
|
||||||
|
|
||||||
|
if loop_type == "auto":
|
||||||
|
try:
|
||||||
|
import uvloop
|
||||||
|
loop_type = "uvloop"
|
||||||
|
except ImportError:
|
||||||
|
loop_type = "asyncio"
|
||||||
|
|
||||||
|
if loop_type == "uvloop":
|
||||||
|
try:
|
||||||
|
import uvloop
|
||||||
|
self.loop = uvloop.new_event_loop()
|
||||||
|
self.log.debug("Using uvloop event loop")
|
||||||
|
except ImportError:
|
||||||
|
self.log.warning("uvloop not available, falling back to asyncio")
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
else:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
self.log.debug("Using asyncio event loop")
|
||||||
|
|
||||||
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
|
def load_wsgi(self):
|
||||||
|
"""Load the ASGI application."""
|
||||||
|
try:
|
||||||
|
self.asgi = self.app.wsgi()
|
||||||
|
except SyntaxError as e:
|
||||||
|
if not self.cfg.reload:
|
||||||
|
raise
|
||||||
|
self.log.exception(e)
|
||||||
|
self.asgi = self._make_error_app(str(e))
|
||||||
|
|
||||||
|
def _make_error_app(self, error_msg):
|
||||||
|
"""Create an error ASGI app for syntax errors during reload."""
|
||||||
|
async def error_app(scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 500,
|
||||||
|
"headers": [(b"content-type", b"text/plain")],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": f"Application error: {error_msg}".encode(),
|
||||||
|
})
|
||||||
|
elif scope["type"] == "lifespan":
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.shutdown":
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return error_app
|
||||||
|
|
||||||
|
def init_signals(self):
|
||||||
|
"""Initialize signal handlers for asyncio."""
|
||||||
|
# Reset all signals first
|
||||||
|
for s in self.SIGNALS:
|
||||||
|
signal.signal(s, signal.SIG_DFL)
|
||||||
|
|
||||||
|
# Set up signal handlers via the event loop
|
||||||
|
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal)
|
||||||
|
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal)
|
||||||
|
|
||||||
|
def handle_quit_signal(self):
|
||||||
|
"""Handle SIGQUIT - immediate shutdown."""
|
||||||
|
self.alive = False
|
||||||
|
self.cfg.worker_int(self)
|
||||||
|
|
||||||
|
def handle_exit_signal(self):
|
||||||
|
"""Handle SIGTERM - graceful shutdown."""
|
||||||
|
self.alive = False
|
||||||
|
|
||||||
|
def handle_usr1_signal(self):
|
||||||
|
"""Handle SIGUSR1 - reopen log files."""
|
||||||
|
self.log.reopen_files()
|
||||||
|
|
||||||
|
def handle_winch_signal(self):
|
||||||
|
"""Handle SIGWINCH - ignored in worker."""
|
||||||
|
self.log.debug("worker: SIGWINCH ignored.")
|
||||||
|
|
||||||
|
def handle_abort_signal(self):
|
||||||
|
"""Handle SIGABRT - abort."""
|
||||||
|
self.alive = False
|
||||||
|
self.cfg.worker_abort(self)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""Main entry point for the worker."""
|
||||||
|
try:
|
||||||
|
self.loop.run_until_complete(self._serve())
|
||||||
|
except Exception as e:
|
||||||
|
self.log.exception("Worker exception: %s", e)
|
||||||
|
finally:
|
||||||
|
self._cleanup()
|
||||||
|
|
||||||
|
async def _serve(self):
|
||||||
|
"""Main async serving loop."""
|
||||||
|
# Run lifespan startup
|
||||||
|
lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto')
|
||||||
|
if lifespan_mode != "off":
|
||||||
|
from gunicorn.asgi.lifespan import LifespanManager
|
||||||
|
self.lifespan = LifespanManager(self.asgi, self.log, self.state)
|
||||||
|
try:
|
||||||
|
await self.lifespan.startup()
|
||||||
|
except Exception as e:
|
||||||
|
if lifespan_mode == "on":
|
||||||
|
self.log.error("ASGI lifespan startup failed: %s", e)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
# auto mode - app doesn't support lifespan
|
||||||
|
self.log.debug("ASGI lifespan not supported by app: %s", e)
|
||||||
|
self.lifespan = None
|
||||||
|
|
||||||
|
# Create servers for each listener socket
|
||||||
|
ssl_context = self._get_ssl_context()
|
||||||
|
|
||||||
|
for sock in self.sockets:
|
||||||
|
try:
|
||||||
|
server = await self.loop.create_server(
|
||||||
|
lambda: ASGIProtocol(self),
|
||||||
|
sock=sock.sock,
|
||||||
|
ssl=ssl_context,
|
||||||
|
reuse_address=True,
|
||||||
|
start_serving=True,
|
||||||
|
)
|
||||||
|
self.servers.append(server)
|
||||||
|
self.log.info("ASGI server listening on %s", sock)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to create server on %s: %s", sock, e)
|
||||||
|
|
||||||
|
if not self.servers:
|
||||||
|
self.log.error("No servers could be started")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Main loop with heartbeat
|
||||||
|
try:
|
||||||
|
while self.alive:
|
||||||
|
self.notify()
|
||||||
|
|
||||||
|
# Check if parent is still alive
|
||||||
|
if self.ppid != os.getppid():
|
||||||
|
self.log.info("Parent changed, shutting down: %s", self)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check connection limit
|
||||||
|
# (Connections are managed by nr_conns in ASGIProtocol)
|
||||||
|
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Graceful shutdown
|
||||||
|
await self._shutdown()
|
||||||
|
|
||||||
|
async def _shutdown(self):
|
||||||
|
"""Perform graceful shutdown."""
|
||||||
|
self.log.info("Worker shutting down...")
|
||||||
|
|
||||||
|
# Stop accepting new connections
|
||||||
|
for server in self.servers:
|
||||||
|
server.close()
|
||||||
|
|
||||||
|
# Wait for servers to close
|
||||||
|
for server in self.servers:
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
# Wait for in-flight connections (with timeout)
|
||||||
|
graceful_timeout = self.cfg.graceful_timeout
|
||||||
|
if self.nr_conns > 0:
|
||||||
|
self.log.info("Waiting for %d connections to finish...", self.nr_conns)
|
||||||
|
deadline = self.loop.time() + graceful_timeout
|
||||||
|
while self.nr_conns > 0 and self.loop.time() < deadline:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
if self.nr_conns > 0:
|
||||||
|
self.log.warning("Closing %d connections after timeout", self.nr_conns)
|
||||||
|
|
||||||
|
# Run lifespan shutdown
|
||||||
|
if self.lifespan:
|
||||||
|
try:
|
||||||
|
await self.lifespan.shutdown()
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("ASGI lifespan shutdown error: %s", e)
|
||||||
|
|
||||||
|
def _get_ssl_context(self):
|
||||||
|
"""Get SSL context if configured."""
|
||||||
|
if not self.cfg.is_ssl:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gunicorn import sock
|
||||||
|
return sock.ssl_context(self.cfg)
|
||||||
|
except Exception as e:
|
||||||
|
self.log.error("Failed to create SSL context: %s", e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cleanup(self):
|
||||||
|
"""Clean up resources on exit."""
|
||||||
|
try:
|
||||||
|
# Cancel all pending tasks
|
||||||
|
pending = asyncio.all_tasks(self.loop)
|
||||||
|
for task in pending:
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
|
# Run loop until all tasks are cancelled
|
||||||
|
if pending:
|
||||||
|
self.loop.run_until_complete(
|
||||||
|
asyncio.gather(*pending, return_exceptions=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.loop.close()
|
||||||
|
except Exception as e:
|
||||||
|
self.log.debug("Cleanup error: %s", e)
|
||||||
|
|
||||||
|
# Close sockets
|
||||||
|
for s in self.sockets:
|
||||||
|
try:
|
||||||
|
s.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
@ -58,6 +58,7 @@ testing = [
|
|||||||
"coverage",
|
"coverage",
|
||||||
"pytest",
|
"pytest",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
|
"pytest-asyncio",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
285
tests/test_asgi.py
Normal file
285
tests/test_asgi.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests for ASGI worker components.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import pytest
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from gunicorn.asgi.unreader import AsyncUnreader
|
||||||
|
from gunicorn.asgi.message import AsyncRequest
|
||||||
|
|
||||||
|
|
||||||
|
class MockStreamReader:
|
||||||
|
"""Mock asyncio.StreamReader for testing."""
|
||||||
|
|
||||||
|
def __init__(self, data):
|
||||||
|
self.data = data
|
||||||
|
self.pos = 0
|
||||||
|
|
||||||
|
async def read(self, size=-1):
|
||||||
|
if self.pos >= len(self.data):
|
||||||
|
return b""
|
||||||
|
if size < 0:
|
||||||
|
result = self.data[self.pos:]
|
||||||
|
self.pos = len(self.data)
|
||||||
|
else:
|
||||||
|
result = self.data[self.pos:self.pos + size]
|
||||||
|
self.pos += size
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def readexactly(self, n):
|
||||||
|
if self.pos + n > len(self.data):
|
||||||
|
raise asyncio.IncompleteReadError(
|
||||||
|
self.data[self.pos:], n
|
||||||
|
)
|
||||||
|
result = self.data[self.pos:self.pos + n]
|
||||||
|
self.pos += n
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class MockConfig:
|
||||||
|
"""Mock gunicorn config for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.is_ssl = False
|
||||||
|
self.proxy_protocol = False
|
||||||
|
self.proxy_allow_ips = ["127.0.0.1"]
|
||||||
|
self.forwarded_allow_ips = ["127.0.0.1"]
|
||||||
|
self.secure_scheme_headers = {}
|
||||||
|
self.forwarder_headers = []
|
||||||
|
self.limit_request_line = 8190
|
||||||
|
self.limit_request_fields = 100
|
||||||
|
self.limit_request_field_size = 8190
|
||||||
|
self.permit_unconventional_http_method = False
|
||||||
|
self.permit_unconventional_http_version = False
|
||||||
|
self.permit_obsolete_folding = False
|
||||||
|
self.casefold_http_method = False
|
||||||
|
self.strip_header_spaces = False
|
||||||
|
self.header_map = "refuse"
|
||||||
|
|
||||||
|
|
||||||
|
# AsyncUnreader Tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unreader_read_chunk():
|
||||||
|
"""Test basic chunk reading."""
|
||||||
|
reader = MockStreamReader(b"hello world")
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
data = await unreader.read()
|
||||||
|
assert data == b"hello world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unreader_read_size():
|
||||||
|
"""Test reading specific size."""
|
||||||
|
reader = MockStreamReader(b"hello world")
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
data = await unreader.read(5)
|
||||||
|
assert data == b"hello"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unreader_unread():
|
||||||
|
"""Test unread functionality."""
|
||||||
|
reader = MockStreamReader(b"hello world")
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
|
||||||
|
# Read all data
|
||||||
|
data = await unreader.read()
|
||||||
|
assert data == b"hello world"
|
||||||
|
|
||||||
|
# Unread some data
|
||||||
|
unreader.unread(b"world")
|
||||||
|
|
||||||
|
# Read again should get unread data
|
||||||
|
data = await unreader.read()
|
||||||
|
assert data == b"world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unreader_read_zero():
|
||||||
|
"""Test reading zero bytes."""
|
||||||
|
reader = MockStreamReader(b"hello")
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
data = await unreader.read(0)
|
||||||
|
assert data == b""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_unreader_read_empty():
|
||||||
|
"""Test reading from empty stream."""
|
||||||
|
reader = MockStreamReader(b"")
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
data = await unreader.read()
|
||||||
|
assert data == b""
|
||||||
|
|
||||||
|
|
||||||
|
# AsyncRequest Tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_simple_get():
|
||||||
|
"""Test parsing a simple GET request."""
|
||||||
|
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.method == "GET"
|
||||||
|
assert request.path == "/path"
|
||||||
|
assert request.version == (1, 1)
|
||||||
|
assert ("HOST", "localhost") in request.headers
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_with_query():
|
||||||
|
"""Test parsing request with query string."""
|
||||||
|
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.method == "GET"
|
||||||
|
assert request.path == "/search"
|
||||||
|
assert request.query == "q=test&page=1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_post_with_body():
|
||||||
|
"""Test parsing POST request with body."""
|
||||||
|
request_data = (
|
||||||
|
b"POST /submit HTTP/1.1\r\n"
|
||||||
|
b"Host: localhost\r\n"
|
||||||
|
b"Content-Length: 11\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
b"hello=world"
|
||||||
|
)
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.method == "POST"
|
||||||
|
assert request.path == "/submit"
|
||||||
|
assert request.content_length == 11
|
||||||
|
|
||||||
|
# Read body
|
||||||
|
body = await request.read_body(100)
|
||||||
|
assert body == b"hello=world"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_multiple_headers():
|
||||||
|
"""Test parsing request with multiple headers."""
|
||||||
|
request_data = (
|
||||||
|
b"GET / HTTP/1.1\r\n"
|
||||||
|
b"Host: localhost\r\n"
|
||||||
|
b"Accept: text/html\r\n"
|
||||||
|
b"Accept-Language: en-US\r\n"
|
||||||
|
b"Connection: keep-alive\r\n"
|
||||||
|
b"\r\n"
|
||||||
|
)
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert len(request.headers) == 4
|
||||||
|
assert request.get_header("HOST") == "localhost"
|
||||||
|
assert request.get_header("ACCEPT") == "text/html"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_should_close_http10():
|
||||||
|
"""Test connection close detection for HTTP/1.0."""
|
||||||
|
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.version == (1, 0)
|
||||||
|
assert request.should_close() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_should_close_connection_header():
|
||||||
|
"""Test connection close detection with Connection header."""
|
||||||
|
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.should_close() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_keepalive():
|
||||||
|
"""Test keepalive detection."""
|
||||||
|
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.should_close() is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_no_body_for_get():
|
||||||
|
"""Test that GET requests have no body by default."""
|
||||||
|
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
assert request.content_length == 0
|
||||||
|
body = await request.read_body()
|
||||||
|
assert body == b""
|
||||||
|
|
||||||
|
|
||||||
|
# Error handling tests
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_invalid_method():
|
||||||
|
"""Test invalid HTTP method detection."""
|
||||||
|
from gunicorn.http.errors import InvalidRequestMethod
|
||||||
|
|
||||||
|
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
with pytest.raises(InvalidRequestMethod):
|
||||||
|
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_request_invalid_http_version():
|
||||||
|
"""Test invalid HTTP version detection."""
|
||||||
|
from gunicorn.http.errors import InvalidHTTPVersion
|
||||||
|
|
||||||
|
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
|
||||||
|
reader = MockStreamReader(request_data)
|
||||||
|
unreader = AsyncUnreader(reader)
|
||||||
|
cfg = MockConfig()
|
||||||
|
|
||||||
|
with pytest.raises(InvalidHTTPVersion):
|
||||||
|
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||||
643
tests/test_asgi_worker.py
Normal file
643
tests/test_asgi_worker.py
Normal file
@ -0,0 +1,643 @@
|
|||||||
|
#
|
||||||
|
# This file is part of gunicorn released under the MIT license.
|
||||||
|
# See the NOTICE for more information.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests for the ASGI worker.
|
||||||
|
|
||||||
|
Includes unit tests for worker components and integration tests
|
||||||
|
that actually start the server and make HTTP requests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gunicorn.config import Config
|
||||||
|
from gunicorn.workers import gasgi
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Mock Classes
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class FakeSocket:
|
||||||
|
"""Mock socket for testing."""
|
||||||
|
|
||||||
|
def __init__(self, data=b''):
|
||||||
|
self.data = data
|
||||||
|
self.closed = False
|
||||||
|
self.blocking = True
|
||||||
|
self._fileno = id(self) % 65536
|
||||||
|
|
||||||
|
def fileno(self):
|
||||||
|
return self._fileno
|
||||||
|
|
||||||
|
def setblocking(self, blocking):
|
||||||
|
self.blocking = blocking
|
||||||
|
|
||||||
|
def recv(self, size):
|
||||||
|
if self.closed:
|
||||||
|
raise OSError(errno.EBADF, "Bad file descriptor")
|
||||||
|
result = self.data[:size]
|
||||||
|
self.data = self.data[size:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def send(self, data):
|
||||||
|
if self.closed:
|
||||||
|
raise OSError(errno.EPIPE, "Broken pipe")
|
||||||
|
return len(data)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.closed = True
|
||||||
|
|
||||||
|
def getsockname(self):
|
||||||
|
return ('127.0.0.1', 8000)
|
||||||
|
|
||||||
|
def getpeername(self):
|
||||||
|
return ('127.0.0.1', 12345)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeApp:
|
||||||
|
"""Mock ASGI application for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def wsgi(self):
|
||||||
|
return self.asgi_app
|
||||||
|
|
||||||
|
async def asgi_app(self, scope, receive, send):
|
||||||
|
self.calls.append(scope)
|
||||||
|
if scope["type"] == "lifespan":
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
elif message["type"] == "lifespan.shutdown":
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return
|
||||||
|
elif scope["type"] == "http":
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 200,
|
||||||
|
"headers": [(b"content-type", b"text/plain")],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": b"Hello from ASGI!",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
class FakeListener:
|
||||||
|
"""Mock listener socket."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.sock = FakeSocket()
|
||||||
|
|
||||||
|
def getsockname(self):
|
||||||
|
return ('127.0.0.1', 8000)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self.sock.close()
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "http://127.0.0.1:8000"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Helper Functions
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def _has_uvloop():
|
||||||
|
"""Check if uvloop is available."""
|
||||||
|
try:
|
||||||
|
import uvloop
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Unit Tests for ASGIWorker
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestASGIWorkerInit:
|
||||||
|
"""Tests for ASGIWorker initialization."""
|
||||||
|
|
||||||
|
def create_worker(self, **kwargs):
|
||||||
|
"""Create a worker for testing."""
|
||||||
|
cfg = Config()
|
||||||
|
cfg.set('workers', 1)
|
||||||
|
cfg.set('worker_connections', 1000)
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
cfg.set(key, value)
|
||||||
|
|
||||||
|
worker = gasgi.ASGIWorker(
|
||||||
|
age=1,
|
||||||
|
ppid=os.getpid(),
|
||||||
|
sockets=[],
|
||||||
|
app=FakeApp(),
|
||||||
|
timeout=30,
|
||||||
|
cfg=cfg,
|
||||||
|
log=mock.Mock(),
|
||||||
|
)
|
||||||
|
return worker
|
||||||
|
|
||||||
|
def test_worker_init(self):
|
||||||
|
"""Test worker initialization."""
|
||||||
|
worker = self.create_worker()
|
||||||
|
|
||||||
|
assert worker.worker_connections == 1000
|
||||||
|
assert worker.nr_conns == 0
|
||||||
|
assert worker.loop is None
|
||||||
|
assert worker.servers == []
|
||||||
|
assert worker.state == {}
|
||||||
|
|
||||||
|
def test_worker_connections_config(self):
|
||||||
|
"""Test worker_connections configuration."""
|
||||||
|
worker = self.create_worker(worker_connections=500)
|
||||||
|
assert worker.worker_connections == 500
|
||||||
|
|
||||||
|
|
||||||
|
class TestASGIWorkerEventLoop:
|
||||||
|
"""Tests for event loop setup."""
|
||||||
|
|
||||||
|
def create_worker(self, **kwargs):
|
||||||
|
"""Create a worker for testing."""
|
||||||
|
cfg = Config()
|
||||||
|
cfg.set('workers', 1)
|
||||||
|
cfg.set('worker_connections', 1000)
|
||||||
|
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
cfg.set(key, value)
|
||||||
|
|
||||||
|
worker = gasgi.ASGIWorker(
|
||||||
|
age=1,
|
||||||
|
ppid=os.getpid(),
|
||||||
|
sockets=[],
|
||||||
|
app=FakeApp(),
|
||||||
|
timeout=30,
|
||||||
|
cfg=cfg,
|
||||||
|
log=mock.Mock(),
|
||||||
|
)
|
||||||
|
return worker
|
||||||
|
|
||||||
|
def test_setup_asyncio_loop(self):
|
||||||
|
"""Test asyncio event loop setup."""
|
||||||
|
worker = self.create_worker(asgi_loop='asyncio')
|
||||||
|
worker._setup_event_loop()
|
||||||
|
|
||||||
|
assert worker.loop is not None
|
||||||
|
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
|
||||||
|
worker.loop.close()
|
||||||
|
|
||||||
|
def test_setup_auto_loop_falls_back_to_asyncio(self):
|
||||||
|
"""Test that auto mode uses asyncio when uvloop unavailable."""
|
||||||
|
worker = self.create_worker(asgi_loop='auto')
|
||||||
|
|
||||||
|
# Mock uvloop import failure
|
||||||
|
with mock.patch.dict('sys.modules', {'uvloop': None}):
|
||||||
|
worker._setup_event_loop()
|
||||||
|
|
||||||
|
assert worker.loop is not None
|
||||||
|
worker.loop.close()
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _has_uvloop(),
|
||||||
|
reason="uvloop not installed"
|
||||||
|
)
|
||||||
|
def test_setup_uvloop(self):
|
||||||
|
"""Test uvloop event loop setup."""
|
||||||
|
worker = self.create_worker(asgi_loop='uvloop')
|
||||||
|
worker._setup_event_loop()
|
||||||
|
|
||||||
|
import uvloop
|
||||||
|
assert isinstance(worker.loop, uvloop.Loop)
|
||||||
|
worker.loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestASGIWorkerSignals:
|
||||||
|
"""Tests for signal handling."""
|
||||||
|
|
||||||
|
def create_worker(self):
|
||||||
|
"""Create a worker for testing."""
|
||||||
|
cfg = Config()
|
||||||
|
cfg.set('workers', 1)
|
||||||
|
cfg.set('worker_connections', 1000)
|
||||||
|
cfg.set('graceful_timeout', 5)
|
||||||
|
|
||||||
|
worker = gasgi.ASGIWorker(
|
||||||
|
age=1,
|
||||||
|
ppid=os.getpid(),
|
||||||
|
sockets=[],
|
||||||
|
app=FakeApp(),
|
||||||
|
timeout=30,
|
||||||
|
cfg=cfg,
|
||||||
|
log=mock.Mock(),
|
||||||
|
)
|
||||||
|
worker._setup_event_loop()
|
||||||
|
return worker
|
||||||
|
|
||||||
|
def test_handle_exit_sets_alive_false(self):
|
||||||
|
"""Test that exit signal sets alive=False."""
|
||||||
|
worker = self.create_worker()
|
||||||
|
worker.alive = True
|
||||||
|
|
||||||
|
worker.handle_exit_signal()
|
||||||
|
|
||||||
|
assert worker.alive is False
|
||||||
|
worker.loop.close()
|
||||||
|
|
||||||
|
def test_handle_quit_sets_alive_false(self):
|
||||||
|
"""Test that quit signal sets alive=False."""
|
||||||
|
worker = self.create_worker()
|
||||||
|
worker.alive = True
|
||||||
|
|
||||||
|
# Mock the worker_int callback on the worker's cfg settings
|
||||||
|
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
|
||||||
|
worker.handle_quit_signal()
|
||||||
|
|
||||||
|
assert worker.alive is False
|
||||||
|
worker.loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for Lifespan Protocol
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestLifespanManager:
|
||||||
|
"""Tests for ASGI lifespan protocol."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_startup_complete(self):
|
||||||
|
"""Test successful lifespan startup."""
|
||||||
|
from gunicorn.asgi.lifespan import LifespanManager
|
||||||
|
|
||||||
|
startup_called = False
|
||||||
|
shutdown_called = False
|
||||||
|
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
nonlocal startup_called, shutdown_called
|
||||||
|
assert scope["type"] == "lifespan"
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
startup_called = True
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
elif message["type"] == "lifespan.shutdown":
|
||||||
|
shutdown_called = True
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
return
|
||||||
|
|
||||||
|
manager = LifespanManager(app, mock.Mock())
|
||||||
|
await manager.startup()
|
||||||
|
|
||||||
|
assert startup_called
|
||||||
|
assert manager._startup_complete.is_set()
|
||||||
|
assert not manager._startup_failed
|
||||||
|
|
||||||
|
await manager.shutdown()
|
||||||
|
assert shutdown_called
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_startup_failed(self):
|
||||||
|
"""Test lifespan startup failure."""
|
||||||
|
from gunicorn.asgi.lifespan import LifespanManager
|
||||||
|
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "lifespan.startup":
|
||||||
|
await send({
|
||||||
|
"type": "lifespan.startup.failed",
|
||||||
|
"message": "Database connection failed"
|
||||||
|
})
|
||||||
|
|
||||||
|
manager = LifespanManager(app, mock.Mock())
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||||
|
await manager.startup()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_lifespan_state_shared(self):
|
||||||
|
"""Test that lifespan state is shared with app."""
|
||||||
|
from gunicorn.asgi.lifespan import LifespanManager
|
||||||
|
|
||||||
|
state = {}
|
||||||
|
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
assert "state" in scope
|
||||||
|
scope["state"]["db"] = "connected"
|
||||||
|
message = await receive()
|
||||||
|
await send({"type": "lifespan.startup.complete"})
|
||||||
|
message = await receive()
|
||||||
|
await send({"type": "lifespan.shutdown.complete"})
|
||||||
|
|
||||||
|
manager = LifespanManager(app, mock.Mock(), state)
|
||||||
|
await manager.startup()
|
||||||
|
|
||||||
|
assert state.get("db") == "connected"
|
||||||
|
|
||||||
|
await manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Tests for WebSocket Protocol
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestWebSocketProtocol:
|
||||||
|
"""Tests for WebSocket protocol handling."""
|
||||||
|
|
||||||
|
def test_websocket_guid(self):
|
||||||
|
"""Test WebSocket GUID constant."""
|
||||||
|
from gunicorn.asgi.websocket import WS_GUID
|
||||||
|
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
|
||||||
|
def test_websocket_opcodes(self):
|
||||||
|
"""Test WebSocket opcode constants."""
|
||||||
|
from gunicorn.asgi import websocket
|
||||||
|
|
||||||
|
assert websocket.OPCODE_TEXT == 0x1
|
||||||
|
assert websocket.OPCODE_BINARY == 0x2
|
||||||
|
assert websocket.OPCODE_CLOSE == 0x8
|
||||||
|
assert websocket.OPCODE_PING == 0x9
|
||||||
|
assert websocket.OPCODE_PONG == 0xA
|
||||||
|
|
||||||
|
def test_websocket_accept_key_calculation(self):
|
||||||
|
"""Test WebSocket accept key calculation per RFC 6455."""
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from gunicorn.asgi.websocket import WS_GUID
|
||||||
|
|
||||||
|
# Example from RFC 6455
|
||||||
|
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
|
||||||
|
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||||
|
|
||||||
|
accept_key = base64.b64encode(
|
||||||
|
hashlib.sha1(client_key + WS_GUID).digest()
|
||||||
|
).decode("ascii")
|
||||||
|
|
||||||
|
assert accept_key == expected_accept
|
||||||
|
|
||||||
|
def test_websocket_frame_masking(self):
|
||||||
|
"""Test WebSocket frame unmasking."""
|
||||||
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||||
|
|
||||||
|
# Create a minimal protocol instance
|
||||||
|
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||||
|
|
||||||
|
# Test unmasking (XOR operation)
|
||||||
|
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||||
|
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
|
||||||
|
|
||||||
|
unmasked = protocol._unmask(masked_data, masking_key)
|
||||||
|
assert unmasked == b"Hello"
|
||||||
|
|
||||||
|
def test_websocket_frame_masking_empty(self):
|
||||||
|
"""Test WebSocket frame unmasking with empty payload."""
|
||||||
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||||
|
|
||||||
|
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||||
|
|
||||||
|
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||||
|
unmasked = protocol._unmask(b"", masking_key)
|
||||||
|
assert unmasked == b""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestASGIIntegration:
|
||||||
|
"""Integration tests that start actual servers."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def free_port(self):
|
||||||
|
"""Get a free port for testing."""
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(('127.0.0.1', 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_request_response(self, free_port):
|
||||||
|
"""Test basic HTTP request/response cycle."""
|
||||||
|
# Simple ASGI app
|
||||||
|
async def app(scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 200,
|
||||||
|
"headers": [(b"content-type", b"text/plain")],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": b"Hello, World!",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Start server
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
server = await loop.create_server(
|
||||||
|
lambda: _TestProtocol(app),
|
||||||
|
'127.0.0.1',
|
||||||
|
free_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use asyncio to make HTTP request
|
||||||
|
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
|
||||||
|
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
|
||||||
|
writer.write(request.encode())
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
# Read response
|
||||||
|
response = await reader.read(4096)
|
||||||
|
response_text = response.decode()
|
||||||
|
|
||||||
|
assert "HTTP/1.1 200" in response_text
|
||||||
|
assert "Hello, World!" in response_text
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
|
finally:
|
||||||
|
server.close()
|
||||||
|
await server.wait_closed()
|
||||||
|
|
||||||
|
|
||||||
|
class _TestProtocol(asyncio.Protocol):
|
||||||
|
"""Minimal protocol for integration testing."""
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
self.transport = None
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
self.transport = transport
|
||||||
|
|
||||||
|
def data_received(self, data):
|
||||||
|
# Very simple HTTP parsing for testing
|
||||||
|
asyncio.create_task(self._handle(data))
|
||||||
|
|
||||||
|
async def _handle(self, data):
|
||||||
|
# Parse basic HTTP request
|
||||||
|
lines = data.decode().split('\r\n')
|
||||||
|
method, path, _ = lines[0].split(' ')
|
||||||
|
|
||||||
|
scope = {
|
||||||
|
"type": "http",
|
||||||
|
"asgi": {"version": "3.0"},
|
||||||
|
"http_version": "1.1",
|
||||||
|
"method": method,
|
||||||
|
"path": path,
|
||||||
|
"query_string": b"",
|
||||||
|
"headers": [],
|
||||||
|
"server": ("127.0.0.1", 8000),
|
||||||
|
"client": ("127.0.0.1", 12345),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def receive():
|
||||||
|
return {"type": "http.request", "body": b"", "more_body": False}
|
||||||
|
|
||||||
|
async def send(message):
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
status = message["status"]
|
||||||
|
headers = message.get("headers", [])
|
||||||
|
response = f"HTTP/1.1 {status} OK\r\n"
|
||||||
|
for name, value in headers:
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = name.decode()
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode()
|
||||||
|
response += f"{name}: {value}\r\n"
|
||||||
|
response += "\r\n"
|
||||||
|
self.transport.write(response.encode())
|
||||||
|
elif message["type"] == "http.response.body":
|
||||||
|
body = message.get("body", b"")
|
||||||
|
self.transport.write(body)
|
||||||
|
if not message.get("more_body", False):
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ASGI Protocol Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestASGIProtocol:
|
||||||
|
"""Tests for ASGIProtocol."""
|
||||||
|
|
||||||
|
def test_reason_phrases(self):
|
||||||
|
"""Test HTTP reason phrase lookup."""
|
||||||
|
from gunicorn.asgi.protocol import ASGIProtocol
|
||||||
|
|
||||||
|
# Create minimal worker mock
|
||||||
|
worker = mock.Mock()
|
||||||
|
worker.cfg = Config()
|
||||||
|
worker.log = mock.Mock()
|
||||||
|
worker.asgi = mock.Mock()
|
||||||
|
|
||||||
|
protocol = ASGIProtocol(worker)
|
||||||
|
|
||||||
|
assert protocol._get_reason_phrase(200) == "OK"
|
||||||
|
assert protocol._get_reason_phrase(404) == "Not Found"
|
||||||
|
assert protocol._get_reason_phrase(500) == "Internal Server Error"
|
||||||
|
assert protocol._get_reason_phrase(999) == "Unknown"
|
||||||
|
|
||||||
|
def test_scope_building(self):
|
||||||
|
"""Test HTTP scope building."""
|
||||||
|
from gunicorn.asgi.protocol import ASGIProtocol
|
||||||
|
from gunicorn.asgi.message import AsyncRequest
|
||||||
|
|
||||||
|
worker = mock.Mock()
|
||||||
|
worker.cfg = Config()
|
||||||
|
worker.cfg.set('root_path', '/api')
|
||||||
|
worker.log = mock.Mock()
|
||||||
|
worker.asgi = mock.Mock()
|
||||||
|
|
||||||
|
protocol = ASGIProtocol(worker)
|
||||||
|
|
||||||
|
# Create mock request
|
||||||
|
request = mock.Mock()
|
||||||
|
request.method = "GET"
|
||||||
|
request.path = "/users"
|
||||||
|
request.query = "page=1"
|
||||||
|
request.version = (1, 1)
|
||||||
|
request.scheme = "http"
|
||||||
|
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
|
||||||
|
|
||||||
|
scope = protocol._build_http_scope(
|
||||||
|
request,
|
||||||
|
("127.0.0.1", 8000), # sockname
|
||||||
|
("127.0.0.1", 12345), # peername
|
||||||
|
)
|
||||||
|
|
||||||
|
assert scope["type"] == "http"
|
||||||
|
assert scope["method"] == "GET"
|
||||||
|
assert scope["path"] == "/users"
|
||||||
|
assert scope["query_string"] == b"page=1"
|
||||||
|
assert scope["root_path"] == "/api"
|
||||||
|
assert scope["http_version"] == "1.1"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Config Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestASGIConfig:
|
||||||
|
"""Tests for ASGI configuration options."""
|
||||||
|
|
||||||
|
def test_asgi_loop_default(self):
|
||||||
|
"""Test default asgi_loop value."""
|
||||||
|
cfg = Config()
|
||||||
|
assert cfg.asgi_loop == "auto"
|
||||||
|
|
||||||
|
def test_asgi_loop_validation(self):
|
||||||
|
"""Test asgi_loop validation."""
|
||||||
|
cfg = Config()
|
||||||
|
|
||||||
|
cfg.set('asgi_loop', 'asyncio')
|
||||||
|
assert cfg.asgi_loop == 'asyncio'
|
||||||
|
|
||||||
|
cfg.set('asgi_loop', 'uvloop')
|
||||||
|
assert cfg.asgi_loop == 'uvloop'
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
cfg.set('asgi_loop', 'invalid')
|
||||||
|
|
||||||
|
def test_asgi_lifespan_default(self):
|
||||||
|
"""Test default asgi_lifespan value."""
|
||||||
|
cfg = Config()
|
||||||
|
assert cfg.asgi_lifespan == "auto"
|
||||||
|
|
||||||
|
def test_asgi_lifespan_validation(self):
|
||||||
|
"""Test asgi_lifespan validation."""
|
||||||
|
cfg = Config()
|
||||||
|
|
||||||
|
cfg.set('asgi_lifespan', 'on')
|
||||||
|
assert cfg.asgi_lifespan == 'on'
|
||||||
|
|
||||||
|
cfg.set('asgi_lifespan', 'off')
|
||||||
|
assert cfg.asgi_lifespan == 'off'
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
cfg.set('asgi_lifespan', 'invalid')
|
||||||
|
|
||||||
|
def test_root_path_default(self):
|
||||||
|
"""Test default root_path value."""
|
||||||
|
cfg = Config()
|
||||||
|
assert cfg.root_path == ""
|
||||||
|
|
||||||
|
def test_root_path_setting(self):
|
||||||
|
"""Test root_path configuration."""
|
||||||
|
cfg = Config()
|
||||||
|
cfg.set('root_path', '/api/v1')
|
||||||
|
assert cfg.root_path == '/api/v1'
|
||||||
Loading…
x
Reference in New Issue
Block a user