diff --git a/tests/test_dirty_arbiter.py b/tests/test_dirty_arbiter.py index 40abb504..05f35cb0 100644 --- a/tests/test_dirty_arbiter.py +++ b/tests/test_dirty_arbiter.py @@ -14,7 +14,12 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) class MockStreamWriter: @@ -29,16 +34,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break