diff --git a/gunicorn/asgi/websocket.py b/gunicorn/asgi/websocket.py index d1b2251b..dbb2d2c4 100644 --- a/gunicorn/asgi/websocket.py +++ b/gunicorn/asgi/websocket.py @@ -65,6 +65,11 @@ class WebSocketProtocol: self.close_code = None self.close_reason = "" + # Close handshake state (RFC 6455 Section 7.1.1) + self._close_sent = False + self._close_received = False + self._close_event = asyncio.Event() + # Message reassembly state self._fragments = [] self._fragment_opcode = None @@ -105,16 +110,21 @@ class WebSocketProtocol: except Exception: self.log.exception("Error in WebSocket ASGI application") finally: + # Send close frame if not already closed + if not self.closed and self.accepted and not self._close_sent: + await self._send_close(CLOSE_INTERNAL_ERROR, "Application error") + # Wait for client's close response + try: + await asyncio.wait_for(self._close_event.wait(), timeout=5.0) + except asyncio.TimeoutError: + self.closed = True + read_task.cancel() try: await read_task except asyncio.CancelledError: pass - # Send close frame if not already closed - if not self.closed and self.accepted: - await self._send_close(CLOSE_INTERNAL_ERROR, "Application error") - async def _receive(self): """ASGI receive callable.""" return await self._receive_queue.get() @@ -144,7 +154,14 @@ class WebSocketProtocol: code = message.get("code", CLOSE_NORMAL) reason = message.get("reason", "") await self._send_close(code, reason) - self.closed = True + + # Wait for client's close frame (RFC 6455 close handshake) + try: + await asyncio.wait_for(self._close_event.wait(), timeout=5.0) + except asyncio.TimeoutError: + self.log.debug("WebSocket close handshake timeout") + self.closed = True + self._close_event.set() async def _send_accept(self, message): """Send WebSocket handshake accept response.""" @@ -191,7 +208,9 @@ class WebSocketProtocol: async def _read_frames(self): """Read and process incoming WebSocket frames.""" try: - while not self.closed: + # Continue reading while not closed, or if we sent close but haven't + # received client's close response yet (RFC 6455 close handshake) + while not self.closed or (self._close_sent and not self._close_received): frame = await self._read_frame() if frame is None: break @@ -353,11 +372,14 @@ class WebSocketProtocol: self.close_code = CLOSE_NO_STATUS self.close_reason = "" + self._close_received = True + # Echo close frame back if we haven't already sent one - if not self.closed: + if not self._close_sent: await self._send_close(self.close_code, self.close_reason) self.closed = True + self._close_event.set() async def _handle_continuation(self, payload): # pylint: disable=unused-argument """Handle continuation frame (already processed in _read_frame).""" @@ -394,8 +416,16 @@ class WebSocketProtocol: async def _send_close(self, code, reason=""): """Send a close frame.""" + if self._close_sent: + return # Already sent + payload = struct.pack("!H", code) if reason: payload += reason.encode("utf-8")[:123] # Max 125 bytes total await self._send_frame(OPCODE_CLOSE, payload) - self.closed = True + self._close_sent = True + + # If we already received a close, handshake is complete + if self._close_received: + self.closed = True + self._close_event.set() diff --git a/tests/test_asgi_protocol_compat.py b/tests/test_asgi_protocol_compat.py index 982941fc..1a0a514b 100644 --- a/tests/test_asgi_protocol_compat.py +++ b/tests/test_asgi_protocol_compat.py @@ -441,6 +441,16 @@ class TestWebSocketAcceptThenCloseE2E: written_data = [] transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + protocol = WebSocketProtocol( + transport=transport, + scope={ + "type": "websocket", + "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], + }, + app=None, # Will be replaced + log=mock.Mock(), + ) + # App that accepts then immediately closes (Django pattern) async def close_app(scope, receive, send): # Wait for connect message @@ -453,19 +463,32 @@ class TestWebSocketAcceptThenCloseE2E: # Immediately close with code await send({"type": "websocket.close", "code": 1000}) - protocol = WebSocketProtocol( - transport=transport, - scope={ - "type": "websocket", - "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], - }, - app=close_app, - log=mock.Mock(), - ) + protocol.app = close_app + + # Helper to simulate client close frame response after server sends close + async def feed_client_close_after_delay(): + # Wait for server to send close frame + await asyncio.sleep(0.1) + # Masked close frame with code 1000: FIN=1, opcode=8, masked, len=2 + # Mask key: 0x00000000 for simplicity, payload: 0x03E8 (1000) + client_close = bytes([ + 0x88, # FIN + opcode 8 (close) + 0x82, # Masked + length 2 + 0x00, 0x00, 0x00, 0x00, # Mask key + 0x03, 0xE8, # Close code 1000 (masked with 0s = unchanged) + ]) + protocol.feed_data(client_close) + + # Run both concurrently + async def run_with_client_response(): + await asyncio.gather( + protocol.run(), + feed_client_close_after_delay(), + ) # Run the WebSocket - this should complete without timeout try: - await asyncio.wait_for(protocol.run(), timeout=2.0) + await asyncio.wait_for(run_with_client_response(), timeout=2.0) except asyncio.TimeoutError: pytest.fail("WebSocket run() timed out - close frame likely not sent") @@ -490,6 +513,16 @@ class TestWebSocketAcceptThenCloseE2E: written_data = [] transport.write = mock.Mock(side_effect=lambda d: written_data.append(d)) + protocol = WebSocketProtocol( + transport=transport, + scope={ + "type": "websocket", + "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], + }, + app=None, # Will be replaced + log=mock.Mock(), + ) + async def close_app(scope, receive, send): message = await receive() assert message["type"] == "websocket.connect" @@ -497,17 +530,27 @@ class TestWebSocketAcceptThenCloseE2E: await send({"type": "websocket.accept"}) await send({"type": "websocket.close", "code": 1008}) - protocol = WebSocketProtocol( - transport=transport, - scope={ - "type": "websocket", - "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], - }, - app=close_app, - log=mock.Mock(), - ) + protocol.app = close_app - await asyncio.wait_for(protocol.run(), timeout=2.0) + # Helper to simulate client close frame response + async def feed_client_close_after_delay(): + await asyncio.sleep(0.1) + # Masked close frame with code 1008 + client_close = bytes([ + 0x88, # FIN + opcode 8 (close) + 0x82, # Masked + length 2 + 0x00, 0x00, 0x00, 0x00, # Mask key + 0x03, 0xF0, # Close code 1008 (masked with 0s = unchanged) + ]) + protocol.feed_data(client_close) + + async def run_with_client_response(): + await asyncio.gather( + protocol.run(), + feed_client_close_after_delay(), + ) + + await asyncio.wait_for(run_with_client_response(), timeout=2.0) combined = b"".join(written_data)