Fix WebSocket close handshake to comply with RFC 6455

- Add _close_sent, _close_received, _close_event state variables
- Server now waits for client's close frame response before marking
  connection as closed (5s timeout)
- Update _read_frames loop to continue reading after sending close
- Fix tests to simulate client close frame response
This commit is contained in:
Benoit Chesneau 2026-04-03 14:53:36 +02:00
parent 47bd20a7cb
commit 3fc9a2f002
2 changed files with 101 additions and 28 deletions

View File

@ -65,6 +65,11 @@ class WebSocketProtocol:
self.close_code = None
self.close_reason = ""
# Close handshake state (RFC 6455 Section 7.1.1)
self._close_sent = False
self._close_received = False
self._close_event = asyncio.Event()
# Message reassembly state
self._fragments = []
self._fragment_opcode = None
@ -105,16 +110,21 @@ class WebSocketProtocol:
except Exception:
self.log.exception("Error in WebSocket ASGI application")
finally:
# Send close frame if not already closed
if not self.closed and self.accepted and not self._close_sent:
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
# Wait for client's close response
try:
await asyncio.wait_for(self._close_event.wait(), timeout=5.0)
except asyncio.TimeoutError:
self.closed = True
read_task.cancel()
try:
await read_task
except asyncio.CancelledError:
pass
# Send close frame if not already closed
if not self.closed and self.accepted:
await self._send_close(CLOSE_INTERNAL_ERROR, "Application error")
async def _receive(self):
"""ASGI receive callable."""
return await self._receive_queue.get()
@ -144,7 +154,14 @@ class WebSocketProtocol:
code = message.get("code", CLOSE_NORMAL)
reason = message.get("reason", "")
await self._send_close(code, reason)
self.closed = True
# Wait for client's close frame (RFC 6455 close handshake)
try:
await asyncio.wait_for(self._close_event.wait(), timeout=5.0)
except asyncio.TimeoutError:
self.log.debug("WebSocket close handshake timeout")
self.closed = True
self._close_event.set()
async def _send_accept(self, message):
"""Send WebSocket handshake accept response."""
@ -191,7 +208,9 @@ class WebSocketProtocol:
async def _read_frames(self):
"""Read and process incoming WebSocket frames."""
try:
while not self.closed:
# Continue reading while not closed, or if we sent close but haven't
# received client's close response yet (RFC 6455 close handshake)
while not self.closed or (self._close_sent and not self._close_received):
frame = await self._read_frame()
if frame is None:
break
@ -353,11 +372,14 @@ class WebSocketProtocol:
self.close_code = CLOSE_NO_STATUS
self.close_reason = ""
self._close_received = True
# Echo close frame back if we haven't already sent one
if not self.closed:
if not self._close_sent:
await self._send_close(self.close_code, self.close_reason)
self.closed = True
self._close_event.set()
async def _handle_continuation(self, payload): # pylint: disable=unused-argument
"""Handle continuation frame (already processed in _read_frame)."""
@ -394,8 +416,16 @@ class WebSocketProtocol:
async def _send_close(self, code, reason=""):
"""Send a close frame."""
if self._close_sent:
return # Already sent
payload = struct.pack("!H", code)
if reason:
payload += reason.encode("utf-8")[:123] # Max 125 bytes total
await self._send_frame(OPCODE_CLOSE, payload)
self.closed = True
self._close_sent = True
# If we already received a close, handshake is complete
if self._close_received:
self.closed = True
self._close_event.set()

View File

@ -441,6 +441,16 @@ class TestWebSocketAcceptThenCloseE2E:
written_data = []
transport.write = mock.Mock(side_effect=lambda d: written_data.append(d))
protocol = WebSocketProtocol(
transport=transport,
scope={
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
},
app=None, # Will be replaced
log=mock.Mock(),
)
# App that accepts then immediately closes (Django pattern)
async def close_app(scope, receive, send):
# Wait for connect message
@ -453,19 +463,32 @@ class TestWebSocketAcceptThenCloseE2E:
# Immediately close with code
await send({"type": "websocket.close", "code": 1000})
protocol = WebSocketProtocol(
transport=transport,
scope={
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
},
app=close_app,
log=mock.Mock(),
)
protocol.app = close_app
# Helper to simulate client close frame response after server sends close
async def feed_client_close_after_delay():
# Wait for server to send close frame
await asyncio.sleep(0.1)
# Masked close frame with code 1000: FIN=1, opcode=8, masked, len=2
# Mask key: 0x00000000 for simplicity, payload: 0x03E8 (1000)
client_close = bytes([
0x88, # FIN + opcode 8 (close)
0x82, # Masked + length 2
0x00, 0x00, 0x00, 0x00, # Mask key
0x03, 0xE8, # Close code 1000 (masked with 0s = unchanged)
])
protocol.feed_data(client_close)
# Run both concurrently
async def run_with_client_response():
await asyncio.gather(
protocol.run(),
feed_client_close_after_delay(),
)
# Run the WebSocket - this should complete without timeout
try:
await asyncio.wait_for(protocol.run(), timeout=2.0)
await asyncio.wait_for(run_with_client_response(), timeout=2.0)
except asyncio.TimeoutError:
pytest.fail("WebSocket run() timed out - close frame likely not sent")
@ -490,6 +513,16 @@ class TestWebSocketAcceptThenCloseE2E:
written_data = []
transport.write = mock.Mock(side_effect=lambda d: written_data.append(d))
protocol = WebSocketProtocol(
transport=transport,
scope={
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
},
app=None, # Will be replaced
log=mock.Mock(),
)
async def close_app(scope, receive, send):
message = await receive()
assert message["type"] == "websocket.connect"
@ -497,17 +530,27 @@ class TestWebSocketAcceptThenCloseE2E:
await send({"type": "websocket.accept"})
await send({"type": "websocket.close", "code": 1008})
protocol = WebSocketProtocol(
transport=transport,
scope={
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
},
app=close_app,
log=mock.Mock(),
)
protocol.app = close_app
await asyncio.wait_for(protocol.run(), timeout=2.0)
# Helper to simulate client close frame response
async def feed_client_close_after_delay():
await asyncio.sleep(0.1)
# Masked close frame with code 1008
client_close = bytes([
0x88, # FIN + opcode 8 (close)
0x82, # Masked + length 2
0x00, 0x00, 0x00, 0x00, # Mask key
0x03, 0xF0, # Close code 1008 (masked with 0s = unchanged)
])
protocol.feed_data(client_close)
async def run_with_client_response():
await asyncio.gather(
protocol.run(),
feed_client_close_after_delay(),
)
await asyncio.wait_for(run_with_client_response(), timeout=2.0)
combined = b"".join(written_data)