mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 18:21:30 +08:00
- Fix path expectation in test_scope_path_preserved (router strips /http prefix) - Fix lifespan state check to use scope_state instead of module_state - Add tolerance for partial failures in proxy concurrent test - Add retry logic with proper assertions in HTTPS proxy FastAPI test
679 lines
23 KiB
Python
679 lines
23 KiB
Python
#
|
|
# This file is part of gunicorn released under the MIT license.
|
|
# See the NOTICE for more information.
|
|
|
|
"""
|
|
WebSocket RFC 6455 compliance tests.
|
|
|
|
Tests that gunicorn's WebSocket implementation conforms to RFC 6455:
|
|
https://tools.ietf.org/html/rfc6455
|
|
"""
|
|
|
|
import asyncio
|
|
import base64
|
|
import hashlib
|
|
import struct
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Constants Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketConstants:
|
|
"""Tests for WebSocket protocol constants."""
|
|
|
|
def test_websocket_guid(self):
|
|
"""Test WebSocket GUID per RFC 6455 Section 1.3."""
|
|
from gunicorn.asgi.websocket import WS_GUID
|
|
|
|
# The GUID is a fixed value specified in RFC 6455
|
|
assert WS_GUID == b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
|
|
|
def test_opcode_continuation(self):
|
|
"""Test continuation frame opcode (0x0)."""
|
|
from gunicorn.asgi.websocket import OPCODE_CONTINUATION
|
|
assert OPCODE_CONTINUATION == 0x0
|
|
|
|
def test_opcode_text(self):
|
|
"""Test text frame opcode (0x1)."""
|
|
from gunicorn.asgi.websocket import OPCODE_TEXT
|
|
assert OPCODE_TEXT == 0x1
|
|
|
|
def test_opcode_binary(self):
|
|
"""Test binary frame opcode (0x2)."""
|
|
from gunicorn.asgi.websocket import OPCODE_BINARY
|
|
assert OPCODE_BINARY == 0x2
|
|
|
|
def test_opcode_close(self):
|
|
"""Test close frame opcode (0x8)."""
|
|
from gunicorn.asgi.websocket import OPCODE_CLOSE
|
|
assert OPCODE_CLOSE == 0x8
|
|
|
|
def test_opcode_ping(self):
|
|
"""Test ping frame opcode (0x9)."""
|
|
from gunicorn.asgi.websocket import OPCODE_PING
|
|
assert OPCODE_PING == 0x9
|
|
|
|
def test_opcode_pong(self):
|
|
"""Test pong frame opcode (0xA)."""
|
|
from gunicorn.asgi.websocket import OPCODE_PONG
|
|
assert OPCODE_PONG == 0xA
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Close Codes Tests (RFC 6455 Section 7.4.1)
|
|
# ============================================================================
|
|
|
|
class TestWebSocketCloseCodes:
|
|
"""Tests for WebSocket close status codes."""
|
|
|
|
def test_close_normal(self):
|
|
"""Test normal closure code (1000)."""
|
|
from gunicorn.asgi.websocket import CLOSE_NORMAL
|
|
assert CLOSE_NORMAL == 1000
|
|
|
|
def test_close_going_away(self):
|
|
"""Test going away code (1001)."""
|
|
from gunicorn.asgi.websocket import CLOSE_GOING_AWAY
|
|
assert CLOSE_GOING_AWAY == 1001
|
|
|
|
def test_close_protocol_error(self):
|
|
"""Test protocol error code (1002)."""
|
|
from gunicorn.asgi.websocket import CLOSE_PROTOCOL_ERROR
|
|
assert CLOSE_PROTOCOL_ERROR == 1002
|
|
|
|
def test_close_unsupported(self):
|
|
"""Test unsupported data code (1003)."""
|
|
from gunicorn.asgi.websocket import CLOSE_UNSUPPORTED
|
|
assert CLOSE_UNSUPPORTED == 1003
|
|
|
|
def test_close_no_status(self):
|
|
"""Test no status received code (1005)."""
|
|
from gunicorn.asgi.websocket import CLOSE_NO_STATUS
|
|
assert CLOSE_NO_STATUS == 1005
|
|
|
|
def test_close_abnormal(self):
|
|
"""Test abnormal closure code (1006)."""
|
|
from gunicorn.asgi.websocket import CLOSE_ABNORMAL
|
|
assert CLOSE_ABNORMAL == 1006
|
|
|
|
def test_close_invalid_data(self):
|
|
"""Test invalid frame payload data code (1007)."""
|
|
from gunicorn.asgi.websocket import CLOSE_INVALID_DATA
|
|
assert CLOSE_INVALID_DATA == 1007
|
|
|
|
def test_close_policy_violation(self):
|
|
"""Test policy violation code (1008)."""
|
|
from gunicorn.asgi.websocket import CLOSE_POLICY_VIOLATION
|
|
assert CLOSE_POLICY_VIOLATION == 1008
|
|
|
|
def test_close_message_too_big(self):
|
|
"""Test message too big code (1009)."""
|
|
from gunicorn.asgi.websocket import CLOSE_MESSAGE_TOO_BIG
|
|
assert CLOSE_MESSAGE_TOO_BIG == 1009
|
|
|
|
def test_close_mandatory_ext(self):
|
|
"""Test mandatory extension code (1010)."""
|
|
from gunicorn.asgi.websocket import CLOSE_MANDATORY_EXT
|
|
assert CLOSE_MANDATORY_EXT == 1010
|
|
|
|
def test_close_internal_error(self):
|
|
"""Test internal server error code (1011)."""
|
|
from gunicorn.asgi.websocket import CLOSE_INTERNAL_ERROR
|
|
assert CLOSE_INTERNAL_ERROR == 1011
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Handshake Tests (RFC 6455 Section 4.2.2)
|
|
# ============================================================================
|
|
|
|
class TestWebSocketHandshake:
|
|
"""Tests for WebSocket handshake implementation."""
|
|
|
|
def test_accept_key_calculation(self):
|
|
"""Test Sec-WebSocket-Accept key calculation per RFC 6455."""
|
|
from gunicorn.asgi.websocket import WS_GUID
|
|
|
|
# Example from RFC 6455 Section 1.3
|
|
client_key = b"dGhlIHNhbXBsZSBub25jZQ=="
|
|
expected_accept = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="
|
|
|
|
# Calculation: Base64(SHA-1(client_key + GUID))
|
|
accept_key = base64.b64encode(
|
|
hashlib.sha1(client_key + WS_GUID).digest()
|
|
).decode("ascii")
|
|
|
|
assert accept_key == expected_accept
|
|
|
|
def test_accept_key_another_example(self):
|
|
"""Test accept key calculation with another key."""
|
|
from gunicorn.asgi.websocket import WS_GUID
|
|
|
|
# Another example key
|
|
client_key = b"x3JJHMbDL1EzLkh9GBhXDw=="
|
|
|
|
accept_key = base64.b64encode(
|
|
hashlib.sha1(client_key + WS_GUID).digest()
|
|
).decode("ascii")
|
|
|
|
# Verify it's a valid base64 string
|
|
assert len(accept_key) == 28 # SHA-1 hash is 20 bytes, base64 encoded
|
|
# Verify we can decode it
|
|
decoded = base64.b64decode(accept_key)
|
|
assert len(decoded) == 20 # SHA-1 produces 20 bytes
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Frame Masking Tests (RFC 6455 Section 5.3)
|
|
# ============================================================================
|
|
|
|
class TestWebSocketFrameMasking:
|
|
"""Tests for WebSocket frame masking/unmasking."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create a WebSocketProtocol instance for testing."""
|
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
|
return WebSocketProtocol(None, None, {}, None, mock.Mock())
|
|
|
|
def test_unmask_simple(self):
|
|
"""Test basic unmasking operation."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Mask key and masked "Hello"
|
|
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
|
# H=0x48, e=0x65, l=0x6c, l=0x6c, o=0x6f
|
|
# Masked: 0x48^0x37=0x7f, 0x65^0xfa=0x9f, 0x6c^0x21=0x4d, 0x6c^0x3d=0x51, 0x6f^0x37=0x58
|
|
masked_data = bytes([0x7f, 0x9f, 0x4d, 0x51, 0x58])
|
|
|
|
unmasked = protocol._unmask(masked_data, masking_key)
|
|
assert unmasked == b"Hello"
|
|
|
|
def test_unmask_empty(self):
|
|
"""Test unmasking empty payload."""
|
|
protocol = self._create_protocol()
|
|
|
|
masking_key = bytes([0x37, 0xfa, 0x21, 0x3d])
|
|
unmasked = protocol._unmask(b"", masking_key)
|
|
|
|
assert unmasked == b""
|
|
|
|
def test_unmask_longer_message(self):
|
|
"""Test unmasking message longer than mask key."""
|
|
protocol = self._create_protocol()
|
|
|
|
# The mask cycles every 4 bytes
|
|
masking_key = bytes([0x01, 0x02, 0x03, 0x04])
|
|
message = b"12345678" # 8 bytes
|
|
|
|
# Manually mask
|
|
masked = bytes(b ^ masking_key[i % 4] for i, b in enumerate(message))
|
|
|
|
# Unmask should give back original
|
|
unmasked = protocol._unmask(masked, masking_key)
|
|
assert unmasked == message
|
|
|
|
def test_unmask_binary_data(self):
|
|
"""Test unmasking binary data."""
|
|
protocol = self._create_protocol()
|
|
|
|
masking_key = bytes([0xAB, 0xCD, 0xEF, 0x01])
|
|
original = bytes([0x00, 0xFF, 0x80, 0x7F, 0x01])
|
|
|
|
# Mask the data
|
|
masked = bytes(b ^ masking_key[i % 4] for i, b in enumerate(original))
|
|
|
|
# Unmask should give back original
|
|
unmasked = protocol._unmask(masked, masking_key)
|
|
assert unmasked == original
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Frame Format Tests (RFC 6455 Section 5.2)
|
|
# ============================================================================
|
|
|
|
class TestWebSocketFrameFormat:
|
|
"""Tests for WebSocket frame format handling."""
|
|
|
|
def test_frame_header_structure(self):
|
|
"""Test understanding of WebSocket frame header structure."""
|
|
# First byte: FIN(1) + RSV1(1) + RSV2(1) + RSV3(1) + OPCODE(4)
|
|
# Second byte: MASK(1) + PAYLOAD_LEN(7)
|
|
|
|
# Text frame, FIN=1, no RSV bits, opcode=0x1
|
|
first_byte = 0b10000001 # 0x81
|
|
assert (first_byte >> 7) & 1 == 1 # FIN
|
|
assert (first_byte >> 6) & 1 == 0 # RSV1
|
|
assert (first_byte >> 5) & 1 == 0 # RSV2
|
|
assert (first_byte >> 4) & 1 == 0 # RSV3
|
|
assert first_byte & 0x0F == 1 # OPCODE (text)
|
|
|
|
def test_payload_length_7bit(self):
|
|
"""Test 7-bit payload length encoding (0-125)."""
|
|
# Payload length 100
|
|
second_byte = 0b10000000 | 100 # MASK=1, length=100
|
|
assert (second_byte >> 7) & 1 == 1 # MASK bit
|
|
assert second_byte & 0x7F == 100 # Length
|
|
|
|
def test_payload_length_16bit(self):
|
|
"""Test 16-bit payload length encoding (126 indicator)."""
|
|
# Length 126 indicates next 2 bytes contain the length
|
|
second_byte = 0b10000000 | 126 # MASK=1, length indicator=126
|
|
assert second_byte & 0x7F == 126
|
|
|
|
# Extended length as big-endian 16-bit
|
|
extended_length = 1000
|
|
packed = struct.pack("!H", extended_length)
|
|
assert struct.unpack("!H", packed)[0] == 1000
|
|
|
|
def test_payload_length_64bit(self):
|
|
"""Test 64-bit payload length encoding (127 indicator)."""
|
|
# Length 127 indicates next 8 bytes contain the length
|
|
second_byte = 0b10000000 | 127 # MASK=1, length indicator=127
|
|
assert second_byte & 0x7F == 127
|
|
|
|
# Extended length as big-endian 64-bit
|
|
extended_length = 100000
|
|
packed = struct.pack("!Q", extended_length)
|
|
assert struct.unpack("!Q", packed)[0] == 100000
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Protocol Instance Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketProtocolInstance:
|
|
"""Tests for WebSocketProtocol instance state."""
|
|
|
|
def _create_protocol(self, scope=None):
|
|
"""Create a WebSocketProtocol instance."""
|
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
|
|
|
if scope is None:
|
|
scope = {
|
|
"type": "websocket",
|
|
"headers": [],
|
|
}
|
|
|
|
return WebSocketProtocol(
|
|
transport=mock.Mock(),
|
|
reader=mock.Mock(),
|
|
scope=scope,
|
|
app=mock.AsyncMock(),
|
|
log=mock.Mock(),
|
|
)
|
|
|
|
def test_initial_state(self):
|
|
"""Test initial protocol state."""
|
|
protocol = self._create_protocol()
|
|
|
|
assert protocol.accepted is False
|
|
assert protocol.closed is False
|
|
assert protocol.close_code is None
|
|
assert protocol.close_reason == ""
|
|
|
|
def test_fragment_state_initial(self):
|
|
"""Test initial fragment reassembly state."""
|
|
protocol = self._create_protocol()
|
|
|
|
assert protocol._fragments == []
|
|
assert protocol._fragment_opcode is None
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket ASGI Message Format Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketASGIMessages:
|
|
"""Tests for WebSocket ASGI message formats."""
|
|
|
|
def test_websocket_connect_message(self):
|
|
"""Test websocket.connect message format."""
|
|
message = {"type": "websocket.connect"}
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
def test_websocket_accept_message(self):
|
|
"""Test websocket.accept message format."""
|
|
message = {
|
|
"type": "websocket.accept",
|
|
"subprotocol": "graphql-ws",
|
|
"headers": [
|
|
(b"x-custom-header", b"value"),
|
|
],
|
|
}
|
|
|
|
assert message["type"] == "websocket.accept"
|
|
assert message["subprotocol"] == "graphql-ws"
|
|
|
|
def test_websocket_accept_minimal(self):
|
|
"""Test minimal websocket.accept message."""
|
|
message = {"type": "websocket.accept"}
|
|
assert message["type"] == "websocket.accept"
|
|
|
|
def test_websocket_receive_text_message(self):
|
|
"""Test websocket.receive message with text."""
|
|
message = {
|
|
"type": "websocket.receive",
|
|
"text": "Hello, WebSocket!",
|
|
}
|
|
|
|
assert message["type"] == "websocket.receive"
|
|
assert "text" in message
|
|
assert isinstance(message["text"], str)
|
|
|
|
def test_websocket_receive_binary_message(self):
|
|
"""Test websocket.receive message with binary data."""
|
|
message = {
|
|
"type": "websocket.receive",
|
|
"bytes": b"\x00\x01\x02\x03",
|
|
}
|
|
|
|
assert message["type"] == "websocket.receive"
|
|
assert "bytes" in message
|
|
assert isinstance(message["bytes"], bytes)
|
|
|
|
def test_websocket_send_text_message(self):
|
|
"""Test websocket.send message with text."""
|
|
message = {
|
|
"type": "websocket.send",
|
|
"text": "Response text",
|
|
}
|
|
|
|
assert message["type"] == "websocket.send"
|
|
assert message["text"] == "Response text"
|
|
|
|
def test_websocket_send_binary_message(self):
|
|
"""Test websocket.send message with binary."""
|
|
message = {
|
|
"type": "websocket.send",
|
|
"bytes": b"\xFF\xFE\xFD",
|
|
}
|
|
|
|
assert message["type"] == "websocket.send"
|
|
assert message["bytes"] == b"\xFF\xFE\xFD"
|
|
|
|
def test_websocket_disconnect_message(self):
|
|
"""Test websocket.disconnect message format."""
|
|
message = {
|
|
"type": "websocket.disconnect",
|
|
"code": 1000,
|
|
}
|
|
|
|
assert message["type"] == "websocket.disconnect"
|
|
assert message["code"] == 1000
|
|
|
|
def test_websocket_close_message(self):
|
|
"""Test websocket.close message format."""
|
|
message = {
|
|
"type": "websocket.close",
|
|
"code": 1000,
|
|
"reason": "Normal closure",
|
|
}
|
|
|
|
assert message["type"] == "websocket.close"
|
|
assert message["code"] == 1000
|
|
assert message["reason"] == "Normal closure"
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Upgrade Detection Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketUpgradeDetection:
|
|
"""Tests for WebSocket upgrade request detection."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create an ASGIProtocol instance for testing."""
|
|
from gunicorn.asgi.protocol import ASGIProtocol
|
|
|
|
worker = mock.Mock()
|
|
from gunicorn.config import Config
|
|
worker.cfg = Config()
|
|
worker.log = mock.Mock()
|
|
worker.asgi = mock.Mock()
|
|
|
|
return ASGIProtocol(worker)
|
|
|
|
def _create_mock_request(self, method="GET", headers=None):
|
|
"""Create a mock HTTP request."""
|
|
request = mock.Mock()
|
|
request.method = method
|
|
request.headers = headers or []
|
|
return request
|
|
|
|
def test_valid_websocket_upgrade(self):
|
|
"""Test detection of valid WebSocket upgrade request."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("UPGRADE", "websocket"),
|
|
("CONNECTION", "upgrade"),
|
|
]
|
|
)
|
|
|
|
assert protocol._is_websocket_upgrade(request) is True
|
|
|
|
def test_websocket_upgrade_case_insensitive(self):
|
|
"""Test WebSocket upgrade detection is case-insensitive."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("UPGRADE", "WebSocket"),
|
|
("CONNECTION", "Upgrade"),
|
|
]
|
|
)
|
|
|
|
assert protocol._is_websocket_upgrade(request) is True
|
|
|
|
def test_websocket_upgrade_connection_with_keep_alive(self):
|
|
"""Test WebSocket upgrade with Connection: upgrade, keep-alive."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("UPGRADE", "websocket"),
|
|
("CONNECTION", "upgrade, keep-alive"),
|
|
]
|
|
)
|
|
|
|
assert protocol._is_websocket_upgrade(request) is True
|
|
|
|
def test_not_websocket_wrong_method(self):
|
|
"""Test non-GET methods are not WebSocket upgrades."""
|
|
protocol = self._create_protocol()
|
|
|
|
for method in ["POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"]:
|
|
request = self._create_mock_request(
|
|
method=method,
|
|
headers=[
|
|
("UPGRADE", "websocket"),
|
|
("CONNECTION", "upgrade"),
|
|
]
|
|
)
|
|
assert protocol._is_websocket_upgrade(request) is False
|
|
|
|
def test_not_websocket_missing_upgrade(self):
|
|
"""Test missing Upgrade header."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("CONNECTION", "upgrade"),
|
|
]
|
|
)
|
|
|
|
assert protocol._is_websocket_upgrade(request) is False
|
|
|
|
def test_not_websocket_missing_connection(self):
|
|
"""Test missing Connection header."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("UPGRADE", "websocket"),
|
|
]
|
|
)
|
|
|
|
# Result should be falsy (None or False) when Connection header is missing
|
|
assert not protocol._is_websocket_upgrade(request)
|
|
|
|
def test_not_websocket_wrong_upgrade_value(self):
|
|
"""Test Upgrade header with wrong value."""
|
|
protocol = self._create_protocol()
|
|
request = self._create_mock_request(
|
|
method="GET",
|
|
headers=[
|
|
("UPGRADE", "h2c"),
|
|
("CONNECTION", "upgrade"),
|
|
]
|
|
)
|
|
|
|
assert protocol._is_websocket_upgrade(request) is False
|
|
|
|
|
|
# ============================================================================
|
|
# WebSocket Close Frame Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketCloseFrame:
|
|
"""Tests for WebSocket close frame handling."""
|
|
|
|
def test_close_frame_payload_format(self):
|
|
"""Test close frame payload format (code + reason)."""
|
|
from gunicorn.asgi.websocket import CLOSE_NORMAL
|
|
|
|
code = CLOSE_NORMAL
|
|
reason = "Goodbye"
|
|
|
|
# Close frame payload: 2-byte big-endian code + UTF-8 reason
|
|
payload = struct.pack("!H", code) + reason.encode("utf-8")
|
|
|
|
# Parse it back
|
|
parsed_code = struct.unpack("!H", payload[:2])[0]
|
|
parsed_reason = payload[2:].decode("utf-8")
|
|
|
|
assert parsed_code == 1000
|
|
assert parsed_reason == "Goodbye"
|
|
|
|
def test_close_frame_empty_reason(self):
|
|
"""Test close frame with empty reason."""
|
|
from gunicorn.asgi.websocket import CLOSE_NORMAL
|
|
|
|
payload = struct.pack("!H", CLOSE_NORMAL)
|
|
|
|
parsed_code = struct.unpack("!H", payload[:2])[0]
|
|
parsed_reason = payload[2:].decode("utf-8")
|
|
|
|
assert parsed_code == 1000
|
|
assert parsed_reason == ""
|
|
|
|
def test_close_frame_max_reason_length(self):
|
|
"""Test close frame reason max length (125 - 2 = 123 bytes)."""
|
|
from gunicorn.asgi.websocket import CLOSE_NORMAL
|
|
|
|
# Control frames have max 125 bytes payload
|
|
# 2 bytes for code, leaving 123 for reason
|
|
max_reason = "x" * 123
|
|
|
|
payload = struct.pack("!H", CLOSE_NORMAL) + max_reason.encode("utf-8")
|
|
|
|
assert len(payload) == 125 # Max control frame payload
|
|
|
|
|
|
# ============================================================================
|
|
# Async WebSocket Tests
|
|
# ============================================================================
|
|
|
|
class TestWebSocketAsync:
|
|
"""Async tests for WebSocket protocol."""
|
|
|
|
def _create_protocol(self, scope=None):
|
|
"""Create a WebSocketProtocol instance."""
|
|
from gunicorn.asgi.websocket import WebSocketProtocol
|
|
|
|
if scope is None:
|
|
scope = {
|
|
"type": "websocket",
|
|
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
|
|
}
|
|
|
|
transport = mock.Mock()
|
|
reader = mock.Mock()
|
|
|
|
return WebSocketProtocol(
|
|
transport=transport,
|
|
reader=reader,
|
|
scope=scope,
|
|
app=mock.AsyncMock(),
|
|
log=mock.Mock(),
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_receive_returns_from_queue(self):
|
|
"""Test that _receive returns items from queue."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Put a message on the queue
|
|
await protocol._receive_queue.put({"type": "websocket.connect"})
|
|
|
|
# Receive should return it
|
|
message = await protocol._receive()
|
|
assert message["type"] == "websocket.connect"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_accept_sets_flag(self):
|
|
"""Test that sending accept sets the accepted flag."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Configure mock transport
|
|
protocol.transport.write = mock.Mock()
|
|
|
|
await protocol._send({"type": "websocket.accept"})
|
|
|
|
assert protocol.accepted is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_accept_twice_raises(self):
|
|
"""Test that accepting twice raises RuntimeError."""
|
|
protocol = self._create_protocol()
|
|
protocol.transport.write = mock.Mock()
|
|
|
|
await protocol._send({"type": "websocket.accept"})
|
|
|
|
with pytest.raises(RuntimeError, match="already accepted"):
|
|
await protocol._send({"type": "websocket.accept"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_before_accept_raises(self):
|
|
"""Test that sending data before accept raises RuntimeError."""
|
|
protocol = self._create_protocol()
|
|
|
|
with pytest.raises(RuntimeError, match="not accepted"):
|
|
await protocol._send({"type": "websocket.send", "text": "hello"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_after_close_raises(self):
|
|
"""Test that sending after close raises RuntimeError."""
|
|
protocol = self._create_protocol()
|
|
protocol.transport.write = mock.Mock()
|
|
|
|
await protocol._send({"type": "websocket.accept"})
|
|
protocol.closed = True
|
|
|
|
with pytest.raises(RuntimeError, match="closed"):
|
|
await protocol._send({"type": "websocket.send", "text": "hello"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_close_sets_flag(self):
|
|
"""Test that sending close sets the closed flag."""
|
|
protocol = self._create_protocol()
|
|
protocol.transport.write = mock.Mock()
|
|
|
|
await protocol._send({"type": "websocket.close", "code": 1000})
|
|
|
|
assert protocol.closed is True
|