mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
feat(dirty): update client for binary protocol
Update client and streaming tests to work with the binary protocol: - Update MockStreamWriter/MockStreamReader to use BinaryProtocol - Replace string request IDs with integers - Update test assertions to decode binary protocol messages - Use HEADER_SIZE and decode_header/decode_message instead of old API
This commit is contained in:
parent
98b1b649c2
commit
477b7479cc
@ -12,11 +12,13 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.errors import DirtyError
|
||||
@ -34,16 +36,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -63,7 +71,7 @@ class MockStreamReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding:
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Mock worker connection that returns chunks
|
||||
chunk1 = make_chunk_message("req-123", "Hello")
|
||||
chunk2 = make_chunk_message("req-123", " World")
|
||||
end = make_end_message("req-123")
|
||||
chunk1 = make_chunk_message(123, "Hello")
|
||||
chunk2 = make_chunk_message(123, " World")
|
||||
end = make_end_message(123)
|
||||
|
||||
mock_reader = MockStreamReader([chunk1, chunk2, end])
|
||||
|
||||
@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have forwarded all messages
|
||||
@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", {"result": 42})
|
||||
response = make_response(123, {"result": 42})
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
request = make_request(123, "test:App", "compute")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
chunk = make_chunk_message("req-123", "First")
|
||||
error = make_error_response("req-123", DirtyError("Something broke"))
|
||||
chunk = make_chunk_message(123, "First")
|
||||
error = make_error_response(123, DirtyError("Something broke"))
|
||||
|
||||
mock_reader = MockStreamReader([chunk, error])
|
||||
|
||||
@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 2
|
||||
@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming:
|
||||
arbiter.workers = {} # No workers
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter.route_request(request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming:
|
||||
|
||||
# Mock _execute_on_worker to complete immediately
|
||||
async def mock_execute(pid, request, client_writer):
|
||||
response = make_response("req-123", "result")
|
||||
response = make_response(123, "result")
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
arbiter._execute_on_worker = mock_execute
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
request = make_request(123, "test:App", "compute")
|
||||
|
||||
# Worker queue should be created
|
||||
assert 1234 not in arbiter.worker_queues
|
||||
@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks:
|
||||
# Generate 50 chunks + end
|
||||
messages = []
|
||||
for i in range(50):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
mock_reader = MockStreamReader(messages)
|
||||
|
||||
@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 51
|
||||
@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", [1, 2, 3, 4, 5])
|
||||
response = make_response(123, [1, 2, 3, 4, 5])
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
request = make_request(123, "test:App", "get_list")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
error = make_error_response("req-123", DirtyError("Something failed"))
|
||||
error = make_error_response(123, DirtyError("Something failed"))
|
||||
mock_reader = MockStreamReader([error])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
request = make_request(123, "test:App", "fail")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
|
||||
@ -11,10 +11,12 @@ from unittest import mock
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
|
||||
@ -26,7 +28,7 @@ class MockSocket:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
self._sent = []
|
||||
self.closed = False
|
||||
@ -69,10 +71,10 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_yields_chunks(self):
|
||||
"""Test that stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -83,9 +85,9 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_yields_complex_chunks(self):
|
||||
"""Test that stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message(123, {"token": "World", "score": 0.8}),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -98,8 +100,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_handles_error(self):
|
||||
"""Test that stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
make_chunk_message(123, "First"),
|
||||
make_error_response(123, DirtyError("Something broke")),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -116,7 +118,7 @@ class TestDirtyStreamIterator:
|
||||
|
||||
def test_stream_iterator_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
messages = [make_end_message(123)]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
@ -125,8 +127,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_stops_after_exhausted(self):
|
||||
"""Test that iterator stays exhausted after StopIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Only"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -147,10 +149,10 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_with_for_loop(self):
|
||||
"""Test stream iterator works in for loop."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "a"),
|
||||
make_chunk_message("req-123", "b"),
|
||||
make_chunk_message("req-123", "c"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "a"),
|
||||
make_chunk_message(123, "b"),
|
||||
make_chunk_message(123, "c"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -163,8 +165,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first next() call."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -179,18 +181,15 @@ class TestDirtyStreamIterator:
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
assert msg_type_str == "request"
|
||||
assert payload["app_path"] == "test:App"
|
||||
assert payload["action"] == "generate"
|
||||
assert payload["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyStreamIteratorEdgeCases:
|
||||
@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
"""Test streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
def test_stream_with_kwargs(self):
|
||||
"""Test streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
assert payload["args"] == ["arg1"]
|
||||
assert payload["kwargs"] == {"key": "value"}
|
||||
|
||||
@ -10,9 +10,11 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
|
||||
@ -24,7 +26,7 @@ class MockAsyncReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_yields_chunks(self):
|
||||
"""Test that async stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_yields_complex_chunks(self):
|
||||
"""Test that async stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message(123, {"token": "World", "score": 0.8}),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_handles_error(self):
|
||||
"""Test that async stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
make_chunk_message(123, "First"),
|
||||
make_error_response(123, DirtyError("Something broke")),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
messages = [make_end_message(123)]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_stops_after_exhausted(self):
|
||||
"""Test that async iterator stays exhausted after StopAsyncIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Only"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first async iteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator:
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
assert msg_type_str == "request"
|
||||
assert payload["app_path"] == "test:App"
|
||||
assert payload["action"] == "generate"
|
||||
assert payload["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
"""Test async streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
async def test_async_stream_with_kwargs(self):
|
||||
"""Test async streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
assert payload["args"] == ["arg1"]
|
||||
assert payload["kwargs"] == {"key": "value"}
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamTimeout:
|
||||
|
||||
@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.errors import DirtyAppNotFoundError
|
||||
|
||||
|
||||
@ -71,16 +76,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@ -18,11 +18,13 @@ from unittest import mock
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
@ -67,16 +69,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -96,7 +104,7 @@ class MockStreamReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -115,10 +123,10 @@ class TestStreamingEndToEnd:
|
||||
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
|
||||
# Simulate what a worker would produce for a sync generator
|
||||
worker_messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
|
||||
# Create an arbiter with mocked worker connection
|
||||
@ -141,7 +149,7 @@ class TestStreamingEndToEnd:
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Execute request through arbiter
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Verify all messages were forwarded
|
||||
@ -158,10 +166,10 @@ class TestStreamingEndToEnd:
|
||||
async def test_async_generator_end_to_end(self):
|
||||
"""Test complete flow: async generator -> worker -> arbiter -> client."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-456", "Async"),
|
||||
make_chunk_message("req-456", " "),
|
||||
make_chunk_message("req-456", "Stream"),
|
||||
make_end_message("req-456"),
|
||||
make_chunk_message(456, "Async"),
|
||||
make_chunk_message(456, " "),
|
||||
make_chunk_message(456, "Stream"),
|
||||
make_end_message(456),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
@ -180,7 +188,7 @@ class TestStreamingEndToEnd:
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
request = make_request(456, "test:App", "async_generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 4
|
||||
@ -197,9 +205,9 @@ class TestStreamingErrorHandling:
|
||||
async def test_error_mid_stream(self):
|
||||
"""Test that errors during streaming are properly forwarded."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-789", "First"),
|
||||
make_chunk_message("req-789", "Second"),
|
||||
make_error_response("req-789", DirtyError("Stream failed")),
|
||||
make_chunk_message(789, "First"),
|
||||
make_chunk_message(789, "Second"),
|
||||
make_error_response(789, DirtyError("Stream failed")),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
@ -218,7 +226,7 @@ class TestStreamingErrorHandling:
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-789", "test:App", "generate_with_error")
|
||||
request = make_request(789, "test:App", "generate_with_error")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have 2 chunks + 1 error
|
||||
@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration:
|
||||
return sync_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end
|
||||
@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration:
|
||||
return async_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
request = make_request(456, "test:App", "async_generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 2 chunks + 1 end
|
||||
|
||||
@ -12,9 +12,11 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
|
||||
@ -30,17 +32,22 @@ class FakeStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode the buffer to extract messages
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
assert writer.messages[3]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generator_error_mid_stream(self):
|
||||
@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
return generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
return async_generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
assert writer.messages[3]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator_error_mid_stream(self):
|
||||
@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
return async_generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return args[0] + args[1]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "compute", args=(2, 3))
|
||||
request = make_request(123, "test:App", "compute", args=(2, 3))
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "response"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["result"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return [1, 2, 3, 4, 5]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
request = make_request(123, "test:App", "get_list")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message (not 5 chunks)
|
||||
@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
raise RuntimeError("Failed!")
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
request = make_request(123, "test:App", "fail")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 error message
|
||||
@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return None
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "void")
|
||||
request = make_request(123, "test:App", "void")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 3 # 2 chunks + 1 end
|
||||
@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData:
|
||||
return empty_generate()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have just 1 end message
|
||||
@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData:
|
||||
return generate_many()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 100 chunks + 1 end message
|
||||
@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have been notified at least once per chunk + initial
|
||||
@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation:
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
# Send a message with unknown type
|
||||
message = {"type": "unknown", "id": "req-123"}
|
||||
message = {"type": "unknown", "id": 123}
|
||||
await worker.handle_request(message, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
|
||||
21
tests/docker/http2/certs/server.crt
Normal file
21
tests/docker/http2/certs/server.crt
Normal file
@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL
|
||||
BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0
|
||||
MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx
|
||||
EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG
|
||||
A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY
|
||||
6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt
|
||||
z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq
|
||||
AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2
|
||||
HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr
|
||||
FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC
|
||||
TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q
|
||||
FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD
|
||||
AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ
|
||||
KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G
|
||||
Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a
|
||||
A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr
|
||||
YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z
|
||||
FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO
|
||||
5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo=
|
||||
-----END CERTIFICATE-----
|
||||
28
tests/docker/http2/certs/server.key
Normal file
28
tests/docker/http2/certs/server.key
Normal file
@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d
|
||||
Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP
|
||||
oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj
|
||||
KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J
|
||||
P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7
|
||||
wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/
|
||||
LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD
|
||||
CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU
|
||||
NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s
|
||||
GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB
|
||||
7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO
|
||||
jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7
|
||||
2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD
|
||||
xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1
|
||||
FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn
|
||||
UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S
|
||||
zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W
|
||||
GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21
|
||||
9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC
|
||||
7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf
|
||||
7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS
|
||||
HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY
|
||||
4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J
|
||||
vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g
|
||||
5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp
|
||||
JlfEuUHtKpFJReBnmWTFiQ==
|
||||
-----END PRIVATE KEY-----
|
||||
@ -11,7 +11,7 @@ import pytest
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.app.base import BaseApplication
|
||||
from gunicorn.dirty.protocol import DirtyProtocol
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
@ -26,16 +26,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user