Merge pull request #3444 from benoitc/asgi-worker

Add native ASGI worker and uWSGI binary protocol support
This commit is contained in:
Benoit Chesneau 2026-01-22 20:31:23 +01:00 committed by GitHub
commit 5b50487bab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 5193 additions and 6 deletions

View File

@ -0,0 +1,45 @@
name: Docker Integration Tests
on:
push:
branches: [master]
paths:
- 'gunicorn/uwsgi/**'
- 'tests/docker/uwsgi/**'
- '.github/workflows/docker-integration.yml'
pull_request:
paths:
- 'gunicorn/uwsgi/**'
- 'tests/docker/uwsgi/**'
- '.github/workflows/docker-integration.yml'
permissions:
contents: read
env:
FORCE_COLOR: 1
jobs:
uwsgi-nginx:
name: uWSGI Protocol with nginx
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
cache: pip
cache-dependency-path: requirements_test.txt
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-cov requests
- name: Run uWSGI integration tests
run: |
pytest tests/docker/uwsgi/ -v --tb=short

View File

@ -40,7 +40,7 @@ jobs:
python${{ matrix.python-version }} -m venv venv
. venv/bin/activate
pip install --upgrade pip
pip install pytest pytest-cov coverage
pip install pytest pytest-cov pytest-asyncio coverage
pip install -e .
pytest --cov=gunicorn -v tests/ \
--ignore=tests/workers/test_ggevent.py \

View File

@ -0,0 +1,7 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI example applications for gunicorn.
"""

130
examples/asgi/basic_app.py Normal file
View File

@ -0,0 +1,130 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Basic ASGI application example.
Run with:
gunicorn -k asgi examples.asgi.basic_app:app
Test with:
curl http://127.0.0.1:8000/
curl http://127.0.0.1:8000/hello
curl -X POST http://127.0.0.1:8000/echo -d "test data"
"""
async def app(scope, receive, send):
"""Simple ASGI application demonstrating basic functionality."""
if scope["type"] == "lifespan":
await handle_lifespan(scope, receive, send)
elif scope["type"] == "http":
await handle_http(scope, receive, send)
else:
raise ValueError(f"Unknown scope type: {scope['type']}")
async def handle_lifespan(scope, receive, send):
"""Handle lifespan events (startup/shutdown)."""
while True:
message = await receive()
if message["type"] == "lifespan.startup":
print("ASGI application starting up...")
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
print("ASGI application shutting down...")
await send({"type": "lifespan.shutdown.complete"})
return
async def handle_http(scope, receive, send):
"""Handle HTTP requests."""
path = scope["path"]
method = scope["method"]
if path == "/" and method == "GET":
await send_response(send, 200, b"Welcome to gunicorn ASGI!\n")
elif path == "/hello" and method == "GET":
name = get_query_param(scope, "name", "World")
body = f"Hello, {name}!\n".encode()
await send_response(send, 200, body)
elif path == "/echo" and method == "POST":
body = await read_body(receive)
await send_response(send, 200, body, content_type=b"application/octet-stream")
elif path == "/headers":
headers_info = format_headers(scope["headers"])
await send_response(send, 200, headers_info.encode())
elif path == "/info":
info = format_request_info(scope)
await send_response(send, 200, info.encode(), content_type=b"application/json")
else:
await send_response(send, 404, b"Not Found\n")
async def send_response(send, status, body, content_type=b"text/plain"):
"""Send an HTTP response."""
await send({
"type": "http.response.start",
"status": status,
"headers": [
(b"content-type", content_type),
(b"content-length", str(len(body)).encode()),
],
})
await send({
"type": "http.response.body",
"body": body,
})
async def read_body(receive):
"""Read the full request body."""
body = b""
while True:
message = await receive()
body += message.get("body", b"")
if not message.get("more_body", False):
break
return body
def get_query_param(scope, name, default=None):
"""Get a query parameter value."""
query_string = scope.get("query_string", b"").decode()
for param in query_string.split("&"):
if "=" in param:
key, value = param.split("=", 1)
if key == name:
return value
return default
def format_headers(headers):
"""Format headers for display."""
lines = ["Request Headers:"]
for name, value in headers:
lines.append(f" {name.decode()}: {value.decode()}")
return "\n".join(lines) + "\n"
def format_request_info(scope):
"""Format request info as JSON."""
import json
info = {
"method": scope["method"],
"path": scope["path"],
"query_string": scope.get("query_string", b"").decode(),
"http_version": scope["http_version"],
"scheme": scope["scheme"],
"server": list(scope.get("server") or []),
"client": list(scope.get("client") or []),
"root_path": scope.get("root_path", ""),
}
return json.dumps(info, indent=2) + "\n"

View File

