From 1665857c0e4ee59cf4c68469753078ed2d3d6e15 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 22:58:43 +0100 Subject: [PATCH] feat(dirty): implement binary protocol Replace JSON-based protocol with binary format using 16-byte header: - Magic bytes (GD), version, message type, payload length, request ID - TLV-encoded payloads for efficient binary data transfer - No base64 encoding needed for binary data - Backwards compatible API (DirtyProtocol alias, dict-based interface) Header format inspired by OpenBSD msgctl/msgsnd. --- gunicorn/dirty/protocol.py | 530 +++++++++++++++++++++++++++-------- tests/test_dirty_protocol.py | 448 +++++++++++++++++++---------- 2 files changed, 711 insertions(+), 267 deletions(-) diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index e5ac6cfa..15fab29a 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -3,89 +3,304 @@ # See the NOTICE for more information. """ -Dirty Arbiters Protocol +Dirty Worker Binary Protocol -Length-prefixed JSON message framing over Unix sockets. -Provides both async (primary) and sync (for HTTP workers) APIs. +Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd. +Replaces JSON protocol for efficient binary data transfer. -Message Format: -+----------------+------------------+ -| 4-byte length | JSON payload | -+----------------+------------------+ +Header Format (16 bytes): ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Magic (2B) | Ver(1) | MType | Payload Length (4B) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Request ID (8 bytes) | ++--------+--------+--------+--------+--------+--------+--------+--------+ -The length field is a 4-byte unsigned integer in network byte order (big-endian). +- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty) +- Version: 0x01 +- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END) +- Length: Payload size (big-endian uint32, max 64MB) +- Request ID: uint64 (replaces UUID string) + +Payload is TLV-encoded (see tlv.py). """ import asyncio -import json -import struct import socket +import struct from .errors import DirtyProtocolError +from .tlv import TLVEncoder -class DirtyProtocol: - """Length-prefixed JSON messages over Unix sockets.""" +# Protocol constants +MAGIC = b"GD" # 0x47 0x44 +VERSION = 0x01 - # 4-byte unsigned int, network byte order (big-endian) - HEADER_FORMAT = "!I" - HEADER_SIZE = struct.calcsize(HEADER_FORMAT) +# Message types (1 byte) +MSG_TYPE_REQUEST = 0x01 +MSG_TYPE_RESPONSE = 0x02 +MSG_TYPE_ERROR = 0x03 +MSG_TYPE_CHUNK = 0x04 +MSG_TYPE_END = 0x05 - # Maximum message size (64 MB) - MAX_MESSAGE_SIZE = 64 * 1024 * 1024 +# Message type names (for backwards compatibility with old API) +MSG_TYPE_REQUEST_STR = "request" +MSG_TYPE_RESPONSE_STR = "response" +MSG_TYPE_ERROR_STR = "error" +MSG_TYPE_CHUNK_STR = "chunk" +MSG_TYPE_END_STR = "end" - # Message types for future streaming support - MSG_TYPE_REQUEST = "request" - MSG_TYPE_RESPONSE = "response" - MSG_TYPE_ERROR = "error" - MSG_TYPE_CHUNK = "chunk" - MSG_TYPE_END = "end" +# Map int types to string names +MSG_TYPE_TO_STR = { + MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR, + MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR, + MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR, + MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR, + MSG_TYPE_END: MSG_TYPE_END_STR, +} + +# Map string names to int types +MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()} + +# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 +HEADER_FORMAT = ">2sBBIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) + +# Maximum message size (64 MB) +MAX_MESSAGE_SIZE = 64 * 1024 * 1024 + + +class BinaryProtocol: + """Binary message protocol for dirty worker IPC.""" + + # Export constants for external use + HEADER_SIZE = HEADER_SIZE + MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE + + MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR + MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR + MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR + MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR + MSG_TYPE_END = MSG_TYPE_END_STR @staticmethod - def encode(message: dict) -> bytes: + def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: """ - Encode a message dict to length-prefixed bytes. + Encode the 16-byte message header. Args: - message: Dictionary to encode as JSON + msg_type: Message type (MSG_TYPE_REQUEST, etc.) + request_id: Unique request identifier (uint64) + payload_length: Length of the TLV-encoded payload Returns: - bytes: Length-prefixed encoded message + bytes: 16-byte header + """ + return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type, + payload_length, request_id) + + @staticmethod + def decode_header(data: bytes) -> tuple: + """ + Decode the 16-byte message header. + + Args: + data: 16 bytes of header data + + Returns: + tuple: (msg_type, request_id, payload_length) Raises: - DirtyProtocolError: If encoding fails + DirtyProtocolError: If header is invalid """ - try: - payload = json.dumps(message).encode("utf-8") - if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE: + if len(data) < HEADER_SIZE: + raise DirtyProtocolError( + f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}", + raw_data=data + ) + + magic, version, msg_type, length, request_id = struct.unpack( + HEADER_FORMAT, data[:HEADER_SIZE] + ) + + if magic != MAGIC: + raise DirtyProtocolError( + f"Invalid magic: {magic!r}, expected {MAGIC!r}", + raw_data=data[:20] + ) + + if version != VERSION: + raise DirtyProtocolError( + f"Unsupported protocol version: {version}, expected {VERSION}", + raw_data=data[:20] + ) + + if msg_type not in MSG_TYPE_TO_STR: + raise DirtyProtocolError( + f"Unknown message type: 0x{msg_type:02x}", + raw_data=data[:20] + ) + + if length > MAX_MESSAGE_SIZE: + raise DirtyProtocolError( + f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})" + ) + + return msg_type, request_id, length + + @staticmethod + def encode_request(request_id: int, app_path: str, action: str, + args: tuple = None, kwargs: dict = None) -> bytes: + """ + Encode a request message. + + Args: + request_id: Unique request identifier (uint64) + app_path: Import path of the dirty app + action: Action to call on the app + args: Positional arguments + kwargs: Keyword arguments + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = { + "app_path": app_path, + "action": action, + "args": list(args) if args else [], + "kwargs": kwargs or {}, + } + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_response(request_id: int, result) -> bytes: + """ + Encode a success response message. + + Args: + request_id: Request identifier this responds to + result: Result value (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"result": result} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_error(request_id: int, error) -> bytes: + """ + Encode an error response message. + + Args: + request_id: Request identifier this responds to + error: DirtyError instance, dict, or Exception + + Returns: + bytes: Complete message (header + payload) + """ + from .errors import DirtyError + + if isinstance(error, DirtyError): + error_dict = error.to_dict() + elif isinstance(error, dict): + error_dict = error + else: + error_dict = { + "error_type": type(error).__name__, + "message": str(error), + "details": {}, + } + + payload_dict = {"error": error_dict} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_chunk(request_id: int, data) -> bytes: + """ + Encode a chunk message for streaming responses. + + Args: + request_id: Request identifier this chunk belongs to + data: Chunk data (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"data": data} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_end(request_id: int) -> bytes: + """ + Encode an end-of-stream message. + + Args: + request_id: Request identifier this ends + + Returns: + bytes: Complete message (header + empty payload) + """ + # End message has empty payload + header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0) + return header + + @staticmethod + def decode_message(data: bytes) -> tuple: + """ + Decode a complete message (header + payload). + + Args: + data: Complete message bytes + + Returns: + tuple: (msg_type_str, request_id, payload_dict) + msg_type_str is the string name (e.g., "request") + payload_dict is the decoded TLV payload as a dict + + Raises: + DirtyProtocolError: If message is malformed + """ + msg_type, request_id, length = BinaryProtocol.decode_header(data) + + if len(data) < HEADER_SIZE + length: + raise DirtyProtocolError( + f"Incomplete message: expected {HEADER_SIZE + length} bytes, " + f"got {len(data)}", + raw_data=data[:50] + ) + + if length == 0: + # End message has empty payload + payload_dict = {} + else: + payload_data = data[HEADER_SIZE:HEADER_SIZE + length] + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: raise DirtyProtocolError( - f"Message too large: {len(payload)} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] ) - length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload)) - return length + payload - except (TypeError, ValueError) as e: - raise DirtyProtocolError(f"Failed to encode message: {e}") - @staticmethod - def decode(data: bytes) -> dict: - """ - Decode bytes (without length prefix) to message dict. + # Convert to dict format similar to old JSON protocol + msg_type_str = MSG_TYPE_TO_STR[msg_type] - Args: - data: JSON bytes to decode - - Returns: - dict: Decoded message - - Raises: - DirtyProtocolError: If decoding fails - """ - try: - return json.loads(data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - raise DirtyProtocolError(f"Failed to decode message: {e}", - raw_data=data) + return msg_type_str, request_id, payload_dict # ------------------------------------------------------------------------- # Async API (primary - for DirtyArbiter and DirtyWorker) @@ -94,53 +309,62 @@ class DirtyProtocol: @staticmethod async def read_message_async(reader: asyncio.StreamReader) -> dict: """ - Read a complete message from async stream. + Read a complete binary message from async stream. Args: reader: asyncio StreamReader Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed asyncio.IncompleteReadError: If connection closed mid-read """ - # Read length header + # Read header try: - header = await reader.readexactly(DirtyProtocol.HEADER_SIZE) + header = await reader.readexactly(HEADER_SIZE) except asyncio.IncompleteReadError as e: if len(e.partial) == 0: # Clean close - no data was read raise raise DirtyProtocolError( f"Incomplete header: got {len(e.partial)} bytes, " - f"expected {DirtyProtocol.HEADER_SIZE}", + f"expected {HEADER_SIZE}", raw_data=e.partial ) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - try: - payload = await reader.readexactly(length) - except asyncio.IncompleteReadError as e: - raise DirtyProtocolError( - f"Incomplete message: got {len(e.partial)} bytes, " - f"expected {length}", - raw_data=e.partial - ) + if length > 0: + try: + payload_data = await reader.readexactly(length) + except asyncio.IncompleteReadError as e: + raise DirtyProtocolError( + f"Incomplete payload: got {len(e.partial)} bytes, " + f"expected {length}", + raw_data=e.partial + ) - return DirtyProtocol.decode(payload) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod async def write_message_async(writer: asyncio.StreamWriter, @@ -148,15 +372,17 @@ class DirtyProtocol: """ Write a message to async stream. + Accepts dict format for backwards compatibility. + Args: writer: asyncio StreamWriter - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails ConnectionError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) writer.write(data) await writer.drain() @@ -201,27 +427,36 @@ class DirtyProtocol: sock: Socket to read from Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed """ - # Read length header - header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + # Read header + header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE) + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - payload = DirtyProtocol._recv_exactly(sock, length) - return DirtyProtocol.decode(payload) + if length > 0: + payload_data = BinaryProtocol._recv_exactly(sock, length) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod def write_message(sock: socket.socket, message: dict) -> None: @@ -230,31 +465,92 @@ class DirtyProtocol: Args: sock: Socket to write to - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails OSError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) sock.sendall(data) + @staticmethod + def _encode_from_dict(message: dict) -> bytes: + """ + Encode a message dict to binary format. -# Message builder helpers -def make_request(request_id: str, app_path: str, action: str, + Supports the old dict-based API for backwards compatibility. + + Args: + message: Message dict with 'type', 'id', and payload fields + + Returns: + bytes: Complete encoded message + """ + msg_type_str = message.get("type") + request_id = message.get("id", 0) + + # Handle string or int request IDs + if isinstance(request_id, str): + # For backwards compat with UUID strings, hash to int + request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF + + msg_type = MSG_TYPE_FROM_STR.get(msg_type_str) + if msg_type is None: + raise DirtyProtocolError(f"Unknown message type: {msg_type_str}") + + if msg_type == MSG_TYPE_REQUEST: + return BinaryProtocol.encode_request( + request_id, + message.get("app_path", ""), + message.get("action", ""), + message.get("args"), + message.get("kwargs") + ) + elif msg_type == MSG_TYPE_RESPONSE: + return BinaryProtocol.encode_response( + request_id, + message.get("result") + ) + elif msg_type == MSG_TYPE_ERROR: + return BinaryProtocol.encode_error( + request_id, + message.get("error", {}) + ) + elif msg_type == MSG_TYPE_CHUNK: + return BinaryProtocol.encode_chunk( + request_id, + message.get("data") + ) + elif msg_type == MSG_TYPE_END: + return BinaryProtocol.encode_end(request_id) + else: + raise DirtyProtocolError(f"Unhandled message type: {msg_type}") + + +# ============================================================================= +# Backwards Compatibility Aliases +# ============================================================================= + +# Alias BinaryProtocol as DirtyProtocol for drop-in replacement +DirtyProtocol = BinaryProtocol + + +# Message builder helpers (backwards compatible with old API) +def make_request(request_id, app_path: str, action: str, args: tuple = None, kwargs: dict = None) -> dict: """ - Build a request message. + Build a request message dict. Args: - request_id: Unique request identifier + request_id: Unique request identifier (int or str) app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp') action: Action to call on the app args: Positional arguments kwargs: Keyword arguments Returns: - dict: Request message + dict: Request message dict """ return { "type": DirtyProtocol.MSG_TYPE_REQUEST, @@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str, } -def make_response(request_id: str, result) -> dict: +def make_response(request_id, result) -> dict: """ - Build a success response message. + Build a success response message dict. Args: request_id: Request identifier this responds to - result: Result value (must be JSON-serializable) + result: Result value Returns: - dict: Response message + dict: Response message dict """ return { "type": DirtyProtocol.MSG_TYPE_RESPONSE, @@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict: } -def make_error_response(request_id: str, error) -> dict: +def make_error_response(request_id, error) -> dict: """ - Build an error response message. + Build an error response message dict. Args: request_id: Request identifier this responds to error: DirtyError instance or dict with error info Returns: - dict: Error response message + dict: Error response message dict """ from .errors import DirtyError if isinstance(error, DirtyError): @@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict: } -def make_chunk_message(request_id: str, data) -> dict: +def make_chunk_message(request_id, data) -> dict: """ - Build a chunk message for streaming responses. + Build a chunk message dict for streaming responses. Args: request_id: Request identifier this chunk belongs to - data: Chunk data (must be JSON-serializable) + data: Chunk data Returns: - dict: Chunk message + dict: Chunk message dict """ return { "type": DirtyProtocol.MSG_TYPE_CHUNK, @@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict: } -def make_end_message(request_id: str) -> dict: +def make_end_message(request_id) -> dict: """ - Build an end-of-stream message. + Build an end-of-stream message dict. Args: request_id: Request identifier this ends Returns: - dict: End message + dict: End message dict """ return { "type": DirtyProtocol.MSG_TYPE_END, diff --git a/tests/test_dirty_protocol.py b/tests/test_dirty_protocol.py index dbabc51e..48fe3333 100644 --- a/tests/test_dirty_protocol.py +++ b/tests/test_dirty_protocol.py @@ -2,7 +2,7 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. -"""Tests for dirty arbiter protocol module.""" +"""Tests for dirty worker binary protocol module.""" import asyncio import os @@ -11,12 +11,23 @@ import struct import pytest from gunicorn.dirty.protocol import ( + BinaryProtocol, DirtyProtocol, make_request, make_response, make_error_response, make_chunk_message, make_end_message, + MAGIC, + VERSION, + HEADER_SIZE, + HEADER_FORMAT, + MSG_TYPE_REQUEST, + MSG_TYPE_RESPONSE, + MSG_TYPE_ERROR, + MSG_TYPE_CHUNK, + MSG_TYPE_END, + MAX_MESSAGE_SIZE, ) from gunicorn.dirty.errors import ( DirtyError, @@ -26,118 +37,194 @@ from gunicorn.dirty.errors import ( ) -class TestDirtyProtocolEncodeDecode: - """Tests for encode/decode functionality.""" +class TestBinaryProtocolHeader: + """Tests for header encoding/decoding.""" - def test_encode_decode_roundtrip(self): - """Test basic encode/decode roundtrip.""" - message = {"type": "request", "id": "123", "data": "hello"} - encoded = DirtyProtocol.encode(message) + def test_header_size(self): + """Test header size is 16 bytes.""" + assert HEADER_SIZE == 16 - # Check header format - assert len(encoded) > DirtyProtocol.HEADER_SIZE - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - encoded[:DirtyProtocol.HEADER_SIZE] - )[0] - assert length == len(encoded) - DirtyProtocol.HEADER_SIZE + def test_encode_header(self): + """Test header encoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100) + assert len(header) == HEADER_SIZE + assert header[:2] == MAGIC + assert header[2] == VERSION + assert header[3] == MSG_TYPE_REQUEST - # Decode payload - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header(self): + """Test header decoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200) + msg_type, request_id, length = BinaryProtocol.decode_header(header) + assert msg_type == MSG_TYPE_RESPONSE + assert request_id == 67890 + assert length == 200 - def test_encode_decode_complex_data(self): - """Test with complex nested data structures.""" - message = { - "type": "response", - "id": "456", - "result": { - "models": ["gpt-4", "claude-3"], - "config": {"temperature": 0.7, "max_tokens": 1000}, - "metadata": None, - }, - "args": [1, 2, 3], - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_magic(self): + """Test header decoding with invalid magic.""" + header = b"XX" + b"\x01\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "magic" in str(exc_info.value).lower() - def test_encode_decode_unicode(self): - """Test with unicode characters.""" - message = { - "type": "request", - "data": "Hello, world!" - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_version(self): + """Test header decoding with invalid version.""" + header = MAGIC + b"\x99\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "version" in str(exc_info.value).lower() - def test_encode_large_message(self): + def test_decode_header_invalid_type(self): + """Test header decoding with invalid message type.""" + header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "type" in str(exc_info.value).lower() + + def test_decode_header_too_large(self): + """Test header decoding rejects too-large messages.""" + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1, 0) + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "too large" in str(exc_info.value).lower() + + def test_decode_header_too_short(self): + """Test header decoding with too-short data.""" + header = MAGIC + b"\x01" + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "short" in str(exc_info.value).lower() + + +class TestBinaryProtocolEncodeDecode: + """Tests for message encoding/decoding.""" + + def test_encode_decode_request(self): + """Test request encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_request( + request_id=12345, + app_path="myapp.ml:MLApp", + action="predict", + args=("data",), + kwargs={"temperature": 0.7} + ) + assert len(encoded) > HEADER_SIZE + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "request" + assert request_id == 12345 + assert payload["app_path"] == "myapp.ml:MLApp" + assert payload["action"] == "predict" + assert payload["args"] == ["data"] + assert payload["kwargs"] == {"temperature": 0.7} + + def test_encode_decode_response(self): + """Test response encoding/decoding roundtrip.""" + result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}} + encoded = BinaryProtocol.encode_response(request_id=67890, result=result) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "response" + assert request_id == 67890 + assert payload["result"] == result + + def test_encode_decode_error(self): + """Test error encoding/decoding roundtrip.""" + error = DirtyTimeoutError("Timed out", timeout=30) + encoded = BinaryProtocol.encode_error(request_id=11111, error=error) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "error" + assert request_id == 11111 + assert payload["error"]["error_type"] == "DirtyTimeoutError" + assert "Timed out" in payload["error"]["message"] + + def test_encode_decode_chunk(self): + """Test chunk encoding/decoding roundtrip.""" + chunk_data = {"token": "hello", "index": 5} + encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "chunk" + assert request_id == 22222 + assert payload["data"] == chunk_data + + def test_encode_decode_end(self): + """Test end message encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_end(request_id=33333) + assert len(encoded) == HEADER_SIZE # End has no payload + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "end" + assert request_id == 33333 + assert payload == {} + + def test_encode_decode_binary_data(self): + """Test binary data passes through without base64 encoding.""" + binary_data = bytes(range(256)) + encoded = BinaryProtocol.encode_response( + request_id=44444, + result={"data": binary_data} + ) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == binary_data + + def test_encode_decode_large_message(self): """Test encoding a large message.""" - large_data = "x" * (1024 * 1024) # 1 MB of data - message = {"type": "request", "data": large_data} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + large_data = b"x" * (1024 * 1024) # 1 MB + encoded = BinaryProtocol.encode_response( + request_id=55555, + result={"data": large_data} + ) - def test_encode_empty_dict(self): - """Test encoding an empty dictionary.""" - message = {} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message - - def test_encode_message_too_large(self): - """Test that encoding a message that's too large raises error.""" - large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000) - message = {"data": large_data} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "too large" in str(exc_info.value) - - def test_encode_non_serializable(self): - """Test that encoding non-JSON-serializable data raises error.""" - message = {"func": lambda x: x} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "Failed to encode" in str(exc_info.value) - - def test_decode_invalid_json(self): - """Test decoding invalid JSON raises error.""" - invalid_data = b"not valid json" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) - - def test_decode_invalid_unicode(self): - """Test decoding invalid unicode raises error.""" - invalid_data = b"\x80\x81\x82" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == large_data -class TestDirtyProtocolSync: +class TestBinaryProtocolSync: """Tests for synchronous socket operations.""" def test_read_write_message(self): """Test read/write through socket pair.""" - # Create a socket pair for testing server_sock, client_sock = socket.socketpair() try: - message = {"type": "request", "id": "123", "action": "test"} + message = make_request( + request_id=12345, + app_path="test:App", + action="run" + ) - # Write message - DirtyProtocol.write_message(client_sock, message) + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) - # Read message - received = DirtyProtocol.read_message(server_sock) - assert received == message + assert received["type"] == "request" + assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \ + received["id"] == 12345 + assert received["app_path"] == "test:App" + assert received["action"] == "run" + finally: + server_sock.close() + client_sock.close() + + def test_read_write_with_int_id(self): + """Test read/write with integer request ID.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "request", + "id": 999888777, + "app_path": "test:App", + "action": "run", + "args": [], + "kwargs": {} + } + + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["id"] == 999888777 finally: server_sock.close() client_sock.close() @@ -147,19 +234,17 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() try: messages = [ - {"type": "request", "id": "1"}, - {"type": "request", "id": "2"}, - {"type": "request", "id": "3"}, + make_request(i, f"app{i}:App", f"action{i}") + for i in range(1, 4) ] - # Write all messages for msg in messages: - DirtyProtocol.write_message(client_sock, msg) + BinaryProtocol.write_message(client_sock, msg) - # Read all messages - for expected in messages: - received = DirtyProtocol.read_message(server_sock) - assert received == expected + for i, _ in enumerate(messages, 1): + received = BinaryProtocol.read_message(server_sock) + assert received["app_path"] == f"app{i}:App" + assert received["action"] == f"action{i}" finally: server_sock.close() client_sock.close() @@ -169,39 +254,51 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() client_sock.close() with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.read_message(server_sock) + BinaryProtocol.read_message(server_sock) assert "closed" in str(exc_info.value).lower() server_sock.close() + def test_binary_data_roundtrip(self): + """Test binary data roundtrip through socket.""" + server_sock, client_sock = socket.socketpair() + try: + binary_payload = b"\x00\x01\x02\xff\xfe\xfd" + message = make_response(12345, {"binary": binary_payload}) -class TestDirtyProtocolAsync: + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["result"]["binary"] == binary_payload + finally: + server_sock.close() + client_sock.close() + + +class TestBinaryProtocolAsync: """Tests for async stream operations.""" @pytest.mark.asyncio async def test_async_read_write(self): """Test async read/write with mock streams.""" - message = {"type": "request", "id": "123"} + message = make_request(12345, "test:App", "run") - # Create a pipe for testing read_fd, write_fd = os.pipe() try: reader = asyncio.StreamReader() _ = asyncio.StreamReaderProtocol(reader) - # Write the message to the pipe - encoded = DirtyProtocol.encode(message) + encoded = BinaryProtocol._encode_from_dict(message) os.write(write_fd, encoded) os.close(write_fd) write_fd = None - # Feed data to reader data = os.read(read_fd, len(encoded)) reader.feed_data(data) reader.feed_eof() - # Read the message - received = await DirtyProtocol.read_message_async(reader) - assert received == message + received = await BinaryProtocol.read_message_async(reader) + assert received["type"] == "request" + assert received["app_path"] == "test:App" finally: if write_fd is not None: os.close(write_fd) @@ -211,12 +308,11 @@ class TestDirtyProtocolAsync: async def test_async_read_incomplete_header(self): """Test async read with incomplete header.""" reader = asyncio.StreamReader() - # Feed only 2 bytes instead of 4 - reader.feed_data(b"\x00\x00") + reader.feed_data(MAGIC + b"\x01") # Only 3 bytes reader.feed_eof() with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) @pytest.mark.asyncio async def test_async_read_empty_connection(self): @@ -225,36 +321,33 @@ class TestDirtyProtocolAsync: reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) + + @pytest.mark.asyncio + async def test_async_read_invalid_magic(self): + """Test async read rejects invalid magic.""" + reader = asyncio.StreamReader() + header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12 + reader.feed_data(header) + reader.feed_eof() + + with pytest.raises(DirtyProtocolError) as exc_info: + await BinaryProtocol.read_message_async(reader) + assert "magic" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_async_read_message_too_large(self): """Test async read rejects too-large messages.""" reader = asyncio.StreamReader() - # Create a header claiming an absurdly large message - header = struct.pack( - DirtyProtocol.HEADER_FORMAT, - DirtyProtocol.MAX_MESSAGE_SIZE + 1000 - ) + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1000, 0) reader.feed_data(header) reader.feed_eof() with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) assert "too large" in str(exc_info.value) - @pytest.mark.asyncio - async def test_async_read_empty_message(self): - """Test async read rejects empty messages.""" - reader = asyncio.StreamReader() - header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0) - reader.feed_data(header) - reader.feed_eof() - - with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) - assert "Empty message" in str(exc_info.value) - class TestMessageBuilders: """Tests for message builder helper functions.""" @@ -340,9 +433,9 @@ class TestMessageBuilders: assert chunk["id"] == "req-456" assert chunk["data"] == data - def test_make_chunk_message_with_list_data(self): - """Test chunk message with list data.""" - data = [1, 2, 3, "token"] + def test_make_chunk_message_with_binary_data(self): + """Test chunk message with binary data.""" + data = b"\x00\x01\x02\xff" chunk = make_chunk_message("req-789", data) assert chunk["data"] == data @@ -353,22 +446,22 @@ class TestMessageBuilders: assert end["id"] == "req-123" assert "data" not in end - def test_chunk_and_end_encode_decode(self): + def test_chunk_and_end_roundtrip(self): """Test that chunk and end messages can be encoded/decoded.""" - chunk = make_chunk_message("req-123", {"token": "hello"}) - end = make_end_message("req-123") + chunk = make_chunk_message(12345, {"token": "hello"}) + end = make_end_message(12345) # Test chunk roundtrip - encoded_chunk = DirtyProtocol.encode(chunk) - payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == chunk + encoded_chunk = BinaryProtocol._encode_from_dict(chunk) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk) + assert msg_type == "chunk" + assert payload["data"] == {"token": "hello"} # Test end roundtrip - encoded_end = DirtyProtocol.encode(end) - payload = encoded_end[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == end + encoded_end = BinaryProtocol._encode_from_dict(end) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end) + assert msg_type == "end" + assert payload == {} class TestDirtyErrors: @@ -417,3 +510,58 @@ class TestDirtyErrors: assert error.action == "run" assert error.traceback == "Traceback..." assert "myapp:App" in str(error) + + +class TestBackwardsCompatibility: + """Tests for backwards compatibility with old JSON API.""" + + def test_dirty_protocol_alias(self): + """Test that DirtyProtocol is an alias for BinaryProtocol.""" + assert DirtyProtocol is BinaryProtocol + + def test_header_size_attribute(self): + """Test HEADER_SIZE is accessible on class.""" + assert DirtyProtocol.HEADER_SIZE == 16 + + def test_msg_type_constants(self): + """Test message type constants are strings for compatibility.""" + assert DirtyProtocol.MSG_TYPE_REQUEST == "request" + assert DirtyProtocol.MSG_TYPE_RESPONSE == "response" + assert DirtyProtocol.MSG_TYPE_ERROR == "error" + assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk" + assert DirtyProtocol.MSG_TYPE_END == "end" + + def test_encode_decode_preserves_dict_format(self): + """Test that read_message returns dict compatible with old API.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "response", + "id": 12345, + "result": {"status": "ok"} + } + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Old API: access via dict keys + assert received["type"] == "response" + assert received["result"]["status"] == "ok" + finally: + server_sock.close() + client_sock.close() + + def test_string_request_id_handled(self): + """Test that string request IDs are handled (hashed to int).""" + server_sock, client_sock = socket.socketpair() + try: + message = make_request("uuid-string-id", "test:App", "run") + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Request ID should be converted to int + assert isinstance(received["id"], int) + finally: + server_sock.close() + client_sock.close()