From f9ca296d21c9f80afecbd8ba7136cf874cd6a6ff Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Mon, 23 Mar 2026 13:38:47 +0100 Subject: [PATCH] Fix WebSocket and body receiver issues in ASGI protocol - Fix body receiver timeout handling to prevent infinite loops - Add WebSocket data forwarding via callbacks instead of StreamReader - Fix HTTP/2 stream race condition where DATA frames arrive before first read - Update WebSocketProtocol constructor (removed reader parameter) --- gunicorn/asgi/protocol.py | 35 +++++++++++++++++++++++++---------- gunicorn/http2/stream.py | 4 ++++ tests/test_asgi_worker.py | 4 ++-- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index b0dc9acb..7a593f15 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -248,12 +248,10 @@ class BodyReceiver: if self._chunks: return self._pop_chunk() - if self._complete: - self._body_finished = True - return {"type": "http.request", "body": b"", "more_body": False} - - # Timeout or other condition - return empty with more_body=True - return {"type": "http.request", "body": b"", "more_body": True} + # Complete OR timeout - mark body finished to prevent infinite loops + # Apps should not loop forever waiting for body that won't arrive + self._body_finished = True + return {"type": "http.request", "body": b"", "more_body": False} async def _wait_for_data(self): """Wait for body data to arrive via callback.""" @@ -321,6 +319,9 @@ class ASGIProtocol(asyncio.Protocol): # Write flow control self._flow_control = None + # WebSocket protocol (set during upgrade, receives data via callbacks) + self._websocket = None + def connection_made(self, transport): """Called when a connection is established.""" self.transport = transport @@ -454,6 +455,10 @@ class ASGIProtocol(asyncio.Protocol): def data_received(self, data): """Called when data is received on the connection.""" + if self._websocket: + # WebSocket path - forward to WebSocket protocol + self._websocket.feed_data(data) + return if self.reader: # HTTP/2 path - use StreamReader self.reader.feed_data(data) @@ -545,6 +550,10 @@ class ASGIProtocol(asyncio.Protocol): if self.reader: self.reader.feed_eof() + # Signal EOF to WebSocket if active + if self._websocket: + self._websocket.feed_eof() + # Signal disconnect to the app via the body receiver if self._body_receiver is not None: self._body_receiver.signal_disconnect() @@ -750,10 +759,17 @@ class ASGIProtocol(asyncio.Protocol): """Handle WebSocket upgrade request.""" from gunicorn.asgi.websocket import WebSocketProtocol + # Stop callback parser - WebSocket uses its own data handling + self._callback_parser = None + scope = self._build_websocket_scope(request, sockname, peername) ws_protocol = WebSocketProtocol( - self.transport, self.reader, scope, self.app, self.log + self.transport, scope, self.app, self.log ) + + # Store reference so data_received() forwards to WebSocket + self._websocket = ws_protocol + await ws_protocol.run() async def _handle_http_request(self, request, sockname, peername): @@ -772,9 +788,8 @@ class ASGIProtocol(asyncio.Protocol): response_headers = [] response_sent = 0 - # Create body receiver - reads directly on demand, no Queue/Task overhead - body_receiver = BodyReceiver(request, self) - self._body_receiver = body_receiver + # Use body receiver created in _on_headers_complete (receives data via callbacks) + body_receiver = self._body_receiver async def send(message): nonlocal response_started, response_complete, exc_to_raise diff --git a/gunicorn/http2/stream.py b/gunicorn/http2/stream.py index 8d03fdaf..34b7be18 100644 --- a/gunicorn/http2/stream.py +++ b/gunicorn/http2/stream.py @@ -309,6 +309,10 @@ class HTTP2Stream: # Initialize event lazily (avoids event loop issues at construction) if self._body_event is None: self._body_event = asyncio.Event() + # If data already arrived before event existed, set it now + # This prevents race where DATA frames arrive before first read + if self._body_chunks or self._body_complete: + self._body_event.set() while True: # Return chunk if available diff --git a/tests/test_asgi_worker.py b/tests/test_asgi_worker.py index f52c1726..bf534601 100644 --- a/tests/test_asgi_worker.py +++ b/tests/test_asgi_worker.py @@ -389,7 +389,7 @@ class TestWebSocketProtocol: from gunicorn.asgi.websocket import WebSocketProtocol # Create a minimal protocol instance - protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + protocol = WebSocketProtocol(None, {}, None, mock.Mock()) # Test unmasking (XOR operation) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) @@ -402,7 +402,7 @@ class TestWebSocketProtocol: """Test WebSocket frame unmasking with empty payload.""" from gunicorn.asgi.websocket import WebSocketProtocol - protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) + protocol = WebSocketProtocol(None, {}, None, mock.Mock()) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) unmasked = protocol._unmask(b"", masking_key)