feat(dirty): update worker for binary protocol

Update worker tests to work with the binary protocol:
- Use integer request IDs instead of strings
- Update MockStreamWriter to decode binary messages
- Import binary protocol constants from module level
This commit is contained in:
Benoit Chesneau 2026-02-11 23:01:21 +01:00
parent 1665857c0e
commit 6d2139bb6c

View File

@ -12,7 +12,13 @@ import pytest
from gunicorn.config import Config from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.protocol import DirtyProtocol, make_request from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
HEADER_FORMAT,
)
from gunicorn.dirty.errors import DirtyAppNotFoundError from gunicorn.dirty.errors import DirtyAppNotFoundError
@ -56,17 +62,22 @@ class MockStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
# Decode the buffer to extract messages # Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: while len(self._buffer) >= HEADER_SIZE:
length = struct.unpack( # Decode header to get payload length
DirtyProtocol.HEADER_FORMAT, _, _, length = BinaryProtocol.decode_header(
self._buffer[:DirtyProtocol.HEADER_SIZE] self._buffer[:HEADER_SIZE]
)[0] )
total_size = DirtyProtocol.HEADER_SIZE + length total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break
@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest:
worker.load_apps() worker.load_apps()
request = make_request( request = make_request(
request_id="test-123", request_id=123,
app_path="tests.support_dirty_app:TestDirtyApp", app_path="tests.support_dirty_app:TestDirtyApp",
action="compute", action="compute",
args=(2, 3), args=(2, 3),
@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest:
assert len(writer.messages) == 1 assert len(writer.messages) == 1
response = writer.messages[0] response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
assert response["id"] == "test-123" assert response["id"] == 123
assert response["result"] == 6 assert response["result"] == 6
@pytest.mark.asyncio @pytest.mark.asyncio
@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest:
worker.load_apps() worker.load_apps()
request = make_request( request = make_request(
request_id="test-456", request_id=456,
app_path="tests.support_dirty_app:TestDirtyApp", app_path="tests.support_dirty_app:TestDirtyApp",
action="compute", action="compute",
args=(2, 3), args=(2, 3),
@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest:
assert len(writer.messages) == 1 assert len(writer.messages) == 1
response = writer.messages[0] response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
assert response["id"] == "test-456" assert response["id"] == 456
assert "Unknown operation" in response["error"]["message"] assert "Unknown operation" in response["error"]["message"]
@pytest.mark.asyncio @pytest.mark.asyncio
@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest:
socket_path=socket_path socket_path=socket_path
) )
request = {"type": "unknown", "id": "test-789"} request = {"type": "unknown", "id": 789}
writer = MockStreamWriter() writer = MockStreamWriter()
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync:
# Create a simple test using stream reader/writer # Create a simple test using stream reader/writer
request = make_request( request = make_request(
request_id="conn-test", request_id=999,
app_path="tests.support_dirty_app:TestDirtyApp", app_path="tests.support_dirty_app:TestDirtyApp",
action="compute", action="compute",
args=(5, 3), args=(5, 3),
@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync:
# Mock reader and writer # Mock reader and writer
reader = asyncio.StreamReader() reader = asyncio.StreamReader()
encoded_request = DirtyProtocol.encode(request) encoded_request = BinaryProtocol._encode_from_dict(request)
reader.feed_data(encoded_request) reader.feed_data(encoded_request)
reader.feed_eof() reader.feed_eof()