feat(dirty): implement binary protocol

Replace JSON-based protocol with binary format using 16-byte header:
- Magic bytes (GD), version, message type, payload length, request ID
- TLV-encoded payloads for efficient binary data transfer
- No base64 encoding needed for binary data
- Backwards compatible API (DirtyProtocol alias, dict-based interface)

Header format inspired by OpenBSD msgctl/msgsnd.
This commit is contained in:
Benoit Chesneau 2026-02-11 22:58:43 +01:00
parent 0e0dc669c8
commit 1665857c0e
2 changed files with 711 additions and 267 deletions

View File

@ -3,89 +3,304 @@
# See the NOTICE for more information.
"""
Dirty Arbiters Protocol
Dirty Worker Binary Protocol
Length-prefixed JSON message framing over Unix sockets.
Provides both async (primary) and sync (for HTTP workers) APIs.
Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd.
Replaces JSON protocol for efficient binary data transfer.
Message Format:
+----------------+------------------+
| 4-byte length | JSON payload |
+----------------+------------------+
Header Format (16 bytes):
+--------+--------+--------+--------+--------+--------+--------+--------+
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
+--------+--------+--------+--------+--------+--------+--------+--------+
| Request ID (8 bytes) |
+--------+--------+--------+--------+--------+--------+--------+--------+
The length field is a 4-byte unsigned integer in network byte order (big-endian).
- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty)
- Version: 0x01
- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END)
- Length: Payload size (big-endian uint32, max 64MB)
- Request ID: uint64 (replaces UUID string)
Payload is TLV-encoded (see tlv.py).
"""
import asyncio
import json
import struct
import socket
import struct
from .errors import DirtyProtocolError
from .tlv import TLVEncoder
class DirtyProtocol:
"""Length-prefixed JSON messages over Unix sockets."""
# Protocol constants
MAGIC = b"GD" # 0x47 0x44
VERSION = 0x01
# 4-byte unsigned int, network byte order (big-endian)
HEADER_FORMAT = "!I"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
# Message types (1 byte)
MSG_TYPE_REQUEST = 0x01
MSG_TYPE_RESPONSE = 0x02
MSG_TYPE_ERROR = 0x03
MSG_TYPE_CHUNK = 0x04
MSG_TYPE_END = 0x05
# Maximum message size (64 MB)
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
# Message type names (for backwards compatibility with old API)
MSG_TYPE_REQUEST_STR = "request"
MSG_TYPE_RESPONSE_STR = "response"
MSG_TYPE_ERROR_STR = "error"
MSG_TYPE_CHUNK_STR = "chunk"
MSG_TYPE_END_STR = "end"
# Message types for future streaming support
MSG_TYPE_REQUEST = "request"
MSG_TYPE_RESPONSE = "response"
MSG_TYPE_ERROR = "error"
MSG_TYPE_CHUNK = "chunk"
MSG_TYPE_END = "end"
# Map int types to string names
MSG_TYPE_TO_STR = {
MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR,
MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR,
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
MSG_TYPE_END: MSG_TYPE_END_STR,
}
# Map string names to int types
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
HEADER_FORMAT = ">2sBBIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
# Maximum message size (64 MB)
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
class BinaryProtocol:
"""Binary message protocol for dirty worker IPC."""
# Export constants for external use
HEADER_SIZE = HEADER_SIZE
MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE
MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR
MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
MSG_TYPE_END = MSG_TYPE_END_STR
@staticmethod
def encode(message: dict) -> bytes:
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
"""
Encode a message dict to length-prefixed bytes.
Encode the 16-byte message header.
Args:
message: Dictionary to encode as JSON
msg_type: Message type (MSG_TYPE_REQUEST, etc.)
request_id: Unique request identifier (uint64)
payload_length: Length of the TLV-encoded payload
Returns:
bytes: Length-prefixed encoded message
bytes: 16-byte header
"""
return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type,
payload_length, request_id)
@staticmethod
def decode_header(data: bytes) -> tuple:
"""
Decode the 16-byte message header.
Args:
data: 16 bytes of header data
Returns:
tuple: (msg_type, request_id, payload_length)
Raises:
DirtyProtocolError: If encoding fails
DirtyProtocolError: If header is invalid
"""
try:
payload = json.dumps(message).encode("utf-8")
if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE:
if len(data) < HEADER_SIZE:
raise DirtyProtocolError(
f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}",
raw_data=data
)
magic, version, msg_type, length, request_id = struct.unpack(
HEADER_FORMAT, data[:HEADER_SIZE]
)
if magic != MAGIC:
raise DirtyProtocolError(
f"Invalid magic: {magic!r}, expected {MAGIC!r}",
raw_data=data[:20]
)
if version != VERSION:
raise DirtyProtocolError(
f"Unsupported protocol version: {version}, expected {VERSION}",
raw_data=data[:20]
)
if msg_type not in MSG_TYPE_TO_STR:
raise DirtyProtocolError(
f"Unknown message type: 0x{msg_type:02x}",
raw_data=data[:20]
)
if length > MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})"
)
return msg_type, request_id, length
@staticmethod
def encode_request(request_id: int, app_path: str, action: str,
args: tuple = None, kwargs: dict = None) -> bytes:
"""
Encode a request message.
Args:
request_id: Unique request identifier (uint64)
app_path: Import path of the dirty app
action: Action to call on the app
args: Positional arguments
kwargs: Keyword arguments
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {
"app_path": app_path,
"action": action,
"args": list(args) if args else [],
"kwargs": kwargs or {},
}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id,
len(payload))
return header + payload
@staticmethod
def encode_response(request_id: int, result) -> bytes:
"""
Encode a success response message.
Args:
request_id: Request identifier this responds to
result: Result value (must be TLV-serializable)
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {"result": result}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id,
len(payload))
return header + payload
@staticmethod
def encode_error(request_id: int, error) -> bytes:
"""
Encode an error response message.
Args:
request_id: Request identifier this responds to
error: DirtyError instance, dict, or Exception
Returns:
bytes: Complete message (header + payload)
"""
from .errors import DirtyError
if isinstance(error, DirtyError):
error_dict = error.to_dict()
elif isinstance(error, dict):
error_dict = error
else:
error_dict = {
"error_type": type(error).__name__,
"message": str(error),
"details": {},
}
payload_dict = {"error": error_dict}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id,
len(payload))
return header + payload
@staticmethod
def encode_chunk(request_id: int, data) -> bytes:
"""
Encode a chunk message for streaming responses.
Args:
request_id: Request identifier this chunk belongs to
data: Chunk data (must be TLV-serializable)
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {"data": data}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id,
len(payload))
return header + payload
@staticmethod
def encode_end(request_id: int) -> bytes:
"""
Encode an end-of-stream message.
Args:
request_id: Request identifier this ends
Returns:
bytes: Complete message (header + empty payload)
"""
# End message has empty payload
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
return header
@staticmethod
def decode_message(data: bytes) -> tuple:
"""
Decode a complete message (header + payload).
Args:
data: Complete message bytes
Returns:
tuple: (msg_type_str, request_id, payload_dict)
msg_type_str is the string name (e.g., "request")
payload_dict is the decoded TLV payload as a dict
Raises:
DirtyProtocolError: If message is malformed
"""
msg_type, request_id, length = BinaryProtocol.decode_header(data)
if len(data) < HEADER_SIZE + length:
raise DirtyProtocolError(
f"Incomplete message: expected {HEADER_SIZE + length} bytes, "
f"got {len(data)}",
raw_data=data[:50]
)
if length == 0:
# End message has empty payload
payload_dict = {}
else:
payload_data = data[HEADER_SIZE:HEADER_SIZE + length]
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Message too large: {len(payload)} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload))
return length + payload
except (TypeError, ValueError) as e:
raise DirtyProtocolError(f"Failed to encode message: {e}")
@staticmethod
def decode(data: bytes) -> dict:
"""
Decode bytes (without length prefix) to message dict.
# Convert to dict format similar to old JSON protocol
msg_type_str = MSG_TYPE_TO_STR[msg_type]
Args:
data: JSON bytes to decode
Returns:
dict: Decoded message
Raises:
DirtyProtocolError: If decoding fails
"""
try:
return json.loads(data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise DirtyProtocolError(f"Failed to decode message: {e}",
raw_data=data)
return msg_type_str, request_id, payload_dict
# -------------------------------------------------------------------------
# Async API (primary - for DirtyArbiter and DirtyWorker)
@ -94,53 +309,62 @@ class DirtyProtocol:
@staticmethod
async def read_message_async(reader: asyncio.StreamReader) -> dict:
"""
Read a complete message from async stream.
Read a complete binary message from async stream.
Args:
reader: asyncio StreamReader
Returns:
dict: Decoded message
dict: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If read fails or message is malformed
asyncio.IncompleteReadError: If connection closed mid-read
"""
# Read length header
# Read header
try:
header = await reader.readexactly(DirtyProtocol.HEADER_SIZE)
header = await reader.readexactly(HEADER_SIZE)
except asyncio.IncompleteReadError as e:
if len(e.partial) == 0:
# Clean close - no data was read
raise
raise DirtyProtocolError(
f"Incomplete header: got {len(e.partial)} bytes, "
f"expected {DirtyProtocol.HEADER_SIZE}",
f"expected {HEADER_SIZE}",
raw_data=e.partial
)
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
)
if length == 0:
raise DirtyProtocolError("Empty message received")
msg_type, request_id, length = BinaryProtocol.decode_header(header)
# Read payload
try:
payload = await reader.readexactly(length)
except asyncio.IncompleteReadError as e:
raise DirtyProtocolError(
f"Incomplete message: got {len(e.partial)} bytes, "
f"expected {length}",
raw_data=e.partial
)
if length > 0:
try:
payload_data = await reader.readexactly(length)
except asyncio.IncompleteReadError as e:
raise DirtyProtocolError(
f"Incomplete payload: got {len(e.partial)} bytes, "
f"expected {length}",
raw_data=e.partial
)
return DirtyProtocol.decode(payload)
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
else:
payload_dict = {}
# Build response dict
msg_type_str = MSG_TYPE_TO_STR[msg_type]
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
return result
@staticmethod
async def write_message_async(writer: asyncio.StreamWriter,
@ -148,15 +372,17 @@ class DirtyProtocol:
"""
Write a message to async stream.
Accepts dict format for backwards compatibility.
Args:
writer: asyncio StreamWriter
message: Dictionary to send
message: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If encoding fails
ConnectionError: If write fails
"""
data = DirtyProtocol.encode(message)
data = BinaryProtocol._encode_from_dict(message)
writer.write(data)
await writer.drain()
@ -201,27 +427,36 @@ class DirtyProtocol:
sock: Socket to read from
Returns:
dict: Decoded message
dict: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If read fails or message is malformed
"""
# Read length header
header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE)
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
)
if length == 0:
raise DirtyProtocolError("Empty message received")
# Read header
header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE)
msg_type, request_id, length = BinaryProtocol.decode_header(header)
# Read payload
payload = DirtyProtocol._recv_exactly(sock, length)
return DirtyProtocol.decode(payload)
if length > 0:
payload_data = BinaryProtocol._recv_exactly(sock, length)
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
else:
payload_dict = {}
# Build response dict
msg_type_str = MSG_TYPE_TO_STR[msg_type]
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
return result
@staticmethod
def write_message(sock: socket.socket, message: dict) -> None:
@ -230,31 +465,92 @@ class DirtyProtocol:
Args:
sock: Socket to write to
message: Dictionary to send
message: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If encoding fails
OSError: If write fails
"""
data = DirtyProtocol.encode(message)
data = BinaryProtocol._encode_from_dict(message)
sock.sendall(data)
@staticmethod
def _encode_from_dict(message: dict) -> bytes:
"""
Encode a message dict to binary format.
# Message builder helpers
def make_request(request_id: str, app_path: str, action: str,
Supports the old dict-based API for backwards compatibility.
Args:
message: Message dict with 'type', 'id', and payload fields
Returns:
bytes: Complete encoded message
"""
msg_type_str = message.get("type")
request_id = message.get("id", 0)
# Handle string or int request IDs
if isinstance(request_id, str):
# For backwards compat with UUID strings, hash to int
request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF
msg_type = MSG_TYPE_FROM_STR.get(msg_type_str)
if msg_type is None:
raise DirtyProtocolError(f"Unknown message type: {msg_type_str}")
if msg_type == MSG_TYPE_REQUEST:
return BinaryProtocol.encode_request(
request_id,
message.get("app_path", ""),
message.get("action", ""),
message.get("args"),
message.get("kwargs")
)
elif msg_type == MSG_TYPE_RESPONSE:
return BinaryProtocol.encode_response(
request_id,
message.get("result")
)
elif msg_type == MSG_TYPE_ERROR:
return BinaryProtocol.encode_error(
request_id,
message.get("error", {})
)
elif msg_type == MSG_TYPE_CHUNK:
return BinaryProtocol.encode_chunk(
request_id,
message.get("data")
)
elif msg_type == MSG_TYPE_END:
return BinaryProtocol.encode_end(request_id)
else:
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
# =============================================================================
# Backwards Compatibility Aliases
# =============================================================================
# Alias BinaryProtocol as DirtyProtocol for drop-in replacement
DirtyProtocol = BinaryProtocol
# Message builder helpers (backwards compatible with old API)
def make_request(request_id, app_path: str, action: str,
args: tuple = None, kwargs: dict = None) -> dict:
"""
Build a request message.
Build a request message dict.
Args:
request_id: Unique request identifier
request_id: Unique request identifier (int or str)
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
action: Action to call on the app
args: Positional arguments
kwargs: Keyword arguments
Returns:
dict: Request message
dict: Request message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_REQUEST,
@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str,
}
def make_response(request_id: str, result) -> dict:
def make_response(request_id, result) -> dict:
"""
Build a success response message.
Build a success response message dict.
Args:
request_id: Request identifier this responds to
result: Result value (must be JSON-serializable)
result: Result value
Returns:
dict: Response message
dict: Response message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_RESPONSE,
@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict:
}
def make_error_response(request_id: str, error) -> dict:
def make_error_response(request_id, error) -> dict:
"""
Build an error response message.
Build an error response message dict.
Args:
request_id: Request identifier this responds to
error: DirtyError instance or dict with error info
Returns:
dict: Error response message
dict: Error response message dict
"""
from .errors import DirtyError
if isinstance(error, DirtyError):
@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict:
}
def make_chunk_message(request_id: str, data) -> dict:
def make_chunk_message(request_id, data) -> dict:
"""
Build a chunk message for streaming responses.
Build a chunk message dict for streaming responses.
Args:
request_id: Request identifier this chunk belongs to
data: Chunk data (must be JSON-serializable)
data: Chunk data
Returns:
dict: Chunk message
dict: Chunk message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_CHUNK,
@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict:
}
def make_end_message(request_id: str) -> dict:
def make_end_message(request_id) -> dict:
"""
Build an end-of-stream message.
Build an end-of-stream message dict.
Args:
request_id: Request identifier this ends
Returns:
dict: End message
dict: End message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_END,

View File

@ -2,7 +2,7 @@
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty arbiter protocol module."""
"""Tests for dirty worker binary protocol module."""
import asyncio
import os
@ -11,12 +11,23 @@ import struct
import pytest
from gunicorn.dirty.protocol import (
BinaryProtocol,
DirtyProtocol,
make_request,
make_response,
make_error_response,
make_chunk_message,
make_end_message,
MAGIC,
VERSION,
HEADER_SIZE,
HEADER_FORMAT,
MSG_TYPE_REQUEST,
MSG_TYPE_RESPONSE,
MSG_TYPE_ERROR,
MSG_TYPE_CHUNK,
MSG_TYPE_END,
MAX_MESSAGE_SIZE,
)
from gunicorn.dirty.errors import (
DirtyError,
@ -26,118 +37,194 @@ from gunicorn.dirty.errors import (
)
class TestDirtyProtocolEncodeDecode:
"""Tests for encode/decode functionality."""
class TestBinaryProtocolHeader:
"""Tests for header encoding/decoding."""
def test_encode_decode_roundtrip(self):
"""Test basic encode/decode roundtrip."""
message = {"type": "request", "id": "123", "data": "hello"}
encoded = DirtyProtocol.encode(message)
def test_header_size(self):
"""Test header size is 16 bytes."""
assert HEADER_SIZE == 16
# Check header format
assert len(encoded) > DirtyProtocol.HEADER_SIZE
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
encoded[:DirtyProtocol.HEADER_SIZE]
)[0]
assert length == len(encoded) - DirtyProtocol.HEADER_SIZE
def test_encode_header(self):
"""Test header encoding."""
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100)
assert len(header) == HEADER_SIZE
assert header[:2] == MAGIC
assert header[2] == VERSION
assert header[3] == MSG_TYPE_REQUEST
# Decode payload
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header(self):
"""Test header decoding."""
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200)
msg_type, request_id, length = BinaryProtocol.decode_header(header)
assert msg_type == MSG_TYPE_RESPONSE
assert request_id == 67890
assert length == 200
def test_encode_decode_complex_data(self):
"""Test with complex nested data structures."""
message = {
"type": "response",
"id": "456",
"result": {
"models": ["gpt-4", "claude-3"],
"config": {"temperature": 0.7, "max_tokens": 1000},
"metadata": None,
},
"args": [1, 2, 3],
}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header_invalid_magic(self):
"""Test header decoding with invalid magic."""
header = b"XX" + b"\x01\x01" + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "magic" in str(exc_info.value).lower()
def test_encode_decode_unicode(self):
"""Test with unicode characters."""
message = {
"type": "request",
"data": "Hello, world!"
}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header_invalid_version(self):
"""Test header decoding with invalid version."""
header = MAGIC + b"\x99\x01" + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "version" in str(exc_info.value).lower()
def test_encode_large_message(self):
def test_decode_header_invalid_type(self):
"""Test header decoding with invalid message type."""
header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "type" in str(exc_info.value).lower()
def test_decode_header_too_large(self):
"""Test header decoding rejects too-large messages."""
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
MAX_MESSAGE_SIZE + 1, 0)
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "too large" in str(exc_info.value).lower()
def test_decode_header_too_short(self):
"""Test header decoding with too-short data."""
header = MAGIC + b"\x01"
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "short" in str(exc_info.value).lower()
class TestBinaryProtocolEncodeDecode:
"""Tests for message encoding/decoding."""
def test_encode_decode_request(self):
"""Test request encoding/decoding roundtrip."""
encoded = BinaryProtocol.encode_request(
request_id=12345,
app_path="myapp.ml:MLApp",
action="predict",
args=("data",),
kwargs={"temperature": 0.7}
)
assert len(encoded) > HEADER_SIZE
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "request"
assert request_id == 12345
assert payload["app_path"] == "myapp.ml:MLApp"
assert payload["action"] == "predict"
assert payload["args"] == ["data"]
assert payload["kwargs"] == {"temperature": 0.7}
def test_encode_decode_response(self):
"""Test response encoding/decoding roundtrip."""
result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}}
encoded = BinaryProtocol.encode_response(request_id=67890, result=result)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "response"
assert request_id == 67890
assert payload["result"] == result
def test_encode_decode_error(self):
"""Test error encoding/decoding roundtrip."""
error = DirtyTimeoutError("Timed out", timeout=30)
encoded = BinaryProtocol.encode_error(request_id=11111, error=error)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "error"
assert request_id == 11111
assert payload["error"]["error_type"] == "DirtyTimeoutError"
assert "Timed out" in payload["error"]["message"]
def test_encode_decode_chunk(self):
"""Test chunk encoding/decoding roundtrip."""
chunk_data = {"token": "hello", "index": 5}
encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "chunk"
assert request_id == 22222
assert payload["data"] == chunk_data
def test_encode_decode_end(self):
"""Test end message encoding/decoding roundtrip."""
encoded = BinaryProtocol.encode_end(request_id=33333)
assert len(encoded) == HEADER_SIZE # End has no payload
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "end"
assert request_id == 33333
assert payload == {}
def test_encode_decode_binary_data(self):
"""Test binary data passes through without base64 encoding."""
binary_data = bytes(range(256))
encoded = BinaryProtocol.encode_response(
request_id=44444,
result={"data": binary_data}
)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert payload["result"]["data"] == binary_data
def test_encode_decode_large_message(self):
"""Test encoding a large message."""
large_data = "x" * (1024 * 1024) # 1 MB of data
message = {"type": "request", "data": large_data}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
large_data = b"x" * (1024 * 1024) # 1 MB
encoded = BinaryProtocol.encode_response(
request_id=55555,
result={"data": large_data}
)
def test_encode_empty_dict(self):
"""Test encoding an empty dictionary."""
message = {}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_encode_message_too_large(self):
"""Test that encoding a message that's too large raises error."""
large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000)
message = {"data": large_data}
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.encode(message)
assert "too large" in str(exc_info.value)
def test_encode_non_serializable(self):
"""Test that encoding non-JSON-serializable data raises error."""
message = {"func": lambda x: x}
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.encode(message)
assert "Failed to encode" in str(exc_info.value)
def test_decode_invalid_json(self):
"""Test decoding invalid JSON raises error."""
invalid_data = b"not valid json"
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.decode(invalid_data)
assert "Failed to decode" in str(exc_info.value)
def test_decode_invalid_unicode(self):
"""Test decoding invalid unicode raises error."""
invalid_data = b"\x80\x81\x82"
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.decode(invalid_data)
assert "Failed to decode" in str(exc_info.value)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert payload["result"]["data"] == large_data
class TestDirtyProtocolSync:
class TestBinaryProtocolSync:
"""Tests for synchronous socket operations."""
def test_read_write_message(self):
"""Test read/write through socket pair."""
# Create a socket pair for testing
server_sock, client_sock = socket.socketpair()
try:
message = {"type": "request", "id": "123", "action": "test"}
message = make_request(
request_id=12345,
app_path="test:App",
action="run"
)
# Write message
DirtyProtocol.write_message(client_sock, message)
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
# Read message
received = DirtyProtocol.read_message(server_sock)
assert received == message
assert received["type"] == "request"
assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \
received["id"] == 12345
assert received["app_path"] == "test:App"
assert received["action"] == "run"
finally:
server_sock.close()
client_sock.close()
def test_read_write_with_int_id(self):
"""Test read/write with integer request ID."""
server_sock, client_sock = socket.socketpair()
try:
message = {
"type": "request",
"id": 999888777,
"app_path": "test:App",
"action": "run",
"args": [],
"kwargs": {}
}
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
assert received["id"] == 999888777
finally:
server_sock.close()
client_sock.close()
@ -147,19 +234,17 @@ class TestDirtyProtocolSync:
server_sock, client_sock = socket.socketpair()
try:
messages = [
{"type": "request", "id": "1"},
{"type": "request", "id": "2"},
{"type": "request", "id": "3"},
make_request(i, f"app{i}:App", f"action{i}")
for i in range(1, 4)
]
# Write all messages
for msg in messages:
DirtyProtocol.write_message(client_sock, msg)
BinaryProtocol.write_message(client_sock, msg)
# Read all messages
for expected in messages:
received = DirtyProtocol.read_message(server_sock)
assert received == expected
for i, _ in enumerate(messages, 1):
received = BinaryProtocol.read_message(server_sock)
assert received["app_path"] == f"app{i}:App"
assert received["action"] == f"action{i}"
finally:
server_sock.close()
client_sock.close()
@ -169,39 +254,51 @@ class TestDirtyProtocolSync:
server_sock, client_sock = socket.socketpair()
client_sock.close()
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.read_message(server_sock)
BinaryProtocol.read_message(server_sock)
assert "closed" in str(exc_info.value).lower()
server_sock.close()
def test_binary_data_roundtrip(self):
"""Test binary data roundtrip through socket."""
server_sock, client_sock = socket.socketpair()
try:
binary_payload = b"\x00\x01\x02\xff\xfe\xfd"
message = make_response(12345, {"binary": binary_payload})
class TestDirtyProtocolAsync:
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
assert received["result"]["binary"] == binary_payload
finally:
server_sock.close()
client_sock.close()
class TestBinaryProtocolAsync:
"""Tests for async stream operations."""
@pytest.mark.asyncio
async def test_async_read_write(self):
"""Test async read/write with mock streams."""
message = {"type": "request", "id": "123"}
message = make_request(12345, "test:App", "run")
# Create a pipe for testing
read_fd, write_fd = os.pipe()
try:
reader = asyncio.StreamReader()
_ = asyncio.StreamReaderProtocol(reader)
# Write the message to the pipe
encoded = DirtyProtocol.encode(message)
encoded = BinaryProtocol._encode_from_dict(message)
os.write(write_fd, encoded)
os.close(write_fd)
write_fd = None
# Feed data to reader
data = os.read(read_fd, len(encoded))
reader.feed_data(data)
reader.feed_eof()
# Read the message
received = await DirtyProtocol.read_message_async(reader)
assert received == message
received = await BinaryProtocol.read_message_async(reader)
assert received["type"] == "request"
assert received["app_path"] == "test:App"
finally:
if write_fd is not None:
os.close(write_fd)
@ -211,12 +308,11 @@ class TestDirtyProtocolAsync:
async def test_async_read_incomplete_header(self):
"""Test async read with incomplete header."""
reader = asyncio.StreamReader()
# Feed only 2 bytes instead of 4
reader.feed_data(b"\x00\x00")
reader.feed_data(MAGIC + b"\x01") # Only 3 bytes
reader.feed_eof()
with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)):
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
@pytest.mark.asyncio
async def test_async_read_empty_connection(self):
@ -225,36 +321,33 @@ class TestDirtyProtocolAsync:
reader.feed_eof()
with pytest.raises(asyncio.IncompleteReadError):
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
@pytest.mark.asyncio
async def test_async_read_invalid_magic(self):
"""Test async read rejects invalid magic."""
reader = asyncio.StreamReader()
header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await BinaryProtocol.read_message_async(reader)
assert "magic" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_async_read_message_too_large(self):
"""Test async read rejects too-large messages."""
reader = asyncio.StreamReader()
# Create a header claiming an absurdly large message
header = struct.pack(
DirtyProtocol.HEADER_FORMAT,
DirtyProtocol.MAX_MESSAGE_SIZE + 1000
)
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
MAX_MESSAGE_SIZE + 1000, 0)
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
assert "too large" in str(exc_info.value)
@pytest.mark.asyncio
async def test_async_read_empty_message(self):
"""Test async read rejects empty messages."""
reader = asyncio.StreamReader()
header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0)
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await DirtyProtocol.read_message_async(reader)
assert "Empty message" in str(exc_info.value)
class TestMessageBuilders:
"""Tests for message builder helper functions."""
@ -340,9 +433,9 @@ class TestMessageBuilders:
assert chunk["id"] == "req-456"
assert chunk["data"] == data
def test_make_chunk_message_with_list_data(self):
"""Test chunk message with list data."""
data = [1, 2, 3, "token"]
def test_make_chunk_message_with_binary_data(self):
"""Test chunk message with binary data."""
data = b"\x00\x01\x02\xff"
chunk = make_chunk_message("req-789", data)
assert chunk["data"] == data
@ -353,22 +446,22 @@ class TestMessageBuilders:
assert end["id"] == "req-123"
assert "data" not in end
def test_chunk_and_end_encode_decode(self):
def test_chunk_and_end_roundtrip(self):
"""Test that chunk and end messages can be encoded/decoded."""
chunk = make_chunk_message("req-123", {"token": "hello"})
end = make_end_message("req-123")
chunk = make_chunk_message(12345, {"token": "hello"})
end = make_end_message(12345)
# Test chunk roundtrip
encoded_chunk = DirtyProtocol.encode(chunk)
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == chunk
encoded_chunk = BinaryProtocol._encode_from_dict(chunk)
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk)
assert msg_type == "chunk"
assert payload["data"] == {"token": "hello"}
# Test end roundtrip
encoded_end = DirtyProtocol.encode(end)
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == end
encoded_end = BinaryProtocol._encode_from_dict(end)
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end)
assert msg_type == "end"
assert payload == {}
class TestDirtyErrors:
@ -417,3 +510,58 @@ class TestDirtyErrors:
assert error.action == "run"
assert error.traceback == "Traceback..."
assert "myapp:App" in str(error)
class TestBackwardsCompatibility:
"""Tests for backwards compatibility with old JSON API."""
def test_dirty_protocol_alias(self):
"""Test that DirtyProtocol is an alias for BinaryProtocol."""
assert DirtyProtocol is BinaryProtocol
def test_header_size_attribute(self):
"""Test HEADER_SIZE is accessible on class."""
assert DirtyProtocol.HEADER_SIZE == 16
def test_msg_type_constants(self):
"""Test message type constants are strings for compatibility."""
assert DirtyProtocol.MSG_TYPE_REQUEST == "request"
assert DirtyProtocol.MSG_TYPE_RESPONSE == "response"
assert DirtyProtocol.MSG_TYPE_ERROR == "error"
assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk"
assert DirtyProtocol.MSG_TYPE_END == "end"
def test_encode_decode_preserves_dict_format(self):
"""Test that read_message returns dict compatible with old API."""
server_sock, client_sock = socket.socketpair()
try:
message = {
"type": "response",
"id": 12345,
"result": {"status": "ok"}
}
DirtyProtocol.write_message(client_sock, message)
received = DirtyProtocol.read_message(server_sock)
# Old API: access via dict keys
assert received["type"] == "response"
assert received["result"]["status"] == "ok"
finally:
server_sock.close()
client_sock.close()
def test_string_request_id_handled(self):
"""Test that string request IDs are handled (hashed to int)."""
server_sock, client_sock = socket.socketpair()
try:
message = make_request("uuid-string-id", "test:App", "run")
DirtyProtocol.write_message(client_sock, message)
received = DirtyProtocol.read_message(server_sock)
# Request ID should be converted to int
assert isinstance(received["id"], int)
finally:
server_sock.close()
client_sock.close()