gunicorn/tests/test_asgi_worker.py
Benoit Chesneau ae1eea8108 asgi: Add native ASGI worker with HTTP and WebSocket support
Add a new ASGI worker type that provides native async support using
gunicorn's own HTTP parsing infrastructure adapted for asyncio.

Features:
- HTTP/1.1 with keepalive support
- WebSocket connections (RFC 6455)
- ASGI lifespan protocol for startup/shutdown hooks
- Optional uvloop support for improved performance
- Full proxy protocol support (inherited from gunicorn)

New configuration options:
- --asgi-loop: Event loop selection (auto/asyncio/uvloop)
- --asgi-lifespan: Lifespan protocol control (auto/on/off)
- --root-path: ASGI root path for reverse proxy setups

Usage: gunicorn -k asgi myapp:app
2026-01-22 17:05:29 +01:00

644 lines
19 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Tests for the ASGI worker.
Includes unit tests for worker components and integration tests
that actually start the server and make HTTP requests.
"""
import asyncio
import errno
import os
import signal
import socket
import sys
import time
import threading
from unittest import mock
import pytest
from gunicorn.config import Config
from gunicorn.workers import gasgi
# ============================================================================
# Mock Classes
# ============================================================================
class FakeSocket:
"""Mock socket for testing."""
def __init__(self, data=b''):
self.data = data
self.closed = False
self.blocking = True
self._fileno = id(self) % 65536
def fileno(self):
return self._fileno
def setblocking(self, blocking):
self.blocking = blocking
def recv(self, size):
if self.closed:
raise OSError(errno.EBADF, "Bad file descriptor")
result = self.data[:size]
self.data = self.data[size:]
return result
def send(self, data):
if self.closed:
raise OSError(errno.EPIPE, "Broken pipe")
return len(data)
def close(self):
self.closed = True
def getsockname(self):
return ('127.0.0.1', 8000)
def getpeername(self):
return ('127.0.0.1', 12345)
class FakeApp:
"""Mock ASGI application for testing."""
def __init__(self):
self.calls = []
def wsgi(self):
return self.asgi_app
async def asgi_app(self, scope, receive, send):
self.calls.append(scope)
if scope["type"] == "lifespan":
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
elif scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello from ASGI!",
})
class FakeListener:
"""Mock listener socket."""
def __init__(self):
self.sock = FakeSocket()
def getsockname(self):
return ('127.0.0.1', 8000)
def close(self):
self.sock.close()
def __str__(self):
return "http://127.0.0.1:8000"
# ============================================================================
# Helper Functions
# ============================================================================
def _has_uvloop():
"""Check if uvloop is available."""
try:
import uvloop
return True
except ImportError:
return False
# ============================================================================
# Unit Tests for ASGIWorker
# ============================================================================
class TestASGIWorkerInit:
"""Tests for ASGIWorker initialization."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_worker_init(self):
"""Test worker initialization."""
worker = self.create_worker()
assert worker.worker_connections == 1000
assert worker.nr_conns == 0
assert worker.loop is None
assert worker.servers == []
assert worker.state == {}
def test_worker_connections_config(self):
"""Test worker_connections configuration."""
worker = self.create_worker(worker_connections=500)
assert worker.worker_connections == 500
class TestASGIWorkerEventLoop:
"""Tests for event loop setup."""
def create_worker(self, **kwargs):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
for key, value in kwargs.items():
cfg.set(key, value)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
return worker
def test_setup_asyncio_loop(self):
"""Test asyncio event loop setup."""
worker = self.create_worker(asgi_loop='asyncio')
worker._setup_event_loop()
assert worker.loop is not None
assert isinstance(worker.loop, asyncio.AbstractEventLoop)
worker.loop.close()
def test_setup_auto_loop_falls_back_to_asyncio(self):
"""Test that auto mode uses asyncio when uvloop unavailable."""
worker = self.create_worker(asgi_loop='auto')
# Mock uvloop import failure
with mock.patch.dict('sys.modules', {'uvloop': None}):
worker._setup_event_loop()
assert worker.loop is not None
worker.loop.close()
@pytest.mark.skipif(
not _has_uvloop(),
reason="uvloop not installed"
)
def test_setup_uvloop(self):
"""Test uvloop event loop setup."""
worker = self.create_worker(asgi_loop='uvloop')
worker._setup_event_loop()
import uvloop
assert isinstance(worker.loop, uvloop.Loop)
worker.loop.close()
class TestASGIWorkerSignals:
"""Tests for signal handling."""
def create_worker(self):
"""Create a worker for testing."""
cfg = Config()
cfg.set('workers', 1)
cfg.set('worker_connections', 1000)
cfg.set('graceful_timeout', 5)
worker = gasgi.ASGIWorker(
age=1,
ppid=os.getpid(),
sockets=[],
app=FakeApp(),
timeout=30,
cfg=cfg,
log=mock.Mock(),
)
worker._setup_event_loop()
return worker
def test_handle_exit_sets_alive_false(self):
"""Test that exit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
worker.handle_exit_signal()
assert worker.alive is False
worker.loop.close()
def test_handle_quit_sets_alive_false(self):
"""Test that quit signal sets alive=False."""
worker = self.create_worker()
worker.alive = True
# Mock the worker_int callback on the worker's cfg settings
with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None):
worker.handle_quit_signal()
assert worker.alive is False
worker.loop.close()
# ============================================================================
# Tests for Lifespan Protocol
# ============================================================================
class TestLifespanManager:
"""Tests for ASGI lifespan protocol."""
@pytest.mark.asyncio
async def test_lifespan_startup_complete(self):
"""Test successful lifespan startup."""
from gunicorn.asgi.lifespan import LifespanManager
startup_called = False
shutdown_called = False
async def app(scope, receive, send):
nonlocal startup_called, shutdown_called
assert scope["type"] == "lifespan"
while True:
message = await receive()
if message["type"] == "lifespan.startup":
startup_called = True
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
shutdown_called = True
await send({"type": "lifespan.shutdown.complete"})
return
manager = LifespanManager(app, mock.Mock())
await manager.startup()
assert startup_called
assert manager._startup_complete.is_set()
assert not manager._startup_failed
await manager.shutdown()
assert shutdown_called
@pytest.mark.asyncio
async def test_lifespan_startup_failed(self):
"""Test lifespan startup failure."""
from gunicorn.asgi.lifespan import LifespanManager
async def app(scope, receive, send):
message = await receive()
if message["type"] == "lifespan.startup":
await send({
"type": "lifespan.startup.failed",
"message": "Database connection failed"
})
manager = LifespanManager(app, mock.Mock())
with pytest.raises(RuntimeError, match="Database connection failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_lifespan_state_shared(self):
"""Test that lifespan state is shared with app."""
from gunicorn.asgi.lifespan import LifespanManager
state = {}
async def app(scope, receive, send):
assert "state" in scope
scope["state"]["db"] = "connected"
message = await receive()
await send({"type": "lifespan.startup.complete"})
message = await receive()
await send({"type": "lifespan.shutdown.complete"})
manager = LifespanManager(app, mock.Mock(), state)
await manager.startup()
assert state.get("db") == "connected"
await manager.shutdown()
# ============================================================================
# Tests for WebSocket Protocol
# ============================================================================
class TestWebSocketProtocol:
"""Tests for WebSocket protocol handling."""
def test_websocket_guid(self):
"""Test WebSocket GUID constant."""
from gunicorn.asgi.websocket import WS_GUID
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
def test_websocket_opcodes(self):
"""Test WebSocket opcode constants."""
from gunicorn.asgi import websocket
assert websocket.OPCODE_TEXT == 0x1
assert websocket.OPCODE_BINARY == 0x2
assert websocket.OPCODE_CLOSE == 0x8
assert websocket.OPCODE_PING == 0x9
assert websocket.OPCODE_PONG == 0xA
def test_websocket_accept_key_calculation(self):
"""Test WebSocket accept key calculation per RFC 6455."""
import base64
import hashlib
from gunicorn.asgi.websocket import WS_GUID
# Example from RFC 6455
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
accept_key = base64.b64encode(
hashlib.sha1(client_key + WS_GUID).digest()
).decode("ascii")
assert accept_key == expected_accept
def test_websocket_frame_masking(self):
"""Test WebSocket frame unmasking."""
from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
# Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked
unmasked = protocol._unmask(masked_data, masking_key)
assert unmasked == b"Hello"
def test_websocket_frame_masking_empty(self):
"""Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key)
assert unmasked == b""
# ============================================================================
# Integration Tests
# ============================================================================
class TestASGIIntegration:
"""Integration tests that start actual servers."""
@pytest.fixture
def free_port(self):
"""Get a free port for testing."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 0))
return s.getsockname()[1]
@pytest.mark.asyncio
async def test_http_request_response(self, free_port):
"""Test basic HTTP request/response cycle."""
# Simple ASGI app
async def app(scope, receive, send):
if scope["type"] == "http":
await send({
"type": "http.response.start",
"status": 200,
"headers": [(b"content-type", b"text/plain")],
})
await send({
"type": "http.response.body",
"body": b"Hello, World!",
})
# Start server
loop = asyncio.get_event_loop()
server = await loop.create_server(
lambda: _TestProtocol(app),
'127.0.0.1',
free_port,
)
try:
# Use asyncio to make HTTP request
reader, writer = await asyncio.open_connection('127.0.0.1', free_port)
request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n"
writer.write(request.encode())
await writer.drain()
# Read response
response = await reader.read(4096)
response_text = response.decode()
assert "HTTP/1.1 200" in response_text
assert "Hello, World!" in response_text
writer.close()
await writer.wait_closed()
finally:
server.close()
await server.wait_closed()
class _TestProtocol(asyncio.Protocol):
"""Minimal protocol for integration testing."""
def __init__(self, app):
self.app = app
self.transport = None
def connection_made(self, transport):
self.transport = transport
def data_received(self, data):
# Very simple HTTP parsing for testing
asyncio.create_task(self._handle(data))
async def _handle(self, data):
# Parse basic HTTP request
lines = data.decode().split('\r\n')
method, path, _ = lines[0].split(' ')
scope = {
"type": "http",
"asgi": {"version": "3.0"},
"http_version": "1.1",
"method": method,
"path": path,
"query_string": b"",
"headers": [],
"server": ("127.0.0.1", 8000),
"client": ("127.0.0.1", 12345),
}
async def receive():
return {"type": "http.request", "body": b"", "more_body": False}
async def send(message):
if message["type"] == "http.response.start":
status = message["status"]
headers = message.get("headers", [])
response = f"HTTP/1.1 {status} OK\r\n"
for name, value in headers:
if isinstance(name, bytes):
name = name.decode()
if isinstance(value, bytes):
value = value.decode()
response += f"{name}: {value}\r\n"
response += "\r\n"
self.transport.write(response.encode())
elif message["type"] == "http.response.body":
body = message.get("body", b"")
self.transport.write(body)
if not message.get("more_body", False):
self.transport.close()
await self.app(scope, receive, send)
# ============================================================================
# ASGI Protocol Tests
# ============================================================================
class TestASGIProtocol:
"""Tests for ASGIProtocol."""
def test_reason_phrases(self):
"""Test HTTP reason phrase lookup."""
from gunicorn.asgi.protocol import ASGIProtocol
# Create minimal worker mock
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
assert protocol._get_reason_phrase(200) == "OK"
assert protocol._get_reason_phrase(404) == "Not Found"
assert protocol._get_reason_phrase(500) == "Internal Server Error"
assert protocol._get_reason_phrase(999) == "Unknown"
def test_scope_building(self):
"""Test HTTP scope building."""
from gunicorn.asgi.protocol import ASGIProtocol
from gunicorn.asgi.message import AsyncRequest
worker = mock.Mock()
worker.cfg = Config()
worker.cfg.set('root_path', '/api')
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
# Create mock request
request = mock.Mock()
request.method = "GET"
request.path = "/users"
request.query = "page=1"
request.version = (1, 1)
request.scheme = "http"
request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")]
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000), # sockname
("127.0.0.1", 12345), # peername
)
assert scope["type"] == "http"
assert scope["method"] == "GET"
assert scope["path"] == "/users"
assert scope["query_string"] == b"page=1"
assert scope["root_path"] == "/api"
assert scope["http_version"] == "1.1"
# ============================================================================
# Config Tests
# ============================================================================
class TestASGIConfig:
"""Tests for ASGI configuration options."""
def test_asgi_loop_default(self):
"""Test default asgi_loop value."""
cfg = Config()
assert cfg.asgi_loop == "auto"
def test_asgi_loop_validation(self):
"""Test asgi_loop validation."""
cfg = Config()
cfg.set('asgi_loop', 'asyncio')
assert cfg.asgi_loop == 'asyncio'
cfg.set('asgi_loop', 'uvloop')
assert cfg.asgi_loop == 'uvloop'
with pytest.raises(ValueError):
cfg.set('asgi_loop', 'invalid')
def test_asgi_lifespan_default(self):
"""Test default asgi_lifespan value."""
cfg = Config()
assert cfg.asgi_lifespan == "auto"
def test_asgi_lifespan_validation(self):
"""Test asgi_lifespan validation."""
cfg = Config()
cfg.set('asgi_lifespan', 'on')
assert cfg.asgi_lifespan == 'on'
cfg.set('asgi_lifespan', 'off')
assert cfg.asgi_lifespan == 'off'
with pytest.raises(ValueError):
cfg.set('asgi_lifespan', 'invalid')
def test_root_path_default(self):
"""Test default root_path value."""
cfg = Config()
assert cfg.root_path == ""
def test_root_path_setting(self):
"""Test root_path configuration."""
cfg = Config()
cfg.set('root_path', '/api/v1')
assert cfg.root_path == '/api/v1'