mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-04 11:41:32 +08:00
Merge pull request #3444 from benoitc/asgi-worker
Add native ASGI worker and uWSGI binary protocol support
This commit is contained in:
commit
5b50487bab
45
.github/workflows/docker-integration.yml
vendored
Normal file
45
.github/workflows/docker-integration.yml
vendored
Normal file
@ -0,0 +1,45 @@
|
||||
name: Docker Integration Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
paths:
|
||||
- 'gunicorn/uwsgi/**'
|
||||
- 'tests/docker/uwsgi/**'
|
||||
- '.github/workflows/docker-integration.yml'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'gunicorn/uwsgi/**'
|
||||
- 'tests/docker/uwsgi/**'
|
||||
- '.github/workflows/docker-integration.yml'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
FORCE_COLOR: 1
|
||||
|
||||
jobs:
|
||||
uwsgi-nginx:
|
||||
name: uWSGI Protocol with nginx
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
cache: pip
|
||||
cache-dependency-path: requirements_test.txt
|
||||
|
||||
- name: Install test dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest pytest-cov requests
|
||||
|
||||
- name: Run uWSGI integration tests
|
||||
run: |
|
||||
pytest tests/docker/uwsgi/ -v --tb=short
|
||||
2
.github/workflows/freebsd.yml
vendored
2
.github/workflows/freebsd.yml
vendored
@ -40,7 +40,7 @@ jobs:
|
||||
python${{ matrix.python-version }} -m venv venv
|
||||
. venv/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install pytest pytest-cov coverage
|
||||
pip install pytest pytest-cov pytest-asyncio coverage
|
||||
pip install -e .
|
||||
pytest --cov=gunicorn -v tests/ \
|
||||
--ignore=tests/workers/test_ggevent.py \
|
||||
|
||||
7
examples/asgi/__init__.py
Normal file
7
examples/asgi/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI example applications for gunicorn.
|
||||
"""
|
||||
130
examples/asgi/basic_app.py
Normal file
130
examples/asgi/basic_app.py
Normal file
@ -0,0 +1,130 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Basic ASGI application example.
|
||||
|
||||
Run with:
|
||||
gunicorn -k asgi examples.asgi.basic_app:app
|
||||
|
||||
Test with:
|
||||
curl http://127.0.0.1:8000/
|
||||
curl http://127.0.0.1:8000/hello
|
||||
curl -X POST http://127.0.0.1:8000/echo -d "test data"
|
||||
"""
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
"""Simple ASGI application demonstrating basic functionality."""
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await handle_lifespan(scope, receive, send)
|
||||
elif scope["type"] == "http":
|
||||
await handle_http(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||
|
||||
|
||||
async def handle_lifespan(scope, receive, send):
|
||||
"""Handle lifespan events (startup/shutdown)."""
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
print("ASGI application starting up...")
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
print("ASGI application shutting down...")
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
|
||||
async def handle_http(scope, receive, send):
|
||||
"""Handle HTTP requests."""
|
||||
path = scope["path"]
|
||||
method = scope["method"]
|
||||
|
||||
if path == "/" and method == "GET":
|
||||
await send_response(send, 200, b"Welcome to gunicorn ASGI!\n")
|
||||
|
||||
elif path == "/hello" and method == "GET":
|
||||
name = get_query_param(scope, "name", "World")
|
||||
body = f"Hello, {name}!\n".encode()
|
||||
await send_response(send, 200, body)
|
||||
|
||||
elif path == "/echo" and method == "POST":
|
||||
body = await read_body(receive)
|
||||
await send_response(send, 200, body, content_type=b"application/octet-stream")
|
||||
|
||||
elif path == "/headers":
|
||||
headers_info = format_headers(scope["headers"])
|
||||
await send_response(send, 200, headers_info.encode())
|
||||
|
||||
elif path == "/info":
|
||||
info = format_request_info(scope)
|
||||
await send_response(send, 200, info.encode(), content_type=b"application/json")
|
||||
|
||||
else:
|
||||
await send_response(send, 404, b"Not Found\n")
|
||||
|
||||
|
||||
async def send_response(send, status, body, content_type=b"text/plain"):
|
||||
"""Send an HTTP response."""
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [
|
||||
(b"content-type", content_type),
|
||||
(b"content-length", str(len(body)).encode()),
|
||||
],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
|
||||
|
||||
async def read_body(receive):
|
||||
"""Read the full request body."""
|
||||
body = b""
|
||||
while True:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
return body
|
||||
|
||||
|
||||
def get_query_param(scope, name, default=None):
|
||||
"""Get a query parameter value."""
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
for param in query_string.split("&"):
|
||||
if "=" in param:
|
||||
key, value = param.split("=", 1)
|
||||
if key == name:
|
||||
return value
|
||||
return default
|
||||
|
||||
|
||||
def format_headers(headers):
|
||||
"""Format headers for display."""
|
||||
lines = ["Request Headers:"]
|
||||
for name, value in headers:
|
||||
lines.append(f" {name.decode()}: {value.decode()}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def format_request_info(scope):
|
||||
"""Format request info as JSON."""
|
||||
import json
|
||||
info = {
|
||||
"method": scope["method"],
|
||||
"path": scope["path"],
|
||||
"query_string": scope.get("query_string", b"").decode(),
|
||||
"http_version": scope["http_version"],
|
||||
"scheme": scope["scheme"],
|
||||
"server": list(scope.get("server") or []),
|
||||
"client": list(scope.get("client") or []),
|
||||
"root_path": scope.get("root_path", ""),
|
||||
}
|
||||
return json.dumps(info, indent=2) + "\n"
|
||||
235
examples/asgi/websocket_app.py
Normal file
235
examples/asgi/websocket_app.py
Normal file
@ -0,0 +1,235 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
WebSocket ASGI application example.
|
||||
|
||||
Run with:
|
||||
gunicorn -k asgi examples.asgi.websocket_app:app
|
||||
|
||||
Test with:
|
||||
# Using websocat (install with: cargo install websocat)
|
||||
websocat ws://127.0.0.1:8000/ws
|
||||
|
||||
# Or using Python websockets library
|
||||
python -c "
|
||||
import asyncio
|
||||
import websockets
|
||||
async def test():
|
||||
async with websockets.connect('ws://127.0.0.1:8000/ws') as ws:
|
||||
await ws.send('Hello')
|
||||
print(await ws.recv())
|
||||
asyncio.run(test())
|
||||
"
|
||||
"""
|
||||
|
||||
|
||||
async def app(scope, receive, send):
|
||||
"""ASGI application with WebSocket support."""
|
||||
|
||||
if scope["type"] == "lifespan":
|
||||
await handle_lifespan(scope, receive, send)
|
||||
elif scope["type"] == "http":
|
||||
await handle_http(scope, receive, send)
|
||||
elif scope["type"] == "websocket":
|
||||
await handle_websocket(scope, receive, send)
|
||||
else:
|
||||
raise ValueError(f"Unknown scope type: {scope['type']}")
|
||||
|
||||
|
||||
async def handle_lifespan(scope, receive, send):
|
||||
"""Handle lifespan events."""
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
|
||||
async def handle_http(scope, receive, send):
|
||||
"""Handle HTTP requests - serve a simple HTML page for WebSocket testing."""
|
||||
path = scope["path"]
|
||||
|
||||
if path == "/":
|
||||
html = HTML_PAGE.encode()
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [
|
||||
(b"content-type", b"text/html"),
|
||||
(b"content-length", str(len(html)).encode()),
|
||||
],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": html,
|
||||
})
|
||||
else:
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 404,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Not Found",
|
||||
})
|
||||
|
||||
|
||||
async def handle_websocket(scope, receive, send):
|
||||
"""Handle WebSocket connections."""
|
||||
path = scope["path"]
|
||||
|
||||
if path == "/ws":
|
||||
await echo_websocket(scope, receive, send)
|
||||
elif path == "/ws/chat":
|
||||
await chat_websocket(scope, receive, send)
|
||||
else:
|
||||
# Reject the connection
|
||||
await send({"type": "websocket.close", "code": 4004})
|
||||
|
||||
|
||||
async def echo_websocket(scope, receive, send):
|
||||
"""Echo WebSocket - sends back whatever it receives."""
|
||||
# Wait for connection
|
||||
message = await receive()
|
||||
if message["type"] != "websocket.connect":
|
||||
return
|
||||
|
||||
# Accept the connection
|
||||
await send({"type": "websocket.accept"})
|
||||
|
||||
# Echo loop
|
||||
try:
|
||||
while True:
|
||||
message = await receive()
|
||||
|
||||
if message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
if message["type"] == "websocket.receive":
|
||||
if "text" in message:
|
||||
# Echo text back
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": f"Echo: {message['text']}"
|
||||
})
|
||||
elif "bytes" in message:
|
||||
# Echo bytes back
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"bytes": message["bytes"]
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"WebSocket error: {e}")
|
||||
finally:
|
||||
try:
|
||||
await send({"type": "websocket.close", "code": 1000})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def chat_websocket(scope, receive, send):
|
||||
"""Chat WebSocket - simple broadcast example."""
|
||||
message = await receive()
|
||||
if message["type"] != "websocket.connect":
|
||||
return
|
||||
|
||||
await send({
|
||||
"type": "websocket.accept",
|
||||
"subprotocol": "chat"
|
||||
})
|
||||
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": "Welcome to the chat! Send messages and they will be echoed back."
|
||||
})
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await receive()
|
||||
|
||||
if message["type"] == "websocket.disconnect":
|
||||
break
|
||||
|
||||
if message["type"] == "websocket.receive" and "text" in message:
|
||||
text = message["text"]
|
||||
await send({
|
||||
"type": "websocket.send",
|
||||
"text": f"[You]: {text}"
|
||||
})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
HTML_PAGE = """<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>WebSocket Test</title>
|
||||
<style>
|
||||
body { font-family: sans-serif; max-width: 800px; margin: 50px auto; padding: 20px; }
|
||||
#messages { border: 1px solid #ccc; height: 300px; overflow-y: auto; padding: 10px; margin-bottom: 10px; }
|
||||
#input { width: 80%; padding: 10px; }
|
||||
button { padding: 10px 20px; }
|
||||
.sent { color: blue; }
|
||||
.received { color: green; }
|
||||
.error { color: red; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebSocket Test</h1>
|
||||
<div id="messages"></div>
|
||||
<input type="text" id="input" placeholder="Type a message...">
|
||||
<button onclick="sendMessage()">Send</button>
|
||||
<button onclick="connectWS()">Connect</button>
|
||||
<button onclick="disconnectWS()">Disconnect</button>
|
||||
|
||||
<script>
|
||||
let ws = null;
|
||||
const messages = document.getElementById('messages');
|
||||
const input = document.getElementById('input');
|
||||
|
||||
function log(msg, className) {
|
||||
const div = document.createElement('div');
|
||||
div.className = className || '';
|
||||
div.textContent = msg;
|
||||
messages.appendChild(div);
|
||||
messages.scrollTop = messages.scrollHeight;
|
||||
}
|
||||
|
||||
function connectWS() {
|
||||
if (ws) {
|
||||
log('Already connected', 'error');
|
||||
return;
|
||||
}
|
||||
ws = new WebSocket('ws://' + window.location.host + '/ws');
|
||||
ws.onopen = () => log('Connected!', 'received');
|
||||
ws.onclose = () => { log('Disconnected', 'error'); ws = null; };
|
||||
ws.onerror = (e) => log('Error: ' + e, 'error');
|
||||
ws.onmessage = (e) => log(e.data, 'received');
|
||||
}
|
||||
|
||||
function disconnectWS() {
|
||||
if (ws) ws.close();
|
||||
}
|
||||
|
||||
function sendMessage() {
|
||||
if (!ws) { log('Not connected', 'error'); return; }
|
||||
const msg = input.value;
|
||||
if (!msg) return;
|
||||
ws.send(msg);
|
||||
log('Sent: ' + msg, 'sent');
|
||||
input.value = '';
|
||||
}
|
||||
|
||||
input.onkeypress = (e) => { if (e.key === 'Enter') sendMessage(); };
|
||||
|
||||
// Auto-connect
|
||||
connectWS();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
26
gunicorn/asgi/__init__.py
Normal file
26
gunicorn/asgi/__init__.py
Normal file
@ -0,0 +1,26 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI support for gunicorn.
|
||||
|
||||
This module provides native ASGI worker support, using gunicorn's own
|
||||
HTTP parsing infrastructure adapted for async I/O.
|
||||
|
||||
Components:
|
||||
- AsyncUnreader: Async socket reading with pushback buffer
|
||||
- AsyncRequest: Async HTTP request parser
|
||||
- ASGIProtocol: asyncio.Protocol implementation for HTTP handling
|
||||
- WebSocketProtocol: WebSocket protocol handler (RFC 6455)
|
||||
- LifespanManager: ASGI lifespan protocol support
|
||||
|
||||
Usage:
|
||||
gunicorn -k asgi myapp:app
|
||||
"""
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
__all__ = ['AsyncUnreader', 'AsyncRequest', 'LifespanManager']
|
||||
178
gunicorn/asgi/lifespan.py
Normal file
178
gunicorn/asgi/lifespan.py
Normal file
@ -0,0 +1,178 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI lifespan protocol manager.
|
||||
|
||||
Manages startup and shutdown events for ASGI applications,
|
||||
enabling frameworks like FastAPI to run initialization and
|
||||
cleanup code.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
class LifespanManager:
|
||||
"""Manages ASGI lifespan events (startup/shutdown).
|
||||
|
||||
The lifespan protocol allows ASGI applications to run code at
|
||||
startup and shutdown. This is essential for applications that
|
||||
need to initialize database connections, caches, or other
|
||||
resources.
|
||||
|
||||
ASGI lifespan messages:
|
||||
- Server sends: {"type": "lifespan.startup"}
|
||||
- App responds: {"type": "lifespan.startup.complete"} or
|
||||
{"type": "lifespan.startup.failed", "message": "..."}
|
||||
- Server sends: {"type": "lifespan.shutdown"}
|
||||
- App responds: {"type": "lifespan.shutdown.complete"}
|
||||
"""
|
||||
|
||||
def __init__(self, app, logger, state=None):
|
||||
"""Initialize the lifespan manager.
|
||||
|
||||
Args:
|
||||
app: ASGI application callable
|
||||
logger: Logger instance
|
||||
state: Shared state dict for the application
|
||||
"""
|
||||
self.app = app
|
||||
self.logger = logger
|
||||
self.state = state if state is not None else {}
|
||||
|
||||
self._startup_complete = asyncio.Event()
|
||||
self._shutdown_complete = asyncio.Event()
|
||||
self._startup_failed = False
|
||||
self._startup_error = None
|
||||
self._shutdown_error = None
|
||||
self._receive_queue = asyncio.Queue()
|
||||
self._task = None
|
||||
self._app_finished = False
|
||||
|
||||
async def startup(self):
|
||||
"""Run lifespan startup and wait for completion.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If startup fails or app doesn't support lifespan
|
||||
"""
|
||||
scope = {
|
||||
"type": "lifespan",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"state": self.state,
|
||||
}
|
||||
|
||||
# Send startup event
|
||||
await self._receive_queue.put({"type": "lifespan.startup"})
|
||||
|
||||
# Run lifespan in background task
|
||||
self._task = asyncio.create_task(self._run_lifespan(scope))
|
||||
|
||||
# Wait for startup with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._startup_complete.wait(),
|
||||
timeout=30.0 # Reasonable startup timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
raise RuntimeError("Lifespan startup timed out")
|
||||
|
||||
if self._startup_failed:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
msg = self._startup_error or "Unknown error"
|
||||
raise RuntimeError(f"Lifespan startup failed: {msg}")
|
||||
|
||||
self.logger.debug("ASGI lifespan startup complete")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Signal shutdown and wait for completion.
|
||||
|
||||
This should be called during graceful shutdown.
|
||||
"""
|
||||
if self._app_finished:
|
||||
self.logger.debug("ASGI lifespan already finished")
|
||||
return
|
||||
|
||||
# Send shutdown event
|
||||
await self._receive_queue.put({"type": "lifespan.shutdown"})
|
||||
|
||||
# Wait for shutdown with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._shutdown_complete.wait(),
|
||||
timeout=30.0 # Reasonable shutdown timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Lifespan shutdown timed out")
|
||||
|
||||
if self._shutdown_error:
|
||||
self.logger.error("Lifespan shutdown error: %s", self._shutdown_error)
|
||||
|
||||
# Cancel the task if still running
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self.logger.debug("ASGI lifespan shutdown complete")
|
||||
|
||||
async def _run_lifespan(self, scope):
|
||||
"""Run the ASGI lifespan protocol."""
|
||||
try:
|
||||
await self.app(scope, self._receive, self._send)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.debug("Lifespan application raised: %s", e)
|
||||
# If startup hasn't completed, mark it as failed
|
||||
if not self._startup_complete.is_set():
|
||||
self._startup_failed = True
|
||||
self._startup_error = str(e)
|
||||
self._startup_complete.set()
|
||||
# If shutdown hasn't completed, mark error
|
||||
elif not self._shutdown_complete.is_set():
|
||||
self._shutdown_error = str(e)
|
||||
self._shutdown_complete.set()
|
||||
finally:
|
||||
self._app_finished = True
|
||||
# Ensure events are set to unblock waiters
|
||||
if not self._startup_complete.is_set():
|
||||
self._startup_failed = True
|
||||
self._startup_error = "Application exited before startup complete"
|
||||
self._startup_complete.set()
|
||||
if not self._shutdown_complete.is_set():
|
||||
self._shutdown_complete.set()
|
||||
|
||||
async def _receive(self):
|
||||
"""ASGI receive callable for lifespan."""
|
||||
return await self._receive_queue.get()
|
||||
|
||||
async def _send(self, message):
|
||||
"""ASGI send callable for lifespan."""
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "lifespan.startup.complete":
|
||||
self._startup_complete.set()
|
||||
self.logger.debug("Received lifespan.startup.complete")
|
||||
|
||||
elif msg_type == "lifespan.startup.failed":
|
||||
self._startup_failed = True
|
||||
self._startup_error = message.get("message", "")
|
||||
self._startup_complete.set()
|
||||
self.logger.debug("Received lifespan.startup.failed: %s",
|
||||
self._startup_error)
|
||||
|
||||
elif msg_type == "lifespan.shutdown.complete":
|
||||
self._shutdown_complete.set()
|
||||
self.logger.debug("Received lifespan.shutdown.complete")
|
||||
|
||||
elif msg_type == "lifespan.shutdown.failed":
|
||||
self._shutdown_error = message.get("message", "")
|
||||
self._shutdown_complete.set()
|
||||
self.logger.debug("Received lifespan.shutdown.failed: %s",
|
||||
self._shutdown_error)
|
||||
562
gunicorn/asgi/message.py
Normal file
562
gunicorn/asgi/message.py
Normal file
@ -0,0 +1,562 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Async version of gunicorn/http/message.py for ASGI workers.
|
||||
|
||||
Reuses the parsing logic from the sync version, adapted for async I/O.
|
||||
"""
|
||||
|
||||
import io
|
||||
import re
|
||||
import socket
|
||||
|
||||
from gunicorn.http.errors import (
|
||||
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion,
|
||||
LimitRequestLine, LimitRequestHeaders,
|
||||
UnsupportedTransferCoding, ObsoleteFolding,
|
||||
InvalidProxyLine, ForbiddenProxyRequest,
|
||||
InvalidSchemeHeaders,
|
||||
)
|
||||
from gunicorn.util import bytes_to_str, split_request_uri
|
||||
|
||||
MAX_REQUEST_LINE = 8190
|
||||
MAX_HEADERS = 32768
|
||||
DEFAULT_MAX_HEADERFIELD_SIZE = 8190
|
||||
|
||||
# Reuse regex patterns from sync version
|
||||
RFC9110_5_6_2_TOKEN_SPECIALS = r"!#$%&'*+-.^_`|~"
|
||||
TOKEN_RE = re.compile(r"[%s0-9a-zA-Z]+" % (re.escape(RFC9110_5_6_2_TOKEN_SPECIALS)))
|
||||
METHOD_BADCHAR_RE = re.compile("[a-z#]")
|
||||
VERSION_RE = re.compile(r"HTTP/(\d)\.(\d)")
|
||||
RFC9110_5_5_INVALID_AND_DANGEROUS = re.compile(r"[\0\r\n]")
|
||||
|
||||
|
||||
class AsyncRequest:
|
||||
"""Async HTTP request parser.
|
||||
|
||||
Parses HTTP/1.x requests using async I/O, reusing gunicorn's
|
||||
parsing logic where possible.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.remote_addr = peer_addr
|
||||
self.req_number = req_number
|
||||
|
||||
self.version = None
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
self.scheme = "https" if cfg.is_ssl else "http"
|
||||
self.must_close = False
|
||||
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Request line limit
|
||||
self.limit_request_line = cfg.limit_request_line
|
||||
if (self.limit_request_line < 0
|
||||
or self.limit_request_line >= MAX_REQUEST_LINE):
|
||||
self.limit_request_line = MAX_REQUEST_LINE
|
||||
|
||||
# Headers limits
|
||||
self.limit_request_fields = cfg.limit_request_fields
|
||||
if (self.limit_request_fields <= 0
|
||||
or self.limit_request_fields > MAX_HEADERS):
|
||||
self.limit_request_fields = MAX_HEADERS
|
||||
|
||||
self.limit_request_field_size = cfg.limit_request_field_size
|
||||
if self.limit_request_field_size < 0:
|
||||
self.limit_request_field_size = DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
|
||||
# Max header buffer size
|
||||
max_header_field_size = self.limit_request_field_size or DEFAULT_MAX_HEADERFIELD_SIZE
|
||||
self.max_buffer_headers = self.limit_request_fields * \
|
||||
(max_header_field_size + 2) + 4
|
||||
|
||||
# Body-related state
|
||||
self.content_length = None
|
||||
self.chunked = False
|
||||
self._body_reader = None
|
||||
self._body_remaining = 0
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, cfg, unreader, peer_addr, req_number=1):
|
||||
"""Parse an HTTP request from the stream.
|
||||
|
||||
Args:
|
||||
cfg: gunicorn config object
|
||||
unreader: AsyncUnreader instance
|
||||
peer_addr: client address tuple
|
||||
req_number: request number on this connection (for keepalive)
|
||||
|
||||
Returns:
|
||||
AsyncRequest: Parsed request object
|
||||
|
||||
Raises:
|
||||
NoMoreData: If no data available
|
||||
Various parsing errors for malformed requests
|
||||
"""
|
||||
req = cls(cfg, unreader, peer_addr, req_number)
|
||||
await req._parse()
|
||||
return req
|
||||
|
||||
async def _parse(self):
|
||||
"""Parse the request from the unreader."""
|
||||
buf = io.BytesIO()
|
||||
await self._get_data(buf, stop=True)
|
||||
|
||||
# Get request line
|
||||
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||
|
||||
# Proxy protocol
|
||||
if self._proxy_protocol(bytes_to_str(line)):
|
||||
# Get next request line
|
||||
buf = io.BytesIO()
|
||||
buf.write(rbuf)
|
||||
line, rbuf = await self._read_line(buf, self.limit_request_line)
|
||||
|
||||
self._parse_request_line(line)
|
||||
buf = io.BytesIO()
|
||||
buf.write(rbuf)
|
||||
|
||||
# Headers
|
||||
data = buf.getvalue()
|
||||
|
||||
while True:
|
||||
idx = data.find(b"\r\n\r\n")
|
||||
done = data[:2] == b"\r\n"
|
||||
|
||||
if idx < 0 and not done:
|
||||
await self._get_data(buf)
|
||||
data = buf.getvalue()
|
||||
if len(data) > self.max_buffer_headers:
|
||||
raise LimitRequestHeaders("max buffer headers")
|
||||
else:
|
||||
break
|
||||
|
||||
if done:
|
||||
self.unreader.unread(data[2:])
|
||||
else:
|
||||
self.headers = self._parse_headers(data[:idx], from_trailer=False)
|
||||
self.unreader.unread(data[idx + 4:])
|
||||
|
||||
self._set_body_reader()
|
||||
|
||||
async def _get_data(self, buf, stop=False):
|
||||
"""Read data from unreader into buffer."""
|
||||
data = await self.unreader.read()
|
||||
if not data:
|
||||
if stop:
|
||||
raise StopIteration()
|
||||
raise NoMoreData(buf.getvalue())
|
||||
buf.write(data)
|
||||
|
||||
async def _read_line(self, buf, limit=0):
|
||||
"""Read a line from the buffer/stream."""
|
||||
data = buf.getvalue()
|
||||
|
||||
while True:
|
||||
idx = data.find(b"\r\n")
|
||||
if idx >= 0:
|
||||
if idx > limit > 0:
|
||||
raise LimitRequestLine(idx, limit)
|
||||
break
|
||||
if len(data) - 2 > limit > 0:
|
||||
raise LimitRequestLine(len(data), limit)
|
||||
await self._get_data(buf)
|
||||
data = buf.getvalue()
|
||||
|
||||
return (data[:idx], data[idx + 2:])
|
||||
|
||||
def _proxy_protocol(self, line):
|
||||
"""Detect, check and parse proxy protocol."""
|
||||
if not self.cfg.proxy_protocol:
|
||||
return False
|
||||
|
||||
if self.req_number != 1:
|
||||
return False
|
||||
|
||||
if not line.startswith("PROXY"):
|
||||
return False
|
||||
|
||||
self._proxy_protocol_access_check()
|
||||
self._parse_proxy_protocol(line)
|
||||
|
||||
return True
|
||||
|
||||
def _proxy_protocol_access_check(self):
|
||||
"""Check if proxy protocol is allowed from this peer."""
|
||||
if ("*" not in self.cfg.proxy_allow_ips and
|
||||
isinstance(self.peer_addr, tuple) and
|
||||
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
|
||||
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||
|
||||
def _parse_proxy_protocol(self, line):
|
||||
"""Parse proxy protocol header line."""
|
||||
bits = line.split(" ")
|
||||
|
||||
if len(bits) != 6:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
proto = bits[1]
|
||||
s_addr = bits[2]
|
||||
d_addr = bits[3]
|
||||
|
||||
if proto not in ["TCP4", "TCP6"]:
|
||||
raise InvalidProxyLine("protocol '%s' not supported" % proto)
|
||||
|
||||
if proto == "TCP4":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET, s_addr)
|
||||
socket.inet_pton(socket.AF_INET, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
elif proto == "TCP6":
|
||||
try:
|
||||
socket.inet_pton(socket.AF_INET6, s_addr)
|
||||
socket.inet_pton(socket.AF_INET6, d_addr)
|
||||
except OSError:
|
||||
raise InvalidProxyLine(line)
|
||||
|
||||
try:
|
||||
s_port = int(bits[4])
|
||||
d_port = int(bits[5])
|
||||
except ValueError:
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
|
||||
raise InvalidProxyLine("invalid port %s" % line)
|
||||
|
||||
self.proxy_protocol_info = {
|
||||
"proxy_protocol": proto,
|
||||
"client_addr": s_addr,
|
||||
"client_port": s_port,
|
||||
"proxy_addr": d_addr,
|
||||
"proxy_port": d_port
|
||||
}
|
||||
|
||||
def _parse_request_line(self, line_bytes):
|
||||
"""Parse the HTTP request line."""
|
||||
bits = [bytes_to_str(bit) for bit in line_bytes.split(b" ", 2)]
|
||||
if len(bits) != 3:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
# Method
|
||||
self.method = bits[0]
|
||||
|
||||
if not self.cfg.permit_unconventional_http_method:
|
||||
if METHOD_BADCHAR_RE.search(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not 3 <= len(bits[0]) <= 20:
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if not TOKEN_RE.fullmatch(self.method):
|
||||
raise InvalidRequestMethod(self.method)
|
||||
if self.cfg.casefold_http_method:
|
||||
self.method = self.method.upper()
|
||||
|
||||
# URI
|
||||
self.uri = bits[1]
|
||||
|
||||
if len(self.uri) == 0:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
|
||||
try:
|
||||
parts = split_request_uri(self.uri)
|
||||
except ValueError:
|
||||
raise InvalidRequestLine(bytes_to_str(line_bytes))
|
||||
self.path = parts.path or ""
|
||||
self.query = parts.query or ""
|
||||
self.fragment = parts.fragment or ""
|
||||
|
||||
# Version
|
||||
match = VERSION_RE.fullmatch(bits[2])
|
||||
if match is None:
|
||||
raise InvalidHTTPVersion(bits[2])
|
||||
self.version = (int(match.group(1)), int(match.group(2)))
|
||||
if not (1, 0) <= self.version < (2, 0):
|
||||
if not self.cfg.permit_unconventional_http_version:
|
||||
raise InvalidHTTPVersion(self.version)
|
||||
|
||||
def _parse_headers(self, data, from_trailer=False):
|
||||
"""Parse HTTP headers from raw data."""
|
||||
cfg = self.cfg
|
||||
headers = []
|
||||
|
||||
lines = [bytes_to_str(line) for line in data.split(b"\r\n")]
|
||||
|
||||
# Handle scheme headers
|
||||
scheme_header = False
|
||||
secure_scheme_headers = {}
|
||||
forwarder_headers = []
|
||||
if from_trailer:
|
||||
pass
|
||||
elif ('*' in cfg.forwarded_allow_ips or
|
||||
not isinstance(self.peer_addr, tuple)
|
||||
or self.peer_addr[0] in cfg.forwarded_allow_ips):
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
forwarder_headers = cfg.forwarder_headers
|
||||
|
||||
while lines:
|
||||
if len(headers) >= self.limit_request_fields:
|
||||
raise LimitRequestHeaders("limit request headers fields")
|
||||
|
||||
curr = lines.pop(0)
|
||||
header_length = len(curr) + len("\r\n")
|
||||
if curr.find(":") <= 0:
|
||||
raise InvalidHeader(curr)
|
||||
name, value = curr.split(":", 1)
|
||||
if self.cfg.strip_header_spaces:
|
||||
name = name.rstrip(" \t")
|
||||
if not TOKEN_RE.fullmatch(name):
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
name = name.upper()
|
||||
value = [value.strip(" \t")]
|
||||
|
||||
# Consume value continuation lines
|
||||
while lines and lines[0].startswith((" ", "\t")):
|
||||
if not self.cfg.permit_obsolete_folding:
|
||||
raise ObsoleteFolding(name)
|
||||
curr = lines.pop(0)
|
||||
header_length += len(curr) + len("\r\n")
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
value.append(curr.strip("\t "))
|
||||
value = " ".join(value)
|
||||
|
||||
if RFC9110_5_5_INVALID_AND_DANGEROUS.search(value):
|
||||
raise InvalidHeader(name)
|
||||
|
||||
if header_length > self.limit_request_field_size > 0:
|
||||
raise LimitRequestHeaders("limit request headers fields size")
|
||||
|
||||
if name in secure_scheme_headers:
|
||||
secure = value == secure_scheme_headers[name]
|
||||
scheme = "https" if secure else "http"
|
||||
if scheme_header:
|
||||
if scheme != self.scheme:
|
||||
raise InvalidSchemeHeaders()
|
||||
else:
|
||||
scheme_header = True
|
||||
self.scheme = scheme
|
||||
|
||||
if "_" in name:
|
||||
if name in forwarder_headers or "*" in forwarder_headers:
|
||||
pass
|
||||
elif self.cfg.header_map == "dangerous":
|
||||
pass
|
||||
elif self.cfg.header_map == "drop":
|
||||
continue
|
||||
else:
|
||||
raise InvalidHeaderName(name)
|
||||
|
||||
headers.append((name, value))
|
||||
|
||||
return headers
|
||||
|
||||
def _set_body_reader(self):
|
||||
"""Determine how to read the request body."""
|
||||
chunked = False
|
||||
content_length = None
|
||||
|
||||
for (name, value) in self.headers:
|
||||
if name == "CONTENT-LENGTH":
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
content_length = value
|
||||
elif name == "TRANSFER-ENCODING":
|
||||
vals = [v.strip() for v in value.split(',')]
|
||||
for val in vals:
|
||||
if val.lower() == "chunked":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
chunked = True
|
||||
elif val.lower() == "identity":
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
elif val.lower() in ('compress', 'deflate', 'gzip'):
|
||||
if chunked:
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
self.force_close()
|
||||
else:
|
||||
raise UnsupportedTransferCoding(value)
|
||||
|
||||
if chunked:
|
||||
if self.version < (1, 1):
|
||||
raise InvalidHeader("TRANSFER-ENCODING", req=self)
|
||||
if content_length is not None:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
self.chunked = True
|
||||
self.content_length = None
|
||||
self._body_remaining = -1
|
||||
elif content_length is not None:
|
||||
try:
|
||||
if str(content_length).isnumeric():
|
||||
content_length = int(content_length)
|
||||
else:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
except ValueError:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
if content_length < 0:
|
||||
raise InvalidHeader("CONTENT-LENGTH", req=self)
|
||||
|
||||
self.content_length = content_length
|
||||
self._body_remaining = content_length
|
||||
else:
|
||||
# No body for requests without Content-Length or Transfer-Encoding
|
||||
self.content_length = 0
|
||||
self._body_remaining = 0
|
||||
|
||||
def force_close(self):
|
||||
"""Mark connection for closing after this request."""
|
||||
self.must_close = True
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for (h, v) in self.headers:
|
||||
if h == "CONNECTION":
|
||||
v = v.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for (h, v) in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
|
||||
async def read_body(self, size=8192):
|
||||
"""Read a chunk of the request body.
|
||||
|
||||
Args:
|
||||
size: Maximum bytes to read
|
||||
|
||||
Returns:
|
||||
bytes: Body data, empty bytes when body is exhausted
|
||||
"""
|
||||
if self._body_remaining == 0:
|
||||
return b""
|
||||
|
||||
if self.chunked:
|
||||
return await self._read_chunked_body(size)
|
||||
else:
|
||||
return await self._read_length_body(size)
|
||||
|
||||
async def _read_length_body(self, size):
|
||||
"""Read from a length-delimited body."""
|
||||
if self._body_remaining <= 0:
|
||||
return b""
|
||||
|
||||
to_read = min(size, self._body_remaining)
|
||||
data = await self.unreader.read(to_read)
|
||||
if data:
|
||||
self._body_remaining -= len(data)
|
||||
return data
|
||||
|
||||
async def _read_chunked_body(self, size):
|
||||
"""Read from a chunked body."""
|
||||
if self._body_reader is None:
|
||||
self._body_reader = self._chunked_body_reader()
|
||||
|
||||
try:
|
||||
return await anext(self._body_reader)
|
||||
except StopAsyncIteration:
|
||||
self._body_remaining = 0
|
||||
return b""
|
||||
|
||||
async def _chunked_body_reader(self):
|
||||
"""Async generator for reading chunked body."""
|
||||
while True:
|
||||
# Read chunk size line
|
||||
size_line = await self._read_chunk_size_line()
|
||||
# Parse chunk size (handle extensions)
|
||||
chunk_size, *_ = size_line.split(b";", 1)
|
||||
if _:
|
||||
chunk_size = chunk_size.rstrip(b" \t")
|
||||
|
||||
if any(n not in b"0123456789abcdefABCDEF" for n in chunk_size):
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
if len(chunk_size) == 0:
|
||||
raise InvalidHeader("Invalid chunk size")
|
||||
|
||||
chunk_size = int(chunk_size, 16)
|
||||
|
||||
if chunk_size == 0:
|
||||
# Final chunk - skip trailers and final CRLF
|
||||
await self._skip_trailers()
|
||||
return
|
||||
|
||||
# Read chunk data
|
||||
remaining = chunk_size
|
||||
while remaining > 0:
|
||||
data = await self.unreader.read(min(remaining, 8192))
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
# Skip chunk terminating CRLF
|
||||
crlf = await self.unreader.read(2)
|
||||
if crlf != b"\r\n":
|
||||
# May have partial read, try to get the rest
|
||||
while len(crlf) < 2:
|
||||
more = await self.unreader.read(2 - len(crlf))
|
||||
if not more:
|
||||
break
|
||||
crlf += more
|
||||
if crlf != b"\r\n":
|
||||
raise InvalidHeader("Missing chunk terminator")
|
||||
|
||||
async def _read_chunk_size_line(self):
|
||||
"""Read a chunk size line."""
|
||||
buf = io.BytesIO()
|
||||
while True:
|
||||
data = await self.unreader.read(1)
|
||||
if not data:
|
||||
raise NoMoreData()
|
||||
buf.write(data)
|
||||
if buf.getvalue().endswith(b"\r\n"):
|
||||
return buf.getvalue()[:-2]
|
||||
|
||||
async def _skip_trailers(self):
|
||||
"""Skip trailer headers after chunked body."""
|
||||
buf = io.BytesIO()
|
||||
while True:
|
||||
data = await self.unreader.read(1)
|
||||
if not data:
|
||||
return
|
||||
buf.write(data)
|
||||
content = buf.getvalue()
|
||||
if content.endswith(b"\r\n\r\n"):
|
||||
# Could parse trailers here if needed
|
||||
return
|
||||
if content == b"\r\n":
|
||||
return
|
||||
|
||||
async def drain_body(self):
|
||||
"""Drain any unread body data.
|
||||
|
||||
Should be called before reusing connection for keepalive.
|
||||
"""
|
||||
while True:
|
||||
data = await self.read_body(8192)
|
||||
if not data:
|
||||
break
|
||||
470
gunicorn/asgi/protocol.py
Normal file
470
gunicorn/asgi/protocol.py
Normal file
@ -0,0 +1,470 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI protocol handler for gunicorn.
|
||||
|
||||
Implements asyncio.Protocol to handle HTTP/1.x connections and dispatch
|
||||
to ASGI applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
from gunicorn.http.errors import NoMoreData
|
||||
|
||||
|
||||
class ASGIResponseInfo:
|
||||
"""Simple container for ASGI response info for access logging."""
|
||||
|
||||
def __init__(self, status, headers, sent):
|
||||
self.status = status
|
||||
self.sent = sent
|
||||
# Convert headers to list of string tuples for logging
|
||||
self.headers = []
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
self.headers.append((name, value))
|
||||
|
||||
|
||||
class ASGIProtocol(asyncio.Protocol):
|
||||
"""HTTP/1.1 protocol handler for ASGI applications.
|
||||
|
||||
Handles connection lifecycle, request parsing, and ASGI app invocation.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
self.cfg = worker.cfg
|
||||
self.log = worker.log
|
||||
self.app = worker.asgi
|
||||
|
||||
self.transport = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
self._task = None
|
||||
self.req_count = 0
|
||||
|
||||
# Connection state
|
||||
self._closed = False
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Called when a connection is established."""
|
||||
self.transport = transport
|
||||
self.worker.nr_conns += 1
|
||||
|
||||
# Create stream reader/writer
|
||||
self.reader = asyncio.StreamReader()
|
||||
self.writer = transport
|
||||
|
||||
# Start handling requests
|
||||
self._task = self.worker.loop.create_task(self._handle_connection())
|
||||
|
||||
def data_received(self, data):
|
||||
"""Called when data is received on the connection."""
|
||||
if self.reader:
|
||||
self.reader.feed_data(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Called when the connection is lost or closed."""
|
||||
self._closed = True
|
||||
self.worker.nr_conns -= 1
|
||||
if self.reader:
|
||||
self.reader.feed_eof()
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
async def _handle_connection(self):
|
||||
"""Main request handling loop for this connection."""
|
||||
unreader = AsyncUnreader(self.reader)
|
||||
|
||||
try:
|
||||
peername = self.transport.get_extra_info('peername')
|
||||
sockname = self.transport.get_extra_info('sockname')
|
||||
|
||||
while not self._closed:
|
||||
self.req_count += 1
|
||||
|
||||
try:
|
||||
# Parse HTTP request
|
||||
request = await AsyncRequest.parse(
|
||||
self.cfg,
|
||||
unreader,
|
||||
peername,
|
||||
self.req_count
|
||||
)
|
||||
except StopIteration:
|
||||
# No more data, close connection
|
||||
break
|
||||
except NoMoreData:
|
||||
# Client disconnected
|
||||
break
|
||||
|
||||
# Check for WebSocket upgrade
|
||||
if self._is_websocket_upgrade(request):
|
||||
await self._handle_websocket(request, sockname, peername)
|
||||
break # WebSocket takes over the connection
|
||||
|
||||
# Handle HTTP request
|
||||
keepalive = await self._handle_http_request(
|
||||
request, sockname, peername
|
||||
)
|
||||
|
||||
# Increment worker request count
|
||||
self.worker.nr += 1
|
||||
|
||||
# Check max_requests
|
||||
if self.worker.nr >= self.worker.max_requests:
|
||||
self.log.info("Autorestarting worker after current request.")
|
||||
self.worker.alive = False
|
||||
keepalive = False
|
||||
|
||||
if not keepalive or not self.worker.alive:
|
||||
break
|
||||
|
||||
# Check connection limits for keepalive
|
||||
if not self.cfg.keepalive:
|
||||
break
|
||||
|
||||
# Drain any unread body before next request
|
||||
await request.drain_body()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.log.exception("Error handling connection: %s", e)
|
||||
finally:
|
||||
self._close_transport()
|
||||
|
||||
def _is_websocket_upgrade(self, request):
|
||||
"""Check if request is a WebSocket upgrade.
|
||||
|
||||
Per RFC 6455 Section 4.1, the opening handshake requires:
|
||||
- HTTP method MUST be GET
|
||||
- Upgrade header MUST be "websocket" (case-insensitive)
|
||||
- Connection header MUST contain "Upgrade"
|
||||
"""
|
||||
# RFC 6455: The method of the request MUST be GET
|
||||
if request.method != "GET":
|
||||
return False
|
||||
|
||||
upgrade = None
|
||||
connection = None
|
||||
for name, value in request.headers:
|
||||
if name == "UPGRADE":
|
||||
upgrade = value.lower()
|
||||
elif name == "CONNECTION":
|
||||
connection = value.lower()
|
||||
return upgrade == "websocket" and connection and "upgrade" in connection
|
||||
|
||||
async def _handle_websocket(self, request, sockname, peername):
|
||||
"""Handle WebSocket upgrade request."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
scope = self._build_websocket_scope(request, sockname, peername)
|
||||
ws_protocol = WebSocketProtocol(
|
||||
self.transport, self.reader, scope, self.app, self.log
|
||||
)
|
||||
await ws_protocol.run()
|
||||
|
||||
async def _handle_http_request(self, request, sockname, peername):
|
||||
"""Handle a single HTTP request."""
|
||||
scope = self._build_http_scope(request, sockname, peername)
|
||||
response_started = False
|
||||
response_complete = False
|
||||
exc_to_raise = None
|
||||
|
||||
# Response tracking for access logging
|
||||
response_status = 500
|
||||
response_headers = []
|
||||
response_sent = 0
|
||||
|
||||
# Receive queue for body
|
||||
receive_queue = asyncio.Queue()
|
||||
|
||||
# Pre-populate with initial body state
|
||||
if request.content_length == 0 and not request.chunked:
|
||||
await receive_queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
else:
|
||||
# Start body reading task
|
||||
asyncio.create_task(self._read_body_to_queue(request, receive_queue))
|
||||
|
||||
async def receive():
|
||||
return await receive_queue.get()
|
||||
|
||||
async def send(message):
|
||||
nonlocal response_started, response_complete, exc_to_raise
|
||||
nonlocal response_status, response_headers, response_sent
|
||||
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "http.response.start":
|
||||
if response_started:
|
||||
exc_to_raise = RuntimeError("Response already started")
|
||||
return
|
||||
response_started = True
|
||||
response_status = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
await self._send_response_start(response_status, response_headers, request)
|
||||
|
||||
elif msg_type == "http.response.body":
|
||||
if not response_started:
|
||||
exc_to_raise = RuntimeError("Response not started")
|
||||
return
|
||||
if response_complete:
|
||||
exc_to_raise = RuntimeError("Response already complete")
|
||||
return
|
||||
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
if body:
|
||||
await self._send_body(body)
|
||||
response_sent += len(body)
|
||||
|
||||
if not more_body:
|
||||
response_complete = True
|
||||
|
||||
# Build environ for logging
|
||||
environ = self._build_environ(request, sockname, peername)
|
||||
resp = None
|
||||
|
||||
try:
|
||||
request_start = datetime.now()
|
||||
self.cfg.pre_request(self.worker, request)
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
if exc_to_raise is not None:
|
||||
raise exc_to_raise
|
||||
|
||||
# Ensure response was sent
|
||||
if not response_started:
|
||||
await self._send_error_response(500, "Internal Server Error")
|
||||
response_status = 500
|
||||
|
||||
except Exception:
|
||||
self.log.exception("Error in ASGI application")
|
||||
if not response_started:
|
||||
await self._send_error_response(500, "Internal Server Error")
|
||||
response_status = 500
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
request_time = datetime.now() - request_start
|
||||
# Create response info for logging
|
||||
resp = ASGIResponseInfo(response_status, response_headers, response_sent)
|
||||
self.log.access(resp, request, environ, request_time)
|
||||
self.cfg.post_request(self.worker, request, environ, resp)
|
||||
except Exception:
|
||||
self.log.exception("Exception in post_request hook")
|
||||
|
||||
# Determine keepalive
|
||||
if request.should_close():
|
||||
return False
|
||||
|
||||
return self.worker.alive and self.cfg.keepalive
|
||||
|
||||
async def _read_body_to_queue(self, request, queue):
|
||||
"""Read request body and put chunks on the queue."""
|
||||
try:
|
||||
while True:
|
||||
chunk = await request.read_body(65536)
|
||||
if chunk:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": chunk,
|
||||
"more_body": True,
|
||||
})
|
||||
else:
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
break
|
||||
except Exception as e:
|
||||
self.log.debug("Error reading body: %s", e)
|
||||
await queue.put({
|
||||
"type": "http.request",
|
||||
"body": b"",
|
||||
"more_body": False,
|
||||
})
|
||||
|
||||
def _build_http_scope(self, request, sockname, peername):
|
||||
"""Build ASGI HTTP scope from parsed request."""
|
||||
# Build headers list as bytes tuples
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||
"method": request.method,
|
||||
"scheme": request.scheme,
|
||||
"path": request.path,
|
||||
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||
"root_path": self.cfg.root_path or "",
|
||||
"headers": headers,
|
||||
"server": sockname if sockname else None,
|
||||
"client": peername if peername else None,
|
||||
}
|
||||
|
||||
# Add state dict for lifespan sharing
|
||||
if hasattr(self.worker, 'state'):
|
||||
scope["state"] = self.worker.state
|
||||
|
||||
return scope
|
||||
|
||||
def _build_environ(self, request, sockname, peername):
|
||||
"""Build minimal WSGI-like environ dict for access logging."""
|
||||
environ = {
|
||||
"REQUEST_METHOD": request.method,
|
||||
"RAW_URI": request.uri,
|
||||
"PATH_INFO": request.path,
|
||||
"QUERY_STRING": request.query or "",
|
||||
"SERVER_PROTOCOL": f"HTTP/{request.version[0]}.{request.version[1]}",
|
||||
"REMOTE_ADDR": peername[0] if peername else "-",
|
||||
}
|
||||
|
||||
# Add HTTP headers as environ vars
|
||||
for name, value in request.headers:
|
||||
key = "HTTP_" + name.replace("-", "_")
|
||||
environ[key] = value
|
||||
|
||||
return environ
|
||||
|
||||
def _build_websocket_scope(self, request, sockname, peername):
|
||||
"""Build ASGI WebSocket scope from parsed request."""
|
||||
# Build headers list as bytes tuples
|
||||
headers = []
|
||||
for name, value in request.headers:
|
||||
headers.append((name.lower().encode("latin-1"), value.encode("latin-1")))
|
||||
|
||||
# Extract subprotocols from Sec-WebSocket-Protocol header
|
||||
subprotocols = []
|
||||
for name, value in request.headers:
|
||||
if name == "SEC-WEBSOCKET-PROTOCOL":
|
||||
subprotocols = [s.strip() for s in value.split(",")]
|
||||
break
|
||||
|
||||
scope = {
|
||||
"type": "websocket",
|
||||
"asgi": {"version": "3.0", "spec_version": "2.4"},
|
||||
"http_version": f"{request.version[0]}.{request.version[1]}",
|
||||
"scheme": "wss" if request.scheme == "https" else "ws",
|
||||
"path": request.path,
|
||||
"raw_path": request.path.encode("latin-1") if request.path else b"",
|
||||
"query_string": request.query.encode("latin-1") if request.query else b"",
|
||||
"root_path": self.cfg.root_path or "",
|
||||
"headers": headers,
|
||||
"server": sockname if sockname else None,
|
||||
"client": peername if peername else None,
|
||||
"subprotocols": subprotocols,
|
||||
}
|
||||
|
||||
# Add state dict for lifespan sharing
|
||||
if hasattr(self.worker, 'state'):
|
||||
scope["state"] = self.worker.state
|
||||
|
||||
return scope
|
||||
|
||||
async def _send_response_start(self, status, headers, request):
|
||||
"""Send HTTP response status and headers."""
|
||||
# Build status line
|
||||
reason = self._get_reason_phrase(status)
|
||||
status_line = f"HTTP/{request.version[0]}.{request.version[1]} {status} {reason}\r\n"
|
||||
|
||||
# Build headers
|
||||
header_lines = []
|
||||
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
header_lines.append(f"{name}: {value}\r\n")
|
||||
|
||||
# Add server header if not present
|
||||
header_lines.append("Server: gunicorn/asgi\r\n")
|
||||
|
||||
response = status_line + "".join(header_lines) + "\r\n"
|
||||
self.transport.write(response.encode("latin-1"))
|
||||
|
||||
async def _send_body(self, body):
|
||||
"""Send response body chunk."""
|
||||
if body:
|
||||
self.transport.write(body)
|
||||
|
||||
async def _send_error_response(self, status, message):
|
||||
"""Send an error response."""
|
||||
body = message.encode("utf-8")
|
||||
response = (
|
||||
f"HTTP/1.1 {status} {message}\r\n"
|
||||
f"Content-Type: text/plain\r\n"
|
||||
f"Content-Length: {len(body)}\r\n"
|
||||
f"Connection: close\r\n"
|
||||
f"\r\n"
|
||||
)
|
||||
self.transport.write(response.encode("latin-1"))
|
||||
self.transport.write(body)
|
||||
|
||||
def _get_reason_phrase(self, status):
|
||||
"""Get HTTP reason phrase for status code."""
|
||||
reasons = {
|
||||
100: "Continue",
|
||||
101: "Switching Protocols",
|
||||
200: "OK",
|
||||
201: "Created",
|
||||
202: "Accepted",
|
||||
204: "No Content",
|
||||
206: "Partial Content",
|
||||
301: "Moved Permanently",
|
||||
302: "Found",
|
||||
303: "See Other",
|
||||
304: "Not Modified",
|
||||
307: "Temporary Redirect",
|
||||
308: "Permanent Redirect",
|
||||
400: "Bad Request",
|
||||
401: "Unauthorized",
|
||||
403: "Forbidden",
|
||||
404: "Not Found",
|
||||
405: "Method Not Allowed",
|
||||
408: "Request Timeout",
|
||||
409: "Conflict",
|
||||
410: "Gone",
|
||||
411: "Length Required",
|
||||
413: "Payload Too Large",
|
||||
414: "URI Too Long",
|
||||
415: "Unsupported Media Type",
|
||||
422: "Unprocessable Entity",
|
||||
429: "Too Many Requests",
|
||||
500: "Internal Server Error",
|
||||
501: "Not Implemented",
|
||||
502: "Bad Gateway",
|
||||
503: "Service Unavailable",
|
||||
504: "Gateway Timeout",
|
||||
}
|
||||
return reasons.get(status, "Unknown")
|
||||
|
||||
def _close_transport(self):
|
||||
"""Close the transport safely."""
|
||||
if self.transport and not self._closed:
|
||||
try:
|
||||
self.transport.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._closed = True
|
||||
100
gunicorn/asgi/unreader.py
Normal file
100
gunicorn/asgi/unreader.py
Normal file
@ -0,0 +1,100 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Async version of gunicorn/http/unreader.py for ASGI workers.
|
||||
|
||||
Provides async reading with pushback buffer support.
|
||||
"""
|
||||
|
||||
import io
|
||||
|
||||
|
||||
class AsyncUnreader:
|
||||
"""Async socket reader with pushback buffer support.
|
||||
|
||||
This class wraps an asyncio StreamReader and provides the ability
|
||||
to "unread" data back into a buffer for re-parsing.
|
||||
"""
|
||||
|
||||
def __init__(self, reader, max_chunk=8192):
|
||||
"""Initialize the async unreader.
|
||||
|
||||
Args:
|
||||
reader: asyncio.StreamReader instance
|
||||
max_chunk: Maximum bytes to read at once
|
||||
"""
|
||||
self.reader = reader
|
||||
self.buf = io.BytesIO()
|
||||
self.max_chunk = max_chunk
|
||||
|
||||
async def read(self, size=None):
|
||||
"""Read data from the stream, using buffered data first.
|
||||
|
||||
Args:
|
||||
size: Number of bytes to read. If None, returns all buffered
|
||||
data or reads a single chunk.
|
||||
|
||||
Returns:
|
||||
bytes: Data read from buffer or stream
|
||||
"""
|
||||
if size is not None and not isinstance(size, int):
|
||||
raise TypeError("size parameter must be an int or long.")
|
||||
|
||||
if size is not None:
|
||||
if size == 0:
|
||||
return b""
|
||||
if size < 0:
|
||||
size = None
|
||||
|
||||
# Move to end to check buffer size
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
|
||||
# If no size specified, return buffered data or read chunk
|
||||
if size is None and self.buf.tell():
|
||||
ret = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
return ret
|
||||
if size is None:
|
||||
chunk = await self._read_chunk()
|
||||
return chunk
|
||||
|
||||
# Read until we have enough data
|
||||
while self.buf.tell() < size:
|
||||
chunk = await self._read_chunk()
|
||||
if not chunk:
|
||||
ret = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
return ret
|
||||
self.buf.write(chunk)
|
||||
|
||||
data = self.buf.getvalue()
|
||||
self.buf = io.BytesIO()
|
||||
self.buf.write(data[size:])
|
||||
return data[:size]
|
||||
|
||||
async def _read_chunk(self):
|
||||
"""Read a chunk of data from the underlying stream."""
|
||||
try:
|
||||
return await self.reader.read(self.max_chunk)
|
||||
except Exception:
|
||||
return b""
|
||||
|
||||
def unread(self, data):
|
||||
"""Push data back into the buffer for re-reading.
|
||||
|
||||
Args:
|
||||
data: bytes to push back
|
||||
"""
|
||||
if data:
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
self.buf.write(data)
|
||||
|
||||
def has_buffered_data(self):
|
||||
"""Check if there's data in the pushback buffer."""
|
||||
pos = self.buf.tell()
|
||||
self.buf.seek(0, io.SEEK_END)
|
||||
has_data = self.buf.tell() > 0
|
||||
self.buf.seek(pos)
|
||||
return has_data
|
||||
368
gunicorn/asgi/websocket.py
Normal file
368
gunicorn/asgi/websocket.py
Normal file
@ -0,0 +1,368 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
WebSocket protocol handler for ASGI.
|
||||
|
||||
Implements RFC 6455 WebSocket protocol for ASGI applications.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import struct
|
||||
|
||||
|
||||
# WebSocket frame opcodes
|
||||
OPCODE_CONTINUATION = 0x0
|
||||
OPCODE_TEXT = 0x1
|
||||
OPCODE_BINARY = 0x2
|
||||
OPCODE_CLOSE = 0x8
|
||||
OPCODE_PING = 0x9
|
||||
OPCODE_PONG = 0xA
|
||||
|
||||
# WebSocket close codes
|
||||
CLOSE_NORMAL = 1000
|
||||
CLOSE_GOING_AWAY = 1001
|
||||
CLOSE_PROTOCOL_ERROR = 1002
|
||||
CLOSE_UNSUPPORTED = 1003
|
||||
CLOSE_NO_STATUS = 1005
|
||||
CLOSE_ABNORMAL = 1006
|
||||
CLOSE_INVALID_DATA = 1007
|
||||
CLOSE_POLICY_VIOLATION = 1008
|
||||
CLOSE_MESSAGE_TOO_BIG = 1009
|
||||
CLOSE_MANDATORY_EXT = 1010
|
||||
CLOSE_INTERNAL_ERROR = 1011
|
||||
|
||||
# WebSocket handshake GUID (RFC 6455)
|
||||
WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class WebSocketProtocol:
|
||||
"""WebSocket connection handler for ASGI applications."""
|
||||
|
||||
def __init__(self, transport, reader, scope, app, log):
|
||||
"""Initialize WebSocket protocol handler.
|
||||
|
||||
Args:
|
||||
transport: asyncio transport for writing
|
||||
reader: asyncio StreamReader for reading
|
||||
scope: ASGI WebSocket scope dict
|
||||
app: ASGI application callable
|
||||
log: Logger instance
|
||||
"""
|
||||
self.transport = transport
|
||||
self.reader = reader
|
||||
self.scope = scope
|
||||
self.app = app
|
||||
self.log = log
|
||||
|
||||
self.accepted = False
|
||||
self.closed = False
|
||||
self.close_code = None
|
||||
self.close_reason = ""
|
||||
|
||||
# Message reassembly state
|
||||
self._fragments = []
|
||||
self._fragment_opcode = None
|
||||
|
||||
# Receive queue for incoming messages
|
||||
self._receive_queue = asyncio.Queue()
|
||||
|
||||
async def run(self):
|
||||
"""Run the WebSocket ASGI application."""
|
||||
# Send initial connect event
|
||||
await self._receive_queue.put({"type": "websocket.connect"})
|
||||
|
||||
# Start frame reading task
|
||||
read_task = asyncio.create_task(self._read_frames())
|
||||
|
||||
try:
|
||||
await self.app(self.scope, self._receive, self._send)
|
||||
except Exception:
|
||||
self.log.exception("Error in WebSocket ASGI application")
|
||||
finally:
|
||||
read_task.cancel()
|
||||
try:
|
||||
await read_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Send close frame if not already closed
|
||||
if not self.closed and self.accepted:
|
||||
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
|
||||
|
||||
async def _receive(self):
|
||||
"""ASGI receive callable."""
|
||||
return await self._receive_queue.get()
|
||||
|
||||
async def _send(self, message):
|
||||
"""ASGI send callable."""
|
||||
msg_type = message["type"]
|
||||
|
||||
if msg_type == "websocket.accept":
|
||||
if self.accepted:
|
||||
raise RuntimeError("WebSocket already accepted")
|
||||
await self._send_accept(message)
|
||||
self.accepted = True
|
||||
|
||||
elif msg_type == "websocket.send":
|
||||
if not self.accepted:
|
||||
raise RuntimeError("WebSocket not accepted")
|
||||
if self.closed:
|
||||
raise RuntimeError("WebSocket closed")
|
||||
|
||||
if "text" in message:
|
||||
await self._send_frame(OPCODE_TEXT, message["text"].encode("utf-8"))
|
||||
elif "bytes" in message:
|
||||
await self._send_frame(OPCODE_BINARY, message["bytes"])
|
||||
|
||||
elif msg_type == "websocket.close":
|
||||
code = message.get("code", CLOSE_NORMAL)
|
||||
reason = message.get("reason", "")
|
||||
await self._send_close(code, reason)
|
||||
self.closed = True
|
||||
|
||||
async def _send_accept(self, message):
|
||||
"""Send WebSocket handshake accept response."""
|
||||
# Get Sec-WebSocket-Key from headers
|
||||
ws_key = None
|
||||
for name, value in self.scope["headers"]:
|
||||
if name == b"sec-websocket-key":
|
||||
ws_key = value
|
||||
break
|
||||
|
||||
if not ws_key:
|
||||
raise RuntimeError("Missing Sec-WebSocket-Key header")
|
||||
|
||||
# Calculate accept key
|
||||
accept_key = base64.b64encode(
|
||||
hashlib.sha1(ws_key + WS_GUID).digest()
|
||||
).decode("ascii")
|
||||
|
||||
# Build response headers
|
||||
headers = [
|
||||
"HTTP/1.1 101 Switching Protocols\r\n",
|
||||
"Upgrade: websocket\r\n",
|
||||
"Connection: Upgrade\r\n",
|
||||
f"Sec-WebSocket-Accept: {accept_key}\r\n",
|
||||
]
|
||||
|
||||
# Add selected subprotocol if specified
|
||||
subprotocol = message.get("subprotocol")
|
||||
if subprotocol:
|
||||
headers.append(f"Sec-WebSocket-Protocol: {subprotocol}\r\n")
|
||||
|
||||
# Add any extra headers from message
|
||||
extra_headers = message.get("headers", [])
|
||||
for name, value in extra_headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
headers.append(f"{name}: {value}\r\n")
|
||||
|
||||
headers.append("\r\n")
|
||||
self.transport.write("".join(headers).encode("latin-1"))
|
||||
|
||||
async def _read_frames(self):
|
||||
"""Read and process incoming WebSocket frames."""
|
||||
try:
|
||||
while not self.closed:
|
||||
frame = await self._read_frame()
|
||||
if frame is None:
|
||||
break
|
||||
|
||||
opcode, payload = frame
|
||||
|
||||
if opcode == OPCODE_CLOSE:
|
||||
await self._handle_close(payload)
|
||||
break
|
||||
|
||||
if opcode == OPCODE_PING:
|
||||
await self._send_frame(OPCODE_PONG, payload)
|
||||
elif opcode == OPCODE_PONG:
|
||||
# Ignore pongs
|
||||
pass
|
||||
elif opcode == OPCODE_TEXT:
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.receive",
|
||||
"text": payload.decode("utf-8"),
|
||||
})
|
||||
elif opcode == OPCODE_BINARY:
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.receive",
|
||||
"bytes": payload,
|
||||
})
|
||||
elif opcode == OPCODE_CONTINUATION:
|
||||
# Handle fragmented messages
|
||||
await self._handle_continuation(payload)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
self.log.debug("WebSocket read error: %s", e)
|
||||
finally:
|
||||
# Signal disconnect
|
||||
if not self.closed:
|
||||
self.closed = True
|
||||
await self._receive_queue.put({
|
||||
"type": "websocket.disconnect",
|
||||
"code": self.close_code or CLOSE_ABNORMAL,
|
||||
})
|
||||
|
||||
async def _read_frame(self): # pylint: disable=too-many-return-statements
|
||||
"""Read a single WebSocket frame.
|
||||
|
||||
Returns:
|
||||
tuple: (opcode, payload) or None if connection closed
|
||||
"""
|
||||
# Read frame header (2 bytes minimum)
|
||||
header = await self._read_exact(2)
|
||||
if not header:
|
||||
return None
|
||||
|
||||
first_byte, second_byte = header[0], header[1]
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0x0F
|
||||
|
||||
# RSV bits must be 0 (no extensions)
|
||||
if rsv1 or rsv2 or rsv3:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "RSV bits set")
|
||||
return None
|
||||
|
||||
masked = (second_byte >> 7) & 1
|
||||
payload_len = second_byte & 0x7F
|
||||
|
||||
# Client frames must be masked (RFC 6455)
|
||||
if not masked:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "Frame not masked")
|
||||
return None
|
||||
|
||||
# Extended payload length
|
||||
if payload_len == 126:
|
||||
ext_len = await self._read_exact(2)
|
||||
if not ext_len:
|
||||
return None
|
||||
payload_len = struct.unpack("!H", ext_len)[0]
|
||||
elif payload_len == 127:
|
||||
ext_len = await self._read_exact(8)
|
||||
if not ext_len:
|
||||
return None
|
||||
payload_len = struct.unpack("!Q", ext_len)[0]
|
||||
|
||||
# Read masking key
|
||||
masking_key = await self._read_exact(4)
|
||||
if not masking_key:
|
||||
return None
|
||||
|
||||
# Read payload
|
||||
payload = await self._read_exact(payload_len)
|
||||
if payload is None:
|
||||
return None
|
||||
|
||||
# Unmask payload
|
||||
payload = self._unmask(payload, masking_key)
|
||||
|
||||
# Handle fragmented messages
|
||||
if opcode == OPCODE_CONTINUATION:
|
||||
if self._fragment_opcode is None:
|
||||
await self._send_close(CLOSE_PROTOCOL_ERROR, "Unexpected continuation")
|
||||
return None
|
||||
self._fragments.append(payload)
|
||||
if fin:
|
||||
# Reassemble complete message
|
||||
full_payload = b"".join(self._fragments)
|
||||
final_opcode = self._fragment_opcode
|
||||
self._fragments = []
|
||||
self._fragment_opcode = None
|
||||
return (final_opcode, full_payload)
|
||||
return (OPCODE_CONTINUATION, b"") # Fragment received, wait for more
|
||||
elif opcode in (OPCODE_TEXT, OPCODE_BINARY):
|
||||
if not fin:
|
||||
# Start of fragmented message
|
||||
self._fragment_opcode = opcode
|
||||
self._fragments = [payload]
|
||||
return (OPCODE_CONTINUATION, b"") # Fragment started, wait for more
|
||||
return (opcode, payload)
|
||||
else:
|
||||
# Control frames
|
||||
return (opcode, payload)
|
||||
|
||||
async def _read_exact(self, n):
|
||||
"""Read exactly n bytes from the reader."""
|
||||
try:
|
||||
data = await self.reader.readexactly(n)
|
||||
return data
|
||||
except asyncio.IncompleteReadError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _unmask(self, payload, masking_key):
|
||||
"""Unmask WebSocket payload data."""
|
||||
if not payload:
|
||||
return payload
|
||||
# XOR each byte with corresponding mask byte
|
||||
return bytes(b ^ masking_key[i % 4] for i, b in enumerate(payload))
|
||||
|
||||
async def _handle_close(self, payload):
|
||||
"""Handle incoming close frame."""
|
||||
if len(payload) >= 2:
|
||||
self.close_code = struct.unpack("!H", payload[:2])[0]
|
||||
self.close_reason = payload[2:].decode("utf-8", errors="replace")
|
||||
else:
|
||||
self.close_code = CLOSE_NO_STATUS
|
||||
self.close_reason = ""
|
||||
|
||||
# Echo close frame back if we haven't already sent one
|
||||
if not self.closed:
|
||||
await self._send_close(self.close_code, self.close_reason)
|
||||
|
||||
self.closed = True
|
||||
|
||||
async def _handle_continuation(self, payload): # pylint: disable=unused-argument
|
||||
"""Handle continuation frame (already processed in _read_frame)."""
|
||||
# This is called for partial fragments, nothing to do here
|
||||
|
||||
async def _send_frame(self, opcode, payload):
|
||||
"""Send a WebSocket frame.
|
||||
|
||||
Server frames are not masked (RFC 6455).
|
||||
"""
|
||||
if isinstance(payload, str):
|
||||
payload = payload.encode("utf-8")
|
||||
|
||||
length = len(payload)
|
||||
frame = bytearray()
|
||||
|
||||
# First byte: FIN + opcode
|
||||
frame.append(0x80 | opcode)
|
||||
|
||||
# Second byte: length (no mask bit for server)
|
||||
if length < 126:
|
||||
frame.append(length)
|
||||
elif length < 65536:
|
||||
frame.append(126)
|
||||
frame.extend(struct.pack("!H", length))
|
||||
else:
|
||||
frame.append(127)
|
||||
frame.extend(struct.pack("!Q", length))
|
||||
|
||||
# Payload
|
||||
frame.extend(payload)
|
||||
|
||||
self.transport.write(bytes(frame))
|
||||
|
||||
async def _send_close(self, code, reason=""):
|
||||
"""Send a close frame."""
|
||||
payload = struct.pack("!H", code)
|
||||
if reason:
|
||||
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
|
||||
await self._send_frame(OPCODE_CLOSE, payload)
|
||||
self.closed = True
|
||||
@ -2096,6 +2096,53 @@ class ProxyAllowFrom(Setting):
|
||||
"""
|
||||
|
||||
|
||||
class Protocol(Setting):
|
||||
name = "protocol"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--protocol"]
|
||||
meta = "STRING"
|
||||
validator = validate_string
|
||||
default = "http"
|
||||
desc = """\
|
||||
The protocol for incoming connections.
|
||||
|
||||
* ``http`` - Standard HTTP/1.x (default)
|
||||
* ``uwsgi`` - uWSGI binary protocol (for nginx uwsgi_pass)
|
||||
|
||||
When using the uWSGI protocol, Gunicorn can receive requests from
|
||||
nginx using the uwsgi_pass directive::
|
||||
|
||||
upstream gunicorn {
|
||||
server 127.0.0.1:8000;
|
||||
}
|
||||
location / {
|
||||
uwsgi_pass gunicorn;
|
||||
include uwsgi_params;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class UWSGIAllowFrom(Setting):
|
||||
name = "uwsgi_allow_ips"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--uwsgi-allow-from"]
|
||||
validator = validate_string_to_addr_list
|
||||
default = "127.0.0.1,::1"
|
||||
desc = """\
|
||||
IPs allowed to send uWSGI protocol requests (comma separated).
|
||||
|
||||
Set to ``*`` to allow all IPs. This is useful for setups where you
|
||||
don't know in advance the IP address of front-end, but instead have
|
||||
ensured via other means that only your authorized front-ends can
|
||||
access Gunicorn.
|
||||
|
||||
.. note::
|
||||
|
||||
This option does not affect UNIX socket connections. Connections not associated with
|
||||
an IP address are treated as allowed, unconditionally.
|
||||
"""
|
||||
|
||||
|
||||
class KeyFile(Setting):
|
||||
name = "keyfile"
|
||||
section = "SSL"
|
||||
@ -2440,3 +2487,94 @@ class HeaderMap(Setting):
|
||||
|
||||
.. versionadded:: 22.0.0
|
||||
"""
|
||||
|
||||
|
||||
def validate_asgi_loop(val):
|
||||
if val is None:
|
||||
return "auto"
|
||||
if not isinstance(val, str):
|
||||
raise TypeError("Invalid type for casting: %s" % val)
|
||||
val = val.lower().strip()
|
||||
if val not in ("auto", "asyncio", "uvloop"):
|
||||
raise ValueError("Invalid ASGI loop: %s" % val)
|
||||
return val
|
||||
|
||||
|
||||
def validate_asgi_lifespan(val):
|
||||
if val is None:
|
||||
return "auto"
|
||||
if not isinstance(val, str):
|
||||
raise TypeError("Invalid type for casting: %s" % val)
|
||||
val = val.lower().strip()
|
||||
if val not in ("auto", "on", "off"):
|
||||
raise ValueError("Invalid ASGI lifespan: %s" % val)
|
||||
return val
|
||||
|
||||
|
||||
class ASGILoop(Setting):
|
||||
name = "asgi_loop"
|
||||
section = "Worker Processes"
|
||||
cli = ["--asgi-loop"]
|
||||
meta = "STRING"
|
||||
validator = validate_asgi_loop
|
||||
default = "auto"
|
||||
desc = """\
|
||||
Event loop implementation for ASGI workers.
|
||||
|
||||
- auto: Use uvloop if available, otherwise asyncio
|
||||
- asyncio: Use Python's built-in asyncio event loop
|
||||
- uvloop: Use uvloop (must be installed separately)
|
||||
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
uvloop typically provides better performance but requires
|
||||
installing the uvloop package.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
|
||||
class ASGILifespan(Setting):
|
||||
name = "asgi_lifespan"
|
||||
section = "Worker Processes"
|
||||
cli = ["--asgi-lifespan"]
|
||||
meta = "STRING"
|
||||
validator = validate_asgi_lifespan
|
||||
default = "auto"
|
||||
desc = """\
|
||||
Control ASGI lifespan protocol handling.
|
||||
|
||||
- auto: Detect if app supports lifespan, enable if so
|
||||
- on: Always run lifespan protocol (fail if unsupported)
|
||||
- off: Never run lifespan protocol
|
||||
|
||||
The lifespan protocol allows ASGI applications to run code at
|
||||
startup and shutdown. This is essential for frameworks like
|
||||
FastAPI that need to initialize database connections, caches,
|
||||
or other resources.
|
||||
|
||||
This setting only affects the ``asgi`` worker type.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
|
||||
class RootPath(Setting):
|
||||
name = "root_path"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--root-path"]
|
||||
meta = "STRING"
|
||||
validator = validate_string
|
||||
default = ""
|
||||
desc = """\
|
||||
The root path for ASGI applications.
|
||||
|
||||
This is used to set the ``root_path`` in the ASGI scope, which
|
||||
allows applications to know their mount point when behind a
|
||||
reverse proxy.
|
||||
|
||||
For example, if your application is mounted at ``/api``, set
|
||||
this to ``/api``.
|
||||
|
||||
.. versionadded:: 24.0.0
|
||||
"""
|
||||
|
||||
@ -5,4 +5,23 @@
|
||||
from gunicorn.http.message import Message, Request
|
||||
from gunicorn.http.parser import RequestParser
|
||||
|
||||
__all__ = ['Message', 'Request', 'RequestParser']
|
||||
|
||||
def get_parser(cfg, source, source_addr):
|
||||
"""Get appropriate parser based on protocol config.
|
||||
|
||||
Args:
|
||||
cfg: Gunicorn config object
|
||||
source: Socket or iterable source
|
||||
source_addr: Source address tuple or None
|
||||
|
||||
Returns:
|
||||
Parser instance (RequestParser or UWSGIParser)
|
||||
"""
|
||||
protocol = getattr(cfg, 'protocol', 'http')
|
||||
if protocol == 'uwsgi':
|
||||
from gunicorn.uwsgi.parser import UWSGIParser
|
||||
return UWSGIParser(cfg, source, source_addr)
|
||||
return RequestParser(cfg, source, source_addr)
|
||||
|
||||
|
||||
__all__ = ['Message', 'Request', 'RequestParser', 'get_parser']
|
||||
|
||||
21
gunicorn/uwsgi/__init__.py
Normal file
21
gunicorn/uwsgi/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.uwsgi.message import UWSGIRequest
|
||||
from gunicorn.uwsgi.parser import UWSGIParser
|
||||
from gunicorn.uwsgi.errors import (
|
||||
UWSGIParseException,
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'UWSGIRequest',
|
||||
'UWSGIParser',
|
||||
'UWSGIParseException',
|
||||
'InvalidUWSGIHeader',
|
||||
'UnsupportedModifier',
|
||||
'ForbiddenUWSGIRequest',
|
||||
]
|
||||
46
gunicorn/uwsgi/errors.py
Normal file
46
gunicorn/uwsgi/errors.py
Normal file
@ -0,0 +1,46 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
# We don't need to call super() in __init__ methods of our
|
||||
# BaseException and Exception classes because we also define
|
||||
# our own __str__ methods so there is no need to pass 'message'
|
||||
# to the base class to get a meaningful output from 'str(exc)'.
|
||||
# pylint: disable=super-init-not-called
|
||||
|
||||
|
||||
class UWSGIParseException(Exception):
|
||||
"""Base exception for uWSGI protocol parsing errors."""
|
||||
|
||||
|
||||
class InvalidUWSGIHeader(UWSGIParseException):
|
||||
"""Raised when the uWSGI header is malformed."""
|
||||
|
||||
def __init__(self, msg=""):
|
||||
self.msg = msg
|
||||
self.code = 400
|
||||
|
||||
def __str__(self):
|
||||
return "Invalid uWSGI header: %s" % self.msg
|
||||
|
||||
|
||||
class UnsupportedModifier(UWSGIParseException):
|
||||
"""Raised when modifier1 is not 0 (WSGI request)."""
|
||||
|
||||
def __init__(self, modifier):
|
||||
self.modifier = modifier
|
||||
self.code = 501
|
||||
|
||||
def __str__(self):
|
||||
return "Unsupported uWSGI modifier1: %d" % self.modifier
|
||||
|
||||
|
||||
class ForbiddenUWSGIRequest(UWSGIParseException):
|
||||
"""Raised when source IP is not in the allow list."""
|
||||
|
||||
def __init__(self, host):
|
||||
self.host = host
|
||||
self.code = 403
|
||||
|
||||
def __str__(self):
|
||||
return "uWSGI request from %r not allowed" % self.host
|
||||
255
gunicorn/uwsgi/message.py
Normal file
255
gunicorn/uwsgi/message.py
Normal file
@ -0,0 +1,255 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import io
|
||||
|
||||
from gunicorn.http.body import LengthReader, Body
|
||||
from gunicorn.uwsgi.errors import (
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
|
||||
|
||||
# Maximum number of variables to prevent DoS
|
||||
MAX_UWSGI_VARS = 1000
|
||||
|
||||
|
||||
class UWSGIRequest:
|
||||
"""uWSGI protocol request parser.
|
||||
|
||||
The uWSGI protocol uses a 4-byte binary header:
|
||||
- Byte 0: modifier1 (packet type, 0 = WSGI request)
|
||||
- Bytes 1-2: datasize (16-bit little-endian, size of vars block)
|
||||
- Byte 3: modifier2 (additional flags, typically 0)
|
||||
|
||||
After the header:
|
||||
1. Vars block (datasize bytes): Key-value pairs containing WSGI environ
|
||||
- Each pair: 2-byte key_size (LE) + key + 2-byte val_size (LE) + value
|
||||
2. Request body (determined by CONTENT_LENGTH in vars)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.remote_addr = peer_addr
|
||||
self.req_number = req_number
|
||||
|
||||
# Request attributes (compatible with HTTP Request interface)
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = ""
|
||||
self.version = (1, 1) # uWSGI is HTTP/1.1 compatible
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
self.body = None
|
||||
self.scheme = "https" if cfg.is_ssl else "http"
|
||||
self.must_close = False
|
||||
|
||||
# uWSGI specific
|
||||
self.uwsgi_vars = {}
|
||||
self.modifier1 = 0
|
||||
self.modifier2 = 0
|
||||
|
||||
# Proxy protocol compatibility
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Check if the source IP is allowed
|
||||
self._check_allowed_ip()
|
||||
|
||||
# Parse the request
|
||||
unused = self.parse(self.unreader)
|
||||
self.unreader.unread(unused)
|
||||
self.set_body_reader()
|
||||
|
||||
def _check_allowed_ip(self):
|
||||
"""Verify source IP is in the allowed list."""
|
||||
allow_ips = getattr(self.cfg, 'uwsgi_allow_ips', ['127.0.0.1', '::1'])
|
||||
|
||||
# UNIX sockets don't have IP addresses
|
||||
if not isinstance(self.peer_addr, tuple):
|
||||
return
|
||||
|
||||
# Wildcard allows all
|
||||
if '*' in allow_ips:
|
||||
return
|
||||
|
||||
if self.peer_addr[0] not in allow_ips:
|
||||
raise ForbiddenUWSGIRequest(self.peer_addr[0])
|
||||
|
||||
def force_close(self):
|
||||
"""Force the connection to close after this request."""
|
||||
self.must_close = True
|
||||
|
||||
def parse(self, unreader):
|
||||
"""Parse uWSGI packet header and vars block."""
|
||||
# Read the 4-byte header
|
||||
header = self._read_exact(unreader, 4)
|
||||
if len(header) < 4:
|
||||
raise InvalidUWSGIHeader("incomplete header")
|
||||
|
||||
self.modifier1 = header[0]
|
||||
datasize = int.from_bytes(header[1:3], 'little')
|
||||
self.modifier2 = header[3]
|
||||
|
||||
# Only modifier1=0 (WSGI request) is supported
|
||||
if self.modifier1 != 0:
|
||||
raise UnsupportedModifier(self.modifier1)
|
||||
|
||||
# Read the vars block
|
||||
if datasize > 0:
|
||||
vars_data = self._read_exact(unreader, datasize)
|
||||
if len(vars_data) < datasize:
|
||||
raise InvalidUWSGIHeader("incomplete vars block")
|
||||
self._parse_vars(vars_data)
|
||||
|
||||
# Extract HTTP request info from vars
|
||||
self._extract_request_info()
|
||||
|
||||
return b""
|
||||
|
||||
def _read_exact(self, unreader, size):
|
||||
"""Read exactly size bytes from the unreader."""
|
||||
buf = io.BytesIO()
|
||||
remaining = size
|
||||
|
||||
while remaining > 0:
|
||||
data = unreader.read()
|
||||
if not data:
|
||||
break
|
||||
buf.write(data)
|
||||
remaining = size - buf.tell()
|
||||
|
||||
result = buf.getvalue()
|
||||
# Put back any extra bytes
|
||||
if len(result) > size:
|
||||
unreader.unread(result[size:])
|
||||
result = result[:size]
|
||||
|
||||
return result
|
||||
|
||||
def _parse_vars(self, data):
|
||||
"""Parse uWSGI vars block into key-value pairs.
|
||||
|
||||
Format: key_size (2 bytes LE) + key + val_size (2 bytes LE) + value
|
||||
"""
|
||||
pos = 0
|
||||
var_count = 0
|
||||
|
||||
while pos < len(data):
|
||||
if var_count >= MAX_UWSGI_VARS:
|
||||
raise InvalidUWSGIHeader("too many variables")
|
||||
|
||||
# Key size (2 bytes, little-endian)
|
||||
if pos + 2 > len(data):
|
||||
raise InvalidUWSGIHeader("truncated key size")
|
||||
key_size = int.from_bytes(data[pos:pos + 2], 'little')
|
||||
pos += 2
|
||||
|
||||
# Key
|
||||
if pos + key_size > len(data):
|
||||
raise InvalidUWSGIHeader("truncated key")
|
||||
key = data[pos:pos + key_size].decode('latin-1')
|
||||
pos += key_size
|
||||
|
||||
# Value size (2 bytes, little-endian)
|
||||
if pos + 2 > len(data):
|
||||
raise InvalidUWSGIHeader("truncated value size")
|
||||
val_size = int.from_bytes(data[pos:pos + 2], 'little')
|
||||
pos += 2
|
||||
|
||||
# Value
|
||||
if pos + val_size > len(data):
|
||||
raise InvalidUWSGIHeader("truncated value")
|
||||
value = data[pos:pos + val_size].decode('latin-1')
|
||||
pos += val_size
|
||||
|
||||
self.uwsgi_vars[key] = value
|
||||
var_count += 1
|
||||
|
||||
def _extract_request_info(self):
|
||||
"""Extract HTTP request info from uWSGI vars.
|
||||
|
||||
Header Mapping (CGI/WSGI to HTTP):
|
||||
|
||||
The uWSGI protocol passes HTTP headers using CGI-style environment
|
||||
variable naming. This method converts them back to HTTP header format:
|
||||
|
||||
- HTTP_* vars: Strip 'HTTP_' prefix, replace '_' with '-'
|
||||
Example: HTTP_X_FORWARDED_FOR -> X-FORWARDED-FOR
|
||||
Example: HTTP_ACCEPT_ENCODING -> ACCEPT-ENCODING
|
||||
|
||||
- CONTENT_TYPE: Mapped directly to CONTENT-TYPE header
|
||||
(CGI spec excludes HTTP_ prefix for this header)
|
||||
|
||||
- CONTENT_LENGTH: Mapped directly to CONTENT-LENGTH header
|
||||
(CGI spec excludes HTTP_ prefix for this header)
|
||||
|
||||
Note: The underscore-to-hyphen conversion is lossy. Headers that
|
||||
originally contained underscores (e.g., X_Custom_Header) cannot be
|
||||
distinguished from hyphenated headers (X-Custom-Header) after
|
||||
passing through nginx/uWSGI. This is a CGI/WSGI specification
|
||||
limitation, not specific to this implementation.
|
||||
"""
|
||||
# Method
|
||||
self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET')
|
||||
|
||||
# URI and path
|
||||
self.path = self.uwsgi_vars.get('PATH_INFO', '/')
|
||||
self.query = self.uwsgi_vars.get('QUERY_STRING', '')
|
||||
|
||||
# Build URI
|
||||
if self.query:
|
||||
self.uri = "%s?%s" % (self.path, self.query)
|
||||
else:
|
||||
self.uri = self.path
|
||||
|
||||
# Scheme
|
||||
if self.uwsgi_vars.get('HTTPS', '').lower() in ('on', '1', 'true'):
|
||||
self.scheme = 'https'
|
||||
elif 'wsgi.url_scheme' in self.uwsgi_vars:
|
||||
self.scheme = self.uwsgi_vars['wsgi.url_scheme']
|
||||
|
||||
# Extract HTTP headers from CGI-style vars
|
||||
# See docstring above for mapping details
|
||||
for key, value in self.uwsgi_vars.items():
|
||||
if key.startswith('HTTP_'):
|
||||
# Convert HTTP_HEADER_NAME to HEADER-NAME
|
||||
header_name = key[5:].replace('_', '-')
|
||||
self.headers.append((header_name, value))
|
||||
elif key == 'CONTENT_TYPE':
|
||||
self.headers.append(('CONTENT-TYPE', value))
|
||||
elif key == 'CONTENT_LENGTH':
|
||||
self.headers.append(('CONTENT-LENGTH', value))
|
||||
|
||||
def set_body_reader(self):
|
||||
"""Set up the body reader based on CONTENT_LENGTH."""
|
||||
content_length = 0
|
||||
|
||||
# Get content length from vars
|
||||
if 'CONTENT_LENGTH' in self.uwsgi_vars:
|
||||
try:
|
||||
content_length = max(int(self.uwsgi_vars['CONTENT_LENGTH']), 0)
|
||||
except ValueError:
|
||||
content_length = 0
|
||||
|
||||
self.body = Body(LengthReader(self.unreader, content_length))
|
||||
|
||||
def should_close(self):
|
||||
"""Determine if the connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
|
||||
# Check HTTP_CONNECTION header
|
||||
connection = self.uwsgi_vars.get('HTTP_CONNECTION', '').lower()
|
||||
if connection == 'close':
|
||||
return True
|
||||
elif connection == 'keep-alive':
|
||||
return False
|
||||
|
||||
# Default to keep-alive for HTTP/1.1
|
||||
return False
|
||||
12
gunicorn/uwsgi/parser.py
Normal file
12
gunicorn/uwsgi/parser.py
Normal file
@ -0,0 +1,12 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.http.parser import Parser
|
||||
from gunicorn.uwsgi.message import UWSGIRequest
|
||||
|
||||
|
||||
class UWSGIParser(Parser):
|
||||
"""Parser for uWSGI protocol requests."""
|
||||
|
||||
mesg_class = UWSGIRequest
|
||||
@ -11,4 +11,5 @@ SUPPORTED_WORKERS = {
|
||||
"gevent_pywsgi": "gunicorn.workers.ggevent.GeventPyWSGIWorker",
|
||||
"tornado": "gunicorn.workers.gtornado.TornadoWorker",
|
||||
"gthread": "gunicorn.workers.gthread.ThreadWorker",
|
||||
"asgi": "gunicorn.workers.gasgi.ASGIWorker",
|
||||
}
|
||||
|
||||
@ -32,7 +32,7 @@ class AsyncWorker(base.Worker):
|
||||
def handle(self, listener, client, addr):
|
||||
req = None
|
||||
try:
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
parser = http.get_parser(self.cfg, client, addr)
|
||||
try:
|
||||
listener_name = listener.getsockname()
|
||||
if not self.cfg.keepalive:
|
||||
|
||||
281
gunicorn/workers/gasgi.py
Normal file
281
gunicorn/workers/gasgi.py
Normal file
@ -0,0 +1,281 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
ASGI worker for gunicorn.
|
||||
|
||||
Provides native asyncio-based ASGI support using gunicorn's own
|
||||
HTTP parsing infrastructure.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
|
||||
from gunicorn.workers import base
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
|
||||
class ASGIWorker(base.Worker):
|
||||
"""ASGI worker using asyncio event loop.
|
||||
|
||||
Supports:
|
||||
- HTTP/1.1 with keepalive
|
||||
- WebSocket connections
|
||||
- Lifespan protocol (startup/shutdown hooks)
|
||||
- Optional uvloop for improved performance
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.worker_connections = self.cfg.worker_connections
|
||||
self.loop = None
|
||||
self.servers = []
|
||||
self.nr_conns = 0
|
||||
self.lifespan = None
|
||||
self.state = {} # Shared state for lifespan
|
||||
|
||||
@classmethod
|
||||
def check_config(cls, cfg, log):
|
||||
"""Validate configuration for ASGI worker."""
|
||||
if cfg.threads > 1:
|
||||
log.warning("ASGI worker does not use threads configuration. "
|
||||
"Use worker_connections instead.")
|
||||
|
||||
def init_process(self):
|
||||
"""Initialize the worker process."""
|
||||
# Setup event loop before calling super()
|
||||
self._setup_event_loop()
|
||||
super().init_process()
|
||||
|
||||
def _setup_event_loop(self):
|
||||
"""Setup the asyncio event loop."""
|
||||
loop_type = getattr(self.cfg, 'asgi_loop', 'auto')
|
||||
|
||||
if loop_type == "auto":
|
||||
try:
|
||||
import uvloop
|
||||
loop_type = "uvloop"
|
||||
except ImportError:
|
||||
loop_type = "asyncio"
|
||||
|
||||
if loop_type == "uvloop":
|
||||
try:
|
||||
import uvloop
|
||||
self.loop = uvloop.new_event_loop()
|
||||
self.log.debug("Using uvloop event loop")
|
||||
except ImportError:
|
||||
self.log.warning("uvloop not available, falling back to asyncio")
|
||||
self.loop = asyncio.new_event_loop()
|
||||
else:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.log.debug("Using asyncio event loop")
|
||||
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
def load_wsgi(self):
|
||||
"""Load the ASGI application."""
|
||||
try:
|
||||
self.asgi = self.app.wsgi()
|
||||
except SyntaxError as e:
|
||||
if not self.cfg.reload:
|
||||
raise
|
||||
self.log.exception(e)
|
||||
self.asgi = self._make_error_app(str(e))
|
||||
|
||||
def _make_error_app(self, error_msg):
|
||||
"""Create an error ASGI app for syntax errors during reload."""
|
||||
async def error_app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 500,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": f"Application error: {error_msg}".encode(),
|
||||
})
|
||||
elif scope["type"] == "lifespan":
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return error_app
|
||||
|
||||
def init_signals(self):
|
||||
"""Initialize signal handlers for asyncio."""
|
||||
# Reset all signals first
|
||||
for s in self.SIGNALS:
|
||||
signal.signal(s, signal.SIG_DFL)
|
||||
|
||||
# Set up signal handlers via the event loop
|
||||
self.loop.add_signal_handler(signal.SIGQUIT, self.handle_quit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGTERM, self.handle_exit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGINT, self.handle_quit_signal)
|
||||
self.loop.add_signal_handler(signal.SIGUSR1, self.handle_usr1_signal)
|
||||
self.loop.add_signal_handler(signal.SIGWINCH, self.handle_winch_signal)
|
||||
self.loop.add_signal_handler(signal.SIGABRT, self.handle_abort_signal)
|
||||
|
||||
def handle_quit_signal(self):
|
||||
"""Handle SIGQUIT - immediate shutdown."""
|
||||
self.alive = False
|
||||
self.cfg.worker_int(self)
|
||||
|
||||
def handle_exit_signal(self):
|
||||
"""Handle SIGTERM - graceful shutdown."""
|
||||
self.alive = False
|
||||
|
||||
def handle_usr1_signal(self):
|
||||
"""Handle SIGUSR1 - reopen log files."""
|
||||
self.log.reopen_files()
|
||||
|
||||
def handle_winch_signal(self):
|
||||
"""Handle SIGWINCH - ignored in worker."""
|
||||
self.log.debug("worker: SIGWINCH ignored.")
|
||||
|
||||
def handle_abort_signal(self):
|
||||
"""Handle SIGABRT - abort."""
|
||||
self.alive = False
|
||||
self.cfg.worker_abort(self)
|
||||
sys.exit(1)
|
||||
|
||||
def run(self):
|
||||
"""Main entry point for the worker."""
|
||||
try:
|
||||
self.loop.run_until_complete(self._serve())
|
||||
except Exception as e:
|
||||
self.log.exception("Worker exception: %s", e)
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
async def _serve(self):
|
||||
"""Main async serving loop."""
|
||||
# Run lifespan startup
|
||||
lifespan_mode = getattr(self.cfg, 'asgi_lifespan', 'auto')
|
||||
if lifespan_mode != "off":
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
self.lifespan = LifespanManager(self.asgi, self.log, self.state)
|
||||
try:
|
||||
await self.lifespan.startup()
|
||||
except Exception as e:
|
||||
if lifespan_mode == "on":
|
||||
self.log.error("ASGI lifespan startup failed: %s", e)
|
||||
return
|
||||
else:
|
||||
# auto mode - app doesn't support lifespan
|
||||
self.log.debug("ASGI lifespan not supported by app: %s", e)
|
||||
self.lifespan = None
|
||||
|
||||
# Create servers for each listener socket
|
||||
ssl_context = self._get_ssl_context()
|
||||
|
||||
for sock in self.sockets:
|
||||
try:
|
||||
server = await self.loop.create_server(
|
||||
lambda: ASGIProtocol(self),
|
||||
sock=sock.sock,
|
||||
ssl=ssl_context,
|
||||
reuse_address=True,
|
||||
start_serving=True,
|
||||
)
|
||||
self.servers.append(server)
|
||||
self.log.info("ASGI server listening on %s", sock)
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create server on %s: %s", sock, e)
|
||||
|
||||
if not self.servers:
|
||||
self.log.error("No servers could be started")
|
||||
return
|
||||
|
||||
# Main loop with heartbeat
|
||||
try:
|
||||
while self.alive:
|
||||
self.notify()
|
||||
|
||||
# Check if parent is still alive
|
||||
if self.ppid != os.getppid():
|
||||
self.log.info("Parent changed, shutting down: %s", self)
|
||||
break
|
||||
|
||||
# Check connection limit
|
||||
# (Connections are managed by nr_conns in ASGIProtocol)
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Graceful shutdown
|
||||
await self._shutdown()
|
||||
|
||||
async def _shutdown(self):
|
||||
"""Perform graceful shutdown."""
|
||||
self.log.info("Worker shutting down...")
|
||||
|
||||
# Stop accepting new connections
|
||||
for server in self.servers:
|
||||
server.close()
|
||||
|
||||
# Wait for servers to close
|
||||
for server in self.servers:
|
||||
await server.wait_closed()
|
||||
|
||||
# Wait for in-flight connections (with timeout)
|
||||
graceful_timeout = self.cfg.graceful_timeout
|
||||
if self.nr_conns > 0:
|
||||
self.log.info("Waiting for %d connections to finish...", self.nr_conns)
|
||||
deadline = self.loop.time() + graceful_timeout
|
||||
while self.nr_conns > 0 and self.loop.time() < deadline:
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if self.nr_conns > 0:
|
||||
self.log.warning("Closing %d connections after timeout", self.nr_conns)
|
||||
|
||||
# Run lifespan shutdown
|
||||
if self.lifespan:
|
||||
try:
|
||||
await self.lifespan.shutdown()
|
||||
except Exception as e:
|
||||
self.log.error("ASGI lifespan shutdown error: %s", e)
|
||||
|
||||
def _get_ssl_context(self):
|
||||
"""Get SSL context if configured."""
|
||||
if not self.cfg.is_ssl:
|
||||
return None
|
||||
|
||||
try:
|
||||
from gunicorn import sock
|
||||
return sock.ssl_context(self.cfg)
|
||||
except Exception as e:
|
||||
self.log.error("Failed to create SSL context: %s", e)
|
||||
return None
|
||||
|
||||
def _cleanup(self):
|
||||
"""Clean up resources on exit."""
|
||||
try:
|
||||
# Cancel all pending tasks
|
||||
pending = asyncio.all_tasks(self.loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
# Run loop until all tasks are cancelled
|
||||
if pending:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
|
||||
self.loop.close()
|
||||
except Exception as e:
|
||||
self.log.debug("Cleanup error: %s", e)
|
||||
|
||||
# Close sockets
|
||||
for s in self.sockets:
|
||||
try:
|
||||
s.close()
|
||||
except Exception:
|
||||
pass
|
||||
@ -58,7 +58,7 @@ class TConn:
|
||||
self.sock = sock.ssl_wrap_socket(self.sock, self.cfg)
|
||||
|
||||
# initialize the parser
|
||||
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
|
||||
self.parser = http.get_parser(self.cfg, self.sock, self.client)
|
||||
|
||||
def set_timeout(self):
|
||||
# Use monotonic clock for reliability (time.time() can jump due to NTP)
|
||||
|
||||
@ -129,7 +129,7 @@ class SyncWorker(base.Worker):
|
||||
try:
|
||||
if self.cfg.is_ssl:
|
||||
client = sock.ssl_wrap_socket(client, self.cfg)
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
parser = http.get_parser(self.cfg, client, addr)
|
||||
req = next(parser)
|
||||
self.handle_request(listener, req, client, addr)
|
||||
except http.errors.NoMoreData as e:
|
||||
|
||||
@ -58,6 +58,7 @@ testing = [
|
||||
"coverage",
|
||||
"pytest",
|
||||
"pytest-cov",
|
||||
"pytest-asyncio",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@ -70,7 +71,7 @@ main = "gunicorn.app.pasterapp:serve"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# # can override these: python -m pytest --override-ini="addopts="
|
||||
norecursedirs = ["examples", "lib", "local", "src"]
|
||||
norecursedirs = ["examples", "lib", "local", "src", "tests/docker"]
|
||||
testpaths = ["tests/"]
|
||||
addopts = "--assert=plain --cov=gunicorn --cov-report=xml"
|
||||
|
||||
|
||||
@ -3,3 +3,4 @@ eventlet
|
||||
coverage
|
||||
pytest>=7.2.0
|
||||
pytest-cov
|
||||
pytest-asyncio
|
||||
|
||||
16
tests/docker/uwsgi/Dockerfile.gunicorn
Normal file
16
tests/docker/uwsgi/Dockerfile.gunicorn
Normal file
@ -0,0 +1,16 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy gunicorn source
|
||||
COPY . /app/gunicorn-src/
|
||||
|
||||
# Install gunicorn from source
|
||||
RUN pip install --no-cache-dir /app/gunicorn-src/
|
||||
|
||||
# Copy test application
|
||||
COPY tests/docker/uwsgi/app.py /app/
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["gunicorn", "--protocol", "uwsgi", "--uwsgi-allow-from", "*", "--bind", "0.0.0.0:8000", "--workers", "2", "--log-level", "debug", "app:application"]
|
||||
12
tests/docker/uwsgi/Dockerfile.nginx
Normal file
12
tests/docker/uwsgi/Dockerfile.nginx
Normal file
@ -0,0 +1,12 @@
|
||||
FROM nginx:alpine
|
||||
|
||||
# Remove default config
|
||||
RUN rm /etc/nginx/conf.d/default.conf
|
||||
|
||||
# Copy custom config
|
||||
COPY nginx.conf /etc/nginx/nginx.conf
|
||||
COPY uwsgi_params /etc/nginx/uwsgi_params
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
CMD ["nginx", "-g", "daemon off;"]
|
||||
154
tests/docker/uwsgi/README.md
Normal file
154
tests/docker/uwsgi/README.md
Normal file
@ -0,0 +1,154 @@
|
||||
# uWSGI Protocol Docker Integration Tests
|
||||
|
||||
This directory contains Docker-based integration tests that verify gunicorn's
|
||||
uWSGI binary protocol implementation works correctly with nginx's `uwsgi_pass`
|
||||
directive.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
[pytest] --HTTP--> [nginx:8080] --uwsgi_pass--> [gunicorn:8000]
|
||||
```
|
||||
|
||||
The tests make HTTP requests to nginx, which proxies them to gunicorn using the
|
||||
uWSGI binary protocol. This validates the complete request/response cycle through
|
||||
the protocol.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker
|
||||
- Docker Compose (v2)
|
||||
- Python 3.8+
|
||||
- pytest
|
||||
- requests
|
||||
|
||||
## Running Tests
|
||||
|
||||
### From repository root:
|
||||
|
||||
```bash
|
||||
# Run all uWSGI integration tests
|
||||
pytest tests/docker/uwsgi/ -v
|
||||
|
||||
# Run specific test class
|
||||
pytest tests/docker/uwsgi/ -v -k TestBasicRequests
|
||||
|
||||
# Skip Docker tests (for CI environments without Docker)
|
||||
pytest tests/ -v -m "not docker"
|
||||
```
|
||||
|
||||
### Manual testing:
|
||||
|
||||
```bash
|
||||
cd tests/docker/uwsgi
|
||||
|
||||
# Start services
|
||||
docker compose up -d
|
||||
|
||||
# Wait for services to be healthy
|
||||
docker compose ps
|
||||
|
||||
# Test endpoints
|
||||
curl http://localhost:8080/
|
||||
curl -X POST -d "test body" http://localhost:8080/echo
|
||||
curl http://localhost:8080/headers
|
||||
curl "http://localhost:8080/query?foo=bar"
|
||||
curl http://localhost:8080/environ
|
||||
curl http://localhost:8080/error/404
|
||||
curl http://localhost:8080/large > /dev/null # 1MB response
|
||||
|
||||
# View logs
|
||||
docker compose logs gunicorn
|
||||
docker compose logs nginx
|
||||
|
||||
# Stop services
|
||||
docker compose down -v
|
||||
```
|
||||
|
||||
## Test Categories
|
||||
|
||||
| Category | Description |
|
||||
|----------|-------------|
|
||||
| `TestBasicRequests` | GET, POST, query strings, large bodies |
|
||||
| `TestHeaderPreservation` | Custom headers, Host, Content-Type, User-Agent |
|
||||
| `TestKeepAlive` | Multiple requests per connection |
|
||||
| `TestErrorResponses` | HTTP error codes (400, 404, 500, etc.) |
|
||||
| `TestEnvironVariables` | WSGI environ: REQUEST_METHOD, PATH_INFO, etc. |
|
||||
| `TestLargeResponses` | 1MB response body streaming |
|
||||
| `TestConcurrency` | Parallel request handling |
|
||||
| `TestSpecialCases` | Edge cases: binary data, unicode, long headers |
|
||||
|
||||
## Files
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `docker-compose.yml` | Orchestrates nginx + gunicorn containers |
|
||||
| `Dockerfile.gunicorn` | Builds gunicorn image with test app |
|
||||
| `Dockerfile.nginx` | Builds nginx with uwsgi config |
|
||||
| `nginx.conf` | nginx configuration using `uwsgi_pass` |
|
||||
| `uwsgi_params` | Standard uwsgi parameter mappings |
|
||||
| `app.py` | Test WSGI application with multiple endpoints |
|
||||
| `conftest.py` | pytest fixtures for Docker lifecycle |
|
||||
| `test_uwsgi_integration.py` | Test cases |
|
||||
|
||||
## Test App Endpoints
|
||||
|
||||
| Endpoint | Method | Description |
|
||||
|----------|--------|-------------|
|
||||
| `/` | GET | Basic hello response |
|
||||
| `/echo` | POST | Echo request body |
|
||||
| `/headers` | GET/POST | Return received headers as JSON |
|
||||
| `/environ` | GET/POST | Return WSGI environ as JSON |
|
||||
| `/query` | GET | Return query params as JSON |
|
||||
| `/json` | POST | Parse and echo JSON body |
|
||||
| `/error/{code}` | GET | Return specified HTTP error |
|
||||
| `/large` | GET | Return 1MB response |
|
||||
|
||||
## Gunicorn Configuration
|
||||
|
||||
The gunicorn container runs with:
|
||||
|
||||
```bash
|
||||
gunicorn \
|
||||
--protocol uwsgi \
|
||||
--uwsgi-allow-from "*" \
|
||||
--bind 0.0.0.0:8000 \
|
||||
--workers 2 \
|
||||
--log-level debug \
|
||||
app:application
|
||||
```
|
||||
|
||||
Key settings:
|
||||
- `--protocol uwsgi`: Enable uWSGI binary protocol
|
||||
- `--uwsgi-allow-from "*"`: Accept connections from Docker network IPs
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Services won't start
|
||||
|
||||
Check Docker logs:
|
||||
```bash
|
||||
docker compose logs
|
||||
```
|
||||
|
||||
### Connection refused
|
||||
|
||||
Wait for health checks:
|
||||
```bash
|
||||
docker compose ps # Check health status
|
||||
```
|
||||
|
||||
### Tests timing out
|
||||
|
||||
Increase `STARTUP_TIMEOUT` in `conftest.py` or check if ports are in use:
|
||||
```bash
|
||||
lsof -i :8080
|
||||
lsof -i :8000
|
||||
```
|
||||
|
||||
### Rebuild after code changes
|
||||
|
||||
```bash
|
||||
docker compose build --no-cache
|
||||
docker compose up -d
|
||||
```
|
||||
222
tests/docker/uwsgi/app.py
Normal file
222
tests/docker/uwsgi/app.py
Normal file
@ -0,0 +1,222 @@
|
||||
"""
|
||||
Test WSGI application for uWSGI protocol integration tests.
|
||||
|
||||
This application provides various endpoints to test different aspects
|
||||
of the uWSGI binary protocol when proxied through nginx.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def application(environ, start_response):
|
||||
"""Main WSGI application entry point."""
|
||||
path = environ.get('PATH_INFO', '/')
|
||||
method = environ.get('REQUEST_METHOD', 'GET')
|
||||
|
||||
# Route to appropriate handler
|
||||
if path == '/':
|
||||
return handle_root(environ, start_response)
|
||||
elif path == '/echo':
|
||||
return handle_echo(environ, start_response)
|
||||
elif path == '/headers':
|
||||
return handle_headers(environ, start_response)
|
||||
elif path == '/environ':
|
||||
return handle_environ(environ, start_response)
|
||||
elif path.startswith('/error/'):
|
||||
return handle_error(environ, start_response, path)
|
||||
elif path == '/large':
|
||||
return handle_large(environ, start_response)
|
||||
elif path == '/json':
|
||||
return handle_json(environ, start_response)
|
||||
elif path == '/query':
|
||||
return handle_query(environ, start_response)
|
||||
else:
|
||||
return handle_not_found(environ, start_response)
|
||||
|
||||
|
||||
def handle_root(environ, start_response):
|
||||
"""Basic root endpoint."""
|
||||
status = '200 OK'
|
||||
headers = [('Content-Type', 'text/plain')]
|
||||
start_response(status, headers)
|
||||
return [b'Hello from gunicorn uWSGI!\n']
|
||||
|
||||
|
||||
def handle_echo(environ, start_response):
|
||||
"""Echo back the request body."""
|
||||
try:
|
||||
content_length = int(environ.get('CONTENT_LENGTH', 0))
|
||||
except (ValueError, TypeError):
|
||||
content_length = 0
|
||||
|
||||
body = b''
|
||||
if content_length > 0:
|
||||
body = environ['wsgi.input'].read(content_length)
|
||||
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/octet-stream'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_headers(environ, start_response):
|
||||
"""Return received HTTP headers as JSON."""
|
||||
headers_dict = {}
|
||||
for key, value in environ.items():
|
||||
if key.startswith('HTTP_'):
|
||||
# Convert HTTP_X_CUSTOM_HEADER to X-Custom-Header
|
||||
header_name = key[5:].replace('_', '-').title()
|
||||
headers_dict[header_name] = value
|
||||
|
||||
# Also include some special headers
|
||||
if 'CONTENT_TYPE' in environ:
|
||||
headers_dict['Content-Type'] = environ['CONTENT_TYPE']
|
||||
if 'CONTENT_LENGTH' in environ:
|
||||
headers_dict['Content-Length'] = environ['CONTENT_LENGTH']
|
||||
|
||||
body = json.dumps(headers_dict, indent=2).encode('utf-8')
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_environ(environ, start_response):
|
||||
"""Return WSGI environ variables as JSON."""
|
||||
# Filter to serializable values
|
||||
safe_environ = {}
|
||||
skip_keys = {'wsgi.input', 'wsgi.errors', 'wsgi.file_wrapper'}
|
||||
|
||||
for key, value in environ.items():
|
||||
if key in skip_keys:
|
||||
continue
|
||||
try:
|
||||
# Test if value is JSON serializable
|
||||
json.dumps(value)
|
||||
safe_environ[key] = value
|
||||
except (TypeError, ValueError):
|
||||
safe_environ[key] = str(value)
|
||||
|
||||
body = json.dumps(safe_environ, indent=2).encode('utf-8')
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_error(environ, start_response, path):
|
||||
"""Return specified HTTP error code."""
|
||||
try:
|
||||
code = int(path.split('/')[-1])
|
||||
except ValueError:
|
||||
code = 500
|
||||
|
||||
status_messages = {
|
||||
400: 'Bad Request',
|
||||
401: 'Unauthorized',
|
||||
403: 'Forbidden',
|
||||
404: 'Not Found',
|
||||
500: 'Internal Server Error',
|
||||
502: 'Bad Gateway',
|
||||
503: 'Service Unavailable',
|
||||
}
|
||||
|
||||
message = status_messages.get(code, 'Error')
|
||||
status = f'{code} {message}'
|
||||
body = json.dumps({'error': message, 'code': code}).encode('utf-8')
|
||||
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_large(environ, start_response):
|
||||
"""Return a 1MB response body for testing large responses."""
|
||||
# Generate 1MB of data (1024 * 1024 bytes)
|
||||
chunk_size = 1024
|
||||
num_chunks = 1024
|
||||
chunk = b'X' * chunk_size
|
||||
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/octet-stream'),
|
||||
('Content-Length', str(chunk_size * num_chunks))
|
||||
]
|
||||
start_response(status, headers)
|
||||
|
||||
# Return as generator for streaming
|
||||
def generate():
|
||||
for _ in range(num_chunks):
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
|
||||
def handle_json(environ, start_response):
|
||||
"""Handle JSON POST requests."""
|
||||
try:
|
||||
content_length = int(environ.get('CONTENT_LENGTH', 0))
|
||||
except (ValueError, TypeError):
|
||||
content_length = 0
|
||||
|
||||
if content_length > 0:
|
||||
body = environ['wsgi.input'].read(content_length)
|
||||
try:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
response = {'received': data, 'status': 'ok'}
|
||||
except json.JSONDecodeError:
|
||||
response = {'error': 'Invalid JSON', 'status': 'error'}
|
||||
else:
|
||||
response = {'error': 'No body', 'status': 'error'}
|
||||
|
||||
body = json.dumps(response).encode('utf-8')
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_query(environ, start_response):
|
||||
"""Return query string parameters as JSON."""
|
||||
from urllib.parse import parse_qs
|
||||
query_string = environ.get('QUERY_STRING', '')
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Convert lists to single values where appropriate
|
||||
simple_params = {k: v[0] if len(v) == 1 else v for k, v in params.items()}
|
||||
|
||||
body = json.dumps(simple_params).encode('utf-8')
|
||||
status = '200 OK'
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
|
||||
|
||||
def handle_not_found(environ, start_response):
|
||||
"""Handle 404 for unknown paths."""
|
||||
body = json.dumps({'error': 'Not Found', 'path': environ.get('PATH_INFO')}).encode('utf-8')
|
||||
status = '404 Not Found'
|
||||
headers = [
|
||||
('Content-Type', 'application/json'),
|
||||
('Content-Length', str(len(body)))
|
||||
]
|
||||
start_response(status, headers)
|
||||
return [body]
|
||||
121
tests/docker/uwsgi/conftest.py
Normal file
121
tests/docker/uwsgi/conftest.py
Normal file
@ -0,0 +1,121 @@
|
||||
"""
|
||||
pytest fixtures for uWSGI Docker integration tests.
|
||||
"""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
COMPOSE_FILE = os.path.join(os.path.dirname(__file__), 'docker-compose.yml')
|
||||
NGINX_URL = 'http://127.0.0.1:8080'
|
||||
STARTUP_TIMEOUT = 60 # seconds
|
||||
|
||||
|
||||
def is_docker_available():
|
||||
"""Check if Docker is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['docker', 'info'],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
def is_compose_available():
|
||||
"""Check if docker compose is available."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
['docker', 'compose', 'version'],
|
||||
capture_output=True,
|
||||
timeout=10
|
||||
)
|
||||
return result.returncode == 0
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return False
|
||||
|
||||
|
||||
docker_available = pytest.mark.skipif(
|
||||
not is_docker_available() or not is_compose_available(),
|
||||
reason="Docker or docker compose not available"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session')
|
||||
def docker_services():
|
||||
"""
|
||||
Start Docker Compose services for the test session.
|
||||
|
||||
This fixture builds and starts the gunicorn and nginx containers,
|
||||
waits for them to be healthy, and tears them down after all tests.
|
||||
"""
|
||||
if not is_docker_available() or not is_compose_available():
|
||||
pytest.skip("Docker or docker compose not available")
|
||||
|
||||
# Build and start services
|
||||
subprocess.run(
|
||||
['docker', 'compose', '-f', COMPOSE_FILE, 'build'],
|
||||
check=True,
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
['docker', 'compose', '-f', COMPOSE_FILE, 'up', '-d'],
|
||||
check=True,
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
# Wait for services to be healthy
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < STARTUP_TIMEOUT:
|
||||
try:
|
||||
response = requests.get(f'{NGINX_URL}/', timeout=2)
|
||||
if response.status_code == 200:
|
||||
break
|
||||
except requests.RequestException:
|
||||
pass
|
||||
time.sleep(1)
|
||||
else:
|
||||
# Get logs for debugging
|
||||
logs = subprocess.run(
|
||||
['docker', 'compose', '-f', COMPOSE_FILE, 'logs'],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
subprocess.run(
|
||||
['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'],
|
||||
capture_output=True
|
||||
)
|
||||
pytest.fail(
|
||||
f"Services did not become healthy within {STARTUP_TIMEOUT}s.\n"
|
||||
f"Logs:\n{logs.stdout}\n{logs.stderr}"
|
||||
)
|
||||
|
||||
yield
|
||||
|
||||
# Teardown
|
||||
subprocess.run(
|
||||
['docker', 'compose', '-f', COMPOSE_FILE, 'down', '-v'],
|
||||
capture_output=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nginx_url(docker_services):
|
||||
"""Return the nginx base URL."""
|
||||
return NGINX_URL
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session(docker_services):
|
||||
"""Return a requests Session with keep-alive enabled."""
|
||||
with requests.Session() as s:
|
||||
# Enable keep-alive
|
||||
s.headers['Connection'] = 'keep-alive'
|
||||
yield s
|
||||
29
tests/docker/uwsgi/docker-compose.yml
Normal file
29
tests/docker/uwsgi/docker-compose.yml
Normal file
@ -0,0 +1,29 @@
|
||||
services:
|
||||
gunicorn:
|
||||
build:
|
||||
context: ../../..
|
||||
dockerfile: tests/docker/uwsgi/Dockerfile.gunicorn
|
||||
expose:
|
||||
- "8000"
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import socket; s=socket.socket(); s.connect(('localhost', 8000)); s.close()"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
start_period: 5s
|
||||
|
||||
nginx:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.nginx
|
||||
ports:
|
||||
- "8080:8080"
|
||||
depends_on:
|
||||
gunicorn:
|
||||
condition: service_healthy
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8080/"]
|
||||
interval: 2s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
start_period: 5s
|
||||
46
tests/docker/uwsgi/nginx.conf
Normal file
46
tests/docker/uwsgi/nginx.conf
Normal file
@ -0,0 +1,46 @@
|
||||
worker_processes 1;
|
||||
|
||||
events {
|
||||
worker_connections 1024;
|
||||
}
|
||||
|
||||
http {
|
||||
include /etc/nginx/mime.types;
|
||||
default_type application/octet-stream;
|
||||
|
||||
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||
'$status $body_bytes_sent "$http_referer" '
|
||||
'"$http_user_agent"';
|
||||
|
||||
access_log /var/log/nginx/access.log main;
|
||||
error_log /var/log/nginx/error.log debug;
|
||||
|
||||
sendfile on;
|
||||
keepalive_timeout 65;
|
||||
|
||||
upstream gunicorn {
|
||||
server gunicorn:8000;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 8080;
|
||||
server_name localhost;
|
||||
|
||||
# Increase buffer sizes for large headers
|
||||
uwsgi_buffer_size 32k;
|
||||
uwsgi_buffers 8 32k;
|
||||
uwsgi_busy_buffers_size 64k;
|
||||
|
||||
# Read timeout for large responses
|
||||
uwsgi_read_timeout 300s;
|
||||
|
||||
location / {
|
||||
uwsgi_pass gunicorn;
|
||||
include uwsgi_params;
|
||||
|
||||
# Pass additional headers
|
||||
uwsgi_param HTTP_X_FORWARDED_FOR $proxy_add_x_forwarded_for;
|
||||
uwsgi_param HTTP_X_REAL_IP $remote_addr;
|
||||
}
|
||||
}
|
||||
}
|
||||
312
tests/docker/uwsgi/test_uwsgi_integration.py
Normal file
312
tests/docker/uwsgi/test_uwsgi_integration.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""
|
||||
Integration tests for gunicorn's uWSGI binary protocol with nginx.
|
||||
|
||||
These tests verify that gunicorn correctly implements the uWSGI binary
|
||||
protocol by running actual requests through nginx's uwsgi_pass directive.
|
||||
"""
|
||||
|
||||
import concurrent.futures
|
||||
import json
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from conftest import docker_available
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestBasicRequests:
|
||||
"""Test basic HTTP request handling through uWSGI protocol."""
|
||||
|
||||
def test_get_root(self, nginx_url):
|
||||
"""Test basic GET request to root endpoint."""
|
||||
response = requests.get(f'{nginx_url}/')
|
||||
assert response.status_code == 200
|
||||
assert b'Hello from gunicorn uWSGI!' in response.content
|
||||
|
||||
def test_get_with_query_string(self, nginx_url):
|
||||
"""Test GET request with query string parameters."""
|
||||
response = requests.get(f'{nginx_url}/query?foo=bar&baz=qux')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['foo'] == 'bar'
|
||||
assert data['baz'] == 'qux'
|
||||
|
||||
def test_post_echo(self, nginx_url):
|
||||
"""Test POST request with body echo."""
|
||||
test_body = b'This is a test body content'
|
||||
response = requests.post(f'{nginx_url}/echo', data=test_body)
|
||||
assert response.status_code == 200
|
||||
assert response.content == test_body
|
||||
|
||||
def test_post_json(self, nginx_url):
|
||||
"""Test POST request with JSON body."""
|
||||
test_data = {'key': 'value', 'number': 42, 'nested': {'a': 1}}
|
||||
response = requests.post(
|
||||
f'{nginx_url}/json',
|
||||
json=test_data,
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data['status'] == 'ok'
|
||||
assert data['received'] == test_data
|
||||
|
||||
def test_post_large_body(self, nginx_url):
|
||||
"""Test POST with large request body (100KB)."""
|
||||
large_body = b'X' * (100 * 1024)
|
||||
response = requests.post(f'{nginx_url}/echo', data=large_body)
|
||||
assert response.status_code == 200
|
||||
assert len(response.content) == len(large_body)
|
||||
assert response.content == large_body
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestHeaderPreservation:
|
||||
"""Test that headers are correctly passed through uWSGI protocol."""
|
||||
|
||||
def test_custom_headers(self, nginx_url):
|
||||
"""Test custom headers are passed to the application."""
|
||||
custom_headers = {
|
||||
'X-Custom-Header': 'custom-value',
|
||||
'X-Another-Header': 'another-value'
|
||||
}
|
||||
response = requests.get(f'{nginx_url}/headers', headers=custom_headers)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('X-Custom-Header') == 'custom-value'
|
||||
assert data.get('X-Another-Header') == 'another-value'
|
||||
|
||||
def test_host_header(self, nginx_url):
|
||||
"""Test Host header is passed correctly."""
|
||||
response = requests.get(
|
||||
f'{nginx_url}/headers',
|
||||
headers={'Host': 'test.example.com'}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('Host') == 'test.example.com'
|
||||
|
||||
def test_content_type_header(self, nginx_url):
|
||||
"""Test Content-Type header is passed correctly."""
|
||||
response = requests.post(
|
||||
f'{nginx_url}/headers',
|
||||
data='test',
|
||||
headers={'Content-Type': 'application/x-custom-type'}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('Content-Type') == 'application/x-custom-type'
|
||||
|
||||
def test_user_agent_header(self, nginx_url):
|
||||
"""Test User-Agent header is passed correctly."""
|
||||
response = requests.get(
|
||||
f'{nginx_url}/headers',
|
||||
headers={'User-Agent': 'TestAgent/1.0'}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('User-Agent') == 'TestAgent/1.0'
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestKeepAlive:
|
||||
"""Test HTTP keep-alive with multiple requests per connection."""
|
||||
|
||||
def test_multiple_requests_same_session(self, session, nginx_url):
|
||||
"""Test multiple requests using same session/connection."""
|
||||
for i in range(5):
|
||||
response = session.get(f'{nginx_url}/')
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_mixed_requests_same_session(self, session, nginx_url):
|
||||
"""Test mixed GET and POST requests using same session."""
|
||||
# GET request
|
||||
response = session.get(f'{nginx_url}/')
|
||||
assert response.status_code == 200
|
||||
|
||||
# POST request
|
||||
response = session.post(f'{nginx_url}/echo', data=b'test')
|
||||
assert response.status_code == 200
|
||||
assert response.content == b'test'
|
||||
|
||||
# Another GET
|
||||
response = session.get(f'{nginx_url}/headers')
|
||||
assert response.status_code == 200
|
||||
|
||||
# JSON POST
|
||||
response = session.post(f'{nginx_url}/json', json={'test': 1})
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestErrorResponses:
|
||||
"""Test HTTP error responses through uWSGI protocol."""
|
||||
|
||||
@pytest.mark.parametrize('code', [400, 401, 403, 404, 500, 502, 503])
|
||||
def test_error_codes(self, nginx_url, code):
|
||||
"""Test various HTTP error codes are returned correctly."""
|
||||
response = requests.get(f'{nginx_url}/error/{code}')
|
||||
assert response.status_code == code
|
||||
data = response.json()
|
||||
assert data['code'] == code
|
||||
|
||||
def test_not_found(self, nginx_url):
|
||||
"""Test 404 for non-existent path."""
|
||||
response = requests.get(f'{nginx_url}/nonexistent/path')
|
||||
assert response.status_code == 404
|
||||
data = response.json()
|
||||
assert data['error'] == 'Not Found'
|
||||
assert data['path'] == '/nonexistent/path'
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestEnvironVariables:
|
||||
"""Test WSGI environ variables are correctly set."""
|
||||
|
||||
def test_request_method(self, nginx_url):
|
||||
"""Test REQUEST_METHOD is set correctly."""
|
||||
response = requests.get(f'{nginx_url}/environ')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('REQUEST_METHOD') == 'GET'
|
||||
|
||||
response = requests.post(f'{nginx_url}/environ', data='')
|
||||
data = response.json()
|
||||
assert data.get('REQUEST_METHOD') == 'POST'
|
||||
|
||||
def test_path_info(self, nginx_url):
|
||||
"""Test PATH_INFO is set correctly."""
|
||||
response = requests.get(f'{nginx_url}/environ')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('PATH_INFO') == '/environ'
|
||||
|
||||
def test_query_string(self, nginx_url):
|
||||
"""Test QUERY_STRING is set correctly."""
|
||||
response = requests.get(f'{nginx_url}/environ?foo=bar&test=123')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('QUERY_STRING') == 'foo=bar&test=123'
|
||||
|
||||
def test_server_protocol(self, nginx_url):
|
||||
"""Test SERVER_PROTOCOL is set."""
|
||||
response = requests.get(f'{nginx_url}/environ')
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert 'SERVER_PROTOCOL' in data
|
||||
assert data['SERVER_PROTOCOL'].startswith('HTTP/')
|
||||
|
||||
def test_content_length(self, nginx_url):
|
||||
"""Test CONTENT_LENGTH is set for POST requests."""
|
||||
body = 'test body content'
|
||||
response = requests.post(f'{nginx_url}/environ', data=body)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('CONTENT_LENGTH') == str(len(body))
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestLargeResponses:
|
||||
"""Test large response handling through uWSGI protocol."""
|
||||
|
||||
def test_1mb_response(self, nginx_url):
|
||||
"""Test 1MB response body is received correctly."""
|
||||
response = requests.get(f'{nginx_url}/large')
|
||||
assert response.status_code == 200
|
||||
assert len(response.content) == 1024 * 1024
|
||||
# Verify content is all 'X' characters
|
||||
assert response.content == b'X' * (1024 * 1024)
|
||||
|
||||
def test_large_response_content_length(self, nginx_url):
|
||||
"""Test Content-Length header for large response."""
|
||||
response = requests.get(f'{nginx_url}/large')
|
||||
assert response.status_code == 200
|
||||
assert response.headers.get('Content-Length') == str(1024 * 1024)
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestConcurrency:
|
||||
"""Test concurrent request handling."""
|
||||
|
||||
def test_parallel_requests(self, nginx_url):
|
||||
"""Test handling multiple parallel requests."""
|
||||
num_requests = 20
|
||||
|
||||
def make_request(i):
|
||||
response = requests.get(f'{nginx_url}/query?id={i}')
|
||||
return response.status_code, response.json().get('id')
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(make_request, i) for i in range(num_requests)]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
# All requests should succeed
|
||||
assert all(status == 200 for status, _ in results)
|
||||
# All IDs should be present
|
||||
ids = set(id_val for _, id_val in results)
|
||||
assert ids == set(str(i) for i in range(num_requests))
|
||||
|
||||
def test_parallel_mixed_requests(self, nginx_url):
|
||||
"""Test parallel GET and POST requests."""
|
||||
def get_request():
|
||||
return requests.get(f'{nginx_url}/').status_code
|
||||
|
||||
def post_request(data):
|
||||
response = requests.post(f'{nginx_url}/echo', data=data)
|
||||
return response.status_code, response.content
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
get_futures = [executor.submit(get_request) for _ in range(10)]
|
||||
post_futures = [
|
||||
executor.submit(post_request, f'data-{i}'.encode())
|
||||
for i in range(10)
|
||||
]
|
||||
|
||||
get_results = [f.result() for f in get_futures]
|
||||
post_results = [f.result() for f in post_futures]
|
||||
|
||||
assert all(status == 200 for status in get_results)
|
||||
assert all(status == 200 for status, _ in post_results)
|
||||
|
||||
|
||||
@docker_available
|
||||
class TestSpecialCases:
|
||||
"""Test edge cases and special scenarios."""
|
||||
|
||||
def test_empty_body_post(self, nginx_url):
|
||||
"""Test POST with empty body."""
|
||||
response = requests.post(f'{nginx_url}/echo', data=b'')
|
||||
assert response.status_code == 200
|
||||
assert response.content == b''
|
||||
|
||||
def test_binary_body(self, nginx_url):
|
||||
"""Test POST with binary body containing null bytes."""
|
||||
binary_data = bytes(range(256))
|
||||
response = requests.post(f'{nginx_url}/echo', data=binary_data)
|
||||
assert response.status_code == 200
|
||||
assert response.content == binary_data
|
||||
|
||||
def test_unicode_in_query_string(self, nginx_url):
|
||||
"""Test unicode characters in query string."""
|
||||
response = requests.get(f'{nginx_url}/query', params={'name': 'test'})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('name') == 'test'
|
||||
|
||||
def test_special_characters_in_path(self, nginx_url):
|
||||
"""Test handling of special path that triggers 404."""
|
||||
# This should return 404 since the path doesn't exist
|
||||
response = requests.get(f'{nginx_url}/path/with/slashes')
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_long_header_value(self, nginx_url):
|
||||
"""Test handling of long header values."""
|
||||
long_value = 'X' * 4096 # 4KB header value
|
||||
response = requests.get(
|
||||
f'{nginx_url}/headers',
|
||||
headers={'X-Long-Header': long_value}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data.get('X-Long-Header') == long_value
|
||||
16
tests/docker/uwsgi/uwsgi_params
Normal file
16
tests/docker/uwsgi/uwsgi_params
Normal file
@ -0,0 +1,16 @@
|
||||
uwsgi_param QUERY_STRING $query_string;
|
||||
uwsgi_param REQUEST_METHOD $request_method;
|
||||
uwsgi_param CONTENT_TYPE $content_type;
|
||||
uwsgi_param CONTENT_LENGTH $content_length;
|
||||
|
||||
uwsgi_param REQUEST_URI $request_uri;
|
||||
uwsgi_param PATH_INFO $document_uri;
|
||||
uwsgi_param DOCUMENT_ROOT $document_root;
|
||||
uwsgi_param SERVER_PROTOCOL $server_protocol;
|
||||
uwsgi_param REQUEST_SCHEME $scheme;
|
||||
uwsgi_param HTTPS $https if_not_empty;
|
||||
|
||||
uwsgi_param REMOTE_ADDR $remote_addr;
|
||||
uwsgi_param REMOTE_PORT $remote_port;
|
||||
uwsgi_param SERVER_PORT $server_port;
|
||||
uwsgi_param SERVER_NAME $server_name;
|
||||
285
tests/test_asgi.py
Normal file
285
tests/test_asgi.py
Normal file
@ -0,0 +1,285 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for ASGI worker components.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.asgi.unreader import AsyncUnreader
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock asyncio.StreamReader for testing."""
|
||||
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
self.pos = 0
|
||||
|
||||
async def read(self, size=-1):
|
||||
if self.pos >= len(self.data):
|
||||
return b""
|
||||
if size < 0:
|
||||
result = self.data[self.pos:]
|
||||
self.pos = len(self.data)
|
||||
else:
|
||||
result = self.data[self.pos:self.pos + size]
|
||||
self.pos += size
|
||||
return result
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self.pos + n > len(self.data):
|
||||
raise asyncio.IncompleteReadError(
|
||||
self.data[self.pos:], n
|
||||
)
|
||||
result = self.data[self.pos:self.pos + n]
|
||||
self.pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock gunicorn config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_ssl = False
|
||||
self.proxy_protocol = False
|
||||
self.proxy_allow_ips = ["127.0.0.1"]
|
||||
self.forwarded_allow_ips = ["127.0.0.1"]
|
||||
self.secure_scheme_headers = {}
|
||||
self.forwarder_headers = []
|
||||
self.limit_request_line = 8190
|
||||
self.limit_request_fields = 100
|
||||
self.limit_request_field_size = 8190
|
||||
self.permit_unconventional_http_method = False
|
||||
self.permit_unconventional_http_version = False
|
||||
self.permit_obsolete_folding = False
|
||||
self.casefold_http_method = False
|
||||
self.strip_header_spaces = False
|
||||
self.header_map = "refuse"
|
||||
|
||||
|
||||
# AsyncUnreader Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_chunk():
|
||||
"""Test basic chunk reading."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_size():
|
||||
"""Test reading specific size."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(5)
|
||||
assert data == b"hello"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_unread():
|
||||
"""Test unread functionality."""
|
||||
reader = MockStreamReader(b"hello world")
|
||||
unreader = AsyncUnreader(reader)
|
||||
|
||||
# Read all data
|
||||
data = await unreader.read()
|
||||
assert data == b"hello world"
|
||||
|
||||
# Unread some data
|
||||
unreader.unread(b"world")
|
||||
|
||||
# Read again should get unread data
|
||||
data = await unreader.read()
|
||||
assert data == b"world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_zero():
|
||||
"""Test reading zero bytes."""
|
||||
reader = MockStreamReader(b"hello")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read(0)
|
||||
assert data == b""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_unreader_read_empty():
|
||||
"""Test reading from empty stream."""
|
||||
reader = MockStreamReader(b"")
|
||||
unreader = AsyncUnreader(reader)
|
||||
data = await unreader.read()
|
||||
assert data == b""
|
||||
|
||||
|
||||
# AsyncRequest Tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_simple_get():
|
||||
"""Test parsing a simple GET request."""
|
||||
request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/path"
|
||||
assert request.version == (1, 1)
|
||||
assert ("HOST", "localhost") in request.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_with_query():
|
||||
"""Test parsing request with query string."""
|
||||
request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/search"
|
||||
assert request.query == "q=test&page=1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_post_with_body():
|
||||
"""Test parsing POST request with body."""
|
||||
request_data = (
|
||||
b"POST /submit HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Content-Length: 11\r\n"
|
||||
b"\r\n"
|
||||
b"hello=world"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.method == "POST"
|
||||
assert request.path == "/submit"
|
||||
assert request.content_length == 11
|
||||
|
||||
# Read body
|
||||
body = await request.read_body(100)
|
||||
assert body == b"hello=world"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_multiple_headers():
|
||||
"""Test parsing request with multiple headers."""
|
||||
request_data = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: localhost\r\n"
|
||||
b"Accept: text/html\r\n"
|
||||
b"Accept-Language: en-US\r\n"
|
||||
b"Connection: keep-alive\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert len(request.headers) == 4
|
||||
assert request.get_header("HOST") == "localhost"
|
||||
assert request.get_header("ACCEPT") == "text/html"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_http10():
|
||||
"""Test connection close detection for HTTP/1.0."""
|
||||
request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.version == (1, 0)
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_should_close_connection_header():
|
||||
"""Test connection close detection with Connection header."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_keepalive():
|
||||
"""Test keepalive detection."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.should_close() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_no_body_for_get():
|
||||
"""Test that GET requests have no body by default."""
|
||||
request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
assert request.content_length == 0
|
||||
body = await request.read_body()
|
||||
assert body == b""
|
||||
|
||||
|
||||
# Error handling tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_method():
|
||||
"""Test invalid HTTP method detection."""
|
||||
from gunicorn.http.errors import InvalidRequestMethod
|
||||
|
||||
request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidRequestMethod):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_request_invalid_http_version():
|
||||
"""Test invalid HTTP version detection."""
|
||||
from gunicorn.http.errors import InvalidHTTPVersion
|
||||
|
||||
request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n"
|
||||
reader = MockStreamReader(request_data)
|
||||
unreader = AsyncUnreader(reader)
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidHTTPVersion):
|
||||
await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))
|
||||
643
tests/test_asgi_worker.py
Normal file
643
tests/test_asgi_worker.py
Normal file
@ -0,0 +1,643 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for the ASGI worker.
|
||||
|
||||
Includes unit tests for worker components and integration tests
|
||||
that actually start the server and make HTTP requests.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.workers import gasgi
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock Classes
|
||||
# ============================================================================
|
||||
|
||||
class FakeSocket:
|
||||
"""Mock socket for testing."""
|
||||
|
||||
def __init__(self, data=b''):
|
||||
self.data = data
|
||||
self.closed = False
|
||||
self.blocking = True
|
||||
self._fileno = id(self) % 65536
|
||||
|
||||
def fileno(self):
|
||||
return self._fileno
|
||||
|
||||
def setblocking(self, blocking):
|
||||
self.blocking = blocking
|
||||
|
||||
def recv(self, size):
|
||||
if self.closed:
|
||||
raise OSError(errno.EBADF, "Bad file descriptor")
|
||||
result = self.data[:size]
|
||||
self.data = self.data[size:]
|
||||
return result
|
||||
|
||||
def send(self, data):
|
||||
if self.closed:
|
||||
raise OSError(errno.EPIPE, "Broken pipe")
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def getsockname(self):
|
||||
return ('127.0.0.1', 8000)
|
||||
|
||||
def getpeername(self):
|
||||
return ('127.0.0.1', 12345)
|
||||
|
||||
|
||||
class FakeApp:
|
||||
"""Mock ASGI application for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def wsgi(self):
|
||||
return self.asgi_app
|
||||
|
||||
async def asgi_app(self, scope, receive, send):
|
||||
self.calls.append(scope)
|
||||
if scope["type"] == "lifespan":
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
elif scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Hello from ASGI!",
|
||||
})
|
||||
|
||||
|
||||
class FakeListener:
|
||||
"""Mock listener socket."""
|
||||
|
||||
def __init__(self):
|
||||
self.sock = FakeSocket()
|
||||
|
||||
def getsockname(self):
|
||||
return ('127.0.0.1', 8000)
|
||||
|
||||
def close(self):
|
||||
self.sock.close()
|
||||
|
||||
def __str__(self):
|
||||
return "http://127.0.0.1:8000"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def _has_uvloop():
|
||||
"""Check if uvloop is available."""
|
||||
try:
|
||||
import uvloop
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Unit Tests for ASGIWorker
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIWorkerInit:
|
||||
"""Tests for ASGIWorker initialization."""
|
||||
|
||||
def create_worker(self, **kwargs):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
cfg.set(key, value)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
return worker
|
||||
|
||||
def test_worker_init(self):
|
||||
"""Test worker initialization."""
|
||||
worker = self.create_worker()
|
||||
|
||||
assert worker.worker_connections == 1000
|
||||
assert worker.nr_conns == 0
|
||||
assert worker.loop is None
|
||||
assert worker.servers == []
|
||||
assert worker.state == {}
|
||||
|
||||
def test_worker_connections_config(self):
|
||||
"""Test worker_connections configuration."""
|
||||
worker = self.create_worker(worker_connections=500)
|
||||
assert worker.worker_connections == 500
|
||||
|
||||
|
||||
class TestASGIWorkerEventLoop:
|
||||
"""Tests for event loop setup."""
|
||||
|
||||
def create_worker(self, **kwargs):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
for key, value in kwargs.items():
|
||||
cfg.set(key, value)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
return worker
|
||||
|
||||
def test_setup_asyncio_loop(self):
|
||||
"""Test asyncio event loop setup."""
|
||||
worker = self.create_worker(asgi_loop='asyncio')
|
||||
worker._setup_event_loop()
|
||||
|
||||
assert worker.loop is not None
|
||||
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
|
||||
worker.loop.close()
|
||||
|
||||
def test_setup_auto_loop_falls_back_to_asyncio(self):
|
||||
"""Test that auto mode uses asyncio when uvloop unavailable."""
|
||||
worker = self.create_worker(asgi_loop='auto')
|
||||
|
||||
# Mock uvloop import failure
|
||||
with mock.patch.dict('sys.modules', {'uvloop': None}):
|
||||
worker._setup_event_loop()
|
||||
|
||||
assert worker.loop is not None
|
||||
worker.loop.close()
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _has_uvloop(),
|
||||
reason="uvloop not installed"
|
||||
)
|
||||
def test_setup_uvloop(self):
|
||||
"""Test uvloop event loop setup."""
|
||||
worker = self.create_worker(asgi_loop='uvloop')
|
||||
worker._setup_event_loop()
|
||||
|
||||
import uvloop
|
||||
assert isinstance(worker.loop, uvloop.Loop)
|
||||
worker.loop.close()
|
||||
|
||||
|
||||
class TestASGIWorkerSignals:
|
||||
"""Tests for signal handling."""
|
||||
|
||||
def create_worker(self):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('worker_connections', 1000)
|
||||
cfg.set('graceful_timeout', 5)
|
||||
|
||||
worker = gasgi.ASGIWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=FakeApp(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
worker._setup_event_loop()
|
||||
return worker
|
||||
|
||||
def test_handle_exit_sets_alive_false(self):
|
||||
"""Test that exit signal sets alive=False."""
|
||||
worker = self.create_worker()
|
||||
worker.alive = True
|
||||
|
||||
worker.handle_exit_signal()
|
||||
|
||||
assert worker.alive is False
|
||||
worker.loop.close()
|
||||
|
||||
def test_handle_quit_sets_alive_false(self):
|
||||
"""Test that quit signal sets alive=False."""
|
||||
worker = self.create_worker()
|
||||
worker.alive = True
|
||||
|
||||
# Mock the worker_int callback on the worker's cfg settings
|
||||
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
|
||||
worker.handle_quit_signal()
|
||||
|
||||
assert worker.alive is False
|
||||
worker.loop.close()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for Lifespan Protocol
|
||||
# ============================================================================
|
||||
|
||||
class TestLifespanManager:
|
||||
"""Tests for ASGI lifespan protocol."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_startup_complete(self):
|
||||
"""Test successful lifespan startup."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
startup_called = False
|
||||
shutdown_called = False
|
||||
|
||||
async def app(scope, receive, send):
|
||||
nonlocal startup_called, shutdown_called
|
||||
assert scope["type"] == "lifespan"
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
startup_called = True
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
elif message["type"] == "lifespan.shutdown":
|
||||
shutdown_called = True
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
return
|
||||
|
||||
manager = LifespanManager(app, mock.Mock())
|
||||
await manager.startup()
|
||||
|
||||
assert startup_called
|
||||
assert manager._startup_complete.is_set()
|
||||
assert not manager._startup_failed
|
||||
|
||||
await manager.shutdown()
|
||||
assert shutdown_called
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_startup_failed(self):
|
||||
"""Test lifespan startup failure."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
async def app(scope, receive, send):
|
||||
message = await receive()
|
||||
if message["type"] == "lifespan.startup":
|
||||
await send({
|
||||
"type": "lifespan.startup.failed",
|
||||
"message": "Database connection failed"
|
||||
})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock())
|
||||
|
||||
with pytest.raises(RuntimeError, match="Database connection failed"):
|
||||
await manager.startup()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifespan_state_shared(self):
|
||||
"""Test that lifespan state is shared with app."""
|
||||
from gunicorn.asgi.lifespan import LifespanManager
|
||||
|
||||
state = {}
|
||||
|
||||
async def app(scope, receive, send):
|
||||
assert "state" in scope
|
||||
scope["state"]["db"] = "connected"
|
||||
message = await receive()
|
||||
await send({"type": "lifespan.startup.complete"})
|
||||
message = await receive()
|
||||
await send({"type": "lifespan.shutdown.complete"})
|
||||
|
||||
manager = LifespanManager(app, mock.Mock(), state)
|
||||
await manager.startup()
|
||||
|
||||
assert state.get("db") == "connected"
|
||||
|
||||
await manager.shutdown()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for WebSocket Protocol
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketProtocol:
|
||||
"""Tests for WebSocket protocol handling."""
|
||||
|
||||
def test_websocket_guid(self):
|
||||
"""Test WebSocket GUID constant."""
|
||||
from gunicorn.asgi.websocket import WS_GUID
|
||||
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
def test_websocket_opcodes(self):
|
||||
"""Test WebSocket opcode constants."""
|
||||
from gunicorn.asgi import websocket
|
||||
|
||||
assert websocket.OPCODE_TEXT == 0x1
|
||||
assert websocket.OPCODE_BINARY == 0x2
|
||||
assert websocket.OPCODE_CLOSE == 0x8
|
||||
assert websocket.OPCODE_PING == 0x9
|
||||
assert websocket.OPCODE_PONG == 0xA
|
||||
|
||||
def test_websocket_accept_key_calculation(self):
|
||||
"""Test WebSocket accept key calculation per RFC 6455."""
|
||||
import base64
|
||||
import hashlib
|
||||
from gunicorn.asgi.websocket import WS_GUID
|
||||
|
||||
# Example from RFC 6455
|
||||
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
|
||||
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
||||
|
||||
accept_key = base64.b64encode(
|
||||
hashlib.sha1(client_key + WS_GUID).digest()
|
||||
).decode("ascii")
|
||||
|
||||
assert accept_key == expected_accept
|
||||
|
||||
def test_websocket_frame_masking(self):
|
||||
"""Test WebSocket frame unmasking."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
# Create a minimal protocol instance
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
|
||||
# Test unmasking (XOR operation)
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
|
||||
|
||||
unmasked = protocol._unmask(masked_data, masking_key)
|
||||
assert unmasked == b"Hello"
|
||||
|
||||
def test_websocket_frame_masking_empty(self):
|
||||
"""Test WebSocket frame unmasking with empty payload."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
|
||||
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
|
||||
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
unmasked = protocol._unmask(b"", masking_key)
|
||||
assert unmasked == b""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIIntegration:
|
||||
"""Integration tests that start actual servers."""
|
||||
|
||||
@pytest.fixture
|
||||
def free_port(self):
|
||||
"""Get a free port for testing."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('127.0.0.1', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_request_response(self, free_port):
|
||||
"""Test basic HTTP request/response cycle."""
|
||||
# Simple ASGI app
|
||||
async def app(scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": [(b"content-type", b"text/plain")],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": b"Hello, World!",
|
||||
})
|
||||
|
||||
# Start server
|
||||
loop = asyncio.get_event_loop()
|
||||
server = await loop.create_server(
|
||||
lambda: _TestProtocol(app),
|
||||
'127.0.0.1',
|
||||
free_port,
|
||||
)
|
||||
|
||||
try:
|
||||
# Use asyncio to make HTTP request
|
||||
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
|
||||
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
|
||||
writer.write(request.encode())
|
||||
await writer.drain()
|
||||
|
||||
# Read response
|
||||
response = await reader.read(4096)
|
||||
response_text = response.decode()
|
||||
|
||||
assert "HTTP/1.1 200" in response_text
|
||||
assert "Hello, World!" in response_text
|
||||
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
finally:
|
||||
server.close()
|
||||
await server.wait_closed()
|
||||
|
||||
|
||||
class _TestProtocol(asyncio.Protocol):
|
||||
"""Minimal protocol for integration testing."""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.transport = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
self.transport = transport
|
||||
|
||||
def data_received(self, data):
|
||||
# Very simple HTTP parsing for testing
|
||||
asyncio.create_task(self._handle(data))
|
||||
|
||||
async def _handle(self, data):
|
||||
# Parse basic HTTP request
|
||||
lines = data.decode().split('\r\n')
|
||||
method, path, _ = lines[0].split(' ')
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"asgi": {"version": "3.0"},
|
||||
"http_version": "1.1",
|
||||
"method": method,
|
||||
"path": path,
|
||||
"query_string": b"",
|
||||
"headers": [],
|
||||
"server": ("127.0.0.1", 8000),
|
||||
"client": ("127.0.0.1", 12345),
|
||||
}
|
||||
|
||||
async def receive():
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
|
||||
async def send(message):
|
||||
if message["type"] == "http.response.start":
|
||||
status = message["status"]
|
||||
headers = message.get("headers", [])
|
||||
response = f"HTTP/1.1 {status} OK\r\n"
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode()
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
response += f"{name}: {value}\r\n"
|
||||
response += "\r\n"
|
||||
self.transport.write(response.encode())
|
||||
elif message["type"] == "http.response.body":
|
||||
body = message.get("body", b"")
|
||||
self.transport.write(body)
|
||||
if not message.get("more_body", False):
|
||||
self.transport.close()
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ASGI Protocol Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIProtocol:
|
||||
"""Tests for ASGIProtocol."""
|
||||
|
||||
def test_reason_phrases(self):
|
||||
"""Test HTTP reason phrase lookup."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
|
||||
# Create minimal worker mock
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
|
||||
protocol = ASGIProtocol(worker)
|
||||
|
||||
assert protocol._get_reason_phrase(200) == "OK"
|
||||
assert protocol._get_reason_phrase(404) == "Not Found"
|
||||
assert protocol._get_reason_phrase(500) == "Internal Server Error"
|
||||
assert protocol._get_reason_phrase(999) == "Unknown"
|
||||
|
||||
def test_scope_building(self):
|
||||
"""Test HTTP scope building."""
|
||||
from gunicorn.asgi.protocol import ASGIProtocol
|
||||
from gunicorn.asgi.message import AsyncRequest
|
||||
|
||||
worker = mock.Mock()
|
||||
worker.cfg = Config()
|
||||
worker.cfg.set('root_path', '/api')
|
||||
worker.log = mock.Mock()
|
||||
worker.asgi = mock.Mock()
|
||||
|
||||
protocol = ASGIProtocol(worker)
|
||||
|
||||
# Create mock request
|
||||
request = mock.Mock()
|
||||
request.method = "GET"
|
||||
request.path = "/users"
|
||||
request.query = "page=1"
|
||||
request.version = (1, 1)
|
||||
request.scheme = "http"
|
||||
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
|
||||
|
||||
scope = protocol._build_http_scope(
|
||||
request,
|
||||
("127.0.0.1", 8000), # sockname
|
||||
("127.0.0.1", 12345), # peername
|
||||
)
|
||||
|
||||
assert scope["type"] == "http"
|
||||
assert scope["method"] == "GET"
|
||||
assert scope["path"] == "/users"
|
||||
assert scope["query_string"] == b"page=1"
|
||||
assert scope["root_path"] == "/api"
|
||||
assert scope["http_version"] == "1.1"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestASGIConfig:
|
||||
"""Tests for ASGI configuration options."""
|
||||
|
||||
def test_asgi_loop_default(self):
|
||||
"""Test default asgi_loop value."""
|
||||
cfg = Config()
|
||||
assert cfg.asgi_loop == "auto"
|
||||
|
||||
def test_asgi_loop_validation(self):
|
||||
"""Test asgi_loop validation."""
|
||||
cfg = Config()
|
||||
|
||||
cfg.set('asgi_loop', 'asyncio')
|
||||
assert cfg.asgi_loop == 'asyncio'
|
||||
|
||||
cfg.set('asgi_loop', 'uvloop')
|
||||
assert cfg.asgi_loop == 'uvloop'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cfg.set('asgi_loop', 'invalid')
|
||||
|
||||
def test_asgi_lifespan_default(self):
|
||||
"""Test default asgi_lifespan value."""
|
||||
cfg = Config()
|
||||
assert cfg.asgi_lifespan == "auto"
|
||||
|
||||
def test_asgi_lifespan_validation(self):
|
||||
"""Test asgi_lifespan validation."""
|
||||
cfg = Config()
|
||||
|
||||
cfg.set('asgi_lifespan', 'on')
|
||||
assert cfg.asgi_lifespan == 'on'
|
||||
|
||||
cfg.set('asgi_lifespan', 'off')
|
||||
assert cfg.asgi_lifespan == 'off'
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
cfg.set('asgi_lifespan', 'invalid')
|
||||
|
||||
def test_root_path_default(self):
|
||||
"""Test default root_path value."""
|
||||
cfg = Config()
|
||||
assert cfg.root_path == ""
|
||||
|
||||
def test_root_path_setting(self):
|
||||
"""Test root_path configuration."""
|
||||
cfg = Config()
|
||||
cfg.set('root_path', '/api/v1')
|
||||
assert cfg.root_path == '/api/v1'
|
||||
435
tests/test_uwsgi.py
Normal file
435
tests/test_uwsgi.py
Normal file
@ -0,0 +1,435 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.uwsgi import (
|
||||
UWSGIRequest,
|
||||
UWSGIParser,
|
||||
UWSGIParseException,
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
from gunicorn.http.unreader import IterUnreader
|
||||
|
||||
|
||||
def make_uwsgi_packet(vars_dict, modifier1=0, modifier2=0):
|
||||
"""Create uWSGI packet for testing.
|
||||
|
||||
Args:
|
||||
vars_dict: Dict of WSGI environ variables
|
||||
modifier1: Packet type (0 = WSGI request)
|
||||
modifier2: Additional flags
|
||||
|
||||
Returns:
|
||||
bytes: Complete uWSGI packet
|
||||
"""
|
||||
vars_data = b''
|
||||
for key, value in vars_dict.items():
|
||||
k = key.encode('latin-1')
|
||||
v = value.encode('latin-1')
|
||||
vars_data += len(k).to_bytes(2, 'little') + k
|
||||
vars_data += len(v).to_bytes(2, 'little') + v
|
||||
|
||||
header = bytes([modifier1]) + len(vars_data).to_bytes(2, 'little') + bytes([modifier2])
|
||||
return header + vars_data
|
||||
|
||||
|
||||
def make_uwsgi_packet_with_body(vars_dict, body=b'', modifier1=0, modifier2=0):
|
||||
"""Create uWSGI packet with body for testing."""
|
||||
if body:
|
||||
vars_dict = dict(vars_dict)
|
||||
vars_dict['CONTENT_LENGTH'] = str(len(body))
|
||||
return make_uwsgi_packet(vars_dict, modifier1, modifier2) + body
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing."""
|
||||
|
||||
def __init__(self, is_ssl=False, uwsgi_allow_ips=None):
|
||||
self.is_ssl = is_ssl
|
||||
self.uwsgi_allow_ips = uwsgi_allow_ips or ['127.0.0.1', '::1']
|
||||
|
||||
|
||||
class TestUWSGIPacketConstruction:
|
||||
"""Test the packet construction helper."""
|
||||
|
||||
def test_empty_vars(self):
|
||||
packet = make_uwsgi_packet({})
|
||||
assert packet == b'\x00\x00\x00\x00' # modifier1=0, size=0, modifier2=0
|
||||
|
||||
def test_single_var(self):
|
||||
packet = make_uwsgi_packet({'KEY': 'val'})
|
||||
# Header: modifier1(0) + size(10 in LE) + modifier2(0)
|
||||
# Var: key_size(3 in LE) + 'KEY' + val_size(3 in LE) + 'val'
|
||||
# Size = 2 + 3 + 2 + 3 = 10 bytes
|
||||
expected_header = b'\x00\x0a\x00\x00'
|
||||
expected_var = b'\x03\x00KEY\x03\x00val'
|
||||
assert packet == expected_header + expected_var
|
||||
|
||||
def test_multiple_vars(self):
|
||||
packet = make_uwsgi_packet({'A': '1', 'B': '2'})
|
||||
assert len(packet) == 4 + (2 + 1 + 2 + 1) * 2 # header + 2 vars
|
||||
|
||||
|
||||
class TestUWSGIRequest:
|
||||
"""Test UWSGIRequest parsing."""
|
||||
|
||||
def test_parse_simple_request(self):
|
||||
"""Test parsing a simple GET request."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/test',
|
||||
'QUERY_STRING': 'foo=bar',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/test'
|
||||
assert req.query == 'foo=bar'
|
||||
assert req.uri == '/test?foo=bar'
|
||||
|
||||
def test_parse_post_request_with_body(self):
|
||||
"""Test parsing a POST request with body."""
|
||||
body = b'name=test&value=123'
|
||||
packet = make_uwsgi_packet_with_body({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/submit',
|
||||
'CONTENT_TYPE': 'application/x-www-form-urlencoded',
|
||||
}, body)
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'POST'
|
||||
assert req.path == '/submit'
|
||||
assert req.body.read() == body
|
||||
|
||||
def test_parse_headers(self):
|
||||
"""Test that HTTP_* vars become headers."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_HOST': 'example.com',
|
||||
'HTTP_USER_AGENT': 'TestClient/1.0',
|
||||
'HTTP_ACCEPT': 'text/html',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
headers_dict = dict(req.headers)
|
||||
assert headers_dict['HOST'] == 'example.com'
|
||||
assert headers_dict['USER-AGENT'] == 'TestClient/1.0'
|
||||
assert headers_dict['ACCEPT'] == 'text/html'
|
||||
|
||||
def test_parse_content_type_header(self):
|
||||
"""Test that CONTENT_TYPE becomes a header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_TYPE': 'application/json',
|
||||
'CONTENT_LENGTH': '0',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
headers_dict = dict(req.headers)
|
||||
assert headers_dict['CONTENT-TYPE'] == 'application/json'
|
||||
assert headers_dict['CONTENT-LENGTH'] == '0'
|
||||
|
||||
def test_https_scheme(self):
|
||||
"""Test scheme detection from HTTPS variable."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTPS': 'on',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.scheme == 'https'
|
||||
|
||||
def test_wsgi_url_scheme(self):
|
||||
"""Test scheme from wsgi.url_scheme variable."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'wsgi.url_scheme': 'https',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.scheme == 'https'
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values when vars are missing."""
|
||||
packet = make_uwsgi_packet({})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/'
|
||||
assert req.query == ''
|
||||
assert req.uri == '/'
|
||||
|
||||
def test_uwsgi_vars_preserved(self):
|
||||
"""Test that all vars are preserved in uwsgi_vars."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'SERVER_NAME': 'localhost',
|
||||
'SERVER_PORT': '8000',
|
||||
'CUSTOM_VAR': 'custom_value',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.uwsgi_vars['SERVER_NAME'] == 'localhost'
|
||||
assert req.uwsgi_vars['SERVER_PORT'] == '8000'
|
||||
assert req.uwsgi_vars['CUSTOM_VAR'] == 'custom_value'
|
||||
|
||||
|
||||
class TestUWSGIRequestErrors:
|
||||
"""Test UWSGIRequest error handling."""
|
||||
|
||||
def test_incomplete_header(self):
|
||||
"""Test error on incomplete header."""
|
||||
unreader = IterUnreader([b'\x00\x00']) # Only 2 bytes
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'incomplete header' in str(exc_info.value)
|
||||
|
||||
def test_incomplete_vars_block(self):
|
||||
"""Test error on truncated vars block."""
|
||||
# Header says 100 bytes of vars, but we only provide 10
|
||||
header = b'\x00\x64\x00\x00' # modifier1=0, size=100, modifier2=0
|
||||
unreader = IterUnreader([header + b'1234567890'])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'incomplete vars block' in str(exc_info.value)
|
||||
|
||||
def test_unsupported_modifier(self):
|
||||
"""Test error on non-zero modifier1."""
|
||||
packet = bytes([1]) + b'\x00\x00\x00' # modifier1=1
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(UnsupportedModifier) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert exc_info.value.modifier == 1
|
||||
assert exc_info.value.code == 501
|
||||
|
||||
def test_truncated_key_size(self):
|
||||
"""Test error on truncated key size."""
|
||||
header = b'\x00\x01\x00\x00' # size=1, but need at least 2 bytes for key_size
|
||||
unreader = IterUnreader([header + b'X'])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'truncated' in str(exc_info.value)
|
||||
|
||||
def test_forbidden_ip(self):
|
||||
"""Test error when source IP not in allow list."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['192.168.1.1'])
|
||||
|
||||
with pytest.raises(ForbiddenUWSGIRequest) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
|
||||
assert exc_info.value.code == 403
|
||||
assert '10.0.0.1' in str(exc_info.value)
|
||||
|
||||
def test_allowed_ip_wildcard(self):
|
||||
"""Test that wildcard allows any IP."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['*'])
|
||||
|
||||
# Should not raise
|
||||
req = UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
|
||||
assert req.method == 'GET'
|
||||
|
||||
def test_unix_socket_always_allowed(self):
|
||||
"""Test that UNIX socket connections are always allowed."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['127.0.0.1'])
|
||||
|
||||
# UNIX socket has non-tuple peer_addr
|
||||
req = UWSGIRequest(cfg, unreader, None)
|
||||
assert req.method == 'GET'
|
||||
|
||||
|
||||
class TestUWSGIRequestConnection:
|
||||
"""Test connection handling."""
|
||||
|
||||
def test_should_close_default(self):
|
||||
"""Test default keep-alive behavior."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is False
|
||||
|
||||
def test_should_close_connection_close(self):
|
||||
"""Test Connection: close header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_CONNECTION': 'close',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is True
|
||||
|
||||
def test_should_close_connection_keepalive(self):
|
||||
"""Test Connection: keep-alive header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_CONNECTION': 'keep-alive',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is False
|
||||
|
||||
def test_force_close(self):
|
||||
"""Test force_close method."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
req.force_close()
|
||||
|
||||
assert req.should_close() is True
|
||||
|
||||
|
||||
class TestUWSGIParser:
|
||||
"""Test UWSGIParser."""
|
||||
|
||||
def test_parser_iteration(self):
|
||||
"""Test iterating over parser for multiple requests."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/test',
|
||||
'HTTP_CONNECTION': 'close', # Single request
|
||||
})
|
||||
cfg = MockConfig()
|
||||
|
||||
# Parser expects an iterable source, not an unreader
|
||||
parser = UWSGIParser(cfg, [packet], ('127.0.0.1', 12345))
|
||||
req = next(parser)
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/test'
|
||||
|
||||
def test_parser_mesg_class(self):
|
||||
"""Test that parser uses UWSGIRequest."""
|
||||
assert UWSGIParser.mesg_class is UWSGIRequest
|
||||
|
||||
|
||||
class TestExceptionStrings:
|
||||
"""Test exception string representations."""
|
||||
|
||||
def test_invalid_uwsgi_header_str(self):
|
||||
exc = InvalidUWSGIHeader("test message")
|
||||
assert str(exc) == "Invalid uWSGI header: test message"
|
||||
assert exc.code == 400
|
||||
|
||||
def test_unsupported_modifier_str(self):
|
||||
exc = UnsupportedModifier(5)
|
||||
assert str(exc) == "Unsupported uWSGI modifier1: 5"
|
||||
assert exc.code == 501
|
||||
|
||||
def test_forbidden_uwsgi_request_str(self):
|
||||
exc = ForbiddenUWSGIRequest("10.0.0.1")
|
||||
assert str(exc) == "uWSGI request from '10.0.0.1' not allowed"
|
||||
assert exc.code == 403
|
||||
|
||||
|
||||
class TestUWSGIBody:
|
||||
"""Test body reading."""
|
||||
|
||||
def test_read_body_in_chunks(self):
|
||||
"""Test reading body in multiple chunks."""
|
||||
body = b'A' * 1000
|
||||
packet = make_uwsgi_packet_with_body({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
}, body)
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
result = b''
|
||||
chunk = req.body.read(100)
|
||||
while chunk:
|
||||
result += chunk
|
||||
chunk = req.body.read(100)
|
||||
|
||||
assert result == body
|
||||
|
||||
def test_invalid_content_length(self):
|
||||
"""Test handling of invalid CONTENT_LENGTH."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_LENGTH': 'invalid',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
# Invalid content length should default to 0
|
||||
assert req.body.read() == b''
|
||||
|
||||
def test_negative_content_length(self):
|
||||
"""Test handling of negative CONTENT_LENGTH."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_LENGTH': '-5',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
# Negative content length should default to 0
|
||||
assert req.body.read() == b''
|
||||
Loading…
x
Reference in New Issue
Block a user