Fix WebSocket race condition in callback-based _read_exact()

Add double-check after clearing _data_event to prevent deadlock when
data arrives between clear() and wait(). The race condition occurred
when:
1. Task A checks buffer, needs more data
2. Task A clears _data_event
3. Task B (feed_data) sets event
4. Task A awaits on cleared event - deadlock

The fix re-checks the buffer after clear() to catch data that arrived
in the race window.

Also adds tests for edge cases: race condition simulation, EOF during
wait, fragmented message reassembly, and control frames during
fragmentation.
This commit is contained in:
Benoit Chesneau 2026-03-23 13:08:57 +01:00
parent f76a4942c3
commit 241c479701
2 changed files with 410 additions and 16 deletions

View File

@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
class WebSocketProtocol:
"""WebSocket connection handler for ASGI applications."""
"""WebSocket connection handler for ASGI applications.
def __init__(self, transport, reader, scope, app, log):
Uses callback-based data feeding instead of StreamReader for efficiency.
Data is fed via feed_data() from the parent protocol's data_received().
"""
def __init__(self, transport, scope, app, log):
"""Initialize WebSocket protocol handler.
Args:
transport: asyncio transport for writing
reader: asyncio StreamReader for reading
scope: ASGI WebSocket scope dict
app: ASGI application callable
log: Logger instance
"""
self.transport = transport
self.reader = reader
self.scope = scope
self.app = app
self.log = log
@ -70,6 +72,26 @@ class WebSocketProtocol:
# Receive queue for incoming messages
self._receive_queue = asyncio.Queue()
# Callback-based data reception (replaces StreamReader)
self._buffer = bytearray()
self._data_event = asyncio.Event()
self._eof = False
def feed_data(self, data):
"""Feed incoming data from the parent protocol's data_received().
Args:
data: bytes received on the connection
"""
if data:
self._buffer.extend(data)
self._data_event.set()
def feed_eof(self):
"""Signal that the connection has been closed."""
self._eof = True
self._data_event.set()
async def run(self):
"""Run the WebSocket ASGI application."""
# Send initial connect event
@ -295,14 +317,25 @@ class WebSocketProtocol:
return (opcode, payload)
async def _read_exact(self, n):
"""Read exactly n bytes from the reader."""
try:
data = await self.reader.readexactly(n)
return data
except asyncio.IncompleteReadError:
return None
except Exception:
return None
"""Read exactly n bytes from internal buffer.
Waits for data via the callback-fed buffer instead of StreamReader.
"""
while len(self._buffer) < n:
if self._eof:
return None
self._data_event.clear()
# Critical: check buffer AGAIN after clearing to avoid race
# condition where data arrives between clear() and wait()
if len(self._buffer) >= n:
break
await self._data_event.wait()
if self._eof and len(self._buffer) < n:
return None
data = bytes(self._buffer[:n])
del self._buffer[:n]
return data
def _unmask(self, payload, masking_key):
"""Unmask WebSocket payload data."""

View File

