diff --git a/tests/dirty/test_arbiter_streaming.py b/tests/dirty/test_arbiter_streaming.py index ef15c33a..a722f2af 100644 --- a/tests/dirty/test_arbiter_streaming.py +++ b/tests/dirty/test_arbiter_streaming.py @@ -12,11 +12,13 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_response, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError @@ -34,16 +36,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -63,7 +71,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding: client_writer = MockStreamWriter() # Mock worker connection that returns chunks - chunk1 = make_chunk_message("req-123", "Hello") - chunk2 = make_chunk_message("req-123", " World") - end = make_end_message("req-123") + chunk1 = make_chunk_message(123, "Hello") + chunk2 = make_chunk_message(123, " World") + end = make_end_message(123) mock_reader = MockStreamReader([chunk1, chunk2, end]) @@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Should have forwarded all messages @@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", {"result": 42}) + response = make_response(123, {"result": 42}) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - chunk = make_chunk_message("req-123", "First") - error = make_error_response("req-123", DirtyError("Something broke")) + chunk = make_chunk_message(123, "First") + error = make_error_response(123, DirtyError("Something broke")) mock_reader = MockStreamReader([chunk, error]) @@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 2 @@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming: arbiter.workers = {} # No workers client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter.route_request(request, client_writer) assert len(client_writer.messages) == 1 @@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming: # Mock _execute_on_worker to complete immediately async def mock_execute(pid, request, client_writer): - response = make_response("req-123", "result") + response = make_response(123, "result") await DirtyProtocol.write_message_async(client_writer, response) arbiter._execute_on_worker = mock_execute client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") # Worker queue should be created assert 1234 not in arbiter.worker_queues @@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks: # Generate 50 chunks + end messages = [] for i in range(50): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) mock_reader = MockStreamReader(messages) @@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 51 @@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", [1, 2, 3, 4, 5]) + response = make_response(123, [1, 2, 3, 4, 5]) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - error = make_error_response("req-123", DirtyError("Something failed")) + error = make_error_response(123, DirtyError("Something failed")) mock_reader = MockStreamReader([error]) async def mock_get_connection(pid): @@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 diff --git a/tests/dirty/test_client_streaming.py b/tests/dirty/test_client_streaming.py index 7bc13525..eca76e98 100644 --- a/tests/dirty/test_client_streaming.py +++ b/tests/dirty/test_client_streaming.py @@ -11,10 +11,12 @@ from unittest import mock from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyConnectionError @@ -26,7 +28,7 @@ class MockSocket: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 self._sent = [] self.closed = False @@ -69,10 +71,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_chunks(self): """Test that stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -83,9 +85,9 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_complex_chunks(self): """Test that stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -98,8 +100,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_handles_error(self): """Test that stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_client_with_mock_socket(messages) @@ -116,7 +118,7 @@ class TestDirtyStreamIterator: def test_stream_iterator_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) @@ -125,8 +127,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_stops_after_exhausted(self): """Test that iterator stays exhausted after StopIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -147,10 +149,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_with_for_loop(self): """Test stream iterator works in for loop.""" messages = [ - make_chunk_message("req-123", "a"), - make_chunk_message("req-123", "b"), - make_chunk_message("req-123", "c"), - make_end_message("req-123"), + make_chunk_message(123, "a"), + make_chunk_message(123, "b"), + make_chunk_message(123, "c"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -163,8 +165,8 @@ class TestDirtyStreamIterator: def test_stream_sends_request_on_first_iteration(self): """Test that request is sent on first next() call.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -179,18 +181,15 @@ class TestDirtyStreamIterator: # Decode sent request sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyStreamIteratorEdgeCases: @@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases: """Test streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_client_with_mock_socket(messages) @@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases: def test_stream_with_kwargs(self): """Test streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} diff --git a/tests/dirty/test_client_streaming_async.py b/tests/dirty/test_client_streaming_async.py index 651c73d1..b38eff6c 100644 --- a/tests/dirty/test_client_streaming_async.py +++ b/tests/dirty/test_client_streaming_async.py @@ -10,9 +10,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError @@ -24,7 +26,7 @@ class MockAsyncReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_chunks(self): """Test that async stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_complex_chunks(self): """Test that async stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_handles_error(self): """Test that async stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_async_client_with_mocks(messages) @@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator: @pytest.mark.asyncio async def test_async_stream_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_async_client_with_mocks(messages) chunks = [] @@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_stops_after_exhausted(self): """Test that async iterator stays exhausted after StopAsyncIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_sends_request_on_first_iteration(self): """Test that request is sent on first async iteration.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator: # Decode sent request sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyAsyncStreamIteratorEdgeCases: @@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: """Test async streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_async_client_with_mocks(messages) @@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: async def test_async_stream_with_kwargs(self): """Test async streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} class TestDirtyAsyncStreamTimeout: diff --git a/tests/dirty/test_multi_app_routing.py b/tests/dirty/test_multi_app_routing.py index c113bab1..4e01b711 100644 --- a/tests/dirty/test_multi_app_routing.py +++ b/tests/dirty/test_multi_app_routing.py @@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -71,16 +76,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break diff --git a/tests/dirty/test_streaming_integration.py b/tests/dirty/test_streaming_integration.py index 06b9645f..b23fee38 100644 --- a/tests/dirty/test_streaming_integration.py +++ b/tests/dirty/test_streaming_integration.py @@ -18,11 +18,13 @@ from unittest import mock from gunicorn.config import Config from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter @@ -67,16 +69,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -96,7 +104,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -115,10 +123,10 @@ class TestStreamingEndToEnd: """Test complete flow: sync generator -> worker -> arbiter -> client.""" # Simulate what a worker would produce for a sync generator worker_messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] # Create an arbiter with mocked worker connection @@ -141,7 +149,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() # Execute request through arbiter - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Verify all messages were forwarded @@ -158,10 +166,10 @@ class TestStreamingEndToEnd: async def test_async_generator_end_to_end(self): """Test complete flow: async generator -> worker -> arbiter -> client.""" worker_messages = [ - make_chunk_message("req-456", "Async"), - make_chunk_message("req-456", " "), - make_chunk_message("req-456", "Stream"), - make_end_message("req-456"), + make_chunk_message(456, "Async"), + make_chunk_message(456, " "), + make_chunk_message(456, "Stream"), + make_end_message(456), ] cfg = Config() @@ -180,7 +188,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 4 @@ -197,9 +205,9 @@ class TestStreamingErrorHandling: async def test_error_mid_stream(self): """Test that errors during streaming are properly forwarded.""" worker_messages = [ - make_chunk_message("req-789", "First"), - make_chunk_message("req-789", "Second"), - make_error_response("req-789", DirtyError("Stream failed")), + make_chunk_message(789, "First"), + make_chunk_message(789, "Second"), + make_error_response(789, DirtyError("Stream failed")), ] cfg = Config() @@ -218,7 +226,7 @@ class TestStreamingErrorHandling: client_writer = MockStreamWriter() - request = make_request("req-789", "test:App", "generate_with_error") + request = make_request(789, "test:App", "generate_with_error") await arbiter._execute_on_worker(1234, request, client_writer) # Should have 2 chunks + 1 error @@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration: return sync_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end @@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration: return async_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await worker.handle_request(request, writer) # Should have 2 chunks + 1 end diff --git a/tests/dirty/test_worker_streaming.py b/tests/dirty/test_worker_streaming.py index bb674590..6efc471d 100644 --- a/tests/dirty/test_worker_streaming.py +++ b/tests/dirty/test_worker_streaming.py @@ -12,9 +12,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker @@ -30,17 +32,22 @@ class FakeStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_sync_generator_error_mid_stream(self): @@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_async_generator_error_mid_stream(self): @@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat: return args[0] + args[1] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "compute", args=(2, 3)) + request = make_request(123, "test:App", "compute", args=(2, 3)) await worker.handle_request(request, writer) # Should have 1 response message assert len(writer.messages) == 1 assert writer.messages[0]["type"] == "response" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["result"] == 5 @pytest.mark.asyncio @@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat: return [1, 2, 3, 4, 5] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await worker.handle_request(request, writer) # Should have 1 response message (not 5 chunks) @@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat: raise RuntimeError("Failed!") with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await worker.handle_request(request, writer) # Should have 1 error message @@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat: return None with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "void") + request = make_request(123, "test:App", "void") await worker.handle_request(request, writer) # Should have 1 response message @@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) assert len(writer.messages) == 3 # 2 chunks + 1 end @@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData: return empty_generate() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have just 1 end message @@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData: return generate_many() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 100 chunks + 1 end message @@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have been notified at least once per chunk + initial @@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation: writer = FakeStreamWriter() # Send a message with unknown type - message = {"type": "unknown", "id": "req-123"} + message = {"type": "unknown", "id": 123} await worker.handle_request(message, writer) assert len(writer.messages) == 1 diff --git a/tests/docker/http2/certs/server.crt b/tests/docker/http2/certs/server.crt new file mode 100644 index 00000000..b4056d76 --- /dev/null +++ b/tests/docker/http2/certs/server.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL +BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0 +MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx +EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG +A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY +6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt +z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq +AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2 +HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr +FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC +TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q +FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD +AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ +KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G +Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a +A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr +YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z +FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO +5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo= +-----END CERTIFICATE----- diff --git a/tests/docker/http2/certs/server.key b/tests/docker/http2/certs/server.key new file mode 100644 index 00000000..3d472c7c --- /dev/null +++ b/tests/docker/http2/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d +Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP +oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj +KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J +P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7 +wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/ +LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD +CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU +NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s +GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB +7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO +jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7 +2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD +xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1 +FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn +UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S +zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W +GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21 +9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC +7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf +7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS +HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY +4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J +vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g +5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp +JlfEuUHtKpFJReBnmWTFiQ== +-----END PRIVATE KEY----- diff --git a/tests/test_dirty_integration.py b/tests/test_dirty_integration.py index f24c6894..a841cf2c 100644 --- a/tests/test_dirty_integration.py +++ b/tests/test_dirty_integration.py @@ -11,7 +11,7 @@ import pytest from gunicorn.arbiter import Arbiter from gunicorn.config import Config from gunicorn.app.base import BaseApplication -from gunicorn.dirty.protocol import DirtyProtocol +from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE class MockStreamWriter: @@ -26,16 +26,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break