diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index d1f37517..8bdafc70 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -907,17 +907,22 @@ class ASGIProtocol(asyncio.Protocol): response_status = message["status"] response_headers = message.get("headers", []) - # Check if Content-Length is present - has_content_length = any( - (name.lower() if isinstance(name, str) else name.lower()) == b"content-length" - or (name.lower() if isinstance(name, str) else name.lower()) == "content-length" - for name, _ in response_headers - ) + # Check if Content-Length or Transfer-Encoding is present + has_content_length = False + has_transfer_encoding = False + for name, _ in response_headers: + name_lower = name.lower() if isinstance(name, str) else name.lower() + if name_lower in (b"content-length", "content-length"): + has_content_length = True + elif name_lower in (b"transfer-encoding", "transfer-encoding"): + has_transfer_encoding = True + use_chunked = True # Framework already set chunked encoding # Use chunked encoding for HTTP/1.1 streaming responses without Content-Length # Skip for 1xx informational responses (RFC 9110) + # Skip if Transfer-Encoding already set by framework is_informational = 100 <= response_status < 200 - if not has_content_length and request.version >= (1, 1) and not is_informational: + if not has_content_length and not has_transfer_encoding and request.version >= (1, 1) and not is_informational: use_chunked = True response_headers = list(response_headers) + [(b"transfer-encoding", b"chunked")] diff --git a/tests/test_asgi_protocol_compat.py b/tests/test_asgi_protocol_compat.py index 1a0a514b..0bf0036a 100644 --- a/tests/test_asgi_protocol_compat.py +++ b/tests/test_asgi_protocol_compat.py @@ -827,3 +827,372 @@ class TestWebSocketHandshake: with pytest.raises(RuntimeError, match="Missing Sec-WebSocket-Key"): await protocol._send({"type": "websocket.accept"}) + + +# ============================================================================= +# Transfer-Encoding Header Duplicate Prevention Tests +# ============================================================================= + +class TestTransferEncodingChunked: + """Test Transfer-Encoding: chunked handling for streaming responses. + + Reproduces failures: + - test_streaming_response[blacksheep] - multiple Transfer-Encoding headers + - test_streaming_large_response[blacksheep] - multiple Transfer-Encoding headers + - test_sse_events[blacksheep] - multiple Transfer-Encoding headers + + Root cause: BlackSheep's StreamedContent sets Transfer-Encoding: chunked, + and gunicorn was adding another one without checking if it already exists. + """ + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + from gunicorn.config import Config + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.log.access_log_enabled = False + worker.asgi = mock.Mock() + worker.nr = 0 + worker.max_requests = 10000 + worker.alive = True + worker.state = {} + + protocol = ASGIProtocol(worker) + protocol.transport = mock.Mock() + protocol._response_buffer = None + protocol._flow_control = mock.Mock() + protocol._flow_control.drain = mock.AsyncMock() + protocol._closed = False + return protocol + + def _create_mock_request(self, version=(1, 1)): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/stream" + request.raw_path = b"/stream" + request.query = "" + request.version = version + request.scheme = "http" + request.headers = [] + request.uri = "/stream" + request.should_close = mock.Mock(return_value=False) + request.content_length = 0 + request.chunked = False + return request + + @pytest.mark.asyncio + async def test_no_duplicate_transfer_encoding_when_framework_sets_it(self): + """Gunicorn should not add Transfer-Encoding if framework already set it. + + This reproduces the BlackSheep streaming issue where frameworks that + set their own Transfer-Encoding: chunked header get duplicate headers. + """ + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + from gunicorn.config import Config + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.log.access_log_enabled = False + worker.asgi = mock.Mock() + worker.nr = 0 + worker.max_requests = 10000 + worker.alive = True + worker.state = {} + + protocol = ASGIProtocol(worker) + protocol.transport = mock.Mock() + protocol._closed = False + protocol._flow_control = mock.Mock() + protocol._flow_control.drain = mock.AsyncMock() + + written_data = [] + protocol.transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + + request = self._create_mock_request() + + # Create body receiver + protocol._body_receiver = BodyReceiver(request, protocol) + protocol._body_receiver.set_complete() + + # Simulate framework that sets Transfer-Encoding: chunked (like BlackSheep) + async def streaming_app_with_te(scope, receive, send): + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + (b"transfer-encoding", b"chunked"), # Framework sets this + ], + }) + await send({ + "type": "http.response.body", + "body": b"chunk-0\n", + "more_body": True, + }) + await send({ + "type": "http.response.body", + "body": b"", + "more_body": False, + }) + + protocol.app = streaming_app_with_te + + # Handle the request + sockname = ("127.0.0.1", 8000) + peername = ("127.0.0.1", 50000) + + await protocol._handle_http_request(request, sockname, peername) + + # Verify only one Transfer-Encoding header in response + response = b"".join(written_data) + te_count = response.lower().count(b"transfer-encoding") + assert te_count == 1, f"Expected 1 Transfer-Encoding header, got {te_count}" + + @pytest.mark.asyncio + async def test_adds_transfer_encoding_when_not_present(self): + """Gunicorn should add Transfer-Encoding for streaming without Content-Length.""" + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + from gunicorn.config import Config + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.log.access_log_enabled = False + worker.asgi = mock.Mock() + worker.nr = 0 + worker.max_requests = 10000 + worker.alive = True + worker.state = {} + + protocol = ASGIProtocol(worker) + protocol.transport = mock.Mock() + protocol._closed = False + protocol._flow_control = mock.Mock() + protocol._flow_control.drain = mock.AsyncMock() + + written_data = [] + protocol.transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + + request = self._create_mock_request() + + protocol._body_receiver = BodyReceiver(request, protocol) + protocol._body_receiver.set_complete() + + # Streaming app without Transfer-Encoding header + async def streaming_app_without_te(scope, receive, send): + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + # No Transfer-Encoding - gunicorn should add it + ], + }) + await send({ + "type": "http.response.body", + "body": b"chunk-0\n", + "more_body": True, + }) + await send({ + "type": "http.response.body", + "body": b"", + "more_body": False, + }) + + protocol.app = streaming_app_without_te + + sockname = ("127.0.0.1", 8000) + peername = ("127.0.0.1", 50000) + + await protocol._handle_http_request(request, sockname, peername) + + response = b"".join(written_data) + te_count = response.lower().count(b"transfer-encoding") + assert te_count == 1, f"Expected 1 Transfer-Encoding header, got {te_count}" + assert b"transfer-encoding: chunked" in response.lower() + + @pytest.mark.asyncio + async def test_no_transfer_encoding_when_content_length_set(self): + """Gunicorn should not add Transfer-Encoding when Content-Length is present.""" + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + from gunicorn.config import Config + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.log.access_log_enabled = False + worker.asgi = mock.Mock() + worker.nr = 0 + worker.max_requests = 10000 + worker.alive = True + worker.state = {} + + protocol = ASGIProtocol(worker) + protocol.transport = mock.Mock() + protocol._closed = False + protocol._flow_control = mock.Mock() + protocol._flow_control.drain = mock.AsyncMock() + + written_data = [] + protocol.transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + + request = self._create_mock_request() + + protocol._body_receiver = BodyReceiver(request, protocol) + protocol._body_receiver.set_complete() + + # App with Content-Length + async def app_with_content_length(scope, receive, send): + body = b"Hello, World!" + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + (b"content-length", str(len(body)).encode()), + ], + }) + await send({ + "type": "http.response.body", + "body": body, + "more_body": False, + }) + + protocol.app = app_with_content_length + + sockname = ("127.0.0.1", 8000) + peername = ("127.0.0.1", 50000) + + await protocol._handle_http_request(request, sockname, peername) + + response = b"".join(written_data) + te_count = response.lower().count(b"transfer-encoding") + assert te_count == 0, f"Expected no Transfer-Encoding header, got {te_count}" + assert b"content-length: 13" in response.lower() + + @pytest.mark.asyncio + async def test_chunked_body_encoding_with_framework_te(self): + """Body chunks should still be properly encoded when framework sets TE.""" + from gunicorn.asgi.protocol import ASGIProtocol, BodyReceiver + from gunicorn.config import Config + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.log.access_log_enabled = False + worker.asgi = mock.Mock() + worker.nr = 0 + worker.max_requests = 10000 + worker.alive = True + worker.state = {} + + protocol = ASGIProtocol(worker) + protocol.transport = mock.Mock() + protocol._closed = False + protocol._flow_control = mock.Mock() + protocol._flow_control.drain = mock.AsyncMock() + + written_data = [] + protocol.transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + + request = self._create_mock_request() + + protocol._body_receiver = BodyReceiver(request, protocol) + protocol._body_receiver.set_complete() + + # Framework sets Transfer-Encoding: chunked + async def streaming_app(scope, receive, send): + await send({ + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"text/plain"), + (b"transfer-encoding", b"chunked"), + ], + }) + await send({ + "type": "http.response.body", + "body": b"Hello", + "more_body": True, + }) + await send({ + "type": "http.response.body", + "body": b"World", + "more_body": True, + }) + await send({ + "type": "http.response.body", + "body": b"", + "more_body": False, + }) + + protocol.app = streaming_app + + sockname = ("127.0.0.1", 8000) + peername = ("127.0.0.1", 50000) + + await protocol._handle_http_request(request, sockname, peername) + + response = b"".join(written_data) + + # Body should be chunked encoded + assert b"5\r\nHello\r\n" in response, "First chunk not properly encoded" + assert b"5\r\nWorld\r\n" in response, "Second chunk not properly encoded" + assert b"0\r\n\r\n" in response, "Terminal chunk missing" + + def test_transfer_encoding_detection_logic_bytes(self): + """Test the header detection logic with bytes headers.""" + response_headers = [ + (b"content-type", b"text/plain"), + (b"transfer-encoding", b"chunked"), + ] + + has_transfer_encoding = False + for name, _ in response_headers: + name_lower = name.lower() if isinstance(name, str) else name.lower() + if name_lower in (b"transfer-encoding", "transfer-encoding"): + has_transfer_encoding = True + + assert has_transfer_encoding, "Should detect Transfer-Encoding header (bytes)" + + def test_transfer_encoding_detection_logic_str(self): + """Test the header detection logic with string headers.""" + response_headers = [ + ("content-type", "text/plain"), + ("Transfer-Encoding", "chunked"), + ] + + has_transfer_encoding = False + for name, _ in response_headers: + name_lower = name.lower() if isinstance(name, str) else name.lower() + if name_lower in (b"transfer-encoding", "transfer-encoding"): + has_transfer_encoding = True + + assert has_transfer_encoding, "Should detect Transfer-Encoding header (str)" + + def test_transfer_encoding_detection_logic_mixed_case(self): + """Test detection handles various case variations.""" + test_cases = [ + (b"Transfer-Encoding", b"chunked"), + (b"TRANSFER-ENCODING", b"chunked"), + (b"transfer-encoding", b"chunked"), + ("Transfer-Encoding", "chunked"), + ("TRANSFER-ENCODING", "chunked"), + ("transfer-encoding", "chunked"), + ] + + for header_name, header_value in test_cases: + response_headers = [(header_name, header_value)] + + has_transfer_encoding = False + for name, _ in response_headers: + name_lower = name.lower() if isinstance(name, str) else name.lower() + if name_lower in (b"transfer-encoding", "transfer-encoding"): + has_transfer_encoding = True + + assert has_transfer_encoding, f"Should detect {header_name!r}"