# # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. """ Tests for the ASGI worker. Includes unit tests for worker components and integration tests that actually start the server and make HTTP requests. """ import asyncio import errno import os import signal import socket import sys import time import threading from unittest import mock import pytest from gunicorn.config import Config from gunicorn.workers import gasgi # ============================================================================ # Mock Classes # ============================================================================ class FakeSocket: """Mock socket for testing.""" def __init__(self, data=b''): self.data = data self.closed = False self.blocking = True self._fileno = id(self) % 65536 def fileno(self): return self._fileno def setblocking(self, blocking): self.blocking = blocking def recv(self, size): if self.closed: raise OSError(errno.EBADF, "Bad file descriptor") result = self.data[:size] self.data = self.data[size:] return result def send(self, data): if self.closed: raise OSError(errno.EPIPE, "Broken pipe") return len(data) def close(self): self.closed = True def getsockname(self): return ('127.0.0.1', 8000) def getpeername(self): return ('127.0.0.1', 12345) class FakeApp: """Mock ASGI application for testing.""" def __init__(self): self.calls = [] def wsgi(self): return self.asgi_app async def asgi_app(self, scope, receive, send): self.calls.append(scope) if scope["type"] == "lifespan": while True: message = await receive() if message["type"] == "lifespan.startup": await send({"type": "lifespan.startup.complete"}) elif message["type"] == "lifespan.shutdown": await send({"type": "lifespan.shutdown.complete"}) return elif scope["type"] == "http": await send({ "type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")], }) await send({ "type": "http.response.body", "body": b"Hello from ASGI!", }) class FakeListener: """Mock listener socket.""" def __init__(self): self.sock = FakeSocket() def getsockname(self): return ('127.0.0.1', 8000) def close(self): self.sock.close() def __str__(self): return "http://127.0.0.1:8000" # ============================================================================ # Helper Functions # ============================================================================ def _has_uvloop(): """Check if uvloop is available.""" try: import uvloop return True except ImportError: return False # ============================================================================ # Unit Tests for ASGIWorker # ============================================================================ class TestASGIWorkerInit: """Tests for ASGIWorker initialization.""" def create_worker(self, **kwargs): """Create a worker for testing.""" cfg = Config() cfg.set('workers', 1) cfg.set('worker_connections', 1000) for key, value in kwargs.items(): cfg.set(key, value) worker = gasgi.ASGIWorker( age=1, ppid=os.getpid(), sockets=[], app=FakeApp(), timeout=30, cfg=cfg, log=mock.Mock(), ) return worker def test_worker_init(self): """Test worker initialization.""" worker = self.create_worker() assert worker.worker_connections == 1000 assert worker.nr_conns == 0 assert worker.loop is None assert worker.servers == [] assert worker.state == {} def test_worker_connections_config(self): """Test worker_connections configuration.""" worker = self.create_worker(worker_connections=500) assert worker.worker_connections == 500 class TestASGIWorkerEventLoop: """Tests for event loop setup.""" def create_worker(self, **kwargs): """Create a worker for testing.""" cfg = Config() cfg.set('workers', 1) cfg.set('worker_connections', 1000) for key, value in kwargs.items(): cfg.set(key, value) worker = gasgi.ASGIWorker( age=1, ppid=os.getpid(), sockets=[], app=FakeApp(), timeout=30, cfg=cfg, log=mock.Mock(), ) return worker def test_setup_asyncio_loop(self): """Test asyncio event loop setup.""" worker = self.create_worker(asgi_loop='asyncio') worker._setup_event_loop() assert worker.loop is not None assert isinstance(worker.loop, asyncio.AbstractEventLoop) worker.loop.close() def test_setup_auto_loop_falls_back_to_asyncio(self): """Test that auto mode uses asyncio when uvloop unavailable.""" worker = self.create_worker(asgi_loop='auto') # Mock uvloop import failure with mock.patch.dict('sys.modules', {'uvloop': None}): worker._setup_event_loop() assert worker.loop is not None worker.loop.close() @pytest.mark.skipif( not _has_uvloop(), reason="uvloop not installed" ) def test_setup_uvloop(self): """Test uvloop event loop setup.""" worker = self.create_worker(asgi_loop='uvloop') worker._setup_event_loop() import uvloop assert isinstance(worker.loop, uvloop.Loop) worker.loop.close() class TestASGIWorkerSignals: """Tests for signal handling.""" def create_worker(self): """Create a worker for testing.""" cfg = Config() cfg.set('workers', 1) cfg.set('worker_connections', 1000) cfg.set('graceful_timeout', 5) worker = gasgi.ASGIWorker( age=1, ppid=os.getpid(), sockets=[], app=FakeApp(), timeout=30, cfg=cfg, log=mock.Mock(), ) worker._setup_event_loop() return worker def test_handle_exit_sets_alive_false(self): """Test that exit signal sets alive=False.""" worker = self.create_worker() worker.alive = True worker.handle_exit_signal() assert worker.alive is False worker.loop.close() def test_handle_quit_sets_alive_false(self): """Test that quit signal sets alive=False.""" worker = self.create_worker() worker.alive = True # Mock the worker_int callback on the worker's cfg settings with mock.patch.object(worker.cfg.settings['worker_int'], 'get', return_value=lambda w: None): worker.handle_quit_signal() assert worker.alive is False worker.loop.close() # ============================================================================ # Tests for Lifespan Protocol # ============================================================================ class TestLifespanManager: """Tests for ASGI lifespan protocol.""" @pytest.mark.asyncio async def test_lifespan_startup_complete(self): """Test successful lifespan startup.""" from gunicorn.asgi.lifespan import LifespanManager startup_called = False shutdown_called = False async def app(scope, receive, send): nonlocal startup_called, shutdown_called assert scope["type"] == "lifespan" while True: message = await receive() if message["type"] == "lifespan.startup": startup_called = True await send({"type": "lifespan.startup.complete"}) elif message["type"] == "lifespan.shutdown": shutdown_called = True await send({"type": "lifespan.shutdown.complete"}) return manager = LifespanManager(app, mock.Mock()) await manager.startup() assert startup_called assert manager._startup_complete.is_set() assert not manager._startup_failed await manager.shutdown() assert shutdown_called @pytest.mark.asyncio async def test_lifespan_startup_failed(self): """Test lifespan startup failure.""" from gunicorn.asgi.lifespan import LifespanManager async def app(scope, receive, send): message = await receive() if message["type"] == "lifespan.startup": await send({ "type": "lifespan.startup.failed", "message": "Database connection failed" }) manager = LifespanManager(app, mock.Mock()) with pytest.raises(RuntimeError, match="Database connection failed"): await manager.startup() @pytest.mark.asyncio async def test_lifespan_state_shared(self): """Test that lifespan state is shared with app.""" from gunicorn.asgi.lifespan import LifespanManager state = {} async def app(scope, receive, send): assert "state" in scope scope["state"]["db"] = "connected" message = await receive() await send({"type": "lifespan.startup.complete"}) message = await receive() await send({"type": "lifespan.shutdown.complete"}) manager = LifespanManager(app, mock.Mock(), state) await manager.startup() assert state.get("db") == "connected" await manager.shutdown() # ============================================================================ # Tests for WebSocket Protocol # ============================================================================ class TestWebSocketProtocol: """Tests for WebSocket protocol handling.""" def test_websocket_guid(self): """Test WebSocket GUID constant.""" from gunicorn.asgi.websocket import WS_GUID assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" def test_websocket_opcodes(self): """Test WebSocket opcode constants.""" from gunicorn.asgi import websocket assert websocket.OPCODE_TEXT == 0x1 assert websocket.OPCODE_BINARY == 0x2 assert websocket.OPCODE_CLOSE == 0x8 assert websocket.OPCODE_PING == 0x9 assert websocket.OPCODE_PONG == 0xA def test_websocket_accept_key_calculation(self): """Test WebSocket accept key calculation per RFC 6455.""" import base64 import hashlib from gunicorn.asgi.websocket import WS_GUID # Example from RFC 6455 client_key = b"dGhlIHNhbXBsZSBub25jZQ==" expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=" accept_key = base64.b64encode( hashlib.sha1(client_key + WS_GUID).digest() ).decode("ascii") assert accept_key == expected_accept def test_websocket_frame_masking(self): """Test WebSocket frame unmasking.""" from gunicorn.asgi.websocket import WebSocketProtocol # Create a minimal protocol instance protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) # Test unmasking (XOR operation) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58]) # "Hello" masked unmasked = protocol._unmask(masked_data, masking_key) assert unmasked == b"Hello" def test_websocket_frame_masking_empty(self): """Test WebSocket frame unmasking with empty payload.""" from gunicorn.asgi.websocket import WebSocketProtocol protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) unmasked = protocol._unmask(b"", masking_key) assert unmasked == b"" # ============================================================================ # Integration Tests # ============================================================================ class TestASGIIntegration: """Integration tests that start actual servers.""" @pytest.fixture def free_port(self): """Get a free port for testing.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('127.0.0.1', 0)) return s.getsockname()[1] @pytest.mark.asyncio async def test_http_request_response(self, free_port): """Test basic HTTP request/response cycle.""" # Simple ASGI app async def app(scope, receive, send): if scope["type"] == "http": await send({ "type": "http.response.start", "status": 200, "headers": [(b"content-type", b"text/plain")], }) await send({ "type": "http.response.body", "body": b"Hello, World!", }) # Start server loop = asyncio.get_event_loop() server = await loop.create_server( lambda: _TestProtocol(app), '127.0.0.1', free_port, ) try: # Use asyncio to make HTTP request reader, writer = await asyncio.open_connection('127.0.0.1', free_port) request = f"GET / HTTP/1.1\r\nHost: 127.0.0.1:{free_port}\r\n\r\n" writer.write(request.encode()) await writer.drain() # Read response response = await reader.read(4096) response_text = response.decode() assert "HTTP/1.1 200" in response_text assert "Hello, World!" in response_text writer.close() await writer.wait_closed() finally: server.close() await server.wait_closed() class _TestProtocol(asyncio.Protocol): """Minimal protocol for integration testing.""" def __init__(self, app): self.app = app self.transport = None def connection_made(self, transport): self.transport = transport def data_received(self, data): # Very simple HTTP parsing for testing asyncio.create_task(self._handle(data)) async def _handle(self, data): # Parse basic HTTP request lines = data.decode().split('\r\n') method, path, _ = lines[0].split(' ') scope = { "type": "http", "asgi": {"version": "3.0"}, "http_version": "1.1", "method": method, "path": path, "query_string": b"", "headers": [], "server": ("127.0.0.1", 8000), "client": ("127.0.0.1", 12345), } async def receive(): return {"type": "http.request", "body": b"", "more_body": False} async def send(message): if message["type"] == "http.response.start": status = message["status"] headers = message.get("headers", []) response = f"HTTP/1.1 {status} OK\r\n" for name, value in headers: if isinstance(name, bytes): name = name.decode() if isinstance(value, bytes): value = value.decode() response += f"{name}: {value}\r\n" response += "\r\n" self.transport.write(response.encode()) elif message["type"] == "http.response.body": body = message.get("body", b"") self.transport.write(body) if not message.get("more_body", False): self.transport.close() await self.app(scope, receive, send) # ============================================================================ # ASGI Protocol Tests # ============================================================================ class TestASGIProtocol: """Tests for ASGIProtocol.""" def test_reason_phrases(self): """Test HTTP reason phrase lookup.""" from gunicorn.asgi.protocol import ASGIProtocol # Create minimal worker mock worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) assert protocol._get_reason_phrase(200) == "OK" assert protocol._get_reason_phrase(404) == "Not Found" assert protocol._get_reason_phrase(500) == "Internal Server Error" assert protocol._get_reason_phrase(999) == "Unknown" def test_scope_building(self): """Test HTTP scope building.""" from gunicorn.asgi.protocol import ASGIProtocol from gunicorn.asgi.message import AsyncRequest worker = mock.Mock() worker.cfg = Config() worker.cfg.set('root_path', '/api') worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) # Create mock request request = mock.Mock() request.method = "GET" request.path = "/users" request.query = "page=1" request.version = (1, 1) request.scheme = "http" request.headers = [("HOST", "localhost"), ("ACCEPT", "text/html")] scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), # sockname ("127.0.0.1", 12345), # peername ) assert scope["type"] == "http" assert scope["method"] == "GET" assert scope["path"] == "/users" assert scope["query_string"] == b"page=1" assert scope["root_path"] == "/api" assert scope["http_version"] == "1.1" # ============================================================================ # Config Tests # ============================================================================ class TestASGIConfig: """Tests for ASGI configuration options.""" def test_asgi_loop_default(self): """Test default asgi_loop value.""" cfg = Config() assert cfg.asgi_loop == "auto" def test_asgi_loop_validation(self): """Test asgi_loop validation.""" cfg = Config() cfg.set('asgi_loop', 'asyncio') assert cfg.asgi_loop == 'asyncio' cfg.set('asgi_loop', 'uvloop') assert cfg.asgi_loop == 'uvloop' with pytest.raises(ValueError): cfg.set('asgi_loop', 'invalid') def test_asgi_lifespan_default(self): """Test default asgi_lifespan value.""" cfg = Config() assert cfg.asgi_lifespan == "auto" def test_asgi_lifespan_validation(self): """Test asgi_lifespan validation.""" cfg = Config() cfg.set('asgi_lifespan', 'on') assert cfg.asgi_lifespan == 'on' cfg.set('asgi_lifespan', 'off') assert cfg.asgi_lifespan == 'off' with pytest.raises(ValueError): cfg.set('asgi_lifespan', 'invalid') def test_root_path_default(self): """Test default root_path value.""" cfg = Config() assert cfg.root_path == "" def test_root_path_setting(self): """Test root_path configuration.""" cfg = Config() cfg.set('root_path', '/api/v1') assert cfg.root_path == '/api/v1' # ============================================================================ # HTTP/2 Priority Tests # ============================================================================ class TestASGIHTTP2Priority: """Test HTTP/2 priority in ASGI scope.""" def test_http2_priority_in_scope(self): """Test that HTTP/2 priority is added to ASGI scope extensions.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) # Create mock HTTP/2 request with priority request = mock.Mock() request.method = "GET" request.path = "/test" request.query = "" request.version = (2, 0) request.scheme = "https" request.headers = [("HOST", "localhost")] request.priority_weight = 128 request.priority_depends_on = 3 scope = protocol._build_http_scope( request, ("127.0.0.1", 8443), ("127.0.0.1", 12345), ) assert "extensions" in scope assert "http.response.priority" in scope["extensions"] assert scope["extensions"]["http.response.priority"]["weight"] == 128 assert scope["extensions"]["http.response.priority"]["depends_on"] == 3 def test_http2_priority_in_http2_scope(self): """Test that HTTP/2 priority is in _build_http2_scope.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) # Create mock HTTP/2 request with priority request = mock.Mock() request.method = "POST" request.path = "/api/data" request.query = "id=1" request.uri = "/api/data?id=1" request.scheme = "https" request.headers = [("HOST", "localhost"), ("CONTENT-TYPE", "application/json")] request.priority_weight = 256 request.priority_depends_on = 1 scope = protocol._build_http2_scope( request, ("127.0.0.1", 8443), ("127.0.0.1", 12345), ) assert scope["http_version"] == "2" assert "extensions" in scope assert "http.response.priority" in scope["extensions"] assert scope["extensions"]["http.response.priority"]["weight"] == 256 assert scope["extensions"]["http.response.priority"]["depends_on"] == 1 def test_no_priority_for_http1_requests(self): """Test that HTTP/1.1 requests don't have priority extensions.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) # Create mock HTTP/1.1 request (no priority attributes) request = mock.Mock(spec=['method', 'path', 'query', 'version', 'scheme', 'headers']) request.method = "GET" request.path = "/test" request.query = "" request.version = (1, 1) request.scheme = "http" request.headers = [("HOST", "localhost")] scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) # HTTP/1.1 requests should not have extensions with priority assert "extensions" not in scope or "http.response.priority" not in scope.get("extensions", {}) # ============================================================================ # HTTP/2 Trailers Tests # ============================================================================ class TestASGIHTTP2Trailers: """Test HTTP/2 response trailer support in ASGI.""" def test_http2_trailers_extension_in_scope(self): """Test that HTTP/2 scope includes http.response.trailers extension.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) # Create mock HTTP/2 request request = mock.Mock() request.method = "GET" request.path = "/api" request.query = "" request.uri = "/api" request.scheme = "https" request.headers = [("HOST", "localhost")] request.priority_weight = 16 request.priority_depends_on = 0 scope = protocol._build_http2_scope( request, ("127.0.0.1", 8443), ("127.0.0.1", 12345), ) # HTTP/2 scope should have trailers extension assert "extensions" in scope assert "http.response.trailers" in scope["extensions"] def test_http2_scope_has_both_priority_and_trailers(self): """Test that HTTP/2 scope includes both priority and trailers extensions.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) request = mock.Mock() request.method = "POST" request.path = "/grpc" request.query = "" request.uri = "/grpc" request.scheme = "https" request.headers = [("HOST", "localhost"), ("CONTENT-TYPE", "application/grpc")] request.priority_weight = 128 request.priority_depends_on = 1 scope = protocol._build_http2_scope( request, ("127.0.0.1", 8443), ("127.0.0.1", 54321), ) extensions = scope.get("extensions", {}) assert "http.response.priority" in extensions assert "http.response.trailers" in extensions assert extensions["http.response.priority"]["weight"] == 128