mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 18:21:30 +08:00
asgi: Add native ASGI worker with HTTP and WebSocket support
Add a new ASGI worker type that provides native async support using gunicorn's own HTTP parsing infrastructure adapted for asyncio. Features: - HTTP/1.1 with keepalive support - WebSocket connections (RFC 6455) - ASGI lifespan protocol for startup/shutdown hooks - Optional uvloop support for improved performance - Full proxy protocol support (inherited from gunicorn) New configuration options: - --asgi-loop: Event loop selection (auto/asyncio/uvloop) - --asgi-lifespan: Lifespan protocol control (auto/on/off) - --root-path: ASGI root path for reverse proxy setups Usage: gunicorn -k asgi myapp:app
This commit is contained in:
parent
ea98400820
commit
ae1eea8108
7
examples/asgi/__init__.py
Normal file
7
examples/asgi/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI example applications for gunicorn.
|
||||
"""
|
||||
130
examples/asgi/basic_app.py
Normal file
130
examples/asgi/basic_app.py
Normal file
@ -0,0 +1,130 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Basic ASGI application example.
|
||||
|
||||
Run with:
|
||||
gunicorn -k asgi examples.asgi.basic_app:app
|
||||
|
||||
Test with:
|
||||
curl http://127.0.0.1:8000/
|
||||
curl http://127.0.0.1:8000/hello
|
||||
curl -X POST http://127.0.0.1:8000/echo -d "test data"
|
||||
"""
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
"""Simple ASGI application demonstrating basic functionality."""
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await handle_lifespan(scope, receive, send)
|
||||
elif scope["type"] == "http":
|
||||
await handle_http(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||
|
||||
|
||||
async def handle_lifespan(scope, receive, send):
|
||||
"""Handle lifespan events (startup/shutdown)."""
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
print("ASGI application starting up...")
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
print("ASGI application shutting down...")
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
|
||||
async def handle_http(scope, receive, send):
|
||||
"""Handle HTTP requests."""
|
||||
path = scope["path"]
|
||||
method = scope["method"]
|
||||
|
||||
if path == "/" and method == "GET":
|
||||
await send_response(send, 200, b"Welcome to gunicorn ASGI!\n")
|
||||
|
||||
elif path == "/hello" and method == "GET":
|
||||
name = get_query_param(scope, "name", "World")
|
||||
body = f"Hello, {name}!\n".encode()
|
||||
await send_response(send, 200, body)
|
||||
|
||||
elif path == "/echo" and method == "POST":
|
||||
body = await read_body(receive)
|
||||
await send_response(send, 200, body, content_type=b"application/octet-stream")
|
||||
|
||||
elif path == "/headers":
|
||||
headers_info = format_headers(scope["headers"])
|
||||
await send_response(send, 200, headers_info.encode())
|
||||
|
||||
elif path == "/info":
|
||||
info = format_request_info(scope)
|
||||
await send_response(send, 200, info.encode(), content_type=b"application/json")
|
||||
|
||||
else:
|
||||
await send_response(send, 404, b"Not Found\n")
|
||||
|
||||
|
||||
async def send_response(send, status, body, content_type=b"text/plain"):
|
||||
"""Send an HTTP response."""
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [
|
||||
(b"content-type", content_type),
|
||||
(b"content-length", str(len(body)).encode()),
|
||||
],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
|
||||
|
||||
async def read_body(receive):
|
||||
"""Read the full request body."""
|
||||
body = b""
|
||||
while True:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
return body
|
||||
|
||||
|
||||
def get_query_param(scope, name, default=None):
|
||||
"""Get a query parameter value."""
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
for param in query_string.split("&"):
|
||||
if "=" in param:
|
||||
key, value = param.split("=", 1)
|
||||
if key == name:
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def format_headers(headers):
|
||||
"""Format headers for display."""
|
||||
lines = ["Request Headers:"]
|
||||
for name, value in headers:
|
||||
lines.append(f" {name.decode()}: {value.decode()}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def format_request_info(scope):
|
||||
"""Format request info as JSON."""
|
||||
import json
|
||||
info = {
|
||||
"method": scope["method"],
|
||||
"path": scope["path"],
|
||||
"query_string": scope.get("query_string", b"").decode(),
|
||||
"http_version": scope["http_version"],
|
||||
"scheme": scope["scheme"],
|
||||
"server": list(scope.get("server") or []),
|
||||
"client": list(scope.get("client") or []),
|
||||
"root_path": scope.get("root_path", ""),
|
||||
}
|
||||
return json.dumps(info, indent=2) + "\n"
|
||||
235
examples/asgi/websocket_app.py
Normal file
235
examples/asgi/websocket_app.py
Normal file
@ -0,0 +1,235 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
WebSocket ASGI application example.
|
||||
|
||||
Run with:
|
||||
gunicorn -k asgi examples.asgi.websocket_app:app
|
||||
|
||||
Test with:
|
||||
# Using websocat (install with: cargo install websocat)
|
||||
websocat ws://127.0.0.1:8000/ws
|
||||
|
||||
# Or using Python websockets library
|
||||
python -c "
|
||||
import asyncio
|
||||
import websockets
|
||||
async def test():
|
||||
async with websockets.connect('ws://127.0.0.1:8000/ws') as ws:
|
||||
await ws.send('Hello')
|
||||
print(await ws.recv())
|
||||
asyncio.run(test())
|
||||
"
|
||||
"""
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
"""ASGI application with WebSocket support."""
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await handle_lifespan(scope, receive, send)
|
||||
elif scope["type"] == "http":
|
||||
await handle_http(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
await handle_websocket(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||
|
||||
|
||||
async def handle_lifespan(scope, receive, send):
|
||||
"""Handle lifespan events."""
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
|
||||
async def handle_http(scope, receive, send):
|
||||
"""Handle HTTP requests - serve a simple HTML page for WebSocket testing."""
|
||||
path = scope["path"]
|
||||
|
||||
if path == "/":
|
||||
html = HTML_PAGE.encode()
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [
|
||||
(b"content-type", b"text/html"),
|
||||
(b"content-length", str(len(html)).encode()),
|
||||
],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": html,
|
||||
})
|
||||
else:
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Not Found",
|
||||
})
|
||||
|
||||
|
||||
async def handle_websocket(scope, receive, send):
|
||||
"""Handle WebSocket connections."""
|
||||
path = scope["path"]
|
||||
|
||||
if path == "/ws":
|
||||
await echo_websocket(scope, receive, send)
|
||||
elif path == "/ws/chat":
|
||||
await chat_websocket(scope, receive, send)
|
||||
else:
|
||||
# Reject the connection
|
||||
await send({"type": "websocket.close", "code": 4004})
|
||||
|
||||
|
||||
async def echo_websocket(scope, receive, send):
|
||||
"""Echo WebSocket - sends back whatever it receives."""
|
||||
# Wait for connection
|
||||
message = await receive()
|
||||
if message["type"] != "websocket.connect":
|
||||
return
|
||||
|
||||
# Accept the connection
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
# Echo loop
|
||||
try:
|
||||
while True:
|
||||
message = await receive()
|
||||
|
||||
if message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
if "text" in message:
|
||||
# Echo text back
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": f"Echo: {message['text']}"
|
||||
})
|
||||
elif "bytes" in message:
|
||||
# Echo bytes back
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"bytes": message["bytes"]
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
finally:
|
||||
try:
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def chat_websocket(scope, receive, send):
|
||||
"""Chat WebSocket - simple broadcast example."""
|
||||
message = await receive()
|
||||
if message["type"] != "websocket.connect":
|
||||
return
|
||||
|
||||
await send({
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": "chat"
|
||||
})
|
||||
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": "Welcome to the chat! Send messages and they will be echoed back."
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await receive()
|
||||
|
||||
if message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
if message["type"] == "websocket.receive" and "text" in message:
|
||||
text = message["text"]
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": f"[You]: {text}"
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
HTML_PAGE = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>WebSocket Test</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
||||
#messages { border: 1px solid #ccc; height: 300px; overflow-y: auto; padding: 10px; margin-bottom: 10px; }
|
||||
#input { width: 80%; padding: 10px; }
|
||||
button { padding: 10px 20px; }
|
||||
.sent { color: blue; }
|
||||
.received { color: green; }
|
||||
.error { color: red; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebSocket Test</h1>
|
||||
<div id="messages"></div>
|
||||
<input type="text" id="input" placeholder="Type a message...">
|
||||
<button onclick="sendMessage()">Send</button>
|
||||
<button onclick="connectWS()">Connect</button>
|
||||
<button onclick="disconnectWS()">Disconnect</button>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
const messages = document.getElementById('messages');
|
||||
const input = document.getElementById('input');
|
||||
|
||||
function log(msg, className) {
|
||||
const div = document.createElement('div');
|
||||
div.className = className || '';
|
||||
div.textContent = msg;
|
||||
messages.appendChild(div);
|
||||
messages.scrollTop = messages.scrollHeight;
|
||||
}
|
||||
|
||||
function connectWS() {
|
||||
if (ws) {
|
||||
log('Already connected', 'error');
|
||||
return;
|
||||
}
|
||||
ws = new WebSocket('ws://' + window.location.host + '/ws');
|
||||
ws.onopen = () => log('Connected!', 'received');
|
||||
ws.onclose = () => { log('Disconnected', 'error'); ws = null; };
|
||||
ws.onerror = (e) => log('Error: ' + e, 'error');
|
||||
ws.onmessage = (e) => log(e.data, 'received');
|
||||
}
|
||||
|
||||
function disconnectWS() {
|
||||
if (ws) ws.close();
|
||||
}
|
||||
|
||||
function sendMessage() {
|
||||
if (!ws) { log('Not connected', 'error'); return; }
|
||||
const msg = input.value;
|
||||
if (!msg) return;
|
||||
ws.send(msg);
|
||||
log('Sent: ' + msg, 'sent');
|
||||
input.value = '';
|
||||
}
|
||||
|
||||
input.onkeypress = (e) => { if (e.key === 'Enter') sendMessage(); };
|
||||
|
||||
// Auto-connect
|
||||
connectWS();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
26
gunicorn/asgi/__init__.py
Normal file
26
gunicorn/asgi/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI support for gunicorn.
|
||||
|
||||
This module provides native ASGI worker support, using gunicorn's own
|
||||
HTTP parsing infrastructure adapted for async I/O.
|
||||
|
||||
Components:
|
||||
- AsyncUnreader: Async socket reading with pushback buffer
|
||||
- AsyncRequest: Async HTTP request parser
|
||||
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
|
||||
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
|
||||
- LifespanManager: ASGI lifespan protocol support
|
||||
|
||||
Usage:
|
||||
gunicorn -k asgi myapp:app
|
||||
"""
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']
|
||||
178
gunicorn/asgi/lifespan.py
Normal file
178
gunicorn/asgi/lifespan.py
Normal file
@ -0,0 +1,178 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI lifespan protocol manager.
|
||||
|
||||
Manages startup and shutdown events for ASGI applications,
|
||||
enabling frameworks like FastAPI to run initialization and
|
||||
cleanup code.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class LifespanManager:
|
||||
"""Manages ASGI lifespan events (startup/shutdown).
|
||||
|
||||
The lifespan protocol allows ASGI applications to run code at
|
||||
startup and shutdown. This is essential for applications that
|
||||
need to initialize database connections, caches, or other
|
||||
resources.
|
||||
|
||||
ASGI lifespan messages:
|
||||
- Server sends: {"type": "lifespan.startup"}
|
||||
- App responds: {"type": "lifespan.startup.complete"} or
|
||||
{"type": "lifespan.startup.failed", "message": "..."}
|
||||
- Server sends: {"type": "lifespan.shutdown"}
|
||||
- App responds: {"type": "lifespan.shutdown.complete"}
|
||||
"""
|
||||
|
||||
def __init__(self, app, logger, state=None):
|
||||
"""Initialize the lifespan manager.
|
||||
|
||||
Args:
|
||||
app: ASGI application callable
|
||||
logger: Logger instance
|
||||
state: Shared state dict for the application
|
||||
"""
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.state = state if state is not None else {}
|
||||
|
||||
self._startup_complete = asyncio.Event()
|
||||
self._shutdown_complete = asyncio.Event()
|
||||
self._startup_failed = False
|
||||
self._startup_error = None
|
||||
self._shutdown_error = None
|
||||
self._receive_queue = asyncio.Queue()
|
||||
self._task = None
|
||||
self._app_finished = False
|
||||
|
||||
async def startup(self):
|
||||
"""Run lifespan startup and wait for completion.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If startup fails or app doesn't support lifespan
|
||||
"""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"state": self.state,
|
||||
}
|
||||
|
||||
# Send startup event
|
||||
await self._receive_queue.put({"type": "lifespan.startup"})
|
||||
|
||||
# Run lifespan in background task
|
||||
self._task = asyncio.create_task(self._run_lifespan(scope))
|
||||
|
||||
# Wait for startup with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._startup_complete.wait(),
|
||||
timeout=30.0 # Reasonable startup timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
raise RuntimeError("Lifespan startup timed out")
|
||||
|
||||
if self._startup_failed:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
msg = self._startup_error or "Unknown error"
|
||||
raise RuntimeError(f"Lifespan startup failed: {msg}")
|
||||
|
||||
self.logger.debug("ASGI lifespan startup complete")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Signal shutdown and wait for completion.
|
||||
|
||||
This should be called during graceful shutdown.
|
||||
"""
|
||||
if self._app_finished:
|
||||
self.logger.debug("ASGI lifespan already finished")
|
||||
return
|
||||
|
||||
# Send shutdown event
|
||||
await self._receive_queue.put({"type": "lifespan.shutdown"})
|
||||
|
||||
# Wait for shutdown with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_complete.wait(),
|
||||
timeout=30.0 # Reasonable shutdown timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Lifespan shutdown timed out")
|
||||
|
||||
if self._shutdown_error:
|
||||
self.logger.error("Lifespan shutdown error: %s", self._shutdown_error)
|
||||
|
||||
# Cancel the task if still running
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.logger.debug("ASGI lifespan shutdown complete")
|
||||
|
||||
async def _run_lifespan(self, scope):
|
||||
"""Run the ASGI lifespan protocol."""
|
||||
try:
|
||||
await self.app(scope, self._receive, self._send)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.debug("Lifespan application raised: %s", e)
|
||||
# If startup hasn't completed, mark it as failed
|
||||
if not self._startup_complete.is_set():
|
||||
self._startup_failed = True
|
||||
self._startup_error = str(e)
|
||||
self._startup_complete.set()
|
||||
# If shutdown hasn't completed, mark error
|
||||
elif not self._shutdown_complete.is_set():
|
||||
self._shutdown_error = str(e)
|
||||
self._shutdown_complete.set()
|
||||
finally:
|
||||
self._app_finished = True
|
||||
# Ensure events are set to unblock waiters
|
||||
if not self._startup_complete.is_set():
|
||||
self._startup_failed = True
|
||||
self._startup_error = "Application exited before startup complete"
|
||||
self._startup_complete.set()
|
||||
if not self._shutdown_complete.is_set():
|
||||
self._shutdown_complete.set()
|
||||
|
||||
async def _receive(self):
|
||||
"""ASGI receive callable for lifespan."""
|
||||
return await self._receive_queue.get()
|
||||
|
||||
async def _send(self, message):
|
||||
"""ASGI send callable for lifespan."""
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "lifespan.startup.complete":
|
||||
self._startup_complete.set()
|
||||
self.logger.debug("Received lifespan.startup.complete")
|
||||
|
||||
elif msg_type == "lifespan.startup.failed":
|
||||
self._startup_failed = True
|
||||
self._startup_error = message.get("message", "")
|
||||
self._startup_complete.set()
|
||||
self.logger.debug("Received lifespan.startup.failed: %s",
|
||||
self._startup_error)
|
||||
|
||||
elif msg_type == "lifespan.shutdown.complete":
|
||||
self._shutdown_complete.set()
|
||||
self.logger.debug("Received lifespan.shutdown.complete")
|
||||
|
||||
elif msg_type == "lifespan.shutdown.failed":
|
||||
self._shutdown_error = message.get("message", "")
|
||||
self._shutdown_complete.set()
|
||||
self.logger.debug("Received lifespan.shutdown.failed: %s",
|
||||
self._shutdown_error)
|
||||
562
gunicorn/asgi/message.py
Normal file
562
gunicorn/asgi/message.py
Normal file
@ -0,0 +1,562 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Async version of gunicorn/http/message.py for ASGI workers.
|
||||
|
||||
Reuses the parsing logic from the sync version, adapted for async I/O.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import socket
|
||||
|
||||
from gunicorn.http.errors import (
|
||||
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
|
||||
LimitRequestLine, LimitRequestHeaders,
|
||||
UnsupportedTransferCoding, ObsoleteFolding,
|
||||
InvalidProxyLine, ForbiddenProxyRequest,
|
||||
InvalidSchemeHeaders,
|
||||
)
|
||||
from gunicorn.util import bytes_to_str, split_request_uri
|
||||
|
||||
MAX_REQUEST_LINE = 8190
|
||||
MAX_HEADERS = 32768
|
||||
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
|
||||
|
||||
# Reuse regex patterns from sync version
|
||||
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
|
||||
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
|
||||
METHOD_BADCHAR_RE = re.compile("[a-z#]")
|
||||
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
|
||||
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
|
||||
|
||||
|
||||
class AsyncRequest:
|
||||
"""Async HTTP request parser.
|
||||
|
||||
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
|
||||
parsing logic where possible.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.remote_addr = peer_addr
|
||||
self.req_number = req_number
|
||||
|
||||
self.version = None
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
self.scheme = "https" if cfg.is_ssl else "http"
|
||||
self.must_close = False
|
||||
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Request line limit
|
||||
self.limit_request_line = cfg.limit_request_line
|
||||
if (self.limit_request_line < 0
|
||||
or self.limit_request_line >= MAX_REQUEST_LINE):
|
||||
self.limit_request_line = MAX_REQUEST_LINE
|
||||
|
||||
# Headers limits
|
||||
self.limit_request_fields = cfg.limit_request_fields
|
||||
if (self.limit_request_fields <= 0
|
||||
or self.limit_request_fields > MAX_HEADERS):
|
||||
self.limit_request_fields = MAX_HEADERS
|
||||
|
||||
self.limit_request_field_size = cfg.limit_request_field_size
|
||||
if self.limit_request_field_size < 0:
|
||||
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
|
||||
# Max header buffer size
|
||||
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
self.max_buffer_headers = self.limit_request_fields * \
|
||||
(max_header_field_size + 2) + 4
|
||||
|
||||
# Body-related state
|
||||
self.content_length = None
|
||||
self.chunked = False
|
||||
self._body_reader = None
|
||||
self._body_remaining = 0
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
|
||||
"""Parse an HTTP request from the stream.
|
||||
|
||||
Args:
|
||||
cfg: gunicorn config object
|
||||
unreader: AsyncUnreader instance
|
||||
peer_addr: client address tuple
|
||||
req_number: request number on this connection (for keepalive)
|
||||
|
||||
Returns:
|
||||
AsyncRequest: Parsed request object
|
||||
|
||||
Raises:
|
||||
NoMoreData: If no data available
|
||||
Various parsing errors for malformed requests
|
||||
"""
|
||||
req = cls(cfg, unreader, peer_addr, req_number)
|
||||
await req._parse()
|
||||
return req
|
||||
|
||||
async def _parse(self):
|
||||
"""Parse the request from the unreader."""
|
||||
buf = io.BytesIO()
|
||||
await self._get_data(buf, stop=True)
|
||||
|
||||
# Get request line
|
||||
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||
|
||||
# Proxy protocol
|
||||
if self._proxy_protocol(bytes_to_str(line)):
|
||||
# Get next request line
|
||||
buf = io.BytesIO()
|
||||
buf.write(rbuf)
|
||||
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||
|
||||
self._parse_request_line(line)
|
||||
buf = io.BytesIO()
|
||||
buf.write(rbuf)
|
||||
|
||||
# Headers
|
||||
data = buf.getvalue()
|
||||
|
||||
while True:
|
||||
idx = data.find(b"\r\n\r\n")
|
||||
done = data[:2] == b"\r\n"
|
||||
|
||||
if idx < 0 and not done:
|
||||
await self._get_data(buf)
|
||||
data = buf.getvalue()
|
||||
if len(data) > self.max_buffer_headers:
|
||||
raise LimitRequestHeaders("max buffer headers")
|
||||
else:
|
||||
break
|
||||
|
||||
if done:
|
||||
self.unreader.unread(data[2:])
|
||||
else:
|
||||
self.headers = self._parse_headers(data[:idx], from_trailer=False)
|
||||
self.unreader.unread(data[idx + 4:])
|
||||
|
||||
self._set_body_reader()
|
||||
|
||||
async def _get_data(self, buf, stop=False):
|
||||
"""Read data from unreader into buffer."""
|
||||
data = await self.unreader.read()
|
||||
if not data:
|
||||
if stop:
|
||||
raise StopIteration()
|
||||
raise NoMoreData(buf.getvalue())
|
||||
buf.write(data)
|
||||
|
||||
async def _read_line(self, buf, limit=0):
|
||||
"""Read a line from the buffer/stream."""
|
||||
data = buf.getvalue()
|
||||
|
||||
while True:
|
||||
idx = data.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
if idx > limit > 0:
|
||||
raise LimitRequestLine(idx, limit)
|
||||
break
|
||||
if len(data) - 2 > limit > 0:
|
||||
raise LimitRequestLine(len(data), limit)
|
||||
await self._get_data(buf)
|
||||
data = buf.getvalue()
|
||||
|
||||
return (data[:idx], data[idx + 2:])
|
||||
|
||||
def _proxy_protocol(self, line):
|
||||
"""Detect, check and parse proxy protocol."""
|
||||
if not self.cfg.proxy_protocol:
|
||||
return False
|
||||
|
||||
if self.req_number != 1:
|
||||
return False
|
||||
|
||||
if not line.startswith("PROXY"):
|
||||
return False
|
||||
|
||||
self._proxy_protocol_access_check()
|
||||
self._parse_proxy_protocol(line)
|
||||
|
||||
return True
|
||||
|
||||
def _proxy_protocol_access_check(self):
|
||||
"""Check if proxy protocol is allowed from this peer."""
|
||||
if ("*" not in self.cfg.proxy_allow_ips and
|
||||
isinstance(self.peer_addr, tuple) and
|
||||
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
|
||||
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||
|
||||
def _parse_proxy_protocol(self, line):
|
||||
"""Parse proxy protocol header line."""
|
||||
bits = line.split(" ")
|
||||
|
||||
if len(bits) != 6:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
proto = bits[1]
|
||||
s_addr = bits[2]
|
||||
d_addr = bits[3]
|
||||
|
||||
if proto not in ["TCP4", "TCP6"]:
|
||||
raise InvalidProxyLine("protocol '%s' not supported" % proto)
|
||||
|
||||
if proto == "TCP4":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET, s_addr)
|
||||
socket.inet_pton(socket.AF_INET, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
elif proto == "TCP6":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, s_addr)
|
||||
socket.inet_pton(socket.AF_INET6, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
try:
|
||||
s_port = int(bits[4])
|
||||
d_port = int(bits[5])
|
||||
except ValueError:
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": proto,
|
||||
"client_addr": s_addr,
|
||||
"client_port": s_port,
|
||||
"proxy_addr": d_addr,
|
||||
"proxy_port": d_port
|
||||
}
|
||||
|
||||
def _parse_request_line(self, line_bytes):
|
||||
"""Parse the HTTP request line."""
|
||||
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
|
||||
if len(bits) != 3:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
# Method
|
||||
self.method = bits[0]
|
||||
|
||||
if not self.cfg.permit_unconventional_http_method:
|
||||
if METHOD_BADCHAR_RE.search(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not 3 <= len(bits[0]) <= 20:
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not TOKEN_RE.fullmatch(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if self.cfg.casefold_http_method:
|
||||
self.method = self.method.upper()
|
||||
|
||||
# URI
|
||||
self.uri = bits[1]
|
||||
|
||||
if len(self.uri) == 0:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
try:
|
||||
parts = split_request_uri(self.uri)
|
||||
except ValueError:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
self.path = parts.path or ""
|
||||
self.query = parts.query or ""
|
||||
self.fragment = parts.fragment or ""
|
||||
|
||||
# Version
|
||||
match = VERSION_RE.fullmatch(bits[2])
|
||||
if match is None:
|
||||
raise InvalidHTTPVersion(bits[2])
|
||||
self.version = (int(match.group(1)), int(match.group(2)))
|
||||
if not (1, 0) <= self.version < (2, 0):
|
||||
if not self.cfg.permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(self.version)
|
||||
|
||||
def _parse_headers(self, data, from_trailer=False):
|
||||
"""Parse HTTP headers from raw data."""
|
||||
cfg = self.cfg
|
||||
headers = []
|
||||
|
||||
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
|
||||
|
||||
# Handle scheme headers
|
||||
scheme_header = False
|
||||
secure_scheme_headers = {}
|
||||
forwarder_headers = []
|
||||
if from_trailer:
|
||||
pass
|
||||
elif ('*' in cfg.forwarded_allow_ips or
|
||||
not isinstance(self.peer_addr, tuple)
|
||||
or self.peer_addr[0] in cfg.forwarded_allow_ips):
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
forwarder_headers = cfg.forwarder_headers
|
||||
|
||||
while lines:
|
||||
if len(headers) >= self.limit_request_fields:
|
||||
raise LimitRequestHeaders("limit request headers fields")
|
||||
|
||||
curr = lines.pop(0)
|
||||
header_length = len(curr) + len("\r\n")
|
||||
if curr.find(":") <= 0:
|
||||
raise InvalidHeader(curr)
|
||||
name, value = curr.split(":", 1)
|
||||
if self.cfg.strip_header_spaces:
|
||||
name = name.rstrip(" \t")
|
||||
if not TOKEN_RE.fullmatch(name):
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
name = name.upper()
|
||||
value = [value.strip(" \t")]
|
||||
|
||||
# Consume value continuation lines
|
||||
while lines and lines[0].startswith((" ", "\t")):
|
||||
if not self.cfg.permit_obsolete_folding:
|
||||
raise ObsoleteFolding(name)
|
||||
curr = lines.pop(0)
|
||||
header_length += len(curr) + len("\r\n")
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
value.append(curr.strip("\t "))
|
||||
value = " ".join(value)
|
||||
|
||||
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
|
||||
raise InvalidHeader(name)
|
||||
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
|
||||
if name in secure_scheme_headers:
|
||||
secure = value == secure_scheme_headers[name]
|
||||
scheme = "https" if secure else "http"
|
||||
if scheme_header:
|
||||
if scheme != self.scheme:
|
||||
raise InvalidSchemeHeaders()
|
||||
else:
|
||||
scheme_header = True
|
||||
self.scheme = scheme
|
||||
|
||||
if "_" in name:
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
continue
|
||||
else:
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
headers.append((name, value))
|
||||
|
||||
return headers
|
||||
|
||||
def _set_body_reader(self):
|
||||
"""Determine how to read the request body."""
|
||||
chunked = False
|
||||
content_length = None
|
||||
|
||||
for (name, value) in self.headers:
|
||||
if name == "CONTENT-LENGTH":
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
content_length = value
|
||||
elif name == "TRANSFER-ENCODING":
|
||||
vals = [v.strip() for v in value.split(',')]
|
||||
for val in vals:
|
||||
if val.lower() == "chunked":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
chunked = True
|
||||
elif val.lower() == "identity":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
elif val.lower() in ('compress', 'deflate', 'gzip'):
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
self.force_close()
|
||||
else:
|
||||
raise UnsupportedTransferCoding(value)
|
||||
|
||||
if chunked:
|
||||
if self.version < (1, 1):
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
self.chunked = True
|
||||
self.content_length = None
|
||||
self._body_remaining = -1
|
||||
elif content_length is not None:
|
||||
try:
|
||||
if str(content_length).isnumeric():
|
||||
content_length = int(content_length)
|
||||
else:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
except ValueError:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
if content_length < 0:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
self.content_length = content_length
|
||||
self._body_remaining = content_length
|
||||
else:
|
||||
# No body for requests without Content-Length or Transfer-Encoding
|
||||
self.content_length = 0
|
||||
self._body_remaining = 0
|
||||
|
||||
def force_close(self):
|
||||
"""Mark connection for closing after this request."""
|
||||
self.must_close = True
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for (h, v) in self.headers:
|
||||
if h == "CONNECTION":
|
||||
v = v.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for (h, v) in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
|
||||
async def read_body(self, size=8192):
|
||||
"""Read a chunk of the request body.
|
||||
|
||||
Args:
|
||||
size: Maximum bytes to read
|
||||
|
||||
Returns:
|
||||
bytes: Body data, empty bytes when body is exhausted
|
||||
"""
|
||||
if self._body_remaining == 0:
|
||||
return b""
|
||||
|
||||
if self.chunked:
|
||||
return await self._read_chunked_body(size)
|
||||
else:
|
||||
return await self._read_length_body(size)
|
||||
|
||||
async def _read_length_body(self, size):
|
||||
"""Read from a length-delimited body."""
|
||||
if self._body_remaining <= 0:
|
||||
return b""
|
||||
|
||||
to_read = min(size, self._body_remaining)
|
||||
data = await self.unreader.read(to_read)
|
||||
if data:
|
||||
self._body_remaining -= len(data)
|
||||
return data
|
||||
|
||||
async def _read_chunked_body(self, size):
|
||||
"""Read from a chunked body."""
|
||||
if self._body_reader is None:
|
||||
self._body_reader = self._chunked_body_reader()
|
||||
|
||||
try:
|
||||
return await self._body_reader.__anext__()
|
||||
except StopAsyncIteration:
|
||||
self._body_remaining = 0
|
||||
return b""
|
||||
|
||||
async def _chunked_body_reader(self):
|
||||
"""Async generator for reading chunked body."""
|
||||
while True:
|
||||
# Read chunk size line
|
||||
size_line = await self._read_chunk_size_line()
|
||||
# Parse chunk size (handle extensions)
|
||||
chunk_size, *_ = size_line.split(b";", 1)
|
||||
if _ :
|
||||
chunk_size = chunk_size.rstrip(b" \t")
|
||||
|
||||
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
if len(chunk_size) == 0:
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
|
||||
chunk_size = int(chunk_size, 16)
|
||||
|
||||
if chunk_size == 0:
|
||||
# Final chunk - skip trailers and final CRLF
|
||||
await self._skip_trailers()
|
||||
return
|
||||
|
||||
# Read chunk data
|
||||
remaining = chunk_size
|
||||
while remaining > 0:
|
||||
data = await self.unreader.read(min(remaining, 8192))
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
# Skip chunk terminating CRLF
|
||||
crlf = await self.unreader.read(2)
|
||||
if crlf != b"\r\n":
|
||||
# May have partial read, try to get the rest
|
||||
while len(crlf) < 2:
|
||||
more = await self.unreader.read(2 - len(crlf))
|
||||
if not more:
|
||||
break
|
||||
crlf += more
|
||||
if crlf != b"\r\n":
|
||||
raise InvalidHeader("Missing chunk terminator")
|
||||
|
||||
async def _read_chunk_size_line(self):
|
||||
"""Read a chunk size line."""
|
||||
buf = io.BytesIO()
|
||||
while True:
|
||||
data = await self.unreader.read(1)
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
buf.write(data)
|
||||
if buf.getvalue().endswith(b"\r\n"):
|
||||
return buf.getvalue()[:-2]
|
||||
|
||||
async def _skip_trailers(self):
|
||||
"""Skip trailer headers after chunked body."""
|
||||
buf = io.BytesIO()
|
||||
while True:
|
||||
data = await self.unreader.read(1)
|
||||
if not data:
|
||||
return
|
||||
buf.write(data)
|
||||
content = buf.getvalue()
|
||||
if content.endswith(b"\r\n\r\n"):
|
||||
# Could parse trailers here if needed
|
||||
return
|
||||
if content == b"\r\n":
|
||||
return
|
||||
|
||||
async def drain_body(self):
|
||||
"""Drain any unread body data.
|
||||
|
||||
Should be called before reusing connection for keepalive.
|
||||
"""
|
||||
while True:
|
||||
data = await self.read_body(8192)
|
||||
if not data:
|
||||
break
|
||||
424
gunicorn/asgi/protocol.py
Normal file
424
gunicorn/asgi/protocol.py
Normal file
@ -0,0 +1,424 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI protocol handler for gunicorn.
|
||||
|
||||
Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch
|
||||
to ASGI applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.http.errors import NoMoreData
|
||||
|
||||
|
||||
class ASGIProtocol(asyncio.Protocol):
|
||||
"""HTTP/1.1 protocol handler for ASGI applications.
|
||||
|
||||
Handles connection lifecycle, request parsing, and ASGI app invocation.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
self.cfg = worker.cfg
|
||||
self.log = worker.log
|
||||
self.app = worker.asgi
|
||||
|
||||
self.transport = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._task = None
|
||||
self.req_count = 0
|
||||
|
||||
# Connection state
|
||||
self._closed = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Called when a connection is established."""
|
||||
self.transport = transport
|
||||
self.worker.nr_conns += 1
|
||||
|
||||
# Create stream reader/writer
|
||||
self.reader = asyncio.StreamReader()
|
||||
self.writer = transport
|
||||
|
||||
# Start handling requests
|
||||
self._task = self.worker.loop.create_task(self._handle_connection())
|
||||
|
||||
def data_received(self, data):
|
||||
"""Called when data is received on the connection."""
|
||||
if self.reader:
|
||||
self.reader.feed_data(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Called when the connection is lost or closed."""
|
||||
self._closed = True
|
||||
self.worker.nr_conns -= 1
|
||||
if self.reader:
|
||||
self.reader.feed_eof()
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
async def _handle_connection(self):
|
||||
"""Main request handling loop for this connection."""
|
||||
unreader = AsyncUnreader(self.reader)
|
||||
|
||||
try:
|
||||
peername = self.transport.get_extra_info('peername')
|
||||
sockname = self.transport.get_extra_info('sockname')
|
||||
|
||||
while not self._closed:
|
||||
self.req_count += 1
|
||||
|
||||
try:
|
||||
# Parse HTTP request
|
||||
request = await AsyncRequest.parse(
|
||||
self.cfg,
|
||||
unreader,
|
||||
peername,
|
||||
self.req_count
|
||||
)
|
||||
except StopIteration:
|
||||
# No more data, close connection
|
||||
break
|
||||
except NoMoreData:
|
||||
# Client disconnected
|
||||
break
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
break # WebSocket takes over the connection
|
||||
else:
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
self.worker.nr += 1
|
||||
|
||||
# Check max_requests
|
||||
if self.worker.nr >= self.worker.max_requests:
|
||||
self.log.info("Autorestarting worker after current request.")
|
||||
self.worker.alive = False
|
||||
keepalive = False
|
||||
|
||||
if not keepalive or not self.worker.alive:
|
||||
break
|
||||
|
||||
# Check connection limits for keepalive
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
# Drain any unread body before next request
|
||||
await request.drain_body()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.log.exception("Error handling connection: %s", e)
|
||||
finally:
|
||||
self._close_transport()
|
||||
|
||||
def _is_websocket_upgrade(self, request):
|
||||
"""Check if request is a WebSocket upgrade."""
|
||||
upgrade = None
|
||||
connection = None
|
||||
for name, value in request.headers:
|
||||
if name == "UPGRADE":
|
||||
upgrade = value.lower()
|
||||
elif name == "CONNECTION":
|
||||
connection = value.lower()
|
||||
return upgrade == "websocket" and connection and "upgrade" in connection
|
||||
|
||||
async def _handle_websocket(self, request, sockname, peername):
|
||||
"""Handle WebSocket upgrade request."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
scope = self._build_websocket_scope(request, sockname, peername)
|
||||
ws_protocol = WebSocketProtocol(
|
||||
self.transport, self.reader, scope, self.app, self.log
|
||||
)
|
||||
await ws_protocol.run()
|
||||
|
||||
async def _handle_http_request(self, request, sockname, peername):
|
||||
"""Handle a single HTTP request."""
|
||||
scope = self._build_http_scope(request, sockname, peername)
|
||||
response_started = False
|
||||
response_complete = False
|
||||
body_parts = []
|
||||
exc_to_raise = None
|
||||
|
||||
# Receive queue for body
|
||||
receive_queue = asyncio.Queue()
|
||||
|
||||
# Pre-populate with initial body state
|
||||
if request.content_length == 0 and not request.chunked:
|
||||
await receive_queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
else:
|
||||
# Start body reading task
|
||||
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
|
||||
|
||||
async def receive():
|
||||
return await receive_queue.get()
|
||||
|
||||
async def send(message):
|
||||
nonlocal response_started, response_complete, exc_to_raise
|
||||
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "http.response.start":
|
||||
if response_started:
|
||||
exc_to_raise = RuntimeError("Response already started")
|
||||
return
|
||||
response_started = True
|
||||
status = message["status"]
|
||||
headers = message.get("headers", [])
|
||||
await self._send_response_start(status, headers, request)
|
||||
|
||||
elif msg_type == "http.response.body":
|
||||
if not response_started:
|
||||
exc_to_raise = RuntimeError("Response not started")
|
||||
return
|
||||
if response_complete:
|
||||
exc_to_raise = RuntimeError("Response already complete")
|
||||
return
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
if body:
|
||||
await self._send_body(body)
|
||||
|
||||
if not more_body:
|
||||
response_complete = True
|
||||
|
||||
try:
|
||||
request_start = datetime.now()
|
||||
self.cfg.pre_request(self.worker, request)
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
if exc_to_raise:
|
||||
raise exc_to_raise
|
||||
|
||||
# Ensure response was sent
|
||||
if not response_started:
|
||||
await self._send_error_response(500, "Internal Server Error")
|
||||
|
||||
except Exception as e:
|
||||
self.log.exception("Error in ASGI application")
|
||||
if not response_started:
|
||||
await self._send_error_response(500, "Internal Server Error")
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
request_time = datetime.now() - request_start
|
||||
self.cfg.post_request(self.worker, request, {}, None)
|
||||
except Exception:
|
||||
self.log.exception("Exception in post_request hook")
|
||||
|
||||
# Determine keepalive
|
||||
if request.should_close():
|
||||
return False
|
||||
|
||||
return self.worker.alive and self.cfg.keepalive
|
||||
|
||||
async def _read_body_to_queue(self, request, queue):
|
||||
"""Read request body and put chunks on the queue."""
|
||||
try:
|
||||
while True:
|
||||
chunk = await request.read_body(65536)
|
||||
if chunk:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
})
|
||||
else:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
self.log.debug("Error reading body: %s", e)
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
|
||||
def _build_http_scope(self, request, sockname, peername):
|
||||
"""Build ASGI HTTP scope from parsed request."""
|
||||
# Build headers list as bytes tuples
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||
"method": request.method,
|
||||
"scheme": request.scheme,
|
||||
"path": request.path,
|
||||
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||
"root_path": self.cfg.root_path or "",
|
||||
"headers": headers,
|
||||
"server": sockname if sockname else None,
|
||||
"client": peername if peername else None,
|
||||
}
|
||||
|
||||
# Add state dict for lifespan sharing
|
||||
if hasattr(self.worker, 'state'):
|
||||
scope["state"] = self.worker.state
|
||||
|
||||
return scope
|
||||
|
||||
def _build_websocket_scope(self, request, sockname, peername):
|
||||
"""Build ASGI WebSocket scope from parsed request."""
|
||||
# Build headers list as bytes tuples
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
# Extract subprotocols from Sec-WebSocket-Protocol header
|
||||
subprotocols = []
|
||||
for name, value in request.headers:
|
||||
if name == "SEC-WEBSOCKET-PROTOCOL":
|
||||
subprotocols = [s.strip() for s in value.split(",")]
|
||||
break
|
||||
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||
"scheme": "wss" if request.scheme == "https" else "ws",
|
||||
"path": request.path,
|
||||
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||
"root_path": self.cfg.root_path or "",
|
||||
"headers": headers,
|
||||
"server": sockname if sockname else None,
|
||||
"client": peername if peername else None,
|
||||
"subprotocols": subprotocols,
|
||||
}
|
||||
|
||||
# Add state dict for lifespan sharing
|
||||
if hasattr(self.worker, 'state'):
|
||||
scope["state"] = self.worker.state
|
||||
|
||||
return scope
|
||||
|
||||
async def _send_response_start(self, status, headers, request):
|
||||
"""Send HTTP response status and headers."""
|
||||
# Build status line
|
||||
reason = self._get_reason_phrase(status)
|
||||
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
|
||||
|
||||
# Build headers
|
||||
header_lines = []
|
||||
has_content_length = False
|
||||
has_transfer_encoding = False
|
||||
has_connection = False
|
||||
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
header_lines.append(f"{name}: {value}\r\n")
|
||||
name_lower = name.lower()
|
||||
if name_lower == "content-length":
|
||||
has_content_length = True
|
||||
elif name_lower == "transfer-encoding":
|
||||
has_transfer_encoding = True
|
||||
elif name_lower == "connection":
|
||||
has_connection = True
|
||||
|
||||
# Add server header if not present
|
||||
header_lines.append("Server: gunicorn/asgi\r\n")
|
||||
|
||||
response = status_line + "".join(header_lines) + "\r\n"
|
||||
self.transport.write(response.encode("latin-1"))
|
||||
|
||||
async def _send_body(self, body):
|
||||
"""Send response body chunk."""
|
||||
if body:
|
||||
self.transport.write(body)
|
||||
|
||||
async def _send_error_response(self, status, message):
|
||||
"""Send an error response."""
|
||||
body = message.encode("utf-8")
|
||||
response = (
|
||||
f"HTTP/1.1 {status} {message}\r\n"
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"Connection: close\r\n"
|
||||
f"\r\n"
|
||||
)
|
||||
self.transport.write(response.encode("latin-1"))
|
||||
self.transport.write(body)
|
||||
|
||||
def _get_reason_phrase(self, status):
|
||||
"""Get HTTP reason phrase for status code."""
|
||||
reasons = {
|
||||
100: "Continue",
|
||||
101: "Switching Protocols",
|
||||
200: "OK",
|
||||
201: "Created",
|
||||
202: "Accepted",
|
||||
204: "No Content",
|
||||
206: "Partial Content",
|
||||
301: "Moved Permanently",
|
||||
302: "Found",
|
||||
303: "See Other",
|
||||
304: "Not Modified",
|
||||
307: "Temporary Redirect",
|
||||
308: "Permanent Redirect",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Not Found",
|
||||
405: "Method Not Allowed",
|
||||
408: "Request Timeout",
|
||||
409: "Conflict",
|
||||
410: "Gone",
|
||||
411: "Length Required",
|
||||
413: "Payload Too Large",
|
||||
414: "URI Too Long",
|
||||
415: "Unsupported Media Type",
|
||||
422: "Unprocessable Entity",
|
||||
429: "Too Many Requests",
|
||||
500: "Internal Server Error",
|
||||
501: "Not Implemented",
|
||||
502: "Bad Gateway",
|
||||
503: "Service Unavailable",
|
||||
504: "Gateway Timeout",
|
||||
}
|
||||
return reasons.get(status, "Unknown")
|
||||
|
||||
def _close_transport(self):
|
||||
"""Close the transport safely."""
|
||||
if self.transport and not self._closed:
|
||||
try:
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._closed = True
|
||||
100
gunicorn/asgi/unreader.py
Normal file
100
gunicorn/asgi/unreader.py
Normal file
@ -0,0 +1,100 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Async version of gunicorn/http/unreader.py for ASGI workers.
|
||||
|
||||
Provides async reading with pushback buffer support.
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
|
||||
class AsyncUnreader:
|
||||
"""Async socket reader with pushback buffer support.
|
||||
|
||||
This class wraps an asyncio StreamReader and provides the ability
|
||||
to "unread" data back into a buffer for re-parsing.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, max_chunk=8192):
|
||||
"""Initialize the async unreader.
|
||||
|
||||
Args:
|
||||
reader: asyncio.StreamReader instance
|
||||
max_chunk: Maximum bytes to read at once
|
||||
"""
|
||||
self.reader = reader
|
||||
self.buf = io.BytesIO()
|
||||
self.max_chunk = max_chunk
|
||||
|
||||
async def read(self, size=None):
|
||||
"""Read data from the stream, using buffered data first.
|
||||
|
||||
Args:
|
||||
size: Number of bytes to read. If None, returns all buffered
|
||||
data or reads a single chunk.
|
||||
|
||||
Returns:
|
||||
bytes: Data read from buffer or stream
|
||||
"""
|
||||
if size is not None and not isinstance(size, int):
|
||||
raise TypeError("size parameter must be an int or long.")
|
||||
|
||||
if size is not None:
|
||||
if size == 0:
|
||||
return b""
|
||||
if size < 0:
|
||||
size = None
|
||||
|
||||
# Move to end to check buffer size
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
|
||||
# If no size specified, return buffered data or read chunk
|
||||
if size is None and self.buf.tell():
|
||||
ret = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
return ret
|
||||
if size is None:
|
||||
chunk = await self._read_chunk()
|
||||
return chunk
|
||||
|
||||
# Read until we have enough data
|
||||
while self.buf.tell() < size:
|
||||
chunk = await self._read_chunk()
|
||||
if not chunk:
|
||||
ret = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
return ret
|
||||
self.buf.write(chunk)
|
||||
|
||||
data = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
self.buf.write(data[size:])
|
||||
return data[:size]
|
||||
|
||||
async def _read_chunk(self):
|
||||
"""Read a chunk of data from the underlying stream."""
|
||||
try:
|
||||
return await self.reader.read(self.max_chunk)
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
def unread(self, data):
|
||||
"""Push data back into the buffer for re-reading.
|
||||
|
||||
Args:
|
||||
data: bytes to push back
|
||||
"""
|
||||
if data:
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
self.buf.write(data)
|
||||
|
||||
def has_buffered_data(self):
|
||||
"""Check if there's data in the pushback buffer."""
|
||||
pos = self.buf.tell()
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
has_data = self.buf.tell() > 0
|
||||
self.buf.seek(pos)
|
||||
return has_data
|
||||
369
gunicorn/asgi/websocket.py
Normal file
369
gunicorn/asgi/websocket.py
Normal file
@ -0,0 +1,369 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
WebSocket protocol handler for ASGI.
|
||||
|
||||
Implements RFC 6455 WebSocket protocol for ASGI applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
import os
|
||||
|
||||
|
||||
# WebSocket frame opcodes
|
||||
OPCODE_CONTINUATION = 0x0
|
||||
OPCODE_TEXT = 0x1
|
||||
OPCODE_BINARY = 0x2
|
||||
OPCODE_CLOSE = 0x8
|
||||
OPCODE_PING = 0x9
|
||||
OPCODE_PONG = 0xA
|
||||
|
||||
# WebSocket close codes
|
||||
CLOSE_NORMAL = 1000
|
||||
CLOSE_GOING_AWAY = 1001
|
||||
CLOSE_PROTOCOL_ERROR = 1002
|
||||
CLOSE_UNSUPPORTED = 1003
|
||||
CLOSE_NO_STATUS = 1005
|
||||
CLOSE_ABNORMAL = 1006
|
||||
CLOSE_INVALID_DATA = 1007
|
||||
CLOSE_POLICY_VIOLATION = 1008
|
||||
CLOSE_MESSAGE_TOO_BIG = 1009
|
||||
CLOSE_MANDATORY_EXT = 1010
|
||||
CLOSE_INTERNAL_ERROR = 1011
|
||||
|
||||
# WebSocket handshake GUID (RFC 6455)
|
||||
WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class WebSocketProtocol:
|
||||
"""WebSocket connection handler for ASGI applications."""
|
||||
|
||||
def __init__(self, transport, reader, scope, app, log):
|
||||
"""Initialize WebSocket protocol handler.
|
||||
|
||||
Args:
|
||||
transport: asyncio transport for writing
|
||||
reader: asyncio StreamReader for reading
|
||||
scope: ASGI WebSocket scope dict
|
||||
app: ASGI application callable
|
||||
log: Logger instance
|
||||
"""
|
||||
self.transport = transport
|
||||
self.reader = reader
|
||||
self.scope = scope
|
||||
self.app = app
|
||||
self.log = log
|
||||
|
||||
self.accepted = False
|
||||
self.closed = False
|
||||
self.close_code = None
|
||||
self.close_reason = ""
|
||||
|
||||
# Message reassembly state
|
||||
self._fragments = []
|
||||
self._fragment_opcode = None
|
||||
|
||||
# Receive queue for incoming messages
|
||||
self._receive_queue = asyncio.Queue()
|
||||
|
||||
async def run(self):
|
||||
"""Run the WebSocket ASGI application."""
|
||||
# Send initial connect event
|
||||
await self._receive_queue.put({"type": "websocket.connect"})
|
||||
|
||||
# Start frame reading task
|
||||
read_task = asyncio.create_task(self._read_frames())
|
||||
|
||||
try:
|
||||
await self.app(self.scope, self._receive, self._send)
|
||||
except Exception as e:
|
||||
self.log.exception("Error in WebSocket ASGI application")
|
||||
finally:
|
||||
read_task.cancel()
|
||||
try:
|
||||
await read_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Send close frame if not already closed
|
||||
if not self.closed and self.accepted:
|
||||
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
|
||||
|
||||
async def _receive(self):
|
||||
"""ASGI receive callable."""
|
||||
return await self._receive_queue.get()
|
||||
|
||||
async def _send(self, message):
|
||||
"""ASGI send callable."""
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "websocket.accept":
|
||||
if self.accepted:
|
||||
raise RuntimeError("WebSocket already accepted")
|
||||
await self._send_accept(message)
|
||||
self.accepted = True
|
||||
|
||||
elif msg_type == "websocket.send":
|
||||
if not self.accepted:
|
||||
raise RuntimeError("WebSocket not accepted")
|
||||
if self.closed:
|
||||
raise RuntimeError("WebSocket closed")
|
||||
|
||||
if "text" in message:
|
||||
await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8"))
|
||||
elif "bytes" in message:
|
||||
await self._send_frame(OPCODE_BINARY, message["bytes"])
|
||||
|
||||
elif msg_type == "websocket.close":
|
||||
code = message.get("code", CLOSE_NORMAL)
|
||||
reason = message.get("reason", "")
|
||||
await self._send_close(code, reason)
|
||||
self.closed = True
|
||||
|
||||
async def _send_accept(self, message):
|
||||
"""Send WebSocket handshake accept response."""
|
||||
# Get Sec-WebSocket-Key from headers
|
||||
ws_key = None
|
||||
for name, value in self.scope["headers"]:
|
||||
if name == b"sec-websocket-key":
|
||||
ws_key = value
|
||||
break
|
||||
|
||||
if not ws_key:
|
||||
raise RuntimeError("Missing Sec-WebSocket-Key header")
|
||||
|
||||
# Calculate accept key
|
||||
accept_key = base64.b64encode(
|
||||
hashlib.sha1(ws_key + WS_GUID).digest()
|
||||
).decode("ascii")
|
||||
|
||||
# Build response headers
|
||||
headers = [
|
||||
"HTTP/1.1 101 Switching Protocols\r\n",
|
||||
"Upgrade: websocket\r\n",
|
||||
"Connection: Upgrade\r\n",
|
||||
f"Sec-WebSocket-Accept: {accept_key}\r\n",
|
||||
]
|
||||
|
||||
# Add selected subprotocol if specified
|
||||
subprotocol = message.get("subprotocol")
|
||||
if subprotocol:
|
||||
headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n")
|
||||
|
||||
# Add any extra headers from message
|
||||
extra_headers = message.get("headers", [])
|
||||
for name, value in extra_headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
headers.append(f"{name}: {value}\r\n")
|
||||
|
||||
headers.append("\r\n")
|
||||
self.transport.write("".join(headers).encode("latin-1"))
|
||||
|
||||
async def _read_frames(self):
|
||||
"""Read and process incoming WebSocket frames."""
|
||||
try:
|
||||
while not self.closed:
|
||||
frame = await self._read_frame()
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
opcode, payload = frame
|
||||
|
||||
if opcode == OPCODE_CLOSE:
|
||||
await self._handle_close(payload)
|
||||
break
|
||||
elif opcode == OPCODE_PING:
|
||||
await self._send_frame(OPCODE_PONG, payload)
|
||||
elif opcode == OPCODE_PONG:
|
||||
# Ignore pongs
|
||||
pass
|
||||
elif opcode == OPCODE_TEXT:
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.receive",
|
||||
"text": payload.decode("utf-8"),
|
||||
})
|
||||
elif opcode == OPCODE_BINARY:
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.receive",
|
||||
"bytes": payload,
|
||||
})
|
||||
elif opcode == OPCODE_CONTINUATION:
|
||||
# Handle fragmented messages
|
||||
await self._handle_continuation(payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.log.debug("WebSocket read error: %s", e)
|
||||
finally:
|
||||
# Signal disconnect
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.disconnect",
|
||||
"code": self.close_code or CLOSE_ABNORMAL,
|
||||
})
|
||||
|
||||
async def _read_frame(self):
|
||||
"""Read a single WebSocket frame.
|
||||
|
||||
Returns:
|
||||
tuple: (opcode, payload) or None if connection closed
|
||||
"""
|
||||
# Read frame header (2 bytes minimum)
|
||||
header = await self._read_exact(2)
|
||||
if not header:
|
||||
return None
|
||||
|
||||
first_byte, second_byte = header[0], header[1]
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0x0F
|
||||
|
||||
# RSV bits must be 0 (no extensions)
|
||||
if rsv1 or rsv2 or rsv3:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set")
|
||||
return None
|
||||
|
||||
masked = (second_byte >> 7) & 1
|
||||
payload_len = second_byte & 0x7F
|
||||
|
||||
# Client frames must be masked (RFC 6455)
|
||||
if not masked:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked")
|
||||
return None
|
||||
|
||||
# Extended payload length
|
||||
if payload_len == 126:
|
||||
ext_len = await self._read_exact(2)
|
||||
if not ext_len:
|
||||
return None
|
||||
payload_len = struct.unpack("!H", ext_len)[0]
|
||||
elif payload_len == 127:
|
||||
ext_len = await self._read_exact(8)
|
||||
if not ext_len:
|
||||
return None
|
||||
payload_len = struct.unpack("!Q", ext_len)[0]
|
||||
|
||||
# Read masking key
|
||||
masking_key = await self._read_exact(4)
|
||||
if not masking_key:
|
||||
return None
|
||||
|
||||
# Read payload
|
||||
payload = await self._read_exact(payload_len)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
# Unmask payload
|
||||
payload = self._unmask(payload, masking_key)
|
||||
|
||||
# Handle fragmented messages
|
||||
if opcode == OPCODE_CONTINUATION:
|
||||
if self._fragment_opcode is None:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation")
|
||||
return None
|
||||
self._fragments.append(payload)
|
||||
if fin:
|
||||
# Reassemble complete message
|
||||
full_payload = b"".join(self._fragments)
|
||||
final_opcode = self._fragment_opcode
|
||||
self._fragments = []
|
||||
self._fragment_opcode = None
|
||||
return (final_opcode, full_payload)
|
||||
return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more
|
||||
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||
if not fin:
|
||||
# Start of fragmented message
|
||||
self._fragment_opcode = opcode
|
||||
self._fragments = [payload]
|
||||
return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more
|
||||
return (opcode, payload)
|
||||
else:
|
||||
# Control frames
|
||||
return (opcode, payload)
|
||||
|
||||
async def _read_exact(self, n):
|
||||
"""Read exactly n bytes from the reader."""
|
||||
try:
|
||||
data = await self.reader.readexactly(n)
|
||||
return data
|
||||
except asyncio.IncompleteReadError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _unmask(self, payload, masking_key):
|
||||
"""Unmask WebSocket payload data."""
|
||||
if not payload:
|
||||
return payload
|
||||
# XOR each byte with corresponding mask byte
|
||||
return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload))
|
||||
|
||||
async def _handle_close(self, payload):
|
||||
"""Handle incoming close frame."""
|
||||
if len(payload) >= 2:
|
||||
self.close_code = struct.unpack("!H", payload[:2])[0]
|
||||
self.close_reason = payload[2:].decode("utf-8", errors="replace")
|
||||
else:
|
||||
self.close_code = CLOSE_NO_STATUS
|
||||
self.close_reason = ""
|
||||
|
||||
# Echo close frame back if we haven't already sent one
|
||||
if not self.closed:
|
||||
await self._send_close(self.close_code, self.close_reason)
|
||||
|
||||
self.closed = True
|
||||
|
||||
async def _handle_continuation(self, payload):
|
||||
"""Handle continuation frame (already processed in _read_frame)."""
|
||||
# This is called for partial fragments, nothing to do
|
||||
pass
|
||||
|
||||
async def _send_frame(self, opcode, payload):
|
||||
"""Send a WebSocket frame.
|
||||
|
||||
Server frames are not masked (RFC 6455).
|
||||
"""
|
||||
if isinstance(payload, str):
|
||||
payload = payload.encode("utf-8")
|
||||
|
||||
length = len(payload)
|
||||
frame = bytearray()
|
||||
|
||||
# First byte: FIN + opcode
|
||||
frame.append(0x80 | opcode)
|
||||
|
||||
# Second byte: length (no mask bit for server)
|
||||
if length < 126:
|
||||
frame.append(length)
|
||||
elif length < 65536:
|
||||
frame.append(126)
|
||||
frame.extend(struct.pack("!H", length))
|
||||
else:
|
||||
frame.append(127)
|
||||
frame.extend(struct.pack("!Q", length))
|
||||
|
||||
# Payload
|
||||
frame.extend(payload)
|
||||
|
||||
self.transport.write(bytes(frame))
|
||||
|
||||
async def _send_close(self, code, reason=""):
|
||||
"""Send a close frame."""
|
||||
payload = struct.pack("!H", code)
|
||||
if reason:
|
||||
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
|
||||
await self._send_frame(OPCODE_CLOSE, payload)
|
||||
self.closed = True
|
||||
@ -2440,3 +2440,94 @@ class HeaderMap(Setting):
|
||||
|
||||
.. versionadded:: 22.0.0
|
||||
"""
|
||||
|
||||
|
||||
def validate_asgi_loop(val):
|
||||
if val is None:
|
||||
return "auto"
|
||||
if not isinstance(val, str):
|
||||
raise TypeError("Invalid type for casting: %s" % val)
|
||||
val = val.lower().strip()
|
||||
if val not in ("auto", "asyncio", "uvloop"):
|
||||
raise ValueError("Invalid ASGI loop: %s" % val)
|
||||
return val
|
||||
|
||||
|
||||
def validate_asgi_lifespan(val):
|
||||
if val is None:
|
||||
return "auto"
|
||||
if not isinstance(val, str):
|
||||
raise TypeError("Invalid type for casting: %s" % val)
|
||||
val = val.lower().strip()
|
||||
if val not in ("auto", "on", "off"):
|
||||
raise ValueError("Invalid ASGI lifespan: %s" % val)
|
||||
return val
|
||||
|
||||
|
||||
class ASGILoop(Setting):
|
||||
name = "asgi_loop"
|
||||
section = "Worker Processes"
|
||||
cli = ["--asgi-loop"]
|
||||
meta = "STRING"
|
||||
validator = validate_asgi_loop
|
||||
default = "auto"
|
||||
desc = """\
|
||||
Event loop implementation for ASGI workers.
|
||||
|
||||
- auto: Use uvloop if available, otherwise asyncio
|
||||
- asyncio: Use Python's built-in asyncio event loop
|
||||
- uvloop: Use uvloop (must be installed separately)
|
||||
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
uvloop typically provides better performance but requires
|
||||
installing the uvloop package.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
|
||||
class ASGILifespan(Setting):
|
||||
name = "asgi_lifespan"
|
||||
section = "Worker Processes"
|
||||
cli = ["--asgi-lifespan"]
|
||||
meta = "STRING"
|
||||
validator = validate_asgi_lifespan
|
||||
default = "auto"
|
||||
desc = """\
|
||||
Control ASGI lifespan protocol handling.
|
||||
|
||||
- auto: Detect if app supports lifespan, enable if so
|
||||
- on: Always run lifespan protocol (fail if unsupported)
|
||||
- off: Never run lifespan protocol
|
||||
|
||||
The lifespan protocol allows ASGI applications to run code at
|
||||
startup and shutdown. This is essential for frameworks like
|
||||
FastAPI that need to initialize database connections, caches,
|
||||
or other resources.
|
||||
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
|
||||
class RootPath(Setting):
|
||||
name = "root_path"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--root-path"]
|
||||
meta = "STRING"
|
||||
validator = validate_string
|
||||
default = ""
|
||||
desc = """\
|
||||
The root path for ASGI applications.
|
||||
|
||||
This is used to set the ``root_path`` in the ASGI scope, which
|
||||
allows applications to know their mount point when behind a
|
||||
reverse proxy.
|
||||
|
||||
For example, if your application is mounted at ``/api``, set
|
||||
this to ``/api``.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
@ -11,4 +11,5 @@ SUPPORTED_WORKERS = {
|
||||
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
|
||||
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
|
||||
"gthread": "gunicorn.workers.gthread.ThreadWorker",
|
||||
"asgi": "gunicorn.workers.gasgi.ASGIWorker",
|
||||
}
|
||||
|
||||
282
gunicorn/workers/gasgi.py
Normal file
282
gunicorn/workers/gasgi.py
Normal file
@ -0,0 +1,282 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI worker for gunicorn.
|
||||
|
||||
Provides native asyncio-based ASGI support using gunicorn's own
|
||||
HTTP parsing infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
from gunicorn.workers import base
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
|
||||
class ASGIWorker(base.Worker):
|
||||
"""ASGI worker using asyncio event loop.
|
||||
|
||||
Supports:
|
||||
- HTTP/1.1 with keepalive
|
||||
- WebSocket connections
|
||||
- Lifespan protocol (startup/shutdown hooks)
|
||||
- Optional uvloop for improved performance
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.worker_connections = self.cfg.worker_connections
|
||||
self.loop = None
|
||||
self.servers = []
|
||||
self.nr_conns = 0
|
||||
self.lifespan = None
|
||||
self.state = {} # Shared state for lifespan
|
||||
|
||||
@classmethod
|
||||
def check_config(cls, cfg, log):
|
||||
"""Validate configuration for ASGI worker."""
|
||||
if cfg.threads > 1:
|
||||
log.warning("ASGI worker does not use threads configuration. "
|
||||
"Use worker_connections instead.")
|
||||
|
||||
def init_process(self):
|
||||
"""Initialize the worker process."""
|
||||
# Setup event loop before calling super()
|
||||
self._setup_event_loop()
|
||||
super().init_process()
|
||||
|
||||
def _setup_event_loop(self):
|
||||
"""Setup the asyncio event loop."""
|
||||
loop_type = getattr(self.cfg, 'asgi_loop', 'auto')
|
||||
|
||||
if loop_type == "auto":
|
||||
try:
|
||||
import uvloop
|
||||
loop_type = "uvloop"
|
||||
except ImportError:
|
||||
loop_type = "asyncio"
|
||||
|
||||
if loop_type == "uvloop":
|
||||
try:
|
||||
import uvloop
|
||||
self.loop = uvloop.new_event_loop()
|
||||
self.log.debug("Using uvloop event loop")
|
||||
except ImportError:
|
||||
self.log.warning("uvloop not available, falling back to asyncio")
|
||||
self.loop = asyncio.new_event_loop()
|
||||
else:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.log.debug("Using asyncio event loop")
|
||||
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def load_wsgi(self):
|
||||
"""Load the ASGI application."""
|
||||
try:
|
||||
self.asgi = self.app.wsgi()
|
||||
except SyntaxError as e:
|
||||
if not self.cfg.reload:
|
||||
raise
|
||||
self.log.exception(e)
|
||||
self.asgi = self._make_error_app(str(e))
|
||||
|
||||
def _make_error_app(self, error_msg):
|
||||
"""Create an error ASGI app for syntax errors during reload."""
|
||||
async def error_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": f"Application error: {error_msg}".encode(),
|
||||
})
|
||||
elif scope["type"] == "lifespan":
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return error_app
|
||||
|
||||
def init_signals(self):
|
||||
"""Initialize signal handlers for asyncio."""
|
||||
# Reset all signals first
|
||||
for s in self.SIGNALS:
|
||||
signal.signal(s, signal.SIG_DFL)
|
||||
|
||||
# Set up signal handlers via the event loop
|
||||
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal)
|
||||
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal)
|
||||
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal)
|
||||
|
||||
def handle_quit_signal(self):
|
||||
"""Handle SIGQUIT - immediate shutdown."""
|
||||
self.alive = False
|
||||
self.cfg.worker_int(self)
|
||||
|
||||
def handle_exit_signal(self):
|
||||
"""Handle SIGTERM - graceful shutdown."""
|
||||
self.alive = False
|
||||
|
||||
def handle_usr1_signal(self):
|
||||
"""Handle SIGUSR1 - reopen log files."""
|
||||
self.log.reopen_files()
|
||||
|
||||
def handle_winch_signal(self):
|
||||
"""Handle SIGWINCH - ignored in worker."""
|
||||
self.log.debug("worker: SIGWINCH ignored.")
|
||||
|
||||
def handle_abort_signal(self):
|
||||
"""Handle SIGABRT - abort."""
|
||||
self.alive = False
|
||||
self.cfg.worker_abort(self)
|
||||
sys.exit(1)
|
||||
|
||||
def run(self):
|
||||
"""Main entry point for the worker."""
|
||||
try:
|
||||
self.loop.run_until_complete(self._serve())
|
||||
except Exception as e:
|
||||
self.log.exception("Worker exception: %s", e)
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
async def _serve(self):
|
||||
"""Main async serving loop."""
|
||||
# Run lifespan startup
|
||||
lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto')
|
||||
if lifespan_mode != "off":
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
self.lifespan = LifespanManager(self.asgi, self.log, self.state)
|
||||
try:
|
||||
await self.lifespan.startup()
|
||||
except Exception as e:
|
||||
if lifespan_mode == "on":
|
||||
self.log.error("ASGI lifespan startup failed: %s", e)
|
||||
return
|
||||
else:
|
||||
# auto mode - app doesn't support lifespan
|
||||
self.log.debug("ASGI lifespan not supported by app: %s", e)
|
||||
self.lifespan = None
|
||||
|
||||
# Create servers for each listener socket
|
||||
ssl_context = self._get_ssl_context()
|
||||
|
||||
for sock in self.sockets:
|
||||
try:
|
||||
server = await self.loop.create_server(
|
||||
lambda: ASGIProtocol(self),
|
||||
sock=sock.sock,
|
||||
ssl=ssl_context,
|
||||
reuse_address=True,
|
||||
start_serving=True,
|
||||
)
|
||||
self.servers.append(server)
|
||||
self.log.info("ASGI server listening on %s", sock)
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create server on %s: %s", sock, e)
|
||||
|
||||
if not self.servers:
|
||||
self.log.error("No servers could be started")
|
||||
return
|
||||
|
||||
# Main loop with heartbeat
|
||||
try:
|
||||
while self.alive:
|
||||
self.notify()
|
||||
|
||||
# Check if parent is still alive
|
||||
if self.ppid != os.getppid():
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
break
|
||||
|
||||
# Check connection limit
|
||||
# (Connections are managed by nr_conns in ASGIProtocol)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Graceful shutdown
|
||||
await self._shutdown()
|
||||
|
||||
async def _shutdown(self):
|
||||
"""Perform graceful shutdown."""
|
||||
self.log.info("Worker shutting down...")
|
||||
|
||||
# Stop accepting new connections
|
||||
for server in self.servers:
|
||||
server.close()
|
||||
|
||||
# Wait for servers to close
|
||||
for server in self.servers:
|
||||
await server.wait_closed()
|
||||
|
||||
# Wait for in-flight connections (with timeout)
|
||||
graceful_timeout = self.cfg.graceful_timeout
|
||||
if self.nr_conns > 0:
|
||||
self.log.info("Waiting for %d connections to finish...", self.nr_conns)
|
||||
deadline = self.loop.time() + graceful_timeout
|
||||
while self.nr_conns > 0 and self.loop.time() < deadline:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if self.nr_conns > 0:
|
||||
self.log.warning("Closing %d connections after timeout", self.nr_conns)
|
||||
|
||||
# Run lifespan shutdown
|
||||
if self.lifespan:
|
||||
try:
|
||||
await self.lifespan.shutdown()
|
||||
except Exception as e:
|
||||
self.log.error("ASGI lifespan shutdown error: %s", e)
|
||||
|
||||
def _get_ssl_context(self):
|
||||
"""Get SSL context if configured."""
|
||||
if not self.cfg.is_ssl:
|
||||
return None
|
||||
|
||||
try:
|
||||
from gunicorn import sock
|
||||
return sock.ssl_context(self.cfg)
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create SSL context: %s", e)
|
||||
return None
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up resources on exit."""
|
||||
try:
|
||||
# Cancel all pending tasks
|
||||
pending = asyncio.all_tasks(self.loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run loop until all tasks are cancelled
|
||||
if pending:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
|
||||
self.loop.close()
|
||||
except Exception as e:
|
||||
self.log.debug("Cleanup error: %s", e)
|
||||
|
||||
# Close sockets
|
||||
for s in self.sockets:
|
||||
try:
|
||||
s.close()
|
||||
except Exception:
|
||||
pass
|
||||
@ -58,6 +58,7 @@ testing = [
|
||||
"coverage",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-asyncio",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
285
tests/test_asgi.py
Normal file
285
tests/test_asgi.py
Normal file
@ -0,0 +1,285 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for ASGI worker components.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock asyncio.StreamReader for testing."""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.pos = 0
|
||||
|
||||
async def read(self, size=-1):
|
||||
if self.pos >= len(self.data):
|
||||
return b""
|
||||
if size < 0:
|
||||
result = self.data[self.pos:]
|
||||
self.pos = len(self.data)
|
||||
else:
|
||||
result = self.data[self.pos:self.pos + size]
|
||||
self.pos += size
|
||||
return result
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self.pos + n > len(self.data):
|
||||
raise asyncio.IncompleteReadError(
|
||||
self.data[self.pos:], n
|
||||
)
|
||||
result = self.data[self.pos:self.pos + n]
|
||||
self.pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock gunicorn config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_ssl = False
|
||||
self.proxy_protocol = False
|
||||
self.proxy_allow_ips = ["127.0.0.1"]
|
||||
self.forwarded_allow_ips = ["127.0.0.1"]
|
||||
self.secure_scheme_headers = {}
|
||||
self.forwarder_headers = []
|
||||
self.limit_request_line = 8190
|
||||
self.limit_request_fields = 100
|
||||
self.limit_request_field_size = 8190
|
||||
self.permit_unconventional_http_method = False
|
||||
self.permit_unconventional_http_version = False
|
||||
self.permit_obsolete_folding = False
|
||||
self.casefold_http_method = False
|
||||
self.strip_header_spaces = False
|
||||
self.header_map = "refuse"
|
||||
|
||||
|
||||
# AsyncUnreader Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_chunk():
|
||||
"""Test basic chunk reading."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_size():
|
||||
"""Test reading specific size."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(5)
|
||||
assert data == b"hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_unread():
|
||||
"""Test unread functionality."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
|
||||
# Read all data
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
# Unread some data
|
||||
unreader.unread(b"world")
|
||||
|
||||
# Read again should get unread data
|
||||
data = await unreader.read()
|
||||
assert data == b"world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_zero():
|
||||
"""Test reading zero bytes."""
|
||||
reader = MockStreamReader(b"hello")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(0)
|
||||
assert data == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_empty():
|
||||
"""Test reading from empty stream."""
|
||||
reader = MockStreamReader(b"")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b""
|
||||
|
||||
|
||||
# AsyncRequest Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_simple_get():
|
||||
"""Test parsing a simple GET request."""
|
||||
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/path"
|
||||
assert request.version == (1, 1)
|
||||
assert ("HOST", "localhost") in request.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_query():
|
||||
"""Test parsing request with query string."""
|
||||
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/search"
|
||||
assert request.query == "q=test&page=1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_post_with_body():
|
||||
"""Test parsing POST request with body."""
|
||||
request_data = (
|
||||
b"POST /submit HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 11\r\n"
|
||||
b"\r\n"
|
||||
b"hello=world"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/submit"
|
||||
assert request.content_length == 11
|
||||
|
||||
# Read body
|
||||
body = await request.read_body(100)
|
||||
assert body == b"hello=world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_multiple_headers():
|
||||
"""Test parsing request with multiple headers."""
|
||||
request_data = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Accept: text/html\r\n"
|
||||
b"Accept-Language: en-US\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert len(request.headers) == 4
|
||||
assert request.get_header("HOST") == "localhost"
|
||||
assert request.get_header("ACCEPT") == "text/html"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_http10():
|
||||
"""Test connection close detection for HTTP/1.0."""
|
||||
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.version == (1, 0)
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_connection_header():
|
||||
"""Test connection close detection with Connection header."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_keepalive():
|
||||
"""Test keepalive detection."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_no_body_for_get():
|
||||
"""Test that GET requests have no body by default."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.content_length == 0
|
||||
body = await request.read_body()
|
||||
assert body == b""
|
||||
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_method():
|
||||
"""Test invalid HTTP method detection."""
|
||||
from gunicorn.http.errors import InvalidRequestMethod
|
||||
|
||||
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidRequestMethod):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_http_version():
|
||||
"""Test invalid HTTP version detection."""
|
||||
from gunicorn.http.errors import InvalidHTTPVersion
|
||||
|
||||
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidHTTPVersion):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
643
tests/test_asgi_worker.py
Normal file
643
tests/test_asgi_worker.py
Normal file
@ -0,0 +1,643 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for the ASGI worker.
|
||||
|
||||
Includes unit tests for worker components and integration tests
|
||||
that actually start the server and make HTTP requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.workers import gasgi
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock Classes
|
||||
# ============================================================================
|
||||
|
||||
class FakeSocket:
|
||||
"""Mock socket for testing."""
|
||||
|
||||
def __init__(self, data=b''):
|
||||
self.data = data
|
||||
self.closed = False
|
||||
self.blocking = True
|
||||
self._fileno = id(self) % 65536
|
||||
|
||||
def fileno(self):
|
||||
return self._fileno
|
||||
|
||||
def setblocking(self, blocking):
|
||||
self.blocking = blocking
|
||||
|
||||
def recv(self, size):
|
||||
if self.closed:
|
||||
raise OSError(errno.EBADF, "Bad file descriptor")
|
||||
result = self.data[:size]
|
||||
self.data = self.data[size:]
|
||||
return result
|
||||
|
||||
def send(self, data):
|
||||
if self.closed:
|
||||
raise OSError(errno.EPIPE, "Broken pipe")
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def getsockname(self):
|
||||
return ('127.0.0.1', 8000)
|
||||
|
||||
def getpeername(self):
|
||||
return ('127.0.0.1', 12345)
|
||||
|
||||
|
||||
class FakeApp:
|
||||
"""Mock ASGI application for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def wsgi(self):
|
||||
return self.asgi_app
|
||||
|
||||
async def asgi_app(self, scope, receive, send):
|
||||
self.calls.append(scope)
|
||||
if scope["type"] == "lifespan":
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
elif scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Hello from ASGI!",
|
||||
})
|
||||
|
||||
|
||||
class FakeListener:
|
||||
"""Mock listener socket."""
|
||||
|
||||
def __init__(self):
|
||||
self.sock = FakeSocket()
|
||||
|
||||
def getsockname(self):
|
||||
return ('127.0.0.1', 8000)
|
||||
|
||||
def close(self):
|
||||
self.sock.close()
|
||||
|
||||
def __str__(self):
|
||||
return "http://127.0.0.1:8000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def _has_uvloop():
|
||||
"""Check if uvloop is available."""
|
||||
try:
|
||||
import uvloop
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests for ASGIWorker
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIWorkerInit:
|
||||
"""Tests for ASGIWorker initialization."""
|
||||
|
||||
def create_worker(self, **kwargs):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
cfg.set(key, value)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
return worker
|
||||
|
||||
def test_worker_init(self):
|
||||
"""Test worker initialization."""
|
||||
worker = self.create_worker()
|
||||
|
||||
assert worker.worker_connections == 1000
|
||||
assert worker.nr_conns == 0
|
||||
assert worker.loop is None
|
||||
assert worker.servers == []
|
||||
assert worker.state == {}
|
||||
|
||||
def test_worker_connections_config(self):
|
||||
"""Test worker_connections configuration."""
|
||||
worker = self.create_worker(worker_connections=500)
|
||||
assert worker.worker_connections == 500
|
||||
|
||||
|
||||
class TestASGIWorkerEventLoop:
|
||||
"""Tests for event loop setup."""
|
||||
|
||||
def create_worker(self, **kwargs):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
cfg.set(key, value)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
return worker
|
||||
|
||||
def test_setup_asyncio_loop(self):
|
||||
"""Test asyncio event loop setup."""
|
||||
worker = self.create_worker(asgi_loop='asyncio')
|
||||
worker._setup_event_loop()
|
||||
|
||||
assert worker.loop is not None
|
||||
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
|
||||
worker.loop.close()
|
||||
|
||||
def test_setup_auto_loop_falls_back_to_asyncio(self):
|
||||
"""Test that auto mode uses asyncio when uvloop unavailable."""
|
||||
worker = self.create_worker(asgi_loop='auto')
|
||||
|
||||
# Mock uvloop import failure
|
||||
with mock.patch.dict('sys.modules', {'uvloop': None}):
|
||||
worker._setup_event_loop()
|
||||
|
||||
assert worker.loop is not None
|
||||
worker.loop.close()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_uvloop(),
|
||||
reason="uvloop not installed"
|
||||
)
|
||||
def test_setup_uvloop(self):
|
||||
"""Test uvloop event loop setup."""
|
||||
worker = self.create_worker(asgi_loop='uvloop')
|
||||
worker._setup_event_loop()
|
||||
|
||||
import uvloop
|
||||
assert isinstance(worker.loop, uvloop.Loop)
|
||||
worker.loop.close()
|
||||
|
||||
|
||||
class TestASGIWorkerSignals:
|
||||
"""Tests for signal handling."""
|
||||
|
||||
def create_worker(self):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
cfg.set('graceful_timeout', 5)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
worker._setup_event_loop()
|
||||
return worker
|
||||
|
||||
def test_handle_exit_sets_alive_false(self):
|
||||
"""Test that exit signal sets alive=False."""
|
||||
worker = self.create_worker()
|
||||
worker.alive = True
|
||||
|
||||
worker.handle_exit_signal()
|
||||
|
||||
assert worker.alive is False
|
||||
worker.loop.close()
|
||||
|
||||
def test_handle_quit_sets_alive_false(self):
|
||||
"""Test that quit signal sets alive=False."""
|
||||
worker = self.create_worker()
|
||||
worker.alive = True
|
||||
|
||||
# Mock the worker_int callback on the worker's cfg settings
|
||||
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
|
||||
worker.handle_quit_signal()
|
||||
|
||||
assert worker.alive is False
|
||||
worker.loop.close()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Lifespan Protocol
|
||||
# ============================================================================
|
||||
|
||||
class TestLifespanManager:
|
||||
"""Tests for ASGI lifespan protocol."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_startup_complete(self):
|
||||
"""Test successful lifespan startup."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
startup_called = False
|
||||
shutdown_called = False
|
||||
|
||||
async def app(scope, receive, send):
|
||||
nonlocal startup_called, shutdown_called
|
||||
assert scope["type"] == "lifespan"
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
startup_called = True
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
shutdown_called = True
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
manager = LifespanManager(app, mock.Mock())
|
||||
await manager.startup()
|
||||
|
||||
assert startup_called
|
||||
assert manager._startup_complete.is_set()
|
||||
assert not manager._startup_failed
|
||||
|
||||
await manager.shutdown()
|
||||
assert shutdown_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_startup_failed(self):
|
||||
"""Test lifespan startup failure."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
async def app(scope, receive, send):
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({
|
||||
"type": "lifespan.startup.failed",
|
||||
"message": "Database connection failed"
|
||||
})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||
await manager.startup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_state_shared(self):
|
||||
"""Test that lifespan state is shared with app."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
state = {}
|
||||
|
||||
async def app(scope, receive, send):
|
||||
assert "state" in scope
|
||||
scope["state"]["db"] = "connected"
|
||||
message = await receive()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock(), state)
|
||||
await manager.startup()
|
||||
|
||||
assert state.get("db") == "connected"
|
||||
|
||||
await manager.shutdown()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for WebSocket Protocol
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketProtocol:
|
||||
"""Tests for WebSocket protocol handling."""
|
||||
|
||||
def test_websocket_guid(self):
|
||||
"""Test WebSocket GUID constant."""
|
||||
from gunicorn.asgi.websocket import WS_GUID
|
||||
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
def test_websocket_opcodes(self):
|
||||
"""Test WebSocket opcode constants."""
|
||||
from gunicorn.asgi import websocket
|
||||
|
||||
assert websocket.OPCODE_TEXT == 0x1
|
||||
assert websocket.OPCODE_BINARY == 0x2
|
||||
assert websocket.OPCODE_CLOSE == 0x8
|
||||
assert websocket.OPCODE_PING == 0x9
|
||||
assert websocket.OPCODE_PONG == 0xA
|
||||
|
||||
def test_websocket_accept_key_calculation(self):
|
||||
"""Test WebSocket accept key calculation per RFC 6455."""
|
||||
import base64
|
||||
import hashlib
|
||||
from gunicorn.asgi.websocket import WS_GUID
|
||||
|
||||
# Example from RFC 6455
|
||||
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
|
||||
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
|
||||
accept_key = base64.b64encode(
|
||||
hashlib.sha1(client_key + WS_GUID).digest()
|
||||
).decode("ascii")
|
||||
|
||||
assert accept_key == expected_accept
|
||||
|
||||
def test_websocket_frame_masking(self):
|
||||
"""Test WebSocket frame unmasking."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
# Create a minimal protocol instance
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
|
||||
# Test unmasking (XOR operation)
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
|
||||
|
||||
unmasked = protocol._unmask(masked_data, masking_key)
|
||||
assert unmasked == b"Hello"
|
||||
|
||||
def test_websocket_frame_masking_empty(self):
|
||||
"""Test WebSocket frame unmasking with empty payload."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
unmasked = protocol._unmask(b"", masking_key)
|
||||
assert unmasked == b""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIIntegration:
|
||||
"""Integration tests that start actual servers."""
|
||||
|
||||
@pytest.fixture
|
||||
def free_port(self):
|
||||
"""Get a free port for testing."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_request_response(self, free_port):
|
||||
"""Test basic HTTP request/response cycle."""
|
||||
# Simple ASGI app
|
||||
async def app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Hello, World!",
|
||||
})
|
||||
|
||||
# Start server
|
||||
loop = asyncio.get_event_loop()
|
||||
server = await loop.create_server(
|
||||
lambda: _TestProtocol(app),
|
||||
'127.0.0.1',
|
||||
free_port,
|
||||
)
|
||||
|
||||
try:
|
||||
# Use asyncio to make HTTP request
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
|
||||
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
|
||||
writer.write(request.encode())
|
||||
await writer.drain()
|
||||
|
||||
# Read response
|
||||
response = await reader.read(4096)
|
||||
response_text = response.decode()
|
||||
|
||||
assert "HTTP/1.1 200" in response_text
|
||||
assert "Hello, World!" in response_text
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
class _TestProtocol(asyncio.Protocol):
|
||||
"""Minimal protocol for integration testing."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
# Very simple HTTP parsing for testing
|
||||
asyncio.create_task(self._handle(data))
|
||||
|
||||
async def _handle(self, data):
|
||||
# Parse basic HTTP request
|
||||
lines = data.decode().split('\r\n')
|
||||
method, path, _ = lines[0].split(' ')
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": method,
|
||||
"path": path,
|
||||
"query_string": b"",
|
||||
"headers": [],
|
||||
"server": ("127.0.0.1", 8000),
|
||||
"client": ("127.0.0.1", 12345),
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
|
||||
async def send(message):
|
||||
if message["type"] == "http.response.start":
|
||||
status = message["status"]
|
||||
headers = message.get("headers", [])
|
||||
response = f"HTTP/1.1 {status} OK\r\n"
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode()
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
response += f"{name}: {value}\r\n"
|
||||
response += "\r\n"
|
||||
self.transport.write(response.encode())
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
self.transport.write(body)
|
||||
if not message.get("more_body", False):
|
||||
self.transport.close()
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASGI Protocol Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIProtocol:
|
||||
"""Tests for ASGIProtocol."""
|
||||
|
||||
def test_reason_phrases(self):
|
||||
"""Test HTTP reason phrase lookup."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
# Create minimal worker mock
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
|
||||
protocol = ASGIProtocol(worker)
|
||||
|
||||
assert protocol._get_reason_phrase(200) == "OK"
|
||||
assert protocol._get_reason_phrase(404) == "Not Found"
|
||||
assert protocol._get_reason_phrase(500) == "Internal Server Error"
|
||||
assert protocol._get_reason_phrase(999) == "Unknown"
|
||||
|
||||
def test_scope_building(self):
|
||||
"""Test HTTP scope building."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.cfg.set('root_path', '/api')
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
|
||||
protocol = ASGIProtocol(worker)
|
||||
|
||||
# Create mock request
|
||||
request = mock.Mock()
|
||||
request.method = "GET"
|
||||
request.path = "/users"
|
||||
request.query = "page=1"
|
||||
request.version = (1, 1)
|
||||
request.scheme = "http"
|
||||
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
|
||||
|
||||
scope = protocol._build_http_scope(
|
||||
request,
|
||||
("127.0.0.1", 8000), # sockname
|
||||
("127.0.0.1", 12345), # peername
|
||||
)
|
||||
|
||||
assert scope["type"] == "http"
|
||||
assert scope["method"] == "GET"
|
||||
assert scope["path"] == "/users"
|
||||
assert scope["query_string"] == b"page=1"
|
||||
assert scope["root_path"] == "/api"
|
||||
assert scope["http_version"] == "1.1"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIConfig:
|
||||
"""Tests for ASGI configuration options."""
|
||||
|
||||
def test_asgi_loop_default(self):
|
||||
"""Test default asgi_loop value."""
|
||||
cfg = Config()
|
||||
assert cfg.asgi_loop == "auto"
|
||||
|
||||
def test_asgi_loop_validation(self):
|
||||
"""Test asgi_loop validation."""
|
||||
cfg = Config()
|
||||
|
||||
cfg.set('asgi_loop', 'asyncio')
|
||||
assert cfg.asgi_loop == 'asyncio'
|
||||
|
||||
cfg.set('asgi_loop', 'uvloop')
|
||||
assert cfg.asgi_loop == 'uvloop'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cfg.set('asgi_loop', 'invalid')
|
||||
|
||||
def test_asgi_lifespan_default(self):
|
||||
"""Test default asgi_lifespan value."""
|
||||
cfg = Config()
|
||||
assert cfg.asgi_lifespan == "auto"
|
||||
|
||||
def test_asgi_lifespan_validation(self):
|
||||
"""Test asgi_lifespan validation."""
|
||||
cfg = Config()
|
||||
|
||||
cfg.set('asgi_lifespan', 'on')
|
||||
assert cfg.asgi_lifespan == 'on'
|
||||
|
||||
cfg.set('asgi_lifespan', 'off')
|
||||
assert cfg.asgi_lifespan == 'off'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cfg.set('asgi_lifespan', 'invalid')
|
||||
|
||||
def test_root_path_default(self):
|
||||
"""Test default root_path value."""
|
||||
cfg = Config()
|
||||
assert cfg.root_path == ""
|
||||
|
||||
def test_root_path_setting(self):
|
||||
"""Test root_path configuration."""
|
||||
cfg = Config()
|
||||
cfg.set('root_path', '/api/v1')
|
||||
assert cfg.root_path == '/api/v1'
|
||||
Loading…
x
Reference in New Issue
Block a user