gunicorn/tests/test_asgi_compliance.py
Benoit Chesneau 7953c2585b Fix ASGI disconnect handling for Django-style apps
BodyReceiver.receive() now blocks after body is finished until actual
disconnect, instead of returning http.disconnect immediately. This fixes
Django's listen_for_disconnect task thinking client disconnected early.

Adds regression tests for the fix.

Fixes #3484
2026-04-02 23:55:27 +02:00

1108 lines
38 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI 3.0 specification compliance tests.
Tests that gunicorn's ASGI implementation conforms to the ASGI 3.0 spec:
https://asgi.readthedocs.io/en/latest/specs/main.html
"""
from unittest import mock
import pytest
from gunicorn.config import Config
# ============================================================================
# ASGI Version Tests
# ============================================================================
class TestASGIVersion:
"""Test ASGI version information in scope."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, **kwargs):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
request.headers = kwargs.get("headers", [])
return request
def test_asgi_version_present(self):
"""Test that 'asgi' key is present in HTTP scope."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert "asgi" in scope
def test_asgi_version_is_dict(self):
"""Test that 'asgi' value is a dictionary."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert isinstance(scope["asgi"], dict)
def test_asgi_version_value(self):
"""Test that ASGI version is '3.0'."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert scope["asgi"]["version"] == "3.0"
def test_asgi_spec_version_present(self):
"""Test that spec_version is present in ASGI dict."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert "spec_version" in scope["asgi"]
def test_asgi_spec_version_value(self):
"""Test that spec_version follows semantic versioning."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
spec_version = scope["asgi"]["spec_version"]
# Should be in format "X.Y" (major.minor)
parts = spec_version.split(".")
assert len(parts) == 2
assert all(part.isdigit() for part in parts)
# ============================================================================
# HTTP Scope Keys Tests (ASGI HTTP Connection Scope)
# ============================================================================
class TestHTTPScopeKeys:
"""Test required keys in HTTP connection scope per ASGI spec."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, **kwargs):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1") if path else b"")
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
request.headers = kwargs.get("headers", [])
return request
def test_type_key_present(self):
"""Test 'type' key is present and equals 'http'."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert scope["type"] == "http"
def test_http_version_key_present(self):
"""Test 'http_version' key is present."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert "http_version" in scope
assert scope["http_version"] == "1.1"
def test_http_version_formats(self):
"""Test various HTTP version formats."""
protocol = self._create_protocol()
# HTTP/1.0
request_10 = self._create_mock_request(version=(1, 0))
scope_10 = protocol._build_http_scope(request_10, None, None)
assert scope_10["http_version"] == "1.0"
# HTTP/1.1
request_11 = self._create_mock_request(version=(1, 1))
scope_11 = protocol._build_http_scope(request_11, None, None)
assert scope_11["http_version"] == "1.1"
def test_method_key_present(self):
"""Test 'method' key is present and is uppercase string."""
protocol = self._create_protocol()
for method in ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]:
request = self._create_mock_request(method=method)
scope = protocol._build_http_scope(request, None, None)
assert scope["method"] == method
assert scope["method"].isupper()
def test_scheme_key_present(self):
"""Test 'scheme' key is present."""
protocol = self._create_protocol()
# HTTP
request_http = self._create_mock_request(scheme="http")
scope_http = protocol._build_http_scope(request_http, None, None)
assert scope_http["scheme"] == "http"
# HTTPS
request_https = self._create_mock_request(scheme="https")
scope_https = protocol._build_http_scope(request_https, None, None)
assert scope_https["scheme"] == "https"
def test_path_key_present(self):
"""Test 'path' key is present and starts with /."""
protocol = self._create_protocol()
request = self._create_mock_request(path="/api/users")
scope = protocol._build_http_scope(request, None, None)
assert "path" in scope
assert scope["path"] == "/api/users"
assert scope["path"].startswith("/")
def test_raw_path_key_present(self):
"""Test 'raw_path' key is present and is bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(path="/api/users")
scope = protocol._build_http_scope(request, None, None)
assert "raw_path" in scope
assert isinstance(scope["raw_path"], bytes)
assert scope["raw_path"] == b"/api/users"
def test_query_string_key_present(self):
"""Test 'query_string' key is present and is bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(query="page=1&limit=10")
scope = protocol._build_http_scope(request, None, None)
assert "query_string" in scope
assert isinstance(scope["query_string"], bytes)
assert scope["query_string"] == b"page=1&limit=10"
def test_query_string_empty(self):
"""Test 'query_string' is empty bytes when no query."""
protocol = self._create_protocol()
request = self._create_mock_request(query="")
scope = protocol._build_http_scope(request, None, None)
assert scope["query_string"] == b""
def test_root_path_key_present(self):
"""Test 'root_path' key is present."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert "root_path" in scope
assert isinstance(scope["root_path"], str)
def test_headers_key_present(self):
"""Test 'headers' key is present and is list of 2-tuples."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[("HOST", "localhost"), ("ACCEPT", "text/html")]
)
scope = protocol._build_http_scope(request, None, None)
assert "headers" in scope
assert isinstance(scope["headers"], list)
for header in scope["headers"]:
assert isinstance(header, tuple)
assert len(header) == 2
def test_headers_are_bytes(self):
"""Test that header names and values are bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[("HOST", "localhost"), ("CONTENT-TYPE", "application/json")]
)
scope = protocol._build_http_scope(request, None, None)
for name, value in scope["headers"]:
assert isinstance(name, bytes), f"Header name should be bytes: {name}"
assert isinstance(value, bytes), f"Header value should be bytes: {value}"
def test_headers_names_lowercase(self):
"""Test that header names are lowercase."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[("HOST", "localhost"), ("Content-Type", "application/json")]
)
scope = protocol._build_http_scope(request, None, None)
for name, _ in scope["headers"]:
assert name == name.lower(), f"Header name should be lowercase: {name}"
def test_server_key_present(self):
"""Test 'server' key is present when sockname provided."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("127.0.0.1", 12345),
)
assert "server" in scope
assert scope["server"] == ("127.0.0.1", 8000)
def test_server_key_none(self):
"""Test 'server' key is None when sockname not provided."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert scope["server"] is None
def test_client_key_present(self):
"""Test 'client' key is present when peername provided."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("192.168.1.100", 54321),
)
assert "client" in scope
assert scope["client"] == ("192.168.1.100", 54321)
def test_client_key_none(self):
"""Test 'client' key is None when peername not provided."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert scope["client"] is None
# ============================================================================
# HTTP Message Format Tests
# ============================================================================
class TestHTTPMessageFormats:
"""Test HTTP message formats per ASGI spec."""
def test_http_request_message_format(self):
"""Test http.request message format."""
message = {
"type": "http.request",
"body": b"request body",
"more_body": False,
}
assert message["type"] == "http.request"
assert isinstance(message["body"], bytes)
assert isinstance(message["more_body"], bool)
def test_http_request_message_empty_body(self):
"""Test http.request message with empty body."""
message = {
"type": "http.request",
"body": b"",
"more_body": False,
}
assert message["body"] == b""
assert message["more_body"] is False
def test_http_response_start_format(self):
"""Test http.response.start message format."""
message = {
"type": "http.response.start",
"status": 200,
"headers": [
(b"content-type", b"text/plain"),
(b"content-length", b"13"),
],
}
assert message["type"] == "http.response.start"
assert isinstance(message["status"], int)
assert 100 <= message["status"] < 600
assert isinstance(message["headers"], list)
def test_http_response_body_format(self):
"""Test http.response.body message format."""
message = {
"type": "http.response.body",
"body": b"Hello, World!",
"more_body": False,
}
assert message["type"] == "http.response.body"
assert isinstance(message["body"], bytes)
assert isinstance(message["more_body"], bool)
def test_http_response_body_streaming(self):
"""Test http.response.body message for streaming."""
# First chunk
chunk1 = {
"type": "http.response.body",
"body": b"First chunk",
"more_body": True,
}
# Last chunk
chunk2 = {
"type": "http.response.body",
"body": b"Last chunk",
"more_body": False,
}
assert chunk1["more_body"] is True
assert chunk2["more_body"] is False
def test_http_disconnect_format(self):
"""Test http.disconnect message format."""
message = {"type": "http.disconnect"}
assert message["type"] == "http.disconnect"
# ============================================================================
# HTTP Response Status Codes Tests
# ============================================================================
class TestHTTPStatusCodes:
"""Test HTTP status code handling."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def test_reason_phrase_informational(self):
"""Test reason phrases for 1xx status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(100) == "Continue"
assert protocol._get_reason_phrase(101) == "Switching Protocols"
assert protocol._get_reason_phrase(103) == "Early Hints"
def test_reason_phrase_success(self):
"""Test reason phrases for 2xx status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(200) == "OK"
assert protocol._get_reason_phrase(201) == "Created"
assert protocol._get_reason_phrase(202) == "Accepted"
assert protocol._get_reason_phrase(204) == "No Content"
assert protocol._get_reason_phrase(206) == "Partial Content"
def test_reason_phrase_redirect(self):
"""Test reason phrases for 3xx status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(301) == "Moved Permanently"
assert protocol._get_reason_phrase(302) == "Found"
assert protocol._get_reason_phrase(303) == "See Other"
assert protocol._get_reason_phrase(304) == "Not Modified"
assert protocol._get_reason_phrase(307) == "Temporary Redirect"
assert protocol._get_reason_phrase(308) == "Permanent Redirect"
def test_reason_phrase_client_error(self):
"""Test reason phrases for 4xx status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(400) == "Bad Request"
assert protocol._get_reason_phrase(401) == "Unauthorized"
assert protocol._get_reason_phrase(403) == "Forbidden"
assert protocol._get_reason_phrase(404) == "Not Found"
assert protocol._get_reason_phrase(405) == "Method Not Allowed"
assert protocol._get_reason_phrase(408) == "Request Timeout"
assert protocol._get_reason_phrase(409) == "Conflict"
assert protocol._get_reason_phrase(410) == "Gone"
assert protocol._get_reason_phrase(422) == "Unprocessable Entity"
assert protocol._get_reason_phrase(429) == "Too Many Requests"
def test_reason_phrase_server_error(self):
"""Test reason phrases for 5xx status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(500) == "Internal Server Error"
assert protocol._get_reason_phrase(501) == "Not Implemented"
assert protocol._get_reason_phrase(502) == "Bad Gateway"
assert protocol._get_reason_phrase(503) == "Service Unavailable"
assert protocol._get_reason_phrase(504) == "Gateway Timeout"
def test_reason_phrase_unknown(self):
"""Test reason phrase for unknown status codes."""
protocol = self._create_protocol()
assert protocol._get_reason_phrase(999) == "Unknown"
assert protocol._get_reason_phrase(418) == "Unknown" # I'm a teapot not defined
# ============================================================================
# Informational Response Tests (103 Early Hints, etc.)
# ============================================================================
class TestInformationalResponses:
"""Test support for HTTP 1xx informational responses."""
def test_http_response_informational_format(self):
"""Test http.response.informational message format."""
message = {
"type": "http.response.informational",
"status": 103,
"headers": [
(b"link", b"</style.css>; rel=preload; as=style"),
],
}
assert message["type"] == "http.response.informational"
assert 100 <= message["status"] < 200
assert isinstance(message["headers"], list)
def test_early_hints_103(self):
"""Test 103 Early Hints message format."""
message = {
"type": "http.response.informational",
"status": 103,
"headers": [
(b"link", b"</style.css>; rel=preload; as=style"),
(b"link", b"</script.js>; rel=preload; as=script"),
],
}
assert message["status"] == 103
# ============================================================================
# ASGI Extensions Tests
# ============================================================================
class TestASGIExtensions:
"""Test ASGI extensions support."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_http2_request(self, **kwargs):
"""Create a mock HTTP/2 request with priority."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
request.path = kwargs.get("path", "/")
request.query = kwargs.get("query", "")
request.uri = kwargs.get("uri", "/")
request.scheme = kwargs.get("scheme", "https")
request.headers = kwargs.get("headers", [])
request.priority_weight = kwargs.get("priority_weight", 16)
request.priority_depends_on = kwargs.get("priority_depends_on", 0)
return request
def test_http2_scope_has_extensions(self):
"""Test that HTTP/2 scope includes extensions dict."""
protocol = self._create_protocol()
request = self._create_mock_http2_request()
scope = protocol._build_http2_scope(request, None, None)
assert "extensions" in scope
assert isinstance(scope["extensions"], dict)
def test_http2_priority_extension(self):
"""Test http.response.priority extension in HTTP/2 scope."""
protocol = self._create_protocol()
request = self._create_mock_http2_request(
priority_weight=128,
priority_depends_on=5,
)
scope = protocol._build_http2_scope(request, None, None)
assert "http.response.priority" in scope["extensions"]
priority = scope["extensions"]["http.response.priority"]
assert "weight" in priority
assert "depends_on" in priority
assert priority["weight"] == 128
assert priority["depends_on"] == 5
def test_http2_trailers_extension(self):
"""Test http.response.trailers extension in HTTP/2 scope."""
protocol = self._create_protocol()
request = self._create_mock_http2_request()
scope = protocol._build_http2_scope(request, None, None)
assert "http.response.trailers" in scope["extensions"]
def test_http_response_trailers_message_format(self):
"""Test http.response.trailers message format."""
message = {
"type": "http.response.trailers",
"headers": [
(b"grpc-status", b"0"),
(b"grpc-message", b""),
],
"more_trailers": False,
}
assert message["type"] == "http.response.trailers"
assert isinstance(message["headers"], list)
# ============================================================================
# State Sharing Tests
# ============================================================================
class TestStateSharing:
"""Test state sharing between lifespan and request scopes."""
def _create_protocol_with_state(self, state):
"""Create an ASGIProtocol with worker state."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.state = state
return ASGIProtocol(worker)
def _create_mock_request(self):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = []
return request
def test_state_in_http_scope(self):
"""Test that state dict is included in HTTP scope."""
state = {"db": "connected", "cache": "ready"}
protocol = self._create_protocol_with_state(state)
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert "state" in scope
assert scope["state"] == state
def test_state_is_same_object(self):
"""Test that state is the same object (not a copy)."""
state = {"counter": 0}
protocol = self._create_protocol_with_state(state)
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
# Modifying scope["state"] should modify the original
scope["state"]["counter"] = 1
assert state["counter"] == 1
def test_state_not_present_without_worker_state(self):
"""Test that state is not in scope if worker has no state."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock(spec=["cfg", "log", "asgi"])
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
protocol = ASGIProtocol(worker)
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert "state" not in scope
# ============================================================================
# HTTP Disconnect Event Tests (ASGI Spec Compliance)
# https://asgi.readthedocs.io/en/latest/specs/www.html#disconnect-receive-event
# ============================================================================
class TestHTTPDisconnectEvent:
"""Test http.disconnect event compliance with ASGI spec.
Per the ASGI HTTP Connection Scope spec:
- Disconnect event is sent when client closes connection
- Event type MUST be "http.disconnect"
- Apps should receive this event and clean up gracefully
"""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol.reader = mock.Mock()
return protocol
def test_disconnect_event_type(self):
"""Test that disconnect event signals body receiver per ASGI spec."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Simulate client disconnect
protocol.connection_lost(None)
# Per ASGI spec: disconnect should be signaled
assert body_receiver._closed
def test_disconnect_event_sent_on_connection_lost(self):
"""Test that disconnect is signaled when connection is lost."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Create a mock request for the body receiver
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
assert not body_receiver._closed
# Simulate client disconnect
protocol.connection_lost(None)
# Disconnect should have been signaled
assert body_receiver._closed
def test_disconnect_sets_closed_flag(self):
"""Test that connection_lost sets the closed flag."""
protocol = self._create_protocol()
assert protocol._closed is False
protocol.connection_lost(None)
assert protocol._closed is True
def test_disconnect_allows_graceful_cleanup(self):
"""Test that disconnect doesn't immediately cancel task.
Per ASGI spec, apps should have opportunity to clean up
when they receive http.disconnect.
"""
protocol = self._create_protocol()
# Create a mock task
mock_task = mock.Mock()
mock_task.done.return_value = False
protocol._task = mock_task
# Simulate disconnect
protocol.connection_lost(None)
# Task should NOT be cancelled immediately
mock_task.cancel.assert_not_called()
# Cancellation should be scheduled after grace period
protocol.worker.loop.call_later.assert_called_once()
@pytest.mark.asyncio
async def test_disconnect_message_format(self):
"""Test http.disconnect message format per ASGI spec.
When body is complete and disconnect is signaled, receive()
should return {"type": "http.disconnect"}.
"""
import asyncio
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Create a mock request with no body
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
protocol._body_receiver = body_receiver
# Get initial body message (empty body)
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["more_body"] is False
# Signal disconnect (simulating connection_lost)
# After the fix, receive() waits for actual disconnect signal
body_receiver.signal_disconnect()
# Now receive should return disconnect
msg2 = await body_receiver.receive()
# Per ASGI spec, disconnect message only has 'type'
assert msg2 == {"type": "http.disconnect"}
assert len(msg2) == 1
# ============================================================================
# BodyReceiver Disconnect Regression Tests
# https://github.com/benoitc/gunicorn/issues/3484
# ============================================================================
class TestBodyReceiverDisconnect:
"""Regression tests for BodyReceiver._wait_for_disconnect() behavior.
The original bug: BodyReceiver.receive() immediately returned
`http.disconnect` when `_body_finished` was True, but Django (and other
ASGI frameworks) call `receive()` to listen for client disconnect AFTER
the response is sent. This caused Django's `listen_for_disconnect` task
to think the client disconnected before the response could be sent.
The fix: After body is finished, receive() now calls _wait_for_disconnect()
which blocks until signal_disconnect() is called or the waiter is cancelled.
"""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._closed = False
return protocol
@pytest.mark.asyncio
async def test_body_receiver_waits_for_disconnect_after_body_finished(self):
"""Test that receive() blocks after body is finished until disconnect is signaled.
This tests the core regression fix: after body is complete, calling receive()
should NOT immediately return http.disconnect. It should block until the
connection actually closes.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Create a request with no body (body finishes immediately)
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Get the initial body message (empty body, more_body=False)
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["body"] == b""
assert msg1["more_body"] is False
# At this point, _body_finished is True
assert body_receiver._body_finished is True
assert body_receiver._closed is False
# Now calling receive() should block, not return immediately
# We test this by starting receive() as a task and verifying it doesn't complete
import asyncio
receive_task = asyncio.create_task(body_receiver.receive())
# Give the task a moment to start
await asyncio.sleep(0.01)
# Task should NOT be done yet (it's waiting for disconnect)
assert not receive_task.done()
# Now signal disconnect
body_receiver.signal_disconnect()
# Task should complete now
msg2 = await asyncio.wait_for(receive_task, timeout=1.0)
assert msg2 == {"type": "http.disconnect"}
@pytest.mark.asyncio
async def test_body_receiver_immediate_disconnect_if_already_closed(self):
"""Test that receive() immediately returns http.disconnect if already closed.
If signal_disconnect() has already been called before receive(),
it should return http.disconnect immediately without blocking.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Create a request with body
mock_request = mock.Mock()
mock_request.content_length = 100
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Signal disconnect BEFORE calling receive
body_receiver.signal_disconnect()
assert body_receiver._closed is True
# receive() should return disconnect immediately
msg = await body_receiver.receive()
assert msg == {"type": "http.disconnect"}
@pytest.mark.asyncio
async def test_body_receiver_respects_protocol_closed_state(self):
"""Test that receive() checks protocol._closed state.
If the protocol is closed but signal_disconnect wasn't called,
receive() should still detect the disconnect.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Consume the body first
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["more_body"] is False
# Mark protocol as closed
protocol._closed = True
# Start receive task - should detect protocol closure
import asyncio
receive_task = asyncio.create_task(body_receiver.receive())
# Give it a moment
await asyncio.sleep(0.01)
# Wake up the waiter by signaling disconnect
body_receiver.signal_disconnect()
msg2 = await asyncio.wait_for(receive_task, timeout=1.0)
assert msg2 == {"type": "http.disconnect"}
@pytest.mark.asyncio
async def test_asgi_app_with_disconnect_listener(self):
"""Test Django-style ASGI app pattern that listens for disconnect.
This simulates a real-world scenario where an ASGI app:
1. Reads the request body
2. Sends a response
3. Calls receive() to wait for client disconnect (background task)
The bug caused step 3 to return immediately with http.disconnect,
making Django think the client disconnected mid-response.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
# Simulate a POST request with body
mock_request = mock.Mock()
mock_request.content_length = 13
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Simulate sending body data via callback
body_receiver.feed(b"Hello, World!")
body_receiver.set_complete()
# Step 1: App reads the body
msg1 = await body_receiver.receive()
assert msg1["type"] == "http.request"
assert msg1["body"] == b"Hello, World!"
assert msg1["more_body"] is False
# At this point body is finished
assert body_receiver._body_finished is True
# Step 2: App would send response here (simulated)
response_sent = True
# Step 3: App starts listening for disconnect (like Django does)
import asyncio
disconnect_received = asyncio.Event()
async def listen_for_disconnect():
"""Simulates Django's disconnect listener task."""
msg = await body_receiver.receive()
if msg["type"] == "http.disconnect":
disconnect_received.set()
return msg
listener_task = asyncio.create_task(listen_for_disconnect())
# Give listener task time to start waiting
await asyncio.sleep(0.01)
# Listener should be blocked waiting, not done
assert not listener_task.done()
assert response_sent # Response was sent before disconnect detected
# Simulate client closing connection after receiving response
body_receiver.signal_disconnect()
# Now listener should complete
msg = await asyncio.wait_for(listener_task, timeout=1.0)
assert msg == {"type": "http.disconnect"}
assert disconnect_received.is_set()
@pytest.mark.asyncio
async def test_body_receiver_cancellation_during_wait(self):
"""Test that receive() handles cancellation while waiting for disconnect.
When the ASGI task is cancelled (e.g., timeout), the waiting receive()
catches the CancelledError, marks itself as closed, and the cancellation
propagates up from the await. The body receiver is marked as closed
to ensure subsequent calls return disconnect immediately.
"""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Consume body
await body_receiver.receive()
import asyncio
receive_task = asyncio.create_task(body_receiver.receive())
# Let it start waiting
await asyncio.sleep(0.01)
assert not receive_task.done()
# Cancel the task
receive_task.cancel()
# Wait for the task to finish - it may raise CancelledError
# or return disconnect depending on timing
try:
msg = await receive_task
# If it returns, it should be a disconnect message
assert msg == {"type": "http.disconnect"}
except asyncio.CancelledError:
# Cancellation propagated - this is also valid
pass
# Body receiver should be marked as closed after cancellation
assert body_receiver._closed is True