# # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. """ ASGI 3.0 specification compliance tests. Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec: https://asgi.readthedocs.io/en/latest/specs/main.html """ from unittest import mock import pytest from gunicorn.config import Config # ============================================================================ # ASGI Version Tests # ============================================================================ class TestASGIVersion: """Test ASGI version information in scope.""" def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() return ASGIProtocol(worker) def _create_mock_request(self, **kwargs): """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") path = kwargs.get("path", "/") request.path = path request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") request.headers = kwargs.get("headers", []) return request def test_asgi_version_present(self): """Test that 'asgi' key is present in HTTP scope.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert "asgi" in scope def test_asgi_version_is_dict(self): """Test that 'asgi' value is a dictionary.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert isinstance(scope["asgi"], dict) def test_asgi_version_value(self): """Test that ASGI version is '3.0'.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert scope["asgi"]["version"] == "3.0" def test_asgi_spec_version_present(self): """Test that spec_version is present in ASGI dict.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert "spec_version" in scope["asgi"] def test_asgi_spec_version_value(self): """Test that spec_version follows semantic versioning.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) spec_version = scope["asgi"]["spec_version"] # Should be in format "X.Y" (major.minor) parts = spec_version.split(".") assert len(parts) == 2 assert all(part.isdigit() for part in parts) # ============================================================================ # HTTP Scope Keys Tests (ASGI HTTP Connection Scope) # ============================================================================ class TestHTTPScopeKeys: """Test required keys in HTTP connection scope per ASGI spec.""" def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() return ASGIProtocol(worker) def _create_mock_request(self, **kwargs): """Create a mock HTTP request.""" request = mock.Mock() request.method = kwargs.get("method", "GET") path = kwargs.get("path", "/") request.path = path request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"") request.query = kwargs.get("query", "") request.version = kwargs.get("version", (1, 1)) request.scheme = kwargs.get("scheme", "http") request.headers = kwargs.get("headers", []) return request def test_type_key_present(self): """Test 'type' key is present and equals 'http'.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert scope["type"] == "http" def test_http_version_key_present(self): """Test 'http_version' key is present.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert "http_version" in scope assert scope["http_version"] == "1.1" def test_http_version_formats(self): """Test various HTTP version formats.""" protocol = self._create_protocol() # HTTP/1.0 request_10 = self._create_mock_request(version=(1, 0)) scope_10 = protocol._build_http_scope(request_10, None, None) assert scope_10["http_version"] == "1.0" # HTTP/1.1 request_11 = self._create_mock_request(version=(1, 1)) scope_11 = protocol._build_http_scope(request_11, None, None) assert scope_11["http_version"] == "1.1" def test_method_key_present(self): """Test 'method' key is present and is uppercase string.""" protocol = self._create_protocol() for method in ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]: request = self._create_mock_request(method=method) scope = protocol._build_http_scope(request, None, None) assert scope["method"] == method assert scope["method"].isupper() def test_scheme_key_present(self): """Test 'scheme' key is present.""" protocol = self._create_protocol() # HTTP request_http = self._create_mock_request(scheme="http") scope_http = protocol._build_http_scope(request_http, None, None) assert scope_http["scheme"] == "http" # HTTPS request_https = self._create_mock_request(scheme="https") scope_https = protocol._build_http_scope(request_https, None, None) assert scope_https["scheme"] == "https" def test_path_key_present(self): """Test 'path' key is present and starts with /.""" protocol = self._create_protocol() request = self._create_mock_request(path="/api/users") scope = protocol._build_http_scope(request, None, None) assert "path" in scope assert scope["path"] == "/api/users" assert scope["path"].startswith("/") def test_raw_path_key_present(self): """Test 'raw_path' key is present and is bytes.""" protocol = self._create_protocol() request = self._create_mock_request(path="/api/users") scope = protocol._build_http_scope(request, None, None) assert "raw_path" in scope assert isinstance(scope["raw_path"], bytes) assert scope["raw_path"] == b"/api/users" def test_query_string_key_present(self): """Test 'query_string' key is present and is bytes.""" protocol = self._create_protocol() request = self._create_mock_request(query="page=1&limit=10") scope = protocol._build_http_scope(request, None, None) assert "query_string" in scope assert isinstance(scope["query_string"], bytes) assert scope["query_string"] == b"page=1&limit=10" def test_query_string_empty(self): """Test 'query_string' is empty bytes when no query.""" protocol = self._create_protocol() request = self._create_mock_request(query="") scope = protocol._build_http_scope(request, None, None) assert scope["query_string"] == b"" def test_root_path_key_present(self): """Test 'root_path' key is present.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) assert "root_path" in scope assert isinstance(scope["root_path"], str) def test_headers_key_present(self): """Test 'headers' key is present and is list of 2-tuples.""" protocol = self._create_protocol() request = self._create_mock_request( headers=[("HOST", "localhost"), ("ACCEPT", "text/html")] ) scope = protocol._build_http_scope(request, None, None) assert "headers" in scope assert isinstance(scope["headers"], list) for header in scope["headers"]: assert isinstance(header, tuple) assert len(header) == 2 def test_headers_are_bytes(self): """Test that header names and values are bytes.""" protocol = self._create_protocol() request = self._create_mock_request( headers=[("HOST", "localhost"), ("CONTENT-TYPE", "application/json")] ) scope = protocol._build_http_scope(request, None, None) for name, value in scope["headers"]: assert isinstance(name, bytes), f"Header name should be bytes: {name}" assert isinstance(value, bytes), f"Header value should be bytes: {value}" def test_headers_names_lowercase(self): """Test that header names are lowercase.""" protocol = self._create_protocol() request = self._create_mock_request( headers=[("HOST", "localhost"), ("Content-Type", "application/json")] ) scope = protocol._build_http_scope(request, None, None) for name, _ in scope["headers"]: assert name == name.lower(), f"Header name should be lowercase: {name}" def test_server_key_present(self): """Test 'server' key is present when sockname provided.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("127.0.0.1", 12345), ) assert "server" in scope assert scope["server"] == ("127.0.0.1", 8000) def test_server_key_none(self): """Test 'server' key is None when sockname not provided.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) assert scope["server"] is None def test_client_key_present(self): """Test 'client' key is present when peername provided.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope( request, ("127.0.0.1", 8000), ("192.168.1.100", 54321), ) assert "client" in scope assert scope["client"] == ("192.168.1.100", 54321) def test_client_key_none(self): """Test 'client' key is None when peername not provided.""" protocol = self._create_protocol() request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) assert scope["client"] is None # ============================================================================ # HTTP Message Format Tests # ============================================================================ class TestHTTPMessageFormats: """Test HTTP message formats per ASGI spec.""" def test_http_request_message_format(self): """Test http.request message format.""" message = { "type": "http.request", "body": b"request body", "more_body": False, } assert message["type"] == "http.request" assert isinstance(message["body"], bytes) assert isinstance(message["more_body"], bool) def test_http_request_message_empty_body(self): """Test http.request message with empty body.""" message = { "type": "http.request", "body": b"", "more_body": False, } assert message["body"] == b"" assert message["more_body"] is False def test_http_response_start_format(self): """Test http.response.start message format.""" message = { "type": "http.response.start", "status": 200, "headers": [ (b"content-type", b"text/plain"), (b"content-length", b"13"), ], } assert message["type"] == "http.response.start" assert isinstance(message["status"], int) assert 100 <= message["status"] < 600 assert isinstance(message["headers"], list) def test_http_response_body_format(self): """Test http.response.body message format.""" message = { "type": "http.response.body", "body": b"Hello, World!", "more_body": False, } assert message["type"] == "http.response.body" assert isinstance(message["body"], bytes) assert isinstance(message["more_body"], bool) def test_http_response_body_streaming(self): """Test http.response.body message for streaming.""" # First chunk chunk1 = { "type": "http.response.body", "body": b"First chunk", "more_body": True, } # Last chunk chunk2 = { "type": "http.response.body", "body": b"Last chunk", "more_body": False, } assert chunk1["more_body"] is True assert chunk2["more_body"] is False def test_http_disconnect_format(self): """Test http.disconnect message format.""" message = {"type": "http.disconnect"} assert message["type"] == "http.disconnect" # ============================================================================ # HTTP Response Status Codes Tests # ============================================================================ class TestHTTPStatusCodes: """Test HTTP status code handling.""" def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() return ASGIProtocol(worker) def test_reason_phrase_informational(self): """Test reason phrases for 1xx status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(100) == "Continue" assert protocol._get_reason_phrase(101) == "Switching Protocols" assert protocol._get_reason_phrase(103) == "Early Hints" def test_reason_phrase_success(self): """Test reason phrases for 2xx status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(200) == "OK" assert protocol._get_reason_phrase(201) == "Created" assert protocol._get_reason_phrase(202) == "Accepted" assert protocol._get_reason_phrase(204) == "No Content" assert protocol._get_reason_phrase(206) == "Partial Content" def test_reason_phrase_redirect(self): """Test reason phrases for 3xx status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(301) == "Moved Permanently" assert protocol._get_reason_phrase(302) == "Found" assert protocol._get_reason_phrase(303) == "See Other" assert protocol._get_reason_phrase(304) == "Not Modified" assert protocol._get_reason_phrase(307) == "Temporary Redirect" assert protocol._get_reason_phrase(308) == "Permanent Redirect" def test_reason_phrase_client_error(self): """Test reason phrases for 4xx status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(400) == "Bad Request" assert protocol._get_reason_phrase(401) == "Unauthorized" assert protocol._get_reason_phrase(403) == "Forbidden" assert protocol._get_reason_phrase(404) == "Not Found" assert protocol._get_reason_phrase(405) == "Method Not Allowed" assert protocol._get_reason_phrase(408) == "Request Timeout" assert protocol._get_reason_phrase(409) == "Conflict" assert protocol._get_reason_phrase(410) == "Gone" assert protocol._get_reason_phrase(422) == "Unprocessable Entity" assert protocol._get_reason_phrase(429) == "Too Many Requests" def test_reason_phrase_server_error(self): """Test reason phrases for 5xx status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(500) == "Internal Server Error" assert protocol._get_reason_phrase(501) == "Not Implemented" assert protocol._get_reason_phrase(502) == "Bad Gateway" assert protocol._get_reason_phrase(503) == "Service Unavailable" assert protocol._get_reason_phrase(504) == "Gateway Timeout" def test_reason_phrase_unknown(self): """Test reason phrase for unknown status codes.""" protocol = self._create_protocol() assert protocol._get_reason_phrase(999) == "Unknown" assert protocol._get_reason_phrase(418) == "Unknown" # I'm a teapot not defined # ============================================================================ # Informational Response Tests (103 Early Hints, etc.) # ============================================================================ class TestInformationalResponses: """Test support for HTTP 1xx informational responses.""" def test_http_response_informational_format(self): """Test http.response.informational message format.""" message = { "type": "http.response.informational", "status": 103, "headers": [ (b"link", b"; rel=preload; as=style"), ], } assert message["type"] == "http.response.informational" assert 100 <= message["status"] < 200 assert isinstance(message["headers"], list) def test_early_hints_103(self): """Test 103 Early Hints message format.""" message = { "type": "http.response.informational", "status": 103, "headers": [ (b"link", b"; rel=preload; as=style"), (b"link", b"; rel=preload; as=script"), ], } assert message["status"] == 103 # ============================================================================ # ASGI Extensions Tests # ============================================================================ class TestASGIExtensions: """Test ASGI extensions support.""" def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() return ASGIProtocol(worker) def _create_mock_http2_request(self, **kwargs): """Create a mock HTTP/2 request with priority.""" request = mock.Mock() request.method = kwargs.get("method", "GET") request.path = kwargs.get("path", "/") request.query = kwargs.get("query", "") request.uri = kwargs.get("uri", "/") request.scheme = kwargs.get("scheme", "https") request.headers = kwargs.get("headers", []) request.priority_weight = kwargs.get("priority_weight", 16) request.priority_depends_on = kwargs.get("priority_depends_on", 0) return request def test_http2_scope_has_extensions(self): """Test that HTTP/2 scope includes extensions dict.""" protocol = self._create_protocol() request = self._create_mock_http2_request() scope = protocol._build_http2_scope(request, None, None) assert "extensions" in scope assert isinstance(scope["extensions"], dict) def test_http2_priority_extension(self): """Test http.response.priority extension in HTTP/2 scope.""" protocol = self._create_protocol() request = self._create_mock_http2_request( priority_weight=128, priority_depends_on=5, ) scope = protocol._build_http2_scope(request, None, None) assert "http.response.priority" in scope["extensions"] priority = scope["extensions"]["http.response.priority"] assert "weight" in priority assert "depends_on" in priority assert priority["weight"] == 128 assert priority["depends_on"] == 5 def test_http2_trailers_extension(self): """Test http.response.trailers extension in HTTP/2 scope.""" protocol = self._create_protocol() request = self._create_mock_http2_request() scope = protocol._build_http2_scope(request, None, None) assert "http.response.trailers" in scope["extensions"] def test_http_response_trailers_message_format(self): """Test http.response.trailers message format.""" message = { "type": "http.response.trailers", "headers": [ (b"grpc-status", b"0"), (b"grpc-message", b""), ], "more_trailers": False, } assert message["type"] == "http.response.trailers" assert isinstance(message["headers"], list) # ============================================================================ # State Sharing Tests # ============================================================================ class TestStateSharing: """Test state sharing between lifespan and request scopes.""" def _create_protocol_with_state(self, state): """Create an ASGIProtocol with worker state.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() worker.state = state return ASGIProtocol(worker) def _create_mock_request(self): """Create a mock HTTP request.""" request = mock.Mock() request.method = "GET" request.path = "/" request.raw_path = b"/" request.query = "" request.version = (1, 1) request.scheme = "http" request.headers = [] return request def test_state_in_http_scope(self): """Test that state dict is included in HTTP scope.""" state = {"db": "connected", "cache": "ready"} protocol = self._create_protocol_with_state(state) request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) assert "state" in scope assert scope["state"] == state def test_state_is_same_object(self): """Test that state is the same object (not a copy).""" state = {"counter": 0} protocol = self._create_protocol_with_state(state) request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) # Modifying scope["state"] should modify the original scope["state"]["counter"] = 1 assert state["counter"] == 1 def test_state_not_present_without_worker_state(self): """Test that state is not in scope if worker has no state.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock(spec=["cfg", "log", "asgi"]) worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() protocol = ASGIProtocol(worker) request = self._create_mock_request() scope = protocol._build_http_scope(request, None, None) assert "state" not in scope # ============================================================================ # HTTP Disconnect Event Tests (ASGI Spec Compliance) # https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event # ============================================================================ class TestHTTPDisconnectEvent: """Test http.disconnect event compliance with ASGI spec. Per the ASGI HTTP Connection Scope spec: - Disconnect event is sent when client closes connection - Event type MUST be "http.disconnect" - Apps should receive this event and clean up gracefully """ def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() worker.nr_conns = 1 worker.loop = mock.Mock() protocol = ASGIProtocol(worker) protocol.reader = mock.Mock() return protocol def test_disconnect_event_type(self): """Test that disconnect event signals body receiver per ASGI spec.""" from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Create a mock request for the body receiver mock_request = mock.Mock() mock_request.content_length = 100 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) protocol._body_receiver = body_receiver # Simulate client disconnect protocol.connection_lost(None) # Per ASGI spec: disconnect should be signaled assert body_receiver._closed def test_disconnect_event_sent_on_connection_lost(self): """Test that disconnect is signaled when connection is lost.""" from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Create a mock request for the body receiver mock_request = mock.Mock() mock_request.content_length = 100 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) protocol._body_receiver = body_receiver assert not body_receiver._closed # Simulate client disconnect protocol.connection_lost(None) # Disconnect should have been signaled assert body_receiver._closed def test_disconnect_sets_closed_flag(self): """Test that connection_lost sets the closed flag.""" protocol = self._create_protocol() assert protocol._closed is False protocol.connection_lost(None) assert protocol._closed is True def test_disconnect_allows_graceful_cleanup(self): """Test that disconnect doesn't immediately cancel task. Per ASGI spec, apps should have opportunity to clean up when they receive http.disconnect. """ protocol = self._create_protocol() # Create a mock task mock_task = mock.Mock() mock_task.done.return_value = False protocol._task = mock_task # Simulate disconnect protocol.connection_lost(None) # Task should NOT be cancelled immediately mock_task.cancel.assert_not_called() # Cancellation should be scheduled after grace period protocol.worker.loop.call_later.assert_called_once() @pytest.mark.asyncio async def test_disconnect_message_format(self): """Test http.disconnect message format per ASGI spec. When body is complete and disconnect is signaled, receive() should return {"type": "http.disconnect"}. """ import asyncio from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Create a mock request with no body mock_request = mock.Mock() mock_request.content_length = 0 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) protocol._body_receiver = body_receiver # Get initial body message (empty body) msg1 = await body_receiver.receive() assert msg1["type"] == "http.request" assert msg1["more_body"] is False # Signal disconnect (simulating connection_lost) # After the fix, receive() waits for actual disconnect signal body_receiver.signal_disconnect() # Now receive should return disconnect msg2 = await body_receiver.receive() # Per ASGI spec, disconnect message only has 'type' assert msg2 == {"type": "http.disconnect"} assert len(msg2) == 1 # ============================================================================ # BodyReceiver Disconnect Regression Tests # https://github.com/benoitc/gunicorn/issues/3484 # ============================================================================ class TestBodyReceiverDisconnect: """Regression tests for BodyReceiver._wait_for_disconnect() behavior. The original bug: BodyReceiver.receive() immediately returned `http.disconnect` when `_body_finished` was True, but Django (and other ASGI frameworks) call `receive()` to listen for client disconnect AFTER the response is sent. This caused Django's `listen_for_disconnect` task to think the client disconnected before the response could be sent. The fix: After body is finished, receive() now calls _wait_for_disconnect() which blocks until signal_disconnect() is called or the waiter is cancelled. """ def _create_protocol(self): """Create an ASGIProtocol instance for testing.""" from gunicorn.asgi.protocol import ASGIProtocol worker = mock.Mock() worker.cfg = Config() worker.log = mock.Mock() worker.asgi = mock.Mock() worker.nr_conns = 1 worker.loop = mock.Mock() protocol = ASGIProtocol(worker) protocol._closed = False return protocol @pytest.mark.asyncio async def test_body_receiver_waits_for_disconnect_after_body_finished(self): """Test that receive() blocks after body is finished until disconnect is signaled. This tests the core regression fix: after body is complete, calling receive() should NOT immediately return http.disconnect. It should block until the connection actually closes. """ from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Create a request with no body (body finishes immediately) mock_request = mock.Mock() mock_request.content_length = 0 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) # Get the initial body message (empty body, more_body=False) msg1 = await body_receiver.receive() assert msg1["type"] == "http.request" assert msg1["body"] == b"" assert msg1["more_body"] is False # At this point, _body_finished is True assert body_receiver._body_finished is True assert body_receiver._closed is False # Now calling receive() should block, not return immediately # We test this by starting receive() as a task and verifying it doesn't complete import asyncio receive_task = asyncio.create_task(body_receiver.receive()) # Give the task a moment to start await asyncio.sleep(0.01) # Task should NOT be done yet (it's waiting for disconnect) assert not receive_task.done() # Now signal disconnect body_receiver.signal_disconnect() # Task should complete now msg2 = await asyncio.wait_for(receive_task, timeout=1.0) assert msg2 == {"type": "http.disconnect"} @pytest.mark.asyncio async def test_body_receiver_immediate_disconnect_if_already_closed(self): """Test that receive() immediately returns http.disconnect if already closed. If signal_disconnect() has already been called before receive(), it should return http.disconnect immediately without blocking. """ from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Create a request with body mock_request = mock.Mock() mock_request.content_length = 100 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) # Signal disconnect BEFORE calling receive body_receiver.signal_disconnect() assert body_receiver._closed is True # receive() should return disconnect immediately msg = await body_receiver.receive() assert msg == {"type": "http.disconnect"} @pytest.mark.asyncio async def test_body_receiver_respects_protocol_closed_state(self): """Test that receive() checks protocol._closed state. If the protocol is closed but signal_disconnect wasn't called, receive() should still detect the disconnect. """ from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() mock_request = mock.Mock() mock_request.content_length = 0 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) # Consume the body first msg1 = await body_receiver.receive() assert msg1["type"] == "http.request" assert msg1["more_body"] is False # Mark protocol as closed protocol._closed = True # Start receive task - should detect protocol closure import asyncio receive_task = asyncio.create_task(body_receiver.receive()) # Give it a moment await asyncio.sleep(0.01) # Wake up the waiter by signaling disconnect body_receiver.signal_disconnect() msg2 = await asyncio.wait_for(receive_task, timeout=1.0) assert msg2 == {"type": "http.disconnect"} @pytest.mark.asyncio async def test_asgi_app_with_disconnect_listener(self): """Test Django-style ASGI app pattern that listens for disconnect. This simulates a real-world scenario where an ASGI app: 1. Reads the request body 2. Sends a response 3. Calls receive() to wait for client disconnect (background task) The bug caused step 3 to return immediately with http.disconnect, making Django think the client disconnected mid-response. """ from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() # Simulate a POST request with body mock_request = mock.Mock() mock_request.content_length = 13 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) # Simulate sending body data via callback body_receiver.feed(b"Hello, World!") body_receiver.set_complete() # Step 1: App reads the body msg1 = await body_receiver.receive() assert msg1["type"] == "http.request" assert msg1["body"] == b"Hello, World!" assert msg1["more_body"] is False # At this point body is finished assert body_receiver._body_finished is True # Step 2: App would send response here (simulated) response_sent = True # Step 3: App starts listening for disconnect (like Django does) import asyncio disconnect_received = asyncio.Event() async def listen_for_disconnect(): """Simulates Django's disconnect listener task.""" msg = await body_receiver.receive() if msg["type"] == "http.disconnect": disconnect_received.set() return msg listener_task = asyncio.create_task(listen_for_disconnect()) # Give listener task time to start waiting await asyncio.sleep(0.01) # Listener should be blocked waiting, not done assert not listener_task.done() assert response_sent # Response was sent before disconnect detected # Simulate client closing connection after receiving response body_receiver.signal_disconnect() # Now listener should complete msg = await asyncio.wait_for(listener_task, timeout=1.0) assert msg == {"type": "http.disconnect"} assert disconnect_received.is_set() @pytest.mark.asyncio async def test_body_receiver_cancellation_during_wait(self): """Test that receive() handles cancellation while waiting for disconnect. When the ASGI task is cancelled (e.g., timeout), the waiting receive() catches the CancelledError, marks itself as closed, and the cancellation propagates up from the await. The body receiver is marked as closed to ensure subsequent calls return disconnect immediately. """ from gunicorn.asgi.protocol import BodyReceiver protocol = self._create_protocol() mock_request = mock.Mock() mock_request.content_length = 0 mock_request.chunked = False body_receiver = BodyReceiver(mock_request, protocol) # Consume body await body_receiver.receive() import asyncio receive_task = asyncio.create_task(body_receiver.receive()) # Let it start waiting await asyncio.sleep(0.01) assert not receive_task.done() # Cancel the task receive_task.cancel() # Wait for the task to finish - it may raise CancelledError # or return disconnect depending on timing try: msg = await receive_task # If it returns, it should be a disconnect message assert msg == {"type": "http.disconnect"} except asyncio.CancelledError: # Cancellation propagated - this is also valid pass # Body receiver should be marked as closed after cancellation assert body_receiver._closed is True