mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
Fix WebSocket race condition in callback-based _read_exact()
Add double-check after clearing _data_event to prevent deadlock when data arrives between clear() and wait(). The race condition occurred when: 1. Task A checks buffer, needs more data 2. Task A clears _data_event 3. Task B (feed_data) sets event 4. Task A awaits on cleared event - deadlock The fix re-checks the buffer after clear() to catch data that arrived in the race window. Also adds tests for edge cases: race condition simulation, EOF during wait, fragmented message reassembly, and control frames during fragmentation.
This commit is contained in:
parent
f76a4942c3
commit
241c479701
@ -40,20 +40,22 @@ WS_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class WebSocketProtocol:
|
||||
"""WebSocket connection handler for ASGI applications."""
|
||||
"""WebSocket connection handler for ASGI applications.
|
||||
|
||||
def __init__(self, transport, reader, scope, app, log):
|
||||
Uses callback-based data feeding instead of StreamReader for efficiency.
|
||||
Data is fed via feed_data() from the parent protocol's data_received().
|
||||
"""
|
||||
|
||||
def __init__(self, transport, scope, app, log):
|
||||
"""Initialize WebSocket protocol handler.
|
||||
|
||||
Args:
|
||||
transport: asyncio transport for writing
|
||||
reader: asyncio StreamReader for reading
|
||||
scope: ASGI WebSocket scope dict
|
||||
app: ASGI application callable
|
||||
log: Logger instance
|
||||
"""
|
||||
self.transport = transport
|
||||
self.reader = reader
|
||||
self.scope = scope
|
||||
self.app = app
|
||||
self.log = log
|
||||
@ -70,6 +72,26 @@ class WebSocketProtocol:
|
||||
# Receive queue for incoming messages
|
||||
self._receive_queue = asyncio.Queue()
|
||||
|
||||
# Callback-based data reception (replaces StreamReader)
|
||||
self._buffer = bytearray()
|
||||
self._data_event = asyncio.Event()
|
||||
self._eof = False
|
||||
|
||||
def feed_data(self, data):
|
||||
"""Feed incoming data from the parent protocol's data_received().
|
||||
|
||||
Args:
|
||||
data: bytes received on the connection
|
||||
"""
|
||||
if data:
|
||||
self._buffer.extend(data)
|
||||
self._data_event.set()
|
||||
|
||||
def feed_eof(self):
|
||||
"""Signal that the connection has been closed."""
|
||||
self._eof = True
|
||||
self._data_event.set()
|
||||
|
||||
async def run(self):
|
||||
"""Run the WebSocket ASGI application."""
|
||||
# Send initial connect event
|
||||
@ -295,14 +317,25 @@ class WebSocketProtocol:
|
||||
return (opcode, payload)
|
||||
|
||||
async def _read_exact(self, n):
|
||||
"""Read exactly n bytes from the reader."""
|
||||
try:
|
||||
data = await self.reader.readexactly(n)
|
||||
return data
|
||||
except asyncio.IncompleteReadError:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
"""Read exactly n bytes from internal buffer.
|
||||
|
||||
Waits for data via the callback-fed buffer instead of StreamReader.
|
||||
"""
|
||||
while len(self._buffer) < n:
|
||||
if self._eof:
|
||||
return None
|
||||
self._data_event.clear()
|
||||
# Critical: check buffer AGAIN after clearing to avoid race
|
||||
# condition where data arrives between clear() and wait()
|
||||
if len(self._buffer) >= n:
|
||||
break
|
||||
await self._data_event.wait()
|
||||
if self._eof and len(self._buffer) < n:
|
||||
return None
|
||||
|
||||
data = bytes(self._buffer[:n])
|
||||
del self._buffer[:n]
|
||||
return data
|
||||
|
||||
def _unmask(self, payload, masking_key):
|
||||
"""Unmask WebSocket payload data."""
|
||||
|
||||
@ -175,7 +175,7 @@ class TestWebSocketFrameMasking:
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(None, None, {}, None, mock.Mock())
|
||||
return WebSocketProtocol(None, {}, None, mock.Mock())
|
||||
|
||||
def test_unmask_simple(self):
|
||||
"""Test basic unmasking operation."""
|
||||
@ -298,7 +298,6 @@ class TestWebSocketProtocolInstance:
|
||||
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
reader=mock.Mock(),
|
||||
scope=scope,
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
@ -601,11 +600,9 @@ class TestWebSocketAsync:
|
||||
}
|
||||
|
||||
transport = mock.Mock()
|
||||
reader = mock.Mock()
|
||||
|
||||
return WebSocketProtocol(
|
||||
transport=transport,
|
||||
reader=reader,
|
||||
scope=scope,
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
@ -675,3 +672,367 @@ class TestWebSocketAsync:
|
||||
await protocol._send({"type": "websocket.close", "code": 1000})
|
||||
|
||||
assert protocol.closed is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Callback-based Data Feeding Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketCallbackDataFeeding:
|
||||
"""Tests for callback-based data feeding (replaces StreamReader)."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
def test_initial_buffer_empty(self):
|
||||
"""Test that initial buffer is empty."""
|
||||
protocol = self._create_protocol()
|
||||
assert len(protocol._buffer) == 0
|
||||
assert protocol._eof is False
|
||||
|
||||
def test_feed_data_adds_to_buffer(self):
|
||||
"""Test that feed_data adds bytes to buffer."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"Hello")
|
||||
assert bytes(protocol._buffer) == b"Hello"
|
||||
|
||||
protocol.feed_data(b" World")
|
||||
assert bytes(protocol._buffer) == b"Hello World"
|
||||
|
||||
def test_feed_data_ignores_empty(self):
|
||||
"""Test that feed_data ignores empty data."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"")
|
||||
assert len(protocol._buffer) == 0
|
||||
|
||||
protocol.feed_data(None)
|
||||
# Should not raise, just be ignored
|
||||
|
||||
def test_feed_data_sets_event(self):
|
||||
"""Test that feed_data sets the data event."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert not protocol._data_event.is_set()
|
||||
protocol.feed_data(b"data")
|
||||
assert protocol._data_event.is_set()
|
||||
|
||||
def test_feed_eof_sets_flag(self):
|
||||
"""Test that feed_eof sets the EOF flag."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert protocol._eof is False
|
||||
protocol.feed_eof()
|
||||
assert protocol._eof is True
|
||||
|
||||
def test_feed_eof_sets_event(self):
|
||||
"""Test that feed_eof sets the data event."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
assert not protocol._data_event.is_set()
|
||||
protocol.feed_eof()
|
||||
assert protocol._data_event.is_set()
|
||||
|
||||
|
||||
class TestWebSocketReadExact:
|
||||
"""Tests for _read_exact method with callback-based buffer."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_with_sufficient_data(self):
|
||||
"""Test _read_exact returns data when buffer has enough."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Pre-fill buffer
|
||||
protocol.feed_data(b"Hello World")
|
||||
|
||||
result = await protocol._read_exact(5)
|
||||
assert result == b"Hello"
|
||||
assert bytes(protocol._buffer) == b" World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_consumes_buffer(self):
|
||||
"""Test _read_exact properly consumes buffer."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"ABCDEFGH")
|
||||
|
||||
result1 = await protocol._read_exact(3)
|
||||
assert result1 == b"ABC"
|
||||
|
||||
result2 = await protocol._read_exact(3)
|
||||
assert result2 == b"DEF"
|
||||
|
||||
assert bytes(protocol._buffer) == b"GH"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_returns_none_on_eof(self):
|
||||
"""Test _read_exact returns None when EOF with insufficient data."""
|
||||
protocol = self._create_protocol()
|
||||
|
||||
protocol.feed_data(b"Hi")
|
||||
protocol.feed_eof()
|
||||
|
||||
# Request more data than available after EOF
|
||||
result = await protocol._read_exact(10)
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_waits_for_data(self):
|
||||
"""Test _read_exact waits when buffer is insufficient."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read that needs more data
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
|
||||
# Give task a chance to start waiting
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Feed enough data
|
||||
protocol.feed_data(b"1234567890")
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_handles_incremental_data(self):
|
||||
"""Test _read_exact handles data arriving in chunks."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read needing 10 bytes
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
# Feed data incrementally
|
||||
protocol.feed_data(b"123")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
protocol.feed_data(b"456")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
protocol.feed_data(b"7890")
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_race_condition(self):
|
||||
"""Test _read_exact handles race condition when data arrives during clear/wait gap.
|
||||
|
||||
This tests the fix for the race condition where:
|
||||
1. Task A checks buffer, needs more data
|
||||
2. Task A clears _data_event
|
||||
3. Task B (data_received) calls feed_data(), sets event
|
||||
4. Task A would wait forever on cleared event - DEADLOCK
|
||||
|
||||
The fix adds a buffer check after clear() to catch this case.
|
||||
"""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Pre-fill with partial data
|
||||
protocol.feed_data(b"12345")
|
||||
|
||||
# Start read needing 10 bytes
|
||||
read_task = asyncio.create_task(protocol._read_exact(10))
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Simulate race: feed remaining data rapidly
|
||||
# In the buggy version, if data arrives right after clear() but before wait(),
|
||||
# the event gets set then immediately the wait() would block on a stale clear
|
||||
protocol.feed_data(b"67890")
|
||||
|
||||
# Should complete without deadlock
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result == b"1234567890"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_multiple_feeds_before_wait(self):
|
||||
"""Test _read_exact when all data arrives before wait starts."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Feed all data before starting read - should not block
|
||||
protocol.feed_data(b"Complete message here")
|
||||
|
||||
result = await asyncio.wait_for(protocol._read_exact(8), timeout=0.1)
|
||||
assert result == b"Complete"
|
||||
|
||||
# Buffer should have remainder
|
||||
assert bytes(protocol._buffer) == b" message here"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_exact_eof_during_wait(self):
|
||||
"""Test _read_exact handles EOF arriving while waiting for data."""
|
||||
import asyncio
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start read needing more data than we'll provide
|
||||
read_task = asyncio.create_task(protocol._read_exact(100))
|
||||
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Feed some data but not enough
|
||||
protocol.feed_data(b"partial")
|
||||
await asyncio.sleep(0.01)
|
||||
assert not read_task.done()
|
||||
|
||||
# Signal EOF - should cause read to return None
|
||||
protocol.feed_eof()
|
||||
|
||||
result = await asyncio.wait_for(read_task, timeout=1.0)
|
||||
assert result is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Fragmented Message Tests (RFC 6455 Section 5.4)
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocketFragmentedMessages:
|
||||
"""Tests for WebSocket fragmented message handling."""
|
||||
|
||||
def _create_protocol(self):
|
||||
"""Create a WebSocketProtocol instance for testing."""
|
||||
from gunicorn.asgi.websocket import WebSocketProtocol
|
||||
return WebSocketProtocol(
|
||||
transport=mock.Mock(),
|
||||
scope={"type": "websocket", "headers": []},
|
||||
app=mock.AsyncMock(),
|
||||
log=mock.Mock(),
|
||||
)
|
||||
|
||||
def _create_masked_frame(self, fin, opcode, payload, mask_key=None):
|
||||
"""Create a masked WebSocket frame.
|
||||
|
||||
Args:
|
||||
fin: FIN bit (1 for final, 0 for continuation)
|
||||
opcode: Frame opcode
|
||||
payload: Frame payload bytes
|
||||
mask_key: 4-byte masking key (generated if None)
|
||||
|
||||
Returns:
|
||||
bytes: Complete masked frame
|
||||
"""
|
||||
if mask_key is None:
|
||||
mask_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
||||
|
||||
frame = bytearray()
|
||||
|
||||
# First byte: FIN + RSV(000) + opcode
|
||||
frame.append((fin << 7) | opcode)
|
||||
|
||||
# Second byte: MASK(1) + length
|
||||
length = len(payload)
|
||||
if length < 126:
|
||||
frame.append(0x80 | length)
|
||||
elif length < 65536:
|
||||
frame.append(0x80 | 126)
|
||||
frame.extend(struct.pack("!H", length))
|
||||
else:
|
||||
frame.append(0x80 | 127)
|
||||
frame.extend(struct.pack("!Q", length))
|
||||
|
||||
# Masking key
|
||||
frame.extend(mask_key)
|
||||
|
||||
# Masked payload
|
||||
masked_payload = bytes(b ^ mask_key[i % 4] for i, b in enumerate(payload))
|
||||
frame.extend(masked_payload)
|
||||
|
||||
return bytes(frame)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fragmented_message_reassembly(self):
|
||||
"""Test reassembly of fragmented text message with multiple continuation frames."""
|
||||
from gunicorn.asgi.websocket import (
|
||||
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_CONTINUATION as CONT
|
||||
)
|
||||
import asyncio
|
||||
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Build fragmented message: "Hello" + " " + "World" + "!"
|
||||
# First frame: opcode=TEXT, FIN=0, payload="Hello"
|
||||
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
|
||||
# Continuation frames: opcode=CONTINUATION, FIN=0
|
||||
frame2 = self._create_masked_frame(fin=0, opcode=CONT, payload=b" ")
|
||||
frame3 = self._create_masked_frame(fin=0, opcode=CONT, payload=b"World")
|
||||
# Final frame: opcode=CONTINUATION, FIN=1
|
||||
frame4 = self._create_masked_frame(fin=1, opcode=CONT, payload=b"!")
|
||||
|
||||
# Feed all frames
|
||||
protocol.feed_data(frame1 + frame2 + frame3 + frame4)
|
||||
|
||||
# Read frames - first 3 should return CONTINUATION with empty payload (waiting)
|
||||
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result1 == (OPCODE_CONTINUATION, b"") # Fragment started
|
||||
|
||||
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result2 == (OPCODE_CONTINUATION, b"") # Fragment continued
|
||||
|
||||
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result3 == (OPCODE_CONTINUATION, b"") # Fragment continued
|
||||
|
||||
# Final frame should return complete reassembled message with original opcode
|
||||
result4 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result4 == (OPCODE_TEXT, b"Hello World!")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_frame_during_fragmentation(self):
|
||||
"""Test that control frames (ping) can arrive during fragmented message.
|
||||
|
||||
RFC 6455 Section 5.4: Control frames MAY be injected in the middle
|
||||
of a fragmented message.
|
||||
"""
|
||||
from gunicorn.asgi.websocket import (
|
||||
OPCODE_TEXT, OPCODE_CONTINUATION, OPCODE_PING
|
||||
)
|
||||
import asyncio
|
||||
|
||||
protocol = self._create_protocol()
|
||||
|
||||
# Start fragmented message
|
||||
frame1 = self._create_masked_frame(fin=0, opcode=OPCODE_TEXT, payload=b"Hello")
|
||||
# Ping frame in the middle (control frames are always FIN=1)
|
||||
ping_frame = self._create_masked_frame(fin=1, opcode=OPCODE_PING, payload=b"ping")
|
||||
# Continue and finish fragmented message
|
||||
frame2 = self._create_masked_frame(fin=1, opcode=OPCODE_CONTINUATION, payload=b" World")
|
||||
|
||||
protocol.feed_data(frame1 + ping_frame + frame2)
|
||||
|
||||
# First read: fragment started (waiting for more)
|
||||
result1 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result1 == (OPCODE_CONTINUATION, b"")
|
||||
|
||||
# Second read: ping frame (control frames handled separately)
|
||||
result2 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result2 == (OPCODE_PING, b"ping")
|
||||
|
||||
# Third read: complete reassembled message
|
||||
result3 = await asyncio.wait_for(protocol._read_frame(), timeout=1.0)
|
||||
assert result3 == (OPCODE_TEXT, b"Hello World")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user