feat(dirty): update client for binary protocol

Update client and streaming tests to work with the binary protocol:
- Update MockStreamWriter/MockStreamReader to use BinaryProtocol
- Replace string request IDs with integers
- Update test assertions to decode binary protocol messages
- Use HEADER_SIZE and decode_header/decode_message instead of old API
This commit is contained in:
Benoit Chesneau 2026-02-11 23:12:44 +01:00
parent 98b1b649c2
commit 477b7479cc
9 changed files with 258 additions and 177 deletions

View File

@ -12,11 +12,13 @@ import pytest
from gunicorn.dirty.protocol import ( from gunicorn.dirty.protocol import (
DirtyProtocol, DirtyProtocol,
BinaryProtocol,
make_request, make_request,
make_response, make_response,
make_chunk_message, make_chunk_message,
make_end_message, make_end_message,
make_error_response, make_error_response,
HEADER_SIZE,
) )
from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.errors import DirtyError from gunicorn.dirty.errors import DirtyError
@ -34,16 +36,22 @@ class MockStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: # Decode the buffer to extract messages using binary protocol
length = struct.unpack( while len(self._buffer) >= HEADER_SIZE:
DirtyProtocol.HEADER_FORMAT, # Decode header to get payload length
self._buffer[:DirtyProtocol.HEADER_SIZE] _, _, length = BinaryProtocol.decode_header(
)[0] self._buffer[:HEADER_SIZE]
total_size = DirtyProtocol.HEADER_SIZE + length )
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break
@ -63,7 +71,7 @@ class MockStreamReader:
def __init__(self, messages): def __init__(self, messages):
self._data = b'' self._data = b''
for msg in messages: for msg in messages:
self._data += DirtyProtocol.encode(msg) self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0 self._pos = 0
async def readexactly(self, n): async def readexactly(self, n):
@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding:
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
# Mock worker connection that returns chunks # Mock worker connection that returns chunks
chunk1 = make_chunk_message("req-123", "Hello") chunk1 = make_chunk_message(123, "Hello")
chunk2 = make_chunk_message("req-123", " World") chunk2 = make_chunk_message(123, " World")
end = make_end_message("req-123") end = make_end_message(123)
mock_reader = MockStreamReader([chunk1, chunk2, end]) mock_reader = MockStreamReader([chunk1, chunk2, end])
@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
# Should have forwarded all messages # Should have forwarded all messages
@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter() arbiter = create_arbiter()
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
response = make_response("req-123", {"result": 42}) response = make_response(123, {"result": 42})
mock_reader = MockStreamReader([response]) mock_reader = MockStreamReader([response])
async def mock_get_connection(pid): async def mock_get_connection(pid):
@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "compute") request = make_request(123, "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1 assert len(client_writer.messages) == 1
@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter() arbiter = create_arbiter()
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
chunk = make_chunk_message("req-123", "First") chunk = make_chunk_message(123, "First")
error = make_error_response("req-123", DirtyError("Something broke")) error = make_error_response(123, DirtyError("Something broke"))
mock_reader = MockStreamReader([chunk, error]) mock_reader = MockStreamReader([chunk, error])
@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 2 assert len(client_writer.messages) == 2
@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1 assert len(client_writer.messages) == 1
@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming:
arbiter.workers = {} # No workers arbiter.workers = {} # No workers
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter.route_request(request, client_writer) await arbiter.route_request(request, client_writer)
assert len(client_writer.messages) == 1 assert len(client_writer.messages) == 1
@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming:
# Mock _execute_on_worker to complete immediately # Mock _execute_on_worker to complete immediately
async def mock_execute(pid, request, client_writer): async def mock_execute(pid, request, client_writer):
response = make_response("req-123", "result") response = make_response(123, "result")
await DirtyProtocol.write_message_async(client_writer, response) await DirtyProtocol.write_message_async(client_writer, response)
arbiter._execute_on_worker = mock_execute arbiter._execute_on_worker = mock_execute
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "compute") request = make_request(123, "test:App", "compute")
# Worker queue should be created # Worker queue should be created
assert 1234 not in arbiter.worker_queues assert 1234 not in arbiter.worker_queues
@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks:
# Generate 50 chunks + end # Generate 50 chunks + end
messages = [] messages = []
for i in range(50): for i in range(50):
messages.append(make_chunk_message("req-123", f"chunk-{i}")) messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message("req-123")) messages.append(make_end_message(123))
mock_reader = MockStreamReader(messages) mock_reader = MockStreamReader(messages)
@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 51 assert len(client_writer.messages) == 51
@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter() arbiter = create_arbiter()
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
response = make_response("req-123", [1, 2, 3, 4, 5]) response = make_response(123, [1, 2, 3, 4, 5])
mock_reader = MockStreamReader([response]) mock_reader = MockStreamReader([response])
async def mock_get_connection(pid): async def mock_get_connection(pid):
@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "get_list") request = make_request(123, "test:App", "get_list")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1 assert len(client_writer.messages) == 1
@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter() arbiter = create_arbiter()
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
error = make_error_response("req-123", DirtyError("Something failed")) error = make_error_response(123, DirtyError("Something failed"))
mock_reader = MockStreamReader([error]) mock_reader = MockStreamReader([error])
async def mock_get_connection(pid): async def mock_get_connection(pid):
@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "fail") request = make_request(123, "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1 assert len(client_writer.messages) == 1

View File

@ -11,10 +11,12 @@ from unittest import mock
from gunicorn.dirty.protocol import ( from gunicorn.dirty.protocol import (
DirtyProtocol, DirtyProtocol,
BinaryProtocol,
make_chunk_message, make_chunk_message,
make_end_message, make_end_message,
make_response, make_response,
make_error_response, make_error_response,
HEADER_SIZE,
) )
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
@ -26,7 +28,7 @@ class MockSocket:
def __init__(self, messages): def __init__(self, messages):
self._data = b'' self._data = b''
for msg in messages: for msg in messages:
self._data += DirtyProtocol.encode(msg) self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0 self._pos = 0
self._sent = [] self._sent = []
self.closed = False self.closed = False
@ -69,10 +71,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_chunks(self): def test_stream_iterator_yields_chunks(self):
"""Test that stream iterator yields chunks correctly.""" """Test that stream iterator yields chunks correctly."""
messages = [ messages = [
make_chunk_message("req-123", "Hello"), make_chunk_message(123, "Hello"),
make_chunk_message("req-123", " "), make_chunk_message(123, " "),
make_chunk_message("req-123", "World"), make_chunk_message(123, "World"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -83,9 +85,9 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_complex_chunks(self): def test_stream_iterator_yields_complex_chunks(self):
"""Test that stream iterator yields complex data types.""" """Test that stream iterator yields complex data types."""
messages = [ messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}), make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -98,8 +100,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_handles_error(self): def test_stream_iterator_handles_error(self):
"""Test that stream iterator raises on error message.""" """Test that stream iterator raises on error message."""
messages = [ messages = [
make_chunk_message("req-123", "First"), make_chunk_message(123, "First"),
make_error_response("req-123", DirtyError("Something broke")), make_error_response(123, DirtyError("Something broke")),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -116,7 +118,7 @@ class TestDirtyStreamIterator:
def test_stream_iterator_empty_stream(self): def test_stream_iterator_empty_stream(self):
"""Test that empty stream (just end) works.""" """Test that empty stream (just end) works."""
messages = [make_end_message("req-123")] messages = [make_end_message(123)]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate")) chunks = list(client.stream("test:App", "generate"))
@ -125,8 +127,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_stops_after_exhausted(self): def test_stream_iterator_stops_after_exhausted(self):
"""Test that iterator stays exhausted after StopIteration.""" """Test that iterator stays exhausted after StopIteration."""
messages = [ messages = [
make_chunk_message("req-123", "Only"), make_chunk_message(123, "Only"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -147,10 +149,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_with_for_loop(self): def test_stream_iterator_with_for_loop(self):
"""Test stream iterator works in for loop.""" """Test stream iterator works in for loop."""
messages = [ messages = [
make_chunk_message("req-123", "a"), make_chunk_message(123, "a"),
make_chunk_message("req-123", "b"), make_chunk_message(123, "b"),
make_chunk_message("req-123", "c"), make_chunk_message(123, "c"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -163,8 +165,8 @@ class TestDirtyStreamIterator:
def test_stream_sends_request_on_first_iteration(self): def test_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first next() call.""" """Test that request is sent on first next() call."""
messages = [ messages = [
make_chunk_message("req-123", "data"), make_chunk_message(123, "data"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -179,18 +181,15 @@ class TestDirtyStreamIterator:
# Decode sent request # Decode sent request
sent_data = client._sock._sent[0] sent_data = client._sock._sent[0]
length = struct.unpack( _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
DirtyProtocol.HEADER_FORMAT, msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:DirtyProtocol.HEADER_SIZE] sent_data[:HEADER_SIZE + length]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
) )
assert request["type"] == "request" assert msg_type_str == "request"
assert request["app_path"] == "test:App" assert payload["app_path"] == "test:App"
assert request["action"] == "generate" assert payload["action"] == "generate"
assert request["args"] == ["prompt_arg"] assert payload["args"] == ["prompt_arg"]
class TestDirtyStreamIteratorEdgeCases: class TestDirtyStreamIteratorEdgeCases:
@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases:
"""Test streaming with many chunks.""" """Test streaming with many chunks."""
messages = [] messages = []
for i in range(100): for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}")) messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message("req-123")) messages.append(make_end_message(123))
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases:
def test_stream_with_kwargs(self): def test_stream_with_kwargs(self):
"""Test streaming with keyword arguments.""" """Test streaming with keyword arguments."""
messages = [ messages = [
make_chunk_message("req-123", "data"), make_chunk_message(123, "data"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_client_with_mock_socket(messages) client = create_client_with_mock_socket(messages)
@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases:
# Check the sent request includes kwargs # Check the sent request includes kwargs
sent_data = client._sock._sent[0] sent_data = client._sock._sent[0]
length = struct.unpack( _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
DirtyProtocol.HEADER_FORMAT, msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:DirtyProtocol.HEADER_SIZE] sent_data[:HEADER_SIZE + length]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
) )
assert request["args"] == ["arg1"] assert payload["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"} assert payload["kwargs"] == {"key": "value"}

View File

@ -10,9 +10,11 @@ import pytest
from gunicorn.dirty.protocol import ( from gunicorn.dirty.protocol import (
DirtyProtocol, DirtyProtocol,
BinaryProtocol,
make_chunk_message, make_chunk_message,
make_end_message, make_end_message,
make_error_response, make_error_response,
HEADER_SIZE,
) )
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
@ -24,7 +26,7 @@ class MockAsyncReader:
def __init__(self, messages): def __init__(self, messages):
self._data = b'' self._data = b''
for msg in messages: for msg in messages:
self._data += DirtyProtocol.encode(msg) self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0 self._pos = 0
async def readexactly(self, n): async def readexactly(self, n):
@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_chunks(self): async def test_async_stream_yields_chunks(self):
"""Test that async stream iterator yields chunks correctly.""" """Test that async stream iterator yields chunks correctly."""
messages = [ messages = [
make_chunk_message("req-123", "Hello"), make_chunk_message(123, "Hello"),
make_chunk_message("req-123", " "), make_chunk_message(123, " "),
make_chunk_message("req-123", "World"), make_chunk_message(123, "World"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_complex_chunks(self): async def test_async_stream_yields_complex_chunks(self):
"""Test that async stream iterator yields complex data types.""" """Test that async stream iterator yields complex data types."""
messages = [ messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}), make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_handles_error(self): async def test_async_stream_handles_error(self):
"""Test that async stream iterator raises on error message.""" """Test that async stream iterator raises on error message."""
messages = [ messages = [
make_chunk_message("req-123", "First"), make_chunk_message(123, "First"),
make_error_response("req-123", DirtyError("Something broke")), make_error_response(123, DirtyError("Something broke")),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_stream_empty_stream(self): async def test_async_stream_empty_stream(self):
"""Test that empty stream (just end) works.""" """Test that empty stream (just end) works."""
messages = [make_end_message("req-123")] messages = [make_end_message(123)]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
chunks = [] chunks = []
@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_stops_after_exhausted(self): async def test_async_stream_stops_after_exhausted(self):
"""Test that async iterator stays exhausted after StopAsyncIteration.""" """Test that async iterator stays exhausted after StopAsyncIteration."""
messages = [ messages = [
make_chunk_message("req-123", "Only"), make_chunk_message(123, "Only"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_sends_request_on_first_iteration(self): async def test_async_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first async iteration.""" """Test that request is sent on first async iteration."""
messages = [ messages = [
make_chunk_message("req-123", "data"), make_chunk_message(123, "data"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator:
# Decode sent request # Decode sent request
sent_data = client._writer._sent[0] sent_data = client._writer._sent[0]
length = struct.unpack( _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
DirtyProtocol.HEADER_FORMAT, msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:DirtyProtocol.HEADER_SIZE] sent_data[:HEADER_SIZE + length]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
) )
assert request["type"] == "request" assert msg_type_str == "request"
assert request["app_path"] == "test:App" assert payload["app_path"] == "test:App"
assert request["action"] == "generate" assert payload["action"] == "generate"
assert request["args"] == ["prompt_arg"] assert payload["args"] == ["prompt_arg"]
class TestDirtyAsyncStreamIteratorEdgeCases: class TestDirtyAsyncStreamIteratorEdgeCases:
@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
"""Test async streaming with many chunks.""" """Test async streaming with many chunks."""
messages = [] messages = []
for i in range(100): for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}")) messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message("req-123")) messages.append(make_end_message(123))
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
async def test_async_stream_with_kwargs(self): async def test_async_stream_with_kwargs(self):
"""Test async streaming with keyword arguments.""" """Test async streaming with keyword arguments."""
messages = [ messages = [
make_chunk_message("req-123", "data"), make_chunk_message(123, "data"),
make_end_message("req-123"), make_end_message(123),
] ]
client = create_async_client_with_mocks(messages) client = create_async_client_with_mocks(messages)
@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
# Check the sent request includes kwargs # Check the sent request includes kwargs
sent_data = client._writer._sent[0] sent_data = client._writer._sent[0]
length = struct.unpack( _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
DirtyProtocol.HEADER_FORMAT, msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:DirtyProtocol.HEADER_SIZE] sent_data[:HEADER_SIZE + length]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
) )
assert request["args"] == ["arg1"] assert payload["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"} assert payload["kwargs"] == {"key": "value"}
class TestDirtyAsyncStreamTimeout: class TestDirtyAsyncStreamTimeout:

View File

@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor
from gunicorn.config import Config from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.protocol import DirtyProtocol, make_request from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
)
from gunicorn.dirty.errors import DirtyAppNotFoundError from gunicorn.dirty.errors import DirtyAppNotFoundError
@ -71,16 +76,22 @@ class MockStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: # Decode the buffer to extract messages using binary protocol
length = struct.unpack( while len(self._buffer) >= HEADER_SIZE:
DirtyProtocol.HEADER_FORMAT, # Decode header to get payload length
self._buffer[:DirtyProtocol.HEADER_SIZE] _, _, length = BinaryProtocol.decode_header(
)[0] self._buffer[:HEADER_SIZE]
total_size = DirtyProtocol.HEADER_SIZE + length )
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break

View File

@ -18,11 +18,13 @@ from unittest import mock
from gunicorn.config import Config from gunicorn.config import Config
from gunicorn.dirty.protocol import ( from gunicorn.dirty.protocol import (
DirtyProtocol, DirtyProtocol,
BinaryProtocol,
make_request, make_request,
make_chunk_message, make_chunk_message,
make_end_message, make_end_message,
make_response, make_response,
make_error_response, make_error_response,
HEADER_SIZE,
) )
from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.arbiter import DirtyArbiter
@ -67,16 +69,22 @@ class MockStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: # Decode the buffer to extract messages using binary protocol
length = struct.unpack( while len(self._buffer) >= HEADER_SIZE:
DirtyProtocol.HEADER_FORMAT, # Decode header to get payload length
self._buffer[:DirtyProtocol.HEADER_SIZE] _, _, length = BinaryProtocol.decode_header(
)[0] self._buffer[:HEADER_SIZE]
total_size = DirtyProtocol.HEADER_SIZE + length )
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break
@ -96,7 +104,7 @@ class MockStreamReader:
def __init__(self, messages): def __init__(self, messages):
self._data = b'' self._data = b''
for msg in messages: for msg in messages:
self._data += DirtyProtocol.encode(msg) self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0 self._pos = 0
async def readexactly(self, n): async def readexactly(self, n):
@ -115,10 +123,10 @@ class TestStreamingEndToEnd:
"""Test complete flow: sync generator -> worker -> arbiter -> client.""" """Test complete flow: sync generator -> worker -> arbiter -> client."""
# Simulate what a worker would produce for a sync generator # Simulate what a worker would produce for a sync generator
worker_messages = [ worker_messages = [
make_chunk_message("req-123", "Hello"), make_chunk_message(123, "Hello"),
make_chunk_message("req-123", " "), make_chunk_message(123, " "),
make_chunk_message("req-123", "World"), make_chunk_message(123, "World"),
make_end_message("req-123"), make_end_message(123),
] ]
# Create an arbiter with mocked worker connection # Create an arbiter with mocked worker connection
@ -141,7 +149,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
# Execute request through arbiter # Execute request through arbiter
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
# Verify all messages were forwarded # Verify all messages were forwarded
@ -158,10 +166,10 @@ class TestStreamingEndToEnd:
async def test_async_generator_end_to_end(self): async def test_async_generator_end_to_end(self):
"""Test complete flow: async generator -> worker -> arbiter -> client.""" """Test complete flow: async generator -> worker -> arbiter -> client."""
worker_messages = [ worker_messages = [
make_chunk_message("req-456", "Async"), make_chunk_message(456, "Async"),
make_chunk_message("req-456", " "), make_chunk_message(456, " "),
make_chunk_message("req-456", "Stream"), make_chunk_message(456, "Stream"),
make_end_message("req-456"), make_end_message(456),
] ]
cfg = Config() cfg = Config()
@ -180,7 +188,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
request = make_request("req-456", "test:App", "async_generate") request = make_request(456, "test:App", "async_generate")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 4 assert len(client_writer.messages) == 4
@ -197,9 +205,9 @@ class TestStreamingErrorHandling:
async def test_error_mid_stream(self): async def test_error_mid_stream(self):
"""Test that errors during streaming are properly forwarded.""" """Test that errors during streaming are properly forwarded."""
worker_messages = [ worker_messages = [
make_chunk_message("req-789", "First"), make_chunk_message(789, "First"),
make_chunk_message("req-789", "Second"), make_chunk_message(789, "Second"),
make_error_response("req-789", DirtyError("Stream failed")), make_error_response(789, DirtyError("Stream failed")),
] ]
cfg = Config() cfg = Config()
@ -218,7 +226,7 @@ class TestStreamingErrorHandling:
client_writer = MockStreamWriter() client_writer = MockStreamWriter()
request = make_request("req-789", "test:App", "generate_with_error") request = make_request(789, "test:App", "generate_with_error")
await arbiter._execute_on_worker(1234, request, client_writer) await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 2 chunks + 1 error # Should have 2 chunks + 1 error
@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration:
return sync_gen() return sync_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end # Should have 3 chunks + 1 end
@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration:
return async_gen() return async_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-456", "test:App", "async_generate") request = make_request(456, "test:App", "async_generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 2 chunks + 1 end # Should have 2 chunks + 1 end

View File

@ -12,9 +12,11 @@ import pytest
from gunicorn.dirty.protocol import ( from gunicorn.dirty.protocol import (
DirtyProtocol, DirtyProtocol,
BinaryProtocol,
make_request, make_request,
make_chunk_message, make_chunk_message,
make_end_message, make_end_message,
HEADER_SIZE,
) )
from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.worker import DirtyWorker
@ -30,17 +32,22 @@ class FakeStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
# Decode the buffer to extract messages # Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: while len(self._buffer) >= HEADER_SIZE:
length = struct.unpack( # Decode header to get payload length
DirtyProtocol.HEADER_FORMAT, _, _, length = BinaryProtocol.decode_header(
self._buffer[:DirtyProtocol.HEADER_SIZE] self._buffer[:HEADER_SIZE]
)[0] )
total_size = DirtyProtocol.HEADER_SIZE + length total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break
@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_tokens() return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message # Should have 3 chunks + 1 end message
@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check chunk messages # Check chunk messages
assert writer.messages[0]["type"] == "chunk" assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123" assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello" assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk" assert writer.messages[1]["type"] == "chunk"
@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check end message # Check end message
assert writer.messages[3]["type"] == "end" assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123" assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_generator_error_mid_stream(self): async def test_sync_generator_error_mid_stream(self):
@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_with_error() return generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message # Should have 1 chunk + 1 error message
@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_tokens() return async_generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message # Should have 3 chunks + 1 end message
@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check chunk messages # Check chunk messages
assert writer.messages[0]["type"] == "chunk" assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123" assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello" assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk" assert writer.messages[1]["type"] == "chunk"
@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check end message # Check end message
assert writer.messages[3]["type"] == "end" assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123" assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_generator_error_mid_stream(self): async def test_async_generator_error_mid_stream(self):
@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_with_error() return async_generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message # Should have 1 chunk + 1 error message
@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat:
return args[0] + args[1] return args[0] + args[1]
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "compute", args=(2, 3)) request = make_request(123, "test:App", "compute", args=(2, 3))
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 response message # Should have 1 response message
assert len(writer.messages) == 1 assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response" assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["id"] == "req-123" assert writer.messages[0]["id"] == 123
assert writer.messages[0]["result"] == 5 assert writer.messages[0]["result"] == 5
@pytest.mark.asyncio @pytest.mark.asyncio
@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat:
return [1, 2, 3, 4, 5] return [1, 2, 3, 4, 5]
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "get_list") request = make_request(123, "test:App", "get_list")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 response message (not 5 chunks) # Should have 1 response message (not 5 chunks)
@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat:
raise RuntimeError("Failed!") raise RuntimeError("Failed!")
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "fail") request = make_request(123, "test:App", "fail")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 error message # Should have 1 error message
@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat:
return None return None
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "void") request = make_request(123, "test:App", "void")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 1 response message # Should have 1 response message
@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData:
return generate_tokens() return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
assert len(writer.messages) == 3 # 2 chunks + 1 end assert len(writer.messages) == 3 # 2 chunks + 1 end
@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData:
return empty_generate() return empty_generate()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have just 1 end message # Should have just 1 end message
@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData:
return generate_many() return generate_many()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have 100 chunks + 1 end message # Should have 100 chunks + 1 end message
@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat:
return generate_tokens() return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute): with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate") request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer) await worker.handle_request(request, writer)
# Should have been notified at least once per chunk + initial # Should have been notified at least once per chunk + initial
@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation:
writer = FakeStreamWriter() writer = FakeStreamWriter()
# Send a message with unknown type # Send a message with unknown type
message = {"type": "unknown", "id": "req-123"} message = {"type": "unknown", "id": 123}
await worker.handle_request(message, writer) await worker.handle_request(message, writer)
assert len(writer.messages) == 1 assert len(writer.messages) == 1

View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL
BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0
MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx
EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG
A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY
6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt
z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq
AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2
HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr
FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC
TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q
FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD
AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ
KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G
Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a
A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr
YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z
FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO
5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo=
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d
Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP
oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj
KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J
P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7
wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/
LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD
CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU
NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s
GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB
7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO
jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7
2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD
xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1
FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn
UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S
zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W
GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21
9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC
7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf
7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS
HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY
4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J
vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g
5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp
JlfEuUHtKpFJReBnmWTFiQ==
-----END PRIVATE KEY-----

View File

@ -11,7 +11,7 @@ import pytest
from gunicorn.arbiter import Arbiter from gunicorn.arbiter import Arbiter
from gunicorn.config import Config from gunicorn.config import Config
from gunicorn.app.base import BaseApplication from gunicorn.app.base import BaseApplication
from gunicorn.dirty.protocol import DirtyProtocol from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE
class MockStreamWriter: class MockStreamWriter:
@ -26,16 +26,22 @@ class MockStreamWriter:
self._buffer += data self._buffer += data
async def drain(self): async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: # Decode the buffer to extract messages using binary protocol
length = struct.unpack( while len(self._buffer) >= HEADER_SIZE:
DirtyProtocol.HEADER_FORMAT, # Decode header to get payload length
self._buffer[:DirtyProtocol.HEADER_SIZE] _, _, length = BinaryProtocol.decode_header(
)[0] self._buffer[:HEADER_SIZE]
total_size = DirtyProtocol.HEADER_SIZE + length )
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size: if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:] self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data)) # decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else: else:
break break