From 6d2139bb6cf1ad08fb721a1e4c4b7b1bf792315b Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:01:21 +0100 Subject: [PATCH] feat(dirty): update worker for binary protocol Update worker tests to work with the binary protocol: - Use integer request IDs instead of strings - Update MockStreamWriter to decode binary messages - Import binary protocol constants from module level --- tests/test_dirty_worker.py | 45 ++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/test_dirty_worker.py b/tests/test_dirty_worker.py index f68a2276..e50e7c41 100644 --- a/tests/test_dirty_worker.py +++ b/tests/test_dirty_worker.py @@ -12,7 +12,13 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, + HEADER_FORMAT, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -56,17 +62,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-123", + request_id=123, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE - assert response["id"] == "test-123" + assert response["id"] == 123 assert response["result"] == 6 @pytest.mark.asyncio @@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-456", + request_id=456, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR - assert response["id"] == "test-456" + assert response["id"] == 456 assert "Unknown operation" in response["error"]["message"] @pytest.mark.asyncio @@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest: socket_path=socket_path ) - request = {"type": "unknown", "id": "test-789"} + request = {"type": "unknown", "id": 789} writer = MockStreamWriter() await worker.handle_request(request, writer) @@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync: # Create a simple test using stream reader/writer request = make_request( - request_id="conn-test", + request_id=999, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(5, 3), @@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync: # Mock reader and writer reader = asyncio.StreamReader() - encoded_request = DirtyProtocol.encode(request) + encoded_request = BinaryProtocol._encode_from_dict(request) reader.feed_data(encoded_request) reader.feed_eof()