@ -0,0 +1,235 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
WebSocket ASGI application example.
Run with:
gunicorn -k asgi examples.asgi.websocket_app:app
Test with:
# Using websocat (install with: cargo install websocat)
websocat ws://127.0.0.1:8000/ws
# Or using Python websockets library
python -c "
import asyncio
import websockets
async def test():
async with websockets.connect('ws://127.0.0.1:8000/ws') as ws:
await ws.send('Hello')
print(await ws.recv())
asyncio.run(test())
"
"""
async def app(scope, receive, send):
"""ASGI application with WebSocket support."""
if scope["type"] == "lifespan":
await handle_lifespan(scope, receive, send)
elif scope["type"] == "http":
await handle_http(scope, receive, send)
elif scope["type"] == "websocket":
await handle_websocket(scope, receive, send)
else:
raise ValueError(f"Unknown scope type: {scope['type']}")
async def handle_lifespan(scope, receive, send):
"""Handle lifespan events."""
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
async def handle_http(scope, receive, send):
"""Handle HTTP requests - serve a simple HTML page for WebSocket testing."""
path = scope["path"]
if path == "/":
html = HTML_PAGE.encode()
await send({
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/html"),
(b"content-length", str(len(html)).encode()),
],
})
await send({
"type": "http.response.body",
"body": html,
})
else:
await send({
"type": "http.response.start",
"status": 404,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Not Found",
})
async def handle_websocket(scope, receive, send):
"""Handle WebSocket connections."""
path = scope["path"]
if path == "/ws":
await echo_websocket(scope, receive, send)
elif path == "/ws/chat":
await chat_websocket(scope, receive, send)
else:
# Reject the connection
await send({"type": "websocket.close", "code": 4004})
async def echo_websocket(scope, receive, send):
"""Echo WebSocket - sends back whatever it receives."""
# Wait for connection
message = await receive()
if message["type"] != "websocket.connect":
return
# Accept the connection
await send({"type": "websocket.accept"})
# Echo loop
try:
while True:
message = await receive()
if message["type"] == "websocket.disconnect":
break
if message["type"] == "websocket.receive":
if "text" in message:
# Echo text back
await send({
"type": "websocket.send",
"text": f"Echo: {message['text']}"
})
elif "bytes" in message:
# Echo bytes back
await send({
"type": "websocket.send",
"bytes": message["bytes"]
})
except Exception as e:
print(f"WebSocket error: {e}")
finally:
try:
await send({"type": "websocket.close", "code": 1000})
except Exception:
pass
async def chat_websocket(scope, receive, send):
"""Chat WebSocket - simple broadcast example."""
message = await receive()
if message["type"] != "websocket.connect":
return
await send({
"type": "websocket.accept",
"subprotocol": "chat"
})
await send({
"type": "websocket.send",
"text": "Welcome to the chat! Send messages and they will be echoed back."
})
try:
while True:
message = await receive()
if message["type"] == "websocket.disconnect":
break
if message["type"] == "websocket.receive" and "text" in message:
text = message["text"]
await send({
"type": "websocket.send",
"text": f"[You]: {text}"
})
except Exception:
pass
HTML_PAGE = """<!DOCTYPE html>
<html>
<head>
<title>WebSocket Test</title>
<style>
body { font-family: sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
#messages { border: 1px solid #ccc; height: 300px; overflow-y: auto; padding: 10px; margin-bottom: 10px; }
#input { width: 80%; padding: 10px; }
button { padding: 10px 20px; }
.sent { color: blue; }
.received { color: green; }
.error { color: red; }
</style>
</head>
<body>
<h1>WebSocket Test</h1>
<div id="messages"></div>
<input type="text" id="input" placeholder="Type a message...">
<button onclick="sendMessage()">Send</button>
<button onclick="connectWS()">Connect</button>
<button onclick="disconnectWS()">Disconnect</button>
<script>
let ws = null;
const messages = document.getElementById('messages');
const input = document.getElementById('input');
function log(msg, className) {
const div = document.createElement('div');
div.className = className || '';
div.textContent = msg;
messages.appendChild(div);
messages.scrollTop = messages.scrollHeight;
}
function connectWS() {
if (ws) {
log('Already connected', 'error');
return;
}
ws = new WebSocket('ws://' + window.location.host + '/ws');
ws.onopen = () => log('Connected!', 'received');
ws.onclose = () => { log('Disconnected', 'error'); ws = null; };
ws.onerror = (e) => log('Error: ' + e, 'error');
ws.onmessage = (e) => log(e.data, 'received');
}
function disconnectWS() {
if (ws) ws.close();
}
function sendMessage() {
if (!ws) { log('Not connected', 'error'); return; }
const msg = input.value;
if (!msg) return;
ws.send(msg);
log('Sent: ' + msg, 'sent');
input.value = '';
}
input.onkeypress = (e) => { if (e.key === 'Enter') sendMessage(); };
// Auto-connect
connectWS();
</script>
</body>
</html>
"""

26
gunicorn/asgi/__init__.py Normal file
View File

@ -0,0 +1,26 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI support for gunicorn.
This module provides native ASGI worker support, using gunicorn's own
HTTP parsing infrastructure adapted for async I/O.
Components:
- AsyncUnreader: Async socket reading with pushback buffer
- AsyncRequest: Async HTTP request parser
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
- LifespanManager: ASGI lifespan protocol support
Usage:
gunicorn -k asgi myapp:app
"""
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
from gunicorn.asgi.lifespan import LifespanManager
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']

178
gunicorn/asgi/lifespan.py Normal file
View File

@ -0,0 +1,178 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI lifespan protocol manager.
Manages startup and shutdown events for ASGI applications,
enabling frameworks like FastAPI to run initialization and
cleanup code.
"""
import asyncio
class LifespanManager:
"""Manages ASGI lifespan events (startup/shutdown).
The lifespan protocol allows ASGI applications to run code at
startup and shutdown. This is essential for applications that
need to initialize database connections, caches, or other
resources.
ASGI lifespan messages:
- Server sends: {"type": "lifespan.startup"}
- App responds: {"type": "lifespan.startup.complete"} or
{"type": "lifespan.startup.failed", "message": "..."}
- Server sends: {"type": "lifespan.shutdown"}
- App responds: {"type": "lifespan.shutdown.complete"}
"""
def __init__(self, app, logger, state=None):
"""Initialize the lifespan manager.
Args:
app: ASGI application callable
logger: Logger instance
state: Shared state dict for the application
"""
self.app = app
self.logger = logger
self.state = state if state is not None else {}
self._startup_complete = asyncio.Event()
self._shutdown_complete = asyncio.Event()
self._startup_failed = False
self._startup_error = None
self._shutdown_error = None
self._receive_queue = asyncio.Queue()
self._task = None
self._app_finished = False
async def startup(self):
"""Run lifespan startup and wait for completion.
Raises:
RuntimeError: If startup fails or app doesn't support lifespan
"""
scope = {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"state": self.state,
}
# Send startup event
await self._receive_queue.put({"type": "lifespan.startup"})
# Run lifespan in background task
self._task = asyncio.create_task(self._run_lifespan(scope))
# Wait for startup with timeout
try:
await asyncio.wait_for(
self._startup_complete.wait(),
timeout=30.0 # Reasonable startup timeout
)
except asyncio.TimeoutError:
if self._task:
self._task.cancel()
raise RuntimeError("Lifespan startup timed out")
if self._startup_failed:
if self._task:
self._task.cancel()
msg = self._startup_error or "Unknown error"
raise RuntimeError(f"Lifespan startup failed: {msg}")
self.logger.debug("ASGI lifespan startup complete")
async def shutdown(self):
"""Signal shutdown and wait for completion.
This should be called during graceful shutdown.
"""
if self._app_finished:
self.logger.debug("ASGI lifespan already finished")
return
# Send shutdown event
await self._receive_queue.put({"type": "lifespan.shutdown"})
# Wait for shutdown with timeout
try:
await asyncio.wait_for(
self._shutdown_complete.wait(),
timeout=30.0 # Reasonable shutdown timeout
)
except asyncio.TimeoutError:
self.logger.warning("Lifespan shutdown timed out")
if self._shutdown_error:
self.logger.error("Lifespan shutdown error: %s", self._shutdown_error)
# Cancel the task if still running
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self.logger.debug("ASGI lifespan shutdown complete")
async def _run_lifespan(self, scope):
"""Run the ASGI lifespan protocol."""
try:
await self.app(scope, self._receive, self._send)
except asyncio.CancelledError:
raise
except Exception as e:
self.logger.debug("Lifespan application raised: %s", e)
# If startup hasn't completed, mark it as failed
if not self._startup_complete.is_set():
self._startup_failed = True
self._startup_error = str(e)
self._startup_complete.set()
# If shutdown hasn't completed, mark error
elif not self._shutdown_complete.is_set():
self._shutdown_error = str(e)
self._shutdown_complete.set()
finally:
self._app_finished = True
# Ensure events are set to unblock waiters
if not self._startup_complete.is_set():
self._startup_failed = True
self._startup_error = "Application exited before startup complete"
self._startup_complete.set()
if not self._shutdown_complete.is_set():
self._shutdown_complete.set()
async def _receive(self):
"""ASGI receive callable for lifespan."""
return await self._receive_queue.get()
async def _send(self, message):
"""ASGI send callable for lifespan."""
msg_type = message["type"]
if msg_type == "lifespan.startup.complete":
self._startup_complete.set()
self.logger.debug("Received lifespan.startup.complete")
elif msg_type == "lifespan.startup.failed":
self._startup_failed = True
self._startup_error = message.get("message", "")
self._startup_complete.set()
self.logger.debug("Received lifespan.startup.failed: %s",
self._startup_error)
elif msg_type == "lifespan.shutdown.complete":
self._shutdown_complete.set()
self.logger.debug("Received lifespan.shutdown.complete")
elif msg_type == "lifespan.shutdown.failed":
self._shutdown_error = message.get("message", "")
self._shutdown_complete.set()
self.logger.debug("Received lifespan.shutdown.failed: %s",
self._shutdown_error)

562
gunicorn/asgi/message.py Normal file
View File

@ -0,0 +1,562 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Async version of gunicorn/http/message.py for ASGI workers.
Reuses the parsing logic from the sync version, adapted for async I/O.
"""
import io
import re
import socket
from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData,
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
LimitRequestLine, LimitRequestHeaders,
UnsupportedTransferCoding, ObsoleteFolding,
InvalidProxyLine, ForbiddenProxyRequest,
InvalidSchemeHeaders,
)
from gunicorn.util import bytes_to_str, split_request_uri
MAX_REQUEST_LINE = 8190
MAX_HEADERS = 32768
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
# Reuse regex patterns from sync version
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
METHOD_BADCHAR_RE = re.compile("[a-z#]")
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
class AsyncRequest:
"""Async HTTP request parser.
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
parsing logic where possible.
"""
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.remote_addr = peer_addr
self.req_number = req_number
self.version = None
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = None
self.headers = []
self.trailers = []
self.scheme = "https" if cfg.is_ssl else "http"
self.must_close = False
self.proxy_protocol_info = None
# Request line limit
self.limit_request_line = cfg.limit_request_line
if (self.limit_request_line < 0
or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE
# Headers limits
self.limit_request_fields = cfg.limit_request_fields
if (self.limit_request_fields <= 0
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = cfg.limit_request_field_size
if self.limit_request_field_size < 0:
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
# Max header buffer size
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
self.max_buffer_headers = self.limit_request_fields * \
(max_header_field_size + 2) + 4
# Body-related state
self.content_length = None
self.chunked = False
self._body_reader = None
self._body_remaining = 0
@classmethod
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
"""Parse an HTTP request from the stream.
Args:
cfg: gunicorn config object
unreader: AsyncUnreader instance
peer_addr: client address tuple
req_number: request number on this connection (for keepalive)
Returns:
AsyncRequest: Parsed request object
Raises:
NoMoreData: If no data available
Various parsing errors for malformed requests
"""
req = cls(cfg, unreader, peer_addr, req_number)
await req._parse()
return req
async def _parse(self):
"""Parse the request from the unreader."""
buf = io.BytesIO()
await self._get_data(buf, stop=True)
# Get request line
line, rbuf = await self._read_line(buf, self.limit_request_line)
# Proxy protocol
if self._proxy_protocol(bytes_to_str(line)):
# Get next request line
buf = io.BytesIO()
buf.write(rbuf)
line, rbuf = await self._read_line(buf, self.limit_request_line)
self._parse_request_line(line)
buf = io.BytesIO()
buf.write(rbuf)
# Headers
data = buf.getvalue()
while True:
idx = data.find(b"\r\n\r\n")
done = data[:2] == b"\r\n"
if idx < 0 and not done:
await self._get_data(buf)
data = buf.getvalue()
if len(data) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers")
else:
break
if done:
self.unreader.unread(data[2:])
else:
self.headers = self._parse_headers(data[:idx], from_trailer=False)
self.unreader.unread(data[idx + 4:])
self._set_body_reader()
async def _get_data(self, buf, stop=False):
"""Read data from unreader into buffer."""
data = await self.unreader.read()
if not data:
if stop:
raise StopIteration()
raise NoMoreData(buf.getvalue())
buf.write(data)
async def _read_line(self, buf, limit=0):
"""Read a line from the buffer/stream."""
data = buf.getvalue()
while True:
idx = data.find(b"\r\n")
if idx >= 0:
if idx > limit > 0:
raise LimitRequestLine(idx, limit)
break
if len(data) - 2 > limit > 0:
raise LimitRequestLine(len(data), limit)
await self._get_data(buf)
data = buf.getvalue()
return (data[:idx], data[idx + 2:])
def _proxy_protocol(self, line):
"""Detect, check and parse proxy protocol."""
if not self.cfg.proxy_protocol:
return False
if self.req_number != 1:
return False
if not line.startswith("PROXY"):
return False
self._proxy_protocol_access_check()
self._parse_proxy_protocol(line)
return True
def _proxy_protocol_access_check(self):
"""Check if proxy protocol is allowed from this peer."""
if ("*" not in self.cfg.proxy_allow_ips and
isinstance(self.peer_addr, tuple) and
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(self.peer_addr[0])
def _parse_proxy_protocol(self, line):
"""Parse proxy protocol header line."""
bits = line.split(" ")
if len(bits) != 6:
raise InvalidProxyLine(line)
proto = bits[1]
s_addr = bits[2]
d_addr = bits[3]
if proto not in ["TCP4", "TCP6"]:
raise InvalidProxyLine("protocol '%s' not supported" % proto)
if proto == "TCP4":
try:
socket.inet_pton(socket.AF_INET, s_addr)
socket.inet_pton(socket.AF_INET, d_addr)
except OSError:
raise InvalidProxyLine(line)
elif proto == "TCP6":
try:
socket.inet_pton(socket.AF_INET6, s_addr)
socket.inet_pton(socket.AF_INET6, d_addr)
except OSError:
raise InvalidProxyLine(line)
try:
s_port = int(bits[4])
d_port = int(bits[5])
except ValueError:
raise InvalidProxyLine("invalid port %s" % line)
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
raise InvalidProxyLine("invalid port %s" % line)
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
def _parse_request_line(self, line_bytes):
"""Parse the HTTP request line."""
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
if len(bits) != 3:
raise InvalidRequestLine(bytes_to_str(line_bytes))
# Method
self.method = bits[0]
if not self.cfg.permit_unconventional_http_method:
if METHOD_BADCHAR_RE.search(self.method):
raise InvalidRequestMethod(self.method)
if not 3 <= len(bits[0]) <= 20:
raise InvalidRequestMethod(self.method)
if not TOKEN_RE.fullmatch(self.method):
raise InvalidRequestMethod(self.method)
if self.cfg.casefold_http_method:
self.method = self.method.upper()
# URI
self.uri = bits[1]
if len(self.uri) == 0:
raise InvalidRequestLine(bytes_to_str(line_bytes))
try:
parts = split_request_uri(self.uri)
except ValueError:
raise InvalidRequestLine(bytes_to_str(line_bytes))
self.path = parts.path or ""
self.query = parts.query or ""
self.fragment = parts.fragment or ""
# Version
match = VERSION_RE.fullmatch(bits[2])
if match is None:
raise InvalidHTTPVersion(bits[2])
self.version = (int(match.group(1)), int(match.group(2)))
if not (1, 0) <= self.version < (2, 0):
if not self.cfg.permit_unconventional_http_version:
raise InvalidHTTPVersion(self.version)
def _parse_headers(self, data, from_trailer=False):
"""Parse HTTP headers from raw data."""
cfg = self.cfg
headers = []
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
# Handle scheme headers
scheme_header = False
secure_scheme_headers = {}
forwarder_headers = []
if from_trailer:
pass
elif ('*' in cfg.forwarded_allow_ips or
not isinstance(self.peer_addr, tuple)
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers
forwarder_headers = cfg.forwarder_headers
while lines:
if len(headers) >= self.limit_request_fields:
raise LimitRequestHeaders("limit request headers fields")
curr = lines.pop(0)
header_length = len(curr) + len("\r\n")
if curr.find(":") <= 0:
raise InvalidHeader(curr)
name, value = curr.split(":", 1)
if self.cfg.strip_header_spaces:
name = name.rstrip(" \t")
if not TOKEN_RE.fullmatch(name):
raise InvalidHeaderName(name)
name = name.upper()
value = [value.strip(" \t")]
# Consume value continuation lines
while lines and lines[0].startswith((" ", "\t")):
if not self.cfg.permit_obsolete_folding:
raise ObsoleteFolding(name)
curr = lines.pop(0)
header_length += len(curr) + len("\r\n")
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
value.append(curr.strip("\t "))
value = " ".join(value)
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
raise InvalidHeader(name)
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
if name in secure_scheme_headers:
secure = value == secure_scheme_headers[name]
scheme = "https" if secure else "http"
if scheme_header:
if scheme != self.scheme:
raise InvalidSchemeHeaders()
else:
scheme_header = True
self.scheme = scheme
if "_" in name:
if name in forwarder_headers or "*" in forwarder_headers:
pass
elif self.cfg.header_map == "dangerous":
pass
elif self.cfg.header_map == "drop":
continue
else:
raise InvalidHeaderName(name)
headers.append((name, value))
return headers
def _set_body_reader(self):
"""Determine how to read the request body."""
chunked = False
content_length = None
for (name, value) in self.headers:
if name == "CONTENT-LENGTH":
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
content_length = value
elif name == "TRANSFER-ENCODING":
vals = [v.strip() for v in value.split(',')]
for val in vals:
if val.lower() == "chunked":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
chunked = True
elif val.lower() == "identity":
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
elif val.lower() in ('compress', 'deflate', 'gzip'):
if chunked:
raise InvalidHeader("TRANSFER-ENCODING", req=self)
self.force_close()
else:
raise UnsupportedTransferCoding(value)
if chunked:
if self.version < (1, 1):
raise InvalidHeader("TRANSFER-ENCODING", req=self)
if content_length is not None:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.chunked = True
self.content_length = None
self._body_remaining = -1
elif content_length is not None:
try:
if str(content_length).isnumeric():
content_length = int(content_length)
else:
raise InvalidHeader("CONTENT-LENGTH", req=self)
except ValueError:
raise InvalidHeader("CONTENT-LENGTH", req=self)
if content_length < 0:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.content_length = content_length
self._body_remaining = content_length
else:
# No body for requests without Content-Length or Transfer-Encoding
self.content_length = 0
self._body_remaining = 0
def force_close(self):
"""Mark connection for closing after this request."""
self.must_close = True
def should_close(self):
"""Check if connection should be closed after this request."""
if self.must_close:
return True
for (h, v) in self.headers:
if h == "CONNECTION":
v = v.lower().strip(" \t")
if v == "close":
return True
elif v == "keep-alive":
return False
break
return self.version <= (1, 0)
def get_header(self, name):
"""Get a header value by name (case-insensitive)."""
name = name.upper()
for (h, v) in self.headers:
if h == name:
return v
return None
async def read_body(self, size=8192):
"""Read a chunk of the request body.
Args:
size: Maximum bytes to read
Returns:
bytes: Body data, empty bytes when body is exhausted
"""
if self._body_remaining == 0:
return b""
if self.chunked:
return await self._read_chunked_body(size)
else:
return await self._read_length_body(size)
async def _read_length_body(self, size):
"""Read from a length-delimited body."""
if self._body_remaining <= 0:
return b""
to_read = min(size, self._body_remaining)
data = await self.unreader.read(to_read)
if data:
self._body_remaining -= len(data)
return data
async def _read_chunked_body(self, size):
"""Read from a chunked body."""
if self._body_reader is None:
self._body_reader = self._chunked_body_reader()
try:
return await anext(self._body_reader)
except StopAsyncIteration:
self._body_remaining = 0
return b""
async def _chunked_body_reader(self):
"""Async generator for reading chunked body."""
while True:
# Read chunk size line
size_line = await self._read_chunk_size_line()
# Parse chunk size (handle extensions)
chunk_size, *_ = size_line.split(b";", 1)
if _:
chunk_size = chunk_size.rstrip(b" \t")
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
raise InvalidHeader("Invalid chunk size")
if len(chunk_size) == 0:
raise InvalidHeader("Invalid chunk size")
chunk_size = int(chunk_size, 16)
if chunk_size == 0:
# Final chunk - skip trailers and final CRLF
await self._skip_trailers()
return
# Read chunk data
remaining = chunk_size
while remaining > 0:
data = await self.unreader.read(min(remaining, 8192))
if not data:
raise NoMoreData()
remaining -= len(data)
yield data
# Skip chunk terminating CRLF
crlf = await self.unreader.read(2)
if crlf != b"\r\n":
# May have partial read, try to get the rest
while len(crlf) < 2:
more = await self.unreader.read(2 - len(crlf))
if not more:
break
crlf += more
if crlf != b"\r\n":
raise InvalidHeader("Missing chunk terminator")
async def _read_chunk_size_line(self):
"""Read a chunk size line."""
buf = io.BytesIO()
while True:
data = await self.unreader.read(1)
if not data:
raise NoMoreData()
buf.write(data)
if buf.getvalue().endswith(b"\r\n"):
return buf.getvalue()[:-2]
async def _skip_trailers(self):
"""Skip trailer headers after chunked body."""
buf = io.BytesIO()
while True:
data = await self.unreader.read(1)
if not data:
return
buf.write(data)
content = buf.getvalue()
if content.endswith(b"\r\n\r\n"):
# Could parse trailers here if needed
return
if content == b"\r\n":
return
async def drain_body(self):
"""Drain any unread body data.
Should be called before reusing connection for keepalive.
"""
while True:
data = await self.read_body(8192)
if not data:
break

470
gunicorn/asgi/protocol.py Normal file
View File

@ -0,0 +1,470 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI protocol handler for gunicorn.
Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch
to ASGI applications.
"""
import asyncio
from datetime import datetime
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
from gunicorn.http.errors import NoMoreData
class ASGIResponseInfo:
"""Simple container for ASGI response info for access logging."""
def __init__(self, status, headers, sent):
self.status = status
self.sent = sent
# Convert headers to list of string tuples for logging
self.headers = []
for name, value in headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
self.headers.append((name, value))
class ASGIProtocol(asyncio.Protocol):
"""HTTP/1.1 protocol handler for ASGI applications.
Handles connection lifecycle, request parsing, and ASGI app invocation.
"""
def __init__(self, worker):
self.worker = worker
self.cfg = worker.cfg
self.log = worker.log
self.app = worker.asgi
self.transport = None
self.reader = None
self.writer = None
self._task = None
self.req_count = 0
# Connection state
self._closed = False
def connection_made(self, transport):
"""Called when a connection is established."""
self.transport = transport
self.worker.nr_conns += 1
# Create stream reader/writer
self.reader = asyncio.StreamReader()
self.writer = transport
# Start handling requests
self._task = self.worker.loop.create_task(self._handle_connection())
def data_received(self, data):
"""Called when data is received on the connection."""
if self.reader:
self.reader.feed_data(data)
def connection_lost(self, exc):
"""Called when the connection is lost or closed."""
self._closed = True
self.worker.nr_conns -= 1
if self.reader:
self.reader.feed_eof()
if self._task and not self._task.done():
self._task.cancel()
async def _handle_connection(self):
"""Main request handling loop for this connection."""
unreader = AsyncUnreader(self.reader)
try:
peername = self.transport.get_extra_info('peername')
sockname = self.transport.get_extra_info('sockname')
while not self._closed:
self.req_count += 1
try:
# Parse HTTP request
request = await AsyncRequest.parse(
self.cfg,
unreader,
peername,
self.req_count
)
except StopIteration:
# No more data, close connection
break
except NoMoreData:
# Client disconnected
break
# Check for WebSocket upgrade
if self._is_websocket_upgrade(request):
await self._handle_websocket(request, sockname, peername)
break # WebSocket takes over the connection
# Handle HTTP request
keepalive = await self._handle_http_request(
request, sockname, peername
)
# Increment worker request count
self.worker.nr += 1
# Check max_requests
if self.worker.nr >= self.worker.max_requests:
self.log.info("Autorestarting worker after current request.")
self.worker.alive = False
keepalive = False
if not keepalive or not self.worker.alive:
break
# Check connection limits for keepalive
if not self.cfg.keepalive:
break
# Drain any unread body before next request
await request.drain_body()
except asyncio.CancelledError:
pass
except Exception as e:
self.log.exception("Error handling connection: %s", e)
finally:
self._close_transport()
def _is_websocket_upgrade(self, request):
"""Check if request is a WebSocket upgrade.
Per RFC 6455 Section 4.1, the opening handshake requires:
- HTTP method MUST be GET
- Upgrade header MUST be "websocket" (case-insensitive)
- Connection header MUST contain "Upgrade"
"""
# RFC 6455: The method of the request MUST be GET
if request.method != "GET":
return False
upgrade = None
connection = None
for name, value in request.headers:
if name == "UPGRADE":
upgrade = value.lower()
elif name == "CONNECTION":
connection = value.lower()
return upgrade == "websocket" and connection and "upgrade" in connection
async def _handle_websocket(self, request, sockname, peername):
"""Handle WebSocket upgrade request."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = self._build_websocket_scope(request, sockname, peername)
ws_protocol = WebSocketProtocol(
self.transport, self.reader, scope, self.app, self.log
)
await ws_protocol.run()
async def _handle_http_request(self, request, sockname, peername):
"""Handle a single HTTP request."""
scope = self._build_http_scope(request, sockname, peername)
response_started = False
response_complete = False
exc_to_raise = None
# Response tracking for access logging
response_status = 500
response_headers = []
response_sent = 0
# Receive queue for body
receive_queue = asyncio.Queue()
# Pre-populate with initial body state
if request.content_length == 0 and not request.chunked:
await receive_queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
else:
# Start body reading task
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
async def receive():
return await receive_queue.get()
async def send(message):
nonlocal response_started, response_complete, exc_to_raise
nonlocal response_status, response_headers, response_sent
msg_type = message["type"]
if msg_type == "http.response.start":
if response_started:
exc_to_raise = RuntimeError("Response already started")
return
response_started = True
response_status = message["status"]
response_headers = message.get("headers", [])
await self._send_response_start(response_status, response_headers, request)
elif msg_type == "http.response.body":
if not response_started:
exc_to_raise = RuntimeError("Response not started")
return
if response_complete:
exc_to_raise = RuntimeError("Response already complete")
return
body = message.get("body", b"")
more_body = message.get("more_body", False)
if body:
await self._send_body(body)
response_sent += len(body)
if not more_body:
response_complete = True
# Build environ for logging
environ = self._build_environ(request, sockname, peername)
resp = None
try:
request_start = datetime.now()
self.cfg.pre_request(self.worker, request)
await self.app(scope, receive, send)
if exc_to_raise is not None:
raise exc_to_raise
# Ensure response was sent
if not response_started:
await self._send_error_response(500, "Internal Server Error")
response_status = 500
except Exception:
self.log.exception("Error in ASGI application")
if not response_started:
await self._send_error_response(500, "Internal Server Error")
response_status = 500
return False
finally:
try:
request_time = datetime.now() - request_start
# Create response info for logging
resp = ASGIResponseInfo(response_status, response_headers, response_sent)
self.log.access(resp, request, environ, request_time)
self.cfg.post_request(self.worker, request, environ, resp)
except Exception:
self.log.exception("Exception in post_request hook")
# Determine keepalive
if request.should_close():
return False
return self.worker.alive and self.cfg.keepalive
async def _read_body_to_queue(self, request, queue):
"""Read request body and put chunks on the queue."""
try:
while True:
chunk = await request.read_body(65536)
if chunk:
await queue.put({
"type": "http.request",
"body": chunk,
"more_body": True,
})
else:
await queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
break
except Exception as e:
self.log.debug("Error reading body: %s", e)
await queue.put({
"type": "http.request",
"body": b"",
"more_body": False,
})
def _build_http_scope(self, request, sockname, peername):
"""Build ASGI HTTP scope from parsed request."""
# Build headers list as bytes tuples
headers = []
for name, value in request.headers:
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
scope = {
"type": "http",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": f"{request.version[0]}.{request.version[1]}",
"method": request.method,
"scheme": request.scheme,
"path": request.path,
"raw_path": request.path.encode("latin-1") if request.path else b"",
"query_string": request.query.encode("latin-1") if request.query else b"",
"root_path": self.cfg.root_path or "",
"headers": headers,
"server": sockname if sockname else None,
"client": peername if peername else None,
}
# Add state dict for lifespan sharing
if hasattr(self.worker, 'state'):
scope["state"] = self.worker.state
return scope
def _build_environ(self, request, sockname, peername):
"""Build minimal WSGI-like environ dict for access logging."""
environ = {
"REQUEST_METHOD": request.method,
"RAW_URI": request.uri,
"PATH_INFO": request.path,
"QUERY_STRING": request.query or "",
"SERVER_PROTOCOL": f"HTTP/{request.version[0]}.{request.version[1]}",
"REMOTE_ADDR": peername[0] if peername else "-",
}
# Add HTTP headers as environ vars
for name, value in request.headers:
key = "HTTP_" + name.replace("-", "_")
environ[key] = value
return environ
def _build_websocket_scope(self, request, sockname, peername):
"""Build ASGI WebSocket scope from parsed request."""
# Build headers list as bytes tuples
headers = []
for name, value in request.headers:
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
# Extract subprotocols from Sec-WebSocket-Protocol header
subprotocols = []
for name, value in request.headers:
if name == "SEC-WEBSOCKET-PROTOCOL":
subprotocols = [s.strip() for s in value.split(",")]
break
scope = {
"type": "websocket",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"http_version": f"{request.version[0]}.{request.version[1]}",
"scheme": "wss" if request.scheme == "https" else "ws",
"path": request.path,
"raw_path": request.path.encode("latin-1") if request.path else b"",
"query_string": request.query.encode("latin-1") if request.query else b"",
"root_path": self.cfg.root_path or "",
"headers": headers,
"server": sockname if sockname else None,
"client": peername if peername else None,
"subprotocols": subprotocols,
}
# Add state dict for lifespan sharing
if hasattr(self.worker, 'state'):
scope["state"] = self.worker.state
return scope
async def _send_response_start(self, status, headers, request):
"""Send HTTP response status and headers."""
# Build status line
reason = self._get_reason_phrase(status)
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
# Build headers
header_lines = []
for name, value in headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
header_lines.append(f"{name}: {value}\r\n")
# Add server header if not present
header_lines.append("Server: gunicorn/asgi\r\n")
response = status_line + "".join(header_lines) + "\r\n"
self.transport.write(response.encode("latin-1"))
async def _send_body(self, body):
"""Send response body chunk."""
if body:
self.transport.write(body)
async def _send_error_response(self, status, message):
"""Send an error response."""
body = message.encode("utf-8")
response = (
f"HTTP/1.1 {status} {message}\r\n"
f"Content-Type: text/plain\r\n"
f"Content-Length: {len(body)}\r\n"
f"Connection: close\r\n"
f"\r\n"
)
self.transport.write(response.encode("latin-1"))
self.transport.write(body)
def _get_reason_phrase(self, status):
"""Get HTTP reason phrase for status code."""
reasons = {
100: "Continue",
101: "Switching Protocols",
200: "OK",
201: "Created",
202: "Accepted",
204: "No Content",
206: "Partial Content",
301: "Moved Permanently",
302: "Found",
303: "See Other",
304: "Not Modified",
307: "Temporary Redirect",
308: "Permanent Redirect",
400: "Bad Request",
401: "Unauthorized",
403: "Forbidden",
404: "Not Found",
405: "Method Not Allowed",
408: "Request Timeout",
409: "Conflict",
410: "Gone",
411: "Length Required",
413: "Payload Too Large",
414: "URI Too Long",
415: "Unsupported Media Type",
422: "Unprocessable Entity",
429: "Too Many Requests",
500: "Internal Server Error",
501: "Not Implemented",
502: "Bad Gateway",
503: "Service Unavailable",
504: "Gateway Timeout",
}
return reasons.get(status, "Unknown")
def _close_transport(self):
"""Close the transport safely."""
if self.transport and not self._closed:
try:
self.transport.close()
except Exception:
pass
self._closed = True

100
gunicorn/asgi/unreader.py Normal file
View File

@ -0,0 +1,100 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Async version of gunicorn/http/unreader.py for ASGI workers.
Provides async reading with pushback buffer support.
"""
import io
class AsyncUnreader:
"""Async socket reader with pushback buffer support.
This class wraps an asyncio StreamReader and provides the ability
to "unread" data back into a buffer for re-parsing.
"""
def __init__(self, reader, max_chunk=8192):
"""Initialize the async unreader.
Args:
reader: asyncio.StreamReader instance
max_chunk: Maximum bytes to read at once
"""
self.reader = reader
self.buf = io.BytesIO()
self.max_chunk = max_chunk
async def read(self, size=None):
"""Read data from the stream, using buffered data first.
Args:
size: Number of bytes to read. If None, returns all buffered
data or reads a single chunk.
Returns:
bytes: Data read from buffer or stream
"""
if size is not None and not isinstance(size, int):
raise TypeError("size parameter must be an int or long.")
if size is not None:
if size == 0:
return b""
if size < 0:
size = None
# Move to end to check buffer size
self.buf.seek(0, io.SEEK_END)
# If no size specified, return buffered data or read chunk
if size is None and self.buf.tell():
ret = self.buf.getvalue()
self.buf = io.BytesIO()
return ret
if size is None:
chunk = await self._read_chunk()
return chunk
# Read until we have enough data
while self.buf.tell() < size:
chunk = await self._read_chunk()
if not chunk:
ret = self.buf.getvalue()
self.buf = io.BytesIO()
return ret
self.buf.write(chunk)
data = self.buf.getvalue()
self.buf = io.BytesIO()
self.buf.write(data[size:])
return data[:size]
async def _read_chunk(self):
"""Read a chunk of data from the underlying stream."""
try:
return await self.reader.read(self.max_chunk)
except Exception:
return b""
def unread(self, data):
"""Push data back into the buffer for re-reading.
Args:
data: bytes to push back
"""
if data:
self.buf.seek(0, io.SEEK_END)
self.buf.write(data)
def has_buffered_data(self):
"""Check if there's data in the pushback buffer."""
pos = self.buf.tell()
self.buf.seek(0, io.SEEK_END)
has_data = self.buf.tell() > 0
self.buf.seek(pos)
return has_data

368
gunicorn/asgi/websocket.py Normal file
View File

@ -0,0 +1,368 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
WebSocket protocol handler for ASGI.
Implements RFC 6455 WebSocket protocol for ASGI applications.
"""
import asyncio
import base64
import hashlib
import struct
# WebSocket frame opcodes
OPCODE_CONTINUATION = 0x0
OPCODE_TEXT = 0x1
OPCODE_BINARY = 0x2
OPCODE_CLOSE = 0x8
OPCODE_PING = 0x9
OPCODE_PONG = 0xA
# WebSocket close codes
CLOSE_NORMAL = 1000
CLOSE_GOING_AWAY = 1001
CLOSE_PROTOCOL_ERROR = 1002
CLOSE_UNSUPPORTED = 1003
CLOSE_NO_STATUS = 1005
CLOSE_ABNORMAL = 1006
CLOSE_INVALID_DATA = 1007
CLOSE_POLICY_VIOLATION = 1008
CLOSE_MESSAGE_TOO_BIG = 1009
CLOSE_MANDATORY_EXT = 1010
CLOSE_INTERNAL_ERROR = 1011
# WebSocket handshake GUID (RFC 6455)
WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
class WebSocketProtocol:
"""WebSocket connection handler for ASGI applications."""
def __init__(self, transport, reader, scope, app, log):
"""Initialize WebSocket protocol handler.
Args:
transport: asyncio transport for writing
reader: asyncio StreamReader for reading
scope: ASGI WebSocket scope dict
app: ASGI application callable
log: Logger instance
"""
self.transport = transport
self.reader = reader
self.scope = scope
self.app = app
self.log = log
self.accepted = False
self.closed = False
self.close_code = None
self.close_reason = ""
# Message reassembly state
self._fragments = []
self._fragment_opcode = None
# Receive queue for incoming messages
self._receive_queue = asyncio.Queue()
async def run(self):
"""Run the WebSocket ASGI application."""
# Send initial connect event
await self._receive_queue.put({"type": "websocket.connect"})
# Start frame reading task
read_task = asyncio.create_task(self._read_frames())
try:
await self.app(self.scope, self._receive, self._send)
except Exception:
self.log.exception("Error in WebSocket ASGI application")
finally:
read_task.cancel()
try:
await read_task
except asyncio.CancelledError:
pass
# Send close frame if not already closed
if not self.closed and self.accepted:
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
async def _receive(self):
"""ASGI receive callable."""
return await self._receive_queue.get()
async def _send(self, message):
"""ASGI send callable."""
msg_type = message["type"]
if msg_type == "websocket.accept":
if self.accepted:
raise RuntimeError("WebSocket already accepted")
await self._send_accept(message)
self.accepted = True
elif msg_type == "websocket.send":
if not self.accepted:
raise RuntimeError("WebSocket not accepted")
if self.closed:
raise RuntimeError("WebSocket closed")
if "text" in message:
await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8"))
elif "bytes" in message:
await self._send_frame(OPCODE_BINARY, message["bytes"])
elif msg_type == "websocket.close":
code = message.get("code", CLOSE_NORMAL)
reason = message.get("reason", "")
await self._send_close(code, reason)
self.closed = True
async def _send_accept(self, message):
"""Send WebSocket handshake accept response."""
# Get Sec-WebSocket-Key from headers
ws_key = None
for name, value in self.scope["headers"]:
if name == b"sec-websocket-key":
ws_key = value
break
if not ws_key:
raise RuntimeError("Missing Sec-WebSocket-Key header")
# Calculate accept key
accept_key = base64.b64encode(
hashlib.sha1(ws_key + WS_GUID).digest()
).decode("ascii")
# Build response headers
headers = [
"HTTP/1.1 101 Switching Protocols\r\n",
"Upgrade: websocket\r\n",
"Connection: Upgrade\r\n",
f"Sec-WebSocket-Accept: {accept_key}\r\n",
]
# Add selected subprotocol if specified
subprotocol = message.get("subprotocol")
if subprotocol:
headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n")
# Add any extra headers from message
extra_headers = message.get("headers", [])
for name, value in extra_headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
headers.append(f"{name}: {value}\r\n")
headers.append("\r\n")
self.transport.write("".join(headers).encode("latin-1"))
async def _read_frames(self):
"""Read and process incoming WebSocket frames."""
try:
while not self.closed:
frame = await self._read_frame()
if frame is None:
break
opcode, payload = frame
if opcode == OPCODE_CLOSE:
await self._handle_close(payload)
break
if opcode == OPCODE_PING:
await self._send_frame(OPCODE_PONG, payload)
elif opcode == OPCODE_PONG:
# Ignore pongs
pass
elif opcode == OPCODE_TEXT:
await self._receive_queue.put({
"type": "websocket.receive",
"text": payload.decode("utf-8"),
})
elif opcode == OPCODE_BINARY:
await self._receive_queue.put({
"type": "websocket.receive",
"bytes": payload,
})
elif opcode == OPCODE_CONTINUATION:
# Handle fragmented messages
await self._handle_continuation(payload)
except asyncio.CancelledError:
raise
except Exception as e:
self.log.debug("WebSocket read error: %s", e)
finally:
# Signal disconnect
if not self.closed:
self.closed = True
await self._receive_queue.put({
"type": "websocket.disconnect",
"code": self.close_code or CLOSE_ABNORMAL,
})
async def _read_frame(self): # pylint: disable=too-many-return-statements
"""Read a single WebSocket frame.
Returns:
tuple: (opcode, payload) or None if connection closed
"""
# Read frame header (2 bytes minimum)
header = await self._read_exact(2)
if not header:
return None
first_byte, second_byte = header[0], header[1]
fin = (first_byte >> 7) & 1
rsv1 = (first_byte >> 6) & 1
rsv2 = (first_byte >> 5) & 1
rsv3 = (first_byte >> 4) & 1
opcode = first_byte & 0x0F
# RSV bits must be 0 (no extensions)
if rsv1 or rsv2 or rsv3:
await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set")
return None
masked = (second_byte >> 7) & 1
payload_len = second_byte & 0x7F
# Client frames must be masked (RFC 6455)
if not masked:
await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked")
return None
# Extended payload length
if payload_len == 126:
ext_len = await self._read_exact(2)
if not ext_len:
return None
payload_len = struct.unpack("!H", ext_len)[0]
elif payload_len == 127:
ext_len = await self._read_exact(8)
if not ext_len:
return None
payload_len = struct.unpack("!Q", ext_len)[0]
# Read masking key
masking_key = await self._read_exact(4)
if not masking_key:
return None
# Read payload
payload = await self._read_exact(payload_len)
if payload is None:
return None
# Unmask payload
payload = self._unmask(payload, masking_key)
# Handle fragmented messages
if opcode == OPCODE_CONTINUATION:
if self._fragment_opcode is None:
await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation")
return None
self._fragments.append(payload)
if fin:
# Reassemble complete message
full_payload = b"".join(self._fragments)
final_opcode = self._fragment_opcode
self._fragments = []
self._fragment_opcode = None
return (final_opcode, full_payload)
return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
if not fin:
# Start of fragmented message
self._fragment_opcode = opcode
self._fragments = [payload]
return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more
return (opcode, payload)
else:
# Control frames
return (opcode, payload)
async def _read_exact(self, n):
"""Read exactly n bytes from the reader."""
try:
data = await self.reader.readexactly(n)
return data
except asyncio.IncompleteReadError:
return None
except Exception:
return None
def _unmask(self, payload, masking_key):
"""Unmask WebSocket payload data."""
if not payload:
return payload
# XOR each byte with corresponding mask byte
return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload))
async def _handle_close(self, payload):
"""Handle incoming close frame."""
if len(payload) >= 2:
self.close_code = struct.unpack("!H", payload[:2])[0]
self.close_reason = payload[2:].decode("utf-8", errors="replace")
else:
self.close_code = CLOSE_NO_STATUS
self.close_reason = ""
# Echo close frame back if we haven't already sent one
if not self.closed:
await self._send_close(self.close_code, self.close_reason)
self.closed = True
async def _handle_continuation(self, payload): # pylint: disable=unused-argument
"""Handle continuation frame (already processed in _read_frame)."""
# This is called for partial fragments, nothing to do here
async def _send_frame(self, opcode, payload):
"""Send a WebSocket frame.
Server frames are not masked (RFC 6455).
"""
if isinstance(payload, str):
payload = payload.encode("utf-8")
length = len(payload)
frame = bytearray()
# First byte: FIN + opcode
frame.append(0x80 | opcode)
# Second byte: length (no mask bit for server)
if length < 126:
frame.append(length)
elif length < 65536:
frame.append(126)
frame.extend(struct.pack("!H", length))
else:
frame.append(127)
frame.extend(struct.pack("!Q", length))
# Payload
frame.extend(payload)
self.transport.write(bytes(frame))
async def _send_close(self, code, reason=""):
"""Send a close frame."""
payload = struct.pack("!H", code)
if reason:
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
await self._send_frame(OPCODE_CLOSE, payload)
self.closed = True

View File

@ -2096,6 +2096,53 @@ class ProxyAllowFrom(Setting):
"""
class Protocol(Setting):
name = "protocol"
section = "Server Mechanics"
cli = ["--protocol"]
meta = "STRING"
validator = validate_string
default = "http"
desc = """\
The protocol for incoming connections.
* ``http`` - Standard HTTP/1.x (default)
* ``uwsgi`` - uWSGI binary protocol (for nginx uwsgi_pass)
When using the uWSGI protocol, Gunicorn can receive requests from
nginx using the uwsgi_pass directive::
upstream gunicorn {
server 127.0.0.1:8000;
}
location / {
uwsgi_pass gunicorn;
include uwsgi_params;
}
"""
class UWSGIAllowFrom(Setting):
name = "uwsgi_allow_ips"
section = "Server Mechanics"
cli = ["--uwsgi-allow-from"]
validator = validate_string_to_addr_list
default = "127.0.0.1,::1"
desc = """\
IPs allowed to send uWSGI protocol requests (comma separated).
Set to ``*`` to allow all IPs. This is useful for setups where you
don't know in advance the IP address of front-end, but instead have
ensured via other means that only your authorized front-ends can
access Gunicorn.
.. note::
This option does not affect UNIX socket connections. Connections not associated with
an IP address are treated as allowed, unconditionally.
"""
class KeyFile(Setting):
name = "keyfile"
section = "SSL"
@ -2440,3 +2487,94 @@ class HeaderMap(Setting):
.. versionadded:: 22.0.0
"""
def validate_asgi_loop(val):
if val is None:
return "auto"
if not isinstance(val, str):
raise TypeError("Invalid type for casting: %s" % val)
val = val.lower().strip()
if val not in ("auto", "asyncio", "uvloop"):
raise ValueError("Invalid ASGI loop: %s" % val)
return val
def validate_asgi_lifespan(val):
if val is None:
return "auto"
if not isinstance(val, str):
raise TypeError("Invalid type for casting: %s" % val)
val = val.lower().strip()
if val not in ("auto", "on", "off"):
raise ValueError("Invalid ASGI lifespan: %s" % val)
return val
class ASGILoop(Setting):
name = "asgi_loop"
section = "Worker Processes"
cli = ["--asgi-loop"]
meta = "STRING"
validator = validate_asgi_loop
default = "auto"
desc = """\
Event loop implementation for ASGI workers.
- auto: Use uvloop if available, otherwise asyncio
- asyncio: Use Python's built-in asyncio event loop
- uvloop: Use uvloop (must be installed separately)
This setting only affects the ``asgi`` worker type.
uvloop typically provides better performance but requires
installing the uvloop package.
.. versionadded:: 24.0.0
"""
class ASGILifespan(Setting):
name = "asgi_lifespan"
section = "Worker Processes"
cli = ["--asgi-lifespan"]
meta = "STRING"
validator = validate_asgi_lifespan
default = "auto"
desc = """\
Control ASGI lifespan protocol handling.
- auto: Detect if app supports lifespan, enable if so
- on: Always run lifespan protocol (fail if unsupported)
- off: Never run lifespan protocol
The lifespan protocol allows ASGI applications to run code at
startup and shutdown. This is essential for frameworks like
FastAPI that need to initialize database connections, caches,
or other resources.
This setting only affects the ``asgi`` worker type.
.. versionadded:: 24.0.0
"""
class RootPath(Setting):
name = "root_path"
section = "Server Mechanics"
cli = ["--root-path"]
meta = "STRING"
validator = validate_string
default = ""
desc = """\
The root path for ASGI applications.
This is used to set the ``root_path`` in the ASGI scope, which
allows applications to know their mount point when behind a
reverse proxy.
For example, if your application is mounted at ``/api``, set
this to ``/api``.
.. versionadded:: 24.0.0
"""

View File

@ -5,4 +5,23 @@
from gunicorn.http.message import Message, Request
from gunicorn.http.parser import RequestParser
__all__ = ['Message', 'Request', 'RequestParser']
def get_parser(cfg, source, source_addr):
"""Get appropriate parser based on protocol config.
Args:
cfg: Gunicorn config object
source: Socket or iterable source
source_addr: Source address tuple or None
Returns:
Parser instance (RequestParser or UWSGIParser)
"""
protocol = getattr(cfg, 'protocol', 'http')
if protocol == 'uwsgi':
from gunicorn.uwsgi.parser import UWSGIParser
return UWSGIParser(cfg, source, source_addr)
return RequestParser(cfg, source, source_addr)
__all__ = ['Message', 'Request', 'RequestParser', 'get_parser']

View File

@ -0,0 +1,21 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
from gunicorn.uwsgi.message import UWSGIRequest
from gunicorn.uwsgi.parser import UWSGIParser
from gunicorn.uwsgi.errors import (
UWSGIParseException,
InvalidUWSGIHeader,
UnsupportedModifier,
ForbiddenUWSGIRequest,
)
__all__ = [
'UWSGIRequest',
'UWSGIParser',
'UWSGIParseException',
'InvalidUWSGIHeader',
'UnsupportedModifier',
'ForbiddenUWSGIRequest',
]

46
gunicorn/uwsgi/errors.py Normal file
View File

@ -0,0 +1,46 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
# We don't need to call super() in __init__ methods of our
# BaseException and Exception classes because we also define
# our own __str__ methods so there is no need to pass 'message'
# to the base class to get a meaningful output from 'str(exc)'.
# pylint: disable=super-init-not-called
class UWSGIParseException(Exception):
"""Base exception for uWSGI protocol parsing errors."""
class InvalidUWSGIHeader(UWSGIParseException):
"""Raised when the uWSGI header is malformed."""
def __init__(self, msg=""):
self.msg = msg
self.code = 400
def __str__(self):
return "Invalid uWSGI header: %s" % self.msg
class UnsupportedModifier(UWSGIParseException):
"""Raised when modifier1 is not 0 (WSGI request)."""
def __init__(self, modifier):
self.modifier = modifier
self.code = 501
def __str__(self):
return "Unsupported uWSGI modifier1: %d" % self.modifier
class ForbiddenUWSGIRequest(UWSGIParseException):
"""Raised when source IP is not in the allow list."""
def __init__(self, host):
self.host = host
self.code = 403
def __str__(self):
return "uWSGI request from %r not allowed" % self.host

255
gunicorn/uwsgi/message.py Normal file
View File

@ -0,0 +1,255 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import io
from gunicorn.http.body import LengthReader, Body
from gunicorn.uwsgi.errors import (
InvalidUWSGIHeader,
UnsupportedModifier,
ForbiddenUWSGIRequest,
)
# Maximum number of variables to prevent DoS
MAX_UWSGI_VARS = 1000
class UWSGIRequest:
"""uWSGI protocol request parser.
The uWSGI protocol uses a 4-byte binary header:
- Byte 0: modifier1 (packet type, 0 = WSGI request)
- Bytes 1-2: datasize (16-bit little-endian, size of vars block)
- Byte 3: modifier2 (additional flags, typically 0)
After the header:
1. Vars block (datasize bytes): Key-value pairs containing WSGI environ
- Each pair: 2-byte key_size (LE) + key + 2-byte val_size (LE) + value
2. Request body (determined by CONTENT_LENGTH in vars)
"""
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.remote_addr = peer_addr
self.req_number = req_number
# Request attributes (compatible with HTTP Request interface)
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = ""
self.version = (1, 1) # uWSGI is HTTP/1.1 compatible
self.headers = []
self.trailers = []
self.body = None
self.scheme = "https" if cfg.is_ssl else "http"
self.must_close = False
# uWSGI specific
self.uwsgi_vars = {}
self.modifier1 = 0
self.modifier2 = 0
# Proxy protocol compatibility
self.proxy_protocol_info = None
# Check if the source IP is allowed
self._check_allowed_ip()
# Parse the request
unused = self.parse(self.unreader)
self.unreader.unread(unused)
self.set_body_reader()
def _check_allowed_ip(self):
"""Verify source IP is in the allowed list."""
allow_ips = getattr(self.cfg, 'uwsgi_allow_ips', ['127.0.0.1', '::1'])
# UNIX sockets don't have IP addresses
if not isinstance(self.peer_addr, tuple):
return
# Wildcard allows all
if '*' in allow_ips:
return
if self.peer_addr[0] not in allow_ips:
raise ForbiddenUWSGIRequest(self.peer_addr[0])
def force_close(self):
"""Force the connection to close after this request."""
self.must_close = True
def parse(self, unreader):
"""Parse uWSGI packet header and vars block."""
# Read the 4-byte header
header = self._read_exact(unreader, 4)
if len(header) < 4:
raise InvalidUWSGIHeader("incomplete header")
self.modifier1 = header[0]
datasize = int.from_bytes(header[1:3], 'little')
self.modifier2 = header[3]
# Only modifier1=0 (WSGI request) is supported
if self.modifier1 != 0:
raise UnsupportedModifier(self.modifier1)
# Read the vars block
if datasize > 0:
vars_data = self._read_exact(unreader, datasize)
if len(vars_data) < datasize:
raise InvalidUWSGIHeader("incomplete vars block")
self._parse_vars(vars_data)
# Extract HTTP request info from vars
self._extract_request_info()
return b""
def _read_exact(self, unreader, size):
"""Read exactly size bytes from the unreader."""
buf = io.BytesIO()
remaining = size
while remaining > 0:
data = unreader.read()
if not data:
break
buf.write(data)
remaining = size - buf.tell()
result = buf.getvalue()
# Put back any extra bytes
if len(result) > size:
unreader.unread(result[size:])
result = result[:size]
return result
def _parse_vars(self, data):
"""Parse uWSGI vars block into key-value pairs.
Format: key_size (2 bytes LE) + key + val_size (2 bytes LE) + value
"""
pos = 0
var_count = 0
while pos < len(data):
if var_count >= MAX_UWSGI_VARS:
raise InvalidUWSGIHeader("too many variables")
# Key size (2 bytes, little-endian)
if pos + 2 > len(data):
raise InvalidUWSGIHeader("truncated key size")
key_size = int.from_bytes(data[pos:pos + 2], 'little')
pos += 2
# Key
if pos + key_size > len(data):
raise InvalidUWSGIHeader("truncated key")
key = data[pos:pos + key_size].decode('latin-1')
pos += key_size
# Value size (2 bytes, little-endian)
if pos + 2 > len(data):
raise InvalidUWSGIHeader("truncated value size")
val_size = int.from_bytes(data[pos:pos + 2], 'little')
pos += 2
# Value
if pos + val_size > len(data):
raise InvalidUWSGIHeader("truncated value")
value = data[pos:pos + val_size].decode('latin-1')
pos += val_size
self.uwsgi_vars[key] = value
var_count += 1
def _extract_request_info(self):
"""Extract HTTP request info from uWSGI vars.
Header Mapping (CGI/WSGI to HTTP):
The uWSGI protocol passes HTTP headers using CGI-style environment
variable naming. This method converts them back to HTTP header format:
- HTTP_* vars: Strip 'HTTP_' prefix, replace '_' with '-'
Example: HTTP_X_FORWARDED_FOR -> X-FORWARDED-FOR
Example: HTTP_ACCEPT_ENCODING -> ACCEPT-ENCODING
- CONTENT_TYPE: Mapped directly to CONTENT-TYPE header
(CGI spec excludes HTTP_ prefix for this header)
- CONTENT_LENGTH: Mapped directly to CONTENT-LENGTH header
(CGI spec excludes HTTP_ prefix for this header)
Note: The underscore-to-hyphen conversion is lossy. Headers that
originally contained underscores (e.g., X_Custom_Header) cannot be
distinguished from hyphenated headers (X-Custom-Header) after
passing through nginx/uWSGI. This is a CGI/WSGI specification
limitation, not specific to this implementation.
"""
# Method
self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET')
# URI and path
self.path = self.uwsgi_vars.get('PATH_INFO', '/')
self.query = self.uwsgi_vars.get('QUERY_STRING', '')
# Build URI
if self.query:
self.uri = "%s?%s" % (self.path, self.query)
else:
self.uri = self.path
# Scheme
if self.uwsgi_vars.get('HTTPS', '').lower() in ('on', '1', 'true'):
self.scheme = 'https'
elif 'wsgi.url_scheme' in self.uwsgi_vars:
self.scheme = self.uwsgi_vars['wsgi.url_scheme']
# Extract HTTP headers from CGI-style vars
# See docstring above for mapping details
for key, value in self.uwsgi_vars.items():
if key.startswith('HTTP_'):
# Convert HTTP_HEADER_NAME to HEADER-NAME
header_name = key[5:].replace('_', '-')
self.headers.append((header_name, value))
elif key == 'CONTENT_TYPE':
self.headers.append(('CONTENT-TYPE', value))
elif key == 'CONTENT_LENGTH':
self.headers.append(('CONTENT-LENGTH', value))
def set_body_reader(self):
"""Set up the body reader based on CONTENT_LENGTH."""
content_length = 0
# Get content length from vars
if 'CONTENT_LENGTH' in self.uwsgi_vars:
try:
content_length = max(int(self.uwsgi_vars['CONTENT_LENGTH']), 0)
except ValueError:
content_length = 0
self.body = Body(LengthReader(self.unreader, content_length))
def should_close(self):
"""Determine if the connection should be closed after this request."""
if self.must_close:
return True
# Check HTTP_CONNECTION header
connection = self.uwsgi_vars.get('HTTP_CONNECTION', '').lower()
if connection == 'close':
return True
elif connection == 'keep-alive':
return False
# Default to keep-alive for HTTP/1.1
return False

12
gunicorn/uwsgi/parser.py Normal file
View File

@ -0,0 +1,12 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
from gunicorn.http.parser import Parser
from gunicorn.uwsgi.message import UWSGIRequest
class UWSGIParser(Parser):
"""Parser for uWSGI protocol requests."""
mesg_class = UWSGIRequest

View File

@ -11,4 +11,5 @@ SUPPORTED_WORKERS = {
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
"gthread": "gunicorn.workers.gthread.ThreadWorker",
"asgi": "gunicorn.workers.gasgi.ASGIWorker",
}

View File

@ -32,7 +32,7 @@ class AsyncWorker(base.Worker):
def handle(self, listener, client, addr):
req = None
try:
parser = http.RequestParser(self.cfg, client, addr)
parser = http.get_parser(self.cfg, client, addr)
try:
listener_name = listener.getsockname()
if not self.cfg.keepalive:

281
gunicorn/workers/gasgi.py Normal file
View File

@ -0,0 +1,281 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI worker for gunicorn.
Provides native asyncio-based ASGI support using gunicorn's own
HTTP parsing infrastructure.
"""
import asyncio
import os
import signal
import sys
from gunicorn.workers import base
from gunicorn.asgi.protocol import ASGIProtocol
class ASGIWorker(base.Worker):
"""ASGI worker using asyncio event loop.
Supports:
- HTTP/1.1 with keepalive
- WebSocket connections
- Lifespan protocol (startup/shutdown hooks)
- Optional uvloop for improved performance
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.worker_connections = self.cfg.worker_connections
self.loop = None
self.servers = []
self.nr_conns = 0
self.lifespan = None
self.state = {} # Shared state for lifespan
@classmethod
def check_config(cls, cfg, log):
"""Validate configuration for ASGI worker."""
if cfg.threads > 1:
log.warning("ASGI worker does not use threads configuration. "
"Use worker_connections instead.")
def init_process(self):
"""Initialize the worker process."""
# Setup event loop before calling super()
self._setup_event_loop()
super().init_process()
def _setup_event_loop(self):
"""Setup the asyncio event loop."""
loop_type = getattr(self.cfg, 'asgi_loop', 'auto')
if loop_type == "auto":
try:
import uvloop
loop_type = "uvloop"
except ImportError:
loop_type = "asyncio"
if loop_type == "uvloop":
try:
import uvloop
self.loop = uvloop.new_event_loop()
self.log.debug("Using uvloop event loop")
except ImportError:
self.log.warning("uvloop not available, falling back to asyncio")
self.loop = asyncio.new_event_loop()
else:
self.loop = asyncio.new_event_loop()
self.log.debug("Using asyncio event loop")
asyncio.set_event_loop(self.loop)
def load_wsgi(self):
"""Load the ASGI application."""
try:
self.asgi = self.app.wsgi()
except SyntaxError as e:
if not self.cfg.reload:
raise
self.log.exception(e)
self.asgi = self._make_error_app(str(e))
def _make_error_app(self, error_msg):
"""Create an error ASGI app for syntax errors during reload."""
async def error_app(scope, receive, send):
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 500,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": f"Application error: {error_msg}".encode(),
})
elif scope["type"] == "lifespan":
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
message = await receive()
if message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return error_app
def init_signals(self):
"""Initialize signal handlers for asyncio."""
# Reset all signals first
for s in self.SIGNALS:
signal.signal(s, signal.SIG_DFL)
# Set up signal handlers via the event loop
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal)
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal)
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal)
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal)
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal)
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal)
def handle_quit_signal(self):
"""Handle SIGQUIT - immediate shutdown."""
self.alive = False
self.cfg.worker_int(self)
def handle_exit_signal(self):
"""Handle SIGTERM - graceful shutdown."""
self.alive = False
def handle_usr1_signal(self):
"""Handle SIGUSR1 - reopen log files."""
self.log.reopen_files()
def handle_winch_signal(self):
"""Handle SIGWINCH - ignored in worker."""
self.log.debug("worker: SIGWINCH ignored.")
def handle_abort_signal(self):
"""Handle SIGABRT - abort."""
self.alive = False
self.cfg.worker_abort(self)
sys.exit(1)
def run(self):
"""Main entry point for the worker."""
try:
self.loop.run_until_complete(self._serve())
except Exception as e:
self.log.exception("Worker exception: %s", e)
finally:
self._cleanup()
async def _serve(self):
"""Main async serving loop."""
# Run lifespan startup
lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto')
if lifespan_mode != "off":
from gunicorn.asgi.lifespan import LifespanManager
self.lifespan = LifespanManager(self.asgi, self.log, self.state)
try:
await self.lifespan.startup()
except Exception as e:
if lifespan_mode == "on":
self.log.error("ASGI lifespan startup failed: %s", e)
return
else:
# auto mode - app doesn't support lifespan
self.log.debug("ASGI lifespan not supported by app: %s", e)
self.lifespan = None
# Create servers for each listener socket
ssl_context = self._get_ssl_context()
for sock in self.sockets:
try:
server = await self.loop.create_server(
lambda: ASGIProtocol(self),
sock=sock.sock,
ssl=ssl_context,
reuse_address=True,
start_serving=True,
)
self.servers.append(server)
self.log.info("ASGI server listening on %s", sock)
except Exception as e:
self.log.error("Failed to create server on %s: %s", sock, e)
if not self.servers:
self.log.error("No servers could be started")
return
# Main loop with heartbeat
try:
while self.alive:
self.notify()
# Check if parent is still alive
if self.ppid != os.getppid():
self.log.info("Parent changed, shutting down: %s", self)
break
# Check connection limit
# (Connections are managed by nr_conns in ASGIProtocol)
await asyncio.sleep(1.0)
except asyncio.CancelledError:
pass
# Graceful shutdown
await self._shutdown()
async def _shutdown(self):
"""Perform graceful shutdown."""
self.log.info("Worker shutting down...")
# Stop accepting new connections
for server in self.servers:
server.close()
# Wait for servers to close
for server in self.servers:
await server.wait_closed()
# Wait for in-flight connections (with timeout)
graceful_timeout = self.cfg.graceful_timeout
if self.nr_conns > 0:
self.log.info("Waiting for %d connections to finish...", self.nr_conns)
deadline = self.loop.time() + graceful_timeout
while self.nr_conns > 0 and self.loop.time() < deadline:
await asyncio.sleep(0.1)
if self.nr_conns > 0:
self.log.warning("Closing %d connections after timeout", self.nr_conns)
# Run lifespan shutdown
if self.lifespan:
try:
await self.lifespan.shutdown()
except Exception as e:
self.log.error("ASGI lifespan shutdown error: %s", e)
def _get_ssl_context(self):
"""Get SSL context if configured."""
if not self.cfg.is_ssl:
return None
try:
from gunicorn import sock
return sock.ssl_context(self.cfg)
except Exception as e:
self.log.error("Failed to create SSL context: %s", e)
return None
def _cleanup(self):
"""Clean up resources on exit."""
try:
# Cancel all pending tasks
pending = asyncio.all_tasks(self.loop)
for task in pending:
task.cancel()
# Run loop until all tasks are cancelled
if pending:
self.loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
self.loop.close()
except Exception as e:
self.log.debug("Cleanup error: %s", e)
# Close sockets
for s in self.sockets:
try:
s.close()
except Exception:
pass

View File

@ -58,7 +58,7 @@ class TConn:
self.sock = sock.ssl_wrap_socket(self.sock, self.cfg)
# initialize the parser
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
self.parser = http.get_parser(self.cfg, self.sock, self.client)
def set_timeout(self):
# Use monotonic clock for reliability (time.time() can jump due to NTP)

View File

@ -129,7 +129,7 @@ class SyncWorker(base.Worker):
try:
if self.cfg.is_ssl:
client = sock.ssl_wrap_socket(client, self.cfg)
parser = http.RequestParser(self.cfg, client, addr)
parser = http.get_parser(self.cfg, client, addr)
req = next(parser)
self.handle_request(listener, req, client, addr)
except http.errors.NoMoreData as e:

View File

@ -58,6 +58,7 @@ testing = [
"coverage",
"pytest",
"pytest-cov",
"pytest-asyncio",
]
[project.scripts]
@ -70,7 +71,7 @@ main = "gunicorn.app.pasterapp:serve"
[tool.pytest.ini_options]
# # can override these: python -m pytest --override-ini="addopts="
norecursedirs = ["examples", "lib", "local", "src"]
norecursedirs = ["examples", "lib", "local", "src", "tests/docker"]
testpaths = ["tests/"]
addopts = "--assert=plain --cov=gunicorn --cov-report=xml"

View File

@ -3,3 +3,4 @@ eventlet
coverage
pytest>=7.2.0
pytest-cov
pytest-asyncio

View File

@ -0,0 +1,16 @@
FROM python:3.11-slim
WORKDIR /app
# Copy gunicorn source
COPY . /app/gunicorn-src/
# Install gunicorn from source
RUN pip install --no-cache-dir /app/gunicorn-src/
# Copy test application
COPY tests/docker/uwsgi/app.py /app/
EXPOSE 8000
CMD ["gunicorn", "--protocol", "uwsgi", "--uwsgi-allow-from", "*", "--bind", "0.0.0.0:8000", "--workers", "2", "--log-level", "debug", "app:application"]

View File

@ -0,0 +1,12 @@
FROM nginx:alpine
# Remove default config
RUN rm /etc/nginx/conf.d/default.conf
# Copy custom config
COPY nginx.conf /etc/nginx/nginx.conf
COPY uwsgi_params /etc/nginx/uwsgi_params
EXPOSE 8080
CMD ["nginx", "-g", "daemon off;"]

View File

@ -0,0 +1,154 @@
# uWSGI Protocol Docker Integration Tests
This directory contains Docker-based integration tests that verify gunicorn's
uWSGI binary protocol implementation works correctly with nginx's `uwsgi_pass`
directive.
## Architecture
```
[pytest] --HTTP--> [nginx:8080] --uwsgi_pass--> [gunicorn:8000]
```
The tests make HTTP requests to nginx, which proxies them to gunicorn using the
uWSGI binary protocol. This validates the complete request/response cycle through
the protocol.
## Prerequisites
- Docker
- Docker Compose (v2)
- Python 3.8+
- pytest
- requests
## Running Tests
### From repository root:
```bash
# Run all uWSGI integration tests
pytest tests/docker/uwsgi/ -v
# Run specific test class
pytest tests/docker/uwsgi/ -v -k TestBasicRequests
# Skip Docker tests (for CI environments without Docker)
pytest tests/ -v -m "not docker"
```
### Manual testing:
```bash
cd tests/docker/uwsgi
# Start services
docker compose up -d
# Wait for services to be healthy
docker compose ps
# Test endpoints
curl http://localhost:8080/
curl -X POST -d "test body" http://localhost:8080/echo
curl http://localhost:8080/headers
curl "http://localhost:8080/query?foo=bar"
curl http://localhost:8080/environ
curl http://localhost:8080/error/404
curl http://localhost:8080/large > /dev/null # 1MB response
# View logs
docker compose logs gunicorn
docker compose logs nginx
# Stop services
docker compose down -v
```
## Test Categories
| Category | Description |
|----------|-------------|
| `TestBasicRequests` | GET, POST, query strings, large bodies |
| `TestHeaderPreservation` | Custom headers, Host, Content-Type, User-Agent |
| `TestKeepAlive` | Multiple requests per connection |
| `TestErrorResponses` | HTTP error codes (400, 404, 500, etc.) |
| `TestEnvironVariables` | WSGI environ: REQUEST_METHOD, PATH_INFO, etc. |
| `TestLargeResponses` | 1MB response body streaming |
| `TestConcurrency` | Parallel request handling |
| `TestSpecialCases` | Edge cases: binary data, unicode, long headers |
## Files
| File | Purpose |
|------|---------|
| `docker-compose.yml` | Orchestrates nginx + gunicorn containers |
| `Dockerfile.gunicorn` | Builds gunicorn image with test app |
| `Dockerfile.nginx` | Builds nginx with uwsgi config |
| `nginx.conf` | nginx configuration using `uwsgi_pass` |
| `uwsgi_params` | Standard uwsgi parameter mappings |
| `app.py` | Test WSGI application with multiple endpoints |
| `conftest.py` | pytest fixtures for Docker lifecycle |
| `test_uwsgi_integration.py` | Test cases |
## Test App Endpoints
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/` | GET | Basic hello response |
| `/echo` | POST | Echo request body |
| `/headers` | GET/POST | Return received headers as JSON |
| `/environ` | GET/POST | Return WSGI environ as JSON |
| `/query` | GET | Return query params as JSON |
| `/json` | POST | Parse and echo JSON body |
| `/error/{code}` | GET | Return specified HTTP error |
| `/large` | GET | Return 1MB response |
## Gunicorn Configuration
The gunicorn container runs with:
```bash
gunicorn \
--protocol uwsgi \
--uwsgi-allow-from "*" \
--bind 0.0.0.0:8000 \
--workers 2 \
--log-level debug \
app:application
```
Key settings:
- `--protocol uwsgi`: Enable uWSGI binary protocol
- `--uwsgi-allow-from "*"`: Accept connections from Docker network IPs
## Troubleshooting
### Services won't start
Check Docker logs:
```bash
docker compose logs
```
### Connection refused
Wait for health checks:
```bash
docker compose ps # Check health status
```
### Tests timing out
Increase `STARTUP_TIMEOUT` in `conftest.py` or check if ports are in use:
```bash
lsof -i :8080
lsof -i :8000
```
### Rebuild after code changes
```bash
docker compose build --no-cache
docker compose up -d
```

222
tests/docker/uwsgi/app.py Normal file
View File

@ -0,0 +1,222 @@
"""
Test WSGI application for uWSGI protocol integration tests.
This application provides various endpoints to test different aspects
of the uWSGI binary protocol when proxied through nginx.
"""
import json
def application(environ, start_response):
"""Main WSGI application entry point."""
path = environ.get('PATH_INFO', '/')
method = environ.get('REQUEST_METHOD', 'GET')
# Route to appropriate handler
if path == '/':
return handle_root(environ, start_response)
elif path == '/echo':
return handle_echo(environ, start_response)
elif path == '/headers':
return handle_headers(environ, start_response)
elif path == '/environ':
return handle_environ(environ, start_response)
elif path.startswith('/error/'):
return handle_error(environ, start_response, path)
elif path == '/large':
return handle_large(environ, start_response)
elif path == '/json':
return handle_json(environ, start_response)
elif path == '/query':
return handle_query(environ, start_response)
else:
return handle_not_found(environ, start_response)
def handle_root(environ, start_response):
"""Basic root endpoint."""
status = '200 OK'
headers = [('Content-Type', 'text/plain')]
start_response(status, headers)
return [b'Hello from gunicorn uWSGI!\n']
def handle_echo(environ, start_response):
"""Echo back the request body."""
try:
content_length = int(environ.get('CONTENT_LENGTH', 0))
except (ValueError, TypeError):
content_length = 0
body = b''
if content_length > 0:
body = environ['wsgi.input'].read(content_length)
status = '200 OK'
headers = [
('Content-Type', 'application/octet-stream'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_headers(environ, start_response):
"""Return received HTTP headers as JSON."""
headers_dict = {}
for key, value in environ.items():
if key.startswith('HTTP_'):
# Convert HTTP_X_CUSTOM_HEADER to X-Custom-Header
header_name = key[5:].replace('_', '-').title()
headers_dict[header_name] = value
# Also include some special headers
if 'CONTENT_TYPE' in environ:
headers_dict['Content-Type'] = environ['CONTENT_TYPE']
if 'CONTENT_LENGTH' in environ:
headers_dict['Content-Length'] = environ['CONTENT_LENGTH']
body = json.dumps(headers_dict, indent=2).encode('utf-8')
status = '200 OK'
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_environ(environ, start_response):
"""Return WSGI environ variables as JSON."""
# Filter to serializable values
safe_environ = {}
skip_keys = {'wsgi.input', 'wsgi.errors', 'wsgi.file_wrapper'}
for key, value in environ.items():
if key in skip_keys:
continue
try:
# Test if value is JSON serializable
json.dumps(value)
safe_environ[key] = value
except (TypeError, ValueError):
safe_environ[key] = str(value)
body = json.dumps(safe_environ, indent=2).encode('utf-8')
status = '200 OK'
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_error(environ, start_response, path):
"""Return specified HTTP error code."""
try:
code = int(path.split('/')[-1])
except ValueError:
code = 500
status_messages = {
400: 'Bad Request',
401: 'Unauthorized',
403: 'Forbidden',
404: 'Not Found',
500: 'Internal Server Error',
502: 'Bad Gateway',
503: 'Service Unavailable',
}
message = status_messages.get(code, 'Error')
status = f'{code} {message}'
body = json.dumps({'error': message, 'code': code}).encode('utf-8')
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_large(environ, start_response):
"""Return a 1MB response body for testing large responses."""
# Generate 1MB of data (1024 * 1024 bytes)
chunk_size = 1024
num_chunks = 1024
chunk = b'X' * chunk_size
status = '200 OK'
headers = [
('Content-Type', 'application/octet-stream'),
('Content-Length', str(chunk_size * num_chunks))
]
start_response(status, headers)
# Return as generator for streaming
def generate():
for _ in range(num_chunks):
yield chunk
return generate()
def handle_json(environ, start_response):
"""Handle JSON POST requests."""
try:
content_length = int(environ.get('CONTENT_LENGTH', 0))
except (ValueError, TypeError):
content_length = 0
if content_length > 0:
body = environ['wsgi.input'].read(content_length)
try:
data = json.loads(body.decode('utf-8'))
response = {'received': data, 'status': 'ok'}
except json.JSONDecodeError:
response = {'error': 'Invalid JSON', 'status': 'error'}
else:
response = {'error': 'No body', 'status': 'error'}
body = json.dumps(response).encode('utf-8')
status = '200 OK'
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_query(environ, start_response):
"""Return query string parameters as JSON."""
from urllib.parse import parse_qs
query_string = environ.get('QUERY_STRING', '')
params = parse_qs(query_string)
# Convert lists to single values where appropriate
simple_params = {k: v[0] if len(v) == 1 else v for k, v in params.items()}
body = json.dumps(simple_params).encode('utf-8')
status = '200 OK'
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]
def handle_not_found(environ, start_response):
"""Handle 404 for unknown paths."""
body = json.dumps({'error': 'Not Found', 'path': environ.get('PATH_INFO')}).encode('utf-8')
status = '404 Not Found'
headers = [
('Content-Type', 'application/json'),
('Content-Length', str(len(body)))
]
start_response(status, headers)
return [body]

View File

@ -0,0 +1,121 @@
"""
pytest fixtures for uWSGI Docker integration tests.
"""
import os
import subprocess
import time
import pytest
import requests
COMPOSE_FILE = os.path.join(os.path.dirname(__file__), 'docker-compose.yml')
NGINX_URL = 'http://127.0.0.1:8080'
STARTUP_TIMEOUT = 60 # seconds
def is_docker_available():
"""Check if Docker is available."""
try:
result = subprocess.run(
['docker', 'info'],
capture_output=True,
timeout=10
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
def is_compose_available():
"""Check if docker compose is available."""
try:
result = subprocess.run(
['docker', 'compose', 'version'],
capture_output=True,
timeout=10
)
return result.returncode == 0
except (subprocess.TimeoutExpired, FileNotFoundError):
return False
docker_available = pytest.mark.skipif(
not is_docker_available() or not is_compose_available(),
reason="Docker or docker compose not available"
)
@pytest.fixture(scope='session')
def docker_services():
"""
Start Docker Compose services for the test session.
This fixture builds and starts the gunicorn and nginx containers,
waits for them to be healthy, and tears them down after all tests.
"""
if not is_docker_available() or not is_compose_available():
pytest.skip("Docker or docker compose not available")
# Build and start services
subprocess.run(
['docker', 'compose', '-f', COMPOSE_FILE, 'build'],
check=True,
capture_output=True
)
subprocess.run(
['docker', 'compose', '-f', COMPOSE_FILE, 'up', '-d'],
check=True,
capture_output=True
)
# Wait for services to be healthy
start_time = time.time()
while time.time() - start_time < STARTUP_TIMEOUT:
try:
response = requests.get(f'{NGINX_URL}/', timeout=2)
if response.status_code == 200:
break
except requests.RequestException:
pass
time.sleep(1)
else:
# Get logs for debugging
logs = subprocess.run(
['docker', 'compose', '-f', COMPOSE_FILE, 'logs'],
capture_output=True,
text=True
)
subprocess.run(
['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'],
capture_output=True
)
pytest.fail(
f"Services did not become healthy within {STARTUP_TIMEOUT}s.\n"
f"Logs:\n{logs.stdout}\n{logs.stderr}"
)
yield
# Teardown
subprocess.run(
['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'],
capture_output=True
)
@pytest.fixture
def nginx_url(docker_services):
"""Return the nginx base URL."""
return NGINX_URL
@pytest.fixture
def session(docker_services):
"""Return a requests Session with keep-alive enabled."""
with requests.Session() as s:
# Enable keep-alive
s.headers['Connection'] = 'keep-alive'
yield s

View File

@ -0,0 +1,29 @@
services:
gunicorn:
build:
context: ../../..
dockerfile: tests/docker/uwsgi/Dockerfile.gunicorn
expose:
- "8000"
healthcheck:
test: ["CMD", "python", "-c", "import socket; s=socket.socket(); s.connect(('localhost', 8000)); s.close()"]
interval: 2s
timeout: 5s
retries: 10
start_period: 5s
nginx:
build:
context: .
dockerfile: Dockerfile.nginx
ports:
- "8080:8080"
depends_on:
gunicorn:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/"]
interval: 2s
timeout: 5s
retries: 10
start_period: 5s

View File

@ -0,0 +1,46 @@
worker_processes 1;
events {
worker_connections 1024;
}
http {
include /etc/nginx/mime.types;
default_type application/octet-stream;
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent"';
access_log /var/log/nginx/access.log main;
error_log /var/log/nginx/error.log debug;
sendfile on;
keepalive_timeout 65;
upstream gunicorn {
server gunicorn:8000;
}
server {
listen 8080;
server_name localhost;
# Increase buffer sizes for large headers
uwsgi_buffer_size 32k;
uwsgi_buffers 8 32k;
uwsgi_busy_buffers_size 64k;
# Read timeout for large responses
uwsgi_read_timeout 300s;
location / {
uwsgi_pass gunicorn;
include uwsgi_params;
# Pass additional headers
uwsgi_param HTTP_X_FORWARDED_FOR $proxy_add_x_forwarded_for;
uwsgi_param HTTP_X_REAL_IP $remote_addr;
}
}
}

View File

@ -0,0 +1,312 @@
"""
Integration tests for gunicorn's uWSGI binary protocol with nginx.
These tests verify that gunicorn correctly implements the uWSGI binary
protocol by running actual requests through nginx's uwsgi_pass directive.
"""
import concurrent.futures
import json
import pytest
import requests
from conftest import docker_available
@docker_available
class TestBasicRequests:
"""Test basic HTTP request handling through uWSGI protocol."""
def test_get_root(self, nginx_url):
"""Test basic GET request to root endpoint."""
response = requests.get(f'{nginx_url}/')
assert response.status_code == 200
assert b'Hello from gunicorn uWSGI!' in response.content
def test_get_with_query_string(self, nginx_url):
"""Test GET request with query string parameters."""
response = requests.get(f'{nginx_url}/query?foo=bar&baz=qux')
assert response.status_code == 200
data = response.json()
assert data['foo'] == 'bar'
assert data['baz'] == 'qux'
def test_post_echo(self, nginx_url):
"""Test POST request with body echo."""
test_body = b'This is a test body content'
response = requests.post(f'{nginx_url}/echo', data=test_body)
assert response.status_code == 200
assert response.content == test_body
def test_post_json(self, nginx_url):
"""Test POST request with JSON body."""
test_data = {'key': 'value', 'number': 42, 'nested': {'a': 1}}
response = requests.post(
f'{nginx_url}/json',
json=test_data,
headers={'Content-Type': 'application/json'}
)
assert response.status_code == 200
data = response.json()
assert data['status'] == 'ok'
assert data['received'] == test_data
def test_post_large_body(self, nginx_url):
"""Test POST with large request body (100KB)."""
large_body = b'X' * (100 * 1024)
response = requests.post(f'{nginx_url}/echo', data=large_body)
assert response.status_code == 200
assert len(response.content) == len(large_body)
assert response.content == large_body
@docker_available
class TestHeaderPreservation:
"""Test that headers are correctly passed through uWSGI protocol."""
def test_custom_headers(self, nginx_url):
"""Test custom headers are passed to the application."""
custom_headers = {
'X-Custom-Header': 'custom-value',
'X-Another-Header': 'another-value'
}
response = requests.get(f'{nginx_url}/headers', headers=custom_headers)
assert response.status_code == 200
data = response.json()
assert data.get('X-Custom-Header') == 'custom-value'
assert data.get('X-Another-Header') == 'another-value'
def test_host_header(self, nginx_url):
"""Test Host header is passed correctly."""
response = requests.get(
f'{nginx_url}/headers',
headers={'Host': 'test.example.com'}
)
assert response.status_code == 200
data = response.json()
assert data.get('Host') == 'test.example.com'
def test_content_type_header(self, nginx_url):
"""Test Content-Type header is passed correctly."""
response = requests.post(
f'{nginx_url}/headers',
data='test',
headers={'Content-Type': 'application/x-custom-type'}
)
assert response.status_code == 200
data = response.json()
assert data.get('Content-Type') == 'application/x-custom-type'
def test_user_agent_header(self, nginx_url):
"""Test User-Agent header is passed correctly."""
response = requests.get(
f'{nginx_url}/headers',
headers={'User-Agent': 'TestAgent/1.0'}
)
assert response.status_code == 200
data = response.json()
assert data.get('User-Agent') == 'TestAgent/1.0'
@docker_available
class TestKeepAlive:
"""Test HTTP keep-alive with multiple requests per connection."""
def test_multiple_requests_same_session(self, session, nginx_url):
"""Test multiple requests using same session/connection."""
for i in range(5):
response = session.get(f'{nginx_url}/')
assert response.status_code == 200
def test_mixed_requests_same_session(self, session, nginx_url):
"""Test mixed GET and POST requests using same session."""
# GET request
response = session.get(f'{nginx_url}/')
assert response.status_code == 200
# POST request
response = session.post(f'{nginx_url}/echo', data=b'test')
assert response.status_code == 200
assert response.content == b'test'
# Another GET
response = session.get(f'{nginx_url}/headers')
assert response.status_code == 200
# JSON POST
response = session.post(f'{nginx_url}/json', json={'test': 1})
assert response.status_code == 200
@docker_available
class TestErrorResponses:
"""Test HTTP error responses through uWSGI protocol."""
@pytest.mark.parametrize('code', [400, 401, 403, 404, 500, 502, 503])
def test_error_codes(self, nginx_url, code):
"""Test various HTTP error codes are returned correctly."""
response = requests.get(f'{nginx_url}/error/{code}')
assert response.status_code == code
data = response.json()
assert data['code'] == code
def test_not_found(self, nginx_url):
"""Test 404 for non-existent path."""
response = requests.get(f'{nginx_url}/nonexistent/path')
assert response.status_code == 404
data = response.json()
assert data['error'] == 'Not Found'
assert data['path'] == '/nonexistent/path'
@docker_available
class TestEnvironVariables:
"""Test WSGI environ variables are correctly set."""
def test_request_method(self, nginx_url):
"""Test REQUEST_METHOD is set correctly."""
response = requests.get(f'{nginx_url}/environ')
assert response.status_code == 200
data = response.json()
assert data.get('REQUEST_METHOD') == 'GET'
response = requests.post(f'{nginx_url}/environ', data='')
data = response.json()
assert data.get('REQUEST_METHOD') == 'POST'
def test_path_info(self, nginx_url):
"""Test PATH_INFO is set correctly."""
response = requests.get(f'{nginx_url}/environ')
assert response.status_code == 200
data = response.json()
assert data.get('PATH_INFO') == '/environ'
def test_query_string(self, nginx_url):
"""Test QUERY_STRING is set correctly."""
response = requests.get(f'{nginx_url}/environ?foo=bar&test=123')
assert response.status_code == 200
data = response.json()
assert data.get('QUERY_STRING') == 'foo=bar&test=123'
def test_server_protocol(self, nginx_url):
"""Test SERVER_PROTOCOL is set."""
response = requests.get(f'{nginx_url}/environ')
assert response.status_code == 200
data = response.json()
assert 'SERVER_PROTOCOL' in data
assert data['SERVER_PROTOCOL'].startswith('HTTP/')
def test_content_length(self, nginx_url):
"""Test CONTENT_LENGTH is set for POST requests."""
body = 'test body content'
response = requests.post(f'{nginx_url}/environ', data=body)
assert response.status_code == 200
data = response.json()
assert data.get('CONTENT_LENGTH') == str(len(body))
@docker_available
class TestLargeResponses:
"""Test large response handling through uWSGI protocol."""
def test_1mb_response(self, nginx_url):
"""Test 1MB response body is received correctly."""
response = requests.get(f'{nginx_url}/large')
assert response.status_code == 200
assert len(response.content) == 1024 * 1024
# Verify content is all 'X' characters
assert response.content == b'X' * (1024 * 1024)
def test_large_response_content_length(self, nginx_url):
"""Test Content-Length header for large response."""
response = requests.get(f'{nginx_url}/large')
assert response.status_code == 200
assert response.headers.get('Content-Length') == str(1024 * 1024)
@docker_available
class TestConcurrency:
"""Test concurrent request handling."""
def test_parallel_requests(self, nginx_url):
"""Test handling multiple parallel requests."""
num_requests = 20
def make_request(i):
response = requests.get(f'{nginx_url}/query?id={i}')
return response.status_code, response.json().get('id')
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(make_request, i) for i in range(num_requests)]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
# All requests should succeed
assert all(status == 200 for status, _ in results)
# All IDs should be present
ids = set(id_val for _, id_val in results)
assert ids == set(str(i) for i in range(num_requests))
def test_parallel_mixed_requests(self, nginx_url):
"""Test parallel GET and POST requests."""
def get_request():
return requests.get(f'{nginx_url}/').status_code
def post_request(data):
response = requests.post(f'{nginx_url}/echo', data=data)
return response.status_code, response.content
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
get_futures = [executor.submit(get_request) for _ in range(10)]
post_futures = [
executor.submit(post_request, f'data-{i}'.encode())
for i in range(10)
]
get_results = [f.result() for f in get_futures]
post_results = [f.result() for f in post_futures]
assert all(status == 200 for status in get_results)
assert all(status == 200 for status, _ in post_results)
@docker_available
class TestSpecialCases:
"""Test edge cases and special scenarios."""
def test_empty_body_post(self, nginx_url):
"""Test POST with empty body."""
response = requests.post(f'{nginx_url}/echo', data=b'')
assert response.status_code == 200
assert response.content == b''
def test_binary_body(self, nginx_url):
"""Test POST with binary body containing null bytes."""
binary_data = bytes(range(256))
response = requests.post(f'{nginx_url}/echo', data=binary_data)
assert response.status_code == 200
assert response.content == binary_data
def test_unicode_in_query_string(self, nginx_url):
"""Test unicode characters in query string."""
response = requests.get(f'{nginx_url}/query', params={'name': 'test'})
assert response.status_code == 200
data = response.json()
assert data.get('name') == 'test'
def test_special_characters_in_path(self, nginx_url):
"""Test handling of special path that triggers 404."""
# This should return 404 since the path doesn't exist
response = requests.get(f'{nginx_url}/path/with/slashes')
assert response.status_code == 404
def test_long_header_value(self, nginx_url):
"""Test handling of long header values."""
long_value = 'X' * 4096 # 4KB header value
response = requests.get(
f'{nginx_url}/headers',
headers={'X-Long-Header': long_value}
)
assert response.status_code == 200
data = response.json()
assert data.get('X-Long-Header') == long_value

View File

@ -0,0 +1,16 @@
uwsgi_param QUERY_STRING $query_string;
uwsgi_param REQUEST_METHOD $request_method;
uwsgi_param CONTENT_TYPE $content_type;
uwsgi_param CONTENT_LENGTH $content_length;
uwsgi_param REQUEST_URI $request_uri;
uwsgi_param PATH_INFO $document_uri;
uwsgi_param DOCUMENT_ROOT $document_root;
uwsgi_param SERVER_PROTOCOL $server_protocol;
uwsgi_param REQUEST_SCHEME $scheme;
uwsgi_param HTTPS $https if_not_empty;
uwsgi_param REMOTE_ADDR $remote_addr;
uwsgi_param REMOTE_PORT $remote_port;
uwsgi_param SERVER_PORT $server_port;
uwsgi_param SERVER_NAME $server_name;

285
tests/test_asgi.py Normal file
View File

@ -0,0 +1,285 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for ASGI worker components.
"""
import asyncio
import io
import pytest
from unittest import mock
from gunicorn.asgi.unreader import AsyncUnreader
from gunicorn.asgi.message import AsyncRequest
class MockStreamReader:
"""Mock asyncio.StreamReader for testing."""
def __init__(self, data):
self.data = data
self.pos = 0
async def read(self, size=-1):
if self.pos >= len(self.data):
return b""
if size < 0:
result = self.data[self.pos:]
self.pos = len(self.data)
else:
result = self.data[self.pos:self.pos + size]
self.pos += size
return result
async def readexactly(self, n):
if self.pos + n > len(self.data):
raise asyncio.IncompleteReadError(
self.data[self.pos:], n
)
result = self.data[self.pos:self.pos + n]
self.pos += n
return result
class MockConfig:
"""Mock gunicorn config for testing."""
def __init__(self):
self.is_ssl = False
self.proxy_protocol = False
self.proxy_allow_ips = ["127.0.0.1"]
self.forwarded_allow_ips = ["127.0.0.1"]
self.secure_scheme_headers = {}
self.forwarder_headers = []
self.limit_request_line = 8190
self.limit_request_fields = 100
self.limit_request_field_size = 8190
self.permit_unconventional_http_method = False
self.permit_unconventional_http_version = False
self.permit_obsolete_folding = False
self.casefold_http_method = False
self.strip_header_spaces = False
self.header_map = "refuse"
# AsyncUnreader Tests
@pytest.mark.asyncio
async def test_async_unreader_read_chunk():
"""Test basic chunk reading."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b"hello world"
@pytest.mark.asyncio
async def test_async_unreader_read_size():
"""Test reading specific size."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
data = await unreader.read(5)
assert data == b"hello"
@pytest.mark.asyncio
async def test_async_unreader_unread():
"""Test unread functionality."""
reader = MockStreamReader(b"hello world")
unreader = AsyncUnreader(reader)
# Read all data
data = await unreader.read()
assert data == b"hello world"
# Unread some data
unreader.unread(b"world")
# Read again should get unread data
data = await unreader.read()
assert data == b"world"
@pytest.mark.asyncio
async def test_async_unreader_read_zero():
"""Test reading zero bytes."""
reader = MockStreamReader(b"hello")
unreader = AsyncUnreader(reader)
data = await unreader.read(0)
assert data == b""
@pytest.mark.asyncio
async def test_async_unreader_read_empty():
"""Test reading from empty stream."""
reader = MockStreamReader(b"")
unreader = AsyncUnreader(reader)
data = await unreader.read()
assert data == b""
# AsyncRequest Tests
@pytest.mark.asyncio
async def test_async_request_simple_get():
"""Test parsing a simple GET request."""
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/path"
assert request.version == (1, 1)
assert ("HOST", "localhost") in request.headers
@pytest.mark.asyncio
async def test_async_request_with_query():
"""Test parsing request with query string."""
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "GET"
assert request.path == "/search"
assert request.query == "q=test&page=1"
@pytest.mark.asyncio
async def test_async_request_post_with_body():
"""Test parsing POST request with body."""
request_data = (
b"POST /submit HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 11\r\n"
b"\r\n"
b"hello=world"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.method == "POST"
assert request.path == "/submit"
assert request.content_length == 11
# Read body
body = await request.read_body(100)
assert body == b"hello=world"
@pytest.mark.asyncio
async def test_async_request_multiple_headers():
"""Test parsing request with multiple headers."""
request_data = (
b"GET / HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Accept: text/html\r\n"
b"Accept-Language: en-US\r\n"
b"Connection: keep-alive\r\n"
b"\r\n"
)
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert len(request.headers) == 4
assert request.get_header("HOST") == "localhost"
assert request.get_header("ACCEPT") == "text/html"
@pytest.mark.asyncio
async def test_async_request_should_close_http10():
"""Test connection close detection for HTTP/1.0."""
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.version == (1, 0)
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_should_close_connection_header():
"""Test connection close detection with Connection header."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is True
@pytest.mark.asyncio
async def test_async_request_keepalive():
"""Test keepalive detection."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.should_close() is False
@pytest.mark.asyncio
async def test_async_request_no_body_for_get():
"""Test that GET requests have no body by default."""
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
assert request.content_length == 0
body = await request.read_body()
assert body == b""
# Error handling tests
@pytest.mark.asyncio
async def test_async_request_invalid_method():
"""Test invalid HTTP method detection."""
from gunicorn.http.errors import InvalidRequestMethod
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidRequestMethod):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
@pytest.mark.asyncio
async def test_async_request_invalid_http_version():
"""Test invalid HTTP version detection."""
from gunicorn.http.errors import InvalidHTTPVersion
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
reader = MockStreamReader(request_data)
unreader = AsyncUnreader(reader)
cfg = MockConfig()
with pytest.raises(InvalidHTTPVersion):
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))

643
tests/test_asgi_worker.py Normal file
View File

@ -0,0 +1,643 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for the ASGI worker.
Includes unit tests for worker components and integration tests
that actually start the server and make HTTP requests.
"""
import asyncio
import errno
import os
import signal
import socket
import sys
import time
import threading
from unittest import mock
import pytest
from gunicorn.config import Config
from gunicorn.workers import gasgi
# ============================================================================
# Mock Classes
# ============================================================================
class FakeSocket:
"""Mock socket for testing."""
def __init__(self, data=b''):
self.data = data
self.closed = False
self.blocking = True
self._fileno = id(self) % 65536
def fileno(self):
return self._fileno
def setblocking(self, blocking):
self.blocking = blocking
def recv(self, size):
if self.closed:
raise OSError(errno.EBADF, "Bad file descriptor")
result = self.data[:size]
self.data = self.data[size:]
return result
def send(self, data):
if self.closed:
raise OSError(errno.EPIPE, "Broken pipe")
return len(data)
def close(self):
self.closed = True
def getsockname(self):
return ('127.0.0.1', 8000)
def getpeername(self):
return ('127.0.0.1', 12345)
class FakeApp:
"""Mock ASGI application for testing."""
def __init__(self):
self.calls = []
def wsgi(self):
return self.asgi_app
async def asgi_app(self, scope, receive, send):
self.calls.append(scope)
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
elif scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello from ASGI!",
})
class FakeListener:
"""Mock listener socket."""
def __init__(self):
self.sock = FakeSocket()
def getsockname(self):
return ('127.0.0.1', 8000)
def close(self):
self.sock.close()
def __str__(self):
return "http://127.0.0.1:8000"
# ============================================================================
# Helper Functions
# ============================================================================
def _has_uvloop():
"""Check if uvloop is available."""
try:
import uvloop
return True
except ImportError:
return False
# ============================================================================
# Unit Tests for ASGIWorker
# ============================================================================
class TestASGIWorkerInit:
"""Tests for ASGIWorker initialization."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_worker_init(self):
"""Test worker initialization."""
worker = self.create_worker()
assert worker.worker_connections == 1000
assert worker.nr_conns == 0
assert worker.loop is None
assert worker.servers == []
assert worker.state == {}
def test_worker_connections_config(self):
"""Test worker_connections configuration."""
worker = self.create_worker(worker_connections=500)
assert worker.worker_connections == 500
class TestASGIWorkerEventLoop:
"""Tests for event loop setup."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_setup_asyncio_loop(self):
"""Test asyncio event loop setup."""
worker = self.create_worker(asgi_loop='asyncio')
worker._setup_event_loop()
assert worker.loop is not None
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
worker.loop.close()
def test_setup_auto_loop_falls_back_to_asyncio(self):
"""Test that auto mode uses asyncio when uvloop unavailable."""
worker = self.create_worker(asgi_loop='auto')
# Mock uvloop import failure
with mock.patch.dict('sys.modules', {'uvloop': None}):
worker._setup_event_loop()
assert worker.loop is not None
worker.loop.close()
@pytest.mark.skipif(
not _has_uvloop(),
reason="uvloop not installed"
)
def test_setup_uvloop(self):
"""Test uvloop event loop setup."""
worker = self.create_worker(asgi_loop='uvloop')
worker._setup_event_loop()
import uvloop
assert isinstance(worker.loop, uvloop.Loop)
worker.loop.close()
class TestASGIWorkerSignals:
"""Tests for signal handling."""
def create_worker(self):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
cfg.set('graceful_timeout', 5)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
worker._setup_event_loop()
return worker
def test_handle_exit_sets_alive_false(self):
"""Test that exit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
worker.handle_exit_signal()
assert worker.alive is False
worker.loop.close()
def test_handle_quit_sets_alive_false(self):
"""Test that quit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
# Mock the worker_int callback on the worker's cfg settings
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
worker.handle_quit_signal()
assert worker.alive is False
worker.loop.close()
# ============================================================================
# Tests for Lifespan Protocol
# ============================================================================
class TestLifespanManager:
"""Tests for ASGI lifespan protocol."""
@pytest.mark.asyncio
async def test_lifespan_startup_complete(self):
"""Test successful lifespan startup."""
from gunicorn.asgi.lifespan import LifespanManager
startup_called = False
shutdown_called = False
async def app(scope, receive, send):
nonlocal startup_called, shutdown_called
assert scope["type"] == "lifespan"
while True:
message = await receive()
if message["type"] == "lifespan.startup":
startup_called = True
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
shutdown_called = True
await send({"type": "lifespan.shutdown.complete"})
return
manager = LifespanManager(app, mock.Mock())
await manager.startup()
assert startup_called
assert manager._startup_complete.is_set()
assert not manager._startup_failed
await manager.shutdown()
assert shutdown_called
@pytest.mark.asyncio
async def test_lifespan_startup_failed(self):
"""Test lifespan startup failure."""
from gunicorn.asgi.lifespan import LifespanManager
async def app(scope, receive, send):
message = await receive()
if message["type"] == "lifespan.startup":
await send({
"type": "lifespan.startup.failed",
"message": "Database connection failed"
})
manager = LifespanManager(app, mock.Mock())
with pytest.raises(RuntimeError, match="Database connection failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_lifespan_state_shared(self):
"""Test that lifespan state is shared with app."""
from gunicorn.asgi.lifespan import LifespanManager
state = {}
async def app(scope, receive, send):
assert "state" in scope
scope["state"]["db"] = "connected"
message = await receive()
await send({"type": "lifespan.startup.complete"})
message = await receive()
await send({"type": "lifespan.shutdown.complete"})
manager = LifespanManager(app, mock.Mock(), state)
await manager.startup()
assert state.get("db") == "connected"
await manager.shutdown()
# ============================================================================
# Tests for WebSocket Protocol
# ============================================================================
class TestWebSocketProtocol:
"""Tests for WebSocket protocol handling."""
def test_websocket_guid(self):
"""Test WebSocket GUID constant."""
from gunicorn.asgi.websocket import WS_GUID
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def test_websocket_opcodes(self):
"""Test WebSocket opcode constants."""
from gunicorn.asgi import websocket
assert websocket.OPCODE_TEXT == 0x1
assert websocket.OPCODE_BINARY == 0x2
assert websocket.OPCODE_CLOSE == 0x8
assert websocket.OPCODE_PING == 0x9
assert websocket.OPCODE_PONG == 0xA
def test_websocket_accept_key_calculation(self):
"""Test WebSocket accept key calculation per RFC 6455."""
import base64
import hashlib
from gunicorn.asgi.websocket import WS_GUID
# Example from RFC 6455
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
accept_key = base64.b64encode(
hashlib.sha1(client_key + WS_GUID).digest()
).decode("ascii")
assert accept_key == expected_accept
def test_websocket_frame_masking(self):
"""Test WebSocket frame unmasking."""
from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
# Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
unmasked = protocol._unmask(masked_data, masking_key)
assert unmasked == b"Hello"
def test_websocket_frame_masking_empty(self):
"""Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key)
assert unmasked == b""
# ============================================================================
# Integration Tests
# ============================================================================
class TestASGIIntegration:
"""Integration tests that start actual servers."""
@pytest.fixture
def free_port(self):
"""Get a free port for testing."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
@pytest.mark.asyncio
async def test_http_request_response(self, free_port):
"""Test basic HTTP request/response cycle."""
# Simple ASGI app
async def app(scope, receive, send):
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello, World!",
})
# Start server
loop = asyncio.get_event_loop()
server = await loop.create_server(
lambda: _TestProtocol(app),
'127.0.0.1',
free_port,
)
try:
# Use asyncio to make HTTP request
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
writer.write(request.encode())
await writer.drain()
# Read response
response = await reader.read(4096)
response_text = response.decode()
assert "HTTP/1.1 200" in response_text
assert "Hello, World!" in response_text
writer.close()
await writer.wait_closed()
finally:
server.close()
await server.wait_closed()
class _TestProtocol(asyncio.Protocol):
"""Minimal protocol for integration testing."""
def __init__(self, app):
self.app = app
self.transport = None
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
# Very simple HTTP parsing for testing
asyncio.create_task(self._handle(data))
async def _handle(self, data):
# Parse basic HTTP request
lines = data.decode().split('\r\n')
method, path, _ = lines[0].split(' ')
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method,
"path": path,
"query_string": b"",
"headers": [],
"server": ("127.0.0.1", 8000),
"client": ("127.0.0.1", 12345),
}
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message):
if message["type"] == "http.response.start":
status = message["status"]
headers = message.get("headers", [])
response = f"HTTP/1.1 {status} OK\r\n"
for name, value in headers:
if isinstance(name, bytes):
name = name.decode()
if isinstance(value, bytes):
value = value.decode()
response += f"{name}: {value}\r\n"
response += "\r\n"
self.transport.write(response.encode())
elif message["type"] == "http.response.body":
body = message.get("body", b"")
self.transport.write(body)
if not message.get("more_body", False):
self.transport.close()
await self.app(scope, receive, send)
# ============================================================================
# ASGI Protocol Tests
# ============================================================================
class TestASGIProtocol:
"""Tests for ASGIProtocol."""
def test_reason_phrases(self):
"""Test HTTP reason phrase lookup."""
from gunicorn.asgi.protocol import ASGIProtocol
# Create minimal worker mock
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
assert protocol._get_reason_phrase(200) == "OK"
assert protocol._get_reason_phrase(404) == "Not Found"
assert protocol._get_reason_phrase(500) == "Internal Server Error"
assert protocol._get_reason_phrase(999) == "Unknown"
def test_scope_building(self):
"""Test HTTP scope building."""
from gunicorn.asgi.protocol import ASGIProtocol
from gunicorn.asgi.message import AsyncRequest
worker = mock.Mock()
worker.cfg = Config()
worker.cfg.set('root_path', '/api')
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
# Create mock request
request = mock.Mock()
request.method = "GET"
request.path = "/users"
request.query = "page=1"
request.version = (1, 1)
request.scheme = "http"
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000), # sockname
("127.0.0.1", 12345), # peername
)
assert scope["type"] == "http"
assert scope["method"] == "GET"
assert scope["path"] == "/users"
assert scope["query_string"] == b"page=1"
assert scope["root_path"] == "/api"
assert scope["http_version"] == "1.1"
# ============================================================================
# Config Tests
# ============================================================================
class TestASGIConfig:
"""Tests for ASGI configuration options."""
def test_asgi_loop_default(self):
"""Test default asgi_loop value."""
cfg = Config()
assert cfg.asgi_loop == "auto"
def test_asgi_loop_validation(self):
"""Test asgi_loop validation."""
cfg = Config()
cfg.set('asgi_loop', 'asyncio')
assert cfg.asgi_loop == 'asyncio'
cfg.set('asgi_loop', 'uvloop')
assert cfg.asgi_loop == 'uvloop'
with pytest.raises(ValueError):
cfg.set('asgi_loop', 'invalid')
def test_asgi_lifespan_default(self):
"""Test default asgi_lifespan value."""
cfg = Config()
assert cfg.asgi_lifespan == "auto"
def test_asgi_lifespan_validation(self):
"""Test asgi_lifespan validation."""
cfg = Config()
cfg.set('asgi_lifespan', 'on')
assert cfg.asgi_lifespan == 'on'
cfg.set('asgi_lifespan', 'off')
assert cfg.asgi_lifespan == 'off'
with pytest.raises(ValueError):
cfg.set('asgi_lifespan', 'invalid')
def test_root_path_default(self):
"""Test default root_path value."""
cfg = Config()
assert cfg.root_path == ""
def test_root_path_setting(self):
"""Test root_path configuration."""
cfg = Config()
cfg.set('root_path', '/api/v1')
assert cfg.root_path == '/api/v1'

435
tests/test_uwsgi.py Normal file
View File

@ -0,0 +1,435 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import io
import pytest
from unittest import mock
from gunicorn.uwsgi import (
UWSGIRequest,
UWSGIParser,
UWSGIParseException,
InvalidUWSGIHeader,
UnsupportedModifier,
ForbiddenUWSGIRequest,
)
from gunicorn.http.unreader import IterUnreader
def make_uwsgi_packet(vars_dict, modifier1=0, modifier2=0):
"""Create uWSGI packet for testing.
Args:
vars_dict: Dict of WSGI environ variables
modifier1: Packet type (0 = WSGI request)
modifier2: Additional flags
Returns:
bytes: Complete uWSGI packet
"""
vars_data = b''
for key, value in vars_dict.items():
k = key.encode('latin-1')
v = value.encode('latin-1')
vars_data += len(k).to_bytes(2, 'little') + k
vars_data += len(v).to_bytes(2, 'little') + v
header = bytes([modifier1]) + len(vars_data).to_bytes(2, 'little') + bytes([modifier2])
return header + vars_data
def make_uwsgi_packet_with_body(vars_dict, body=b'', modifier1=0, modifier2=0):
"""Create uWSGI packet with body for testing."""
if body:
vars_dict = dict(vars_dict)
vars_dict['CONTENT_LENGTH'] = str(len(body))
return make_uwsgi_packet(vars_dict, modifier1, modifier2) + body
class MockConfig:
"""Mock config object for testing."""
def __init__(self, is_ssl=False, uwsgi_allow_ips=None):
self.is_ssl = is_ssl
self.uwsgi_allow_ips = uwsgi_allow_ips or ['127.0.0.1', '::1']
class TestUWSGIPacketConstruction:
"""Test the packet construction helper."""
def test_empty_vars(self):
packet = make_uwsgi_packet({})
assert packet == b'\x00\x00\x00\x00' # modifier1=0, size=0, modifier2=0
def test_single_var(self):
packet = make_uwsgi_packet({'KEY': 'val'})
# Header: modifier1(0) + size(10 in LE) + modifier2(0)
# Var: key_size(3 in LE) + 'KEY' + val_size(3 in LE) + 'val'
# Size = 2 + 3 + 2 + 3 = 10 bytes
expected_header = b'\x00\x0a\x00\x00'
expected_var = b'\x03\x00KEY\x03\x00val'
assert packet == expected_header + expected_var
def test_multiple_vars(self):
packet = make_uwsgi_packet({'A': '1', 'B': '2'})
assert len(packet) == 4 + (2 + 1 + 2 + 1) * 2 # header + 2 vars
class TestUWSGIRequest:
"""Test UWSGIRequest parsing."""
def test_parse_simple_request(self):
"""Test parsing a simple GET request."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/test',
'QUERY_STRING': 'foo=bar',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.method == 'GET'
assert req.path == '/test'
assert req.query == 'foo=bar'
assert req.uri == '/test?foo=bar'
def test_parse_post_request_with_body(self):
"""Test parsing a POST request with body."""
body = b'name=test&value=123'
packet = make_uwsgi_packet_with_body({
'REQUEST_METHOD': 'POST',
'PATH_INFO': '/submit',
'CONTENT_TYPE': 'application/x-www-form-urlencoded',
}, body)
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.method == 'POST'
assert req.path == '/submit'
assert req.body.read() == body
def test_parse_headers(self):
"""Test that HTTP_* vars become headers."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'HTTP_HOST': 'example.com',
'HTTP_USER_AGENT': 'TestClient/1.0',
'HTTP_ACCEPT': 'text/html',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
headers_dict = dict(req.headers)
assert headers_dict['HOST'] == 'example.com'
assert headers_dict['USER-AGENT'] == 'TestClient/1.0'
assert headers_dict['ACCEPT'] == 'text/html'
def test_parse_content_type_header(self):
"""Test that CONTENT_TYPE becomes a header."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'POST',
'PATH_INFO': '/',
'CONTENT_TYPE': 'application/json',
'CONTENT_LENGTH': '0',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
headers_dict = dict(req.headers)
assert headers_dict['CONTENT-TYPE'] == 'application/json'
assert headers_dict['CONTENT-LENGTH'] == '0'
def test_https_scheme(self):
"""Test scheme detection from HTTPS variable."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'HTTPS': 'on',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.scheme == 'https'
def test_wsgi_url_scheme(self):
"""Test scheme from wsgi.url_scheme variable."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'wsgi.url_scheme': 'https',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.scheme == 'https'
def test_default_values(self):
"""Test default values when vars are missing."""
packet = make_uwsgi_packet({})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.method == 'GET'
assert req.path == '/'
assert req.query == ''
assert req.uri == '/'
def test_uwsgi_vars_preserved(self):
"""Test that all vars are preserved in uwsgi_vars."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'SERVER_NAME': 'localhost',
'SERVER_PORT': '8000',
'CUSTOM_VAR': 'custom_value',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.uwsgi_vars['SERVER_NAME'] == 'localhost'
assert req.uwsgi_vars['SERVER_PORT'] == '8000'
assert req.uwsgi_vars['CUSTOM_VAR'] == 'custom_value'
class TestUWSGIRequestErrors:
"""Test UWSGIRequest error handling."""
def test_incomplete_header(self):
"""Test error on incomplete header."""
unreader = IterUnreader([b'\x00\x00']) # Only 2 bytes
cfg = MockConfig()
with pytest.raises(InvalidUWSGIHeader) as exc_info:
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert 'incomplete header' in str(exc_info.value)
def test_incomplete_vars_block(self):
"""Test error on truncated vars block."""
# Header says 100 bytes of vars, but we only provide 10
header = b'\x00\x64\x00\x00' # modifier1=0, size=100, modifier2=0
unreader = IterUnreader([header + b'1234567890'])
cfg = MockConfig()
with pytest.raises(InvalidUWSGIHeader) as exc_info:
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert 'incomplete vars block' in str(exc_info.value)
def test_unsupported_modifier(self):
"""Test error on non-zero modifier1."""
packet = bytes([1]) + b'\x00\x00\x00' # modifier1=1
unreader = IterUnreader([packet])
cfg = MockConfig()
with pytest.raises(UnsupportedModifier) as exc_info:
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert exc_info.value.modifier == 1
assert exc_info.value.code == 501
def test_truncated_key_size(self):
"""Test error on truncated key size."""
header = b'\x00\x01\x00\x00' # size=1, but need at least 2 bytes for key_size
unreader = IterUnreader([header + b'X'])
cfg = MockConfig()
with pytest.raises(InvalidUWSGIHeader) as exc_info:
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert 'truncated' in str(exc_info.value)
def test_forbidden_ip(self):
"""Test error when source IP not in allow list."""
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
unreader = IterUnreader([packet])
cfg = MockConfig(uwsgi_allow_ips=['192.168.1.1'])
with pytest.raises(ForbiddenUWSGIRequest) as exc_info:
UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
assert exc_info.value.code == 403
assert '10.0.0.1' in str(exc_info.value)
def test_allowed_ip_wildcard(self):
"""Test that wildcard allows any IP."""
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
unreader = IterUnreader([packet])
cfg = MockConfig(uwsgi_allow_ips=['*'])
# Should not raise
req = UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
assert req.method == 'GET'
def test_unix_socket_always_allowed(self):
"""Test that UNIX socket connections are always allowed."""
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
unreader = IterUnreader([packet])
cfg = MockConfig(uwsgi_allow_ips=['127.0.0.1'])
# UNIX socket has non-tuple peer_addr
req = UWSGIRequest(cfg, unreader, None)
assert req.method == 'GET'
class TestUWSGIRequestConnection:
"""Test connection handling."""
def test_should_close_default(self):
"""Test default keep-alive behavior."""
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.should_close() is False
def test_should_close_connection_close(self):
"""Test Connection: close header."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'HTTP_CONNECTION': 'close',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.should_close() is True
def test_should_close_connection_keepalive(self):
"""Test Connection: keep-alive header."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/',
'HTTP_CONNECTION': 'keep-alive',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
assert req.should_close() is False
def test_force_close(self):
"""Test force_close method."""
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
req.force_close()
assert req.should_close() is True
class TestUWSGIParser:
"""Test UWSGIParser."""
def test_parser_iteration(self):
"""Test iterating over parser for multiple requests."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'GET',
'PATH_INFO': '/test',
'HTTP_CONNECTION': 'close', # Single request
})
cfg = MockConfig()
# Parser expects an iterable source, not an unreader
parser = UWSGIParser(cfg, [packet], ('127.0.0.1', 12345))
req = next(parser)
assert req.method == 'GET'
assert req.path == '/test'
def test_parser_mesg_class(self):
"""Test that parser uses UWSGIRequest."""
assert UWSGIParser.mesg_class is UWSGIRequest
class TestExceptionStrings:
"""Test exception string representations."""
def test_invalid_uwsgi_header_str(self):
exc = InvalidUWSGIHeader("test message")
assert str(exc) == "Invalid uWSGI header: test message"
assert exc.code == 400
def test_unsupported_modifier_str(self):
exc = UnsupportedModifier(5)
assert str(exc) == "Unsupported uWSGI modifier1: 5"
assert exc.code == 501
def test_forbidden_uwsgi_request_str(self):
exc = ForbiddenUWSGIRequest("10.0.0.1")
assert str(exc) == "uWSGI request from '10.0.0.1' not allowed"
assert exc.code == 403
class TestUWSGIBody:
"""Test body reading."""
def test_read_body_in_chunks(self):
"""Test reading body in multiple chunks."""
body = b'A' * 1000
packet = make_uwsgi_packet_with_body({
'REQUEST_METHOD': 'POST',
'PATH_INFO': '/',
}, body)
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
result = b''
chunk = req.body.read(100)
while chunk:
result += chunk
chunk = req.body.read(100)
assert result == body
def test_invalid_content_length(self):
"""Test handling of invalid CONTENT_LENGTH."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'POST',
'PATH_INFO': '/',
'CONTENT_LENGTH': 'invalid',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
# Invalid content length should default to 0
assert req.body.read() == b''
def test_negative_content_length(self):
"""Test handling of negative CONTENT_LENGTH."""
packet = make_uwsgi_packet({
'REQUEST_METHOD': 'POST',
'PATH_INFO': '/',
'CONTENT_LENGTH': '-5',
})
unreader = IterUnreader([packet])
cfg = MockConfig()
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
# Negative content length should default to 0
assert req.body.read() == b''