diff --git a/docs/content/dirty.md b/docs/content/dirty.md index 8afede06..bb98f6ee 100644 --- a/docs/content/dirty.md +++ b/docs/content/dirty.md @@ -17,7 +17,7 @@ Dirty Arbiters solve this by providing: - **Separate worker pool** - Completely separate from HTTP workers, can be killed/restarted independently - **Stateful workers** - Loaded resources persist in dirty worker memory -- **Message-passing IPC** - Communication via Unix sockets with JSON serialization +- **Message-passing IPC** - Communication via Unix sockets with binary TLV protocol - **Explicit API** - Clear `execute()` calls (no hidden IPC) - **Asyncio-based** - Clean concurrent handling with streaming support @@ -476,6 +476,64 @@ Worker -> Arbiter -> Client: chunk (data: "World") Worker -> Arbiter -> Client: end ``` +## Binary Protocol + +The dirty worker IPC uses a binary protocol inspired by OpenBSD msgctl/msgsnd for efficient data transfer. This eliminates base64 encoding overhead for binary data like images, audio, or model weights. + +### Header Format (16 bytes) + +``` ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Magic (2B) | Ver(1) | MType | Payload Length (4B) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Request ID (8 bytes) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +``` + +- **Magic**: `0x47 0x44` ("GD" for Gunicorn Dirty) +- **Version**: `0x01` +- **MType**: Message type (`0x01`=REQUEST, `0x02`=RESPONSE, `0x03`=ERROR, `0x04`=CHUNK, `0x05`=END) +- **Length**: Payload size (big-endian uint32, max 64MB) +- **Request ID**: uint64 identifier + +### TLV Payload Encoding + +Payloads use Type-Length-Value encoding: + +| Type | Code | Description | +|------|------|-------------| +| None | `0x00` | No value bytes | +| Bool | `0x01` | 1 byte (0x00/0x01) | +| Int64 | `0x05` | 8 bytes big-endian signed | +| Float64 | `0x06` | 8 bytes IEEE 754 | +| Bytes | `0x10` | 4-byte length + raw bytes | +| String | `0x11` | 4-byte length + UTF-8 | +| List | `0x20` | 4-byte count + elements | +| Dict | `0x21` | 4-byte count + key-value pairs | + +### Binary Data Benefits + +The binary protocol allows passing raw bytes directly without encoding: + +```python +# Image processing with binary data +def resize(self, image_data, width, height): + """Resize an image - image_data is raw bytes.""" + img = Image.open(io.BytesIO(image_data)) + resized = img.resize((width, height)) + buffer = io.BytesIO() + resized.save(buffer, format='PNG') + return buffer.getvalue() # Returns raw bytes + +# Called from HTTP worker +thumbnail = client.execute( + "myapp.images:ImageApp", + "thumbnail", + raw_image_bytes, # No base64 encoding needed + size=256 +) +``` + ### Error Handling in Streams Errors during streaming are delivered as error messages: @@ -768,7 +826,7 @@ except DirtyConnectionError: 2. **Set appropriate timeouts** based on your workload 3. **Handle errors gracefully** - dirty workers may restart 4. **Use meaningful action names** for easier debugging -5. **Keep responses JSON-serializable** - results are passed via IPC +5. **Keep responses serializable** - results are passed via binary IPC (supports bytes directly) ## Monitoring diff --git a/examples/celery_alternative/docker-compose.yml b/examples/celery_alternative/docker-compose.yml index ebf8d303..66ce282e 100644 --- a/examples/celery_alternative/docker-compose.yml +++ b/examples/celery_alternative/docker-compose.yml @@ -15,7 +15,7 @@ services: context: ../.. # Gunicorn repo root dockerfile: examples/celery_alternative/Dockerfile ports: - - "8000:8000" + - "8003:8000" environment: - GUNICORN_WORKERS=4 - DIRTY_WORKERS=9 diff --git a/examples/dirty_example/Dockerfile b/examples/dirty_example/Dockerfile new file mode 100644 index 00000000..302578dc --- /dev/null +++ b/examples/dirty_example/Dockerfile @@ -0,0 +1,25 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +FROM python:3.12-slim + +WORKDIR /app + +# Copy gunicorn source +COPY . /app/gunicorn-src + +# Install gunicorn and dependencies +# setproctitle is needed for process title changes +RUN pip install --no-cache-dir /app/gunicorn-src setproctitle + +# Copy example files +COPY examples/dirty_example/ /app/examples/dirty_example/ + +WORKDIR /app + +# Expose the port +EXPOSE 8000 + +# Default command - run the example tests +CMD ["python", "-m", "pytest", "-v", "examples/dirty_example/"] diff --git a/examples/dirty_example/docker-compose.yml b/examples/dirty_example/docker-compose.yml new file mode 100644 index 00000000..c55669a3 --- /dev/null +++ b/examples/dirty_example/docker-compose.yml @@ -0,0 +1,54 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +services: + # Run the example tests (protocol, dirty app, worker integration) + tests: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + command: > + bash -c " + echo '=== Running Protocol Tests ===' && + python examples/dirty_example/test_protocol.py && + echo '' && + echo '=== Running Dirty App Tests ===' && + python examples/dirty_example/test_dirty_app.py && + echo '' && + echo '=== Running Worker Integration Tests ===' && + python examples/dirty_example/test_worker_integration.py && + echo '' && + echo '=== All tests passed! ===' + " + + # Run the full gunicorn server with dirty workers + server: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + ports: + - "8001:8000" + environment: + - GUNICORN_BIND=0.0.0.0:8000 + command: > + gunicorn examples.dirty_example.wsgi_app:app + -c examples/dirty_example/gunicorn_conf.py + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/')"] + interval: 5s + timeout: 5s + retries: 5 + start_period: 10s + + # Run integration test against the server + integration-test: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + depends_on: + server: + condition: service_healthy + environment: + - TEST_BASE_URL=http://server:8000 + command: python examples/dirty_example/test_integration.py diff --git a/examples/dirty_example/gunicorn_conf.py b/examples/dirty_example/gunicorn_conf.py index a7877c3e..ba3d28e4 100644 --- a/examples/dirty_example/gunicorn_conf.py +++ b/examples/dirty_example/gunicorn_conf.py @@ -11,7 +11,9 @@ Run with: """ # Basic settings -bind = "127.0.0.1:8000" +# Use 0.0.0.0 for Docker, override with GUNICORN_BIND env var if needed +import os +bind = os.environ.get("GUNICORN_BIND", "127.0.0.1:8000") workers = 2 worker_class = "sync" timeout = 30 diff --git a/examples/dirty_example/test_integration.py b/examples/dirty_example/test_integration.py new file mode 100644 index 00000000..20008522 --- /dev/null +++ b/examples/dirty_example/test_integration.py @@ -0,0 +1,81 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +#!/usr/bin/env python +""" +Integration test for the dirty example server. + +This tests that the full gunicorn server with dirty workers responds +correctly to HTTP requests. + +Run with: + python examples/dirty_example/test_integration.py [base_url] + +Default base_url is http://localhost:8000 +""" + +import sys +import os +import json +import urllib.request +import urllib.error + + +def test_endpoint(base, path, expected_key=None): + """Test an endpoint and check for expected key in response.""" + url = base + path + print(f"Testing: {url}") + try: + with urllib.request.urlopen(url, timeout=10) as resp: + data = json.loads(resp.read()) + print(f" Response: {str(data)[:200]}") + if expected_key and expected_key not in data: + print(f" ERROR: Expected key '{expected_key}' not found!") + return False + return True + except urllib.error.HTTPError as e: + print(f" HTTP ERROR {e.code}: {e.reason}") + return False + except Exception as e: + print(f" ERROR: {e}") + return False + + +def main(): + # Get base URL from env or command line + base = os.environ.get("TEST_BASE_URL", "http://localhost:8000") + if len(sys.argv) > 1: + base = sys.argv[1] + + print(f"Testing dirty example server at: {base}") + print("=" * 60) + + # Define tests: (path, expected_key_in_response) + tests = [ + ("/", "endpoints"), + ("/models", "models"), + ("/load?name=test-model", "status"), + ("/inference?model=default&data=hello", "prediction"), + ("/fibonacci?n=20", "result"), + ("/prime?n=17", "is_prime"), + ("/stats", "ml_app"), + ("/unload?name=test-model", "status"), + ] + + failed = 0 + for path, key in tests: + if not test_endpoint(base, path, key): + failed += 1 + print() + + print("=" * 60) + if failed: + print(f"FAILED: {failed} tests failed") + sys.exit(1) + else: + print("SUCCESS: All integration tests passed!") + + +if __name__ == "__main__": + main() diff --git a/examples/dirty_example/test_protocol.py b/examples/dirty_example/test_protocol.py index 31c45b92..8a677d03 100644 --- a/examples/dirty_example/test_protocol.py +++ b/examples/dirty_example/test_protocol.py @@ -4,7 +4,10 @@ #!/usr/bin/env python """ -Test script to demonstrate the Dirty Protocol layer. +Test script to demonstrate the Dirty Binary Protocol layer. + +The binary protocol uses a 16-byte header + TLV-encoded payloads for efficient +binary data transfer without base64 encoding overhead. Run with: python examples/dirty_example/test_protocol.py @@ -18,10 +21,14 @@ import socket sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from gunicorn.dirty.protocol import ( + BinaryProtocol, DirtyProtocol, make_request, make_response, make_error_response, + HEADER_SIZE, + MAGIC, + VERSION, ) from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError @@ -29,13 +36,13 @@ from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError def test_protocol_encode_decode(): """Test protocol encoding and decoding.""" print("=" * 60) - print("Testing Protocol Encode/Decode") + print("Testing Binary Protocol Encode/Decode") print("=" * 60) - # Test request + # Test request with integer ID (recommended for binary protocol) print("\n1. Creating a request message...") request = make_request( - request_id="req-001", + request_id=12345, # Integer IDs are efficient app_path="myapp.ml:MLApp", action="inference", args=("model1",), @@ -43,18 +50,53 @@ def test_protocol_encode_decode(): ) print(f" Request: {request}") - # Encode - print("\n2. Encoding message...") - encoded = DirtyProtocol.encode(request) + # Encode using binary protocol + print("\n2. Encoding message with binary protocol...") + encoded = BinaryProtocol._encode_from_dict(request) print(f" Encoded length: {len(encoded)} bytes") - print(f" Header (4 bytes): {encoded[:4].hex()}") + print(f" Header ({HEADER_SIZE} bytes): {encoded[:HEADER_SIZE].hex()}") + print(f" Magic: {MAGIC!r}") + print(f" Version: {VERSION}") - # Decode - print("\n3. Decoding payload...") - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - print(f" Decoded: {decoded}") - print(f" Match: {decoded == request}") + # Decode header + print("\n3. Decoding header...") + msg_type, request_id, payload_len = BinaryProtocol.decode_header(encoded[:HEADER_SIZE]) + print(f" Message type: {msg_type} (0x{msg_type:02x})") + print(f" Request ID: {request_id}") + print(f" Payload length: {payload_len} bytes") + + # Decode full message + print("\n4. Decoding full message...") + msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded) + print(f" Type: {msg_type_str}") + print(f" Request ID: {req_id}") + print(f" Payload: {payload}") + + +def test_binary_data_handling(): + """Test binary data handling - the main advantage of binary protocol.""" + print("\n" + "=" * 60) + print("Testing Binary Data Handling") + print("=" * 60) + + # Create binary data (e.g., image, audio, model weights) + binary_data = bytes(range(256)) # All byte values + print(f"\n1. Original binary data: {len(binary_data)} bytes") + print(f" First 16 bytes: {binary_data[:16].hex()}") + + # Create response with binary data (no base64 needed!) + print("\n2. Encoding binary data in response...") + response = make_response(67890, {"image_data": binary_data, "format": "raw"}) + encoded = BinaryProtocol._encode_from_dict(response) + print(f" Encoded total size: {len(encoded)} bytes") + + # Decode and verify + print("\n3. Decoding binary data...") + msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded) + recovered_data = payload["result"]["image_data"] + print(f" Recovered data size: {len(recovered_data)} bytes") + print(f" Data matches: {recovered_data == binary_data}") + print(f" First 16 bytes: {recovered_data[:16].hex()}") def test_protocol_response(): @@ -65,13 +107,13 @@ def test_protocol_response(): # Success response print("\n1. Creating success response...") - response = make_response("req-001", {"result": "Hello, World!", "confidence": 0.95}) + response = make_response(12345, {"result": "Hello, World!", "confidence": 0.95}) print(f" Response: {response}") # Error response print("\n2. Creating error response...") error = DirtyTimeoutError("Operation timed out", timeout=30) - error_response = make_error_response("req-001", error) + error_response = make_error_response(12345, error) print(f" Error response: {error_response}") @@ -88,7 +130,7 @@ def test_socket_communication(): # Send a request print("\n1. Sending request over socket...") request = make_request( - request_id="socket-req-001", + request_id=100001, app_path="test:App", action="compute", args=(1, 2, 3), @@ -101,19 +143,20 @@ def test_socket_communication(): print("\n2. Receiving request...") received = DirtyProtocol.read_message(server_sock) print(f" Received: {received}") - print(f" Match: {received == request}") + print(f" Request ID: {received['id']}") - # Send a response - print("\n3. Sending response...") - response = make_response("socket-req-001", {"sum": 6}) + # Send a response with binary data + print("\n3. Sending response with binary data...") + binary_result = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc" + response = make_response(100001, {"data": binary_result, "sum": 6}) DirtyProtocol.write_message(server_sock, response) - print(f" Sent: {response}") + print(f" Sent binary data: {binary_result.hex()}") # Receive the response print("\n4. Receiving response...") received = DirtyProtocol.read_message(client_sock) - print(f" Received: {received}") - print(f" Match: {received == response}") + print(f" Received binary data: {received['result']['data'].hex()}") + print(f" Sum: {received['result']['sum']}") finally: server_sock.close() @@ -132,7 +175,7 @@ async def test_async_communication(): try: # Create message request = make_request( - request_id="async-req-001", + request_id=200001, app_path="async:App", action="process", args=("data",), @@ -141,7 +184,7 @@ async def test_async_communication(): # Write to pipe print("\n1. Writing async message...") - encoded = DirtyProtocol.encode(request) + encoded = BinaryProtocol._encode_from_dict(request) os.write(write_fd, encoded) os.close(write_fd) write_fd = None @@ -156,7 +199,7 @@ async def test_async_communication(): received = await DirtyProtocol.read_message_async(reader) print(f" Received: {received}") - print(f" Match: {received == request}") + print(f" Request ID: {received['id']}") finally: if write_fd is not None: @@ -193,10 +236,11 @@ def test_error_serialization(): if __name__ == "__main__": print("\n" + "#" * 60) - print("# Dirty Protocol Demonstration") + print("# Dirty Binary Protocol Demonstration") print("#" * 60) test_protocol_encode_decode() + test_binary_data_handling() test_protocol_response() test_socket_communication() asyncio.run(test_async_communication()) diff --git a/examples/dirty_example/test_worker_integration.py b/examples/dirty_example/test_worker_integration.py index 1711ee3d..acca9961 100644 --- a/examples/dirty_example/test_worker_integration.py +++ b/examples/dirty_example/test_worker_integration.py @@ -22,7 +22,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, make_request, HEADER_SIZE class MockLog: @@ -35,6 +35,36 @@ class MockLog: def reopen_files(self): pass +class MockWriter: + """Mock StreamWriter that captures written responses.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + + def write(self, data): + self._buffer += data + + async def drain(self): + # Decode messages from buffer using binary protocol + while len(self._buffer) >= HEADER_SIZE: + _, _, length = BinaryProtocol.decode_header(self._buffer[:HEADER_SIZE]) + total_size = HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[:total_size] + self._buffer = self._buffer[total_size:] + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) + else: + break + + def get_last_response(self): + """Get the last response message.""" + return self.messages[-1] if self.messages else None + + async def test_worker_request_handling(): """Test that a worker can load apps and handle requests.""" print("=" * 60) @@ -75,52 +105,60 @@ async def test_worker_request_handling(): # Test handle_request with a proper request message print("\n3. Testing handle_request() - load_model...") request = make_request( - request_id="test-001", + request_id=1001, app_path="examples.dirty_example.dirty_app:MLApp", action="load_model", args=("gpt-4",), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Result: {response.get('result', response.get('error'))}") # Test inference print("\n4. Testing handle_request() - inference...") request = make_request( - request_id="test-002", + request_id=1002, app_path="examples.dirty_example.dirty_app:MLApp", action="inference", args=("default", "Hello AI!"), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Result: {response.get('result', response.get('error'))}") # Test error handling print("\n5. Testing error handling - unknown action...") request = make_request( - request_id="test-003", + request_id=1003, app_path="examples.dirty_example.dirty_app:MLApp", action="nonexistent_action", args=(), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Error: {response.get('error', {}).get('message')}") # Test app not found print("\n6. Testing error handling - app not found...") request = make_request( - request_id="test-004", + request_id=1004, app_path="nonexistent:App", action="test", args=(), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Error type: {response.get('error', {}).get('error_type')}") diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index e5ac6cfa..15fab29a 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -3,89 +3,304 @@ # See the NOTICE for more information. """ -Dirty Arbiters Protocol +Dirty Worker Binary Protocol -Length-prefixed JSON message framing over Unix sockets. -Provides both async (primary) and sync (for HTTP workers) APIs. +Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd. +Replaces JSON protocol for efficient binary data transfer. -Message Format: -+----------------+------------------+ -| 4-byte length | JSON payload | -+----------------+------------------+ +Header Format (16 bytes): ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Magic (2B) | Ver(1) | MType | Payload Length (4B) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Request ID (8 bytes) | ++--------+--------+--------+--------+--------+--------+--------+--------+ -The length field is a 4-byte unsigned integer in network byte order (big-endian). +- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty) +- Version: 0x01 +- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END) +- Length: Payload size (big-endian uint32, max 64MB) +- Request ID: uint64 (replaces UUID string) + +Payload is TLV-encoded (see tlv.py). """ import asyncio -import json -import struct import socket +import struct from .errors import DirtyProtocolError +from .tlv import TLVEncoder -class DirtyProtocol: - """Length-prefixed JSON messages over Unix sockets.""" +# Protocol constants +MAGIC = b"GD" # 0x47 0x44 +VERSION = 0x01 - # 4-byte unsigned int, network byte order (big-endian) - HEADER_FORMAT = "!I" - HEADER_SIZE = struct.calcsize(HEADER_FORMAT) +# Message types (1 byte) +MSG_TYPE_REQUEST = 0x01 +MSG_TYPE_RESPONSE = 0x02 +MSG_TYPE_ERROR = 0x03 +MSG_TYPE_CHUNK = 0x04 +MSG_TYPE_END = 0x05 - # Maximum message size (64 MB) - MAX_MESSAGE_SIZE = 64 * 1024 * 1024 +# Message type names (for backwards compatibility with old API) +MSG_TYPE_REQUEST_STR = "request" +MSG_TYPE_RESPONSE_STR = "response" +MSG_TYPE_ERROR_STR = "error" +MSG_TYPE_CHUNK_STR = "chunk" +MSG_TYPE_END_STR = "end" - # Message types for future streaming support - MSG_TYPE_REQUEST = "request" - MSG_TYPE_RESPONSE = "response" - MSG_TYPE_ERROR = "error" - MSG_TYPE_CHUNK = "chunk" - MSG_TYPE_END = "end" +# Map int types to string names +MSG_TYPE_TO_STR = { + MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR, + MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR, + MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR, + MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR, + MSG_TYPE_END: MSG_TYPE_END_STR, +} + +# Map string names to int types +MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()} + +# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 +HEADER_FORMAT = ">2sBBIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) + +# Maximum message size (64 MB) +MAX_MESSAGE_SIZE = 64 * 1024 * 1024 + + +class BinaryProtocol: + """Binary message protocol for dirty worker IPC.""" + + # Export constants for external use + HEADER_SIZE = HEADER_SIZE + MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE + + MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR + MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR + MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR + MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR + MSG_TYPE_END = MSG_TYPE_END_STR @staticmethod - def encode(message: dict) -> bytes: + def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: """ - Encode a message dict to length-prefixed bytes. + Encode the 16-byte message header. Args: - message: Dictionary to encode as JSON + msg_type: Message type (MSG_TYPE_REQUEST, etc.) + request_id: Unique request identifier (uint64) + payload_length: Length of the TLV-encoded payload Returns: - bytes: Length-prefixed encoded message + bytes: 16-byte header + """ + return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type, + payload_length, request_id) + + @staticmethod + def decode_header(data: bytes) -> tuple: + """ + Decode the 16-byte message header. + + Args: + data: 16 bytes of header data + + Returns: + tuple: (msg_type, request_id, payload_length) Raises: - DirtyProtocolError: If encoding fails + DirtyProtocolError: If header is invalid """ - try: - payload = json.dumps(message).encode("utf-8") - if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE: + if len(data) < HEADER_SIZE: + raise DirtyProtocolError( + f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}", + raw_data=data + ) + + magic, version, msg_type, length, request_id = struct.unpack( + HEADER_FORMAT, data[:HEADER_SIZE] + ) + + if magic != MAGIC: + raise DirtyProtocolError( + f"Invalid magic: {magic!r}, expected {MAGIC!r}", + raw_data=data[:20] + ) + + if version != VERSION: + raise DirtyProtocolError( + f"Unsupported protocol version: {version}, expected {VERSION}", + raw_data=data[:20] + ) + + if msg_type not in MSG_TYPE_TO_STR: + raise DirtyProtocolError( + f"Unknown message type: 0x{msg_type:02x}", + raw_data=data[:20] + ) + + if length > MAX_MESSAGE_SIZE: + raise DirtyProtocolError( + f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})" + ) + + return msg_type, request_id, length + + @staticmethod + def encode_request(request_id: int, app_path: str, action: str, + args: tuple = None, kwargs: dict = None) -> bytes: + """ + Encode a request message. + + Args: + request_id: Unique request identifier (uint64) + app_path: Import path of the dirty app + action: Action to call on the app + args: Positional arguments + kwargs: Keyword arguments + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = { + "app_path": app_path, + "action": action, + "args": list(args) if args else [], + "kwargs": kwargs or {}, + } + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_response(request_id: int, result) -> bytes: + """ + Encode a success response message. + + Args: + request_id: Request identifier this responds to + result: Result value (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"result": result} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_error(request_id: int, error) -> bytes: + """ + Encode an error response message. + + Args: + request_id: Request identifier this responds to + error: DirtyError instance, dict, or Exception + + Returns: + bytes: Complete message (header + payload) + """ + from .errors import DirtyError + + if isinstance(error, DirtyError): + error_dict = error.to_dict() + elif isinstance(error, dict): + error_dict = error + else: + error_dict = { + "error_type": type(error).__name__, + "message": str(error), + "details": {}, + } + + payload_dict = {"error": error_dict} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_chunk(request_id: int, data) -> bytes: + """ + Encode a chunk message for streaming responses. + + Args: + request_id: Request identifier this chunk belongs to + data: Chunk data (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"data": data} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_end(request_id: int) -> bytes: + """ + Encode an end-of-stream message. + + Args: + request_id: Request identifier this ends + + Returns: + bytes: Complete message (header + empty payload) + """ + # End message has empty payload + header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0) + return header + + @staticmethod + def decode_message(data: bytes) -> tuple: + """ + Decode a complete message (header + payload). + + Args: + data: Complete message bytes + + Returns: + tuple: (msg_type_str, request_id, payload_dict) + msg_type_str is the string name (e.g., "request") + payload_dict is the decoded TLV payload as a dict + + Raises: + DirtyProtocolError: If message is malformed + """ + msg_type, request_id, length = BinaryProtocol.decode_header(data) + + if len(data) < HEADER_SIZE + length: + raise DirtyProtocolError( + f"Incomplete message: expected {HEADER_SIZE + length} bytes, " + f"got {len(data)}", + raw_data=data[:50] + ) + + if length == 0: + # End message has empty payload + payload_dict = {} + else: + payload_data = data[HEADER_SIZE:HEADER_SIZE + length] + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: raise DirtyProtocolError( - f"Message too large: {len(payload)} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] ) - length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload)) - return length + payload - except (TypeError, ValueError) as e: - raise DirtyProtocolError(f"Failed to encode message: {e}") - @staticmethod - def decode(data: bytes) -> dict: - """ - Decode bytes (without length prefix) to message dict. + # Convert to dict format similar to old JSON protocol + msg_type_str = MSG_TYPE_TO_STR[msg_type] - Args: - data: JSON bytes to decode - - Returns: - dict: Decoded message - - Raises: - DirtyProtocolError: If decoding fails - """ - try: - return json.loads(data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - raise DirtyProtocolError(f"Failed to decode message: {e}", - raw_data=data) + return msg_type_str, request_id, payload_dict # ------------------------------------------------------------------------- # Async API (primary - for DirtyArbiter and DirtyWorker) @@ -94,53 +309,62 @@ class DirtyProtocol: @staticmethod async def read_message_async(reader: asyncio.StreamReader) -> dict: """ - Read a complete message from async stream. + Read a complete binary message from async stream. Args: reader: asyncio StreamReader Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed asyncio.IncompleteReadError: If connection closed mid-read """ - # Read length header + # Read header try: - header = await reader.readexactly(DirtyProtocol.HEADER_SIZE) + header = await reader.readexactly(HEADER_SIZE) except asyncio.IncompleteReadError as e: if len(e.partial) == 0: # Clean close - no data was read raise raise DirtyProtocolError( f"Incomplete header: got {len(e.partial)} bytes, " - f"expected {DirtyProtocol.HEADER_SIZE}", + f"expected {HEADER_SIZE}", raw_data=e.partial ) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - try: - payload = await reader.readexactly(length) - except asyncio.IncompleteReadError as e: - raise DirtyProtocolError( - f"Incomplete message: got {len(e.partial)} bytes, " - f"expected {length}", - raw_data=e.partial - ) + if length > 0: + try: + payload_data = await reader.readexactly(length) + except asyncio.IncompleteReadError as e: + raise DirtyProtocolError( + f"Incomplete payload: got {len(e.partial)} bytes, " + f"expected {length}", + raw_data=e.partial + ) - return DirtyProtocol.decode(payload) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod async def write_message_async(writer: asyncio.StreamWriter, @@ -148,15 +372,17 @@ class DirtyProtocol: """ Write a message to async stream. + Accepts dict format for backwards compatibility. + Args: writer: asyncio StreamWriter - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails ConnectionError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) writer.write(data) await writer.drain() @@ -201,27 +427,36 @@ class DirtyProtocol: sock: Socket to read from Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed """ - # Read length header - header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + # Read header + header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE) + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - payload = DirtyProtocol._recv_exactly(sock, length) - return DirtyProtocol.decode(payload) + if length > 0: + payload_data = BinaryProtocol._recv_exactly(sock, length) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod def write_message(sock: socket.socket, message: dict) -> None: @@ -230,31 +465,92 @@ class DirtyProtocol: Args: sock: Socket to write to - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails OSError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) sock.sendall(data) + @staticmethod + def _encode_from_dict(message: dict) -> bytes: + """ + Encode a message dict to binary format. -# Message builder helpers -def make_request(request_id: str, app_path: str, action: str, + Supports the old dict-based API for backwards compatibility. + + Args: + message: Message dict with 'type', 'id', and payload fields + + Returns: + bytes: Complete encoded message + """ + msg_type_str = message.get("type") + request_id = message.get("id", 0) + + # Handle string or int request IDs + if isinstance(request_id, str): + # For backwards compat with UUID strings, hash to int + request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF + + msg_type = MSG_TYPE_FROM_STR.get(msg_type_str) + if msg_type is None: + raise DirtyProtocolError(f"Unknown message type: {msg_type_str}") + + if msg_type == MSG_TYPE_REQUEST: + return BinaryProtocol.encode_request( + request_id, + message.get("app_path", ""), + message.get("action", ""), + message.get("args"), + message.get("kwargs") + ) + elif msg_type == MSG_TYPE_RESPONSE: + return BinaryProtocol.encode_response( + request_id, + message.get("result") + ) + elif msg_type == MSG_TYPE_ERROR: + return BinaryProtocol.encode_error( + request_id, + message.get("error", {}) + ) + elif msg_type == MSG_TYPE_CHUNK: + return BinaryProtocol.encode_chunk( + request_id, + message.get("data") + ) + elif msg_type == MSG_TYPE_END: + return BinaryProtocol.encode_end(request_id) + else: + raise DirtyProtocolError(f"Unhandled message type: {msg_type}") + + +# ============================================================================= +# Backwards Compatibility Aliases +# ============================================================================= + +# Alias BinaryProtocol as DirtyProtocol for drop-in replacement +DirtyProtocol = BinaryProtocol + + +# Message builder helpers (backwards compatible with old API) +def make_request(request_id, app_path: str, action: str, args: tuple = None, kwargs: dict = None) -> dict: """ - Build a request message. + Build a request message dict. Args: - request_id: Unique request identifier + request_id: Unique request identifier (int or str) app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp') action: Action to call on the app args: Positional arguments kwargs: Keyword arguments Returns: - dict: Request message + dict: Request message dict """ return { "type": DirtyProtocol.MSG_TYPE_REQUEST, @@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str, } -def make_response(request_id: str, result) -> dict: +def make_response(request_id, result) -> dict: """ - Build a success response message. + Build a success response message dict. Args: request_id: Request identifier this responds to - result: Result value (must be JSON-serializable) + result: Result value Returns: - dict: Response message + dict: Response message dict """ return { "type": DirtyProtocol.MSG_TYPE_RESPONSE, @@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict: } -def make_error_response(request_id: str, error) -> dict: +def make_error_response(request_id, error) -> dict: """ - Build an error response message. + Build an error response message dict. Args: request_id: Request identifier this responds to error: DirtyError instance or dict with error info Returns: - dict: Error response message + dict: Error response message dict """ from .errors import DirtyError if isinstance(error, DirtyError): @@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict: } -def make_chunk_message(request_id: str, data) -> dict: +def make_chunk_message(request_id, data) -> dict: """ - Build a chunk message for streaming responses. + Build a chunk message dict for streaming responses. Args: request_id: Request identifier this chunk belongs to - data: Chunk data (must be JSON-serializable) + data: Chunk data Returns: - dict: Chunk message + dict: Chunk message dict """ return { "type": DirtyProtocol.MSG_TYPE_CHUNK, @@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict: } -def make_end_message(request_id: str) -> dict: +def make_end_message(request_id) -> dict: """ - Build an end-of-stream message. + Build an end-of-stream message dict. Args: request_id: Request identifier this ends Returns: - dict: End message + dict: End message dict """ return { "type": DirtyProtocol.MSG_TYPE_END, diff --git a/gunicorn/dirty/tlv.py b/gunicorn/dirty/tlv.py new file mode 100644 index 00000000..ec18cd76 --- /dev/null +++ b/gunicorn/dirty/tlv.py @@ -0,0 +1,303 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +TLV (Type-Length-Value) Binary Encoder/Decoder + +Provides efficient binary serialization for dirty worker protocol messages. +Inspired by OpenBSD msgctl/msgsnd message format. + +Type Codes: + 0x00: None (no value bytes) + 0x01: bool (1 byte: 0x00 or 0x01) + 0x05: int64 (8 bytes big-endian signed) + 0x06: float64 (8 bytes IEEE 754) + 0x10: bytes (4-byte length + raw bytes) + 0x11: string (4-byte length + UTF-8 encoded) + 0x20: list (4-byte count + encoded elements) + 0x21: dict (4-byte count + encoded key-value pairs) +""" + +import struct + +from .errors import DirtyProtocolError + + +# Type codes +TYPE_NONE = 0x00 +TYPE_BOOL = 0x01 +TYPE_INT64 = 0x05 +TYPE_FLOAT64 = 0x06 +TYPE_BYTES = 0x10 +TYPE_STRING = 0x11 +TYPE_LIST = 0x20 +TYPE_DICT = 0x21 + +# Maximum sizes for safety +MAX_STRING_SIZE = 64 * 1024 * 1024 # 64 MB +MAX_BYTES_SIZE = 64 * 1024 * 1024 # 64 MB +MAX_LIST_SIZE = 1024 * 1024 # 1 million items +MAX_DICT_SIZE = 1024 * 1024 # 1 million items + + +class TLVEncoder: + """ + TLV binary encoder/decoder. + + Encodes Python values to binary TLV format and decodes back. + Supports: None, bool, int, float, bytes, str, list, dict. + """ + + @staticmethod + def encode(value) -> bytes: # pylint: disable=too-many-return-statements + """ + Encode a Python value to TLV binary format. + + Args: + value: Python value to encode (None, bool, int, float, + bytes, str, list, or dict) + + Returns: + bytes: TLV-encoded binary data + + Raises: + DirtyProtocolError: If value type is not supported + """ + if value is None: + return bytes([TYPE_NONE]) + + if isinstance(value, bool): + # bool must come before int since bool is a subclass of int + return bytes([TYPE_BOOL, 0x01 if value else 0x00]) + + if isinstance(value, int): + return bytes([TYPE_INT64]) + struct.pack(">q", value) + + if isinstance(value, float): + return bytes([TYPE_FLOAT64]) + struct.pack(">d", value) + + if isinstance(value, bytes): + if len(value) > MAX_BYTES_SIZE: + raise DirtyProtocolError( + f"Bytes too large: {len(value)} bytes " + f"(max: {MAX_BYTES_SIZE})" + ) + return bytes([TYPE_BYTES]) + struct.pack(">I", len(value)) + value + + if isinstance(value, str): + encoded = value.encode("utf-8") + if len(encoded) > MAX_STRING_SIZE: + raise DirtyProtocolError( + f"String too large: {len(encoded)} bytes " + f"(max: {MAX_STRING_SIZE})" + ) + return bytes([TYPE_STRING]) + struct.pack(">I", len(encoded)) + encoded + + if isinstance(value, (list, tuple)): + if len(value) > MAX_LIST_SIZE: + raise DirtyProtocolError( + f"List too large: {len(value)} items " + f"(max: {MAX_LIST_SIZE})" + ) + parts = [bytes([TYPE_LIST]), struct.pack(">I", len(value))] + for item in value: + parts.append(TLVEncoder.encode(item)) + return b"".join(parts) + + if isinstance(value, dict): + if len(value) > MAX_DICT_SIZE: + raise DirtyProtocolError( + f"Dict too large: {len(value)} items " + f"(max: {MAX_DICT_SIZE})" + ) + parts = [bytes([TYPE_DICT]), struct.pack(">I", len(value))] + for k, v in value.items(): + # Convert keys to strings (like JSON) + if not isinstance(k, str): + k = str(k) + parts.append(TLVEncoder.encode(k)) + parts.append(TLVEncoder.encode(v)) + return b"".join(parts) + + raise DirtyProtocolError( + f"Unsupported type for TLV encoding: {type(value).__name__}" + ) + + @staticmethod + def decode(data: bytes, offset: int = 0) -> tuple: # pylint: disable=too-many-return-statements + """ + Decode a TLV-encoded value from binary data. + + Args: + data: Binary data to decode + offset: Starting offset in the data + + Returns: + tuple: (decoded_value, new_offset) + + Raises: + DirtyProtocolError: If data is malformed or truncated + """ + if offset >= len(data): + raise DirtyProtocolError( + "Truncated TLV data: no type byte", + raw_data=data[offset:offset + 20] + ) + + type_code = data[offset] + offset += 1 + + if type_code == TYPE_NONE: + return None, offset + + if type_code == TYPE_BOOL: + if offset >= len(data): + raise DirtyProtocolError( + "Truncated TLV data: missing bool value", + raw_data=data[offset - 1:offset + 20] + ) + value = data[offset] != 0x00 + return value, offset + 1 + + if type_code == TYPE_INT64: + if offset + 8 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete int64", + raw_data=data[offset - 1:offset + 20] + ) + value = struct.unpack(">q", data[offset:offset + 8])[0] + return value, offset + 8 + + if type_code == TYPE_FLOAT64: + if offset + 8 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete float64", + raw_data=data[offset - 1:offset + 20] + ) + value = struct.unpack(">d", data[offset:offset + 8])[0] + return value, offset + 8 + + if type_code == TYPE_BYTES: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete bytes length", + raw_data=data[offset - 1:offset + 20] + ) + length = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if length > MAX_BYTES_SIZE: + raise DirtyProtocolError( + f"Bytes too large: {length} bytes (max: {MAX_BYTES_SIZE})" + ) + + if offset + length > len(data): + raise DirtyProtocolError( + f"Truncated TLV data: expected {length} bytes, " + f"got {len(data) - offset}", + raw_data=data[offset - 5:offset + 20] + ) + value = data[offset:offset + length] + return value, offset + length + + if type_code == TYPE_STRING: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete string length", + raw_data=data[offset - 1:offset + 20] + ) + length = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if length > MAX_STRING_SIZE: + raise DirtyProtocolError( + f"String too large: {length} bytes (max: {MAX_STRING_SIZE})" + ) + + if offset + length > len(data): + raise DirtyProtocolError( + f"Truncated TLV data: expected {length} bytes for string, " + f"got {len(data) - offset}", + raw_data=data[offset - 5:offset + 20] + ) + try: + value = data[offset:offset + length].decode("utf-8") + except UnicodeDecodeError as e: + raise DirtyProtocolError( + f"Invalid UTF-8 in string: {e}", + raw_data=data[offset:offset + min(length, 20)] + ) + return value, offset + length + + if type_code == TYPE_LIST: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete list count", + raw_data=data[offset - 1:offset + 20] + ) + count = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if count > MAX_LIST_SIZE: + raise DirtyProtocolError( + f"List too large: {count} items (max: {MAX_LIST_SIZE})" + ) + + items = [] + for _ in range(count): + item, offset = TLVEncoder.decode(data, offset) + items.append(item) + return items, offset + + if type_code == TYPE_DICT: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete dict count", + raw_data=data[offset - 1:offset + 20] + ) + count = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if count > MAX_DICT_SIZE: + raise DirtyProtocolError( + f"Dict too large: {count} items (max: {MAX_DICT_SIZE})" + ) + + result = {} + for _ in range(count): + key, offset = TLVEncoder.decode(data, offset) + if not isinstance(key, str): + raise DirtyProtocolError( + f"Dict key must be string, got {type(key).__name__}" + ) + value, offset = TLVEncoder.decode(data, offset) + result[key] = value + return result, offset + + raise DirtyProtocolError( + f"Unknown TLV type code: 0x{type_code:02x}", + raw_data=data[offset - 1:offset + 20] + ) + + @staticmethod + def decode_full(data: bytes): + """ + Decode a complete TLV-encoded value, ensuring all data is consumed. + + Args: + data: Binary data to decode + + Returns: + Decoded Python value + + Raises: + DirtyProtocolError: If data is malformed or has trailing bytes + """ + value, offset = TLVEncoder.decode(data, 0) + if offset != len(data): + raise DirtyProtocolError( + f"Trailing data after TLV: {len(data) - offset} bytes", + raw_data=data[offset:offset + 20] + ) + return value diff --git a/tests/dirty/test_arbiter_streaming.py b/tests/dirty/test_arbiter_streaming.py index ef15c33a..a722f2af 100644 --- a/tests/dirty/test_arbiter_streaming.py +++ b/tests/dirty/test_arbiter_streaming.py @@ -12,11 +12,13 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_response, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError @@ -34,16 +36,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -63,7 +71,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding: client_writer = MockStreamWriter() # Mock worker connection that returns chunks - chunk1 = make_chunk_message("req-123", "Hello") - chunk2 = make_chunk_message("req-123", " World") - end = make_end_message("req-123") + chunk1 = make_chunk_message(123, "Hello") + chunk2 = make_chunk_message(123, " World") + end = make_end_message(123) mock_reader = MockStreamReader([chunk1, chunk2, end]) @@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Should have forwarded all messages @@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", {"result": 42}) + response = make_response(123, {"result": 42}) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - chunk = make_chunk_message("req-123", "First") - error = make_error_response("req-123", DirtyError("Something broke")) + chunk = make_chunk_message(123, "First") + error = make_error_response(123, DirtyError("Something broke")) mock_reader = MockStreamReader([chunk, error]) @@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 2 @@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming: arbiter.workers = {} # No workers client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter.route_request(request, client_writer) assert len(client_writer.messages) == 1 @@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming: # Mock _execute_on_worker to complete immediately async def mock_execute(pid, request, client_writer): - response = make_response("req-123", "result") + response = make_response(123, "result") await DirtyProtocol.write_message_async(client_writer, response) arbiter._execute_on_worker = mock_execute client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") # Worker queue should be created assert 1234 not in arbiter.worker_queues @@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks: # Generate 50 chunks + end messages = [] for i in range(50): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) mock_reader = MockStreamReader(messages) @@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 51 @@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", [1, 2, 3, 4, 5]) + response = make_response(123, [1, 2, 3, 4, 5]) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - error = make_error_response("req-123", DirtyError("Something failed")) + error = make_error_response(123, DirtyError("Something failed")) mock_reader = MockStreamReader([error]) async def mock_get_connection(pid): @@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 diff --git a/tests/dirty/test_client_streaming.py b/tests/dirty/test_client_streaming.py index 7bc13525..eca76e98 100644 --- a/tests/dirty/test_client_streaming.py +++ b/tests/dirty/test_client_streaming.py @@ -11,10 +11,12 @@ from unittest import mock from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyConnectionError @@ -26,7 +28,7 @@ class MockSocket: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 self._sent = [] self.closed = False @@ -69,10 +71,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_chunks(self): """Test that stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -83,9 +85,9 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_complex_chunks(self): """Test that stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -98,8 +100,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_handles_error(self): """Test that stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_client_with_mock_socket(messages) @@ -116,7 +118,7 @@ class TestDirtyStreamIterator: def test_stream_iterator_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) @@ -125,8 +127,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_stops_after_exhausted(self): """Test that iterator stays exhausted after StopIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -147,10 +149,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_with_for_loop(self): """Test stream iterator works in for loop.""" messages = [ - make_chunk_message("req-123", "a"), - make_chunk_message("req-123", "b"), - make_chunk_message("req-123", "c"), - make_end_message("req-123"), + make_chunk_message(123, "a"), + make_chunk_message(123, "b"), + make_chunk_message(123, "c"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -163,8 +165,8 @@ class TestDirtyStreamIterator: def test_stream_sends_request_on_first_iteration(self): """Test that request is sent on first next() call.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -179,18 +181,15 @@ class TestDirtyStreamIterator: # Decode sent request sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyStreamIteratorEdgeCases: @@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases: """Test streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_client_with_mock_socket(messages) @@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases: def test_stream_with_kwargs(self): """Test streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} diff --git a/tests/dirty/test_client_streaming_async.py b/tests/dirty/test_client_streaming_async.py index 651c73d1..b38eff6c 100644 --- a/tests/dirty/test_client_streaming_async.py +++ b/tests/dirty/test_client_streaming_async.py @@ -10,9 +10,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError @@ -24,7 +26,7 @@ class MockAsyncReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_chunks(self): """Test that async stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_complex_chunks(self): """Test that async stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_handles_error(self): """Test that async stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_async_client_with_mocks(messages) @@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator: @pytest.mark.asyncio async def test_async_stream_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_async_client_with_mocks(messages) chunks = [] @@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_stops_after_exhausted(self): """Test that async iterator stays exhausted after StopAsyncIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_sends_request_on_first_iteration(self): """Test that request is sent on first async iteration.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator: # Decode sent request sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyAsyncStreamIteratorEdgeCases: @@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: """Test async streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_async_client_with_mocks(messages) @@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: async def test_async_stream_with_kwargs(self): """Test async streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} class TestDirtyAsyncStreamTimeout: diff --git a/tests/dirty/test_multi_app_routing.py b/tests/dirty/test_multi_app_routing.py index c113bab1..4e01b711 100644 --- a/tests/dirty/test_multi_app_routing.py +++ b/tests/dirty/test_multi_app_routing.py @@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -71,16 +76,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break diff --git a/tests/dirty/test_streaming_integration.py b/tests/dirty/test_streaming_integration.py index 06b9645f..b23fee38 100644 --- a/tests/dirty/test_streaming_integration.py +++ b/tests/dirty/test_streaming_integration.py @@ -18,11 +18,13 @@ from unittest import mock from gunicorn.config import Config from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter @@ -67,16 +69,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -96,7 +104,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -115,10 +123,10 @@ class TestStreamingEndToEnd: """Test complete flow: sync generator -> worker -> arbiter -> client.""" # Simulate what a worker would produce for a sync generator worker_messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] # Create an arbiter with mocked worker connection @@ -141,7 +149,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() # Execute request through arbiter - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Verify all messages were forwarded @@ -158,10 +166,10 @@ class TestStreamingEndToEnd: async def test_async_generator_end_to_end(self): """Test complete flow: async generator -> worker -> arbiter -> client.""" worker_messages = [ - make_chunk_message("req-456", "Async"), - make_chunk_message("req-456", " "), - make_chunk_message("req-456", "Stream"), - make_end_message("req-456"), + make_chunk_message(456, "Async"), + make_chunk_message(456, " "), + make_chunk_message(456, "Stream"), + make_end_message(456), ] cfg = Config() @@ -180,7 +188,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 4 @@ -197,9 +205,9 @@ class TestStreamingErrorHandling: async def test_error_mid_stream(self): """Test that errors during streaming are properly forwarded.""" worker_messages = [ - make_chunk_message("req-789", "First"), - make_chunk_message("req-789", "Second"), - make_error_response("req-789", DirtyError("Stream failed")), + make_chunk_message(789, "First"), + make_chunk_message(789, "Second"), + make_error_response(789, DirtyError("Stream failed")), ] cfg = Config() @@ -218,7 +226,7 @@ class TestStreamingErrorHandling: client_writer = MockStreamWriter() - request = make_request("req-789", "test:App", "generate_with_error") + request = make_request(789, "test:App", "generate_with_error") await arbiter._execute_on_worker(1234, request, client_writer) # Should have 2 chunks + 1 error @@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration: return sync_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end @@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration: return async_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await worker.handle_request(request, writer) # Should have 2 chunks + 1 end diff --git a/tests/dirty/test_worker_streaming.py b/tests/dirty/test_worker_streaming.py index bb674590..6efc471d 100644 --- a/tests/dirty/test_worker_streaming.py +++ b/tests/dirty/test_worker_streaming.py @@ -12,9 +12,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker @@ -30,17 +32,22 @@ class FakeStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_sync_generator_error_mid_stream(self): @@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_async_generator_error_mid_stream(self): @@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat: return args[0] + args[1] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "compute", args=(2, 3)) + request = make_request(123, "test:App", "compute", args=(2, 3)) await worker.handle_request(request, writer) # Should have 1 response message assert len(writer.messages) == 1 assert writer.messages[0]["type"] == "response" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["result"] == 5 @pytest.mark.asyncio @@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat: return [1, 2, 3, 4, 5] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await worker.handle_request(request, writer) # Should have 1 response message (not 5 chunks) @@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat: raise RuntimeError("Failed!") with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await worker.handle_request(request, writer) # Should have 1 error message @@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat: return None with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "void") + request = make_request(123, "test:App", "void") await worker.handle_request(request, writer) # Should have 1 response message @@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) assert len(writer.messages) == 3 # 2 chunks + 1 end @@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData: return empty_generate() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have just 1 end message @@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData: return generate_many() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 100 chunks + 1 end message @@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have been notified at least once per chunk + initial @@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation: writer = FakeStreamWriter() # Send a message with unknown type - message = {"type": "unknown", "id": "req-123"} + message = {"type": "unknown", "id": 123} await worker.handle_request(message, writer) assert len(writer.messages) == 1 diff --git a/tests/docker/http2/certs/server.crt b/tests/docker/http2/certs/server.crt new file mode 100644 index 00000000..b4056d76 --- /dev/null +++ b/tests/docker/http2/certs/server.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL +BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0 +MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx +EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG +A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY +6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt +z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq +AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2 +HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr +FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC +TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q +FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD +AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ +KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G +Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a +A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr +YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z +FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO +5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo= +-----END CERTIFICATE----- diff --git a/tests/docker/http2/certs/server.key b/tests/docker/http2/certs/server.key new file mode 100644 index 00000000..3d472c7c --- /dev/null +++ b/tests/docker/http2/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d +Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP +oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj +KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J +P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7 +wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/ +LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD +CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU +NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s +GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB +7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO +jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7 +2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD +xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1 +FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn +UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S +zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W +GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21 +9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC +7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf +7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS +HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY +4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J +vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g +5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp +JlfEuUHtKpFJReBnmWTFiQ== +-----END PRIVATE KEY----- diff --git a/tests/test_dirty_arbiter.py b/tests/test_dirty_arbiter.py index 40abb504..05f35cb0 100644 --- a/tests/test_dirty_arbiter.py +++ b/tests/test_dirty_arbiter.py @@ -14,7 +14,12 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) class MockStreamWriter: @@ -29,16 +34,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break diff --git a/tests/test_dirty_integration.py b/tests/test_dirty_integration.py index f24c6894..a841cf2c 100644 --- a/tests/test_dirty_integration.py +++ b/tests/test_dirty_integration.py @@ -11,7 +11,7 @@ import pytest from gunicorn.arbiter import Arbiter from gunicorn.config import Config from gunicorn.app.base import BaseApplication -from gunicorn.dirty.protocol import DirtyProtocol +from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE class MockStreamWriter: @@ -26,16 +26,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break diff --git a/tests/test_dirty_protocol.py b/tests/test_dirty_protocol.py index dbabc51e..48fe3333 100644 --- a/tests/test_dirty_protocol.py +++ b/tests/test_dirty_protocol.py @@ -2,7 +2,7 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. -"""Tests for dirty arbiter protocol module.""" +"""Tests for dirty worker binary protocol module.""" import asyncio import os @@ -11,12 +11,23 @@ import struct import pytest from gunicorn.dirty.protocol import ( + BinaryProtocol, DirtyProtocol, make_request, make_response, make_error_response, make_chunk_message, make_end_message, + MAGIC, + VERSION, + HEADER_SIZE, + HEADER_FORMAT, + MSG_TYPE_REQUEST, + MSG_TYPE_RESPONSE, + MSG_TYPE_ERROR, + MSG_TYPE_CHUNK, + MSG_TYPE_END, + MAX_MESSAGE_SIZE, ) from gunicorn.dirty.errors import ( DirtyError, @@ -26,118 +37,194 @@ from gunicorn.dirty.errors import ( ) -class TestDirtyProtocolEncodeDecode: - """Tests for encode/decode functionality.""" +class TestBinaryProtocolHeader: + """Tests for header encoding/decoding.""" - def test_encode_decode_roundtrip(self): - """Test basic encode/decode roundtrip.""" - message = {"type": "request", "id": "123", "data": "hello"} - encoded = DirtyProtocol.encode(message) + def test_header_size(self): + """Test header size is 16 bytes.""" + assert HEADER_SIZE == 16 - # Check header format - assert len(encoded) > DirtyProtocol.HEADER_SIZE - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - encoded[:DirtyProtocol.HEADER_SIZE] - )[0] - assert length == len(encoded) - DirtyProtocol.HEADER_SIZE + def test_encode_header(self): + """Test header encoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100) + assert len(header) == HEADER_SIZE + assert header[:2] == MAGIC + assert header[2] == VERSION + assert header[3] == MSG_TYPE_REQUEST - # Decode payload - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header(self): + """Test header decoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200) + msg_type, request_id, length = BinaryProtocol.decode_header(header) + assert msg_type == MSG_TYPE_RESPONSE + assert request_id == 67890 + assert length == 200 - def test_encode_decode_complex_data(self): - """Test with complex nested data structures.""" - message = { - "type": "response", - "id": "456", - "result": { - "models": ["gpt-4", "claude-3"], - "config": {"temperature": 0.7, "max_tokens": 1000}, - "metadata": None, - }, - "args": [1, 2, 3], - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_magic(self): + """Test header decoding with invalid magic.""" + header = b"XX" + b"\x01\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "magic" in str(exc_info.value).lower() - def test_encode_decode_unicode(self): - """Test with unicode characters.""" - message = { - "type": "request", - "data": "Hello, world!" - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_version(self): + """Test header decoding with invalid version.""" + header = MAGIC + b"\x99\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "version" in str(exc_info.value).lower() - def test_encode_large_message(self): + def test_decode_header_invalid_type(self): + """Test header decoding with invalid message type.""" + header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "type" in str(exc_info.value).lower() + + def test_decode_header_too_large(self): + """Test header decoding rejects too-large messages.""" + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1, 0) + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "too large" in str(exc_info.value).lower() + + def test_decode_header_too_short(self): + """Test header decoding with too-short data.""" + header = MAGIC + b"\x01" + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "short" in str(exc_info.value).lower() + + +class TestBinaryProtocolEncodeDecode: + """Tests for message encoding/decoding.""" + + def test_encode_decode_request(self): + """Test request encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_request( + request_id=12345, + app_path="myapp.ml:MLApp", + action="predict", + args=("data",), + kwargs={"temperature": 0.7} + ) + assert len(encoded) > HEADER_SIZE + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "request" + assert request_id == 12345 + assert payload["app_path"] == "myapp.ml:MLApp" + assert payload["action"] == "predict" + assert payload["args"] == ["data"] + assert payload["kwargs"] == {"temperature": 0.7} + + def test_encode_decode_response(self): + """Test response encoding/decoding roundtrip.""" + result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}} + encoded = BinaryProtocol.encode_response(request_id=67890, result=result) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "response" + assert request_id == 67890 + assert payload["result"] == result + + def test_encode_decode_error(self): + """Test error encoding/decoding roundtrip.""" + error = DirtyTimeoutError("Timed out", timeout=30) + encoded = BinaryProtocol.encode_error(request_id=11111, error=error) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "error" + assert request_id == 11111 + assert payload["error"]["error_type"] == "DirtyTimeoutError" + assert "Timed out" in payload["error"]["message"] + + def test_encode_decode_chunk(self): + """Test chunk encoding/decoding roundtrip.""" + chunk_data = {"token": "hello", "index": 5} + encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "chunk" + assert request_id == 22222 + assert payload["data"] == chunk_data + + def test_encode_decode_end(self): + """Test end message encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_end(request_id=33333) + assert len(encoded) == HEADER_SIZE # End has no payload + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "end" + assert request_id == 33333 + assert payload == {} + + def test_encode_decode_binary_data(self): + """Test binary data passes through without base64 encoding.""" + binary_data = bytes(range(256)) + encoded = BinaryProtocol.encode_response( + request_id=44444, + result={"data": binary_data} + ) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == binary_data + + def test_encode_decode_large_message(self): """Test encoding a large message.""" - large_data = "x" * (1024 * 1024) # 1 MB of data - message = {"type": "request", "data": large_data} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + large_data = b"x" * (1024 * 1024) # 1 MB + encoded = BinaryProtocol.encode_response( + request_id=55555, + result={"data": large_data} + ) - def test_encode_empty_dict(self): - """Test encoding an empty dictionary.""" - message = {} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message - - def test_encode_message_too_large(self): - """Test that encoding a message that's too large raises error.""" - large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000) - message = {"data": large_data} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "too large" in str(exc_info.value) - - def test_encode_non_serializable(self): - """Test that encoding non-JSON-serializable data raises error.""" - message = {"func": lambda x: x} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "Failed to encode" in str(exc_info.value) - - def test_decode_invalid_json(self): - """Test decoding invalid JSON raises error.""" - invalid_data = b"not valid json" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) - - def test_decode_invalid_unicode(self): - """Test decoding invalid unicode raises error.""" - invalid_data = b"\x80\x81\x82" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == large_data -class TestDirtyProtocolSync: +class TestBinaryProtocolSync: """Tests for synchronous socket operations.""" def test_read_write_message(self): """Test read/write through socket pair.""" - # Create a socket pair for testing server_sock, client_sock = socket.socketpair() try: - message = {"type": "request", "id": "123", "action": "test"} + message = make_request( + request_id=12345, + app_path="test:App", + action="run" + ) - # Write message - DirtyProtocol.write_message(client_sock, message) + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) - # Read message - received = DirtyProtocol.read_message(server_sock) - assert received == message + assert received["type"] == "request" + assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \ + received["id"] == 12345 + assert received["app_path"] == "test:App" + assert received["action"] == "run" + finally: + server_sock.close() + client_sock.close() + + def test_read_write_with_int_id(self): + """Test read/write with integer request ID.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "request", + "id": 999888777, + "app_path": "test:App", + "action": "run", + "args": [], + "kwargs": {} + } + + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["id"] == 999888777 finally: server_sock.close() client_sock.close() @@ -147,19 +234,17 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() try: messages = [ - {"type": "request", "id": "1"}, - {"type": "request", "id": "2"}, - {"type": "request", "id": "3"}, + make_request(i, f"app{i}:App", f"action{i}") + for i in range(1, 4) ] - # Write all messages for msg in messages: - DirtyProtocol.write_message(client_sock, msg) + BinaryProtocol.write_message(client_sock, msg) - # Read all messages - for expected in messages: - received = DirtyProtocol.read_message(server_sock) - assert received == expected + for i, _ in enumerate(messages, 1): + received = BinaryProtocol.read_message(server_sock) + assert received["app_path"] == f"app{i}:App" + assert received["action"] == f"action{i}" finally: server_sock.close() client_sock.close() @@ -169,39 +254,51 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() client_sock.close() with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.read_message(server_sock) + BinaryProtocol.read_message(server_sock) assert "closed" in str(exc_info.value).lower() server_sock.close() + def test_binary_data_roundtrip(self): + """Test binary data roundtrip through socket.""" + server_sock, client_sock = socket.socketpair() + try: + binary_payload = b"\x00\x01\x02\xff\xfe\xfd" + message = make_response(12345, {"binary": binary_payload}) -class TestDirtyProtocolAsync: + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["result"]["binary"] == binary_payload + finally: + server_sock.close() + client_sock.close() + + +class TestBinaryProtocolAsync: """Tests for async stream operations.""" @pytest.mark.asyncio async def test_async_read_write(self): """Test async read/write with mock streams.""" - message = {"type": "request", "id": "123"} + message = make_request(12345, "test:App", "run") - # Create a pipe for testing read_fd, write_fd = os.pipe() try: reader = asyncio.StreamReader() _ = asyncio.StreamReaderProtocol(reader) - # Write the message to the pipe - encoded = DirtyProtocol.encode(message) + encoded = BinaryProtocol._encode_from_dict(message) os.write(write_fd, encoded) os.close(write_fd) write_fd = None - # Feed data to reader data = os.read(read_fd, len(encoded)) reader.feed_data(data) reader.feed_eof() - # Read the message - received = await DirtyProtocol.read_message_async(reader) - assert received == message + received = await BinaryProtocol.read_message_async(reader) + assert received["type"] == "request" + assert received["app_path"] == "test:App" finally: if write_fd is not None: os.close(write_fd) @@ -211,12 +308,11 @@ class TestDirtyProtocolAsync: async def test_async_read_incomplete_header(self): """Test async read with incomplete header.""" reader = asyncio.StreamReader() - # Feed only 2 bytes instead of 4 - reader.feed_data(b"\x00\x00") + reader.feed_data(MAGIC + b"\x01") # Only 3 bytes reader.feed_eof() with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) @pytest.mark.asyncio async def test_async_read_empty_connection(self): @@ -225,36 +321,33 @@ class TestDirtyProtocolAsync: reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) + + @pytest.mark.asyncio + async def test_async_read_invalid_magic(self): + """Test async read rejects invalid magic.""" + reader = asyncio.StreamReader() + header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12 + reader.feed_data(header) + reader.feed_eof() + + with pytest.raises(DirtyProtocolError) as exc_info: + await BinaryProtocol.read_message_async(reader) + assert "magic" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_async_read_message_too_large(self): """Test async read rejects too-large messages.""" reader = asyncio.StreamReader() - # Create a header claiming an absurdly large message - header = struct.pack( - DirtyProtocol.HEADER_FORMAT, - DirtyProtocol.MAX_MESSAGE_SIZE + 1000 - ) + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1000, 0) reader.feed_data(header) reader.feed_eof() with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) assert "too large" in str(exc_info.value) - @pytest.mark.asyncio - async def test_async_read_empty_message(self): - """Test async read rejects empty messages.""" - reader = asyncio.StreamReader() - header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0) - reader.feed_data(header) - reader.feed_eof() - - with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) - assert "Empty message" in str(exc_info.value) - class TestMessageBuilders: """Tests for message builder helper functions.""" @@ -340,9 +433,9 @@ class TestMessageBuilders: assert chunk["id"] == "req-456" assert chunk["data"] == data - def test_make_chunk_message_with_list_data(self): - """Test chunk message with list data.""" - data = [1, 2, 3, "token"] + def test_make_chunk_message_with_binary_data(self): + """Test chunk message with binary data.""" + data = b"\x00\x01\x02\xff" chunk = make_chunk_message("req-789", data) assert chunk["data"] == data @@ -353,22 +446,22 @@ class TestMessageBuilders: assert end["id"] == "req-123" assert "data" not in end - def test_chunk_and_end_encode_decode(self): + def test_chunk_and_end_roundtrip(self): """Test that chunk and end messages can be encoded/decoded.""" - chunk = make_chunk_message("req-123", {"token": "hello"}) - end = make_end_message("req-123") + chunk = make_chunk_message(12345, {"token": "hello"}) + end = make_end_message(12345) # Test chunk roundtrip - encoded_chunk = DirtyProtocol.encode(chunk) - payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == chunk + encoded_chunk = BinaryProtocol._encode_from_dict(chunk) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk) + assert msg_type == "chunk" + assert payload["data"] == {"token": "hello"} # Test end roundtrip - encoded_end = DirtyProtocol.encode(end) - payload = encoded_end[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == end + encoded_end = BinaryProtocol._encode_from_dict(end) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end) + assert msg_type == "end" + assert payload == {} class TestDirtyErrors: @@ -417,3 +510,58 @@ class TestDirtyErrors: assert error.action == "run" assert error.traceback == "Traceback..." assert "myapp:App" in str(error) + + +class TestBackwardsCompatibility: + """Tests for backwards compatibility with old JSON API.""" + + def test_dirty_protocol_alias(self): + """Test that DirtyProtocol is an alias for BinaryProtocol.""" + assert DirtyProtocol is BinaryProtocol + + def test_header_size_attribute(self): + """Test HEADER_SIZE is accessible on class.""" + assert DirtyProtocol.HEADER_SIZE == 16 + + def test_msg_type_constants(self): + """Test message type constants are strings for compatibility.""" + assert DirtyProtocol.MSG_TYPE_REQUEST == "request" + assert DirtyProtocol.MSG_TYPE_RESPONSE == "response" + assert DirtyProtocol.MSG_TYPE_ERROR == "error" + assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk" + assert DirtyProtocol.MSG_TYPE_END == "end" + + def test_encode_decode_preserves_dict_format(self): + """Test that read_message returns dict compatible with old API.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "response", + "id": 12345, + "result": {"status": "ok"} + } + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Old API: access via dict keys + assert received["type"] == "response" + assert received["result"]["status"] == "ok" + finally: + server_sock.close() + client_sock.close() + + def test_string_request_id_handled(self): + """Test that string request IDs are handled (hashed to int).""" + server_sock, client_sock = socket.socketpair() + try: + message = make_request("uuid-string-id", "test:App", "run") + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Request ID should be converted to int + assert isinstance(received["id"], int) + finally: + server_sock.close() + client_sock.close() diff --git a/tests/test_dirty_tlv.py b/tests/test_dirty_tlv.py new file mode 100644 index 00000000..c36b839a --- /dev/null +++ b/tests/test_dirty_tlv.py @@ -0,0 +1,554 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty TLV binary encoder/decoder.""" + +import math +import struct +import pytest + +from gunicorn.dirty.tlv import ( + TLVEncoder, + TYPE_NONE, + TYPE_BOOL, + TYPE_INT64, + TYPE_FLOAT64, + TYPE_BYTES, + TYPE_STRING, + TYPE_LIST, + TYPE_DICT, + MAX_STRING_SIZE, + MAX_BYTES_SIZE, + MAX_LIST_SIZE, + MAX_DICT_SIZE, +) +from gunicorn.dirty.errors import DirtyProtocolError + + +class TestTLVEncoderBasicTypes: + """Tests for basic type encoding/decoding.""" + + def test_encode_decode_none(self): + """Test None encoding/decoding.""" + encoded = TLVEncoder.encode(None) + assert encoded == bytes([TYPE_NONE]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is None + assert offset == 1 + + def test_encode_decode_true(self): + """Test True encoding/decoding.""" + encoded = TLVEncoder.encode(True) + assert encoded == bytes([TYPE_BOOL, 0x01]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is True + assert offset == 2 + + def test_encode_decode_false(self): + """Test False encoding/decoding.""" + encoded = TLVEncoder.encode(False) + assert encoded == bytes([TYPE_BOOL, 0x00]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is False + assert offset == 2 + + def test_encode_decode_positive_int(self): + """Test positive integer encoding/decoding.""" + encoded = TLVEncoder.encode(42) + assert encoded[0] == TYPE_INT64 + assert len(encoded) == 9 # 1 type + 8 value + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == 42 + assert offset == 9 + + def test_encode_decode_negative_int(self): + """Test negative integer encoding/decoding.""" + encoded = TLVEncoder.encode(-12345) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == -12345 + + def test_encode_decode_large_int(self): + """Test large integer encoding/decoding.""" + large_val = 2**62 + encoded = TLVEncoder.encode(large_val) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == large_val + + def test_encode_decode_zero(self): + """Test zero encoding/decoding.""" + encoded = TLVEncoder.encode(0) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == 0 + + def test_encode_decode_float(self): + """Test float encoding/decoding.""" + encoded = TLVEncoder.encode(3.14159) + assert encoded[0] == TYPE_FLOAT64 + assert len(encoded) == 9 # 1 type + 8 value + + value, offset = TLVEncoder.decode(encoded, 0) + assert abs(value - 3.14159) < 1e-10 + + def test_encode_decode_negative_float(self): + """Test negative float encoding/decoding.""" + encoded = TLVEncoder.encode(-273.15) + + value, offset = TLVEncoder.decode(encoded, 0) + assert abs(value - (-273.15)) < 1e-10 + + def test_encode_decode_float_infinity(self): + """Test infinity encoding/decoding.""" + encoded = TLVEncoder.encode(float('inf')) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == float('inf') + + def test_encode_decode_float_nan(self): + """Test NaN encoding/decoding.""" + encoded = TLVEncoder.encode(float('nan')) + + value, offset = TLVEncoder.decode(encoded, 0) + assert math.isnan(value) + + +class TestTLVEncoderBytes: + """Tests for bytes encoding/decoding.""" + + def test_encode_decode_empty_bytes(self): + """Test empty bytes encoding/decoding.""" + encoded = TLVEncoder.encode(b"") + assert encoded[0] == TYPE_BYTES + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == b"" + + def test_encode_decode_bytes(self): + """Test bytes encoding/decoding.""" + data = b"\x00\x01\x02\xff\xfe\xfd" + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_large_bytes(self): + """Test large bytes encoding/decoding.""" + data = b"x" * 10000 + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_bytes_too_large(self): + """Test that bytes exceeding max size raises error.""" + # We won't actually allocate MAX_BYTES_SIZE, just check the encoding + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(b"x" * (MAX_BYTES_SIZE + 1)) + assert "too large" in str(exc_info.value).lower() + + +class TestTLVEncoderString: + """Tests for string encoding/decoding.""" + + def test_encode_decode_empty_string(self): + """Test empty string encoding/decoding.""" + encoded = TLVEncoder.encode("") + assert encoded[0] == TYPE_STRING + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == "" + + def test_encode_decode_ascii_string(self): + """Test ASCII string encoding/decoding.""" + encoded = TLVEncoder.encode("hello world") + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == "hello world" + + def test_encode_decode_unicode_string(self): + """Test Unicode string encoding/decoding.""" + text = "Hello, world! \u00a9 \u2603 \U0001F600" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_chinese(self): + """Test Chinese characters encoding/decoding.""" + text = "Hello, world!" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_emoji(self): + """Test emoji encoding/decoding.""" + text = "Test emoji" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_large_string(self): + """Test large string encoding/decoding.""" + text = "x" * 10000 + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + +class TestTLVEncoderList: + """Tests for list encoding/decoding.""" + + def test_encode_decode_empty_list(self): + """Test empty list encoding/decoding.""" + encoded = TLVEncoder.encode([]) + assert encoded[0] == TYPE_LIST + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == [] + + def test_encode_decode_simple_list(self): + """Test simple list encoding/decoding.""" + data = [1, 2, 3] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_mixed_list(self): + """Test mixed type list encoding/decoding.""" + data = [1, "hello", 3.14, True, None, b"bytes"] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_nested_list(self): + """Test nested list encoding/decoding.""" + data = [[1, 2], [3, [4, 5]], ["a", "b"]] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_tuple_as_list(self): + """Test that tuples are encoded as lists.""" + data = (1, 2, 3) + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == [1, 2, 3] # Decoded as list + + def test_encode_decode_large_list(self): + """Test large list encoding/decoding.""" + data = list(range(1000)) + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + +class TestTLVEncoderDict: + """Tests for dict encoding/decoding.""" + + def test_encode_decode_empty_dict(self): + """Test empty dict encoding/decoding.""" + encoded = TLVEncoder.encode({}) + assert encoded[0] == TYPE_DICT + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == {} + + def test_encode_decode_simple_dict(self): + """Test simple dict encoding/decoding.""" + data = {"a": 1, "b": 2, "c": 3} + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_mixed_values_dict(self): + """Test dict with mixed value types.""" + data = { + "int": 42, + "float": 3.14, + "string": "hello", + "bool": True, + "none": None, + "bytes": b"data", + "list": [1, 2, 3], + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_nested_dict(self): + """Test nested dict encoding/decoding.""" + data = { + "outer": { + "inner": { + "value": 42 + }, + "list": [{"a": 1}, {"b": 2}] + } + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_dict_non_string_key_converted(self): + """Test that non-string keys are converted to strings (like JSON).""" + data = {1: "value", 2: "other"} + encoded = TLVEncoder.encode(data) + decoded, _ = TLVEncoder.decode(encoded, 0) + # Keys should be converted to strings + assert decoded == {"1": "value", "2": "other"} + + +class TestTLVEncoderComplexStructures: + """Tests for complex nested structures.""" + + def test_encode_decode_request_like(self): + """Test encoding/decoding a request-like structure.""" + data = { + "id": 12345, + "app_path": "myapp.ml:MLApp", + "action": "predict", + "args": [b"input_data", 0.7], + "kwargs": {"temperature": 0.7, "max_tokens": 1000}, + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_response_like(self): + """Test encoding/decoding a response-like structure.""" + data = { + "id": 12345, + "result": { + "predictions": [0.1, 0.2, 0.7], + "metadata": {"model": "v1.0", "latency_ms": 42}, + } + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_deeply_nested(self): + """Test deeply nested structures.""" + data = {"a": {"b": {"c": {"d": {"e": {"f": "deep"}}}}}} + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + +class TestTLVEncoderRoundtrip: + """Tests for complete roundtrip using decode_full.""" + + def test_decode_full_simple(self): + """Test decode_full with simple value.""" + data = {"key": "value"} + encoded = TLVEncoder.encode(data) + + value = TLVEncoder.decode_full(encoded) + assert value == data + + def test_decode_full_trailing_data(self): + """Test decode_full raises on trailing data.""" + encoded = TLVEncoder.encode(42) + b"extra" + + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode_full(encoded) + assert "trailing" in str(exc_info.value).lower() + + +class TestTLVEncoderErrors: + """Tests for error handling.""" + + def test_decode_empty_data(self): + """Test decoding empty data raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(b"", 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_int(self): + """Test decoding truncated int raises error.""" + # TYPE_INT64 followed by only 4 bytes instead of 8 + data = bytes([TYPE_INT64, 0, 0, 0, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_float(self): + """Test decoding truncated float raises error.""" + data = bytes([TYPE_FLOAT64, 0, 0, 0, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_bytes_length(self): + """Test decoding truncated bytes length raises error.""" + data = bytes([TYPE_BYTES, 0, 0]) # Only 2 bytes of length + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_bytes_data(self): + """Test decoding truncated bytes data raises error.""" + # Says 10 bytes but only provides 5 + data = bytes([TYPE_BYTES]) + struct.pack(">I", 10) + b"12345" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_string_length(self): + """Test decoding truncated string length raises error.""" + data = bytes([TYPE_STRING, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_string_data(self): + """Test decoding truncated string data raises error.""" + data = bytes([TYPE_STRING]) + struct.pack(">I", 10) + b"hello" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_invalid_utf8(self): + """Test decoding invalid UTF-8 raises error.""" + # Valid length, but invalid UTF-8 bytes + data = bytes([TYPE_STRING]) + struct.pack(">I", 3) + b"\x80\x81\x82" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "utf-8" in str(exc_info.value).lower() + + def test_decode_truncated_list_count(self): + """Test decoding truncated list count raises error.""" + data = bytes([TYPE_LIST, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_dict_count(self): + """Test decoding truncated dict count raises error.""" + data = bytes([TYPE_DICT, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_unknown_type(self): + """Test decoding unknown type raises error.""" + data = bytes([0xFF]) # Unknown type + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "unknown" in str(exc_info.value).lower() + + def test_encode_unsupported_type(self): + """Test encoding unsupported type raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(object()) + assert "unsupported type" in str(exc_info.value).lower() + + def test_encode_function_raises_error(self): + """Test encoding a function raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(lambda x: x) + assert "unsupported type" in str(exc_info.value).lower() + + def test_decode_dict_non_string_key_in_data(self): + """Test decoding dict with non-string key raises error.""" + # Manually construct a dict with int key + # TYPE_DICT, count=1, TYPE_INT64 key, TYPE_INT64 value + data = ( + bytes([TYPE_DICT]) + + struct.pack(">I", 1) + + bytes([TYPE_INT64]) + + struct.pack(">q", 1) # Key (int, not string) + + bytes([TYPE_INT64]) + + struct.pack(">q", 2) # Value + ) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "string" in str(exc_info.value).lower() + + +class TestTLVEncoderOffset: + """Tests for offset handling.""" + + def test_decode_with_offset(self): + """Test decoding from specific offset.""" + # Create data with prefix + prefix = b"garbage" + encoded = TLVEncoder.encode(42) + data = prefix + encoded + + value, offset = TLVEncoder.decode(data, len(prefix)) + assert value == 42 + assert offset == len(prefix) + len(encoded) + + def test_decode_multiple_values(self): + """Test decoding multiple consecutive values.""" + v1 = TLVEncoder.encode("hello") + v2 = TLVEncoder.encode(42) + v3 = TLVEncoder.encode([1, 2, 3]) + data = v1 + v2 + v3 + + offset = 0 + val1, offset = TLVEncoder.decode(data, offset) + assert val1 == "hello" + + val2, offset = TLVEncoder.decode(data, offset) + assert val2 == 42 + + val3, offset = TLVEncoder.decode(data, offset) + assert val3 == [1, 2, 3] + + assert offset == len(data) + + +class TestTLVEncoderBinaryData: + """Tests for binary data handling (the main motivation for this protocol).""" + + def test_binary_data_no_encoding(self): + """Test that binary data is passed through without encoding.""" + # This is the key advantage over JSON - binary data doesn't need base64 + binary_data = bytes(range(256)) # All byte values + encoded = TLVEncoder.encode(binary_data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == binary_data + + def test_binary_with_null_bytes(self): + """Test binary data with embedded null bytes.""" + binary_data = b"\x00\x00\xff\x00\x00" + encoded = TLVEncoder.encode(binary_data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == binary_data + + def test_binary_in_nested_structure(self): + """Test binary data inside nested structures.""" + data = { + "image": b"\x89PNG\r\n\x1a\n" + b"\x00" * 100, + "metadata": {"width": 640, "height": 480}, + "chunks": [b"chunk1", b"chunk2", b"chunk3"], + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data diff --git a/tests/test_dirty_worker.py b/tests/test_dirty_worker.py index f68a2276..e50e7c41 100644 --- a/tests/test_dirty_worker.py +++ b/tests/test_dirty_worker.py @@ -12,7 +12,13 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, + HEADER_FORMAT, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -56,17 +62,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-123", + request_id=123, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE - assert response["id"] == "test-123" + assert response["id"] == 123 assert response["result"] == 6 @pytest.mark.asyncio @@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-456", + request_id=456, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR - assert response["id"] == "test-456" + assert response["id"] == 456 assert "Unknown operation" in response["error"]["message"] @pytest.mark.asyncio @@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest: socket_path=socket_path ) - request = {"type": "unknown", "id": "test-789"} + request = {"type": "unknown", "id": 789} writer = MockStreamWriter() await worker.handle_request(request, writer) @@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync: # Create a simple test using stream reader/writer request = make_request( - request_id="conn-test", + request_id=999, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(5, 3), @@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync: # Mock reader and writer reader = asyncio.StreamReader() - encoded_request = DirtyProtocol.encode(request) + encoded_request = BinaryProtocol._encode_from_dict(request) reader.feed_data(encoded_request) reader.feed_eof()