@ -175,7 +175,7 @@ class TestWebSocketFrameMasking:
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(None, None, {}, None, mock.Mock())
return WebSocketProtocol(None, {}, None, mock.Mock())
def test_unmask_simple(self):
"""Test basic unmasking operation."""
@ -298,7 +298,6 @@ class TestWebSocketProtocolInstance:
return WebSocketProtocol(
transport=mock.Mock(),
reader=mock.Mock(),
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
@ -601,11 +600,9 @@ class TestWebSocketAsync:
}
transport = mock.Mock()
reader = mock.Mock()
return WebSocketProtocol(
transport=transport,
reader=reader,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
@ -675,3 +672,367 @@ class TestWebSocketAsync:
await protocol._send({"type": "websocket.close", "code": 1000})
assert protocol.closed is True
# ============================================================================
# Callback-based Data Feeding Tests
# ============================================================================
class TestWebSocketCallbackDataFeeding:
"""Tests for callback-based data feeding (replaces StreamReader)."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
def test_initial_buffer_empty(self):
"""Test that initial buffer is empty."""
protocol = self._create_protocol()
assert len(protocol._buffer) == 0
assert protocol._eof is False
def test_feed_data_adds_to_buffer(self):
"""Test that feed_data adds bytes to buffer."""
protocol = self._create_protocol()
protocol.feed_data(b"Hello")
assert bytes(protocol._buffer) == b"Hello"
protocol.feed_data(b" World")
assert bytes(protocol._buffer) == b"Hello World"
def test_feed_data_ignores_empty(self):
"""Test that feed_data ignores empty data."""
protocol = self._create_protocol()
protocol.feed_data(b"")
assert len(protocol._buffer) == 0
protocol.feed_data(None)
# Should not raise, just be ignored
def test_feed_data_sets_event(self):
"""Test that feed_data sets the data event."""
protocol = self._create_protocol()
assert not protocol._data_event.is_set()
protocol.feed_data(b"data")
assert protocol._data_event.is_set()
def test_feed_eof_sets_flag(self):
"""Test that feed_eof sets the EOF flag."""
protocol = self._create_protocol()
assert protocol._eof is False
protocol.feed_eof()
assert protocol._eof is True
def test_feed_eof_sets_event(self):
"""Test that feed_eof sets the data event."""
protocol = self._create_protocol()
assert not protocol._data_event.is_set()
protocol.feed_eof()
assert protocol._data_event.is_set()
class TestWebSocketReadExact:
"""Tests for _read_exact method with callback-based buffer."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_read_exact_with_sufficient_data(self):
"""Test _read_exact returns data when buffer has enough."""
protocol = self._create_protocol()
# Pre-fill buffer
protocol.feed_data(b"Hello World")
result = await protocol._read_exact(5)
assert result == b"Hello"
assert bytes(protocol._buffer) == b" World"
@pytest.mark.asyncio
async def test_read_exact_consumes_buffer(self):
"""Test _read_exact properly consumes buffer."""
protocol = self._create_protocol()
protocol.feed_data(b"ABCDEFGH")
result1 = await protocol._read_exact(3)
assert result1 == b"ABC"
result2 = await protocol._read_exact(3)
assert result2 == b"DEF"
assert bytes(protocol._buffer) == b"GH"
@pytest.mark.asyncio
async def test_read_exact_returns_none_on_eof(self):
"""Test _read_exact returns None when EOF with insufficient data."""
protocol = self._create_protocol()
protocol.feed_data(b"Hi")
protocol.feed_eof()
# Request more data than available after EOF
result = await protocol._read_exact(10)
assert result is None
@pytest.mark.asyncio
async def test_read_exact_waits_for_data(self):
"""Test _read_exact waits when buffer is insufficient."""
import asyncio
protocol = self._create_protocol()
# Start read that needs more data
read_task = asyncio.create_task(protocol._read_exact(10))
# Give task a chance to start waiting
await asyncio.sleep(0.01)
assert not read_task.done()
# Feed enough data
protocol.feed_data(b"1234567890")
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_handles_incremental_data(self):
"""Test _read_exact handles data arriving in chunks."""
import asyncio
protocol = self._create_protocol()
# Start read needing 10 bytes
read_task = asyncio.create_task(protocol._read_exact(10))
await asyncio.sleep(0.01)
# Feed data incrementally
protocol.feed_data(b"123")
await asyncio.sleep(0.01)
assert not read_task.done()
protocol.feed_data(b"456")
await asyncio.sleep(0.01)
assert not read_task.done()
protocol.feed_data(b"7890")
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_race_condition(self):
"""Test _read_exact handles race condition when data arrives during clear/wait gap.
This tests the fix for the race condition where:
1. Task A checks buffer, needs more data
2. Task A clears _data_event
3. Task B (data_received) calls feed_data(), sets event
4. Task A would wait forever on cleared event - DEADLOCK
The fix adds a buffer check after clear() to catch this case.
"""
import asyncio
protocol = self._create_protocol()
# Pre-fill with partial data
protocol.feed_data(b"12345")
# Start read needing 10 bytes
read_task = asyncio.create_task(protocol._read_exact(10))
await asyncio.sleep(0.01)
assert not read_task.done()
# Simulate race: feed remaining data rapidly
# In the buggy version, if data arrives right after clear() but before wait(),
# the event gets set then immediately the wait() would block on a stale clear
protocol.feed_data(b"67890")
# Should complete without deadlock
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result == b"1234567890"
@pytest.mark.asyncio
async def test_read_exact_multiple_feeds_before_wait(self):
"""Test _read_exact when all data arrives before wait starts."""
import asyncio
protocol = self._create_protocol()
# Feed all data before starting read - should not block
protocol.feed_data(b"Complete message here")
result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1)
assert result == b"Complete"
# Buffer should have remainder
assert bytes(protocol._buffer) == b" message here"
@pytest.mark.asyncio
async def test_read_exact_eof_during_wait(self):
"""Test _read_exact handles EOF arriving while waiting for data."""
import asyncio
protocol = self._create_protocol()
# Start read needing more data than we'll provide
read_task = asyncio.create_task(protocol._read_exact(100))
await asyncio.sleep(0.01)
assert not read_task.done()
# Feed some data but not enough
protocol.feed_data(b"partial")
await asyncio.sleep(0.01)
assert not read_task.done()
# Signal EOF - should cause read to return None
protocol.feed_eof()
result = await asyncio.wait_for(read_task, timeout=1.0)
assert result is None
# ============================================================================
# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4)
# ============================================================================
class TestWebSocketFragmentedMessages:
"""Tests for WebSocket fragmented message handling."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance for testing."""
from gunicorn.asgi.websocket import WebSocketProtocol
return WebSocketProtocol(
transport=mock.Mock(),
scope={"type": "websocket", "headers": []},
app=mock.AsyncMock(),
log=mock.Mock(),
)
def _create_masked_frame(self, fin, opcode, payload, mask_key=None):
"""Create a masked WebSocket frame.
Args:
fin: FIN bit (1 for final, 0 for continuation)
opcode: Frame opcode
payload: Frame payload bytes
mask_key: 4-byte masking key (generated if None)
Returns:
bytes: Complete masked frame
"""
if mask_key is None:
mask_key = bytes([0x37, 0xfa, 0x21, 0x3d])
frame = bytearray()
# First byte: FIN + RSV(000) + opcode
frame.append((fin << 7) | opcode)
# Second byte: MASK(1) + length
length = len(payload)
if length < 126:
frame.append(0x80 | length)
elif length < 65536:
frame.append(0x80 | 126)
frame.extend(struct.pack("!H", length))
else:
frame.append(0x80 | 127)
frame.extend(struct.pack("!Q", length))
# Masking key
frame.extend(mask_key)
# Masked payload
masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
frame.extend(masked_payload)
return bytes(frame)
@pytest.mark.asyncio
async def test_fragmented_message_reassembly(self):
"""Test reassembly of fragmented text message with multiple continuation frames."""
from gunicorn.asgi.websocket import (
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT
)
import asyncio
protocol = self._create_protocol()
# Build fragmented message: "Hello" + " " + "World" + "!"
# First frame: opcode=TEXT, FIN=0, payload="Hello"
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
# Continuation frames: opcode=CONTINUATION, FIN=0
frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ")
frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World")
# Final frame: opcode=CONTINUATION, FIN=1
frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!")
# Feed all frames
protocol.feed_data(frame1 + frame2 + frame3 + frame4)
# Read frames - first 3 should return CONTINUATION with empty payload (waiting)
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued
# Final frame should return complete reassembled message with original opcode
result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result4 == (OPCODE_TEXT, b"Hello World!")
@pytest.mark.asyncio
async def test_control_frame_during_fragmentation(self):
"""Test that control frames (ping) can arrive during fragmented message.
RFC 6455 Section 5.4: Control frames MAY be injected in the middle
of a fragmented message.
"""
from gunicorn.asgi.websocket import (
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING
)
import asyncio
protocol = self._create_protocol()
# Start fragmented message
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
# Ping frame in the middle (control frames are always FIN=1)
ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping")
# Continue and finish fragmented message
frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World")
protocol.feed_data(frame1 + ping_frame + frame2)
# First read: fragment started (waiting for more)
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result1 == (OPCODE_CONTINUATION, b"")
# Second read: ping frame (control frames handled separately)
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result2 == (OPCODE_PING, b"ping")
# Third read: complete reassembled message
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
assert result3 == (OPCODE_TEXT, b"Hello World")