Fix WebSocket and body receiver issues in ASGI protocol

- Fix body receiver timeout handling to prevent infinite loops
- Add WebSocket data forwarding via callbacks instead of StreamReader
- Fix HTTP/2 stream race condition where DATA frames arrive before first read
- Update WebSocketProtocol constructor (removed reader parameter)
This commit is contained in:
Benoit Chesneau 2026-03-23 13:38:47 +01:00
parent af8897a14c
commit f9ca296d21
3 changed files with 31 additions and 12 deletions

View File

@ -248,12 +248,10 @@ class BodyReceiver:
if self._chunks: if self._chunks:
return self._pop_chunk() return self._pop_chunk()
if self._complete: # Complete OR timeout - mark body finished to prevent infinite loops
self._body_finished = True # Apps should not loop forever waiting for body that won't arrive
return {"type": "http.request", "body": b"", "more_body": False} self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False}
# Timeout or other condition - return empty with more_body=True
return {"type": "http.request", "body": b"", "more_body": True}
async def _wait_for_data(self): async def _wait_for_data(self):
"""Wait for body data to arrive via callback.""" """Wait for body data to arrive via callback."""
@ -321,6 +319,9 @@ class ASGIProtocol(asyncio.Protocol):
# Write flow control # Write flow control
self._flow_control = None self._flow_control = None
# WebSocket protocol (set during upgrade, receives data via callbacks)
self._websocket = None
def connection_made(self, transport): def connection_made(self, transport):
"""Called when a connection is established.""" """Called when a connection is established."""
self.transport = transport self.transport = transport
@ -454,6 +455,10 @@ class ASGIProtocol(asyncio.Protocol):
def data_received(self, data): def data_received(self, data):
"""Called when data is received on the connection.""" """Called when data is received on the connection."""
if self._websocket:
# WebSocket path - forward to WebSocket protocol
self._websocket.feed_data(data)
return
if self.reader: if self.reader:
# HTTP/2 path - use StreamReader # HTTP/2 path - use StreamReader
self.reader.feed_data(data) self.reader.feed_data(data)
@ -545,6 +550,10 @@ class ASGIProtocol(asyncio.Protocol):
if self.reader: if self.reader:
self.reader.feed_eof() self.reader.feed_eof()
# Signal EOF to WebSocket if active
if self._websocket:
self._websocket.feed_eof()
# Signal disconnect to the app via the body receiver # Signal disconnect to the app via the body receiver
if self._body_receiver is not None: if self._body_receiver is not None:
self._body_receiver.signal_disconnect() self._body_receiver.signal_disconnect()
@ -750,10 +759,17 @@ class ASGIProtocol(asyncio.Protocol):
"""Handle WebSocket upgrade request.""" """Handle WebSocket upgrade request."""
from gunicorn.asgi.websocket import WebSocketProtocol from gunicorn.asgi.websocket import WebSocketProtocol
# Stop callback parser - WebSocket uses its own data handling
self._callback_parser = None
scope = self._build_websocket_scope(request, sockname, peername) scope = self._build_websocket_scope(request, sockname, peername)
ws_protocol = WebSocketProtocol( ws_protocol = WebSocketProtocol(
self.transport, self.reader, scope, self.app, self.log self.transport, scope, self.app, self.log
) )
# Store reference so data_received() forwards to WebSocket
self._websocket = ws_protocol
await ws_protocol.run() await ws_protocol.run()
async def _handle_http_request(self, request, sockname, peername): async def _handle_http_request(self, request, sockname, peername):
@ -772,9 +788,8 @@ class ASGIProtocol(asyncio.Protocol):
response_headers = [] response_headers = []
response_sent = 0 response_sent = 0
# Create body receiver - reads directly on demand, no Queue/Task overhead # Use body receiver created in _on_headers_complete (receives data via callbacks)
body_receiver = BodyReceiver(request, self) body_receiver = self._body_receiver
self._body_receiver = body_receiver
async def send(message): async def send(message):
nonlocal response_started, response_complete, exc_to_raise nonlocal response_started, response_complete, exc_to_raise

View File

@ -309,6 +309,10 @@ class HTTP2Stream:
# Initialize event lazily (avoids event loop issues at construction) # Initialize event lazily (avoids event loop issues at construction)
if self._body_event is None: if self._body_event is None:
self._body_event = asyncio.Event() self._body_event = asyncio.Event()
# If data already arrived before event existed, set it now
# This prevents race where DATA frames arrive before first read
if self._body_chunks or self._body_complete:
self._body_event.set()
while True: while True:
# Return chunk if available # Return chunk if available

View File

@ -389,7 +389,7 @@ class TestWebSocketProtocol:
from gunicorn.asgi.websocket import WebSocketProtocol from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance # Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) protocol = WebSocketProtocol(None, {}, None, mock.Mock())
# Test unmasking (XOR operation) # Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
@ -402,7 +402,7 @@ class TestWebSocketProtocol:
"""Test WebSocket frame unmasking with empty payload.""" """Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock()) protocol = WebSocketProtocol(None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d]) masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key) unmasked = protocol._unmask(b"", masking_key)