gunicorn/examples/dirty_example/test_protocol.py
Benoit Chesneau 95b7ffeeaa chore: prepare release 25.0.2
- Bump version to 25.0.2
- Update copyright year to 2026 in LICENSE and NOTICE
- Add license headers to all Python source files
- Add changelog entry for 25.0.2
2026-02-06 08:21:18 +01:00

208 lines
5.7 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
#!/usr/bin/env python
"""
Test script to demonstrate the Dirty Protocol layer.
Run with:
python examples/dirty_example/test_protocol.py
"""
import sys
import os
import asyncio
import socket
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_response,
make_error_response,
)
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
def test_protocol_encode_decode():
"""Test protocol encoding and decoding."""
print("=" * 60)
print("Testing Protocol Encode/Decode")
print("=" * 60)
# Test request
print("\n1. Creating a request message...")
request = make_request(
request_id="req-001",
app_path="myapp.ml:MLApp",
action="inference",
args=("model1",),
kwargs={"temperature": 0.7}
)
print(f" Request: {request}")
# Encode
print("\n2. Encoding message...")
encoded = DirtyProtocol.encode(request)
print(f" Encoded length: {len(encoded)} bytes")
print(f" Header (4 bytes): {encoded[:4].hex()}")
# Decode
print("\n3. Decoding payload...")
payload = encoded[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
print(f" Decoded: {decoded}")
print(f" Match: {decoded == request}")
def test_protocol_response():
"""Test response message building."""
print("\n" + "=" * 60)
print("Testing Response Messages")
print("=" * 60)
# Success response
print("\n1. Creating success response...")
response = make_response("req-001", {"result": "Hello, World!", "confidence": 0.95})
print(f" Response: {response}")
# Error response
print("\n2. Creating error response...")
error = DirtyTimeoutError("Operation timed out", timeout=30)
error_response = make_error_response("req-001", error)
print(f" Error response: {error_response}")
def test_socket_communication():
"""Test sync protocol over actual sockets."""
print("\n" + "=" * 60)
print("Testing Socket Communication")
print("=" * 60)
# Create a socket pair
server_sock, client_sock = socket.socketpair()
try:
# Send a request
print("\n1. Sending request over socket...")
request = make_request(
request_id="socket-req-001",
app_path="test:App",
action="compute",
args=(1, 2, 3),
kwargs={}
)
DirtyProtocol.write_message(client_sock, request)
print(f" Sent: {request}")
# Receive the request
print("\n2. Receiving request...")
received = DirtyProtocol.read_message(server_sock)
print(f" Received: {received}")
print(f" Match: {received == request}")
# Send a response
print("\n3. Sending response...")
response = make_response("socket-req-001", {"sum": 6})
DirtyProtocol.write_message(server_sock, response)
print(f" Sent: {response}")
# Receive the response
print("\n4. Receiving response...")
received = DirtyProtocol.read_message(client_sock)
print(f" Received: {received}")
print(f" Match: {received == response}")
finally:
server_sock.close()
client_sock.close()
async def test_async_communication():
"""Test async protocol over streams."""
print("\n" + "=" * 60)
print("Testing Async Communication")
print("=" * 60)
# Use a pipe for async testing
read_fd, write_fd = os.pipe()
try:
# Create message
request = make_request(
request_id="async-req-001",
app_path="async:App",
action="process",
args=("data",),
kwargs={"async": True}
)
# Write to pipe
print("\n1. Writing async message...")
encoded = DirtyProtocol.encode(request)
os.write(write_fd, encoded)
os.close(write_fd)
write_fd = None
print(f" Wrote {len(encoded)} bytes")
# Read from pipe using async reader
print("\n2. Reading async message...")
reader = asyncio.StreamReader()
data = os.read(read_fd, len(encoded))
reader.feed_data(data)
reader.feed_eof()
received = await DirtyProtocol.read_message_async(reader)
print(f" Received: {received}")
print(f" Match: {received == request}")
finally:
if write_fd is not None:
os.close(write_fd)
os.close(read_fd)
def test_error_serialization():
"""Test error serialization and deserialization."""
print("\n" + "=" * 60)
print("Testing Error Serialization")
print("=" * 60)
# Create various errors
errors = [
DirtyError("Generic error", {"code": 500}),
DirtyTimeoutError("Timeout!", timeout=60),
]
for error in errors:
print(f"\n1. Original error: {error}")
print(f" Type: {type(error).__name__}")
# Serialize
error_dict = error.to_dict()
print(f"2. Serialized: {error_dict}")
# Deserialize
restored = DirtyError.from_dict(error_dict)
print(f"3. Restored: {restored}")
print(f" Type: {type(restored).__name__}")
print(f" Match type: {type(restored) == type(error)}")
if __name__ == "__main__":
print("\n" + "#" * 60)
print("# Dirty Protocol Demonstration")
print("#" * 60)
test_protocol_encode_decode()
test_protocol_response()
test_socket_communication()
asyncio.run(test_async_communication())
test_error_serialization()
print("\n" + "#" * 60)
print("# All protocol tests passed!")
print("#" * 60 + "\n")