feat(dirty): update client for binary protocol

Update client and streaming tests to work with the binary protocol:
- Update MockStreamWriter/MockStreamReader to use BinaryProtocol
- Replace string request IDs with integers
- Update test assertions to decode binary protocol messages
- Use HEADER_SIZE and decode_header/decode_message instead of old API
This commit is contained in:
Benoit Chesneau 2026-02-11 23:12:44 +01:00
parent 98b1b649c2
commit 477b7479cc
9 changed files with 258 additions and 177 deletions

View File

@ -12,11 +12,13 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_response,
make_chunk_message,
make_end_message,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.errors import DirtyError
@ -34,16 +36,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -63,7 +71,7 @@ class MockStreamReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding:
client_writer = MockStreamWriter()
# Mock worker connection that returns chunks
chunk1 = make_chunk_message("req-123", "Hello")
chunk2 = make_chunk_message("req-123", " World")
end = make_end_message("req-123")
chunk1 = make_chunk_message(123, "Hello")
chunk2 = make_chunk_message(123, " World")
end = make_end_message(123)
mock_reader = MockStreamReader([chunk1, chunk2, end])
@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have forwarded all messages
@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", {"result": 42})
response = make_response(123, {"result": 42})
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "compute")
request = make_request(123, "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
chunk = make_chunk_message("req-123", "First")
error = make_error_response("req-123", DirtyError("Something broke"))
chunk = make_chunk_message(123, "First")
error = make_error_response(123, DirtyError("Something broke"))
mock_reader = MockStreamReader([chunk, error])
@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 2
@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming:
arbiter.workers = {} # No workers
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter.route_request(request, client_writer)
assert len(client_writer.messages) == 1
@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming:
# Mock _execute_on_worker to complete immediately
async def mock_execute(pid, request, client_writer):
response = make_response("req-123", "result")
response = make_response(123, "result")
await DirtyProtocol.write_message_async(client_writer, response)
arbiter._execute_on_worker = mock_execute
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "compute")
request = make_request(123, "test:App", "compute")
# Worker queue should be created
assert 1234 not in arbiter.worker_queues
@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks:
# Generate 50 chunks + end
messages = []
for i in range(50):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
mock_reader = MockStreamReader(messages)
@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 51
@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", [1, 2, 3, 4, 5])
response = make_response(123, [1, 2, 3, 4, 5])
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "get_list")
request = make_request(123, "test:App", "get_list")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
error = make_error_response("req-123", DirtyError("Something failed"))
error = make_error_response(123, DirtyError("Something failed"))
mock_reader = MockStreamReader([error])
async def mock_get_connection(pid):
@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "fail")
request = make_request(123, "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1

View File

@ -11,10 +11,12 @@ from unittest import mock
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
@ -26,7 +28,7 @@ class MockSocket:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
self._sent = []
self.closed = False
@ -69,10 +71,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_chunks(self):
"""Test that stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -83,9 +85,9 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_complex_chunks(self):
"""Test that stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -98,8 +100,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_handles_error(self):
"""Test that stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
make_chunk_message(123, "First"),
make_error_response(123, DirtyError("Something broke")),
]
client = create_client_with_mock_socket(messages)
@ -116,7 +118,7 @@ class TestDirtyStreamIterator:
def test_stream_iterator_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
messages = [make_end_message(123)]
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
@ -125,8 +127,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_stops_after_exhausted(self):
"""Test that iterator stays exhausted after StopIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
make_chunk_message(123, "Only"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -147,10 +149,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_with_for_loop(self):
"""Test stream iterator works in for loop."""
messages = [
make_chunk_message("req-123", "a"),
make_chunk_message("req-123", "b"),
make_chunk_message("req-123", "c"),
make_end_message("req-123"),
make_chunk_message(123, "a"),
make_chunk_message(123, "b"),
make_chunk_message(123, "c"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -163,8 +165,8 @@ class TestDirtyStreamIterator:
def test_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first next() call."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -179,18 +181,15 @@ class TestDirtyStreamIterator:
# Decode sent request
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
assert msg_type_str == "request"
assert payload["app_path"] == "test:App"
assert payload["action"] == "generate"
assert payload["args"] == ["prompt_arg"]
class TestDirtyStreamIteratorEdgeCases:
@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases:
"""Test streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
client = create_client_with_mock_socket(messages)
@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases:
def test_stream_with_kwargs(self):
"""Test streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases:
# Check the sent request includes kwargs
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}
assert payload["args"] == ["arg1"]
assert payload["kwargs"] == {"key": "value"}

View File

@ -10,9 +10,11 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_chunk_message,
make_end_message,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
@ -24,7 +26,7 @@ class MockAsyncReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_chunks(self):
"""Test that async stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_complex_chunks(self):
"""Test that async stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_handles_error(self):
"""Test that async stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
make_chunk_message(123, "First"),
make_error_response(123, DirtyError("Something broke")),
]
client = create_async_client_with_mocks(messages)
@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator:
@pytest.mark.asyncio
async def test_async_stream_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
messages = [make_end_message(123)]
client = create_async_client_with_mocks(messages)
chunks = []
@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_stops_after_exhausted(self):
"""Test that async iterator stays exhausted after StopAsyncIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
make_chunk_message(123, "Only"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first async iteration."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator:
# Decode sent request
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
assert msg_type_str == "request"
assert payload["app_path"] == "test:App"
assert payload["action"] == "generate"
assert payload["args"] == ["prompt_arg"]
class TestDirtyAsyncStreamIteratorEdgeCases:
@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
"""Test async streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
client = create_async_client_with_mocks(messages)
@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
async def test_async_stream_with_kwargs(self):
"""Test async streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
# Check the sent request includes kwargs
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}
assert payload["args"] == ["arg1"]
assert payload["kwargs"] == {"key": "value"}
class TestDirtyAsyncStreamTimeout:

View File

@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor
from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
)
from gunicorn.dirty.errors import DirtyAppNotFoundError
@ -71,16 +76,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break

View File

@ -18,11 +18,13 @@ from unittest import mock
from gunicorn.config import Config
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
@ -67,16 +69,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -96,7 +104,7 @@ class MockStreamReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -115,10 +123,10 @@ class TestStreamingEndToEnd:
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
# Simulate what a worker would produce for a sync generator
worker_messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
# Create an arbiter with mocked worker connection
@ -141,7 +149,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter()
# Execute request through arbiter
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Verify all messages were forwarded
@ -158,10 +166,10 @@ class TestStreamingEndToEnd:
async def test_async_generator_end_to_end(self):
"""Test complete flow: async generator -> worker -> arbiter -> client."""
worker_messages = [
make_chunk_message("req-456", "Async"),
make_chunk_message("req-456", " "),
make_chunk_message("req-456", "Stream"),
make_end_message("req-456"),
make_chunk_message(456, "Async"),
make_chunk_message(456, " "),
make_chunk_message(456, "Stream"),
make_end_message(456),
]
cfg = Config()
@ -180,7 +188,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter()
request = make_request("req-456", "test:App", "async_generate")
request = make_request(456, "test:App", "async_generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 4
@ -197,9 +205,9 @@ class TestStreamingErrorHandling:
async def test_error_mid_stream(self):
"""Test that errors during streaming are properly forwarded."""
worker_messages = [
make_chunk_message("req-789", "First"),
make_chunk_message("req-789", "Second"),
make_error_response("req-789", DirtyError("Stream failed")),
make_chunk_message(789, "First"),
make_chunk_message(789, "Second"),
make_error_response(789, DirtyError("Stream failed")),
]
cfg = Config()
@ -218,7 +226,7 @@ class TestStreamingErrorHandling:
client_writer = MockStreamWriter()
request = make_request("req-789", "test:App", "generate_with_error")
request = make_request(789, "test:App", "generate_with_error")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 2 chunks + 1 error
@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration:
return sync_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end
@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration:
return async_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-456", "test:App", "async_generate")
request = make_request(456, "test:App", "async_generate")
await worker.handle_request(request, writer)
# Should have 2 chunks + 1 end

View File

@ -12,9 +12,11 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_chunk_message,
make_end_message,
HEADER_SIZE,
)
from gunicorn.dirty.worker import DirtyWorker
@ -30,17 +32,22 @@ class FakeStreamWriter:
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_sync_generator_error_mid_stream(self):
@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_async_generator_error_mid_stream(self):
@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat:
return args[0] + args[1]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "compute", args=(2, 3))
request = make_request(123, "test:App", "compute", args=(2, 3))
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["result"] == 5
@pytest.mark.asyncio
@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat:
return [1, 2, 3, 4, 5]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "get_list")
request = make_request(123, "test:App", "get_list")
await worker.handle_request(request, writer)
# Should have 1 response message (not 5 chunks)
@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat:
raise RuntimeError("Failed!")
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "fail")
request = make_request(123, "test:App", "fail")
await worker.handle_request(request, writer)
# Should have 1 error message
@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat:
return None
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "void")
request = make_request(123, "test:App", "void")
await worker.handle_request(request, writer)
# Should have 1 response message
@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
assert len(writer.messages) == 3 # 2 chunks + 1 end
@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData:
return empty_generate()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have just 1 end message
@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData:
return generate_many()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 100 chunks + 1 end message
@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have been notified at least once per chunk + initial
@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation:
writer = FakeStreamWriter()
# Send a message with unknown type
message = {"type": "unknown", "id": "req-123"}
message = {"type": "unknown", "id": 123}
await worker.handle_request(message, writer)
assert len(writer.messages) == 1

View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL
BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0
MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx
EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG
A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY
6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt
z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq
AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2
HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr
FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC
TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q
FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD
AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ
KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G
Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a
A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr
YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z
FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO
5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo=
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d
Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP
oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj
KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J
P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7
wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/
LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD
CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU
NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s
GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB
7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO
jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7
2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD
xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1
FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn
UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S
zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W
GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21
9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC
7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf
7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS
HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY
4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J
vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g
5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp
JlfEuUHtKpFJReBnmWTFiQ==
-----END PRIVATE KEY-----

View File

@ -11,7 +11,7 @@ import pytest
from gunicorn.arbiter import Arbiter
from gunicorn.config import Config
from gunicorn.app.base import BaseApplication
from gunicorn.dirty.protocol import DirtyProtocol
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE
class MockStreamWriter:
@ -26,16 +26,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break