mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-04 03:31:29 +08:00
feat(dirty): implement binary protocol
Replace JSON-based protocol with binary format using 16-byte header: - Magic bytes (GD), version, message type, payload length, request ID - TLV-encoded payloads for efficient binary data transfer - No base64 encoding needed for binary data - Backwards compatible API (DirtyProtocol alias, dict-based interface) Header format inspired by OpenBSD msgctl/msgsnd.
This commit is contained in:
parent
0e0dc669c8
commit
1665857c0e
@ -3,89 +3,304 @@
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Dirty Arbiters Protocol
|
||||
Dirty Worker Binary Protocol
|
||||
|
||||
Length-prefixed JSON message framing over Unix sockets.
|
||||
Provides both async (primary) and sync (for HTTP workers) APIs.
|
||||
Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd.
|
||||
Replaces JSON protocol for efficient binary data transfer.
|
||||
|
||||
Message Format:
|
||||
+----------------+------------------+
|
||||
| 4-byte length | JSON payload |
|
||||
+----------------+------------------+
|
||||
Header Format (16 bytes):
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Request ID (8 bytes) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
|
||||
The length field is a 4-byte unsigned integer in network byte order (big-endian).
|
||||
- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty)
|
||||
- Version: 0x01
|
||||
- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END)
|
||||
- Length: Payload size (big-endian uint32, max 64MB)
|
||||
- Request ID: uint64 (replaces UUID string)
|
||||
|
||||
Payload is TLV-encoded (see tlv.py).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import struct
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from .errors import DirtyProtocolError
|
||||
from .tlv import TLVEncoder
|
||||
|
||||
|
||||
class DirtyProtocol:
|
||||
"""Length-prefixed JSON messages over Unix sockets."""
|
||||
# Protocol constants
|
||||
MAGIC = b"GD" # 0x47 0x44
|
||||
VERSION = 0x01
|
||||
|
||||
# 4-byte unsigned int, network byte order (big-endian)
|
||||
HEADER_FORMAT = "!I"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
# Message types (1 byte)
|
||||
MSG_TYPE_REQUEST = 0x01
|
||||
MSG_TYPE_RESPONSE = 0x02
|
||||
MSG_TYPE_ERROR = 0x03
|
||||
MSG_TYPE_CHUNK = 0x04
|
||||
MSG_TYPE_END = 0x05
|
||||
|
||||
# Maximum message size (64 MB)
|
||||
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
|
||||
# Message type names (for backwards compatibility with old API)
|
||||
MSG_TYPE_REQUEST_STR = "request"
|
||||
MSG_TYPE_RESPONSE_STR = "response"
|
||||
MSG_TYPE_ERROR_STR = "error"
|
||||
MSG_TYPE_CHUNK_STR = "chunk"
|
||||
MSG_TYPE_END_STR = "end"
|
||||
|
||||
# Message types for future streaming support
|
||||
MSG_TYPE_REQUEST = "request"
|
||||
MSG_TYPE_RESPONSE = "response"
|
||||
MSG_TYPE_ERROR = "error"
|
||||
MSG_TYPE_CHUNK = "chunk"
|
||||
MSG_TYPE_END = "end"
|
||||
# Map int types to string names
|
||||
MSG_TYPE_TO_STR = {
|
||||
MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR,
|
||||
MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR,
|
||||
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
|
||||
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
|
||||
MSG_TYPE_END: MSG_TYPE_END_STR,
|
||||
}
|
||||
|
||||
# Map string names to int types
|
||||
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
|
||||
|
||||
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
|
||||
HEADER_FORMAT = ">2sBBIQ"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
|
||||
# Maximum message size (64 MB)
|
||||
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
|
||||
|
||||
|
||||
class BinaryProtocol:
|
||||
"""Binary message protocol for dirty worker IPC."""
|
||||
|
||||
# Export constants for external use
|
||||
HEADER_SIZE = HEADER_SIZE
|
||||
MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE
|
||||
|
||||
MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR
|
||||
MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR
|
||||
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
|
||||
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
|
||||
MSG_TYPE_END = MSG_TYPE_END_STR
|
||||
|
||||
@staticmethod
|
||||
def encode(message: dict) -> bytes:
|
||||
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
|
||||
"""
|
||||
Encode a message dict to length-prefixed bytes.
|
||||
Encode the 16-byte message header.
|
||||
|
||||
Args:
|
||||
message: Dictionary to encode as JSON
|
||||
msg_type: Message type (MSG_TYPE_REQUEST, etc.)
|
||||
request_id: Unique request identifier (uint64)
|
||||
payload_length: Length of the TLV-encoded payload
|
||||
|
||||
Returns:
|
||||
bytes: Length-prefixed encoded message
|
||||
bytes: 16-byte header
|
||||
"""
|
||||
return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type,
|
||||
payload_length, request_id)
|
||||
|
||||
@staticmethod
|
||||
def decode_header(data: bytes) -> tuple:
|
||||
"""
|
||||
Decode the 16-byte message header.
|
||||
|
||||
Args:
|
||||
data: 16 bytes of header data
|
||||
|
||||
Returns:
|
||||
tuple: (msg_type, request_id, payload_length)
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
DirtyProtocolError: If header is invalid
|
||||
"""
|
||||
try:
|
||||
payload = json.dumps(message).encode("utf-8")
|
||||
if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
if len(data) < HEADER_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}",
|
||||
raw_data=data
|
||||
)
|
||||
|
||||
magic, version, msg_type, length, request_id = struct.unpack(
|
||||
HEADER_FORMAT, data[:HEADER_SIZE]
|
||||
)
|
||||
|
||||
if magic != MAGIC:
|
||||
raise DirtyProtocolError(
|
||||
f"Invalid magic: {magic!r}, expected {MAGIC!r}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if version != VERSION:
|
||||
raise DirtyProtocolError(
|
||||
f"Unsupported protocol version: {version}, expected {VERSION}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if msg_type not in MSG_TYPE_TO_STR:
|
||||
raise DirtyProtocolError(
|
||||
f"Unknown message type: 0x{msg_type:02x}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if length > MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
return msg_type, request_id, length
|
||||
|
||||
@staticmethod
|
||||
def encode_request(request_id: int, app_path: str, action: str,
|
||||
args: tuple = None, kwargs: dict = None) -> bytes:
|
||||
"""
|
||||
Encode a request message.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (uint64)
|
||||
app_path: Import path of the dirty app
|
||||
action: Action to call on the app
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {
|
||||
"app_path": app_path,
|
||||
"action": action,
|
||||
"args": list(args) if args else [],
|
||||
"kwargs": kwargs or {},
|
||||
}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_response(request_id: int, result) -> bytes:
|
||||
"""
|
||||
Encode a success response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
result: Result value (must be TLV-serializable)
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {"result": result}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_error(request_id: int, error) -> bytes:
|
||||
"""
|
||||
Encode an error response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
error: DirtyError instance, dict, or Exception
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
from .errors import DirtyError
|
||||
|
||||
if isinstance(error, DirtyError):
|
||||
error_dict = error.to_dict()
|
||||
elif isinstance(error, dict):
|
||||
error_dict = error
|
||||
else:
|
||||
error_dict = {
|
||||
"error_type": type(error).__name__,
|
||||
"message": str(error),
|
||||
"details": {},
|
||||
}
|
||||
|
||||
payload_dict = {"error": error_dict}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_chunk(request_id: int, data) -> bytes:
|
||||
"""
|
||||
Encode a chunk message for streaming responses.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this chunk belongs to
|
||||
data: Chunk data (must be TLV-serializable)
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {"data": data}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_end(request_id: int) -> bytes:
|
||||
"""
|
||||
Encode an end-of-stream message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this ends
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + empty payload)
|
||||
"""
|
||||
# End message has empty payload
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
|
||||
return header
|
||||
|
||||
@staticmethod
|
||||
def decode_message(data: bytes) -> tuple:
|
||||
"""
|
||||
Decode a complete message (header + payload).
|
||||
|
||||
Args:
|
||||
data: Complete message bytes
|
||||
|
||||
Returns:
|
||||
tuple: (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str is the string name (e.g., "request")
|
||||
payload_dict is the decoded TLV payload as a dict
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If message is malformed
|
||||
"""
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(data)
|
||||
|
||||
if len(data) < HEADER_SIZE + length:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete message: expected {HEADER_SIZE + length} bytes, "
|
||||
f"got {len(data)}",
|
||||
raw_data=data[:50]
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
# End message has empty payload
|
||||
payload_dict = {}
|
||||
else:
|
||||
payload_data = data[HEADER_SIZE:HEADER_SIZE + length]
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {len(payload)} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload))
|
||||
return length + payload
|
||||
except (TypeError, ValueError) as e:
|
||||
raise DirtyProtocolError(f"Failed to encode message: {e}")
|
||||
|
||||
@staticmethod
|
||||
def decode(data: bytes) -> dict:
|
||||
"""
|
||||
Decode bytes (without length prefix) to message dict.
|
||||
# Convert to dict format similar to old JSON protocol
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
|
||||
Args:
|
||||
data: JSON bytes to decode
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If decoding fails
|
||||
"""
|
||||
try:
|
||||
return json.loads(data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise DirtyProtocolError(f"Failed to decode message: {e}",
|
||||
raw_data=data)
|
||||
return msg_type_str, request_id, payload_dict
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async API (primary - for DirtyArbiter and DirtyWorker)
|
||||
@ -94,53 +309,62 @@ class DirtyProtocol:
|
||||
@staticmethod
|
||||
async def read_message_async(reader: asyncio.StreamReader) -> dict:
|
||||
"""
|
||||
Read a complete message from async stream.
|
||||
Read a complete binary message from async stream.
|
||||
|
||||
Args:
|
||||
reader: asyncio StreamReader
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
dict: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If read fails or message is malformed
|
||||
asyncio.IncompleteReadError: If connection closed mid-read
|
||||
"""
|
||||
# Read length header
|
||||
# Read header
|
||||
try:
|
||||
header = await reader.readexactly(DirtyProtocol.HEADER_SIZE)
|
||||
header = await reader.readexactly(HEADER_SIZE)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
if len(e.partial) == 0:
|
||||
# Clean close - no data was read
|
||||
raise
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete header: got {len(e.partial)} bytes, "
|
||||
f"expected {DirtyProtocol.HEADER_SIZE}",
|
||||
f"expected {HEADER_SIZE}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
|
||||
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
|
||||
|
||||
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
raise DirtyProtocolError("Empty message received")
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
|
||||
# Read payload
|
||||
try:
|
||||
payload = await reader.readexactly(length)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete message: got {len(e.partial)} bytes, "
|
||||
f"expected {length}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
if length > 0:
|
||||
try:
|
||||
payload_data = await reader.readexactly(length)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete payload: got {len(e.partial)} bytes, "
|
||||
f"expected {length}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
|
||||
return DirtyProtocol.decode(payload)
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
else:
|
||||
payload_dict = {}
|
||||
|
||||
# Build response dict
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def write_message_async(writer: asyncio.StreamWriter,
|
||||
@ -148,15 +372,17 @@ class DirtyProtocol:
|
||||
"""
|
||||
Write a message to async stream.
|
||||
|
||||
Accepts dict format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
writer: asyncio StreamWriter
|
||||
message: Dictionary to send
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
ConnectionError: If write fails
|
||||
"""
|
||||
data = DirtyProtocol.encode(message)
|
||||
data = BinaryProtocol._encode_from_dict(message)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
|
||||
@ -201,27 +427,36 @@ class DirtyProtocol:
|
||||
sock: Socket to read from
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
dict: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If read fails or message is malformed
|
||||
"""
|
||||
# Read length header
|
||||
header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE)
|
||||
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
|
||||
|
||||
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
raise DirtyProtocolError("Empty message received")
|
||||
# Read header
|
||||
header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE)
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
|
||||
# Read payload
|
||||
payload = DirtyProtocol._recv_exactly(sock, length)
|
||||
return DirtyProtocol.decode(payload)
|
||||
if length > 0:
|
||||
payload_data = BinaryProtocol._recv_exactly(sock, length)
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
else:
|
||||
payload_dict = {}
|
||||
|
||||
# Build response dict
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def write_message(sock: socket.socket, message: dict) -> None:
|
||||
@ -230,31 +465,92 @@ class DirtyProtocol:
|
||||
|
||||
Args:
|
||||
sock: Socket to write to
|
||||
message: Dictionary to send
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
OSError: If write fails
|
||||
"""
|
||||
data = DirtyProtocol.encode(message)
|
||||
data = BinaryProtocol._encode_from_dict(message)
|
||||
sock.sendall(data)
|
||||
|
||||
@staticmethod
|
||||
def _encode_from_dict(message: dict) -> bytes:
|
||||
"""
|
||||
Encode a message dict to binary format.
|
||||
|
||||
# Message builder helpers
|
||||
def make_request(request_id: str, app_path: str, action: str,
|
||||
Supports the old dict-based API for backwards compatibility.
|
||||
|
||||
Args:
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Returns:
|
||||
bytes: Complete encoded message
|
||||
"""
|
||||
msg_type_str = message.get("type")
|
||||
request_id = message.get("id", 0)
|
||||
|
||||
# Handle string or int request IDs
|
||||
if isinstance(request_id, str):
|
||||
# For backwards compat with UUID strings, hash to int
|
||||
request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
msg_type = MSG_TYPE_FROM_STR.get(msg_type_str)
|
||||
if msg_type is None:
|
||||
raise DirtyProtocolError(f"Unknown message type: {msg_type_str}")
|
||||
|
||||
if msg_type == MSG_TYPE_REQUEST:
|
||||
return BinaryProtocol.encode_request(
|
||||
request_id,
|
||||
message.get("app_path", ""),
|
||||
message.get("action", ""),
|
||||
message.get("args"),
|
||||
message.get("kwargs")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_RESPONSE:
|
||||
return BinaryProtocol.encode_response(
|
||||
request_id,
|
||||
message.get("result")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_ERROR:
|
||||
return BinaryProtocol.encode_error(
|
||||
request_id,
|
||||
message.get("error", {})
|
||||
)
|
||||
elif msg_type == MSG_TYPE_CHUNK:
|
||||
return BinaryProtocol.encode_chunk(
|
||||
request_id,
|
||||
message.get("data")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_END:
|
||||
return BinaryProtocol.encode_end(request_id)
|
||||
else:
|
||||
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backwards Compatibility Aliases
|
||||
# =============================================================================
|
||||
|
||||
# Alias BinaryProtocol as DirtyProtocol for drop-in replacement
|
||||
DirtyProtocol = BinaryProtocol
|
||||
|
||||
|
||||
# Message builder helpers (backwards compatible with old API)
|
||||
def make_request(request_id, app_path: str, action: str,
|
||||
args: tuple = None, kwargs: dict = None) -> dict:
|
||||
"""
|
||||
Build a request message.
|
||||
Build a request message dict.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier
|
||||
request_id: Unique request identifier (int or str)
|
||||
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
|
||||
action: Action to call on the app
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
dict: Request message
|
||||
dict: Request message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_REQUEST,
|
||||
@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str,
|
||||
}
|
||||
|
||||
|
||||
def make_response(request_id: str, result) -> dict:
|
||||
def make_response(request_id, result) -> dict:
|
||||
"""
|
||||
Build a success response message.
|
||||
Build a success response message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
result: Result value (must be JSON-serializable)
|
||||
result: Result value
|
||||
|
||||
Returns:
|
||||
dict: Response message
|
||||
dict: Response message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_RESPONSE,
|
||||
@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_error_response(request_id: str, error) -> dict:
|
||||
def make_error_response(request_id, error) -> dict:
|
||||
"""
|
||||
Build an error response message.
|
||||
Build an error response message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
error: DirtyError instance or dict with error info
|
||||
|
||||
Returns:
|
||||
dict: Error response message
|
||||
dict: Error response message dict
|
||||
"""
|
||||
from .errors import DirtyError
|
||||
if isinstance(error, DirtyError):
|
||||
@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_chunk_message(request_id: str, data) -> dict:
|
||||
def make_chunk_message(request_id, data) -> dict:
|
||||
"""
|
||||
Build a chunk message for streaming responses.
|
||||
Build a chunk message dict for streaming responses.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this chunk belongs to
|
||||
data: Chunk data (must be JSON-serializable)
|
||||
data: Chunk data
|
||||
|
||||
Returns:
|
||||
dict: Chunk message
|
||||
dict: Chunk message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_CHUNK,
|
||||
@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_end_message(request_id: str) -> dict:
|
||||
def make_end_message(request_id) -> dict:
|
||||
"""
|
||||
Build an end-of-stream message.
|
||||
Build an end-of-stream message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this ends
|
||||
|
||||
Returns:
|
||||
dict: End message
|
||||
dict: End message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_END,
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty arbiter protocol module."""
|
||||
"""Tests for dirty worker binary protocol module."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@ -11,12 +11,23 @@ import struct
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
BinaryProtocol,
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_error_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
MAGIC,
|
||||
VERSION,
|
||||
HEADER_SIZE,
|
||||
HEADER_FORMAT,
|
||||
MSG_TYPE_REQUEST,
|
||||
MSG_TYPE_RESPONSE,
|
||||
MSG_TYPE_ERROR,
|
||||
MSG_TYPE_CHUNK,
|
||||
MSG_TYPE_END,
|
||||
MAX_MESSAGE_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.errors import (
|
||||
DirtyError,
|
||||
@ -26,118 +37,194 @@ from gunicorn.dirty.errors import (
|
||||
)
|
||||
|
||||
|
||||
class TestDirtyProtocolEncodeDecode:
|
||||
"""Tests for encode/decode functionality."""
|
||||
class TestBinaryProtocolHeader:
|
||||
"""Tests for header encoding/decoding."""
|
||||
|
||||
def test_encode_decode_roundtrip(self):
|
||||
"""Test basic encode/decode roundtrip."""
|
||||
message = {"type": "request", "id": "123", "data": "hello"}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
def test_header_size(self):
|
||||
"""Test header size is 16 bytes."""
|
||||
assert HEADER_SIZE == 16
|
||||
|
||||
# Check header format
|
||||
assert len(encoded) > DirtyProtocol.HEADER_SIZE
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
encoded[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
assert length == len(encoded) - DirtyProtocol.HEADER_SIZE
|
||||
def test_encode_header(self):
|
||||
"""Test header encoding."""
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100)
|
||||
assert len(header) == HEADER_SIZE
|
||||
assert header[:2] == MAGIC
|
||||
assert header[2] == VERSION
|
||||
assert header[3] == MSG_TYPE_REQUEST
|
||||
|
||||
# Decode payload
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header(self):
|
||||
"""Test header decoding."""
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200)
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
assert msg_type == MSG_TYPE_RESPONSE
|
||||
assert request_id == 67890
|
||||
assert length == 200
|
||||
|
||||
def test_encode_decode_complex_data(self):
|
||||
"""Test with complex nested data structures."""
|
||||
message = {
|
||||
"type": "response",
|
||||
"id": "456",
|
||||
"result": {
|
||||
"models": ["gpt-4", "claude-3"],
|
||||
"config": {"temperature": 0.7, "max_tokens": 1000},
|
||||
"metadata": None,
|
||||
},
|
||||
"args": [1, 2, 3],
|
||||
}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header_invalid_magic(self):
|
||||
"""Test header decoding with invalid magic."""
|
||||
header = b"XX" + b"\x01\x01" + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "magic" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_decode_unicode(self):
|
||||
"""Test with unicode characters."""
|
||||
message = {
|
||||
"type": "request",
|
||||
"data": "Hello, world!"
|
||||
}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header_invalid_version(self):
|
||||
"""Test header decoding with invalid version."""
|
||||
header = MAGIC + b"\x99\x01" + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "version" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_large_message(self):
|
||||
def test_decode_header_invalid_type(self):
|
||||
"""Test header decoding with invalid message type."""
|
||||
header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "type" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_header_too_large(self):
|
||||
"""Test header decoding rejects too-large messages."""
|
||||
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
||||
MAX_MESSAGE_SIZE + 1, 0)
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "too large" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_header_too_short(self):
|
||||
"""Test header decoding with too-short data."""
|
||||
header = MAGIC + b"\x01"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "short" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestBinaryProtocolEncodeDecode:
|
||||
"""Tests for message encoding/decoding."""
|
||||
|
||||
def test_encode_decode_request(self):
|
||||
"""Test request encoding/decoding roundtrip."""
|
||||
encoded = BinaryProtocol.encode_request(
|
||||
request_id=12345,
|
||||
app_path="myapp.ml:MLApp",
|
||||
action="predict",
|
||||
args=("data",),
|
||||
kwargs={"temperature": 0.7}
|
||||
)
|
||||
assert len(encoded) > HEADER_SIZE
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "request"
|
||||
assert request_id == 12345
|
||||
assert payload["app_path"] == "myapp.ml:MLApp"
|
||||
assert payload["action"] == "predict"
|
||||
assert payload["args"] == ["data"]
|
||||
assert payload["kwargs"] == {"temperature": 0.7}
|
||||
|
||||
def test_encode_decode_response(self):
|
||||
"""Test response encoding/decoding roundtrip."""
|
||||
result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}}
|
||||
encoded = BinaryProtocol.encode_response(request_id=67890, result=result)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "response"
|
||||
assert request_id == 67890
|
||||
assert payload["result"] == result
|
||||
|
||||
def test_encode_decode_error(self):
|
||||
"""Test error encoding/decoding roundtrip."""
|
||||
error = DirtyTimeoutError("Timed out", timeout=30)
|
||||
encoded = BinaryProtocol.encode_error(request_id=11111, error=error)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "error"
|
||||
assert request_id == 11111
|
||||
assert payload["error"]["error_type"] == "DirtyTimeoutError"
|
||||
assert "Timed out" in payload["error"]["message"]
|
||||
|
||||
def test_encode_decode_chunk(self):
|
||||
"""Test chunk encoding/decoding roundtrip."""
|
||||
chunk_data = {"token": "hello", "index": 5}
|
||||
encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "chunk"
|
||||
assert request_id == 22222
|
||||
assert payload["data"] == chunk_data
|
||||
|
||||
def test_encode_decode_end(self):
|
||||
"""Test end message encoding/decoding roundtrip."""
|
||||
encoded = BinaryProtocol.encode_end(request_id=33333)
|
||||
assert len(encoded) == HEADER_SIZE # End has no payload
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "end"
|
||||
assert request_id == 33333
|
||||
assert payload == {}
|
||||
|
||||
def test_encode_decode_binary_data(self):
|
||||
"""Test binary data passes through without base64 encoding."""
|
||||
binary_data = bytes(range(256))
|
||||
encoded = BinaryProtocol.encode_response(
|
||||
request_id=44444,
|
||||
result={"data": binary_data}
|
||||
)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["result"]["data"] == binary_data
|
||||
|
||||
def test_encode_decode_large_message(self):
|
||||
"""Test encoding a large message."""
|
||||
large_data = "x" * (1024 * 1024) # 1 MB of data
|
||||
message = {"type": "request", "data": large_data}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
large_data = b"x" * (1024 * 1024) # 1 MB
|
||||
encoded = BinaryProtocol.encode_response(
|
||||
request_id=55555,
|
||||
result={"data": large_data}
|
||||
)
|
||||
|
||||
def test_encode_empty_dict(self):
|
||||
"""Test encoding an empty dictionary."""
|
||||
message = {}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
|
||||
def test_encode_message_too_large(self):
|
||||
"""Test that encoding a message that's too large raises error."""
|
||||
large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000)
|
||||
message = {"data": large_data}
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.encode(message)
|
||||
assert "too large" in str(exc_info.value)
|
||||
|
||||
def test_encode_non_serializable(self):
|
||||
"""Test that encoding non-JSON-serializable data raises error."""
|
||||
message = {"func": lambda x: x}
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.encode(message)
|
||||
assert "Failed to encode" in str(exc_info.value)
|
||||
|
||||
def test_decode_invalid_json(self):
|
||||
"""Test decoding invalid JSON raises error."""
|
||||
invalid_data = b"not valid json"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.decode(invalid_data)
|
||||
assert "Failed to decode" in str(exc_info.value)
|
||||
|
||||
def test_decode_invalid_unicode(self):
|
||||
"""Test decoding invalid unicode raises error."""
|
||||
invalid_data = b"\x80\x81\x82"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.decode(invalid_data)
|
||||
assert "Failed to decode" in str(exc_info.value)
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["result"]["data"] == large_data
|
||||
|
||||
|
||||
class TestDirtyProtocolSync:
|
||||
class TestBinaryProtocolSync:
|
||||
"""Tests for synchronous socket operations."""
|
||||
|
||||
def test_read_write_message(self):
|
||||
"""Test read/write through socket pair."""
|
||||
# Create a socket pair for testing
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {"type": "request", "id": "123", "action": "test"}
|
||||
message = make_request(
|
||||
request_id=12345,
|
||||
app_path="test:App",
|
||||
action="run"
|
||||
)
|
||||
|
||||
# Write message
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
# Read message
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
assert received == message
|
||||
assert received["type"] == "request"
|
||||
assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \
|
||||
received["id"] == 12345
|
||||
assert received["app_path"] == "test:App"
|
||||
assert received["action"] == "run"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
def test_read_write_with_int_id(self):
|
||||
"""Test read/write with integer request ID."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {
|
||||
"type": "request",
|
||||
"id": 999888777,
|
||||
"app_path": "test:App",
|
||||
"action": "run",
|
||||
"args": [],
|
||||
"kwargs": {}
|
||||
}
|
||||
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
assert received["id"] == 999888777
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
@ -147,19 +234,17 @@ class TestDirtyProtocolSync:
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
messages = [
|
||||
{"type": "request", "id": "1"},
|
||||
{"type": "request", "id": "2"},
|
||||
{"type": "request", "id": "3"},
|
||||
make_request(i, f"app{i}:App", f"action{i}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Write all messages
|
||||
for msg in messages:
|
||||
DirtyProtocol.write_message(client_sock, msg)
|
||||
BinaryProtocol.write_message(client_sock, msg)
|
||||
|
||||
# Read all messages
|
||||
for expected in messages:
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
assert received == expected
|
||||
for i, _ in enumerate(messages, 1):
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
assert received["app_path"] == f"app{i}:App"
|
||||
assert received["action"] == f"action{i}"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
@ -169,39 +254,51 @@ class TestDirtyProtocolSync:
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
client_sock.close()
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.read_message(server_sock)
|
||||
BinaryProtocol.read_message(server_sock)
|
||||
assert "closed" in str(exc_info.value).lower()
|
||||
server_sock.close()
|
||||
|
||||
def test_binary_data_roundtrip(self):
|
||||
"""Test binary data roundtrip through socket."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
binary_payload = b"\x00\x01\x02\xff\xfe\xfd"
|
||||
message = make_response(12345, {"binary": binary_payload})
|
||||
|
||||
class TestDirtyProtocolAsync:
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
assert received["result"]["binary"] == binary_payload
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
|
||||
class TestBinaryProtocolAsync:
|
||||
"""Tests for async stream operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_write(self):
|
||||
"""Test async read/write with mock streams."""
|
||||
message = {"type": "request", "id": "123"}
|
||||
message = make_request(12345, "test:App", "run")
|
||||
|
||||
# Create a pipe for testing
|
||||
read_fd, write_fd = os.pipe()
|
||||
try:
|
||||
reader = asyncio.StreamReader()
|
||||
_ = asyncio.StreamReaderProtocol(reader)
|
||||
|
||||
# Write the message to the pipe
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
encoded = BinaryProtocol._encode_from_dict(message)
|
||||
os.write(write_fd, encoded)
|
||||
os.close(write_fd)
|
||||
write_fd = None
|
||||
|
||||
# Feed data to reader
|
||||
data = os.read(read_fd, len(encoded))
|
||||
reader.feed_data(data)
|
||||
reader.feed_eof()
|
||||
|
||||
# Read the message
|
||||
received = await DirtyProtocol.read_message_async(reader)
|
||||
assert received == message
|
||||
received = await BinaryProtocol.read_message_async(reader)
|
||||
assert received["type"] == "request"
|
||||
assert received["app_path"] == "test:App"
|
||||
finally:
|
||||
if write_fd is not None:
|
||||
os.close(write_fd)
|
||||
@ -211,12 +308,11 @@ class TestDirtyProtocolAsync:
|
||||
async def test_async_read_incomplete_header(self):
|
||||
"""Test async read with incomplete header."""
|
||||
reader = asyncio.StreamReader()
|
||||
# Feed only 2 bytes instead of 4
|
||||
reader.feed_data(b"\x00\x00")
|
||||
reader.feed_data(MAGIC + b"\x01") # Only 3 bytes
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)):
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_empty_connection(self):
|
||||
@ -225,36 +321,33 @@ class TestDirtyProtocolAsync:
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(asyncio.IncompleteReadError):
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_invalid_magic(self):
|
||||
"""Test async read rejects invalid magic."""
|
||||
reader = asyncio.StreamReader()
|
||||
header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
assert "magic" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_message_too_large(self):
|
||||
"""Test async read rejects too-large messages."""
|
||||
reader = asyncio.StreamReader()
|
||||
# Create a header claiming an absurdly large message
|
||||
header = struct.pack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
DirtyProtocol.MAX_MESSAGE_SIZE + 1000
|
||||
)
|
||||
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
||||
MAX_MESSAGE_SIZE + 1000, 0)
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
assert "too large" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_empty_message(self):
|
||||
"""Test async read rejects empty messages."""
|
||||
reader = asyncio.StreamReader()
|
||||
header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0)
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
assert "Empty message" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestMessageBuilders:
|
||||
"""Tests for message builder helper functions."""
|
||||
@ -340,9 +433,9 @@ class TestMessageBuilders:
|
||||
assert chunk["id"] == "req-456"
|
||||
assert chunk["data"] == data
|
||||
|
||||
def test_make_chunk_message_with_list_data(self):
|
||||
"""Test chunk message with list data."""
|
||||
data = [1, 2, 3, "token"]
|
||||
def test_make_chunk_message_with_binary_data(self):
|
||||
"""Test chunk message with binary data."""
|
||||
data = b"\x00\x01\x02\xff"
|
||||
chunk = make_chunk_message("req-789", data)
|
||||
assert chunk["data"] == data
|
||||
|
||||
@ -353,22 +446,22 @@ class TestMessageBuilders:
|
||||
assert end["id"] == "req-123"
|
||||
assert "data" not in end
|
||||
|
||||
def test_chunk_and_end_encode_decode(self):
|
||||
def test_chunk_and_end_roundtrip(self):
|
||||
"""Test that chunk and end messages can be encoded/decoded."""
|
||||
chunk = make_chunk_message("req-123", {"token": "hello"})
|
||||
end = make_end_message("req-123")
|
||||
chunk = make_chunk_message(12345, {"token": "hello"})
|
||||
end = make_end_message(12345)
|
||||
|
||||
# Test chunk roundtrip
|
||||
encoded_chunk = DirtyProtocol.encode(chunk)
|
||||
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == chunk
|
||||
encoded_chunk = BinaryProtocol._encode_from_dict(chunk)
|
||||
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk)
|
||||
assert msg_type == "chunk"
|
||||
assert payload["data"] == {"token": "hello"}
|
||||
|
||||
# Test end roundtrip
|
||||
encoded_end = DirtyProtocol.encode(end)
|
||||
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == end
|
||||
encoded_end = BinaryProtocol._encode_from_dict(end)
|
||||
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end)
|
||||
assert msg_type == "end"
|
||||
assert payload == {}
|
||||
|
||||
|
||||
class TestDirtyErrors:
|
||||
@ -417,3 +510,58 @@ class TestDirtyErrors:
|
||||
assert error.action == "run"
|
||||
assert error.traceback == "Traceback..."
|
||||
assert "myapp:App" in str(error)
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
"""Tests for backwards compatibility with old JSON API."""
|
||||
|
||||
def test_dirty_protocol_alias(self):
|
||||
"""Test that DirtyProtocol is an alias for BinaryProtocol."""
|
||||
assert DirtyProtocol is BinaryProtocol
|
||||
|
||||
def test_header_size_attribute(self):
|
||||
"""Test HEADER_SIZE is accessible on class."""
|
||||
assert DirtyProtocol.HEADER_SIZE == 16
|
||||
|
||||
def test_msg_type_constants(self):
|
||||
"""Test message type constants are strings for compatibility."""
|
||||
assert DirtyProtocol.MSG_TYPE_REQUEST == "request"
|
||||
assert DirtyProtocol.MSG_TYPE_RESPONSE == "response"
|
||||
assert DirtyProtocol.MSG_TYPE_ERROR == "error"
|
||||
assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk"
|
||||
assert DirtyProtocol.MSG_TYPE_END == "end"
|
||||
|
||||
def test_encode_decode_preserves_dict_format(self):
|
||||
"""Test that read_message returns dict compatible with old API."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {
|
||||
"type": "response",
|
||||
"id": 12345,
|
||||
"result": {"status": "ok"}
|
||||
}
|
||||
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
|
||||
# Old API: access via dict keys
|
||||
assert received["type"] == "response"
|
||||
assert received["result"]["status"] == "ok"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
def test_string_request_id_handled(self):
|
||||
"""Test that string request IDs are handled (hashed to int)."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = make_request("uuid-string-id", "test:App", "run")
|
||||
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
|
||||
# Request ID should be converted to int
|
||||
assert isinstance(received["id"], int)
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user