From 241c479701859c00cfaff631cc35951c254ba491 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Mon, 23 Mar 2026 13:08:57 +0100 Subject: [PATCH] Fix WebSocket race condition in callback-based _read_exact() Add double-check after clearing _data_event to prevent deadlock when data arrives between clear() and wait(). The race condition occurred when: 1. Task A checks buffer, needs more data 2. Task A clears _data_event 3. Task B (feed_data) sets event 4. Task A awaits on cleared event - deadlock The fix re-checks the buffer after clear() to catch data that arrived in the race window. Also adds tests for edge cases: race condition simulation, EOF during wait, fragmented message reassembly, and control frames during fragmentation. --- gunicorn/asgi/websocket.py | 57 +++- tests/test_asgi_websocket_protocol.py | 369 +++++++++++++++++++++++++- 2 files changed, 410 insertions(+), 16 deletions(-) diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py index 737268b6..d1b2251b 100644 --- a/gunicorn/asgi/websocket.py +++ b/gunicorn/asgi/websocket.py @@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" class WebSocketProtocol: - """WebSocket connection handler for ASGI applications.""" + """WebSocket connection handler for ASGI applications. - def __init__(self, transport, reader, scope, app, log): + Uses callback-based data feeding instead of StreamReader for efficiency. + Data is fed via feed_data() from the parent protocol's data_received(). + """ + + def __init__(self, transport, scope, app, log): """Initialize WebSocket protocol handler. Args: transport: asyncio transport for writing - reader: asyncio StreamReader for reading scope: ASGI WebSocket scope dict app: ASGI application callable log: Logger instance """ self.transport = transport - self.reader = reader self.scope = scope self.app = app self.log = log @@ -70,6 +72,26 @@ class WebSocketProtocol: # Receive queue for incoming messages self._receive_queue = asyncio.Queue() + # Callback-based data reception (replaces StreamReader) + self._buffer = bytearray() + self._data_event = asyncio.Event() + self._eof = False + + def feed_data(self, data): + """Feed incoming data from the parent protocol's data_received(). + + Args: + data: bytes received on the connection + """ + if data: + self._buffer.extend(data) + self._data_event.set() + + def feed_eof(self): + """Signal that the connection has been closed.""" + self._eof = True + self._data_event.set() + async def run(self): """Run the WebSocket ASGI application.""" # Send initial connect event @@ -295,14 +317,25 @@ class WebSocketProtocol: return (opcode, payload) async def _read_exact(self, n): - """Read exactly n bytes from the reader.""" - try: - data = await self.reader.readexactly(n) - return data - except asyncio.IncompleteReadError: - return None - except Exception: - return None + """Read exactly n bytes from internal buffer. + + Waits for data via the callback-fed buffer instead of StreamReader. + """ + while len(self._buffer) < n: + if self._eof: + return None + self._data_event.clear() + # Critical: check buffer AGAIN after clearing to avoid race + # condition where data arrives between clear() and wait() + if len(self._buffer) >= n: + break + await self._data_event.wait() + if self._eof and len(self._buffer) < n: + return None + + data = bytes(self._buffer[:n]) + del self._buffer[:n] + return data def _unmask(self, payload, masking_key): """Unmask WebSocket payload data.""" diff --git a/tests/test_asgi_websocket_protocol.py b/tests/test_asgi_websocket_protocol.py index 2f339961..08db5866 100644 --- a/tests/test_asgi_websocket_protocol.py +++ b/tests/test_asgi_websocket_protocol.py @@ -175,7 +175,7 @@ class TestWebSocketFrameMasking: def _create_protocol(self): """Create a WebSocketProtocol instance for testing.""" from gunicorn.asgi.websocket import WebSocketProtocol - return WebSocketProtocol(None, None, {}, None, mock.Mock()) + return WebSocketProtocol(None, {}, None, mock.Mock()) def test_unmask_simple(self): """Test basic unmasking operation.""" @@ -298,7 +298,6 @@ class TestWebSocketProtocolInstance: return WebSocketProtocol( transport=mock.Mock(), - reader=mock.Mock(), scope=scope, app=mock.AsyncMock(), log=mock.Mock(), @@ -601,11 +600,9 @@ class TestWebSocketAsync: } transport = mock.Mock() - reader = mock.Mock() return WebSocketProtocol( transport=transport, - reader=reader, scope=scope, app=mock.AsyncMock(), log=mock.Mock(), @@ -675,3 +672,367 @@ class TestWebSocketAsync: await protocol._send({"type": "websocket.close", "code": 1000}) assert protocol.closed is True + + +# ============================================================================ +# Callback-based Data Feeding Tests +# ============================================================================ + +class TestWebSocketCallbackDataFeeding: + """Tests for callback-based data feeding (replaces StreamReader).""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + def test_initial_buffer_empty(self): + """Test that initial buffer is empty.""" + protocol = self._create_protocol() + assert len(protocol._buffer) == 0 + assert protocol._eof is False + + def test_feed_data_adds_to_buffer(self): + """Test that feed_data adds bytes to buffer.""" + protocol = self._create_protocol() + + protocol.feed_data(b"Hello") + assert bytes(protocol._buffer) == b"Hello" + + protocol.feed_data(b" World") + assert bytes(protocol._buffer) == b"Hello World" + + def test_feed_data_ignores_empty(self): + """Test that feed_data ignores empty data.""" + protocol = self._create_protocol() + + protocol.feed_data(b"") + assert len(protocol._buffer) == 0 + + protocol.feed_data(None) + # Should not raise, just be ignored + + def test_feed_data_sets_event(self): + """Test that feed_data sets the data event.""" + protocol = self._create_protocol() + + assert not protocol._data_event.is_set() + protocol.feed_data(b"data") + assert protocol._data_event.is_set() + + def test_feed_eof_sets_flag(self): + """Test that feed_eof sets the EOF flag.""" + protocol = self._create_protocol() + + assert protocol._eof is False + protocol.feed_eof() + assert protocol._eof is True + + def test_feed_eof_sets_event(self): + """Test that feed_eof sets the data event.""" + protocol = self._create_protocol() + + assert not protocol._data_event.is_set() + protocol.feed_eof() + assert protocol._data_event.is_set() + + +class TestWebSocketReadExact: + """Tests for _read_exact method with callback-based buffer.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_read_exact_with_sufficient_data(self): + """Test _read_exact returns data when buffer has enough.""" + protocol = self._create_protocol() + + # Pre-fill buffer + protocol.feed_data(b"Hello World") + + result = await protocol._read_exact(5) + assert result == b"Hello" + assert bytes(protocol._buffer) == b" World" + + @pytest.mark.asyncio + async def test_read_exact_consumes_buffer(self): + """Test _read_exact properly consumes buffer.""" + protocol = self._create_protocol() + + protocol.feed_data(b"ABCDEFGH") + + result1 = await protocol._read_exact(3) + assert result1 == b"ABC" + + result2 = await protocol._read_exact(3) + assert result2 == b"DEF" + + assert bytes(protocol._buffer) == b"GH" + + @pytest.mark.asyncio + async def test_read_exact_returns_none_on_eof(self): + """Test _read_exact returns None when EOF with insufficient data.""" + protocol = self._create_protocol() + + protocol.feed_data(b"Hi") + protocol.feed_eof() + + # Request more data than available after EOF + result = await protocol._read_exact(10) + assert result is None + + @pytest.mark.asyncio + async def test_read_exact_waits_for_data(self): + """Test _read_exact waits when buffer is insufficient.""" + import asyncio + protocol = self._create_protocol() + + # Start read that needs more data + read_task = asyncio.create_task(protocol._read_exact(10)) + + # Give task a chance to start waiting + await asyncio.sleep(0.01) + assert not read_task.done() + + # Feed enough data + protocol.feed_data(b"1234567890") + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_handles_incremental_data(self): + """Test _read_exact handles data arriving in chunks.""" + import asyncio + protocol = self._create_protocol() + + # Start read needing 10 bytes + read_task = asyncio.create_task(protocol._read_exact(10)) + + await asyncio.sleep(0.01) + + # Feed data incrementally + protocol.feed_data(b"123") + await asyncio.sleep(0.01) + assert not read_task.done() + + protocol.feed_data(b"456") + await asyncio.sleep(0.01) + assert not read_task.done() + + protocol.feed_data(b"7890") + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_race_condition(self): + """Test _read_exact handles race condition when data arrives during clear/wait gap. + + This tests the fix for the race condition where: + 1. Task A checks buffer, needs more data + 2. Task A clears _data_event + 3. Task B (data_received) calls feed_data(), sets event + 4. Task A would wait forever on cleared event - DEADLOCK + + The fix adds a buffer check after clear() to catch this case. + """ + import asyncio + protocol = self._create_protocol() + + # Pre-fill with partial data + protocol.feed_data(b"12345") + + # Start read needing 10 bytes + read_task = asyncio.create_task(protocol._read_exact(10)) + await asyncio.sleep(0.01) + assert not read_task.done() + + # Simulate race: feed remaining data rapidly + # In the buggy version, if data arrives right after clear() but before wait(), + # the event gets set then immediately the wait() would block on a stale clear + protocol.feed_data(b"67890") + + # Should complete without deadlock + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result == b"1234567890" + + @pytest.mark.asyncio + async def test_read_exact_multiple_feeds_before_wait(self): + """Test _read_exact when all data arrives before wait starts.""" + import asyncio + protocol = self._create_protocol() + + # Feed all data before starting read - should not block + protocol.feed_data(b"Complete message here") + + result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1) + assert result == b"Complete" + + # Buffer should have remainder + assert bytes(protocol._buffer) == b" message here" + + @pytest.mark.asyncio + async def test_read_exact_eof_during_wait(self): + """Test _read_exact handles EOF arriving while waiting for data.""" + import asyncio + protocol = self._create_protocol() + + # Start read needing more data than we'll provide + read_task = asyncio.create_task(protocol._read_exact(100)) + + await asyncio.sleep(0.01) + assert not read_task.done() + + # Feed some data but not enough + protocol.feed_data(b"partial") + await asyncio.sleep(0.01) + assert not read_task.done() + + # Signal EOF - should cause read to return None + protocol.feed_eof() + + result = await asyncio.wait_for(read_task, timeout=1.0) + assert result is None + + +# ============================================================================ +# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4) +# ============================================================================ + +class TestWebSocketFragmentedMessages: + """Tests for WebSocket fragmented message handling.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance for testing.""" + from gunicorn.asgi.websocket import WebSocketProtocol + return WebSocketProtocol( + transport=mock.Mock(), + scope={"type": "websocket", "headers": []}, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + def _create_masked_frame(self, fin, opcode, payload, mask_key=None): + """Create a masked WebSocket frame. + + Args: + fin: FIN bit (1 for final, 0 for continuation) + opcode: Frame opcode + payload: Frame payload bytes + mask_key: 4-byte masking key (generated if None) + + Returns: + bytes: Complete masked frame + """ + if mask_key is None: + mask_key = bytes([0x37, 0xfa, 0x21, 0x3d]) + + frame = bytearray() + + # First byte: FIN + RSV(000) + opcode + frame.append((fin << 7) | opcode) + + # Second byte: MASK(1) + length + length = len(payload) + if length < 126: + frame.append(0x80 | length) + elif length < 65536: + frame.append(0x80 | 126) + frame.extend(struct.pack("!H", length)) + else: + frame.append(0x80 | 127) + frame.extend(struct.pack("!Q", length)) + + # Masking key + frame.extend(mask_key) + + # Masked payload + masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload)) + frame.extend(masked_payload) + + return bytes(frame) + + @pytest.mark.asyncio + async def test_fragmented_message_reassembly(self): + """Test reassembly of fragmented text message with multiple continuation frames.""" + from gunicorn.asgi.websocket import ( + OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT + ) + import asyncio + + protocol = self._create_protocol() + + # Build fragmented message: "Hello" + " " + "World" + "!" + # First frame: opcode=TEXT, FIN=0, payload="Hello" + frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello") + # Continuation frames: opcode=CONTINUATION, FIN=0 + frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ") + frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World") + # Final frame: opcode=CONTINUATION, FIN=1 + frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!") + + # Feed all frames + protocol.feed_data(frame1 + frame2 + frame3 + frame4) + + # Read frames - first 3 should return CONTINUATION with empty payload (waiting) + result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started + + result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued + + result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued + + # Final frame should return complete reassembled message with original opcode + result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result4 == (OPCODE_TEXT, b"Hello World!") + + @pytest.mark.asyncio + async def test_control_frame_during_fragmentation(self): + """Test that control frames (ping) can arrive during fragmented message. + + RFC 6455 Section 5.4: Control frames MAY be injected in the middle + of a fragmented message. + """ + from gunicorn.asgi.websocket import ( + OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING + ) + import asyncio + + protocol = self._create_protocol() + + # Start fragmented message + frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello") + # Ping frame in the middle (control frames are always FIN=1) + ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping") + # Continue and finish fragmented message + frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World") + + protocol.feed_data(frame1 + ping_frame + frame2) + + # First read: fragment started (waiting for more) + result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result1 == (OPCODE_CONTINUATION, b"") + + # Second read: ping frame (control frames handled separately) + result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result2 == (OPCODE_PING, b"ping") + + # Third read: complete reassembled message + result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0) + assert result3 == (OPCODE_TEXT, b"Hello World")