mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 03:01:31 +08:00
Merge pull request #3500 from benoitc/feature/binary-dirty-protocol
feat(dirty): implement binary protocol for dirty worker IPC
This commit is contained in:
commit
4b90e4ba16
@ -17,7 +17,7 @@ Dirty Arbiters solve this by providing:
|
||||
|
||||
- **Separate worker pool** - Completely separate from HTTP workers, can be killed/restarted independently
|
||||
- **Stateful workers** - Loaded resources persist in dirty worker memory
|
||||
- **Message-passing IPC** - Communication via Unix sockets with JSON serialization
|
||||
- **Message-passing IPC** - Communication via Unix sockets with binary TLV protocol
|
||||
- **Explicit API** - Clear `execute()` calls (no hidden IPC)
|
||||
- **Asyncio-based** - Clean concurrent handling with streaming support
|
||||
|
||||
@ -476,6 +476,64 @@ Worker -> Arbiter -> Client: chunk (data: "World")
|
||||
Worker -> Arbiter -> Client: end
|
||||
```
|
||||
|
||||
## Binary Protocol
|
||||
|
||||
The dirty worker IPC uses a binary protocol inspired by OpenBSD msgctl/msgsnd for efficient data transfer. This eliminates base64 encoding overhead for binary data like images, audio, or model weights.
|
||||
|
||||
### Header Format (16 bytes)
|
||||
|
||||
```
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Request ID (8 bytes) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
```
|
||||
|
||||
- **Magic**: `0x47 0x44` ("GD" for Gunicorn Dirty)
|
||||
- **Version**: `0x01`
|
||||
- **MType**: Message type (`0x01`=REQUEST, `0x02`=RESPONSE, `0x03`=ERROR, `0x04`=CHUNK, `0x05`=END)
|
||||
- **Length**: Payload size (big-endian uint32, max 64MB)
|
||||
- **Request ID**: uint64 identifier
|
||||
|
||||
### TLV Payload Encoding
|
||||
|
||||
Payloads use Type-Length-Value encoding:
|
||||
|
||||
| Type | Code | Description |
|
||||
|------|------|-------------|
|
||||
| None | `0x00` | No value bytes |
|
||||
| Bool | `0x01` | 1 byte (0x00/0x01) |
|
||||
| Int64 | `0x05` | 8 bytes big-endian signed |
|
||||
| Float64 | `0x06` | 8 bytes IEEE 754 |
|
||||
| Bytes | `0x10` | 4-byte length + raw bytes |
|
||||
| String | `0x11` | 4-byte length + UTF-8 |
|
||||
| List | `0x20` | 4-byte count + elements |
|
||||
| Dict | `0x21` | 4-byte count + key-value pairs |
|
||||
|
||||
### Binary Data Benefits
|
||||
|
||||
The binary protocol allows passing raw bytes directly without encoding:
|
||||
|
||||
```python
|
||||
# Image processing with binary data
|
||||
def resize(self, image_data, width, height):
|
||||
"""Resize an image - image_data is raw bytes."""
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
resized = img.resize((width, height))
|
||||
buffer = io.BytesIO()
|
||||
resized.save(buffer, format='PNG')
|
||||
return buffer.getvalue() # Returns raw bytes
|
||||
|
||||
# Called from HTTP worker
|
||||
thumbnail = client.execute(
|
||||
"myapp.images:ImageApp",
|
||||
"thumbnail",
|
||||
raw_image_bytes, # No base64 encoding needed
|
||||
size=256
|
||||
)
|
||||
```
|
||||
|
||||
### Error Handling in Streams
|
||||
|
||||
Errors during streaming are delivered as error messages:
|
||||
@ -768,7 +826,7 @@ except DirtyConnectionError:
|
||||
2. **Set appropriate timeouts** based on your workload
|
||||
3. **Handle errors gracefully** - dirty workers may restart
|
||||
4. **Use meaningful action names** for easier debugging
|
||||
5. **Keep responses JSON-serializable** - results are passed via IPC
|
||||
5. **Keep responses serializable** - results are passed via binary IPC (supports bytes directly)
|
||||
|
||||
## Monitoring
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ services:
|
||||
context: ../.. # Gunicorn repo root
|
||||
dockerfile: examples/celery_alternative/Dockerfile
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8003:8000"
|
||||
environment:
|
||||
- GUNICORN_WORKERS=4
|
||||
- DIRTY_WORKERS=9
|
||||
|
||||
25
examples/dirty_example/Dockerfile
Normal file
25
examples/dirty_example/Dockerfile
Normal file
@ -0,0 +1,25 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy gunicorn source
|
||||
COPY . /app/gunicorn-src
|
||||
|
||||
# Install gunicorn and dependencies
|
||||
# setproctitle is needed for process title changes
|
||||
RUN pip install --no-cache-dir /app/gunicorn-src setproctitle
|
||||
|
||||
# Copy example files
|
||||
COPY examples/dirty_example/ /app/examples/dirty_example/
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Expose the port
|
||||
EXPOSE 8000
|
||||
|
||||
# Default command - run the example tests
|
||||
CMD ["python", "-m", "pytest", "-v", "examples/dirty_example/"]
|
||||
54
examples/dirty_example/docker-compose.yml
Normal file
54
examples/dirty_example/docker-compose.yml
Normal file
@ -0,0 +1,54 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
services:
|
||||
# Run the example tests (protocol, dirty app, worker integration)
|
||||
tests:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/dirty_example/Dockerfile
|
||||
command: >
|
||||
bash -c "
|
||||
echo '=== Running Protocol Tests ===' &&
|
||||
python examples/dirty_example/test_protocol.py &&
|
||||
echo '' &&
|
||||
echo '=== Running Dirty App Tests ===' &&
|
||||
python examples/dirty_example/test_dirty_app.py &&
|
||||
echo '' &&
|
||||
echo '=== Running Worker Integration Tests ===' &&
|
||||
python examples/dirty_example/test_worker_integration.py &&
|
||||
echo '' &&
|
||||
echo '=== All tests passed! ==='
|
||||
"
|
||||
|
||||
# Run the full gunicorn server with dirty workers
|
||||
server:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/dirty_example/Dockerfile
|
||||
ports:
|
||||
- "8001:8000"
|
||||
environment:
|
||||
- GUNICORN_BIND=0.0.0.0:8000
|
||||
command: >
|
||||
gunicorn examples.dirty_example.wsgi_app:app
|
||||
-c examples/dirty_example/gunicorn_conf.py
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/')"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 10s
|
||||
|
||||
# Run integration test against the server
|
||||
integration-test:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/dirty_example/Dockerfile
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- TEST_BASE_URL=http://server:8000
|
||||
command: python examples/dirty_example/test_integration.py
|
||||
@ -11,7 +11,9 @@ Run with:
|
||||
"""
|
||||
|
||||
# Basic settings
|
||||
bind = "127.0.0.1:8000"
|
||||
# Use 0.0.0.0 for Docker, override with GUNICORN_BIND env var if needed
|
||||
import os
|
||||
bind = os.environ.get("GUNICORN_BIND", "127.0.0.1:8000")
|
||||
workers = 2
|
||||
worker_class = "sync"
|
||||
timeout = 30
|
||||
|
||||
81
examples/dirty_example/test_integration.py
Normal file
81
examples/dirty_example/test_integration.py
Normal file
@ -0,0 +1,81 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Integration test for the dirty example server.
|
||||
|
||||
This tests that the full gunicorn server with dirty workers responds
|
||||
correctly to HTTP requests.
|
||||
|
||||
Run with:
|
||||
python examples/dirty_example/test_integration.py [base_url]
|
||||
|
||||
Default base_url is http://localhost:8000
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
|
||||
|
||||
def test_endpoint(base, path, expected_key=None):
|
||||
"""Test an endpoint and check for expected key in response."""
|
||||
url = base + path
|
||||
print(f"Testing: {url}")
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=10) as resp:
|
||||
data = json.loads(resp.read())
|
||||
print(f" Response: {str(data)[:200]}")
|
||||
if expected_key and expected_key not in data:
|
||||
print(f" ERROR: Expected key '{expected_key}' not found!")
|
||||
return False
|
||||
return True
|
||||
except urllib.error.HTTPError as e:
|
||||
print(f" HTTP ERROR {e.code}: {e.reason}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
# Get base URL from env or command line
|
||||
base = os.environ.get("TEST_BASE_URL", "http://localhost:8000")
|
||||
if len(sys.argv) > 1:
|
||||
base = sys.argv[1]
|
||||
|
||||
print(f"Testing dirty example server at: {base}")
|
||||
print("=" * 60)
|
||||
|
||||
# Define tests: (path, expected_key_in_response)
|
||||
tests = [
|
||||
("/", "endpoints"),
|
||||
("/models", "models"),
|
||||
("/load?name=test-model", "status"),
|
||||
("/inference?model=default&data=hello", "prediction"),
|
||||
("/fibonacci?n=20", "result"),
|
||||
("/prime?n=17", "is_prime"),
|
||||
("/stats", "ml_app"),
|
||||
("/unload?name=test-model", "status"),
|
||||
]
|
||||
|
||||
failed = 0
|
||||
for path, key in tests:
|
||||
if not test_endpoint(base, path, key):
|
||||
failed += 1
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
if failed:
|
||||
print(f"FAILED: {failed} tests failed")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("SUCCESS: All integration tests passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -4,7 +4,10 @@
|
||||
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Test script to demonstrate the Dirty Protocol layer.
|
||||
Test script to demonstrate the Dirty Binary Protocol layer.
|
||||
|
||||
The binary protocol uses a 16-byte header + TLV-encoded payloads for efficient
|
||||
binary data transfer without base64 encoding overhead.
|
||||
|
||||
Run with:
|
||||
python examples/dirty_example/test_protocol.py
|
||||
@ -18,10 +21,14 @@ import socket
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
BinaryProtocol,
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
MAGIC,
|
||||
VERSION,
|
||||
)
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
|
||||
|
||||
@ -29,13 +36,13 @@ from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
|
||||
def test_protocol_encode_decode():
|
||||
"""Test protocol encoding and decoding."""
|
||||
print("=" * 60)
|
||||
print("Testing Protocol Encode/Decode")
|
||||
print("Testing Binary Protocol Encode/Decode")
|
||||
print("=" * 60)
|
||||
|
||||
# Test request
|
||||
# Test request with integer ID (recommended for binary protocol)
|
||||
print("\n1. Creating a request message...")
|
||||
request = make_request(
|
||||
request_id="req-001",
|
||||
request_id=12345, # Integer IDs are efficient
|
||||
app_path="myapp.ml:MLApp",
|
||||
action="inference",
|
||||
args=("model1",),
|
||||
@ -43,18 +50,53 @@ def test_protocol_encode_decode():
|
||||
)
|
||||
print(f" Request: {request}")
|
||||
|
||||
# Encode
|
||||
print("\n2. Encoding message...")
|
||||
encoded = DirtyProtocol.encode(request)
|
||||
# Encode using binary protocol
|
||||
print("\n2. Encoding message with binary protocol...")
|
||||
encoded = BinaryProtocol._encode_from_dict(request)
|
||||
print(f" Encoded length: {len(encoded)} bytes")
|
||||
print(f" Header (4 bytes): {encoded[:4].hex()}")
|
||||
print(f" Header ({HEADER_SIZE} bytes): {encoded[:HEADER_SIZE].hex()}")
|
||||
print(f" Magic: {MAGIC!r}")
|
||||
print(f" Version: {VERSION}")
|
||||
|
||||
# Decode
|
||||
print("\n3. Decoding payload...")
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
print(f" Decoded: {decoded}")
|
||||
print(f" Match: {decoded == request}")
|
||||
# Decode header
|
||||
print("\n3. Decoding header...")
|
||||
msg_type, request_id, payload_len = BinaryProtocol.decode_header(encoded[:HEADER_SIZE])
|
||||
print(f" Message type: {msg_type} (0x{msg_type:02x})")
|
||||
print(f" Request ID: {request_id}")
|
||||
print(f" Payload length: {payload_len} bytes")
|
||||
|
||||
# Decode full message
|
||||
print("\n4. Decoding full message...")
|
||||
msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
print(f" Type: {msg_type_str}")
|
||||
print(f" Request ID: {req_id}")
|
||||
print(f" Payload: {payload}")
|
||||
|
||||
|
||||
def test_binary_data_handling():
|
||||
"""Test binary data handling - the main advantage of binary protocol."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Testing Binary Data Handling")
|
||||
print("=" * 60)
|
||||
|
||||
# Create binary data (e.g., image, audio, model weights)
|
||||
binary_data = bytes(range(256)) # All byte values
|
||||
print(f"\n1. Original binary data: {len(binary_data)} bytes")
|
||||
print(f" First 16 bytes: {binary_data[:16].hex()}")
|
||||
|
||||
# Create response with binary data (no base64 needed!)
|
||||
print("\n2. Encoding binary data in response...")
|
||||
response = make_response(67890, {"image_data": binary_data, "format": "raw"})
|
||||
encoded = BinaryProtocol._encode_from_dict(response)
|
||||
print(f" Encoded total size: {len(encoded)} bytes")
|
||||
|
||||
# Decode and verify
|
||||
print("\n3. Decoding binary data...")
|
||||
msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
recovered_data = payload["result"]["image_data"]
|
||||
print(f" Recovered data size: {len(recovered_data)} bytes")
|
||||
print(f" Data matches: {recovered_data == binary_data}")
|
||||
print(f" First 16 bytes: {recovered_data[:16].hex()}")
|
||||
|
||||
|
||||
def test_protocol_response():
|
||||
@ -65,13 +107,13 @@ def test_protocol_response():
|
||||
|
||||
# Success response
|
||||
print("\n1. Creating success response...")
|
||||
response = make_response("req-001", {"result": "Hello, World!", "confidence": 0.95})
|
||||
response = make_response(12345, {"result": "Hello, World!", "confidence": 0.95})
|
||||
print(f" Response: {response}")
|
||||
|
||||
# Error response
|
||||
print("\n2. Creating error response...")
|
||||
error = DirtyTimeoutError("Operation timed out", timeout=30)
|
||||
error_response = make_error_response("req-001", error)
|
||||
error_response = make_error_response(12345, error)
|
||||
print(f" Error response: {error_response}")
|
||||
|
||||
|
||||
@ -88,7 +130,7 @@ def test_socket_communication():
|
||||
# Send a request
|
||||
print("\n1. Sending request over socket...")
|
||||
request = make_request(
|
||||
request_id="socket-req-001",
|
||||
request_id=100001,
|
||||
app_path="test:App",
|
||||
action="compute",
|
||||
args=(1, 2, 3),
|
||||
@ -101,19 +143,20 @@ def test_socket_communication():
|
||||
print("\n2. Receiving request...")
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
print(f" Received: {received}")
|
||||
print(f" Match: {received == request}")
|
||||
print(f" Request ID: {received['id']}")
|
||||
|
||||
# Send a response
|
||||
print("\n3. Sending response...")
|
||||
response = make_response("socket-req-001", {"sum": 6})
|
||||
# Send a response with binary data
|
||||
print("\n3. Sending response with binary data...")
|
||||
binary_result = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
|
||||
response = make_response(100001, {"data": binary_result, "sum": 6})
|
||||
DirtyProtocol.write_message(server_sock, response)
|
||||
print(f" Sent: {response}")
|
||||
print(f" Sent binary data: {binary_result.hex()}")
|
||||
|
||||
# Receive the response
|
||||
print("\n4. Receiving response...")
|
||||
received = DirtyProtocol.read_message(client_sock)
|
||||
print(f" Received: {received}")
|
||||
print(f" Match: {received == response}")
|
||||
print(f" Received binary data: {received['result']['data'].hex()}")
|
||||
print(f" Sum: {received['result']['sum']}")
|
||||
|
||||
finally:
|
||||
server_sock.close()
|
||||
@ -132,7 +175,7 @@ async def test_async_communication():
|
||||
try:
|
||||
# Create message
|
||||
request = make_request(
|
||||
request_id="async-req-001",
|
||||
request_id=200001,
|
||||
app_path="async:App",
|
||||
action="process",
|
||||
args=("data",),
|
||||
@ -141,7 +184,7 @@ async def test_async_communication():
|
||||
|
||||
# Write to pipe
|
||||
print("\n1. Writing async message...")
|
||||
encoded = DirtyProtocol.encode(request)
|
||||
encoded = BinaryProtocol._encode_from_dict(request)
|
||||
os.write(write_fd, encoded)
|
||||
os.close(write_fd)
|
||||
write_fd = None
|
||||
@ -156,7 +199,7 @@ async def test_async_communication():
|
||||
|
||||
received = await DirtyProtocol.read_message_async(reader)
|
||||
print(f" Received: {received}")
|
||||
print(f" Match: {received == request}")
|
||||
print(f" Request ID: {received['id']}")
|
||||
|
||||
finally:
|
||||
if write_fd is not None:
|
||||
@ -193,10 +236,11 @@ def test_error_serialization():
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("\n" + "#" * 60)
|
||||
print("# Dirty Protocol Demonstration")
|
||||
print("# Dirty Binary Protocol Demonstration")
|
||||
print("#" * 60)
|
||||
|
||||
test_protocol_encode_decode()
|
||||
test_binary_data_handling()
|
||||
test_protocol_response()
|
||||
test_socket_communication()
|
||||
asyncio.run(test_async_communication())
|
||||
|
||||
@ -22,7 +22,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, make_request, HEADER_SIZE
|
||||
|
||||
|
||||
class MockLog:
|
||||
@ -35,6 +35,36 @@ class MockLog:
|
||||
def reopen_files(self): pass
|
||||
|
||||
|
||||
class MockWriter:
|
||||
"""Mock StreamWriter that captures written responses."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode messages from buffer using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
_, _, length = BinaryProtocol.decode_header(self._buffer[:HEADER_SIZE])
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
def get_last_response(self):
|
||||
"""Get the last response message."""
|
||||
return self.messages[-1] if self.messages else None
|
||||
|
||||
|
||||
async def test_worker_request_handling():
|
||||
"""Test that a worker can load apps and handle requests."""
|
||||
print("=" * 60)
|
||||
@ -75,52 +105,60 @@ async def test_worker_request_handling():
|
||||
# Test handle_request with a proper request message
|
||||
print("\n3. Testing handle_request() - load_model...")
|
||||
request = make_request(
|
||||
request_id="test-001",
|
||||
request_id=1001,
|
||||
app_path="examples.dirty_example.dirty_app:MLApp",
|
||||
action="load_model",
|
||||
args=("gpt-4",),
|
||||
kwargs={}
|
||||
)
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
response = writer.get_last_response()
|
||||
print(f" Response type: {response['type']}")
|
||||
print(f" Result: {response.get('result', response.get('error'))}")
|
||||
|
||||
# Test inference
|
||||
print("\n4. Testing handle_request() - inference...")
|
||||
request = make_request(
|
||||
request_id="test-002",
|
||||
request_id=1002,
|
||||
app_path="examples.dirty_example.dirty_app:MLApp",
|
||||
action="inference",
|
||||
args=("default", "Hello AI!"),
|
||||
kwargs={}
|
||||
)
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
response = writer.get_last_response()
|
||||
print(f" Response type: {response['type']}")
|
||||
print(f" Result: {response.get('result', response.get('error'))}")
|
||||
|
||||
# Test error handling
|
||||
print("\n5. Testing error handling - unknown action...")
|
||||
request = make_request(
|
||||
request_id="test-003",
|
||||
request_id=1003,
|
||||
app_path="examples.dirty_example.dirty_app:MLApp",
|
||||
action="nonexistent_action",
|
||||
args=(),
|
||||
kwargs={}
|
||||
)
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
response = writer.get_last_response()
|
||||
print(f" Response type: {response['type']}")
|
||||
print(f" Error: {response.get('error', {}).get('message')}")
|
||||
|
||||
# Test app not found
|
||||
print("\n6. Testing error handling - app not found...")
|
||||
request = make_request(
|
||||
request_id="test-004",
|
||||
request_id=1004,
|
||||
app_path="nonexistent:App",
|
||||
action="test",
|
||||
args=(),
|
||||
kwargs={}
|
||||
)
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
response = writer.get_last_response()
|
||||
print(f" Response type: {response['type']}")
|
||||
print(f" Error type: {response.get('error', {}).get('error_type')}")
|
||||
|
||||
|
||||
@ -3,89 +3,304 @@
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Dirty Arbiters Protocol
|
||||
Dirty Worker Binary Protocol
|
||||
|
||||
Length-prefixed JSON message framing over Unix sockets.
|
||||
Provides both async (primary) and sync (for HTTP workers) APIs.
|
||||
Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd.
|
||||
Replaces JSON protocol for efficient binary data transfer.
|
||||
|
||||
Message Format:
|
||||
+----------------+------------------+
|
||||
| 4-byte length | JSON payload |
|
||||
+----------------+------------------+
|
||||
Header Format (16 bytes):
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Magic (2B) | Ver(1) | MType | Payload Length (4B) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
| Request ID (8 bytes) |
|
||||
+--------+--------+--------+--------+--------+--------+--------+--------+
|
||||
|
||||
The length field is a 4-byte unsigned integer in network byte order (big-endian).
|
||||
- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty)
|
||||
- Version: 0x01
|
||||
- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END)
|
||||
- Length: Payload size (big-endian uint32, max 64MB)
|
||||
- Request ID: uint64 (replaces UUID string)
|
||||
|
||||
Payload is TLV-encoded (see tlv.py).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import struct
|
||||
import socket
|
||||
import struct
|
||||
|
||||
from .errors import DirtyProtocolError
|
||||
from .tlv import TLVEncoder
|
||||
|
||||
|
||||
class DirtyProtocol:
|
||||
"""Length-prefixed JSON messages over Unix sockets."""
|
||||
# Protocol constants
|
||||
MAGIC = b"GD" # 0x47 0x44
|
||||
VERSION = 0x01
|
||||
|
||||
# 4-byte unsigned int, network byte order (big-endian)
|
||||
HEADER_FORMAT = "!I"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
# Message types (1 byte)
|
||||
MSG_TYPE_REQUEST = 0x01
|
||||
MSG_TYPE_RESPONSE = 0x02
|
||||
MSG_TYPE_ERROR = 0x03
|
||||
MSG_TYPE_CHUNK = 0x04
|
||||
MSG_TYPE_END = 0x05
|
||||
|
||||
# Maximum message size (64 MB)
|
||||
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
|
||||
# Message type names (for backwards compatibility with old API)
|
||||
MSG_TYPE_REQUEST_STR = "request"
|
||||
MSG_TYPE_RESPONSE_STR = "response"
|
||||
MSG_TYPE_ERROR_STR = "error"
|
||||
MSG_TYPE_CHUNK_STR = "chunk"
|
||||
MSG_TYPE_END_STR = "end"
|
||||
|
||||
# Message types for future streaming support
|
||||
MSG_TYPE_REQUEST = "request"
|
||||
MSG_TYPE_RESPONSE = "response"
|
||||
MSG_TYPE_ERROR = "error"
|
||||
MSG_TYPE_CHUNK = "chunk"
|
||||
MSG_TYPE_END = "end"
|
||||
# Map int types to string names
|
||||
MSG_TYPE_TO_STR = {
|
||||
MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR,
|
||||
MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR,
|
||||
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
|
||||
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
|
||||
MSG_TYPE_END: MSG_TYPE_END_STR,
|
||||
}
|
||||
|
||||
# Map string names to int types
|
||||
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
|
||||
|
||||
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
|
||||
HEADER_FORMAT = ">2sBBIQ"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
|
||||
# Maximum message size (64 MB)
|
||||
MAX_MESSAGE_SIZE = 64 * 1024 * 1024
|
||||
|
||||
|
||||
class BinaryProtocol:
|
||||
"""Binary message protocol for dirty worker IPC."""
|
||||
|
||||
# Export constants for external use
|
||||
HEADER_SIZE = HEADER_SIZE
|
||||
MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE
|
||||
|
||||
MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR
|
||||
MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR
|
||||
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
|
||||
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
|
||||
MSG_TYPE_END = MSG_TYPE_END_STR
|
||||
|
||||
@staticmethod
|
||||
def encode(message: dict) -> bytes:
|
||||
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
|
||||
"""
|
||||
Encode a message dict to length-prefixed bytes.
|
||||
Encode the 16-byte message header.
|
||||
|
||||
Args:
|
||||
message: Dictionary to encode as JSON
|
||||
msg_type: Message type (MSG_TYPE_REQUEST, etc.)
|
||||
request_id: Unique request identifier (uint64)
|
||||
payload_length: Length of the TLV-encoded payload
|
||||
|
||||
Returns:
|
||||
bytes: Length-prefixed encoded message
|
||||
bytes: 16-byte header
|
||||
"""
|
||||
return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type,
|
||||
payload_length, request_id)
|
||||
|
||||
@staticmethod
|
||||
def decode_header(data: bytes) -> tuple:
|
||||
"""
|
||||
Decode the 16-byte message header.
|
||||
|
||||
Args:
|
||||
data: 16 bytes of header data
|
||||
|
||||
Returns:
|
||||
tuple: (msg_type, request_id, payload_length)
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
DirtyProtocolError: If header is invalid
|
||||
"""
|
||||
try:
|
||||
payload = json.dumps(message).encode("utf-8")
|
||||
if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
if len(data) < HEADER_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}",
|
||||
raw_data=data
|
||||
)
|
||||
|
||||
magic, version, msg_type, length, request_id = struct.unpack(
|
||||
HEADER_FORMAT, data[:HEADER_SIZE]
|
||||
)
|
||||
|
||||
if magic != MAGIC:
|
||||
raise DirtyProtocolError(
|
||||
f"Invalid magic: {magic!r}, expected {MAGIC!r}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if version != VERSION:
|
||||
raise DirtyProtocolError(
|
||||
f"Unsupported protocol version: {version}, expected {VERSION}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if msg_type not in MSG_TYPE_TO_STR:
|
||||
raise DirtyProtocolError(
|
||||
f"Unknown message type: 0x{msg_type:02x}",
|
||||
raw_data=data[:20]
|
||||
)
|
||||
|
||||
if length > MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
return msg_type, request_id, length
|
||||
|
||||
@staticmethod
|
||||
def encode_request(request_id: int, app_path: str, action: str,
|
||||
args: tuple = None, kwargs: dict = None) -> bytes:
|
||||
"""
|
||||
Encode a request message.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (uint64)
|
||||
app_path: Import path of the dirty app
|
||||
action: Action to call on the app
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {
|
||||
"app_path": app_path,
|
||||
"action": action,
|
||||
"args": list(args) if args else [],
|
||||
"kwargs": kwargs or {},
|
||||
}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_response(request_id: int, result) -> bytes:
|
||||
"""
|
||||
Encode a success response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
result: Result value (must be TLV-serializable)
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {"result": result}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_error(request_id: int, error) -> bytes:
|
||||
"""
|
||||
Encode an error response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
error: DirtyError instance, dict, or Exception
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
from .errors import DirtyError
|
||||
|
||||
if isinstance(error, DirtyError):
|
||||
error_dict = error.to_dict()
|
||||
elif isinstance(error, dict):
|
||||
error_dict = error
|
||||
else:
|
||||
error_dict = {
|
||||
"error_type": type(error).__name__,
|
||||
"message": str(error),
|
||||
"details": {},
|
||||
}
|
||||
|
||||
payload_dict = {"error": error_dict}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_chunk(request_id: int, data) -> bytes:
|
||||
"""
|
||||
Encode a chunk message for streaming responses.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this chunk belongs to
|
||||
data: Chunk data (must be TLV-serializable)
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {"data": data}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_end(request_id: int) -> bytes:
|
||||
"""
|
||||
Encode an end-of-stream message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this ends
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + empty payload)
|
||||
"""
|
||||
# End message has empty payload
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
|
||||
return header
|
||||
|
||||
@staticmethod
|
||||
def decode_message(data: bytes) -> tuple:
|
||||
"""
|
||||
Decode a complete message (header + payload).
|
||||
|
||||
Args:
|
||||
data: Complete message bytes
|
||||
|
||||
Returns:
|
||||
tuple: (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str is the string name (e.g., "request")
|
||||
payload_dict is the decoded TLV payload as a dict
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If message is malformed
|
||||
"""
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(data)
|
||||
|
||||
if len(data) < HEADER_SIZE + length:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete message: expected {HEADER_SIZE + length} bytes, "
|
||||
f"got {len(data)}",
|
||||
raw_data=data[:50]
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
# End message has empty payload
|
||||
payload_dict = {}
|
||||
else:
|
||||
payload_data = data[HEADER_SIZE:HEADER_SIZE + length]
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {len(payload)} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload))
|
||||
return length + payload
|
||||
except (TypeError, ValueError) as e:
|
||||
raise DirtyProtocolError(f"Failed to encode message: {e}")
|
||||
|
||||
@staticmethod
|
||||
def decode(data: bytes) -> dict:
|
||||
"""
|
||||
Decode bytes (without length prefix) to message dict.
|
||||
# Convert to dict format similar to old JSON protocol
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
|
||||
Args:
|
||||
data: JSON bytes to decode
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If decoding fails
|
||||
"""
|
||||
try:
|
||||
return json.loads(data.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise DirtyProtocolError(f"Failed to decode message: {e}",
|
||||
raw_data=data)
|
||||
return msg_type_str, request_id, payload_dict
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async API (primary - for DirtyArbiter and DirtyWorker)
|
||||
@ -94,53 +309,62 @@ class DirtyProtocol:
|
||||
@staticmethod
|
||||
async def read_message_async(reader: asyncio.StreamReader) -> dict:
|
||||
"""
|
||||
Read a complete message from async stream.
|
||||
Read a complete binary message from async stream.
|
||||
|
||||
Args:
|
||||
reader: asyncio StreamReader
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
dict: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If read fails or message is malformed
|
||||
asyncio.IncompleteReadError: If connection closed mid-read
|
||||
"""
|
||||
# Read length header
|
||||
# Read header
|
||||
try:
|
||||
header = await reader.readexactly(DirtyProtocol.HEADER_SIZE)
|
||||
header = await reader.readexactly(HEADER_SIZE)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
if len(e.partial) == 0:
|
||||
# Clean close - no data was read
|
||||
raise
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete header: got {len(e.partial)} bytes, "
|
||||
f"expected {DirtyProtocol.HEADER_SIZE}",
|
||||
f"expected {HEADER_SIZE}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
|
||||
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
|
||||
|
||||
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
raise DirtyProtocolError("Empty message received")
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
|
||||
# Read payload
|
||||
try:
|
||||
payload = await reader.readexactly(length)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete message: got {len(e.partial)} bytes, "
|
||||
f"expected {length}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
if length > 0:
|
||||
try:
|
||||
payload_data = await reader.readexactly(length)
|
||||
except asyncio.IncompleteReadError as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Incomplete payload: got {len(e.partial)} bytes, "
|
||||
f"expected {length}",
|
||||
raw_data=e.partial
|
||||
)
|
||||
|
||||
return DirtyProtocol.decode(payload)
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
else:
|
||||
payload_dict = {}
|
||||
|
||||
# Build response dict
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def write_message_async(writer: asyncio.StreamWriter,
|
||||
@ -148,15 +372,17 @@ class DirtyProtocol:
|
||||
"""
|
||||
Write a message to async stream.
|
||||
|
||||
Accepts dict format for backwards compatibility.
|
||||
|
||||
Args:
|
||||
writer: asyncio StreamWriter
|
||||
message: Dictionary to send
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
ConnectionError: If write fails
|
||||
"""
|
||||
data = DirtyProtocol.encode(message)
|
||||
data = BinaryProtocol._encode_from_dict(message)
|
||||
writer.write(data)
|
||||
await writer.drain()
|
||||
|
||||
@ -201,27 +427,36 @@ class DirtyProtocol:
|
||||
sock: Socket to read from
|
||||
|
||||
Returns:
|
||||
dict: Decoded message
|
||||
dict: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If read fails or message is malformed
|
||||
"""
|
||||
# Read length header
|
||||
header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE)
|
||||
length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0]
|
||||
|
||||
if length > DirtyProtocol.MAX_MESSAGE_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Message too large: {length} bytes "
|
||||
f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})"
|
||||
)
|
||||
|
||||
if length == 0:
|
||||
raise DirtyProtocolError("Empty message received")
|
||||
# Read header
|
||||
header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE)
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
|
||||
# Read payload
|
||||
payload = DirtyProtocol._recv_exactly(sock, length)
|
||||
return DirtyProtocol.decode(payload)
|
||||
if length > 0:
|
||||
payload_data = BinaryProtocol._recv_exactly(sock, length)
|
||||
try:
|
||||
payload_dict = TLVEncoder.decode_full(payload_data)
|
||||
except DirtyProtocolError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Failed to decode TLV payload: {e}",
|
||||
raw_data=payload_data[:50]
|
||||
)
|
||||
else:
|
||||
payload_dict = {}
|
||||
|
||||
# Build response dict
|
||||
msg_type_str = MSG_TYPE_TO_STR[msg_type]
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def write_message(sock: socket.socket, message: dict) -> None:
|
||||
@ -230,31 +465,92 @@ class DirtyProtocol:
|
||||
|
||||
Args:
|
||||
sock: Socket to write to
|
||||
message: Dictionary to send
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If encoding fails
|
||||
OSError: If write fails
|
||||
"""
|
||||
data = DirtyProtocol.encode(message)
|
||||
data = BinaryProtocol._encode_from_dict(message)
|
||||
sock.sendall(data)
|
||||
|
||||
@staticmethod
|
||||
def _encode_from_dict(message: dict) -> bytes:
|
||||
"""
|
||||
Encode a message dict to binary format.
|
||||
|
||||
# Message builder helpers
|
||||
def make_request(request_id: str, app_path: str, action: str,
|
||||
Supports the old dict-based API for backwards compatibility.
|
||||
|
||||
Args:
|
||||
message: Message dict with 'type', 'id', and payload fields
|
||||
|
||||
Returns:
|
||||
bytes: Complete encoded message
|
||||
"""
|
||||
msg_type_str = message.get("type")
|
||||
request_id = message.get("id", 0)
|
||||
|
||||
# Handle string or int request IDs
|
||||
if isinstance(request_id, str):
|
||||
# For backwards compat with UUID strings, hash to int
|
||||
request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
msg_type = MSG_TYPE_FROM_STR.get(msg_type_str)
|
||||
if msg_type is None:
|
||||
raise DirtyProtocolError(f"Unknown message type: {msg_type_str}")
|
||||
|
||||
if msg_type == MSG_TYPE_REQUEST:
|
||||
return BinaryProtocol.encode_request(
|
||||
request_id,
|
||||
message.get("app_path", ""),
|
||||
message.get("action", ""),
|
||||
message.get("args"),
|
||||
message.get("kwargs")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_RESPONSE:
|
||||
return BinaryProtocol.encode_response(
|
||||
request_id,
|
||||
message.get("result")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_ERROR:
|
||||
return BinaryProtocol.encode_error(
|
||||
request_id,
|
||||
message.get("error", {})
|
||||
)
|
||||
elif msg_type == MSG_TYPE_CHUNK:
|
||||
return BinaryProtocol.encode_chunk(
|
||||
request_id,
|
||||
message.get("data")
|
||||
)
|
||||
elif msg_type == MSG_TYPE_END:
|
||||
return BinaryProtocol.encode_end(request_id)
|
||||
else:
|
||||
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backwards Compatibility Aliases
|
||||
# =============================================================================
|
||||
|
||||
# Alias BinaryProtocol as DirtyProtocol for drop-in replacement
|
||||
DirtyProtocol = BinaryProtocol
|
||||
|
||||
|
||||
# Message builder helpers (backwards compatible with old API)
|
||||
def make_request(request_id, app_path: str, action: str,
|
||||
args: tuple = None, kwargs: dict = None) -> dict:
|
||||
"""
|
||||
Build a request message.
|
||||
Build a request message dict.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier
|
||||
request_id: Unique request identifier (int or str)
|
||||
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
|
||||
action: Action to call on the app
|
||||
args: Positional arguments
|
||||
kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
dict: Request message
|
||||
dict: Request message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_REQUEST,
|
||||
@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str,
|
||||
}
|
||||
|
||||
|
||||
def make_response(request_id: str, result) -> dict:
|
||||
def make_response(request_id, result) -> dict:
|
||||
"""
|
||||
Build a success response message.
|
||||
Build a success response message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
result: Result value (must be JSON-serializable)
|
||||
result: Result value
|
||||
|
||||
Returns:
|
||||
dict: Response message
|
||||
dict: Response message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_RESPONSE,
|
||||
@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_error_response(request_id: str, error) -> dict:
|
||||
def make_error_response(request_id, error) -> dict:
|
||||
"""
|
||||
Build an error response message.
|
||||
Build an error response message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this responds to
|
||||
error: DirtyError instance or dict with error info
|
||||
|
||||
Returns:
|
||||
dict: Error response message
|
||||
dict: Error response message dict
|
||||
"""
|
||||
from .errors import DirtyError
|
||||
if isinstance(error, DirtyError):
|
||||
@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_chunk_message(request_id: str, data) -> dict:
|
||||
def make_chunk_message(request_id, data) -> dict:
|
||||
"""
|
||||
Build a chunk message for streaming responses.
|
||||
Build a chunk message dict for streaming responses.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this chunk belongs to
|
||||
data: Chunk data (must be JSON-serializable)
|
||||
data: Chunk data
|
||||
|
||||
Returns:
|
||||
dict: Chunk message
|
||||
dict: Chunk message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_CHUNK,
|
||||
@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def make_end_message(request_id: str) -> dict:
|
||||
def make_end_message(request_id) -> dict:
|
||||
"""
|
||||
Build an end-of-stream message.
|
||||
Build an end-of-stream message dict.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this ends
|
||||
|
||||
Returns:
|
||||
dict: End message
|
||||
dict: End message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_END,
|
||||
|
||||
303
gunicorn/dirty/tlv.py
Normal file
303
gunicorn/dirty/tlv.py
Normal file
@ -0,0 +1,303 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
TLV (Type-Length-Value) Binary Encoder/Decoder
|
||||
|
||||
Provides efficient binary serialization for dirty worker protocol messages.
|
||||
Inspired by OpenBSD msgctl/msgsnd message format.
|
||||
|
||||
Type Codes:
|
||||
0x00: None (no value bytes)
|
||||
0x01: bool (1 byte: 0x00 or 0x01)
|
||||
0x05: int64 (8 bytes big-endian signed)
|
||||
0x06: float64 (8 bytes IEEE 754)
|
||||
0x10: bytes (4-byte length + raw bytes)
|
||||
0x11: string (4-byte length + UTF-8 encoded)
|
||||
0x20: list (4-byte count + encoded elements)
|
||||
0x21: dict (4-byte count + encoded key-value pairs)
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from .errors import DirtyProtocolError
|
||||
|
||||
|
||||
# Type codes
|
||||
TYPE_NONE = 0x00
|
||||
TYPE_BOOL = 0x01
|
||||
TYPE_INT64 = 0x05
|
||||
TYPE_FLOAT64 = 0x06
|
||||
TYPE_BYTES = 0x10
|
||||
TYPE_STRING = 0x11
|
||||
TYPE_LIST = 0x20
|
||||
TYPE_DICT = 0x21
|
||||
|
||||
# Maximum sizes for safety
|
||||
MAX_STRING_SIZE = 64 * 1024 * 1024 # 64 MB
|
||||
MAX_BYTES_SIZE = 64 * 1024 * 1024 # 64 MB
|
||||
MAX_LIST_SIZE = 1024 * 1024 # 1 million items
|
||||
MAX_DICT_SIZE = 1024 * 1024 # 1 million items
|
||||
|
||||
|
||||
class TLVEncoder:
|
||||
"""
|
||||
TLV binary encoder/decoder.
|
||||
|
||||
Encodes Python values to binary TLV format and decodes back.
|
||||
Supports: None, bool, int, float, bytes, str, list, dict.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def encode(value) -> bytes: # pylint: disable=too-many-return-statements
|
||||
"""
|
||||
Encode a Python value to TLV binary format.
|
||||
|
||||
Args:
|
||||
value: Python value to encode (None, bool, int, float,
|
||||
bytes, str, list, or dict)
|
||||
|
||||
Returns:
|
||||
bytes: TLV-encoded binary data
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If value type is not supported
|
||||
"""
|
||||
if value is None:
|
||||
return bytes([TYPE_NONE])
|
||||
|
||||
if isinstance(value, bool):
|
||||
# bool must come before int since bool is a subclass of int
|
||||
return bytes([TYPE_BOOL, 0x01 if value else 0x00])
|
||||
|
||||
if isinstance(value, int):
|
||||
return bytes([TYPE_INT64]) + struct.pack(">q", value)
|
||||
|
||||
if isinstance(value, float):
|
||||
return bytes([TYPE_FLOAT64]) + struct.pack(">d", value)
|
||||
|
||||
if isinstance(value, bytes):
|
||||
if len(value) > MAX_BYTES_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Bytes too large: {len(value)} bytes "
|
||||
f"(max: {MAX_BYTES_SIZE})"
|
||||
)
|
||||
return bytes([TYPE_BYTES]) + struct.pack(">I", len(value)) + value
|
||||
|
||||
if isinstance(value, str):
|
||||
encoded = value.encode("utf-8")
|
||||
if len(encoded) > MAX_STRING_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"String too large: {len(encoded)} bytes "
|
||||
f"(max: {MAX_STRING_SIZE})"
|
||||
)
|
||||
return bytes([TYPE_STRING]) + struct.pack(">I", len(encoded)) + encoded
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
if len(value) > MAX_LIST_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"List too large: {len(value)} items "
|
||||
f"(max: {MAX_LIST_SIZE})"
|
||||
)
|
||||
parts = [bytes([TYPE_LIST]), struct.pack(">I", len(value))]
|
||||
for item in value:
|
||||
parts.append(TLVEncoder.encode(item))
|
||||
return b"".join(parts)
|
||||
|
||||
if isinstance(value, dict):
|
||||
if len(value) > MAX_DICT_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Dict too large: {len(value)} items "
|
||||
f"(max: {MAX_DICT_SIZE})"
|
||||
)
|
||||
parts = [bytes([TYPE_DICT]), struct.pack(">I", len(value))]
|
||||
for k, v in value.items():
|
||||
# Convert keys to strings (like JSON)
|
||||
if not isinstance(k, str):
|
||||
k = str(k)
|
||||
parts.append(TLVEncoder.encode(k))
|
||||
parts.append(TLVEncoder.encode(v))
|
||||
return b"".join(parts)
|
||||
|
||||
raise DirtyProtocolError(
|
||||
f"Unsupported type for TLV encoding: {type(value).__name__}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def decode(data: bytes, offset: int = 0) -> tuple: # pylint: disable=too-many-return-statements
|
||||
"""
|
||||
Decode a TLV-encoded value from binary data.
|
||||
|
||||
Args:
|
||||
data: Binary data to decode
|
||||
offset: Starting offset in the data
|
||||
|
||||
Returns:
|
||||
tuple: (decoded_value, new_offset)
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If data is malformed or truncated
|
||||
"""
|
||||
if offset >= len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: no type byte",
|
||||
raw_data=data[offset:offset + 20]
|
||||
)
|
||||
|
||||
type_code = data[offset]
|
||||
offset += 1
|
||||
|
||||
if type_code == TYPE_NONE:
|
||||
return None, offset
|
||||
|
||||
if type_code == TYPE_BOOL:
|
||||
if offset >= len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: missing bool value",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
value = data[offset] != 0x00
|
||||
return value, offset + 1
|
||||
|
||||
if type_code == TYPE_INT64:
|
||||
if offset + 8 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete int64",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
value = struct.unpack(">q", data[offset:offset + 8])[0]
|
||||
return value, offset + 8
|
||||
|
||||
if type_code == TYPE_FLOAT64:
|
||||
if offset + 8 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete float64",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
value = struct.unpack(">d", data[offset:offset + 8])[0]
|
||||
return value, offset + 8
|
||||
|
||||
if type_code == TYPE_BYTES:
|
||||
if offset + 4 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete bytes length",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
length = struct.unpack(">I", data[offset:offset + 4])[0]
|
||||
offset += 4
|
||||
|
||||
if length > MAX_BYTES_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Bytes too large: {length} bytes (max: {MAX_BYTES_SIZE})"
|
||||
)
|
||||
|
||||
if offset + length > len(data):
|
||||
raise DirtyProtocolError(
|
||||
f"Truncated TLV data: expected {length} bytes, "
|
||||
f"got {len(data) - offset}",
|
||||
raw_data=data[offset - 5:offset + 20]
|
||||
)
|
||||
value = data[offset:offset + length]
|
||||
return value, offset + length
|
||||
|
||||
if type_code == TYPE_STRING:
|
||||
if offset + 4 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete string length",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
length = struct.unpack(">I", data[offset:offset + 4])[0]
|
||||
offset += 4
|
||||
|
||||
if length > MAX_STRING_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"String too large: {length} bytes (max: {MAX_STRING_SIZE})"
|
||||
)
|
||||
|
||||
if offset + length > len(data):
|
||||
raise DirtyProtocolError(
|
||||
f"Truncated TLV data: expected {length} bytes for string, "
|
||||
f"got {len(data) - offset}",
|
||||
raw_data=data[offset - 5:offset + 20]
|
||||
)
|
||||
try:
|
||||
value = data[offset:offset + length].decode("utf-8")
|
||||
except UnicodeDecodeError as e:
|
||||
raise DirtyProtocolError(
|
||||
f"Invalid UTF-8 in string: {e}",
|
||||
raw_data=data[offset:offset + min(length, 20)]
|
||||
)
|
||||
return value, offset + length
|
||||
|
||||
if type_code == TYPE_LIST:
|
||||
if offset + 4 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete list count",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
count = struct.unpack(">I", data[offset:offset + 4])[0]
|
||||
offset += 4
|
||||
|
||||
if count > MAX_LIST_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"List too large: {count} items (max: {MAX_LIST_SIZE})"
|
||||
)
|
||||
|
||||
items = []
|
||||
for _ in range(count):
|
||||
item, offset = TLVEncoder.decode(data, offset)
|
||||
items.append(item)
|
||||
return items, offset
|
||||
|
||||
if type_code == TYPE_DICT:
|
||||
if offset + 4 > len(data):
|
||||
raise DirtyProtocolError(
|
||||
"Truncated TLV data: incomplete dict count",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
count = struct.unpack(">I", data[offset:offset + 4])[0]
|
||||
offset += 4
|
||||
|
||||
if count > MAX_DICT_SIZE:
|
||||
raise DirtyProtocolError(
|
||||
f"Dict too large: {count} items (max: {MAX_DICT_SIZE})"
|
||||
)
|
||||
|
||||
result = {}
|
||||
for _ in range(count):
|
||||
key, offset = TLVEncoder.decode(data, offset)
|
||||
if not isinstance(key, str):
|
||||
raise DirtyProtocolError(
|
||||
f"Dict key must be string, got {type(key).__name__}"
|
||||
)
|
||||
value, offset = TLVEncoder.decode(data, offset)
|
||||
result[key] = value
|
||||
return result, offset
|
||||
|
||||
raise DirtyProtocolError(
|
||||
f"Unknown TLV type code: 0x{type_code:02x}",
|
||||
raw_data=data[offset - 1:offset + 20]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def decode_full(data: bytes):
|
||||
"""
|
||||
Decode a complete TLV-encoded value, ensuring all data is consumed.
|
||||
|
||||
Args:
|
||||
data: Binary data to decode
|
||||
|
||||
Returns:
|
||||
Decoded Python value
|
||||
|
||||
Raises:
|
||||
DirtyProtocolError: If data is malformed or has trailing bytes
|
||||
"""
|
||||
value, offset = TLVEncoder.decode(data, 0)
|
||||
if offset != len(data):
|
||||
raise DirtyProtocolError(
|
||||
f"Trailing data after TLV: {len(data) - offset} bytes",
|
||||
raw_data=data[offset:offset + 20]
|
||||
)
|
||||
return value
|
||||
@ -12,11 +12,13 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.errors import DirtyError
|
||||
@ -34,16 +36,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -63,7 +71,7 @@ class MockStreamReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding:
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Mock worker connection that returns chunks
|
||||
chunk1 = make_chunk_message("req-123", "Hello")
|
||||
chunk2 = make_chunk_message("req-123", " World")
|
||||
end = make_end_message("req-123")
|
||||
chunk1 = make_chunk_message(123, "Hello")
|
||||
chunk2 = make_chunk_message(123, " World")
|
||||
end = make_end_message(123)
|
||||
|
||||
mock_reader = MockStreamReader([chunk1, chunk2, end])
|
||||
|
||||
@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have forwarded all messages
|
||||
@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", {"result": 42})
|
||||
response = make_response(123, {"result": 42})
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
request = make_request(123, "test:App", "compute")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
chunk = make_chunk_message("req-123", "First")
|
||||
error = make_error_response("req-123", DirtyError("Something broke"))
|
||||
chunk = make_chunk_message(123, "First")
|
||||
error = make_error_response(123, DirtyError("Something broke"))
|
||||
|
||||
mock_reader = MockStreamReader([chunk, error])
|
||||
|
||||
@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 2
|
||||
@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming:
|
||||
arbiter.workers = {} # No workers
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter.route_request(request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming:
|
||||
|
||||
# Mock _execute_on_worker to complete immediately
|
||||
async def mock_execute(pid, request, client_writer):
|
||||
response = make_response("req-123", "result")
|
||||
response = make_response(123, "result")
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
arbiter._execute_on_worker = mock_execute
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
request = make_request(123, "test:App", "compute")
|
||||
|
||||
# Worker queue should be created
|
||||
assert 1234 not in arbiter.worker_queues
|
||||
@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks:
|
||||
# Generate 50 chunks + end
|
||||
messages = []
|
||||
for i in range(50):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
mock_reader = MockStreamReader(messages)
|
||||
|
||||
@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 51
|
||||
@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", [1, 2, 3, 4, 5])
|
||||
response = make_response(123, [1, 2, 3, 4, 5])
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
request = make_request(123, "test:App", "get_list")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility:
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
error = make_error_response("req-123", DirtyError("Something failed"))
|
||||
error = make_error_response(123, DirtyError("Something failed"))
|
||||
mock_reader = MockStreamReader([error])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility:
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
request = make_request(123, "test:App", "fail")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
|
||||
@ -11,10 +11,12 @@ from unittest import mock
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
|
||||
@ -26,7 +28,7 @@ class MockSocket:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
self._sent = []
|
||||
self.closed = False
|
||||
@ -69,10 +71,10 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_yields_chunks(self):
|
||||
"""Test that stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -83,9 +85,9 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_yields_complex_chunks(self):
|
||||
"""Test that stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message(123, {"token": "World", "score": 0.8}),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -98,8 +100,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_handles_error(self):
|
||||
"""Test that stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
make_chunk_message(123, "First"),
|
||||
make_error_response(123, DirtyError("Something broke")),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -116,7 +118,7 @@ class TestDirtyStreamIterator:
|
||||
|
||||
def test_stream_iterator_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
messages = [make_end_message(123)]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
@ -125,8 +127,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_stops_after_exhausted(self):
|
||||
"""Test that iterator stays exhausted after StopIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Only"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -147,10 +149,10 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_iterator_with_for_loop(self):
|
||||
"""Test stream iterator works in for loop."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "a"),
|
||||
make_chunk_message("req-123", "b"),
|
||||
make_chunk_message("req-123", "c"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "a"),
|
||||
make_chunk_message(123, "b"),
|
||||
make_chunk_message(123, "c"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -163,8 +165,8 @@ class TestDirtyStreamIterator:
|
||||
def test_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first next() call."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -179,18 +181,15 @@ class TestDirtyStreamIterator:
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
assert msg_type_str == "request"
|
||||
assert payload["app_path"] == "test:App"
|
||||
assert payload["action"] == "generate"
|
||||
assert payload["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyStreamIteratorEdgeCases:
|
||||
@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
"""Test streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
def test_stream_with_kwargs(self):
|
||||
"""Test streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases:
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
assert payload["args"] == ["arg1"]
|
||||
assert payload["kwargs"] == {"key": "value"}
|
||||
|
||||
@ -10,9 +10,11 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
|
||||
@ -24,7 +26,7 @@ class MockAsyncReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_yields_chunks(self):
|
||||
"""Test that async stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_yields_complex_chunks(self):
|
||||
"""Test that async stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message(123, {"token": "World", "score": 0.8}),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_handles_error(self):
|
||||
"""Test that async stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
make_chunk_message(123, "First"),
|
||||
make_error_response(123, DirtyError("Something broke")),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
messages = [make_end_message(123)]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_stops_after_exhausted(self):
|
||||
"""Test that async iterator stays exhausted after StopAsyncIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Only"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator:
|
||||
async def test_async_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first async iteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator:
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
assert msg_type_str == "request"
|
||||
assert payload["app_path"] == "test:App"
|
||||
assert payload["action"] == "generate"
|
||||
assert payload["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
"""Test async streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
messages.append(make_chunk_message(123, f"chunk-{i}"))
|
||||
messages.append(make_end_message(123))
|
||||
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
async def test_async_stream_with_kwargs(self):
|
||||
"""Test async streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "data"),
|
||||
make_end_message(123),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
|
||||
sent_data[:HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
assert payload["args"] == ["arg1"]
|
||||
assert payload["kwargs"] == {"key": "value"}
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamTimeout:
|
||||
|
||||
@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.errors import DirtyAppNotFoundError
|
||||
|
||||
|
||||
@ -71,16 +76,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@ -18,11 +18,13 @@ from unittest import mock
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
@ -67,16 +69,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -96,7 +104,7 @@ class MockStreamReader:
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._data += BinaryProtocol._encode_from_dict(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
@ -115,10 +123,10 @@ class TestStreamingEndToEnd:
|
||||
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
|
||||
# Simulate what a worker would produce for a sync generator
|
||||
worker_messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
make_chunk_message(123, "Hello"),
|
||||
make_chunk_message(123, " "),
|
||||
make_chunk_message(123, "World"),
|
||||
make_end_message(123),
|
||||
]
|
||||
|
||||
# Create an arbiter with mocked worker connection
|
||||
@ -141,7 +149,7 @@ class TestStreamingEndToEnd:
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Execute request through arbiter
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Verify all messages were forwarded
|
||||
@ -158,10 +166,10 @@ class TestStreamingEndToEnd:
|
||||
async def test_async_generator_end_to_end(self):
|
||||
"""Test complete flow: async generator -> worker -> arbiter -> client."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-456", "Async"),
|
||||
make_chunk_message("req-456", " "),
|
||||
make_chunk_message("req-456", "Stream"),
|
||||
make_end_message("req-456"),
|
||||
make_chunk_message(456, "Async"),
|
||||
make_chunk_message(456, " "),
|
||||
make_chunk_message(456, "Stream"),
|
||||
make_end_message(456),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
@ -180,7 +188,7 @@ class TestStreamingEndToEnd:
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
request = make_request(456, "test:App", "async_generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 4
|
||||
@ -197,9 +205,9 @@ class TestStreamingErrorHandling:
|
||||
async def test_error_mid_stream(self):
|
||||
"""Test that errors during streaming are properly forwarded."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-789", "First"),
|
||||
make_chunk_message("req-789", "Second"),
|
||||
make_error_response("req-789", DirtyError("Stream failed")),
|
||||
make_chunk_message(789, "First"),
|
||||
make_chunk_message(789, "Second"),
|
||||
make_error_response(789, DirtyError("Stream failed")),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
@ -218,7 +226,7 @@ class TestStreamingErrorHandling:
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-789", "test:App", "generate_with_error")
|
||||
request = make_request(789, "test:App", "generate_with_error")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have 2 chunks + 1 error
|
||||
@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration:
|
||||
return sync_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end
|
||||
@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration:
|
||||
return async_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
request = make_request(456, "test:App", "async_generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 2 chunks + 1 end
|
||||
|
||||
@ -12,9 +12,11 @@ import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
|
||||
@ -30,17 +32,22 @@ class FakeStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode the buffer to extract messages
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
assert writer.messages[3]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generator_error_mid_stream(self):
|
||||
@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming:
|
||||
return generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
return async_generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
assert writer.messages[3]["id"] == 123
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator_error_mid_stream(self):
|
||||
@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming:
|
||||
return async_generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return args[0] + args[1]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "compute", args=(2, 3))
|
||||
request = make_request(123, "test:App", "compute", args=(2, 3))
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "response"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["id"] == 123
|
||||
assert writer.messages[0]["result"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return [1, 2, 3, 4, 5]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
request = make_request(123, "test:App", "get_list")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message (not 5 chunks)
|
||||
@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
raise RuntimeError("Failed!")
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
request = make_request(123, "test:App", "fail")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 error message
|
||||
@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat:
|
||||
return None
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "void")
|
||||
request = make_request(123, "test:App", "void")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 3 # 2 chunks + 1 end
|
||||
@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData:
|
||||
return empty_generate()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have just 1 end message
|
||||
@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData:
|
||||
return generate_many()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 100 chunks + 1 end message
|
||||
@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat:
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
request = make_request(123, "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have been notified at least once per chunk + initial
|
||||
@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation:
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
# Send a message with unknown type
|
||||
message = {"type": "unknown", "id": "req-123"}
|
||||
message = {"type": "unknown", "id": 123}
|
||||
await worker.handle_request(message, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
|
||||
21
tests/docker/http2/certs/server.crt
Normal file
21
tests/docker/http2/certs/server.crt
Normal file
@ -0,0 +1,21 @@
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL
|
||||
BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0
|
||||
MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx
|
||||
EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG
|
||||
A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY
|
||||
6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt
|
||||
z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq
|
||||
AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2
|
||||
HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr
|
||||
FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC
|
||||
TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q
|
||||
FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD
|
||||
AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ
|
||||
KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G
|
||||
Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a
|
||||
A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr
|
||||
YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z
|
||||
FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO
|
||||
5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo=
|
||||
-----END CERTIFICATE-----
|
||||
28
tests/docker/http2/certs/server.key
Normal file
28
tests/docker/http2/certs/server.key
Normal file
@ -0,0 +1,28 @@
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d
|
||||
Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP
|
||||
oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj
|
||||
KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J
|
||||
P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7
|
||||
wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/
|
||||
LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD
|
||||
CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU
|
||||
NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s
|
||||
GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB
|
||||
7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO
|
||||
jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7
|
||||
2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD
|
||||
xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1
|
||||
FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn
|
||||
UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S
|
||||
zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W
|
||||
GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21
|
||||
9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC
|
||||
7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf
|
||||
7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS
|
||||
HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY
|
||||
4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J
|
||||
vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g
|
||||
5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp
|
||||
JlfEuUHtKpFJReBnmWTFiQ==
|
||||
-----END PRIVATE KEY-----
|
||||
@ -14,7 +14,12 @@ import pytest
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.errors import DirtyError
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
HEADER_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
@ -29,16 +34,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ import pytest
|
||||
from gunicorn.arbiter import Arbiter
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.app.base import BaseApplication
|
||||
from gunicorn.dirty.protocol import DirtyProtocol
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
@ -26,16 +26,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty arbiter protocol module."""
|
||||
"""Tests for dirty worker binary protocol module."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
@ -11,12 +11,23 @@ import struct
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
BinaryProtocol,
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_error_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
MAGIC,
|
||||
VERSION,
|
||||
HEADER_SIZE,
|
||||
HEADER_FORMAT,
|
||||
MSG_TYPE_REQUEST,
|
||||
MSG_TYPE_RESPONSE,
|
||||
MSG_TYPE_ERROR,
|
||||
MSG_TYPE_CHUNK,
|
||||
MSG_TYPE_END,
|
||||
MAX_MESSAGE_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.errors import (
|
||||
DirtyError,
|
||||
@ -26,118 +37,194 @@ from gunicorn.dirty.errors import (
|
||||
)
|
||||
|
||||
|
||||
class TestDirtyProtocolEncodeDecode:
|
||||
"""Tests for encode/decode functionality."""
|
||||
class TestBinaryProtocolHeader:
|
||||
"""Tests for header encoding/decoding."""
|
||||
|
||||
def test_encode_decode_roundtrip(self):
|
||||
"""Test basic encode/decode roundtrip."""
|
||||
message = {"type": "request", "id": "123", "data": "hello"}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
def test_header_size(self):
|
||||
"""Test header size is 16 bytes."""
|
||||
assert HEADER_SIZE == 16
|
||||
|
||||
# Check header format
|
||||
assert len(encoded) > DirtyProtocol.HEADER_SIZE
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
encoded[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
assert length == len(encoded) - DirtyProtocol.HEADER_SIZE
|
||||
def test_encode_header(self):
|
||||
"""Test header encoding."""
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100)
|
||||
assert len(header) == HEADER_SIZE
|
||||
assert header[:2] == MAGIC
|
||||
assert header[2] == VERSION
|
||||
assert header[3] == MSG_TYPE_REQUEST
|
||||
|
||||
# Decode payload
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header(self):
|
||||
"""Test header decoding."""
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200)
|
||||
msg_type, request_id, length = BinaryProtocol.decode_header(header)
|
||||
assert msg_type == MSG_TYPE_RESPONSE
|
||||
assert request_id == 67890
|
||||
assert length == 200
|
||||
|
||||
def test_encode_decode_complex_data(self):
|
||||
"""Test with complex nested data structures."""
|
||||
message = {
|
||||
"type": "response",
|
||||
"id": "456",
|
||||
"result": {
|
||||
"models": ["gpt-4", "claude-3"],
|
||||
"config": {"temperature": 0.7, "max_tokens": 1000},
|
||||
"metadata": None,
|
||||
},
|
||||
"args": [1, 2, 3],
|
||||
}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header_invalid_magic(self):
|
||||
"""Test header decoding with invalid magic."""
|
||||
header = b"XX" + b"\x01\x01" + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "magic" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_decode_unicode(self):
|
||||
"""Test with unicode characters."""
|
||||
message = {
|
||||
"type": "request",
|
||||
"data": "Hello, world!"
|
||||
}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
def test_decode_header_invalid_version(self):
|
||||
"""Test header decoding with invalid version."""
|
||||
header = MAGIC + b"\x99\x01" + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "version" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_large_message(self):
|
||||
def test_decode_header_invalid_type(self):
|
||||
"""Test header decoding with invalid message type."""
|
||||
header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "type" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_header_too_large(self):
|
||||
"""Test header decoding rejects too-large messages."""
|
||||
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
||||
MAX_MESSAGE_SIZE + 1, 0)
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "too large" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_header_too_short(self):
|
||||
"""Test header decoding with too-short data."""
|
||||
header = MAGIC + b"\x01"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
BinaryProtocol.decode_header(header)
|
||||
assert "short" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestBinaryProtocolEncodeDecode:
|
||||
"""Tests for message encoding/decoding."""
|
||||
|
||||
def test_encode_decode_request(self):
|
||||
"""Test request encoding/decoding roundtrip."""
|
||||
encoded = BinaryProtocol.encode_request(
|
||||
request_id=12345,
|
||||
app_path="myapp.ml:MLApp",
|
||||
action="predict",
|
||||
args=("data",),
|
||||
kwargs={"temperature": 0.7}
|
||||
)
|
||||
assert len(encoded) > HEADER_SIZE
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "request"
|
||||
assert request_id == 12345
|
||||
assert payload["app_path"] == "myapp.ml:MLApp"
|
||||
assert payload["action"] == "predict"
|
||||
assert payload["args"] == ["data"]
|
||||
assert payload["kwargs"] == {"temperature": 0.7}
|
||||
|
||||
def test_encode_decode_response(self):
|
||||
"""Test response encoding/decoding roundtrip."""
|
||||
result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}}
|
||||
encoded = BinaryProtocol.encode_response(request_id=67890, result=result)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "response"
|
||||
assert request_id == 67890
|
||||
assert payload["result"] == result
|
||||
|
||||
def test_encode_decode_error(self):
|
||||
"""Test error encoding/decoding roundtrip."""
|
||||
error = DirtyTimeoutError("Timed out", timeout=30)
|
||||
encoded = BinaryProtocol.encode_error(request_id=11111, error=error)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "error"
|
||||
assert request_id == 11111
|
||||
assert payload["error"]["error_type"] == "DirtyTimeoutError"
|
||||
assert "Timed out" in payload["error"]["message"]
|
||||
|
||||
def test_encode_decode_chunk(self):
|
||||
"""Test chunk encoding/decoding roundtrip."""
|
||||
chunk_data = {"token": "hello", "index": 5}
|
||||
encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "chunk"
|
||||
assert request_id == 22222
|
||||
assert payload["data"] == chunk_data
|
||||
|
||||
def test_encode_decode_end(self):
|
||||
"""Test end message encoding/decoding roundtrip."""
|
||||
encoded = BinaryProtocol.encode_end(request_id=33333)
|
||||
assert len(encoded) == HEADER_SIZE # End has no payload
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert msg_type_str == "end"
|
||||
assert request_id == 33333
|
||||
assert payload == {}
|
||||
|
||||
def test_encode_decode_binary_data(self):
|
||||
"""Test binary data passes through without base64 encoding."""
|
||||
binary_data = bytes(range(256))
|
||||
encoded = BinaryProtocol.encode_response(
|
||||
request_id=44444,
|
||||
result={"data": binary_data}
|
||||
)
|
||||
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["result"]["data"] == binary_data
|
||||
|
||||
def test_encode_decode_large_message(self):
|
||||
"""Test encoding a large message."""
|
||||
large_data = "x" * (1024 * 1024) # 1 MB of data
|
||||
message = {"type": "request", "data": large_data}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
large_data = b"x" * (1024 * 1024) # 1 MB
|
||||
encoded = BinaryProtocol.encode_response(
|
||||
request_id=55555,
|
||||
result={"data": large_data}
|
||||
)
|
||||
|
||||
def test_encode_empty_dict(self):
|
||||
"""Test encoding an empty dictionary."""
|
||||
message = {}
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
payload = encoded[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == message
|
||||
|
||||
def test_encode_message_too_large(self):
|
||||
"""Test that encoding a message that's too large raises error."""
|
||||
large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000)
|
||||
message = {"data": large_data}
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.encode(message)
|
||||
assert "too large" in str(exc_info.value)
|
||||
|
||||
def test_encode_non_serializable(self):
|
||||
"""Test that encoding non-JSON-serializable data raises error."""
|
||||
message = {"func": lambda x: x}
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.encode(message)
|
||||
assert "Failed to encode" in str(exc_info.value)
|
||||
|
||||
def test_decode_invalid_json(self):
|
||||
"""Test decoding invalid JSON raises error."""
|
||||
invalid_data = b"not valid json"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.decode(invalid_data)
|
||||
assert "Failed to decode" in str(exc_info.value)
|
||||
|
||||
def test_decode_invalid_unicode(self):
|
||||
"""Test decoding invalid unicode raises error."""
|
||||
invalid_data = b"\x80\x81\x82"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.decode(invalid_data)
|
||||
assert "Failed to decode" in str(exc_info.value)
|
||||
msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["result"]["data"] == large_data
|
||||
|
||||
|
||||
class TestDirtyProtocolSync:
|
||||
class TestBinaryProtocolSync:
|
||||
"""Tests for synchronous socket operations."""
|
||||
|
||||
def test_read_write_message(self):
|
||||
"""Test read/write through socket pair."""
|
||||
# Create a socket pair for testing
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {"type": "request", "id": "123", "action": "test"}
|
||||
message = make_request(
|
||||
request_id=12345,
|
||||
app_path="test:App",
|
||||
action="run"
|
||||
)
|
||||
|
||||
# Write message
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
# Read message
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
assert received == message
|
||||
assert received["type"] == "request"
|
||||
assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \
|
||||
received["id"] == 12345
|
||||
assert received["app_path"] == "test:App"
|
||||
assert received["action"] == "run"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
def test_read_write_with_int_id(self):
|
||||
"""Test read/write with integer request ID."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {
|
||||
"type": "request",
|
||||
"id": 999888777,
|
||||
"app_path": "test:App",
|
||||
"action": "run",
|
||||
"args": [],
|
||||
"kwargs": {}
|
||||
}
|
||||
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
assert received["id"] == 999888777
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
@ -147,19 +234,17 @@ class TestDirtyProtocolSync:
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
messages = [
|
||||
{"type": "request", "id": "1"},
|
||||
{"type": "request", "id": "2"},
|
||||
{"type": "request", "id": "3"},
|
||||
make_request(i, f"app{i}:App", f"action{i}")
|
||||
for i in range(1, 4)
|
||||
]
|
||||
|
||||
# Write all messages
|
||||
for msg in messages:
|
||||
DirtyProtocol.write_message(client_sock, msg)
|
||||
BinaryProtocol.write_message(client_sock, msg)
|
||||
|
||||
# Read all messages
|
||||
for expected in messages:
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
assert received == expected
|
||||
for i, _ in enumerate(messages, 1):
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
assert received["app_path"] == f"app{i}:App"
|
||||
assert received["action"] == f"action{i}"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
@ -169,39 +254,51 @@ class TestDirtyProtocolSync:
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
client_sock.close()
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
DirtyProtocol.read_message(server_sock)
|
||||
BinaryProtocol.read_message(server_sock)
|
||||
assert "closed" in str(exc_info.value).lower()
|
||||
server_sock.close()
|
||||
|
||||
def test_binary_data_roundtrip(self):
|
||||
"""Test binary data roundtrip through socket."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
binary_payload = b"\x00\x01\x02\xff\xfe\xfd"
|
||||
message = make_response(12345, {"binary": binary_payload})
|
||||
|
||||
class TestDirtyProtocolAsync:
|
||||
BinaryProtocol.write_message(client_sock, message)
|
||||
received = BinaryProtocol.read_message(server_sock)
|
||||
|
||||
assert received["result"]["binary"] == binary_payload
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
|
||||
class TestBinaryProtocolAsync:
|
||||
"""Tests for async stream operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_write(self):
|
||||
"""Test async read/write with mock streams."""
|
||||
message = {"type": "request", "id": "123"}
|
||||
message = make_request(12345, "test:App", "run")
|
||||
|
||||
# Create a pipe for testing
|
||||
read_fd, write_fd = os.pipe()
|
||||
try:
|
||||
reader = asyncio.StreamReader()
|
||||
_ = asyncio.StreamReaderProtocol(reader)
|
||||
|
||||
# Write the message to the pipe
|
||||
encoded = DirtyProtocol.encode(message)
|
||||
encoded = BinaryProtocol._encode_from_dict(message)
|
||||
os.write(write_fd, encoded)
|
||||
os.close(write_fd)
|
||||
write_fd = None
|
||||
|
||||
# Feed data to reader
|
||||
data = os.read(read_fd, len(encoded))
|
||||
reader.feed_data(data)
|
||||
reader.feed_eof()
|
||||
|
||||
# Read the message
|
||||
received = await DirtyProtocol.read_message_async(reader)
|
||||
assert received == message
|
||||
received = await BinaryProtocol.read_message_async(reader)
|
||||
assert received["type"] == "request"
|
||||
assert received["app_path"] == "test:App"
|
||||
finally:
|
||||
if write_fd is not None:
|
||||
os.close(write_fd)
|
||||
@ -211,12 +308,11 @@ class TestDirtyProtocolAsync:
|
||||
async def test_async_read_incomplete_header(self):
|
||||
"""Test async read with incomplete header."""
|
||||
reader = asyncio.StreamReader()
|
||||
# Feed only 2 bytes instead of 4
|
||||
reader.feed_data(b"\x00\x00")
|
||||
reader.feed_data(MAGIC + b"\x01") # Only 3 bytes
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)):
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_empty_connection(self):
|
||||
@ -225,36 +321,33 @@ class TestDirtyProtocolAsync:
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(asyncio.IncompleteReadError):
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_invalid_magic(self):
|
||||
"""Test async read rejects invalid magic."""
|
||||
reader = asyncio.StreamReader()
|
||||
header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
assert "magic" in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_message_too_large(self):
|
||||
"""Test async read rejects too-large messages."""
|
||||
reader = asyncio.StreamReader()
|
||||
# Create a header claiming an absurdly large message
|
||||
header = struct.pack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
DirtyProtocol.MAX_MESSAGE_SIZE + 1000
|
||||
)
|
||||
header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST,
|
||||
MAX_MESSAGE_SIZE + 1000, 0)
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
await BinaryProtocol.read_message_async(reader)
|
||||
assert "too large" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_empty_message(self):
|
||||
"""Test async read rejects empty messages."""
|
||||
reader = asyncio.StreamReader()
|
||||
header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0)
|
||||
reader.feed_data(header)
|
||||
reader.feed_eof()
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
await DirtyProtocol.read_message_async(reader)
|
||||
assert "Empty message" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestMessageBuilders:
|
||||
"""Tests for message builder helper functions."""
|
||||
@ -340,9 +433,9 @@ class TestMessageBuilders:
|
||||
assert chunk["id"] == "req-456"
|
||||
assert chunk["data"] == data
|
||||
|
||||
def test_make_chunk_message_with_list_data(self):
|
||||
"""Test chunk message with list data."""
|
||||
data = [1, 2, 3, "token"]
|
||||
def test_make_chunk_message_with_binary_data(self):
|
||||
"""Test chunk message with binary data."""
|
||||
data = b"\x00\x01\x02\xff"
|
||||
chunk = make_chunk_message("req-789", data)
|
||||
assert chunk["data"] == data
|
||||
|
||||
@ -353,22 +446,22 @@ class TestMessageBuilders:
|
||||
assert end["id"] == "req-123"
|
||||
assert "data" not in end
|
||||
|
||||
def test_chunk_and_end_encode_decode(self):
|
||||
def test_chunk_and_end_roundtrip(self):
|
||||
"""Test that chunk and end messages can be encoded/decoded."""
|
||||
chunk = make_chunk_message("req-123", {"token": "hello"})
|
||||
end = make_end_message("req-123")
|
||||
chunk = make_chunk_message(12345, {"token": "hello"})
|
||||
end = make_end_message(12345)
|
||||
|
||||
# Test chunk roundtrip
|
||||
encoded_chunk = DirtyProtocol.encode(chunk)
|
||||
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == chunk
|
||||
encoded_chunk = BinaryProtocol._encode_from_dict(chunk)
|
||||
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk)
|
||||
assert msg_type == "chunk"
|
||||
assert payload["data"] == {"token": "hello"}
|
||||
|
||||
# Test end roundtrip
|
||||
encoded_end = DirtyProtocol.encode(end)
|
||||
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == end
|
||||
encoded_end = BinaryProtocol._encode_from_dict(end)
|
||||
msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end)
|
||||
assert msg_type == "end"
|
||||
assert payload == {}
|
||||
|
||||
|
||||
class TestDirtyErrors:
|
||||
@ -417,3 +510,58 @@ class TestDirtyErrors:
|
||||
assert error.action == "run"
|
||||
assert error.traceback == "Traceback..."
|
||||
assert "myapp:App" in str(error)
|
||||
|
||||
|
||||
class TestBackwardsCompatibility:
|
||||
"""Tests for backwards compatibility with old JSON API."""
|
||||
|
||||
def test_dirty_protocol_alias(self):
|
||||
"""Test that DirtyProtocol is an alias for BinaryProtocol."""
|
||||
assert DirtyProtocol is BinaryProtocol
|
||||
|
||||
def test_header_size_attribute(self):
|
||||
"""Test HEADER_SIZE is accessible on class."""
|
||||
assert DirtyProtocol.HEADER_SIZE == 16
|
||||
|
||||
def test_msg_type_constants(self):
|
||||
"""Test message type constants are strings for compatibility."""
|
||||
assert DirtyProtocol.MSG_TYPE_REQUEST == "request"
|
||||
assert DirtyProtocol.MSG_TYPE_RESPONSE == "response"
|
||||
assert DirtyProtocol.MSG_TYPE_ERROR == "error"
|
||||
assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk"
|
||||
assert DirtyProtocol.MSG_TYPE_END == "end"
|
||||
|
||||
def test_encode_decode_preserves_dict_format(self):
|
||||
"""Test that read_message returns dict compatible with old API."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = {
|
||||
"type": "response",
|
||||
"id": 12345,
|
||||
"result": {"status": "ok"}
|
||||
}
|
||||
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
|
||||
# Old API: access via dict keys
|
||||
assert received["type"] == "response"
|
||||
assert received["result"]["status"] == "ok"
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
def test_string_request_id_handled(self):
|
||||
"""Test that string request IDs are handled (hashed to int)."""
|
||||
server_sock, client_sock = socket.socketpair()
|
||||
try:
|
||||
message = make_request("uuid-string-id", "test:App", "run")
|
||||
|
||||
DirtyProtocol.write_message(client_sock, message)
|
||||
received = DirtyProtocol.read_message(server_sock)
|
||||
|
||||
# Request ID should be converted to int
|
||||
assert isinstance(received["id"], int)
|
||||
finally:
|
||||
server_sock.close()
|
||||
client_sock.close()
|
||||
|
||||
554
tests/test_dirty_tlv.py
Normal file
554
tests/test_dirty_tlv.py
Normal file
@ -0,0 +1,554 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty TLV binary encoder/decoder."""
|
||||
|
||||
import math
|
||||
import struct
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.tlv import (
|
||||
TLVEncoder,
|
||||
TYPE_NONE,
|
||||
TYPE_BOOL,
|
||||
TYPE_INT64,
|
||||
TYPE_FLOAT64,
|
||||
TYPE_BYTES,
|
||||
TYPE_STRING,
|
||||
TYPE_LIST,
|
||||
TYPE_DICT,
|
||||
MAX_STRING_SIZE,
|
||||
MAX_BYTES_SIZE,
|
||||
MAX_LIST_SIZE,
|
||||
MAX_DICT_SIZE,
|
||||
)
|
||||
from gunicorn.dirty.errors import DirtyProtocolError
|
||||
|
||||
|
||||
class TestTLVEncoderBasicTypes:
|
||||
"""Tests for basic type encoding/decoding."""
|
||||
|
||||
def test_encode_decode_none(self):
|
||||
"""Test None encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(None)
|
||||
assert encoded == bytes([TYPE_NONE])
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value is None
|
||||
assert offset == 1
|
||||
|
||||
def test_encode_decode_true(self):
|
||||
"""Test True encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(True)
|
||||
assert encoded == bytes([TYPE_BOOL, 0x01])
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value is True
|
||||
assert offset == 2
|
||||
|
||||
def test_encode_decode_false(self):
|
||||
"""Test False encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(False)
|
||||
assert encoded == bytes([TYPE_BOOL, 0x00])
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value is False
|
||||
assert offset == 2
|
||||
|
||||
def test_encode_decode_positive_int(self):
|
||||
"""Test positive integer encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(42)
|
||||
assert encoded[0] == TYPE_INT64
|
||||
assert len(encoded) == 9 # 1 type + 8 value
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == 42
|
||||
assert offset == 9
|
||||
|
||||
def test_encode_decode_negative_int(self):
|
||||
"""Test negative integer encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(-12345)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == -12345
|
||||
|
||||
def test_encode_decode_large_int(self):
|
||||
"""Test large integer encoding/decoding."""
|
||||
large_val = 2**62
|
||||
encoded = TLVEncoder.encode(large_val)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == large_val
|
||||
|
||||
def test_encode_decode_zero(self):
|
||||
"""Test zero encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(0)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == 0
|
||||
|
||||
def test_encode_decode_float(self):
|
||||
"""Test float encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(3.14159)
|
||||
assert encoded[0] == TYPE_FLOAT64
|
||||
assert len(encoded) == 9 # 1 type + 8 value
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert abs(value - 3.14159) < 1e-10
|
||||
|
||||
def test_encode_decode_negative_float(self):
|
||||
"""Test negative float encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(-273.15)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert abs(value - (-273.15)) < 1e-10
|
||||
|
||||
def test_encode_decode_float_infinity(self):
|
||||
"""Test infinity encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(float('inf'))
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == float('inf')
|
||||
|
||||
def test_encode_decode_float_nan(self):
|
||||
"""Test NaN encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(float('nan'))
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert math.isnan(value)
|
||||
|
||||
|
||||
class TestTLVEncoderBytes:
|
||||
"""Tests for bytes encoding/decoding."""
|
||||
|
||||
def test_encode_decode_empty_bytes(self):
|
||||
"""Test empty bytes encoding/decoding."""
|
||||
encoded = TLVEncoder.encode(b"")
|
||||
assert encoded[0] == TYPE_BYTES
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == b""
|
||||
|
||||
def test_encode_decode_bytes(self):
|
||||
"""Test bytes encoding/decoding."""
|
||||
data = b"\x00\x01\x02\xff\xfe\xfd"
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_large_bytes(self):
|
||||
"""Test large bytes encoding/decoding."""
|
||||
data = b"x" * 10000
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_bytes_too_large(self):
|
||||
"""Test that bytes exceeding max size raises error."""
|
||||
# We won't actually allocate MAX_BYTES_SIZE, just check the encoding
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.encode(b"x" * (MAX_BYTES_SIZE + 1))
|
||||
assert "too large" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestTLVEncoderString:
|
||||
"""Tests for string encoding/decoding."""
|
||||
|
||||
def test_encode_decode_empty_string(self):
|
||||
"""Test empty string encoding/decoding."""
|
||||
encoded = TLVEncoder.encode("")
|
||||
assert encoded[0] == TYPE_STRING
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == ""
|
||||
|
||||
def test_encode_decode_ascii_string(self):
|
||||
"""Test ASCII string encoding/decoding."""
|
||||
encoded = TLVEncoder.encode("hello world")
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == "hello world"
|
||||
|
||||
def test_encode_decode_unicode_string(self):
|
||||
"""Test Unicode string encoding/decoding."""
|
||||
text = "Hello, world! \u00a9 \u2603 \U0001F600"
|
||||
encoded = TLVEncoder.encode(text)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == text
|
||||
|
||||
def test_encode_decode_chinese(self):
|
||||
"""Test Chinese characters encoding/decoding."""
|
||||
text = "Hello, world!"
|
||||
encoded = TLVEncoder.encode(text)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == text
|
||||
|
||||
def test_encode_decode_emoji(self):
|
||||
"""Test emoji encoding/decoding."""
|
||||
text = "Test emoji"
|
||||
encoded = TLVEncoder.encode(text)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == text
|
||||
|
||||
def test_encode_decode_large_string(self):
|
||||
"""Test large string encoding/decoding."""
|
||||
text = "x" * 10000
|
||||
encoded = TLVEncoder.encode(text)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == text
|
||||
|
||||
|
||||
class TestTLVEncoderList:
|
||||
"""Tests for list encoding/decoding."""
|
||||
|
||||
def test_encode_decode_empty_list(self):
|
||||
"""Test empty list encoding/decoding."""
|
||||
encoded = TLVEncoder.encode([])
|
||||
assert encoded[0] == TYPE_LIST
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == []
|
||||
|
||||
def test_encode_decode_simple_list(self):
|
||||
"""Test simple list encoding/decoding."""
|
||||
data = [1, 2, 3]
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_mixed_list(self):
|
||||
"""Test mixed type list encoding/decoding."""
|
||||
data = [1, "hello", 3.14, True, None, b"bytes"]
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_nested_list(self):
|
||||
"""Test nested list encoding/decoding."""
|
||||
data = [[1, 2], [3, [4, 5]], ["a", "b"]]
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_tuple_as_list(self):
|
||||
"""Test that tuples are encoded as lists."""
|
||||
data = (1, 2, 3)
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == [1, 2, 3] # Decoded as list
|
||||
|
||||
def test_encode_decode_large_list(self):
|
||||
"""Test large list encoding/decoding."""
|
||||
data = list(range(1000))
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
|
||||
class TestTLVEncoderDict:
|
||||
"""Tests for dict encoding/decoding."""
|
||||
|
||||
def test_encode_decode_empty_dict(self):
|
||||
"""Test empty dict encoding/decoding."""
|
||||
encoded = TLVEncoder.encode({})
|
||||
assert encoded[0] == TYPE_DICT
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == {}
|
||||
|
||||
def test_encode_decode_simple_dict(self):
|
||||
"""Test simple dict encoding/decoding."""
|
||||
data = {"a": 1, "b": 2, "c": 3}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_mixed_values_dict(self):
|
||||
"""Test dict with mixed value types."""
|
||||
data = {
|
||||
"int": 42,
|
||||
"float": 3.14,
|
||||
"string": "hello",
|
||||
"bool": True,
|
||||
"none": None,
|
||||
"bytes": b"data",
|
||||
"list": [1, 2, 3],
|
||||
}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_nested_dict(self):
|
||||
"""Test nested dict encoding/decoding."""
|
||||
data = {
|
||||
"outer": {
|
||||
"inner": {
|
||||
"value": 42
|
||||
},
|
||||
"list": [{"a": 1}, {"b": 2}]
|
||||
}
|
||||
}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_dict_non_string_key_converted(self):
|
||||
"""Test that non-string keys are converted to strings (like JSON)."""
|
||||
data = {1: "value", 2: "other"}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
decoded, _ = TLVEncoder.decode(encoded, 0)
|
||||
# Keys should be converted to strings
|
||||
assert decoded == {"1": "value", "2": "other"}
|
||||
|
||||
|
||||
class TestTLVEncoderComplexStructures:
|
||||
"""Tests for complex nested structures."""
|
||||
|
||||
def test_encode_decode_request_like(self):
|
||||
"""Test encoding/decoding a request-like structure."""
|
||||
data = {
|
||||
"id": 12345,
|
||||
"app_path": "myapp.ml:MLApp",
|
||||
"action": "predict",
|
||||
"args": [b"input_data", 0.7],
|
||||
"kwargs": {"temperature": 0.7, "max_tokens": 1000},
|
||||
}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_response_like(self):
|
||||
"""Test encoding/decoding a response-like structure."""
|
||||
data = {
|
||||
"id": 12345,
|
||||
"result": {
|
||||
"predictions": [0.1, 0.2, 0.7],
|
||||
"metadata": {"model": "v1.0", "latency_ms": 42},
|
||||
}
|
||||
}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
def test_encode_decode_deeply_nested(self):
|
||||
"""Test deeply nested structures."""
|
||||
data = {"a": {"b": {"c": {"d": {"e": {"f": "deep"}}}}}}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
|
||||
|
||||
class TestTLVEncoderRoundtrip:
|
||||
"""Tests for complete roundtrip using decode_full."""
|
||||
|
||||
def test_decode_full_simple(self):
|
||||
"""Test decode_full with simple value."""
|
||||
data = {"key": "value"}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value = TLVEncoder.decode_full(encoded)
|
||||
assert value == data
|
||||
|
||||
def test_decode_full_trailing_data(self):
|
||||
"""Test decode_full raises on trailing data."""
|
||||
encoded = TLVEncoder.encode(42) + b"extra"
|
||||
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode_full(encoded)
|
||||
assert "trailing" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestTLVEncoderErrors:
|
||||
"""Tests for error handling."""
|
||||
|
||||
def test_decode_empty_data(self):
|
||||
"""Test decoding empty data raises error."""
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(b"", 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_int(self):
|
||||
"""Test decoding truncated int raises error."""
|
||||
# TYPE_INT64 followed by only 4 bytes instead of 8
|
||||
data = bytes([TYPE_INT64, 0, 0, 0, 0])
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_float(self):
|
||||
"""Test decoding truncated float raises error."""
|
||||
data = bytes([TYPE_FLOAT64, 0, 0, 0, 0])
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_bytes_length(self):
|
||||
"""Test decoding truncated bytes length raises error."""
|
||||
data = bytes([TYPE_BYTES, 0, 0]) # Only 2 bytes of length
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_bytes_data(self):
|
||||
"""Test decoding truncated bytes data raises error."""
|
||||
# Says 10 bytes but only provides 5
|
||||
data = bytes([TYPE_BYTES]) + struct.pack(">I", 10) + b"12345"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_string_length(self):
|
||||
"""Test decoding truncated string length raises error."""
|
||||
data = bytes([TYPE_STRING, 0])
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_string_data(self):
|
||||
"""Test decoding truncated string data raises error."""
|
||||
data = bytes([TYPE_STRING]) + struct.pack(">I", 10) + b"hello"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_invalid_utf8(self):
|
||||
"""Test decoding invalid UTF-8 raises error."""
|
||||
# Valid length, but invalid UTF-8 bytes
|
||||
data = bytes([TYPE_STRING]) + struct.pack(">I", 3) + b"\x80\x81\x82"
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "utf-8" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_list_count(self):
|
||||
"""Test decoding truncated list count raises error."""
|
||||
data = bytes([TYPE_LIST, 0])
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_truncated_dict_count(self):
|
||||
"""Test decoding truncated dict count raises error."""
|
||||
data = bytes([TYPE_DICT, 0])
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "truncated" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_unknown_type(self):
|
||||
"""Test decoding unknown type raises error."""
|
||||
data = bytes([0xFF]) # Unknown type
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "unknown" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_unsupported_type(self):
|
||||
"""Test encoding unsupported type raises error."""
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.encode(object())
|
||||
assert "unsupported type" in str(exc_info.value).lower()
|
||||
|
||||
def test_encode_function_raises_error(self):
|
||||
"""Test encoding a function raises error."""
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.encode(lambda x: x)
|
||||
assert "unsupported type" in str(exc_info.value).lower()
|
||||
|
||||
def test_decode_dict_non_string_key_in_data(self):
|
||||
"""Test decoding dict with non-string key raises error."""
|
||||
# Manually construct a dict with int key
|
||||
# TYPE_DICT, count=1, TYPE_INT64 key, TYPE_INT64 value
|
||||
data = (
|
||||
bytes([TYPE_DICT])
|
||||
+ struct.pack(">I", 1)
|
||||
+ bytes([TYPE_INT64])
|
||||
+ struct.pack(">q", 1) # Key (int, not string)
|
||||
+ bytes([TYPE_INT64])
|
||||
+ struct.pack(">q", 2) # Value
|
||||
)
|
||||
with pytest.raises(DirtyProtocolError) as exc_info:
|
||||
TLVEncoder.decode(data, 0)
|
||||
assert "string" in str(exc_info.value).lower()
|
||||
|
||||
|
||||
class TestTLVEncoderOffset:
|
||||
"""Tests for offset handling."""
|
||||
|
||||
def test_decode_with_offset(self):
|
||||
"""Test decoding from specific offset."""
|
||||
# Create data with prefix
|
||||
prefix = b"garbage"
|
||||
encoded = TLVEncoder.encode(42)
|
||||
data = prefix + encoded
|
||||
|
||||
value, offset = TLVEncoder.decode(data, len(prefix))
|
||||
assert value == 42
|
||||
assert offset == len(prefix) + len(encoded)
|
||||
|
||||
def test_decode_multiple_values(self):
|
||||
"""Test decoding multiple consecutive values."""
|
||||
v1 = TLVEncoder.encode("hello")
|
||||
v2 = TLVEncoder.encode(42)
|
||||
v3 = TLVEncoder.encode([1, 2, 3])
|
||||
data = v1 + v2 + v3
|
||||
|
||||
offset = 0
|
||||
val1, offset = TLVEncoder.decode(data, offset)
|
||||
assert val1 == "hello"
|
||||
|
||||
val2, offset = TLVEncoder.decode(data, offset)
|
||||
assert val2 == 42
|
||||
|
||||
val3, offset = TLVEncoder.decode(data, offset)
|
||||
assert val3 == [1, 2, 3]
|
||||
|
||||
assert offset == len(data)
|
||||
|
||||
|
||||
class TestTLVEncoderBinaryData:
|
||||
"""Tests for binary data handling (the main motivation for this protocol)."""
|
||||
|
||||
def test_binary_data_no_encoding(self):
|
||||
"""Test that binary data is passed through without encoding."""
|
||||
# This is the key advantage over JSON - binary data doesn't need base64
|
||||
binary_data = bytes(range(256)) # All byte values
|
||||
encoded = TLVEncoder.encode(binary_data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == binary_data
|
||||
|
||||
def test_binary_with_null_bytes(self):
|
||||
"""Test binary data with embedded null bytes."""
|
||||
binary_data = b"\x00\x00\xff\x00\x00"
|
||||
encoded = TLVEncoder.encode(binary_data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == binary_data
|
||||
|
||||
def test_binary_in_nested_structure(self):
|
||||
"""Test binary data inside nested structures."""
|
||||
data = {
|
||||
"image": b"\x89PNG\r\n\x1a\n" + b"\x00" * 100,
|
||||
"metadata": {"width": 640, "height": 480},
|
||||
"chunks": [b"chunk1", b"chunk2", b"chunk3"],
|
||||
}
|
||||
encoded = TLVEncoder.encode(data)
|
||||
|
||||
value, offset = TLVEncoder.decode(encoded, 0)
|
||||
assert value == data
|
||||
@ -12,7 +12,13 @@ import pytest
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
BinaryProtocol,
|
||||
make_request,
|
||||
HEADER_SIZE,
|
||||
HEADER_FORMAT,
|
||||
)
|
||||
from gunicorn.dirty.errors import DirtyAppNotFoundError
|
||||
|
||||
|
||||
@ -56,17 +62,22 @@ class MockStreamWriter:
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode the buffer to extract messages
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
# Decode the buffer to extract messages using binary protocol
|
||||
while len(self._buffer) >= HEADER_SIZE:
|
||||
# Decode header to get payload length
|
||||
_, _, length = BinaryProtocol.decode_header(
|
||||
self._buffer[:HEADER_SIZE]
|
||||
)
|
||||
total_size = HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
msg_data = self._buffer[:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
# decode_message returns (msg_type_str, request_id, payload_dict)
|
||||
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
|
||||
# Reconstruct the dict format for backwards compatibility
|
||||
result = {"type": msg_type_str, "id": request_id}
|
||||
result.update(payload_dict)
|
||||
self.messages.append(result)
|
||||
else:
|
||||
break
|
||||
|
||||
@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest:
|
||||
worker.load_apps()
|
||||
|
||||
request = make_request(
|
||||
request_id="test-123",
|
||||
request_id=123,
|
||||
app_path="tests.support_dirty_app:TestDirtyApp",
|
||||
action="compute",
|
||||
args=(2, 3),
|
||||
@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest:
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
|
||||
assert response["id"] == "test-123"
|
||||
assert response["id"] == 123
|
||||
assert response["result"] == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest:
|
||||
worker.load_apps()
|
||||
|
||||
request = make_request(
|
||||
request_id="test-456",
|
||||
request_id=456,
|
||||
app_path="tests.support_dirty_app:TestDirtyApp",
|
||||
action="compute",
|
||||
args=(2, 3),
|
||||
@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest:
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
||||
assert response["id"] == "test-456"
|
||||
assert response["id"] == 456
|
||||
assert "Unknown operation" in response["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest:
|
||||
socket_path=socket_path
|
||||
)
|
||||
|
||||
request = {"type": "unknown", "id": "test-789"}
|
||||
request = {"type": "unknown", "id": 789}
|
||||
writer = MockStreamWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync:
|
||||
|
||||
# Create a simple test using stream reader/writer
|
||||
request = make_request(
|
||||
request_id="conn-test",
|
||||
request_id=999,
|
||||
app_path="tests.support_dirty_app:TestDirtyApp",
|
||||
action="compute",
|
||||
args=(5, 3),
|
||||
@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync:
|
||||
|
||||
# Mock reader and writer
|
||||
reader = asyncio.StreamReader()
|
||||
encoded_request = DirtyProtocol.encode(request)
|
||||
encoded_request = BinaryProtocol._encode_from_dict(request)
|
||||
reader.feed_data(encoded_request)
|
||||
reader.feed_eof()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user