mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
Replace JSON-based protocol with binary format using 16-byte header: - Magic bytes (GD), version, message type, payload length, request ID - TLV-encoded payloads for efficient binary data transfer - No base64 encoding needed for binary data - Backwards compatible API (DirtyProtocol alias, dict-based interface) Header format inspired by OpenBSD msgctl/msgsnd.
568 lines
20 KiB
Python
568 lines
20 KiB
Python
#
|
|
# This file is part of gunicorn released under the MIT license.
|
|
# See the NOTICE for more information.
|
|
|
|
"""Tests for dirty worker binary protocol module."""
|
|
|
|
import asyncio
|
|
import os
|
|
import socket
|
|
import struct
|
|
import pytest
|
|
|
|
from gunicorn.dirty.protocol import (
|
|
BinaryProtocol,
|
|
DirtyProtocol,
|
|
make_request,
|
|
make_response,
|
|
make_error_response,
|
|
make_chunk_message,
|
|
make_end_message,
|
|
MAGIC,
|
|
VERSION,
|
|
HEADER_SIZE,
|
|
HEADER_FORMAT,
|
|
MSG_TYPE_REQUEST,
|
|
MSG_TYPE_RESPONSE,
|
|
MSG_TYPE_ERROR,
|
|
MSG_TYPE_CHUNK,
|
|
MSG_TYPE_END,
|
|
MAX_MESSAGE_SIZE,
|
|
)
|
|
from gunicorn.dirty.errors import (
|
|
DirtyError,
|
|
DirtyProtocolError,
|
|
DirtyTimeoutError,
|
|
DirtyAppError,
|
|
)
|
|
|
|
|
|
class TestBinaryProtocolHeader:
|
|
"""Tests for header encoding/decoding."""
|
|
|
|
def test_header_size(self):
|
|
"""Test header size is 16 bytes."""
|
|
assert HEADER_SIZE == 16
|
|
|
|
def test_encode_header(self):
|
|
"""Test header encoding."""
|
|
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100)
|
|
assert len(header) == HEADER_SIZE
|
|
assert header[:2] == MAGIC
|
|
assert header[2] == VERSION
|
|
assert header[3] == MSG_TYPE_REQUEST
|
|
|
|
def test_decode_header(self):
|
|
"""Test header decoding."""
|
|
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200)
|
|
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
|
assert msg_type == MSG_TYPE_RESPONSE
|
|
assert request_id == 67890
|
|
assert length == 200
|
|
|
|
def test_decode_header_invalid_magic(self):
|
|
"""Test header decoding with invalid magic."""
|
|
header = b"XX" + b"\x01\x01" + b"\x00" * 12
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.decode_header(header)
|
|
assert "magic" in str(exc_info.value).lower()
|
|
|
|
def test_decode_header_invalid_version(self):
|
|
"""Test header decoding with invalid version."""
|
|
header = MAGIC + b"\x99\x01" + b"\x00" * 12
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.decode_header(header)
|
|
assert "version" in str(exc_info.value).lower()
|
|
|
|
def test_decode_header_invalid_type(self):
|
|
"""Test header decoding with invalid message type."""
|
|
header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.decode_header(header)
|
|
assert "type" in str(exc_info.value).lower()
|
|
|
|
def test_decode_header_too_large(self):
|
|
"""Test header decoding rejects too-large messages."""
|
|
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
|
MAX_MESSAGE_SIZE + 1, 0)
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.decode_header(header)
|
|
assert "too large" in str(exc_info.value).lower()
|
|
|
|
def test_decode_header_too_short(self):
|
|
"""Test header decoding with too-short data."""
|
|
header = MAGIC + b"\x01"
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.decode_header(header)
|
|
assert "short" in str(exc_info.value).lower()
|
|
|
|
|
|
class TestBinaryProtocolEncodeDecode:
|
|
"""Tests for message encoding/decoding."""
|
|
|
|
def test_encode_decode_request(self):
|
|
"""Test request encoding/decoding roundtrip."""
|
|
encoded = BinaryProtocol.encode_request(
|
|
request_id=12345,
|
|
app_path="myapp.ml:MLApp",
|
|
action="predict",
|
|
args=("data",),
|
|
kwargs={"temperature": 0.7}
|
|
)
|
|
assert len(encoded) > HEADER_SIZE
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert msg_type_str == "request"
|
|
assert request_id == 12345
|
|
assert payload["app_path"] == "myapp.ml:MLApp"
|
|
assert payload["action"] == "predict"
|
|
assert payload["args"] == ["data"]
|
|
assert payload["kwargs"] == {"temperature": 0.7}
|
|
|
|
def test_encode_decode_response(self):
|
|
"""Test response encoding/decoding roundtrip."""
|
|
result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}}
|
|
encoded = BinaryProtocol.encode_response(request_id=67890, result=result)
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert msg_type_str == "response"
|
|
assert request_id == 67890
|
|
assert payload["result"] == result
|
|
|
|
def test_encode_decode_error(self):
|
|
"""Test error encoding/decoding roundtrip."""
|
|
error = DirtyTimeoutError("Timed out", timeout=30)
|
|
encoded = BinaryProtocol.encode_error(request_id=11111, error=error)
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert msg_type_str == "error"
|
|
assert request_id == 11111
|
|
assert payload["error"]["error_type"] == "DirtyTimeoutError"
|
|
assert "Timed out" in payload["error"]["message"]
|
|
|
|
def test_encode_decode_chunk(self):
|
|
"""Test chunk encoding/decoding roundtrip."""
|
|
chunk_data = {"token": "hello", "index": 5}
|
|
encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data)
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert msg_type_str == "chunk"
|
|
assert request_id == 22222
|
|
assert payload["data"] == chunk_data
|
|
|
|
def test_encode_decode_end(self):
|
|
"""Test end message encoding/decoding roundtrip."""
|
|
encoded = BinaryProtocol.encode_end(request_id=33333)
|
|
assert len(encoded) == HEADER_SIZE # End has no payload
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert msg_type_str == "end"
|
|
assert request_id == 33333
|
|
assert payload == {}
|
|
|
|
def test_encode_decode_binary_data(self):
|
|
"""Test binary data passes through without base64 encoding."""
|
|
binary_data = bytes(range(256))
|
|
encoded = BinaryProtocol.encode_response(
|
|
request_id=44444,
|
|
result={"data": binary_data}
|
|
)
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert payload["result"]["data"] == binary_data
|
|
|
|
def test_encode_decode_large_message(self):
|
|
"""Test encoding a large message."""
|
|
large_data = b"x" * (1024 * 1024) # 1 MB
|
|
encoded = BinaryProtocol.encode_response(
|
|
request_id=55555,
|
|
result={"data": large_data}
|
|
)
|
|
|
|
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
|
assert payload["result"]["data"] == large_data
|
|
|
|
|
|
class TestBinaryProtocolSync:
|
|
"""Tests for synchronous socket operations."""
|
|
|
|
def test_read_write_message(self):
|
|
"""Test read/write through socket pair."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
message = make_request(
|
|
request_id=12345,
|
|
app_path="test:App",
|
|
action="run"
|
|
)
|
|
|
|
BinaryProtocol.write_message(client_sock, message)
|
|
received = BinaryProtocol.read_message(server_sock)
|
|
|
|
assert received["type"] == "request"
|
|
assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \
|
|
received["id"] == 12345
|
|
assert received["app_path"] == "test:App"
|
|
assert received["action"] == "run"
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|
|
|
|
def test_read_write_with_int_id(self):
|
|
"""Test read/write with integer request ID."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
message = {
|
|
"type": "request",
|
|
"id": 999888777,
|
|
"app_path": "test:App",
|
|
"action": "run",
|
|
"args": [],
|
|
"kwargs": {}
|
|
}
|
|
|
|
BinaryProtocol.write_message(client_sock, message)
|
|
received = BinaryProtocol.read_message(server_sock)
|
|
|
|
assert received["id"] == 999888777
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|
|
|
|
def test_multiple_messages(self):
|
|
"""Test sending multiple messages."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
messages = [
|
|
make_request(i, f"app{i}:App", f"action{i}")
|
|
for i in range(1, 4)
|
|
]
|
|
|
|
for msg in messages:
|
|
BinaryProtocol.write_message(client_sock, msg)
|
|
|
|
for i, _ in enumerate(messages, 1):
|
|
received = BinaryProtocol.read_message(server_sock)
|
|
assert received["app_path"] == f"app{i}:App"
|
|
assert received["action"] == f"action{i}"
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|
|
|
|
def test_read_connection_closed(self):
|
|
"""Test reading from closed connection."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
client_sock.close()
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
BinaryProtocol.read_message(server_sock)
|
|
assert "closed" in str(exc_info.value).lower()
|
|
server_sock.close()
|
|
|
|
def test_binary_data_roundtrip(self):
|
|
"""Test binary data roundtrip through socket."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
binary_payload = b"\x00\x01\x02\xff\xfe\xfd"
|
|
message = make_response(12345, {"binary": binary_payload})
|
|
|
|
BinaryProtocol.write_message(client_sock, message)
|
|
received = BinaryProtocol.read_message(server_sock)
|
|
|
|
assert received["result"]["binary"] == binary_payload
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|
|
|
|
|
|
class TestBinaryProtocolAsync:
|
|
"""Tests for async stream operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_read_write(self):
|
|
"""Test async read/write with mock streams."""
|
|
message = make_request(12345, "test:App", "run")
|
|
|
|
read_fd, write_fd = os.pipe()
|
|
try:
|
|
reader = asyncio.StreamReader()
|
|
_ = asyncio.StreamReaderProtocol(reader)
|
|
|
|
encoded = BinaryProtocol._encode_from_dict(message)
|
|
os.write(write_fd, encoded)
|
|
os.close(write_fd)
|
|
write_fd = None
|
|
|
|
data = os.read(read_fd, len(encoded))
|
|
reader.feed_data(data)
|
|
reader.feed_eof()
|
|
|
|
received = await BinaryProtocol.read_message_async(reader)
|
|
assert received["type"] == "request"
|
|
assert received["app_path"] == "test:App"
|
|
finally:
|
|
if write_fd is not None:
|
|
os.close(write_fd)
|
|
os.close(read_fd)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_read_incomplete_header(self):
|
|
"""Test async read with incomplete header."""
|
|
reader = asyncio.StreamReader()
|
|
reader.feed_data(MAGIC + b"\x01") # Only 3 bytes
|
|
reader.feed_eof()
|
|
|
|
with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)):
|
|
await BinaryProtocol.read_message_async(reader)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_read_empty_connection(self):
|
|
"""Test async read on empty connection."""
|
|
reader = asyncio.StreamReader()
|
|
reader.feed_eof()
|
|
|
|
with pytest.raises(asyncio.IncompleteReadError):
|
|
await BinaryProtocol.read_message_async(reader)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_read_invalid_magic(self):
|
|
"""Test async read rejects invalid magic."""
|
|
reader = asyncio.StreamReader()
|
|
header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12
|
|
reader.feed_data(header)
|
|
reader.feed_eof()
|
|
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
await BinaryProtocol.read_message_async(reader)
|
|
assert "magic" in str(exc_info.value).lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_read_message_too_large(self):
|
|
"""Test async read rejects too-large messages."""
|
|
reader = asyncio.StreamReader()
|
|
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
|
MAX_MESSAGE_SIZE + 1000, 0)
|
|
reader.feed_data(header)
|
|
reader.feed_eof()
|
|
|
|
with pytest.raises(DirtyProtocolError) as exc_info:
|
|
await BinaryProtocol.read_message_async(reader)
|
|
assert "too large" in str(exc_info.value)
|
|
|
|
|
|
class TestMessageBuilders:
|
|
"""Tests for message builder helper functions."""
|
|
|
|
def test_make_request(self):
|
|
"""Test request message builder."""
|
|
request = make_request(
|
|
request_id="abc123",
|
|
app_path="myapp.ml:MLApp",
|
|
action="inference",
|
|
args=("model1",),
|
|
kwargs={"temperature": 0.7}
|
|
)
|
|
assert request["type"] == DirtyProtocol.MSG_TYPE_REQUEST
|
|
assert request["id"] == "abc123"
|
|
assert request["app_path"] == "myapp.ml:MLApp"
|
|
assert request["action"] == "inference"
|
|
assert request["args"] == ["model1"]
|
|
assert request["kwargs"] == {"temperature": 0.7}
|
|
|
|
def test_make_request_minimal(self):
|
|
"""Test request with minimal arguments."""
|
|
request = make_request(
|
|
request_id="abc",
|
|
app_path="app:App",
|
|
action="run"
|
|
)
|
|
assert request["args"] == []
|
|
assert request["kwargs"] == {}
|
|
|
|
def test_make_response(self):
|
|
"""Test response message builder."""
|
|
response = make_response(
|
|
request_id="abc123",
|
|
result={"status": "ok", "data": [1, 2, 3]}
|
|
)
|
|
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
|
|
assert response["id"] == "abc123"
|
|
assert response["result"] == {"status": "ok", "data": [1, 2, 3]}
|
|
|
|
def test_make_error_response_with_exception(self):
|
|
"""Test error response with DirtyError."""
|
|
error = DirtyTimeoutError("Operation timed out", timeout=30)
|
|
response = make_error_response("abc123", error)
|
|
|
|
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
|
assert response["id"] == "abc123"
|
|
assert response["error"]["error_type"] == "DirtyTimeoutError"
|
|
assert response["error"]["message"] == "Operation timed out"
|
|
assert response["error"]["details"]["timeout"] == 30
|
|
|
|
def test_make_error_response_with_dict(self):
|
|
"""Test error response with dict."""
|
|
error_dict = {
|
|
"error_type": "CustomError",
|
|
"message": "Something went wrong",
|
|
"details": {"code": 500}
|
|
}
|
|
response = make_error_response("abc123", error_dict)
|
|
|
|
assert response["error"] == error_dict
|
|
|
|
def test_make_error_response_with_generic_exception(self):
|
|
"""Test error response with generic exception."""
|
|
error = ValueError("Invalid value")
|
|
response = make_error_response("abc123", error)
|
|
|
|
assert response["error"]["error_type"] == "ValueError"
|
|
assert response["error"]["message"] == "Invalid value"
|
|
|
|
def test_make_chunk_message(self):
|
|
"""Test chunk message builder."""
|
|
chunk = make_chunk_message("req-123", "Hello, ")
|
|
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
|
|
assert chunk["id"] == "req-123"
|
|
assert chunk["data"] == "Hello, "
|
|
|
|
def test_make_chunk_message_with_complex_data(self):
|
|
"""Test chunk message with complex data."""
|
|
data = {"token": "world", "score": 0.95, "index": 5}
|
|
chunk = make_chunk_message("req-456", data)
|
|
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
|
|
assert chunk["id"] == "req-456"
|
|
assert chunk["data"] == data
|
|
|
|
def test_make_chunk_message_with_binary_data(self):
|
|
"""Test chunk message with binary data."""
|
|
data = b"\x00\x01\x02\xff"
|
|
chunk = make_chunk_message("req-789", data)
|
|
assert chunk["data"] == data
|
|
|
|
def test_make_end_message(self):
|
|
"""Test end message builder."""
|
|
end = make_end_message("req-123")
|
|
assert end["type"] == DirtyProtocol.MSG_TYPE_END
|
|
assert end["id"] == "req-123"
|
|
assert "data" not in end
|
|
|
|
def test_chunk_and_end_roundtrip(self):
|
|
"""Test that chunk and end messages can be encoded/decoded."""
|
|
chunk = make_chunk_message(12345, {"token": "hello"})
|
|
end = make_end_message(12345)
|
|
|
|
# Test chunk roundtrip
|
|
encoded_chunk = BinaryProtocol._encode_from_dict(chunk)
|
|
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk)
|
|
assert msg_type == "chunk"
|
|
assert payload["data"] == {"token": "hello"}
|
|
|
|
# Test end roundtrip
|
|
encoded_end = BinaryProtocol._encode_from_dict(end)
|
|
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end)
|
|
assert msg_type == "end"
|
|
assert payload == {}
|
|
|
|
|
|
class TestDirtyErrors:
|
|
"""Tests for error classes."""
|
|
|
|
def test_dirty_error_to_dict(self):
|
|
"""Test serializing error to dict."""
|
|
error = DirtyError("Test error", {"key": "value"})
|
|
d = error.to_dict()
|
|
assert d["error_type"] == "DirtyError"
|
|
assert d["message"] == "Test error"
|
|
assert d["details"] == {"key": "value"}
|
|
|
|
def test_dirty_error_from_dict(self):
|
|
"""Test deserializing error from dict."""
|
|
d = {
|
|
"error_type": "DirtyTimeoutError",
|
|
"message": "Timed out",
|
|
"details": {"timeout": 30}
|
|
}
|
|
error = DirtyError.from_dict(d)
|
|
assert isinstance(error, DirtyTimeoutError)
|
|
assert error.message == "Timed out"
|
|
assert error.details["timeout"] == 30
|
|
|
|
def test_dirty_error_from_dict_unknown_type(self):
|
|
"""Test deserializing unknown error type falls back to DirtyError."""
|
|
d = {
|
|
"error_type": "UnknownError",
|
|
"message": "Unknown",
|
|
"details": {}
|
|
}
|
|
error = DirtyError.from_dict(d)
|
|
assert isinstance(error, DirtyError)
|
|
assert not isinstance(error, DirtyTimeoutError)
|
|
|
|
def test_dirty_app_error(self):
|
|
"""Test DirtyAppError fields."""
|
|
error = DirtyAppError(
|
|
"App failed",
|
|
app_path="myapp:App",
|
|
action="run",
|
|
traceback="Traceback..."
|
|
)
|
|
assert error.app_path == "myapp:App"
|
|
assert error.action == "run"
|
|
assert error.traceback == "Traceback..."
|
|
assert "myapp:App" in str(error)
|
|
|
|
|
|
class TestBackwardsCompatibility:
|
|
"""Tests for backwards compatibility with old JSON API."""
|
|
|
|
def test_dirty_protocol_alias(self):
|
|
"""Test that DirtyProtocol is an alias for BinaryProtocol."""
|
|
assert DirtyProtocol is BinaryProtocol
|
|
|
|
def test_header_size_attribute(self):
|
|
"""Test HEADER_SIZE is accessible on class."""
|
|
assert DirtyProtocol.HEADER_SIZE == 16
|
|
|
|
def test_msg_type_constants(self):
|
|
"""Test message type constants are strings for compatibility."""
|
|
assert DirtyProtocol.MSG_TYPE_REQUEST == "request"
|
|
assert DirtyProtocol.MSG_TYPE_RESPONSE == "response"
|
|
assert DirtyProtocol.MSG_TYPE_ERROR == "error"
|
|
assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk"
|
|
assert DirtyProtocol.MSG_TYPE_END == "end"
|
|
|
|
def test_encode_decode_preserves_dict_format(self):
|
|
"""Test that read_message returns dict compatible with old API."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
message = {
|
|
"type": "response",
|
|
"id": 12345,
|
|
"result": {"status": "ok"}
|
|
}
|
|
|
|
DirtyProtocol.write_message(client_sock, message)
|
|
received = DirtyProtocol.read_message(server_sock)
|
|
|
|
# Old API: access via dict keys
|
|
assert received["type"] == "response"
|
|
assert received["result"]["status"] == "ok"
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|
|
|
|
def test_string_request_id_handled(self):
|
|
"""Test that string request IDs are handled (hashed to int)."""
|
|
server_sock, client_sock = socket.socketpair()
|
|
try:
|
|
message = make_request("uuid-string-id", "test:App", "run")
|
|
|
|
DirtyProtocol.write_message(client_sock, message)
|
|
received = DirtyProtocol.read_message(server_sock)
|
|
|
|
# Request ID should be converted to int
|
|
assert isinstance(received["id"], int)
|
|
finally:
|
|
server_sock.close()
|
|
client_sock.close()
|