Merge pull request #3500 from benoitc/feature/binary-dirty-protocol

feat(dirty): implement binary protocol for dirty worker IPC
This commit is contained in:
Benoit Chesneau 2026-02-12 17:26:10 +01:00 committed by GitHub
commit 4b90e4ba16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 2217 additions and 511 deletions

View File

@ -17,7 +17,7 @@ Dirty Arbiters solve this by providing:
- **Separate worker pool** - Completely separate from HTTP workers, can be killed/restarted independently
- **Stateful workers** - Loaded resources persist in dirty worker memory
- **Message-passing IPC** - Communication via Unix sockets with JSON serialization
- **Message-passing IPC** - Communication via Unix sockets with binary TLV protocol
- **Explicit API** - Clear `execute()` calls (no hidden IPC)
- **Asyncio-based** - Clean concurrent handling with streaming support
@ -476,6 +476,64 @@ Worker -> Arbiter -> Client: chunk (data: "World")
Worker -> Arbiter -> Client: end
```
## Binary Protocol
The dirty worker IPC uses a binary protocol inspired by OpenBSD msgctl/msgsnd for efficient data transfer. This eliminates base64 encoding overhead for binary data like images, audio, or model weights.
### Header Format (16 bytes)
```
+--------+--------+--------+--------+--------+--------+--------+--------+
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
+--------+--------+--------+--------+--------+--------+--------+--------+
| Request ID (8 bytes) |
+--------+--------+--------+--------+--------+--------+--------+--------+
```
- **Magic**: `0x47 0x44` ("GD" for Gunicorn Dirty)
- **Version**: `0x01`
- **MType**: Message type (`0x01`=REQUEST, `0x02`=RESPONSE, `0x03`=ERROR, `0x04`=CHUNK, `0x05`=END)
- **Length**: Payload size (big-endian uint32, max 64MB)
- **Request ID**: uint64 identifier
### TLV Payload Encoding
Payloads use Type-Length-Value encoding:
| Type | Code | Description |
|------|------|-------------|
| None | `0x00` | No value bytes |
| Bool | `0x01` | 1 byte (0x00/0x01) |
| Int64 | `0x05` | 8 bytes big-endian signed |
| Float64 | `0x06` | 8 bytes IEEE 754 |
| Bytes | `0x10` | 4-byte length + raw bytes |
| String | `0x11` | 4-byte length + UTF-8 |
| List | `0x20` | 4-byte count + elements |
| Dict | `0x21` | 4-byte count + key-value pairs |
### Binary Data Benefits
The binary protocol allows passing raw bytes directly without encoding:
```python
# Image processing with binary data
def resize(self, image_data, width, height):
"""Resize an image - image_data is raw bytes."""
img = Image.open(io.BytesIO(image_data))
resized = img.resize((width, height))
buffer = io.BytesIO()
resized.save(buffer, format='PNG')
return buffer.getvalue() # Returns raw bytes
# Called from HTTP worker
thumbnail = client.execute(
"myapp.images:ImageApp",
"thumbnail",
raw_image_bytes, # No base64 encoding needed
size=256
)
```
### Error Handling in Streams
Errors during streaming are delivered as error messages:
@ -768,7 +826,7 @@ except DirtyConnectionError:
2. **Set appropriate timeouts** based on your workload
3. **Handle errors gracefully** - dirty workers may restart
4. **Use meaningful action names** for easier debugging
5. **Keep responses JSON-serializable** - results are passed via IPC
5. **Keep responses serializable** - results are passed via binary IPC (supports bytes directly)
## Monitoring

View File

@ -15,7 +15,7 @@ services:
context: ../.. # Gunicorn repo root
dockerfile: examples/celery_alternative/Dockerfile
ports:
- "8000:8000"
- "8003:8000"
environment:
- GUNICORN_WORKERS=4
- DIRTY_WORKERS=9

View File

@ -0,0 +1,25 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
FROM python:3.12-slim
WORKDIR /app
# Copy gunicorn source
COPY . /app/gunicorn-src
# Install gunicorn and dependencies
# setproctitle is needed for process title changes
RUN pip install --no-cache-dir /app/gunicorn-src setproctitle
# Copy example files
COPY examples/dirty_example/ /app/examples/dirty_example/
WORKDIR /app
# Expose the port
EXPOSE 8000
# Default command - run the example tests
CMD ["python", "-m", "pytest", "-v", "examples/dirty_example/"]

View File

@ -0,0 +1,54 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
services:
# Run the example tests (protocol, dirty app, worker integration)
tests:
build:
context: ../..
dockerfile: examples/dirty_example/Dockerfile
command: >
bash -c "
echo '=== Running Protocol Tests ===' &&
python examples/dirty_example/test_protocol.py &&
echo '' &&
echo '=== Running Dirty App Tests ===' &&
python examples/dirty_example/test_dirty_app.py &&
echo '' &&
echo '=== Running Worker Integration Tests ===' &&
python examples/dirty_example/test_worker_integration.py &&
echo '' &&
echo '=== All tests passed! ==='
"
# Run the full gunicorn server with dirty workers
server:
build:
context: ../..
dockerfile: examples/dirty_example/Dockerfile
ports:
- "8001:8000"
environment:
- GUNICORN_BIND=0.0.0.0:8000
command: >
gunicorn examples.dirty_example.wsgi_app:app
-c examples/dirty_example/gunicorn_conf.py
healthcheck:
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/')"]
interval: 5s
timeout: 5s
retries: 5
start_period: 10s
# Run integration test against the server
integration-test:
build:
context: ../..
dockerfile: examples/dirty_example/Dockerfile
depends_on:
server:
condition: service_healthy
environment:
- TEST_BASE_URL=http://server:8000
command: python examples/dirty_example/test_integration.py

View File

@ -11,7 +11,9 @@ Run with:
"""
# Basic settings
bind = "127.0.0.1:8000"
# Use 0.0.0.0 for Docker, override with GUNICORN_BIND env var if needed
import os
bind = os.environ.get("GUNICORN_BIND", "127.0.0.1:8000")
workers = 2
worker_class = "sync"
timeout = 30

View File

@ -0,0 +1,81 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
#!/usr/bin/env python
"""
Integration test for the dirty example server.
This tests that the full gunicorn server with dirty workers responds
correctly to HTTP requests.
Run with:
python examples/dirty_example/test_integration.py [base_url]
Default base_url is http://localhost:8000
"""
import sys
import os
import json
import urllib.request
import urllib.error
def test_endpoint(base, path, expected_key=None):
"""Test an endpoint and check for expected key in response."""
url = base + path
print(f"Testing: {url}")
try:
with urllib.request.urlopen(url, timeout=10) as resp:
data = json.loads(resp.read())
print(f" Response: {str(data)[:200]}")
if expected_key and expected_key not in data:
print(f" ERROR: Expected key '{expected_key}' not found!")
return False
return True
except urllib.error.HTTPError as e:
print(f" HTTP ERROR {e.code}: {e.reason}")
return False
except Exception as e:
print(f" ERROR: {e}")
return False
def main():
# Get base URL from env or command line
base = os.environ.get("TEST_BASE_URL", "http://localhost:8000")
if len(sys.argv) > 1:
base = sys.argv[1]
print(f"Testing dirty example server at: {base}")
print("=" * 60)
# Define tests: (path, expected_key_in_response)
tests = [
("/", "endpoints"),
("/models", "models"),
("/load?name=test-model", "status"),
("/inference?model=default&data=hello", "prediction"),
("/fibonacci?n=20", "result"),
("/prime?n=17", "is_prime"),
("/stats", "ml_app"),
("/unload?name=test-model", "status"),
]
failed = 0
for path, key in tests:
if not test_endpoint(base, path, key):
failed += 1
print()
print("=" * 60)
if failed:
print(f"FAILED: {failed} tests failed")
sys.exit(1)
else:
print("SUCCESS: All integration tests passed!")
if __name__ == "__main__":
main()

View File

@ -4,7 +4,10 @@
#!/usr/bin/env python
"""
Test script to demonstrate the Dirty Protocol layer.
Test script to demonstrate the Dirty Binary Protocol layer.
The binary protocol uses a 16-byte header + TLV-encoded payloads for efficient
binary data transfer without base64 encoding overhead.
Run with:
python examples/dirty_example/test_protocol.py
@ -18,10 +21,14 @@ import socket
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from gunicorn.dirty.protocol import (
BinaryProtocol,
DirtyProtocol,
make_request,
make_response,
make_error_response,
HEADER_SIZE,
MAGIC,
VERSION,
)
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
@ -29,13 +36,13 @@ from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
def test_protocol_encode_decode():
"""Test protocol encoding and decoding."""
print("=" * 60)
print("Testing Protocol Encode/Decode")
print("Testing Binary Protocol Encode/Decode")
print("=" * 60)
# Test request
# Test request with integer ID (recommended for binary protocol)
print("\n1. Creating a request message...")
request = make_request(
request_id="req-001",
request_id=12345, # Integer IDs are efficient
app_path="myapp.ml:MLApp",
action="inference",
args=("model1",),
@ -43,18 +50,53 @@ def test_protocol_encode_decode():
)
print(f" Request: {request}")
# Encode
print("\n2. Encoding message...")
encoded = DirtyProtocol.encode(request)
# Encode using binary protocol
print("\n2. Encoding message with binary protocol...")
encoded = BinaryProtocol._encode_from_dict(request)
print(f" Encoded length: {len(encoded)} bytes")
print(f" Header (4 bytes): {encoded[:4].hex()}")
print(f" Header ({HEADER_SIZE} bytes): {encoded[:HEADER_SIZE].hex()}")
print(f" Magic: {MAGIC!r}")
print(f" Version: {VERSION}")
# Decode
print("\n3. Decoding payload...")
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
print(f" Decoded: {decoded}")
print(f" Match: {decoded == request}")
# Decode header
print("\n3. Decoding header...")
msg_type, request_id, payload_len = BinaryProtocol.decode_header(encoded[:HEADER_SIZE])
print(f" Message type: {msg_type} (0x{msg_type:02x})")
print(f" Request ID: {request_id}")
print(f" Payload length: {payload_len} bytes")
# Decode full message
print("\n4. Decoding full message...")
msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded)
print(f" Type: {msg_type_str}")
print(f" Request ID: {req_id}")
print(f" Payload: {payload}")
def test_binary_data_handling():
"""Test binary data handling - the main advantage of binary protocol."""
print("\n" + "=" * 60)
print("Testing Binary Data Handling")
print("=" * 60)
# Create binary data (e.g., image, audio, model weights)
binary_data = bytes(range(256)) # All byte values
print(f"\n1. Original binary data: {len(binary_data)} bytes")
print(f" First 16 bytes: {binary_data[:16].hex()}")
# Create response with binary data (no base64 needed!)
print("\n2. Encoding binary data in response...")
response = make_response(67890, {"image_data": binary_data, "format": "raw"})
encoded = BinaryProtocol._encode_from_dict(response)
print(f" Encoded total size: {len(encoded)} bytes")
# Decode and verify
print("\n3. Decoding binary data...")
msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded)
recovered_data = payload["result"]["image_data"]
print(f" Recovered data size: {len(recovered_data)} bytes")
print(f" Data matches: {recovered_data == binary_data}")
print(f" First 16 bytes: {recovered_data[:16].hex()}")
def test_protocol_response():
@ -65,13 +107,13 @@ def test_protocol_response():
# Success response
print("\n1. Creating success response...")
response = make_response("req-001", {"result": "Hello, World!", "confidence": 0.95})
response = make_response(12345, {"result": "Hello, World!", "confidence": 0.95})
print(f" Response: {response}")
# Error response
print("\n2. Creating error response...")
error = DirtyTimeoutError("Operation timed out", timeout=30)
error_response = make_error_response("req-001", error)
error_response = make_error_response(12345, error)
print(f" Error response: {error_response}")
@ -88,7 +130,7 @@ def test_socket_communication():
# Send a request
print("\n1. Sending request over socket...")
request = make_request(
request_id="socket-req-001",
request_id=100001,
app_path="test:App",
action="compute",
args=(1, 2, 3),
@ -101,19 +143,20 @@ def test_socket_communication():
print("\n2. Receiving request...")
received = DirtyProtocol.read_message(server_sock)
print(f" Received: {received}")
print(f" Match: {received == request}")
print(f" Request ID: {received['id']}")
# Send a response
print("\n3. Sending response...")
response = make_response("socket-req-001", {"sum": 6})
# Send a response with binary data
print("\n3. Sending response with binary data...")
binary_result = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
response = make_response(100001, {"data": binary_result, "sum": 6})
DirtyProtocol.write_message(server_sock, response)
print(f" Sent: {response}")
print(f" Sent binary data: {binary_result.hex()}")
# Receive the response
print("\n4. Receiving response...")
received = DirtyProtocol.read_message(client_sock)
print(f" Received: {received}")
print(f" Match: {received == response}")
print(f" Received binary data: {received['result']['data'].hex()}")
print(f" Sum: {received['result']['sum']}")
finally:
server_sock.close()
@ -132,7 +175,7 @@ async def test_async_communication():
try:
# Create message
request = make_request(
request_id="async-req-001",
request_id=200001,
app_path="async:App",
action="process",
args=("data",),
@ -141,7 +184,7 @@ async def test_async_communication():
# Write to pipe
print("\n1. Writing async message...")
encoded = DirtyProtocol.encode(request)
encoded = BinaryProtocol._encode_from_dict(request)
os.write(write_fd, encoded)
os.close(write_fd)
write_fd = None
@ -156,7 +199,7 @@ async def test_async_communication():
received = await DirtyProtocol.read_message_async(reader)
print(f" Received: {received}")
print(f" Match: {received == request}")
print(f" Request ID: {received['id']}")
finally:
if write_fd is not None:
@ -193,10 +236,11 @@ def test_error_serialization():
if __name__ == "__main__":
print("\n" + "#" * 60)
print("# Dirty Protocol Demonstration")
print("# Dirty Binary Protocol Demonstration")
print("#" * 60)
test_protocol_encode_decode()
test_binary_data_handling()
test_protocol_response()
test_socket_communication()
asyncio.run(test_async_communication())

View File

@ -22,7 +22,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa
from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, make_request, HEADER_SIZE
class MockLog:
@ -35,6 +35,36 @@ class MockLog:
def reopen_files(self): pass
class MockWriter:
"""Mock StreamWriter that captures written responses."""
def __init__(self):
self.messages = []
self._buffer = b""
def write(self, data):
self._buffer += data
async def drain(self):
# Decode messages from buffer using binary protocol
while len(self._buffer) >= HEADER_SIZE:
_, _, length = BinaryProtocol.decode_header(self._buffer[:HEADER_SIZE])
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
def get_last_response(self):
"""Get the last response message."""
return self.messages[-1] if self.messages else None
async def test_worker_request_handling():
"""Test that a worker can load apps and handle requests."""
print("=" * 60)
@ -75,52 +105,60 @@ async def test_worker_request_handling():
# Test handle_request with a proper request message
print("\n3. Testing handle_request() - load_model...")
request = make_request(
request_id="test-001",
request_id=1001,
app_path="examples.dirty_example.dirty_app:MLApp",
action="load_model",
args=("gpt-4",),
kwargs={}
)
response = await worker.handle_request(request)
writer = MockWriter()
await worker.handle_request(request, writer)
response = writer.get_last_response()
print(f" Response type: {response['type']}")
print(f" Result: {response.get('result', response.get('error'))}")
# Test inference
print("\n4. Testing handle_request() - inference...")
request = make_request(
request_id="test-002",
request_id=1002,
app_path="examples.dirty_example.dirty_app:MLApp",
action="inference",
args=("default", "Hello AI!"),
kwargs={}
)
response = await worker.handle_request(request)
writer = MockWriter()
await worker.handle_request(request, writer)
response = writer.get_last_response()
print(f" Response type: {response['type']}")
print(f" Result: {response.get('result', response.get('error'))}")
# Test error handling
print("\n5. Testing error handling - unknown action...")
request = make_request(
request_id="test-003",
request_id=1003,
app_path="examples.dirty_example.dirty_app:MLApp",
action="nonexistent_action",
args=(),
kwargs={}
)
response = await worker.handle_request(request)
writer = MockWriter()
await worker.handle_request(request, writer)
response = writer.get_last_response()
print(f" Response type: {response['type']}")
print(f" Error: {response.get('error', {}).get('message')}")
# Test app not found
print("\n6. Testing error handling - app not found...")
request = make_request(
request_id="test-004",
request_id=1004,
app_path="nonexistent:App",
action="test",
args=(),
kwargs={}
)
response = await worker.handle_request(request)
writer = MockWriter()
await worker.handle_request(request, writer)
response = writer.get_last_response()
print(f" Response type: {response['type']}")
print(f" Error type: {response.get('error', {}).get('error_type')}")

View File

@ -3,89 +3,304 @@
# See the NOTICE for more information.
"""
Dirty Arbiters Protocol
Dirty Worker Binary Protocol
Length-prefixed JSON message framing over Unix sockets.
Provides both async (primary) and sync (for HTTP workers) APIs.
Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd.
Replaces JSON protocol for efficient binary data transfer.
Message Format:
+----------------+------------------+
| 4-byte length | JSON payload |
+----------------+------------------+
Header Format (16 bytes):
+--------+--------+--------+--------+--------+--------+--------+--------+
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
+--------+--------+--------+--------+--------+--------+--------+--------+
| Request ID (8 bytes) |
+--------+--------+--------+--------+--------+--------+--------+--------+
The length field is a 4-byte unsigned integer in network byte order (big-endian).
- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty)
- Version: 0x01
- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END)
- Length: Payload size (big-endian uint32, max 64MB)
- Request ID: uint64 (replaces UUID string)
Payload is TLV-encoded (see tlv.py).
"""
import asyncio
import json
import struct
import socket
import struct
from .errors import DirtyProtocolError
from .tlv import TLVEncoder
class DirtyProtocol:
"""Length-prefixed JSON messages over Unix sockets."""
# Protocol constants
MAGIC = b"GD" # 0x47 0x44
VERSION = 0x01
# 4-byte unsigned int, network byte order (big-endian)
HEADER_FORMAT = "!I"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
# Message types (1 byte)
MSG_TYPE_REQUEST = 0x01
MSG_TYPE_RESPONSE = 0x02
MSG_TYPE_ERROR = 0x03
MSG_TYPE_CHUNK = 0x04
MSG_TYPE_END = 0x05
# Maximum message size (64 MB)
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
# Message type names (for backwards compatibility with old API)
MSG_TYPE_REQUEST_STR = "request"
MSG_TYPE_RESPONSE_STR = "response"
MSG_TYPE_ERROR_STR = "error"
MSG_TYPE_CHUNK_STR = "chunk"
MSG_TYPE_END_STR = "end"
# Message types for future streaming support
MSG_TYPE_REQUEST = "request"
MSG_TYPE_RESPONSE = "response"
MSG_TYPE_ERROR = "error"
MSG_TYPE_CHUNK = "chunk"
MSG_TYPE_END = "end"
# Map int types to string names
MSG_TYPE_TO_STR = {
MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR,
MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR,
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
MSG_TYPE_END: MSG_TYPE_END_STR,
}
# Map string names to int types
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
HEADER_FORMAT = ">2sBBIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
# Maximum message size (64 MB)
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
class BinaryProtocol:
"""Binary message protocol for dirty worker IPC."""
# Export constants for external use
HEADER_SIZE = HEADER_SIZE
MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE
MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR
MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
MSG_TYPE_END = MSG_TYPE_END_STR
@staticmethod
def encode(message: dict) -> bytes:
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
"""
Encode a message dict to length-prefixed bytes.
Encode the 16-byte message header.
Args:
message: Dictionary to encode as JSON
msg_type: Message type (MSG_TYPE_REQUEST, etc.)
request_id: Unique request identifier (uint64)
payload_length: Length of the TLV-encoded payload
Returns:
bytes: Length-prefixed encoded message
bytes: 16-byte header
"""
return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type,
payload_length, request_id)
@staticmethod
def decode_header(data: bytes) -> tuple:
"""
Decode the 16-byte message header.
Args:
data: 16 bytes of header data
Returns:
tuple: (msg_type, request_id, payload_length)
Raises:
DirtyProtocolError: If encoding fails
DirtyProtocolError: If header is invalid
"""
try:
payload = json.dumps(message).encode("utf-8")
if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE:
if len(data) < HEADER_SIZE:
raise DirtyProtocolError(
f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}",
raw_data=data
)
magic, version, msg_type, length, request_id = struct.unpack(
HEADER_FORMAT, data[:HEADER_SIZE]
)
if magic != MAGIC:
raise DirtyProtocolError(
f"Invalid magic: {magic!r}, expected {MAGIC!r}",
raw_data=data[:20]
)
if version != VERSION:
raise DirtyProtocolError(
f"Unsupported protocol version: {version}, expected {VERSION}",
raw_data=data[:20]
)
if msg_type not in MSG_TYPE_TO_STR:
raise DirtyProtocolError(
f"Unknown message type: 0x{msg_type:02x}",
raw_data=data[:20]
)
if length > MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})"
)
return msg_type, request_id, length
@staticmethod
def encode_request(request_id: int, app_path: str, action: str,
args: tuple = None, kwargs: dict = None) -> bytes:
"""
Encode a request message.
Args:
request_id: Unique request identifier (uint64)
app_path: Import path of the dirty app
action: Action to call on the app
args: Positional arguments
kwargs: Keyword arguments
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {
"app_path": app_path,
"action": action,
"args": list(args) if args else [],
"kwargs": kwargs or {},
}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id,
len(payload))
return header + payload
@staticmethod
def encode_response(request_id: int, result) -> bytes:
"""
Encode a success response message.
Args:
request_id: Request identifier this responds to
result: Result value (must be TLV-serializable)
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {"result": result}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id,
len(payload))
return header + payload
@staticmethod
def encode_error(request_id: int, error) -> bytes:
"""
Encode an error response message.
Args:
request_id: Request identifier this responds to
error: DirtyError instance, dict, or Exception
Returns:
bytes: Complete message (header + payload)
"""
from .errors import DirtyError
if isinstance(error, DirtyError):
error_dict = error.to_dict()
elif isinstance(error, dict):
error_dict = error
else:
error_dict = {
"error_type": type(error).__name__,
"message": str(error),
"details": {},
}
payload_dict = {"error": error_dict}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id,
len(payload))
return header + payload
@staticmethod
def encode_chunk(request_id: int, data) -> bytes:
"""
Encode a chunk message for streaming responses.
Args:
request_id: Request identifier this chunk belongs to
data: Chunk data (must be TLV-serializable)
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {"data": data}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id,
len(payload))
return header + payload
@staticmethod
def encode_end(request_id: int) -> bytes:
"""
Encode an end-of-stream message.
Args:
request_id: Request identifier this ends
Returns:
bytes: Complete message (header + empty payload)
"""
# End message has empty payload
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
return header
@staticmethod
def decode_message(data: bytes) -> tuple:
"""
Decode a complete message (header + payload).
Args:
data: Complete message bytes
Returns:
tuple: (msg_type_str, request_id, payload_dict)
msg_type_str is the string name (e.g., "request")
payload_dict is the decoded TLV payload as a dict
Raises:
DirtyProtocolError: If message is malformed
"""
msg_type, request_id, length = BinaryProtocol.decode_header(data)
if len(data) < HEADER_SIZE + length:
raise DirtyProtocolError(
f"Incomplete message: expected {HEADER_SIZE + length} bytes, "
f"got {len(data)}",
raw_data=data[:50]
)
if length == 0:
# End message has empty payload
payload_dict = {}
else:
payload_data = data[HEADER_SIZE:HEADER_SIZE + length]
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Message too large: {len(payload)} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload))
return length + payload
except (TypeError, ValueError) as e:
raise DirtyProtocolError(f"Failed to encode message: {e}")
@staticmethod
def decode(data: bytes) -> dict:
"""
Decode bytes (without length prefix) to message dict.
# Convert to dict format similar to old JSON protocol
msg_type_str = MSG_TYPE_TO_STR[msg_type]
Args:
data: JSON bytes to decode
Returns:
dict: Decoded message
Raises:
DirtyProtocolError: If decoding fails
"""
try:
return json.loads(data.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise DirtyProtocolError(f"Failed to decode message: {e}",
raw_data=data)
return msg_type_str, request_id, payload_dict
# -------------------------------------------------------------------------
# Async API (primary - for DirtyArbiter and DirtyWorker)
@ -94,53 +309,62 @@ class DirtyProtocol:
@staticmethod
async def read_message_async(reader: asyncio.StreamReader) -> dict:
"""
Read a complete message from async stream.
Read a complete binary message from async stream.
Args:
reader: asyncio StreamReader
Returns:
dict: Decoded message
dict: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If read fails or message is malformed
asyncio.IncompleteReadError: If connection closed mid-read
"""
# Read length header
# Read header
try:
header = await reader.readexactly(DirtyProtocol.HEADER_SIZE)
header = await reader.readexactly(HEADER_SIZE)
except asyncio.IncompleteReadError as e:
if len(e.partial) == 0:
# Clean close - no data was read
raise
raise DirtyProtocolError(
f"Incomplete header: got {len(e.partial)} bytes, "
f"expected {DirtyProtocol.HEADER_SIZE}",
f"expected {HEADER_SIZE}",
raw_data=e.partial
)
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
)
if length == 0:
raise DirtyProtocolError("Empty message received")
msg_type, request_id, length = BinaryProtocol.decode_header(header)
# Read payload
try:
payload = await reader.readexactly(length)
except asyncio.IncompleteReadError as e:
raise DirtyProtocolError(
f"Incomplete message: got {len(e.partial)} bytes, "
f"expected {length}",
raw_data=e.partial
)
if length > 0:
try:
payload_data = await reader.readexactly(length)
except asyncio.IncompleteReadError as e:
raise DirtyProtocolError(
f"Incomplete payload: got {len(e.partial)} bytes, "
f"expected {length}",
raw_data=e.partial
)
return DirtyProtocol.decode(payload)
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
else:
payload_dict = {}
# Build response dict
msg_type_str = MSG_TYPE_TO_STR[msg_type]
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
return result
@staticmethod
async def write_message_async(writer: asyncio.StreamWriter,
@ -148,15 +372,17 @@ class DirtyProtocol:
"""
Write a message to async stream.
Accepts dict format for backwards compatibility.
Args:
writer: asyncio StreamWriter
message: Dictionary to send
message: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If encoding fails
ConnectionError: If write fails
"""
data = DirtyProtocol.encode(message)
data = BinaryProtocol._encode_from_dict(message)
writer.write(data)
await writer.drain()
@ -201,27 +427,36 @@ class DirtyProtocol:
sock: Socket to read from
Returns:
dict: Decoded message
dict: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If read fails or message is malformed
"""
# Read length header
header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE)
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
raise DirtyProtocolError(
f"Message too large: {length} bytes "
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
)
if length == 0:
raise DirtyProtocolError("Empty message received")
# Read header
header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE)
msg_type, request_id, length = BinaryProtocol.decode_header(header)
# Read payload
payload = DirtyProtocol._recv_exactly(sock, length)
return DirtyProtocol.decode(payload)
if length > 0:
payload_data = BinaryProtocol._recv_exactly(sock, length)
try:
payload_dict = TLVEncoder.decode_full(payload_data)
except DirtyProtocolError:
raise
except Exception as e:
raise DirtyProtocolError(
f"Failed to decode TLV payload: {e}",
raw_data=payload_data[:50]
)
else:
payload_dict = {}
# Build response dict
msg_type_str = MSG_TYPE_TO_STR[msg_type]
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
return result
@staticmethod
def write_message(sock: socket.socket, message: dict) -> None:
@ -230,31 +465,92 @@ class DirtyProtocol:
Args:
sock: Socket to write to
message: Dictionary to send
message: Message dict with 'type', 'id', and payload fields
Raises:
DirtyProtocolError: If encoding fails
OSError: If write fails
"""
data = DirtyProtocol.encode(message)
data = BinaryProtocol._encode_from_dict(message)
sock.sendall(data)
@staticmethod
def _encode_from_dict(message: dict) -> bytes:
"""
Encode a message dict to binary format.
# Message builder helpers
def make_request(request_id: str, app_path: str, action: str,
Supports the old dict-based API for backwards compatibility.
Args:
message: Message dict with 'type', 'id', and payload fields
Returns:
bytes: Complete encoded message
"""
msg_type_str = message.get("type")
request_id = message.get("id", 0)
# Handle string or int request IDs
if isinstance(request_id, str):
# For backwards compat with UUID strings, hash to int
request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF
msg_type = MSG_TYPE_FROM_STR.get(msg_type_str)
if msg_type is None:
raise DirtyProtocolError(f"Unknown message type: {msg_type_str}")
if msg_type == MSG_TYPE_REQUEST:
return BinaryProtocol.encode_request(
request_id,
message.get("app_path", ""),
message.get("action", ""),
message.get("args"),
message.get("kwargs")
)
elif msg_type == MSG_TYPE_RESPONSE:
return BinaryProtocol.encode_response(
request_id,
message.get("result")
)
elif msg_type == MSG_TYPE_ERROR:
return BinaryProtocol.encode_error(
request_id,
message.get("error", {})
)
elif msg_type == MSG_TYPE_CHUNK:
return BinaryProtocol.encode_chunk(
request_id,
message.get("data")
)
elif msg_type == MSG_TYPE_END:
return BinaryProtocol.encode_end(request_id)
else:
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
# =============================================================================
# Backwards Compatibility Aliases
# =============================================================================
# Alias BinaryProtocol as DirtyProtocol for drop-in replacement
DirtyProtocol = BinaryProtocol
# Message builder helpers (backwards compatible with old API)
def make_request(request_id, app_path: str, action: str,
args: tuple = None, kwargs: dict = None) -> dict:
"""
Build a request message.
Build a request message dict.
Args:
request_id: Unique request identifier
request_id: Unique request identifier (int or str)
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
action: Action to call on the app
args: Positional arguments
kwargs: Keyword arguments
Returns:
dict: Request message
dict: Request message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_REQUEST,
@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str,
}
def make_response(request_id: str, result) -> dict:
def make_response(request_id, result) -> dict:
"""
Build a success response message.
Build a success response message dict.
Args:
request_id: Request identifier this responds to
result: Result value (must be JSON-serializable)
result: Result value
Returns:
dict: Response message
dict: Response message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_RESPONSE,
@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict:
}
def make_error_response(request_id: str, error) -> dict:
def make_error_response(request_id, error) -> dict:
"""
Build an error response message.
Build an error response message dict.
Args:
request_id: Request identifier this responds to
error: DirtyError instance or dict with error info
Returns:
dict: Error response message
dict: Error response message dict
"""
from .errors import DirtyError
if isinstance(error, DirtyError):
@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict:
}
def make_chunk_message(request_id: str, data) -> dict:
def make_chunk_message(request_id, data) -> dict:
"""
Build a chunk message for streaming responses.
Build a chunk message dict for streaming responses.
Args:
request_id: Request identifier this chunk belongs to
data: Chunk data (must be JSON-serializable)
data: Chunk data
Returns:
dict: Chunk message
dict: Chunk message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_CHUNK,
@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict:
}
def make_end_message(request_id: str) -> dict:
def make_end_message(request_id) -> dict:
"""
Build an end-of-stream message.
Build an end-of-stream message dict.
Args:
request_id: Request identifier this ends
Returns:
dict: End message
dict: End message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_END,

303
gunicorn/dirty/tlv.py Normal file
View File

@ -0,0 +1,303 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
TLV (Type-Length-Value) Binary Encoder/Decoder
Provides efficient binary serialization for dirty worker protocol messages.
Inspired by OpenBSD msgctl/msgsnd message format.
Type Codes:
0x00: None (no value bytes)
0x01: bool (1 byte: 0x00 or 0x01)
0x05: int64 (8 bytes big-endian signed)
0x06: float64 (8 bytes IEEE 754)
0x10: bytes (4-byte length + raw bytes)
0x11: string (4-byte length + UTF-8 encoded)
0x20: list (4-byte count + encoded elements)
0x21: dict (4-byte count + encoded key-value pairs)
"""
import struct
from .errors import DirtyProtocolError
# Type codes
TYPE_NONE = 0x00
TYPE_BOOL = 0x01
TYPE_INT64 = 0x05
TYPE_FLOAT64 = 0x06
TYPE_BYTES = 0x10
TYPE_STRING = 0x11
TYPE_LIST = 0x20
TYPE_DICT = 0x21
# Maximum sizes for safety
MAX_STRING_SIZE = 64 * 1024 * 1024 # 64 MB
MAX_BYTES_SIZE = 64 * 1024 * 1024 # 64 MB
MAX_LIST_SIZE = 1024 * 1024 # 1 million items
MAX_DICT_SIZE = 1024 * 1024 # 1 million items
class TLVEncoder:
"""
TLV binary encoder/decoder.
Encodes Python values to binary TLV format and decodes back.
Supports: None, bool, int, float, bytes, str, list, dict.
"""
@staticmethod
def encode(value) -> bytes: # pylint: disable=too-many-return-statements
"""
Encode a Python value to TLV binary format.
Args:
value: Python value to encode (None, bool, int, float,
bytes, str, list, or dict)
Returns:
bytes: TLV-encoded binary data
Raises:
DirtyProtocolError: If value type is not supported
"""
if value is None:
return bytes([TYPE_NONE])
if isinstance(value, bool):
# bool must come before int since bool is a subclass of int
return bytes([TYPE_BOOL, 0x01 if value else 0x00])
if isinstance(value, int):
return bytes([TYPE_INT64]) + struct.pack(">q", value)
if isinstance(value, float):
return bytes([TYPE_FLOAT64]) + struct.pack(">d", value)
if isinstance(value, bytes):
if len(value) > MAX_BYTES_SIZE:
raise DirtyProtocolError(
f"Bytes too large: {len(value)} bytes "
f"(max: {MAX_BYTES_SIZE})"
)
return bytes([TYPE_BYTES]) + struct.pack(">I", len(value)) + value
if isinstance(value, str):
encoded = value.encode("utf-8")
if len(encoded) > MAX_STRING_SIZE:
raise DirtyProtocolError(
f"String too large: {len(encoded)} bytes "
f"(max: {MAX_STRING_SIZE})"
)
return bytes([TYPE_STRING]) + struct.pack(">I", len(encoded)) + encoded
if isinstance(value, (list, tuple)):
if len(value) > MAX_LIST_SIZE:
raise DirtyProtocolError(
f"List too large: {len(value)} items "
f"(max: {MAX_LIST_SIZE})"
)
parts = [bytes([TYPE_LIST]), struct.pack(">I", len(value))]
for item in value:
parts.append(TLVEncoder.encode(item))
return b"".join(parts)
if isinstance(value, dict):
if len(value) > MAX_DICT_SIZE:
raise DirtyProtocolError(
f"Dict too large: {len(value)} items "
f"(max: {MAX_DICT_SIZE})"
)
parts = [bytes([TYPE_DICT]), struct.pack(">I", len(value))]
for k, v in value.items():
# Convert keys to strings (like JSON)
if not isinstance(k, str):
k = str(k)
parts.append(TLVEncoder.encode(k))
parts.append(TLVEncoder.encode(v))
return b"".join(parts)
raise DirtyProtocolError(
f"Unsupported type for TLV encoding: {type(value).__name__}"
)
@staticmethod
def decode(data: bytes, offset: int = 0) -> tuple: # pylint: disable=too-many-return-statements
"""
Decode a TLV-encoded value from binary data.
Args:
data: Binary data to decode
offset: Starting offset in the data
Returns:
tuple: (decoded_value, new_offset)
Raises:
DirtyProtocolError: If data is malformed or truncated
"""
if offset >= len(data):
raise DirtyProtocolError(
"Truncated TLV data: no type byte",
raw_data=data[offset:offset + 20]
)
type_code = data[offset]
offset += 1
if type_code == TYPE_NONE:
return None, offset
if type_code == TYPE_BOOL:
if offset >= len(data):
raise DirtyProtocolError(
"Truncated TLV data: missing bool value",
raw_data=data[offset - 1:offset + 20]
)
value = data[offset] != 0x00
return value, offset + 1
if type_code == TYPE_INT64:
if offset + 8 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete int64",
raw_data=data[offset - 1:offset + 20]
)
value = struct.unpack(">q", data[offset:offset + 8])[0]
return value, offset + 8
if type_code == TYPE_FLOAT64:
if offset + 8 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete float64",
raw_data=data[offset - 1:offset + 20]
)
value = struct.unpack(">d", data[offset:offset + 8])[0]
return value, offset + 8
if type_code == TYPE_BYTES:
if offset + 4 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete bytes length",
raw_data=data[offset - 1:offset + 20]
)
length = struct.unpack(">I", data[offset:offset + 4])[0]
offset += 4
if length > MAX_BYTES_SIZE:
raise DirtyProtocolError(
f"Bytes too large: {length} bytes (max: {MAX_BYTES_SIZE})"
)
if offset + length > len(data):
raise DirtyProtocolError(
f"Truncated TLV data: expected {length} bytes, "
f"got {len(data) - offset}",
raw_data=data[offset - 5:offset + 20]
)
value = data[offset:offset + length]
return value, offset + length
if type_code == TYPE_STRING:
if offset + 4 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete string length",
raw_data=data[offset - 1:offset + 20]
)
length = struct.unpack(">I", data[offset:offset + 4])[0]
offset += 4
if length > MAX_STRING_SIZE:
raise DirtyProtocolError(
f"String too large: {length} bytes (max: {MAX_STRING_SIZE})"
)
if offset + length > len(data):
raise DirtyProtocolError(
f"Truncated TLV data: expected {length} bytes for string, "
f"got {len(data) - offset}",
raw_data=data[offset - 5:offset + 20]
)
try:
value = data[offset:offset + length].decode("utf-8")
except UnicodeDecodeError as e:
raise DirtyProtocolError(
f"Invalid UTF-8 in string: {e}",
raw_data=data[offset:offset + min(length, 20)]
)
return value, offset + length
if type_code == TYPE_LIST:
if offset + 4 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete list count",
raw_data=data[offset - 1:offset + 20]
)
count = struct.unpack(">I", data[offset:offset + 4])[0]
offset += 4
if count > MAX_LIST_SIZE:
raise DirtyProtocolError(
f"List too large: {count} items (max: {MAX_LIST_SIZE})"
)
items = []
for _ in range(count):
item, offset = TLVEncoder.decode(data, offset)
items.append(item)
return items, offset
if type_code == TYPE_DICT:
if offset + 4 > len(data):
raise DirtyProtocolError(
"Truncated TLV data: incomplete dict count",
raw_data=data[offset - 1:offset + 20]
)
count = struct.unpack(">I", data[offset:offset + 4])[0]
offset += 4
if count > MAX_DICT_SIZE:
raise DirtyProtocolError(
f"Dict too large: {count} items (max: {MAX_DICT_SIZE})"
)
result = {}
for _ in range(count):
key, offset = TLVEncoder.decode(data, offset)
if not isinstance(key, str):
raise DirtyProtocolError(
f"Dict key must be string, got {type(key).__name__}"
)
value, offset = TLVEncoder.decode(data, offset)
result[key] = value
return result, offset
raise DirtyProtocolError(
f"Unknown TLV type code: 0x{type_code:02x}",
raw_data=data[offset - 1:offset + 20]
)
@staticmethod
def decode_full(data: bytes):
"""
Decode a complete TLV-encoded value, ensuring all data is consumed.
Args:
data: Binary data to decode
Returns:
Decoded Python value
Raises:
DirtyProtocolError: If data is malformed or has trailing bytes
"""
value, offset = TLVEncoder.decode(data, 0)
if offset != len(data):
raise DirtyProtocolError(
f"Trailing data after TLV: {len(data) - offset} bytes",
raw_data=data[offset:offset + 20]
)
return value

View File

@ -12,11 +12,13 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_response,
make_chunk_message,
make_end_message,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.errors import DirtyError
@ -34,16 +36,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -63,7 +71,7 @@ class MockStreamReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding:
client_writer = MockStreamWriter()
# Mock worker connection that returns chunks
chunk1 = make_chunk_message("req-123", "Hello")
chunk2 = make_chunk_message("req-123", " World")
end = make_end_message("req-123")
chunk1 = make_chunk_message(123, "Hello")
chunk2 = make_chunk_message(123, " World")
end = make_end_message(123)
mock_reader = MockStreamReader([chunk1, chunk2, end])
@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have forwarded all messages
@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", {"result": 42})
response = make_response(123, {"result": 42})
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "compute")
request = make_request(123, "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
chunk = make_chunk_message("req-123", "First")
error = make_error_response("req-123", DirtyError("Something broke"))
chunk = make_chunk_message(123, "First")
error = make_error_response(123, DirtyError("Something broke"))
mock_reader = MockStreamReader([chunk, error])
@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 2
@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming:
arbiter.workers = {} # No workers
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter.route_request(request, client_writer)
assert len(client_writer.messages) == 1
@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming:
# Mock _execute_on_worker to complete immediately
async def mock_execute(pid, request, client_writer):
response = make_response("req-123", "result")
response = make_response(123, "result")
await DirtyProtocol.write_message_async(client_writer, response)
arbiter._execute_on_worker = mock_execute
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "compute")
request = make_request(123, "test:App", "compute")
# Worker queue should be created
assert 1234 not in arbiter.worker_queues
@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks:
# Generate 50 chunks + end
messages = []
for i in range(50):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
mock_reader = MockStreamReader(messages)
@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 51
@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", [1, 2, 3, 4, 5])
response = make_response(123, [1, 2, 3, 4, 5])
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "get_list")
request = make_request(123, "test:App", "get_list")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility:
arbiter = create_arbiter()
client_writer = MockStreamWriter()
error = make_error_response("req-123", DirtyError("Something failed"))
error = make_error_response(123, DirtyError("Something failed"))
mock_reader = MockStreamReader([error])
async def mock_get_connection(pid):
@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility:
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "fail")
request = make_request(123, "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1

View File

@ -11,10 +11,12 @@ from unittest import mock
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
@ -26,7 +28,7 @@ class MockSocket:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
self._sent = []
self.closed = False
@ -69,10 +71,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_chunks(self):
"""Test that stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -83,9 +85,9 @@ class TestDirtyStreamIterator:
def test_stream_iterator_yields_complex_chunks(self):
"""Test that stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -98,8 +100,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_handles_error(self):
"""Test that stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
make_chunk_message(123, "First"),
make_error_response(123, DirtyError("Something broke")),
]
client = create_client_with_mock_socket(messages)
@ -116,7 +118,7 @@ class TestDirtyStreamIterator:
def test_stream_iterator_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
messages = [make_end_message(123)]
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
@ -125,8 +127,8 @@ class TestDirtyStreamIterator:
def test_stream_iterator_stops_after_exhausted(self):
"""Test that iterator stays exhausted after StopIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
make_chunk_message(123, "Only"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -147,10 +149,10 @@ class TestDirtyStreamIterator:
def test_stream_iterator_with_for_loop(self):
"""Test stream iterator works in for loop."""
messages = [
make_chunk_message("req-123", "a"),
make_chunk_message("req-123", "b"),
make_chunk_message("req-123", "c"),
make_end_message("req-123"),
make_chunk_message(123, "a"),
make_chunk_message(123, "b"),
make_chunk_message(123, "c"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -163,8 +165,8 @@ class TestDirtyStreamIterator:
def test_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first next() call."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -179,18 +181,15 @@ class TestDirtyStreamIterator:
# Decode sent request
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
assert msg_type_str == "request"
assert payload["app_path"] == "test:App"
assert payload["action"] == "generate"
assert payload["args"] == ["prompt_arg"]
class TestDirtyStreamIteratorEdgeCases:
@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases:
"""Test streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
client = create_client_with_mock_socket(messages)
@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases:
def test_stream_with_kwargs(self):
"""Test streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_client_with_mock_socket(messages)
@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases:
# Check the sent request includes kwargs
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}
assert payload["args"] == ["arg1"]
assert payload["kwargs"] == {"key": "value"}

View File

@ -10,9 +10,11 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_chunk_message,
make_end_message,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
@ -24,7 +26,7 @@ class MockAsyncReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_chunks(self):
"""Test that async stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_yields_complex_chunks(self):
"""Test that async stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_handles_error(self):
"""Test that async stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
make_chunk_message(123, "First"),
make_error_response(123, DirtyError("Something broke")),
]
client = create_async_client_with_mocks(messages)
@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator:
@pytest.mark.asyncio
async def test_async_stream_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
messages = [make_end_message(123)]
client = create_async_client_with_mocks(messages)
chunks = []
@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_stops_after_exhausted(self):
"""Test that async iterator stays exhausted after StopAsyncIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
make_chunk_message(123, "Only"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator:
async def test_async_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first async iteration."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator:
# Decode sent request
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
assert msg_type_str == "request"
assert payload["app_path"] == "test:App"
assert payload["action"] == "generate"
assert payload["args"] == ["prompt_arg"]
class TestDirtyAsyncStreamIteratorEdgeCases:
@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
"""Test async streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
client = create_async_client_with_mocks(messages)
@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
async def test_async_stream_with_kwargs(self):
"""Test async streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
# Check the sent request includes kwargs
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}
assert payload["args"] == ["arg1"]
assert payload["kwargs"] == {"key": "value"}
class TestDirtyAsyncStreamTimeout:

View File

@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor
from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
)
from gunicorn.dirty.errors import DirtyAppNotFoundError
@ -71,16 +76,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break

View File

@ -18,11 +18,13 @@ from unittest import mock
from gunicorn.config import Config
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
@ -67,16 +69,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -96,7 +104,7 @@ class MockStreamReader:
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
@ -115,10 +123,10 @@ class TestStreamingEndToEnd:
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
# Simulate what a worker would produce for a sync generator
worker_messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
# Create an arbiter with mocked worker connection
@ -141,7 +149,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter()
# Execute request through arbiter
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Verify all messages were forwarded
@ -158,10 +166,10 @@ class TestStreamingEndToEnd:
async def test_async_generator_end_to_end(self):
"""Test complete flow: async generator -> worker -> arbiter -> client."""
worker_messages = [
make_chunk_message("req-456", "Async"),
make_chunk_message("req-456", " "),
make_chunk_message("req-456", "Stream"),
make_end_message("req-456"),
make_chunk_message(456, "Async"),
make_chunk_message(456, " "),
make_chunk_message(456, "Stream"),
make_end_message(456),
]
cfg = Config()
@ -180,7 +188,7 @@ class TestStreamingEndToEnd:
client_writer = MockStreamWriter()
request = make_request("req-456", "test:App", "async_generate")
request = make_request(456, "test:App", "async_generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 4
@ -197,9 +205,9 @@ class TestStreamingErrorHandling:
async def test_error_mid_stream(self):
"""Test that errors during streaming are properly forwarded."""
worker_messages = [
make_chunk_message("req-789", "First"),
make_chunk_message("req-789", "Second"),
make_error_response("req-789", DirtyError("Stream failed")),
make_chunk_message(789, "First"),
make_chunk_message(789, "Second"),
make_error_response(789, DirtyError("Stream failed")),
]
cfg = Config()
@ -218,7 +226,7 @@ class TestStreamingErrorHandling:
client_writer = MockStreamWriter()
request = make_request("req-789", "test:App", "generate_with_error")
request = make_request(789, "test:App", "generate_with_error")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 2 chunks + 1 error
@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration:
return sync_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end
@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration:
return async_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-456", "test:App", "async_generate")
request = make_request(456, "test:App", "async_generate")
await worker.handle_request(request, writer)
# Should have 2 chunks + 1 end

View File

@ -12,9 +12,11 @@ import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_chunk_message,
make_end_message,
HEADER_SIZE,
)
from gunicorn.dirty.worker import DirtyWorker
@ -30,17 +32,22 @@ class FakeStreamWriter:
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming:
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_sync_generator_error_mid_stream(self):
@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming:
return generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming:
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_async_generator_error_mid_stream(self):
@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming:
return async_generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat:
return args[0] + args[1]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "compute", args=(2, 3))
request = make_request(123, "test:App", "compute", args=(2, 3))
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["result"] == 5
@pytest.mark.asyncio
@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat:
return [1, 2, 3, 4, 5]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "get_list")
request = make_request(123, "test:App", "get_list")
await worker.handle_request(request, writer)
# Should have 1 response message (not 5 chunks)
@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat:
raise RuntimeError("Failed!")
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "fail")
request = make_request(123, "test:App", "fail")
await worker.handle_request(request, writer)
# Should have 1 error message
@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat:
return None
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "void")
request = make_request(123, "test:App", "void")
await worker.handle_request(request, writer)
# Should have 1 response message
@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
assert len(writer.messages) == 3 # 2 chunks + 1 end
@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData:
return empty_generate()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have just 1 end message
@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData:
return generate_many()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 100 chunks + 1 end message
@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat:
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have been notified at least once per chunk + initial
@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation:
writer = FakeStreamWriter()
# Send a message with unknown type
message = {"type": "unknown", "id": "req-123"}
message = {"type": "unknown", "id": 123}
await worker.handle_request(message, writer)
assert len(writer.messages) == 1

View File

@ -0,0 +1,21 @@
-----BEGIN CERTIFICATE-----
MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL
BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0
MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx
EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG
A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY
6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt
z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq
AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2
HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr
FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC
TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q
FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD
AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ
KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G
Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a
A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr
YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z
FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO
5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo=
-----END CERTIFICATE-----

View File

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d
Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP
oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj
KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J
P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7
wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/
LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD
CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU
NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s
GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB
7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO
jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7
2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD
xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1
FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn
UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S
zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W
GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21
9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC
7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf
7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS
HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY
4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J
vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g
5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp
JlfEuUHtKpFJReBnmWTFiQ==
-----END PRIVATE KEY-----

View File

@ -14,7 +14,12 @@ import pytest
from gunicorn.config import Config
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.errors import DirtyError
from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
)
class MockStreamWriter:
@ -29,16 +34,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break

View File

@ -11,7 +11,7 @@ import pytest
from gunicorn.arbiter import Arbiter
from gunicorn.config import Config
from gunicorn.app.base import BaseApplication
from gunicorn.dirty.protocol import DirtyProtocol
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE
class MockStreamWriter:
@ -26,16 +26,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break

View File

@ -2,7 +2,7 @@
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty arbiter protocol module."""
"""Tests for dirty worker binary protocol module."""
import asyncio
import os
@ -11,12 +11,23 @@ import struct
import pytest
from gunicorn.dirty.protocol import (
BinaryProtocol,
DirtyProtocol,
make_request,
make_response,
make_error_response,
make_chunk_message,
make_end_message,
MAGIC,
VERSION,
HEADER_SIZE,
HEADER_FORMAT,
MSG_TYPE_REQUEST,
MSG_TYPE_RESPONSE,
MSG_TYPE_ERROR,
MSG_TYPE_CHUNK,
MSG_TYPE_END,
MAX_MESSAGE_SIZE,
)
from gunicorn.dirty.errors import (
DirtyError,
@ -26,118 +37,194 @@ from gunicorn.dirty.errors import (
)
class TestDirtyProtocolEncodeDecode:
"""Tests for encode/decode functionality."""
class TestBinaryProtocolHeader:
"""Tests for header encoding/decoding."""
def test_encode_decode_roundtrip(self):
"""Test basic encode/decode roundtrip."""
message = {"type": "request", "id": "123", "data": "hello"}
encoded = DirtyProtocol.encode(message)
def test_header_size(self):
"""Test header size is 16 bytes."""
assert HEADER_SIZE == 16
# Check header format
assert len(encoded) > DirtyProtocol.HEADER_SIZE
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
encoded[:DirtyProtocol.HEADER_SIZE]
)[0]
assert length == len(encoded) - DirtyProtocol.HEADER_SIZE
def test_encode_header(self):
"""Test header encoding."""
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100)
assert len(header) == HEADER_SIZE
assert header[:2] == MAGIC
assert header[2] == VERSION
assert header[3] == MSG_TYPE_REQUEST
# Decode payload
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header(self):
"""Test header decoding."""
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200)
msg_type, request_id, length = BinaryProtocol.decode_header(header)
assert msg_type == MSG_TYPE_RESPONSE
assert request_id == 67890
assert length == 200
def test_encode_decode_complex_data(self):
"""Test with complex nested data structures."""
message = {
"type": "response",
"id": "456",
"result": {
"models": ["gpt-4", "claude-3"],
"config": {"temperature": 0.7, "max_tokens": 1000},
"metadata": None,
},
"args": [1, 2, 3],
}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header_invalid_magic(self):
"""Test header decoding with invalid magic."""
header = b"XX" + b"\x01\x01" + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "magic" in str(exc_info.value).lower()
def test_encode_decode_unicode(self):
"""Test with unicode characters."""
message = {
"type": "request",
"data": "Hello, world!"
}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_decode_header_invalid_version(self):
"""Test header decoding with invalid version."""
header = MAGIC + b"\x99\x01" + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "version" in str(exc_info.value).lower()
def test_encode_large_message(self):
def test_decode_header_invalid_type(self):
"""Test header decoding with invalid message type."""
header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "type" in str(exc_info.value).lower()
def test_decode_header_too_large(self):
"""Test header decoding rejects too-large messages."""
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
MAX_MESSAGE_SIZE + 1, 0)
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "too large" in str(exc_info.value).lower()
def test_decode_header_too_short(self):
"""Test header decoding with too-short data."""
header = MAGIC + b"\x01"
with pytest.raises(DirtyProtocolError) as exc_info:
BinaryProtocol.decode_header(header)
assert "short" in str(exc_info.value).lower()
class TestBinaryProtocolEncodeDecode:
"""Tests for message encoding/decoding."""
def test_encode_decode_request(self):
"""Test request encoding/decoding roundtrip."""
encoded = BinaryProtocol.encode_request(
request_id=12345,
app_path="myapp.ml:MLApp",
action="predict",
args=("data",),
kwargs={"temperature": 0.7}
)
assert len(encoded) > HEADER_SIZE
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "request"
assert request_id == 12345
assert payload["app_path"] == "myapp.ml:MLApp"
assert payload["action"] == "predict"
assert payload["args"] == ["data"]
assert payload["kwargs"] == {"temperature": 0.7}
def test_encode_decode_response(self):
"""Test response encoding/decoding roundtrip."""
result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}}
encoded = BinaryProtocol.encode_response(request_id=67890, result=result)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "response"
assert request_id == 67890
assert payload["result"] == result
def test_encode_decode_error(self):
"""Test error encoding/decoding roundtrip."""
error = DirtyTimeoutError("Timed out", timeout=30)
encoded = BinaryProtocol.encode_error(request_id=11111, error=error)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "error"
assert request_id == 11111
assert payload["error"]["error_type"] == "DirtyTimeoutError"
assert "Timed out" in payload["error"]["message"]
def test_encode_decode_chunk(self):
"""Test chunk encoding/decoding roundtrip."""
chunk_data = {"token": "hello", "index": 5}
encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "chunk"
assert request_id == 22222
assert payload["data"] == chunk_data
def test_encode_decode_end(self):
"""Test end message encoding/decoding roundtrip."""
encoded = BinaryProtocol.encode_end(request_id=33333)
assert len(encoded) == HEADER_SIZE # End has no payload
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type_str == "end"
assert request_id == 33333
assert payload == {}
def test_encode_decode_binary_data(self):
"""Test binary data passes through without base64 encoding."""
binary_data = bytes(range(256))
encoded = BinaryProtocol.encode_response(
request_id=44444,
result={"data": binary_data}
)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert payload["result"]["data"] == binary_data
def test_encode_decode_large_message(self):
"""Test encoding a large message."""
large_data = "x" * (1024 * 1024) # 1 MB of data
message = {"type": "request", "data": large_data}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
large_data = b"x" * (1024 * 1024) # 1 MB
encoded = BinaryProtocol.encode_response(
request_id=55555,
result={"data": large_data}
)
def test_encode_empty_dict(self):
"""Test encoding an empty dictionary."""
message = {}
encoded = DirtyProtocol.encode(message)
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == message
def test_encode_message_too_large(self):
"""Test that encoding a message that's too large raises error."""
large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000)
message = {"data": large_data}
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.encode(message)
assert "too large" in str(exc_info.value)
def test_encode_non_serializable(self):
"""Test that encoding non-JSON-serializable data raises error."""
message = {"func": lambda x: x}
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.encode(message)
assert "Failed to encode" in str(exc_info.value)
def test_decode_invalid_json(self):
"""Test decoding invalid JSON raises error."""
invalid_data = b"not valid json"
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.decode(invalid_data)
assert "Failed to decode" in str(exc_info.value)
def test_decode_invalid_unicode(self):
"""Test decoding invalid unicode raises error."""
invalid_data = b"\x80\x81\x82"
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.decode(invalid_data)
assert "Failed to decode" in str(exc_info.value)
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
assert payload["result"]["data"] == large_data
class TestDirtyProtocolSync:
class TestBinaryProtocolSync:
"""Tests for synchronous socket operations."""
def test_read_write_message(self):
"""Test read/write through socket pair."""
# Create a socket pair for testing
server_sock, client_sock = socket.socketpair()
try:
message = {"type": "request", "id": "123", "action": "test"}
message = make_request(
request_id=12345,
app_path="test:App",
action="run"
)
# Write message
DirtyProtocol.write_message(client_sock, message)
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
# Read message
received = DirtyProtocol.read_message(server_sock)
assert received == message
assert received["type"] == "request"
assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \
received["id"] == 12345
assert received["app_path"] == "test:App"
assert received["action"] == "run"
finally:
server_sock.close()
client_sock.close()
def test_read_write_with_int_id(self):
"""Test read/write with integer request ID."""
server_sock, client_sock = socket.socketpair()
try:
message = {
"type": "request",
"id": 999888777,
"app_path": "test:App",
"action": "run",
"args": [],
"kwargs": {}
}
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
assert received["id"] == 999888777
finally:
server_sock.close()
client_sock.close()
@ -147,19 +234,17 @@ class TestDirtyProtocolSync:
server_sock, client_sock = socket.socketpair()
try:
messages = [
{"type": "request", "id": "1"},
{"type": "request", "id": "2"},
{"type": "request", "id": "3"},
make_request(i, f"app{i}:App", f"action{i}")
for i in range(1, 4)
]
# Write all messages
for msg in messages:
DirtyProtocol.write_message(client_sock, msg)
BinaryProtocol.write_message(client_sock, msg)
# Read all messages
for expected in messages:
received = DirtyProtocol.read_message(server_sock)
assert received == expected
for i, _ in enumerate(messages, 1):
received = BinaryProtocol.read_message(server_sock)
assert received["app_path"] == f"app{i}:App"
assert received["action"] == f"action{i}"
finally:
server_sock.close()
client_sock.close()
@ -169,39 +254,51 @@ class TestDirtyProtocolSync:
server_sock, client_sock = socket.socketpair()
client_sock.close()
with pytest.raises(DirtyProtocolError) as exc_info:
DirtyProtocol.read_message(server_sock)
BinaryProtocol.read_message(server_sock)
assert "closed" in str(exc_info.value).lower()
server_sock.close()
def test_binary_data_roundtrip(self):
"""Test binary data roundtrip through socket."""
server_sock, client_sock = socket.socketpair()
try:
binary_payload = b"\x00\x01\x02\xff\xfe\xfd"
message = make_response(12345, {"binary": binary_payload})
class TestDirtyProtocolAsync:
BinaryProtocol.write_message(client_sock, message)
received = BinaryProtocol.read_message(server_sock)
assert received["result"]["binary"] == binary_payload
finally:
server_sock.close()
client_sock.close()
class TestBinaryProtocolAsync:
"""Tests for async stream operations."""
@pytest.mark.asyncio
async def test_async_read_write(self):
"""Test async read/write with mock streams."""
message = {"type": "request", "id": "123"}
message = make_request(12345, "test:App", "run")
# Create a pipe for testing
read_fd, write_fd = os.pipe()
try:
reader = asyncio.StreamReader()
_ = asyncio.StreamReaderProtocol(reader)
# Write the message to the pipe
encoded = DirtyProtocol.encode(message)
encoded = BinaryProtocol._encode_from_dict(message)
os.write(write_fd, encoded)
os.close(write_fd)
write_fd = None
# Feed data to reader
data = os.read(read_fd, len(encoded))
reader.feed_data(data)
reader.feed_eof()
# Read the message
received = await DirtyProtocol.read_message_async(reader)
assert received == message
received = await BinaryProtocol.read_message_async(reader)
assert received["type"] == "request"
assert received["app_path"] == "test:App"
finally:
if write_fd is not None:
os.close(write_fd)
@ -211,12 +308,11 @@ class TestDirtyProtocolAsync:
async def test_async_read_incomplete_header(self):
"""Test async read with incomplete header."""
reader = asyncio.StreamReader()
# Feed only 2 bytes instead of 4
reader.feed_data(b"\x00\x00")
reader.feed_data(MAGIC + b"\x01") # Only 3 bytes
reader.feed_eof()
with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)):
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
@pytest.mark.asyncio
async def test_async_read_empty_connection(self):
@ -225,36 +321,33 @@ class TestDirtyProtocolAsync:
reader.feed_eof()
with pytest.raises(asyncio.IncompleteReadError):
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
@pytest.mark.asyncio
async def test_async_read_invalid_magic(self):
"""Test async read rejects invalid magic."""
reader = asyncio.StreamReader()
header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await BinaryProtocol.read_message_async(reader)
assert "magic" in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_async_read_message_too_large(self):
"""Test async read rejects too-large messages."""
reader = asyncio.StreamReader()
# Create a header claiming an absurdly large message
header = struct.pack(
DirtyProtocol.HEADER_FORMAT,
DirtyProtocol.MAX_MESSAGE_SIZE + 1000
)
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
MAX_MESSAGE_SIZE + 1000, 0)
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await DirtyProtocol.read_message_async(reader)
await BinaryProtocol.read_message_async(reader)
assert "too large" in str(exc_info.value)
@pytest.mark.asyncio
async def test_async_read_empty_message(self):
"""Test async read rejects empty messages."""
reader = asyncio.StreamReader()
header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0)
reader.feed_data(header)
reader.feed_eof()
with pytest.raises(DirtyProtocolError) as exc_info:
await DirtyProtocol.read_message_async(reader)
assert "Empty message" in str(exc_info.value)
class TestMessageBuilders:
"""Tests for message builder helper functions."""
@ -340,9 +433,9 @@ class TestMessageBuilders:
assert chunk["id"] == "req-456"
assert chunk["data"] == data
def test_make_chunk_message_with_list_data(self):
"""Test chunk message with list data."""
data = [1, 2, 3, "token"]
def test_make_chunk_message_with_binary_data(self):
"""Test chunk message with binary data."""
data = b"\x00\x01\x02\xff"
chunk = make_chunk_message("req-789", data)
assert chunk["data"] == data
@ -353,22 +446,22 @@ class TestMessageBuilders:
assert end["id"] == "req-123"
assert "data" not in end
def test_chunk_and_end_encode_decode(self):
def test_chunk_and_end_roundtrip(self):
"""Test that chunk and end messages can be encoded/decoded."""
chunk = make_chunk_message("req-123", {"token": "hello"})
end = make_end_message("req-123")
chunk = make_chunk_message(12345, {"token": "hello"})
end = make_end_message(12345)
# Test chunk roundtrip
encoded_chunk = DirtyProtocol.encode(chunk)
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == chunk
encoded_chunk = BinaryProtocol._encode_from_dict(chunk)
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk)
assert msg_type == "chunk"
assert payload["data"] == {"token": "hello"}
# Test end roundtrip
encoded_end = DirtyProtocol.encode(end)
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == end
encoded_end = BinaryProtocol._encode_from_dict(end)
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end)
assert msg_type == "end"
assert payload == {}
class TestDirtyErrors:
@ -417,3 +510,58 @@ class TestDirtyErrors:
assert error.action == "run"
assert error.traceback == "Traceback..."
assert "myapp:App" in str(error)
class TestBackwardsCompatibility:
"""Tests for backwards compatibility with old JSON API."""
def test_dirty_protocol_alias(self):
"""Test that DirtyProtocol is an alias for BinaryProtocol."""
assert DirtyProtocol is BinaryProtocol
def test_header_size_attribute(self):
"""Test HEADER_SIZE is accessible on class."""
assert DirtyProtocol.HEADER_SIZE == 16
def test_msg_type_constants(self):
"""Test message type constants are strings for compatibility."""
assert DirtyProtocol.MSG_TYPE_REQUEST == "request"
assert DirtyProtocol.MSG_TYPE_RESPONSE == "response"
assert DirtyProtocol.MSG_TYPE_ERROR == "error"
assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk"
assert DirtyProtocol.MSG_TYPE_END == "end"
def test_encode_decode_preserves_dict_format(self):
"""Test that read_message returns dict compatible with old API."""
server_sock, client_sock = socket.socketpair()
try:
message = {
"type": "response",
"id": 12345,
"result": {"status": "ok"}
}
DirtyProtocol.write_message(client_sock, message)
received = DirtyProtocol.read_message(server_sock)
# Old API: access via dict keys
assert received["type"] == "response"
assert received["result"]["status"] == "ok"
finally:
server_sock.close()
client_sock.close()
def test_string_request_id_handled(self):
"""Test that string request IDs are handled (hashed to int)."""
server_sock, client_sock = socket.socketpair()
try:
message = make_request("uuid-string-id", "test:App", "run")
DirtyProtocol.write_message(client_sock, message)
received = DirtyProtocol.read_message(server_sock)
# Request ID should be converted to int
assert isinstance(received["id"], int)
finally:
server_sock.close()
client_sock.close()

554
tests/test_dirty_tlv.py Normal file
View File

@ -0,0 +1,554 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty TLV binary encoder/decoder."""
import math
import struct
import pytest
from gunicorn.dirty.tlv import (
TLVEncoder,
TYPE_NONE,
TYPE_BOOL,
TYPE_INT64,
TYPE_FLOAT64,
TYPE_BYTES,
TYPE_STRING,
TYPE_LIST,
TYPE_DICT,
MAX_STRING_SIZE,
MAX_BYTES_SIZE,
MAX_LIST_SIZE,
MAX_DICT_SIZE,
)
from gunicorn.dirty.errors import DirtyProtocolError
class TestTLVEncoderBasicTypes:
"""Tests for basic type encoding/decoding."""
def test_encode_decode_none(self):
"""Test None encoding/decoding."""
encoded = TLVEncoder.encode(None)
assert encoded == bytes([TYPE_NONE])
value, offset = TLVEncoder.decode(encoded, 0)
assert value is None
assert offset == 1
def test_encode_decode_true(self):
"""Test True encoding/decoding."""
encoded = TLVEncoder.encode(True)
assert encoded == bytes([TYPE_BOOL, 0x01])
value, offset = TLVEncoder.decode(encoded, 0)
assert value is True
assert offset == 2
def test_encode_decode_false(self):
"""Test False encoding/decoding."""
encoded = TLVEncoder.encode(False)
assert encoded == bytes([TYPE_BOOL, 0x00])
value, offset = TLVEncoder.decode(encoded, 0)
assert value is False
assert offset == 2
def test_encode_decode_positive_int(self):
"""Test positive integer encoding/decoding."""
encoded = TLVEncoder.encode(42)
assert encoded[0] == TYPE_INT64
assert len(encoded) == 9 # 1 type + 8 value
value, offset = TLVEncoder.decode(encoded, 0)
assert value == 42
assert offset == 9
def test_encode_decode_negative_int(self):
"""Test negative integer encoding/decoding."""
encoded = TLVEncoder.encode(-12345)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == -12345
def test_encode_decode_large_int(self):
"""Test large integer encoding/decoding."""
large_val = 2**62
encoded = TLVEncoder.encode(large_val)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == large_val
def test_encode_decode_zero(self):
"""Test zero encoding/decoding."""
encoded = TLVEncoder.encode(0)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == 0
def test_encode_decode_float(self):
"""Test float encoding/decoding."""
encoded = TLVEncoder.encode(3.14159)
assert encoded[0] == TYPE_FLOAT64
assert len(encoded) == 9 # 1 type + 8 value
value, offset = TLVEncoder.decode(encoded, 0)
assert abs(value - 3.14159) < 1e-10
def test_encode_decode_negative_float(self):
"""Test negative float encoding/decoding."""
encoded = TLVEncoder.encode(-273.15)
value, offset = TLVEncoder.decode(encoded, 0)
assert abs(value - (-273.15)) < 1e-10
def test_encode_decode_float_infinity(self):
"""Test infinity encoding/decoding."""
encoded = TLVEncoder.encode(float('inf'))
value, offset = TLVEncoder.decode(encoded, 0)
assert value == float('inf')
def test_encode_decode_float_nan(self):
"""Test NaN encoding/decoding."""
encoded = TLVEncoder.encode(float('nan'))
value, offset = TLVEncoder.decode(encoded, 0)
assert math.isnan(value)
class TestTLVEncoderBytes:
"""Tests for bytes encoding/decoding."""
def test_encode_decode_empty_bytes(self):
"""Test empty bytes encoding/decoding."""
encoded = TLVEncoder.encode(b"")
assert encoded[0] == TYPE_BYTES
value, offset = TLVEncoder.decode(encoded, 0)
assert value == b""
def test_encode_decode_bytes(self):
"""Test bytes encoding/decoding."""
data = b"\x00\x01\x02\xff\xfe\xfd"
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_large_bytes(self):
"""Test large bytes encoding/decoding."""
data = b"x" * 10000
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_bytes_too_large(self):
"""Test that bytes exceeding max size raises error."""
# We won't actually allocate MAX_BYTES_SIZE, just check the encoding
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.encode(b"x" * (MAX_BYTES_SIZE + 1))
assert "too large" in str(exc_info.value).lower()
class TestTLVEncoderString:
"""Tests for string encoding/decoding."""
def test_encode_decode_empty_string(self):
"""Test empty string encoding/decoding."""
encoded = TLVEncoder.encode("")
assert encoded[0] == TYPE_STRING
value, offset = TLVEncoder.decode(encoded, 0)
assert value == ""
def test_encode_decode_ascii_string(self):
"""Test ASCII string encoding/decoding."""
encoded = TLVEncoder.encode("hello world")
value, offset = TLVEncoder.decode(encoded, 0)
assert value == "hello world"
def test_encode_decode_unicode_string(self):
"""Test Unicode string encoding/decoding."""
text = "Hello, world! \u00a9 \u2603 \U0001F600"
encoded = TLVEncoder.encode(text)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == text
def test_encode_decode_chinese(self):
"""Test Chinese characters encoding/decoding."""
text = "Hello, world!"
encoded = TLVEncoder.encode(text)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == text
def test_encode_decode_emoji(self):
"""Test emoji encoding/decoding."""
text = "Test emoji"
encoded = TLVEncoder.encode(text)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == text
def test_encode_decode_large_string(self):
"""Test large string encoding/decoding."""
text = "x" * 10000
encoded = TLVEncoder.encode(text)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == text
class TestTLVEncoderList:
"""Tests for list encoding/decoding."""
def test_encode_decode_empty_list(self):
"""Test empty list encoding/decoding."""
encoded = TLVEncoder.encode([])
assert encoded[0] == TYPE_LIST
value, offset = TLVEncoder.decode(encoded, 0)
assert value == []
def test_encode_decode_simple_list(self):
"""Test simple list encoding/decoding."""
data = [1, 2, 3]
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_mixed_list(self):
"""Test mixed type list encoding/decoding."""
data = [1, "hello", 3.14, True, None, b"bytes"]
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_nested_list(self):
"""Test nested list encoding/decoding."""
data = [[1, 2], [3, [4, 5]], ["a", "b"]]
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_tuple_as_list(self):
"""Test that tuples are encoded as lists."""
data = (1, 2, 3)
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == [1, 2, 3] # Decoded as list
def test_encode_decode_large_list(self):
"""Test large list encoding/decoding."""
data = list(range(1000))
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
class TestTLVEncoderDict:
"""Tests for dict encoding/decoding."""
def test_encode_decode_empty_dict(self):
"""Test empty dict encoding/decoding."""
encoded = TLVEncoder.encode({})
assert encoded[0] == TYPE_DICT
value, offset = TLVEncoder.decode(encoded, 0)
assert value == {}
def test_encode_decode_simple_dict(self):
"""Test simple dict encoding/decoding."""
data = {"a": 1, "b": 2, "c": 3}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_mixed_values_dict(self):
"""Test dict with mixed value types."""
data = {
"int": 42,
"float": 3.14,
"string": "hello",
"bool": True,
"none": None,
"bytes": b"data",
"list": [1, 2, 3],
}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_nested_dict(self):
"""Test nested dict encoding/decoding."""
data = {
"outer": {
"inner": {
"value": 42
},
"list": [{"a": 1}, {"b": 2}]
}
}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_dict_non_string_key_converted(self):
"""Test that non-string keys are converted to strings (like JSON)."""
data = {1: "value", 2: "other"}
encoded = TLVEncoder.encode(data)
decoded, _ = TLVEncoder.decode(encoded, 0)
# Keys should be converted to strings
assert decoded == {"1": "value", "2": "other"}
class TestTLVEncoderComplexStructures:
"""Tests for complex nested structures."""
def test_encode_decode_request_like(self):
"""Test encoding/decoding a request-like structure."""
data = {
"id": 12345,
"app_path": "myapp.ml:MLApp",
"action": "predict",
"args": [b"input_data", 0.7],
"kwargs": {"temperature": 0.7, "max_tokens": 1000},
}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_response_like(self):
"""Test encoding/decoding a response-like structure."""
data = {
"id": 12345,
"result": {
"predictions": [0.1, 0.2, 0.7],
"metadata": {"model": "v1.0", "latency_ms": 42},
}
}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
def test_encode_decode_deeply_nested(self):
"""Test deeply nested structures."""
data = {"a": {"b": {"c": {"d": {"e": {"f": "deep"}}}}}}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data
class TestTLVEncoderRoundtrip:
"""Tests for complete roundtrip using decode_full."""
def test_decode_full_simple(self):
"""Test decode_full with simple value."""
data = {"key": "value"}
encoded = TLVEncoder.encode(data)
value = TLVEncoder.decode_full(encoded)
assert value == data
def test_decode_full_trailing_data(self):
"""Test decode_full raises on trailing data."""
encoded = TLVEncoder.encode(42) + b"extra"
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode_full(encoded)
assert "trailing" in str(exc_info.value).lower()
class TestTLVEncoderErrors:
"""Tests for error handling."""
def test_decode_empty_data(self):
"""Test decoding empty data raises error."""
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(b"", 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_int(self):
"""Test decoding truncated int raises error."""
# TYPE_INT64 followed by only 4 bytes instead of 8
data = bytes([TYPE_INT64, 0, 0, 0, 0])
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_float(self):
"""Test decoding truncated float raises error."""
data = bytes([TYPE_FLOAT64, 0, 0, 0, 0])
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_bytes_length(self):
"""Test decoding truncated bytes length raises error."""
data = bytes([TYPE_BYTES, 0, 0]) # Only 2 bytes of length
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_bytes_data(self):
"""Test decoding truncated bytes data raises error."""
# Says 10 bytes but only provides 5
data = bytes([TYPE_BYTES]) + struct.pack(">I", 10) + b"12345"
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_string_length(self):
"""Test decoding truncated string length raises error."""
data = bytes([TYPE_STRING, 0])
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_string_data(self):
"""Test decoding truncated string data raises error."""
data = bytes([TYPE_STRING]) + struct.pack(">I", 10) + b"hello"
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_invalid_utf8(self):
"""Test decoding invalid UTF-8 raises error."""
# Valid length, but invalid UTF-8 bytes
data = bytes([TYPE_STRING]) + struct.pack(">I", 3) + b"\x80\x81\x82"
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "utf-8" in str(exc_info.value).lower()
def test_decode_truncated_list_count(self):
"""Test decoding truncated list count raises error."""
data = bytes([TYPE_LIST, 0])
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_truncated_dict_count(self):
"""Test decoding truncated dict count raises error."""
data = bytes([TYPE_DICT, 0])
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "truncated" in str(exc_info.value).lower()
def test_decode_unknown_type(self):
"""Test decoding unknown type raises error."""
data = bytes([0xFF]) # Unknown type
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "unknown" in str(exc_info.value).lower()
def test_encode_unsupported_type(self):
"""Test encoding unsupported type raises error."""
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.encode(object())
assert "unsupported type" in str(exc_info.value).lower()
def test_encode_function_raises_error(self):
"""Test encoding a function raises error."""
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.encode(lambda x: x)
assert "unsupported type" in str(exc_info.value).lower()
def test_decode_dict_non_string_key_in_data(self):
"""Test decoding dict with non-string key raises error."""
# Manually construct a dict with int key
# TYPE_DICT, count=1, TYPE_INT64 key, TYPE_INT64 value
data = (
bytes([TYPE_DICT])
+ struct.pack(">I", 1)
+ bytes([TYPE_INT64])
+ struct.pack(">q", 1) # Key (int, not string)
+ bytes([TYPE_INT64])
+ struct.pack(">q", 2) # Value
)
with pytest.raises(DirtyProtocolError) as exc_info:
TLVEncoder.decode(data, 0)
assert "string" in str(exc_info.value).lower()
class TestTLVEncoderOffset:
"""Tests for offset handling."""
def test_decode_with_offset(self):
"""Test decoding from specific offset."""
# Create data with prefix
prefix = b"garbage"
encoded = TLVEncoder.encode(42)
data = prefix + encoded
value, offset = TLVEncoder.decode(data, len(prefix))
assert value == 42
assert offset == len(prefix) + len(encoded)
def test_decode_multiple_values(self):
"""Test decoding multiple consecutive values."""
v1 = TLVEncoder.encode("hello")
v2 = TLVEncoder.encode(42)
v3 = TLVEncoder.encode([1, 2, 3])
data = v1 + v2 + v3
offset = 0
val1, offset = TLVEncoder.decode(data, offset)
assert val1 == "hello"
val2, offset = TLVEncoder.decode(data, offset)
assert val2 == 42
val3, offset = TLVEncoder.decode(data, offset)
assert val3 == [1, 2, 3]
assert offset == len(data)
class TestTLVEncoderBinaryData:
"""Tests for binary data handling (the main motivation for this protocol)."""
def test_binary_data_no_encoding(self):
"""Test that binary data is passed through without encoding."""
# This is the key advantage over JSON - binary data doesn't need base64
binary_data = bytes(range(256)) # All byte values
encoded = TLVEncoder.encode(binary_data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == binary_data
def test_binary_with_null_bytes(self):
"""Test binary data with embedded null bytes."""
binary_data = b"\x00\x00\xff\x00\x00"
encoded = TLVEncoder.encode(binary_data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == binary_data
def test_binary_in_nested_structure(self):
"""Test binary data inside nested structures."""
data = {
"image": b"\x89PNG\r\n\x1a\n" + b"\x00" * 100,
"metadata": {"width": 640, "height": 480},
"chunks": [b"chunk1", b"chunk2", b"chunk3"],
}
encoded = TLVEncoder.encode(data)
value, offset = TLVEncoder.decode(encoded, 0)
assert value == data

View File

@ -12,7 +12,13 @@ import pytest
from gunicorn.config import Config
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
HEADER_SIZE,
HEADER_FORMAT,
)
from gunicorn.dirty.errors import DirtyAppNotFoundError
@ -56,17 +62,22 @@ class MockStreamWriter:
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest:
worker.load_apps()
request = make_request(
request_id="test-123",
request_id=123,
app_path="tests.support_dirty_app:TestDirtyApp",
action="compute",
args=(2, 3),
@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest:
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
assert response["id"] == "test-123"
assert response["id"] == 123
assert response["result"] == 6
@pytest.mark.asyncio
@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest:
worker.load_apps()
request = make_request(
request_id="test-456",
request_id=456,
app_path="tests.support_dirty_app:TestDirtyApp",
action="compute",
args=(2, 3),
@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest:
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
assert response["id"] == "test-456"
assert response["id"] == 456
assert "Unknown operation" in response["error"]["message"]
@pytest.mark.asyncio
@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest:
socket_path=socket_path
)
request = {"type": "unknown", "id": "test-789"}
request = {"type": "unknown", "id": 789}
writer = MockStreamWriter()
await worker.handle_request(request, writer)
@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync:
# Create a simple test using stream reader/writer
request = make_request(
request_id="conn-test",
request_id=999,
app_path="tests.support_dirty_app:TestDirtyApp",
action="compute",
args=(5, 3),
@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync:
# Mock reader and writer
reader = asyncio.StreamReader()
encoded_request = DirtyProtocol.encode(request)
encoded_request = BinaryProtocol._encode_from_dict(request)
reader.feed_data(encoded_request)
reader.feed_eof()