Fix WebSocket and body receiver issues in ASGI protocol

- Fix body receiver timeout handling to prevent infinite loops
- Add WebSocket data forwarding via callbacks instead of StreamReader
- Fix HTTP/2 stream race condition where DATA frames arrive before first read
- Update WebSocketProtocol constructor (removed reader parameter)
This commit is contained in:
Benoit Chesneau 2026-03-23 13:38:47 +01:00
parent af8897a14c
commit f9ca296d21
3 changed files with 31 additions and 12 deletions

View File

@ -248,12 +248,10 @@ class BodyReceiver:
if self._chunks:
return self._pop_chunk()
if self._complete:
self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False}
# Timeout or other condition - return empty with more_body=True
return {"type": "http.request", "body": b"", "more_body": True}
# Complete OR timeout - mark body finished to prevent infinite loops
# Apps should not loop forever waiting for body that won't arrive
self._body_finished = True
return {"type": "http.request", "body": b"", "more_body": False}
async def _wait_for_data(self):
"""Wait for body data to arrive via callback."""
@ -321,6 +319,9 @@ class ASGIProtocol(asyncio.Protocol):
# Write flow control
self._flow_control = None
# WebSocket protocol (set during upgrade, receives data via callbacks)
self._websocket = None
def connection_made(self, transport):
"""Called when a connection is established."""
self.transport = transport
@ -454,6 +455,10 @@ class ASGIProtocol(asyncio.Protocol):
def data_received(self, data):
"""Called when data is received on the connection."""
if self._websocket:
# WebSocket path - forward to WebSocket protocol
self._websocket.feed_data(data)
return
if self.reader:
# HTTP/2 path - use StreamReader
self.reader.feed_data(data)
@ -545,6 +550,10 @@ class ASGIProtocol(asyncio.Protocol):
if self.reader:
self.reader.feed_eof()
# Signal EOF to WebSocket if active
if self._websocket:
self._websocket.feed_eof()
# Signal disconnect to the app via the body receiver
if self._body_receiver is not None:
self._body_receiver.signal_disconnect()
@ -750,10 +759,17 @@ class ASGIProtocol(asyncio.Protocol):
"""Handle WebSocket upgrade request."""
from gunicorn.asgi.websocket import WebSocketProtocol
# Stop callback parser - WebSocket uses its own data handling
self._callback_parser = None
scope = self._build_websocket_scope(request, sockname, peername)
ws_protocol = WebSocketProtocol(
self.transport, self.reader, scope, self.app, self.log
self.transport, scope, self.app, self.log
)
# Store reference so data_received() forwards to WebSocket
self._websocket = ws_protocol
await ws_protocol.run()
async def _handle_http_request(self, request, sockname, peername):
@ -772,9 +788,8 @@ class ASGIProtocol(asyncio.Protocol):
response_headers = []
response_sent = 0
# Create body receiver - reads directly on demand, no Queue/Task overhead
body_receiver = BodyReceiver(request, self)
self._body_receiver = body_receiver
# Use body receiver created in _on_headers_complete (receives data via callbacks)
body_receiver = self._body_receiver
async def send(message):
nonlocal response_started, response_complete, exc_to_raise

View File

@ -309,6 +309,10 @@ class HTTP2Stream:
# Initialize event lazily (avoids event loop issues at construction)
if self._body_event is None:
self._body_event = asyncio.Event()
# If data already arrived before event existed, set it now
# This prevents race where DATA frames arrive before first read
if self._body_chunks or self._body_complete:
self._body_event.set()
while True:
# Return chunk if available

View File

@ -389,7 +389,7 @@ class TestWebSocketProtocol:
from gunicorn.asgi.websocket import WebSocketProtocol
# Create a minimal protocol instance
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
# Test unmasking (XOR operation)
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
@ -402,7 +402,7 @@ class TestWebSocketProtocol:
"""Test WebSocket frame unmasking with empty payload."""
from gunicorn.asgi.websocket import WebSocketProtocol
protocol = WebSocketProtocol(None, None, {}, None, mock.Mock())
protocol = WebSocketProtocol(None, {}, None, mock.Mock())
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
unmasked = protocol._unmask(b"", masking_key)