feat(dirty): add streaming support and async client benchmarks

Add support for streaming responses when dirty app actions return
generators (sync or async). This enables real-time delivery of
incremental results for use cases like LLM token generation.

Features:
- Streaming protocol with chunk/end/error message types
- Worker support for sync and async generators
- Arbiter forwarding of streaming messages
- Deadline-based timeout handling
- Async client streaming API

Protocol:
- Chunk messages (type: "chunk") contain partial data
- End messages (type: "end") signal stream completion
- Error messages can occur mid-stream

New files:
- benchmarks/dirty_streaming.py: Streaming benchmark suite
- tests/dirty/test_*_streaming*.py: Streaming test coverage
- docs/content/dirty.md: Streaming documentation with examples
This commit is contained in:
Benoit Chesneau 2026-01-24 18:39:14 +01:00
parent 62a29bd0e1
commit f6418d4eb0
15 changed files with 3339 additions and 78 deletions

View File

@ -0,0 +1,755 @@
#!/usr/bin/env python
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Benchmark suite for dirty worker streaming functionality.
This script benchmarks the streaming performance of dirty workers
to measure throughput, latency, and memory usage.
Usage:
python benchmarks/dirty_streaming.py [OPTIONS]
Options:
--quick Run quick benchmarks only
--full Run full benchmark suite including stress tests
"""
import argparse
import asyncio
import gc
import json
import os
import struct
import sys
import time
import tracemalloc
from datetime import datetime
from unittest import mock
# Add parent directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_chunk_message,
make_end_message,
make_response,
)
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.client import (
DirtyClient,
DirtyStreamIterator,
DirtyAsyncStreamIterator,
)
from gunicorn.config import Config
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.bytes_written = 0
def write(self, data):
self._buffer += data
self.bytes_written += len(data)
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
pass
async def wait_closed(self):
pass
class MockStreamReader:
"""Mock StreamReader that yields predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class MockLog:
"""Silent logger for benchmarks."""
def debug(self, msg, *args):
pass
def info(self, msg, *args):
pass
def warning(self, msg, *args):
pass
def error(self, msg, *args):
pass
def close_on_exec(self):
pass
def reopen_files(self):
pass
def create_worker():
"""Create a test worker for benchmarks."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=os.getpid(),
app_paths=["benchmark:App"],
cfg=cfg,
log=log,
socket_path="/tmp/benchmark.sock"
)
worker.apps = {}
worker._executor = None
worker.tmp = mock.Mock()
return worker
def create_arbiter():
"""Create a test arbiter for benchmarks."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
return arbiter
class BenchmarkResults:
"""Store and display benchmark results."""
def __init__(self):
self.results = []
def add(self, name, iterations, duration, chunks=None, bytes_total=None,
memory_start=None, memory_end=None):
throughput = iterations / duration if duration > 0 else 0
result = {
"name": name,
"iterations": iterations,
"duration_s": round(duration, 4),
"throughput_per_s": round(throughput, 2),
}
if chunks:
result["chunks_per_s"] = round(chunks / duration, 2)
if bytes_total:
result["mb_per_s"] = round(bytes_total / (1024 * 1024) / duration, 2)
if memory_start is not None and memory_end is not None:
result["memory_start_mb"] = round(memory_start / (1024 * 1024), 2)
result["memory_end_mb"] = round(memory_end / (1024 * 1024), 2)
result["memory_delta_mb"] = round((memory_end - memory_start) / (1024 * 1024), 2)
self.results.append(result)
def display(self):
print("\n" + "=" * 70)
print("BENCHMARK RESULTS")
print("=" * 70)
for result in self.results:
print(f"\n{result['name']}")
print("-" * 50)
for key, value in result.items():
if key != "name":
print(f" {key}: {value}")
print("\n" + "=" * 70)
def save_json(self, filepath):
with open(filepath, 'w') as f:
json.dump({
"timestamp": datetime.now().isoformat(),
"results": self.results
}, f, indent=2)
print(f"Results saved to {filepath}")
async def benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000):
"""Benchmark worker streaming throughput with various chunk sizes."""
worker = create_worker()
writer = MockStreamWriter()
chunk_data = "x" * chunk_size
async def sync_gen():
for _ in range(num_chunks):
yield chunk_data
async def mock_execute(app_path, action, args, kwargs):
return sync_gen()
gc.collect()
tracemalloc.start()
memory_start = tracemalloc.get_traced_memory()[0]
start = time.perf_counter()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("bench-1", "benchmark:App", "stream")
await worker.handle_request(request, writer)
duration = time.perf_counter() - start
memory_end = tracemalloc.get_traced_memory()[0]
tracemalloc.stop()
total_bytes = chunk_size * num_chunks
results.add(
f"Worker streaming ({chunk_size}B chunks, {num_chunks} chunks)",
iterations=1,
duration=duration,
chunks=num_chunks,
bytes_total=total_bytes,
memory_start=memory_start,
memory_end=memory_end
)
async def benchmark_arbiter_forwarding(results, num_chunks=1000):
"""Benchmark arbiter message forwarding throughput."""
arbiter = create_arbiter()
messages = []
for i in range(num_chunks):
messages.append(make_chunk_message(f"bench-{i}", f"data-{i}"))
messages.append(make_end_message(f"bench-{num_chunks}"))
mock_reader = MockStreamReader(messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
gc.collect()
start = time.perf_counter()
request = make_request("bench-forward", "benchmark:App", "stream")
await arbiter._execute_on_worker(1234, request, client_writer)
duration = time.perf_counter() - start
results.add(
f"Arbiter forwarding ({num_chunks} chunks)",
iterations=1,
duration=duration,
chunks=num_chunks,
bytes_total=client_writer.bytes_written
)
arbiter._cleanup_sync()
async def benchmark_streaming_latency(results, iterations=100):
"""Benchmark time-to-first-chunk and time-to-last-chunk."""
worker = create_worker()
first_chunk_times = []
total_times = []
for _ in range(iterations):
writer = MockStreamWriter()
async def gen_3_chunks():
yield "first"
yield "second"
yield "third"
async def mock_execute(app_path, action, args, kwargs):
return gen_3_chunks()
start = time.perf_counter()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("bench-latency", "benchmark:App", "stream")
await worker.handle_request(request, writer)
# Find time when first chunk was received
if writer.messages:
first_chunk_times.append(time.perf_counter() - start)
total_times.append(time.perf_counter() - start)
avg_first_chunk = sum(first_chunk_times) / len(first_chunk_times) if first_chunk_times else 0
avg_total = sum(total_times) / len(total_times)
print(f"\nLatency Results ({iterations} iterations):")
print(f" Avg time-to-first-chunk: {avg_first_chunk * 1000:.3f}ms")
print(f" Avg time-to-last-chunk: {avg_total * 1000:.3f}ms")
results.add(
f"Streaming latency ({iterations} iterations)",
iterations=iterations,
duration=sum(total_times),
chunks=iterations * 3
)
async def benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100):
"""Benchmark multiple concurrent streams."""
arbiter = create_arbiter()
async def run_stream(stream_id):
messages = []
for i in range(chunks_per_stream):
messages.append(make_chunk_message(f"stream-{stream_id}", f"chunk-{i}"))
messages.append(make_end_message(f"stream-{stream_id}"))
mock_reader = MockStreamReader(messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request(f"bench-concurrent-{stream_id}", "benchmark:App", "stream")
await arbiter._execute_on_worker(1234, request, client_writer)
return len(client_writer.messages)
gc.collect()
start = time.perf_counter()
# Run streams concurrently
tasks = [run_stream(i) for i in range(num_streams)]
results_list = await asyncio.gather(*tasks)
duration = time.perf_counter() - start
total_chunks = sum(results_list)
results.add(
f"Concurrent streams ({num_streams} streams, {chunks_per_stream} chunks each)",
iterations=num_streams,
duration=duration,
chunks=total_chunks
)
arbiter._cleanup_sync()
async def benchmark_memory_stability(results, iterations=10, chunks=1000):
"""Check memory stability over many iterations."""
worker = create_worker()
gc.collect()
tracemalloc.start()
memory_samples = [tracemalloc.get_traced_memory()[0]]
for i in range(iterations):
writer = MockStreamWriter()
async def gen_chunks():
for j in range(chunks):
yield f"chunk-{j}"
async def mock_execute(app_path, action, args, kwargs):
return gen_chunks()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(f"bench-mem-{i}", "benchmark:App", "stream")
await worker.handle_request(request, writer)
gc.collect()
memory_samples.append(tracemalloc.get_traced_memory()[0])
tracemalloc.stop()
memory_start = memory_samples[0]
memory_end = memory_samples[-1]
memory_max = max(memory_samples)
print(f"\nMemory stability ({iterations} iterations of {chunks} chunks):")
print(f" Start: {memory_start / 1024 / 1024:.2f}MB")
print(f" End: {memory_end / 1024 / 1024:.2f}MB")
print(f" Max: {memory_max / 1024 / 1024:.2f}MB")
print(f" Delta: {(memory_end - memory_start) / 1024 / 1024:.2f}MB")
results.add(
f"Memory stability ({iterations} x {chunks} chunks)",
iterations=iterations * chunks,
duration=0.001, # Use small non-zero value to avoid division by zero
memory_start=memory_start,
memory_end=memory_end
)
class MockClientReader:
"""Mock async reader that simulates receiving streaming messages."""
def __init__(self, num_chunks, chunk_data):
self.num_chunks = num_chunks
self.chunk_data = chunk_data
self._chunk_idx = 0
self._messages = []
self._build_messages()
self._pos = 0
self._data = b''
for msg in self._messages:
self._data += DirtyProtocol.encode(msg)
def _build_messages(self):
for i in range(self.num_chunks):
self._messages.append(make_chunk_message(f"bench-{i}", self.chunk_data))
self._messages.append(make_end_message(f"bench-end"))
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class MockClientWriter:
"""Mock async writer for client connection."""
def __init__(self):
self._buffer = b""
self._closed = False
def write(self, data):
self._buffer += data
async def drain(self):
pass
def close(self):
self._closed = True
async def wait_closed(self):
pass
async def benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000):
"""
Benchmark DirtyAsyncStreamIterator directly.
Measures async iterator overhead vs raw message reading.
"""
chunk_data = "x" * chunk_size
# Create mock client with mock reader/writer
client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
client._reader = MockClientReader(num_chunks, chunk_data)
client._writer = MockClientWriter()
gc.collect()
tracemalloc.start()
memory_start = tracemalloc.get_traced_memory()[0]
start = time.perf_counter()
# Use the async stream iterator directly
iterator = DirtyAsyncStreamIterator(client, "benchmark:App", "stream", (), {})
iterator._started = True # Skip the request sending
iterator._request_id = "bench-async"
iterator._deadline = time.perf_counter() + 300 # 5 min deadline
iterator._last_chunk_time = time.perf_counter()
chunks_received = 0
bytes_received = 0
async for chunk in iterator:
chunks_received += 1
bytes_received += len(chunk)
duration = time.perf_counter() - start
memory_end = tracemalloc.get_traced_memory()[0]
tracemalloc.stop()
results.add(
f"Async client streaming ({chunk_size}B chunks, {num_chunks} chunks)",
iterations=1,
duration=duration,
chunks=chunks_received,
bytes_total=bytes_received,
memory_start=memory_start,
memory_end=memory_end
)
async def benchmark_sync_client_streaming(results, chunk_size=1024, num_chunks=1000):
"""
Benchmark DirtyStreamIterator directly (for comparison with async).
Note: This runs the sync iterator within an async context for comparison.
"""
chunk_data = "x" * chunk_size
# Build raw message data
messages_data = b''
for i in range(num_chunks):
msg = make_chunk_message(f"bench-{i}", chunk_data)
messages_data += DirtyProtocol.encode(msg)
messages_data += DirtyProtocol.encode(make_end_message("bench-end"))
# Create a mock socket-like object
class MockSocket:
def __init__(self, data):
self._data = data
self._pos = 0
self._timeout = None
def recv(self, n, flags=0):
if self._pos >= len(self._data):
return b''
result = self._data[self._pos:self._pos + n]
self._pos += len(result)
return result
def settimeout(self, timeout):
self._timeout = timeout
# Create mock client
client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
client._sock = MockSocket(messages_data)
gc.collect()
tracemalloc.start()
memory_start = tracemalloc.get_traced_memory()[0]
start = time.perf_counter()
# Use the sync stream iterator
iterator = DirtyStreamIterator(client, "benchmark:App", "stream", (), {})
iterator._started = True # Skip the request sending
iterator._request_id = "bench-sync"
iterator._deadline = time.perf_counter() + 300 # 5 min deadline
iterator._last_chunk_time = time.perf_counter()
chunks_received = 0
bytes_received = 0
for chunk in iterator:
chunks_received += 1
bytes_received += len(chunk)
duration = time.perf_counter() - start
memory_end = tracemalloc.get_traced_memory()[0]
tracemalloc.stop()
results.add(
f"Sync client streaming ({chunk_size}B chunks, {num_chunks} chunks)",
iterations=1,
duration=duration,
chunks=chunks_received,
bytes_total=bytes_received,
memory_start=memory_start,
memory_end=memory_end
)
async def benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000):
"""
Compare stream() vs stream_async() performance with the same workload.
"""
chunk_data = "x" * chunk_size
# --- Sync test ---
messages_data = b''
for i in range(num_chunks):
msg = make_chunk_message(f"bench-{i}", chunk_data)
messages_data += DirtyProtocol.encode(msg)
messages_data += DirtyProtocol.encode(make_end_message("bench-end"))
class MockSocket:
def __init__(self, data):
self._data = data
self._pos = 0
self._timeout = None
def recv(self, n, flags=0):
if self._pos >= len(self._data):
return b''
result = self._data[self._pos:self._pos + n]
self._pos += len(result)
return result
def settimeout(self, timeout):
self._timeout = timeout
sync_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
sync_client._sock = MockSocket(messages_data)
gc.collect()
sync_start = time.perf_counter()
sync_iter = DirtyStreamIterator(sync_client, "benchmark:App", "stream", (), {})
sync_iter._started = True
sync_iter._request_id = "bench-sync"
sync_iter._deadline = time.perf_counter() + 300 # 5 min deadline
sync_iter._last_chunk_time = time.perf_counter()
sync_chunks = 0
for _ in sync_iter:
sync_chunks += 1
sync_duration = time.perf_counter() - sync_start
# --- Async test ---
async_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
async_client._reader = MockClientReader(num_chunks, chunk_data)
async_client._writer = MockClientWriter()
gc.collect()
async_start = time.perf_counter()
async_iter = DirtyAsyncStreamIterator(async_client, "benchmark:App", "stream", (), {})
async_iter._started = True
async_iter._request_id = "bench-async"
async_iter._deadline = time.perf_counter() + 300 # 5 min deadline
async_iter._last_chunk_time = time.perf_counter()
async_chunks = 0
async for _ in async_iter:
async_chunks += 1
async_duration = time.perf_counter() - async_start
# Report comparison
print(f"\nSync vs Async Client Streaming Comparison ({num_chunks} x {chunk_size}B chunks):")
print(f" Sync: {sync_duration * 1000:.3f}ms ({sync_chunks} chunks)")
print(f" Async: {async_duration * 1000:.3f}ms ({async_chunks} chunks)")
if sync_duration > 0:
ratio = async_duration / sync_duration
print(f" Ratio (async/sync): {ratio:.3f}x")
results.add(
f"Sync client streaming comparison ({chunk_size}B, {num_chunks} chunks)",
iterations=1,
duration=sync_duration,
chunks=sync_chunks,
bytes_total=sync_chunks * chunk_size
)
results.add(
f"Async client streaming comparison ({chunk_size}B, {num_chunks} chunks)",
iterations=1,
duration=async_duration,
chunks=async_chunks,
bytes_total=async_chunks * chunk_size
)
async def run_quick_benchmarks():
"""Run quick benchmarks."""
results = BenchmarkResults()
print("Running quick benchmarks...")
await benchmark_worker_streaming_throughput(results, chunk_size=64, num_chunks=1000)
await benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000)
await benchmark_arbiter_forwarding(results, num_chunks=1000)
await benchmark_streaming_latency(results, iterations=50)
# Async client streaming benchmarks
await benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000)
await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000)
return results
async def run_full_benchmarks():
"""Run full benchmark suite including stress tests."""
results = BenchmarkResults()
print("Running full benchmark suite...")
# Throughput tests with different chunk sizes
for chunk_size in [1, 64, 1024, 65536]:
await benchmark_worker_streaming_throughput(
results, chunk_size=chunk_size, num_chunks=1000
)
# Arbiter forwarding
await benchmark_arbiter_forwarding(results, num_chunks=10000)
# Latency
await benchmark_streaming_latency(results, iterations=100)
# Concurrent streams
await benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100)
await benchmark_concurrent_streams(results, num_streams=50, chunks_per_stream=100)
# Memory stability
await benchmark_memory_stability(results, iterations=20, chunks=1000)
# Async client streaming benchmarks
for chunk_size in [64, 1024, 65536]:
await benchmark_async_client_streaming(results, chunk_size=chunk_size, num_chunks=1000)
await benchmark_sync_client_streaming(results, chunk_size=chunk_size, num_chunks=1000)
# Comparison benchmark
await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=5000)
return results
def main():
parser = argparse.ArgumentParser(description="Dirty streaming benchmarks")
parser.add_argument("--quick", action="store_true", help="Run quick benchmarks only")
parser.add_argument("--full", action="store_true", help="Run full benchmark suite")
parser.add_argument("--output", "-o", help="Output JSON file path")
args = parser.parse_args()
if args.full:
results = asyncio.run(run_full_benchmarks())
else:
results = asyncio.run(run_quick_benchmarks())
results.display()
if args.output:
results.save_json(args.output)
else:
# Save to default location
output_dir = os.path.dirname(os.path.abspath(__file__))
results_dir = os.path.join(output_dir, "results")
os.makedirs(results_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = os.path.join(results_dir, f"streaming_benchmark_{timestamp}.json")
results.save_json(output_file)
if __name__ == "__main__":
main()

View File

@ -200,6 +200,160 @@ async def my_view(request):
return result
```
## Streaming
Dirty Arbiters support streaming responses for use cases like LLM token generation, where data is produced incrementally. This enables real-time delivery of results without waiting for complete execution.
### Streaming with Generators
Any dirty app action that returns a generator (sync or async) automatically streams chunks to the client:
```python
# myapp/llm.py
from gunicorn.dirty import DirtyApp
class LLMApp(DirtyApp):
def init(self):
from transformers import pipeline
self.generator = pipeline("text-generation", model="gpt2")
def generate(self, prompt):
"""Sync streaming - yields tokens."""
for token in self.generator(prompt, stream=True):
yield token["generated_text"]
async def generate_async(self, prompt):
"""Async streaming - yields tokens."""
import openai
client = openai.AsyncOpenAI()
stream = await client.chat.completions.create(
model="gpt-4",
messages=[{"role": "user", "content": prompt}],
stream=True
)
async for chunk in stream:
if chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
def close(self):
pass
```
### Client Streaming API
Use `stream()` for sync workers and `stream_async()` for async workers:
**Sync Workers (sync, gthread):**
```python
from gunicorn.dirty import get_dirty_client
def generate_view(request):
client = get_dirty_client()
def generate_response():
for chunk in client.stream("myapp.llm:LLMApp", "generate", request.prompt):
yield chunk
return StreamingResponse(generate_response())
```
**Async Workers (ASGI):**
```python
from gunicorn.dirty import get_dirty_client_async
async def generate_view(request):
client = await get_dirty_client_async()
async def generate_response():
async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", request.prompt):
yield chunk
return StreamingResponse(generate_response())
```
### Streaming Protocol
Streaming uses a simple protocol with three message types:
1. **Chunk** (`type: "chunk"`) - Contains partial data
2. **End** (`type: "end"`) - Signals stream completion
3. **Error** (`type: "error"`) - Signals error during streaming
Example message flow:
```
Client -> Arbiter -> Worker: request
Worker -> Arbiter -> Client: chunk (data: "Hello")
Worker -> Arbiter -> Client: chunk (data: " ")
Worker -> Arbiter -> Client: chunk (data: "World")
Worker -> Arbiter -> Client: end
```
### Error Handling in Streams
Errors during streaming are delivered as error messages:
```python
def generate_view(request):
client = get_dirty_client()
try:
for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt):
yield chunk
except DirtyError as e:
# Error occurred mid-stream
yield f"\n[Error: {e.message}]"
```
### Best Practices for Streaming
1. **Use async generators for I/O-bound streaming** - e.g., API calls to external services
2. **Use sync generators for CPU-bound streaming** - e.g., local model inference
3. **Yield frequently** - Heartbeats are sent during streaming to keep workers alive
4. **Keep chunks small** - Smaller chunks provide better perceived latency
5. **Handle client disconnection** - Streams continue even if client disconnects; design accordingly
### Flask Example
```python
from flask import Flask, Response
from gunicorn.dirty import get_dirty_client
app = Flask(__name__)
@app.route("/chat", methods=["POST"])
def chat():
prompt = request.json.get("prompt")
client = get_dirty_client()
def stream():
for token in client.stream("myapp.llm:LLMApp", "generate", prompt):
yield f"data: {token}\n\n"
return Response(stream(), content_type="text/event-stream")
```
### FastAPI Example
```python
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from gunicorn.dirty import get_dirty_client_async
app = FastAPI()
@app.post("/chat")
async def chat(prompt: str):
client = await get_dirty_client_async()
async def stream():
async for token in client.stream_async("myapp.llm:LLMApp", "generate", prompt):
yield f"data: {token}\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
```
## Lifecycle Hooks
Dirty Arbiters provide hooks for customization:

View File

@ -199,6 +199,7 @@ class DirtyArbiter:
Handle a connection from an HTTP worker.
Routes requests to available dirty workers and returns responses.
Supports both regular responses and streaming (chunk-based) responses.
"""
self.log.debug("New client connection from HTTP worker")
@ -209,11 +210,8 @@ class DirtyArbiter:
except asyncio.IncompleteReadError:
break
# Route request to a dirty worker
response = await self.route_request(message)
# Send response back to HTTP worker
await DirtyProtocol.write_message_async(writer, response)
# Route request to a dirty worker - pass writer for streaming
await self.route_request(message, writer)
except Exception as e:
self.log.error("Client connection error: %s", e)
finally:
@ -223,28 +221,31 @@ class DirtyArbiter:
except Exception:
pass
async def route_request(self, request):
async def route_request(self, request, client_writer):
"""
Route a request to an available dirty worker via queue.
Each worker has a dedicated queue and consumer task. Requests are
submitted to the queue and processed sequentially by the consumer.
For streaming responses, messages (chunks) are forwarded directly
to the client_writer as they arrive from the worker.
Args:
request: Request message dict
Returns:
Response message dict
client_writer: StreamWriter to send responses to client
"""
request_id = request.get("id", "unknown")
# Find an available worker
worker_pid = await self._get_available_worker()
if worker_pid is None:
return make_error_response(
response = make_error_response(
request_id,
DirtyError("No dirty workers available")
)
await DirtyProtocol.write_message_async(client_writer, response)
return
# Get queue (start consumer if needed)
if worker_pid not in self.worker_queues:
@ -253,17 +254,18 @@ class DirtyArbiter:
queue = self.worker_queues[worker_pid]
future = asyncio.get_running_loop().create_future()
# Submit request to queue
await queue.put((request, future))
# Submit request to queue with client writer for streaming support
await queue.put((request, client_writer, future))
# Wait for response
# Wait for completion (streaming messages forwarded by consumer)
try:
return await future
await future
except Exception as e:
return make_error_response(
response = make_error_response(
request_id,
DirtyWorkerError(f"Request failed: {e}", worker_id=worker_pid)
)
await DirtyProtocol.write_message_async(client_writer, response)
async def _start_worker_consumer(self, worker_pid):
"""Start a consumer task for a worker's request queue."""
@ -273,11 +275,13 @@ class DirtyArbiter:
async def consumer():
while self.alive:
try:
request, future = await queue.get()
request, client_writer, future = await queue.get()
try:
response = await self._execute_on_worker(worker_pid, request)
await self._execute_on_worker(
worker_pid, request, client_writer
)
if not future.done():
future.set_result(response)
future.set_result(None)
except Exception as e:
if not future.done():
future.set_exception(e)
@ -289,32 +293,65 @@ class DirtyArbiter:
task = asyncio.create_task(consumer())
self.worker_consumers[worker_pid] = task
async def _execute_on_worker(self, worker_pid, request):
"""Execute request on a specific worker (called by consumer)."""
async def _execute_on_worker(self, worker_pid, request, client_writer):
"""
Execute request on a specific worker (called by consumer).
Handles both regular responses and streaming (chunk-based) responses.
For streaming, chunk and end messages are forwarded directly to the
client_writer as they arrive from the worker.
"""
request_id = request.get("id", "unknown")
try:
reader, writer = await self._get_worker_connection(worker_pid)
await DirtyProtocol.write_message_async(writer, request)
response = await asyncio.wait_for(
DirtyProtocol.read_message_async(reader),
timeout=self.cfg.dirty_timeout
)
return response
except asyncio.TimeoutError:
return make_error_response(
request_id,
DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout)
)
# Read messages until we get a response, end, or error
while True:
try:
message = await asyncio.wait_for(
DirtyProtocol.read_message_async(reader),
timeout=self.cfg.dirty_timeout
)
except asyncio.TimeoutError:
response = make_error_response(
request_id,
DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout)
)
await DirtyProtocol.write_message_async(client_writer, response)
return
msg_type = message.get("type")
# Forward chunk messages to client
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
await DirtyProtocol.write_message_async(client_writer, message)
continue
# Forward end message and complete
if msg_type == DirtyProtocol.MSG_TYPE_END:
await DirtyProtocol.write_message_async(client_writer, message)
return
# Forward response or error and complete
if msg_type in (DirtyProtocol.MSG_TYPE_RESPONSE,
DirtyProtocol.MSG_TYPE_ERROR):
await DirtyProtocol.write_message_async(client_writer, message)
return
# Unknown message type - log and continue
self.log.warning("Unknown message type from worker: %s", msg_type)
except Exception as e:
self.log.error("Error executing on worker %s: %s", worker_pid, e)
self._close_worker_connection(worker_pid)
return make_error_response(
response = make_error_response(
request_id,
DirtyWorkerError(f"Worker communication failed: {e}",
worker_id=worker_pid)
)
await DirtyProtocol.write_message_async(client_writer, response)
async def _get_available_worker(self):
"""

View File

@ -14,6 +14,7 @@ import contextvars
import os
import socket
import threading
import time
import uuid
from .errors import (
@ -134,6 +135,34 @@ class DirtyClient:
raise
raise DirtyConnectionError(f"Communication error: {e}") from e
def stream(self, app_path, action, *args, **kwargs):
"""
Stream results from a dirty app action (sync).
This method returns an iterator that yields chunks from a streaming
response. Use this for actions that return generators.
Args:
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
action: Action to call on the app
*args: Positional arguments
**kwargs: Keyword arguments
Yields:
Chunks of data from the streaming response
Raises:
DirtyConnectionError: If connection fails
DirtyTimeoutError: If operation times out
DirtyError: If execution fails
Example::
for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt):
print(chunk, end="", flush=True)
"""
return DirtyStreamIterator(self, app_path, action, args, kwargs)
def _handle_response(self, response):
"""Handle response message, extracting result or raising error."""
msg_type = response.get("type")
@ -247,6 +276,34 @@ class DirtyClient:
raise
raise DirtyConnectionError(f"Communication error: {e}") from e
def stream_async(self, app_path, action, *args, **kwargs):
"""
Stream results from a dirty app action (async).
This method returns an async iterator that yields chunks from a
streaming response. Use this for actions that return generators.
Args:
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
action: Action to call on the app
*args: Positional arguments
**kwargs: Keyword arguments
Yields:
Chunks of data from the streaming response
Raises:
DirtyConnectionError: If connection fails
DirtyTimeoutError: If operation times out
DirtyError: If execution fails
Example::
async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", prompt):
await response.write(chunk)
"""
return DirtyAsyncStreamIterator(self, app_path, action, args, kwargs)
async def _close_async(self):
"""Close the async connection."""
if self._writer is not None:
@ -281,6 +338,308 @@ class DirtyClient:
await self.close_async()
# =============================================================================
# Stream Iterator classes
# =============================================================================
class DirtyStreamIterator:
"""
Iterator for streaming responses from dirty workers (sync).
This class is returned by `DirtyClient.stream()` and yields chunks
from a streaming response until the end message is received.
Uses a deadline-based timeout approach:
- Total stream timeout: limits entire stream duration
- Idle timeout: limits gap between chunks (defaults to total timeout)
"""
# Default idle timeout between chunks (seconds)
DEFAULT_IDLE_TIMEOUT = 30.0
# Threshold for applying per-read timeout (seconds)
# When remaining time is above this, use a larger timeout for efficiency
_TIMEOUT_THRESHOLD = 5.0
def __init__(self, client, app_path, action, args, kwargs,
idle_timeout=None):
self.client = client
self.app_path = app_path
self.action = action
self.args = args
self.kwargs = kwargs
self._started = False
self._exhausted = False
self._request_id = None
self._deadline = None
self._last_chunk_time = None
# Idle timeout: max time between chunks
self._idle_timeout = (
idle_timeout if idle_timeout is not None
else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout)
)
def __iter__(self):
return self
def __next__(self):
if self._exhausted:
raise StopIteration
if not self._started:
self._start_request()
self._started = True
return self._read_next_chunk()
def _start_request(self):
"""Send the initial request to the arbiter."""
with self.client._lock:
if self.client._sock is None:
self.client.connect()
# Set deadline for entire stream
now = time.monotonic()
self._deadline = now + self.client.timeout
self._last_chunk_time = now
self._request_id = str(uuid.uuid4())
request = make_request(
self._request_id,
self.app_path,
self.action,
args=self.args,
kwargs=self.kwargs,
)
DirtyProtocol.write_message(self.client._sock, request)
def _read_next_chunk(self):
"""Read the next message from the stream."""
with self.client._lock:
# Check total stream deadline
now = time.monotonic()
if now >= self._deadline:
self._exhausted = True
raise DirtyTimeoutError(
"Stream exceeded total timeout",
timeout=self.client.timeout
)
remaining = self._deadline - now
# Set socket timeout based on remaining time
# Fast path: use larger timeout when plenty of time remains
if remaining > self._TIMEOUT_THRESHOLD:
read_timeout = self._TIMEOUT_THRESHOLD
else:
read_timeout = min(remaining, self._idle_timeout)
try:
self.client._sock.settimeout(read_timeout)
response = DirtyProtocol.read_message(self.client._sock)
except socket.timeout:
# Check which timeout was hit
now = time.monotonic()
if now >= self._deadline:
self._exhausted = True
raise DirtyTimeoutError(
"Stream exceeded total timeout",
timeout=self.client.timeout
)
else:
idle_duration = now - self._last_chunk_time
self._exhausted = True
raise DirtyTimeoutError(
f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)",
timeout=self._idle_timeout
)
except Exception as e:
self._exhausted = True
self.client._close_socket()
raise DirtyConnectionError(f"Communication error: {e}") from e
# Update last chunk time for idle tracking
self._last_chunk_time = time.monotonic()
msg_type = response.get("type")
# Chunk message - return the data
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
return response.get("data")
# End message - stop iteration
if msg_type == DirtyProtocol.MSG_TYPE_END:
self._exhausted = True
raise StopIteration
# Error message - raise exception
if msg_type == DirtyProtocol.MSG_TYPE_ERROR:
self._exhausted = True
error_info = response.get("error", {})
raise DirtyError.from_dict(error_info)
# Regular response - shouldn't happen for streaming, but handle it
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
self._exhausted = True
# Return the result as the only chunk then stop
raise StopIteration
# Unknown type
self._exhausted = True
raise DirtyError(f"Unknown message type: {msg_type}")
class DirtyAsyncStreamIterator:
"""
Async iterator for streaming responses from dirty workers.
This class is returned by `DirtyClient.stream_async()` and yields chunks
from a streaming response until the end message is received.
Uses a deadline-based timeout approach for efficiency:
- Total stream timeout: limits entire stream duration
- Idle timeout: limits gap between chunks (defaults to total timeout)
This avoids the overhead of asyncio.wait_for() on every chunk read.
"""
# Default idle timeout between chunks (seconds)
DEFAULT_IDLE_TIMEOUT = 30.0
def __init__(self, client, app_path, action, args, kwargs,
idle_timeout=None):
self.client = client
self.app_path = app_path
self.action = action
self.args = args
self.kwargs = kwargs
self._started = False
self._exhausted = False
self._request_id = None
self._deadline = None
self._last_chunk_time = None
# Idle timeout: max time between chunks
self._idle_timeout = (
idle_timeout if idle_timeout is not None
else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout)
)
def __aiter__(self):
return self
async def __anext__(self):
if self._exhausted:
raise StopAsyncIteration
if not self._started:
await self._start_request()
self._started = True
return await self._read_next_chunk()
async def _start_request(self):
"""Send the initial request to the arbiter."""
if self.client._writer is None:
await self.client.connect_async()
# Set deadline for entire stream
now = time.monotonic()
self._deadline = now + self.client.timeout
self._last_chunk_time = now
self._request_id = str(uuid.uuid4())
request = make_request(
self._request_id,
self.app_path,
self.action,
args=self.args,
kwargs=self.kwargs,
)
await DirtyProtocol.write_message_async(self.client._writer, request)
# Threshold for applying timeout wrapper (seconds)
# When remaining time is above this, skip timeout for performance
_TIMEOUT_THRESHOLD = 5.0
async def _read_next_chunk(self):
"""Read the next message from the stream."""
# Calculate remaining time until deadline
now = time.monotonic()
# Check total stream deadline
if now >= self._deadline:
self._exhausted = True
raise DirtyTimeoutError(
"Stream exceeded total timeout",
timeout=self.client.timeout
)
remaining = self._deadline - now
try:
# Fast path: skip timeout wrapper when we have plenty of time
# This avoids asyncio.wait_for() overhead for most chunks
if remaining > self._TIMEOUT_THRESHOLD:
response = await DirtyProtocol.read_message_async(
self.client._reader
)
else:
# Near deadline: apply timeout protection
read_timeout = min(remaining, self._idle_timeout)
response = await asyncio.wait_for(
DirtyProtocol.read_message_async(self.client._reader),
timeout=read_timeout
)
except asyncio.TimeoutError:
self._exhausted = True
now = time.monotonic()
if now >= self._deadline:
raise DirtyTimeoutError(
"Stream exceeded total timeout",
timeout=self.client.timeout
)
else:
idle_duration = now - self._last_chunk_time
raise DirtyTimeoutError(
f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)",
timeout=self._idle_timeout
)
except Exception as e:
self._exhausted = True
await self.client._close_async()
raise DirtyConnectionError(f"Communication error: {e}") from e
# Update last chunk time for idle tracking
self._last_chunk_time = time.monotonic()
msg_type = response.get("type")
# Chunk message - return the data
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
return response.get("data")
# End message - stop iteration
if msg_type == DirtyProtocol.MSG_TYPE_END:
self._exhausted = True
raise StopAsyncIteration
# Error message - raise exception
if msg_type == DirtyProtocol.MSG_TYPE_ERROR:
self._exhausted = True
error_info = response.get("error", {})
raise DirtyError.from_dict(error_info)
# Regular response - shouldn't happen for streaming
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
self._exhausted = True
raise StopAsyncIteration
# Unknown type
self._exhausted = True
raise DirtyError(f"Unknown message type: {msg_type}")
# =============================================================================
# Thread-local and context-local client management
# =============================================================================

View File

@ -312,3 +312,37 @@ def make_error_response(request_id: str, error) -> dict:
"id": request_id,
"error": error_dict,
}
def make_chunk_message(request_id: str, data) -> dict:
"""
Build a chunk message for streaming responses.
Args:
request_id: Request identifier this chunk belongs to
data: Chunk data (must be JSON-serializable)
Returns:
dict: Chunk message
"""
return {
"type": DirtyProtocol.MSG_TYPE_CHUNK,
"id": request_id,
"data": data,
}
def make_end_message(request_id: str) -> dict:
"""
Build an end-of-stream message.
Args:
request_id: Request identifier this ends
Returns:
dict: End message
"""
return {
"type": DirtyProtocol.MSG_TYPE_END,
"id": request_id,
}

View File

@ -69,6 +69,7 @@ operation will continue until the worker is killed by the arbiter.
import asyncio
from concurrent.futures import ThreadPoolExecutor
import inspect
import os
import signal
import traceback
@ -88,6 +89,8 @@ from .protocol import (
DirtyProtocol,
make_response,
make_error_response,
make_chunk_message,
make_end_message,
)
@ -296,11 +299,8 @@ class DirtyWorker:
# Connection closed
break
# Handle the request
response = await self.handle_request(message)
# Send response
await DirtyProtocol.write_message_async(writer, response)
# Handle the request - pass writer for streaming support
await self.handle_request(message, writer)
except Exception as e:
self.log.error("Connection error: %s", e)
finally:
@ -310,24 +310,28 @@ class DirtyWorker:
except Exception:
pass
async def handle_request(self, message):
async def handle_request(self, message, writer):
"""
Handle a single request message.
Supports both regular (non-streaming) and streaming responses.
For streaming, detects if the result is a generator and sends
chunk messages followed by an end message.
Args:
message: Request dict from protocol
Returns:
Response dict to send back
writer: StreamWriter for sending responses
"""
request_id = message.get("id", str(uuid.uuid4()))
msg_type = message.get("type")
if msg_type != DirtyProtocol.MSG_TYPE_REQUEST:
return make_error_response(
response = make_error_response(
request_id,
DirtyWorkerError(f"Unknown message type: {msg_type}")
)
await DirtyProtocol.write_message_async(writer, response)
return
app_path = message.get("app_path")
action = message.get("action")
@ -339,16 +343,107 @@ class DirtyWorker:
try:
result = await self.execute(app_path, action, args, kwargs)
return make_response(request_id, result)
# Check if result is a generator (streaming)
if inspect.isgenerator(result):
await self._stream_sync_generator(request_id, result, writer)
elif inspect.isasyncgen(result):
await self._stream_async_generator(request_id, result, writer)
else:
# Regular non-streaming response
response = make_response(request_id, result)
await DirtyProtocol.write_message_async(writer, response)
except Exception as e:
tb = traceback.format_exc()
self.log.error("Error executing %s.%s: %s\n%s",
app_path, action, e, tb)
return make_error_response(
response = make_error_response(
request_id,
DirtyAppError(str(e), app_path=app_path, action=action,
traceback=tb)
)
await DirtyProtocol.write_message_async(writer, response)
async def _stream_sync_generator(self, request_id, gen, writer):
"""
Stream chunks from a synchronous generator.
Args:
request_id: Request ID for the messages
gen: Sync generator to iterate
writer: StreamWriter for sending messages
"""
# Sentinel value to detect end of generator
# (StopIteration cannot be raised into a Future in Python 3.7+)
_EXHAUSTED = object()
def _get_next():
try:
return next(gen)
except StopIteration:
return _EXHAUSTED
try:
loop = asyncio.get_running_loop()
while True:
# Run next() in executor to avoid blocking event loop
chunk = await loop.run_in_executor(self._executor, _get_next)
if chunk is _EXHAUSTED:
break
# Send chunk message
await DirtyProtocol.write_message_async(
writer, make_chunk_message(request_id, chunk)
)
# Update heartbeat during long streams
self.notify()
# Send end message
await DirtyProtocol.write_message_async(
writer, make_end_message(request_id)
)
except Exception as e:
# Error during streaming - send error message
tb = traceback.format_exc()
self.log.error("Error during streaming: %s\n%s", e, tb)
response = make_error_response(
request_id,
DirtyAppError(str(e), traceback=tb)
)
await DirtyProtocol.write_message_async(writer, response)
finally:
gen.close()
async def _stream_async_generator(self, request_id, gen, writer):
"""
Stream chunks from an asynchronous generator.
Args:
request_id: Request ID for the messages
gen: Async generator to iterate
writer: StreamWriter for sending messages
"""
try:
async for chunk in gen:
# Send chunk message
await DirtyProtocol.write_message_async(
writer, make_chunk_message(request_id, chunk)
)
# Update heartbeat during long streams
self.notify()
# Send end message
await DirtyProtocol.write_message_async(
writer, make_end_message(request_id)
)
except Exception as e:
# Error during streaming - send error message
tb = traceback.format_exc()
self.log.error("Error during streaming: %s\n%s", e, tb)
response = make_error_response(
request_id,
DirtyAppError(str(e), traceback=tb)
)
await DirtyProtocol.write_message_async(writer, response)
finally:
await gen.aclose()
async def execute(self, app_path, action, args, kwargs):
"""

5
tests/dirty/__init__.py Normal file
View File

@ -0,0 +1,5 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty worker streaming functionality."""

View File

@ -0,0 +1,319 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty arbiter streaming functionality."""
import asyncio
import struct
from unittest import mock
import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_response,
make_chunk_message,
make_end_message,
make_error_response,
)
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.errors import DirtyError
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.closed = False
def write(self, data):
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
self.closed = True
async def wait_closed(self):
pass
def get_extra_info(self, name):
return None
class MockStreamReader:
"""Mock StreamReader that yields predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
def create_arbiter():
"""Create a test arbiter with mocked components."""
cfg = mock.Mock()
cfg.dirty_timeout = 30
cfg.dirty_workers = 1
cfg.dirty_apps = []
cfg.dirty_graceful_timeout = 30
cfg.on_dirty_starting = mock.Mock()
cfg.dirty_post_fork = mock.Mock()
cfg.dirty_worker_exit = mock.Mock()
log = mock.Mock()
with mock.patch('tempfile.mkdtemp', return_value='/tmp/test-dirty'):
arbiter = DirtyArbiter(cfg, log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()} # Fake worker
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
return arbiter
class TestArbiterStreamingForwarding:
"""Tests for arbiter streaming message forwarding."""
@pytest.mark.asyncio
async def test_forwards_chunk_messages(self):
"""Test that arbiter forwards chunk messages to client."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
# Mock worker connection that returns chunks
chunk1 = make_chunk_message("req-123", "Hello")
chunk2 = make_chunk_message("req-123", " World")
end = make_end_message("req-123")
mock_reader = MockStreamReader([chunk1, chunk2, end])
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have forwarded all messages
assert len(client_writer.messages) == 3
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[0]["data"] == "Hello"
assert client_writer.messages[1]["type"] == "chunk"
assert client_writer.messages[1]["data"] == " World"
assert client_writer.messages[2]["type"] == "end"
@pytest.mark.asyncio
async def test_forwards_regular_response(self):
"""Test that arbiter forwards regular response to client."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", {"result": 42})
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "response"
assert client_writer.messages[0]["result"] == {"result": 42}
@pytest.mark.asyncio
async def test_forwards_error_mid_stream(self):
"""Test that arbiter forwards error during streaming."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
chunk = make_chunk_message("req-123", "First")
error = make_error_response("req-123", DirtyError("Something broke"))
mock_reader = MockStreamReader([chunk, error])
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 2
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[1]["type"] == "error"
@pytest.mark.asyncio
async def test_timeout_during_streaming(self):
"""Test that timeout during streaming sends error."""
arbiter = create_arbiter()
arbiter.cfg.dirty_timeout = 0.01 # Very short timeout
client_writer = MockStreamWriter()
# Reader that times out
class TimeoutReader:
async def readexactly(self, n):
await asyncio.sleep(1) # Longer than timeout
async def mock_get_connection(pid):
return TimeoutReader(), MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "error"
assert "timeout" in client_writer.messages[0]["error"]["message"].lower()
class TestArbiterRouteRequestStreaming:
"""Tests for route_request with streaming support."""
@pytest.mark.asyncio
async def test_route_request_no_workers(self):
"""Test route_request when no workers available."""
arbiter = create_arbiter()
arbiter.workers = {} # No workers
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "generate")
await arbiter.route_request(request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "error"
assert "No dirty workers" in client_writer.messages[0]["error"]["message"]
@pytest.mark.asyncio
async def test_route_request_starts_consumer(self):
"""Test that route_request starts consumer if needed."""
arbiter = create_arbiter()
# Mock _execute_on_worker to complete immediately
async def mock_execute(pid, request, client_writer):
response = make_response("req-123", "result")
await DirtyProtocol.write_message_async(client_writer, response)
arbiter._execute_on_worker = mock_execute
client_writer = MockStreamWriter()
request = make_request("req-123", "test:App", "compute")
# Worker queue should be created
assert 1234 not in arbiter.worker_queues
await arbiter.route_request(request, client_writer)
# Consumer should have been started
assert 1234 in arbiter.worker_queues
assert 1234 in arbiter.worker_consumers
# Clean up
arbiter.worker_consumers[1234].cancel()
class TestArbiterStreamingManyChunks:
"""Tests for streaming with many chunks."""
@pytest.mark.asyncio
async def test_forwards_many_chunks(self):
"""Test that arbiter forwards many chunks correctly."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
# Generate 50 chunks + end
messages = []
for i in range(50):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
mock_reader = MockStreamReader(messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 51
assert client_writer.messages[0]["data"] == "chunk-0"
assert client_writer.messages[49]["data"] == "chunk-49"
assert client_writer.messages[50]["type"] == "end"
class TestArbiterBackwardCompatibility:
"""Tests for backward compatibility with non-streaming."""
@pytest.mark.asyncio
async def test_handles_regular_response(self):
"""Test that regular (non-streaming) responses still work."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
response = make_response("req-123", [1, 2, 3, 4, 5])
mock_reader = MockStreamReader([response])
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "get_list")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "response"
assert client_writer.messages[0]["result"] == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
async def test_handles_error_response(self):
"""Test that error responses still work."""
arbiter = create_arbiter()
client_writer = MockStreamWriter()
error = make_error_response("req-123", DirtyError("Something failed"))
mock_reader = MockStreamReader([error])
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
request = make_request("req-123", "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "error"

View File

@ -0,0 +1,236 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty client sync streaming functionality."""
import socket
import struct
import pytest
from unittest import mock
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
)
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
class MockSocket:
"""Mock socket that returns predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
self._sent = []
self.closed = False
self._timeout = None
def sendall(self, data):
self._sent.append(data)
def recv(self, n, flags=0):
if self._pos >= len(self._data):
return b''
end = min(self._pos + n, len(self._data))
result = self._data[self._pos:end]
self._pos = end
return result
def settimeout(self, timeout):
self._timeout = timeout
def close(self):
self.closed = True
def create_client_with_mock_socket(messages):
"""Create a client with a mock socket returning the given messages."""
client = DirtyClient("/tmp/test.sock")
client._sock = MockSocket(messages)
return client
class TestDirtyStreamIterator:
"""Tests for DirtyStreamIterator."""
def test_stream_returns_iterator(self):
"""Test that stream() returns an iterator."""
client = DirtyClient("/tmp/test.sock")
result = client.stream("test:App", "generate")
assert isinstance(result, DirtyStreamIterator)
def test_stream_iterator_yields_chunks(self):
"""Test that stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
assert chunks == ["Hello", " ", "World"]
def test_stream_iterator_yields_complex_chunks(self):
"""Test that stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
assert len(chunks) == 2
assert chunks[0]["token"] == "Hello"
assert chunks[1]["token"] == "World"
def test_stream_iterator_handles_error(self):
"""Test that stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
]
client = create_client_with_mock_socket(messages)
iterator = client.stream("test:App", "generate")
# First chunk should work
chunk = next(iterator)
assert chunk == "First"
# Second should raise error
with pytest.raises(DirtyError) as exc_info:
next(iterator)
assert "Something broke" in str(exc_info.value)
def test_stream_iterator_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
assert chunks == []
def test_stream_iterator_stops_after_exhausted(self):
"""Test that iterator stays exhausted after StopIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
iterator = client.stream("test:App", "generate")
# Get the chunk
chunk = next(iterator)
assert chunk == "Only"
# Should stop
with pytest.raises(StopIteration):
next(iterator)
# Should stay stopped
with pytest.raises(StopIteration):
next(iterator)
def test_stream_iterator_with_for_loop(self):
"""Test stream iterator works in for loop."""
messages = [
make_chunk_message("req-123", "a"),
make_chunk_message("req-123", "b"),
make_chunk_message("req-123", "c"),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
result = ""
for chunk in client.stream("test:App", "generate"):
result += chunk
assert result == "abc"
def test_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first next() call."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
iterator = client.stream("test:App", "generate", "prompt_arg")
# Before iteration, no request sent
assert len(client._sock._sent) == 0
# First iteration sends request
next(iterator)
assert len(client._sock._sent) == 1
# Decode sent request
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
class TestDirtyStreamIteratorEdgeCases:
"""Edge cases for streaming."""
def test_stream_many_chunks(self):
"""Test streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
client = create_client_with_mock_socket(messages)
chunks = list(client.stream("test:App", "generate"))
assert len(chunks) == 100
assert chunks[0] == "chunk-0"
assert chunks[99] == "chunk-99"
def test_stream_with_kwargs(self):
"""Test streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
]
client = create_client_with_mock_socket(messages)
# Use kwargs
list(client.stream("test:App", "generate", "arg1", key="value"))
# Check the sent request includes kwargs
sent_data = client._sock._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}

View File

@ -0,0 +1,267 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty client async streaming functionality."""
import asyncio
import struct
import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_chunk_message,
make_end_message,
make_error_response,
)
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
class MockAsyncReader:
"""Mock async reader that returns predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class MockAsyncWriter:
"""Mock async writer that captures sent data."""
def __init__(self):
self._sent = []
self.closed = False
def write(self, data):
self._sent.append(data)
async def drain(self):
pass
def close(self):
self.closed = True
async def wait_closed(self):
pass
def create_async_client_with_mocks(messages):
"""Create a client with mock async reader/writer."""
client = DirtyClient("/tmp/test.sock")
client._reader = MockAsyncReader(messages)
client._writer = MockAsyncWriter()
return client
class TestDirtyAsyncStreamIterator:
"""Tests for DirtyAsyncStreamIterator."""
def test_stream_async_returns_async_iterator(self):
"""Test that stream_async() returns an async iterator."""
client = DirtyClient("/tmp/test.sock")
result = client.stream_async("test:App", "generate")
assert isinstance(result, DirtyAsyncStreamIterator)
@pytest.mark.asyncio
async def test_async_stream_yields_chunks(self):
"""Test that async stream iterator yields chunks correctly."""
messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert chunks == ["Hello", " ", "World"]
@pytest.mark.asyncio
async def test_async_stream_yields_complex_chunks(self):
"""Test that async stream iterator yields complex data types."""
messages = [
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
make_end_message("req-123"),
]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0]["token"] == "Hello"
assert chunks[1]["token"] == "World"
@pytest.mark.asyncio
async def test_async_stream_handles_error(self):
"""Test that async stream iterator raises on error message."""
messages = [
make_chunk_message("req-123", "First"),
make_error_response("req-123", DirtyError("Something broke")),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate")
# First chunk should work
chunk = await iterator.__anext__()
assert chunk == "First"
# Second should raise error
with pytest.raises(DirtyError) as exc_info:
await iterator.__anext__()
assert "Something broke" in str(exc_info.value)
@pytest.mark.asyncio
async def test_async_stream_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message("req-123")]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert chunks == []
@pytest.mark.asyncio
async def test_async_stream_stops_after_exhausted(self):
"""Test that async iterator stays exhausted after StopAsyncIteration."""
messages = [
make_chunk_message("req-123", "Only"),
make_end_message("req-123"),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate")
# Get the chunk
chunk = await iterator.__anext__()
assert chunk == "Only"
# Should stop
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
# Should stay stopped
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
@pytest.mark.asyncio
async def test_async_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first async iteration."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate", "prompt_arg")
# Before iteration, no request sent
assert len(client._writer._sent) == 0
# First iteration sends request
await iterator.__anext__()
assert len(client._writer._sent) == 1
# Decode sent request
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
)
assert request["type"] == "request"
assert request["app_path"] == "test:App"
assert request["action"] == "generate"
assert request["args"] == ["prompt_arg"]
class TestDirtyAsyncStreamIteratorEdgeCases:
"""Edge cases for async streaming."""
@pytest.mark.asyncio
async def test_async_stream_many_chunks(self):
"""Test async streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
messages.append(make_end_message("req-123"))
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert len(chunks) == 100
assert chunks[0] == "chunk-0"
assert chunks[99] == "chunk-99"
@pytest.mark.asyncio
async def test_async_stream_with_kwargs(self):
"""Test async streaming with keyword arguments."""
messages = [
make_chunk_message("req-123", "data"),
make_end_message("req-123"),
]
client = create_async_client_with_mocks(messages)
# Use kwargs
chunks = []
async for chunk in client.stream_async("test:App", "generate", "arg1", key="value"):
chunks.append(chunk)
# Check the sent request includes kwargs
sent_data = client._writer._sent[0]
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
sent_data[:DirtyProtocol.HEADER_SIZE]
)[0]
request = DirtyProtocol.decode(
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
)
assert request["args"] == ["arg1"]
assert request["kwargs"] == {"key": "value"}
class TestDirtyAsyncStreamTimeout:
"""Tests for async streaming timeout handling."""
@pytest.mark.asyncio
async def test_async_stream_timeout(self):
"""Test that timeout during async streaming raises DirtyTimeoutError."""
client = DirtyClient("/tmp/test.sock", timeout=0.01)
# Create a reader that times out
class SlowReader:
async def readexactly(self, n):
await asyncio.sleep(1) # Longer than timeout
client._reader = SlowReader()
client._writer = MockAsyncWriter()
iterator = client.stream_async("test:App", "generate")
with pytest.raises(DirtyTimeoutError):
await iterator.__anext__()

View File

@ -0,0 +1,469 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Integration tests for dirty streaming functionality.
These tests verify the full streaming pipeline:
client -> arbiter -> worker -> generator -> chunks -> client
"""
import asyncio
import os
import struct
import tempfile
import pytest
from unittest import mock
from gunicorn.config import Config
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
)
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.client import DirtyClient
from gunicorn.dirty.errors import DirtyError
class MockLog:
"""Mock logger for testing."""
def __init__(self):
self.messages = []
def debug(self, msg, *args):
self.messages.append(("debug", msg % args if args else msg))
def info(self, msg, *args):
self.messages.append(("info", msg % args if args else msg))
def warning(self, msg, *args):
self.messages.append(("warning", msg % args if args else msg))
def error(self, msg, *args):
self.messages.append(("error", msg % args if args else msg))
def close_on_exec(self):
pass
def reopen_files(self):
pass
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.closed = False
def write(self, data):
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
self.closed = True
async def wait_closed(self):
pass
def get_extra_info(self, name):
return None
class MockStreamReader:
"""Mock StreamReader that yields predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class TestStreamingEndToEnd:
"""End-to-end streaming tests using mocked components."""
@pytest.mark.asyncio
async def test_sync_generator_end_to_end(self):
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
# Simulate what a worker would produce for a sync generator
worker_messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
]
# Create an arbiter with mocked worker connection
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
# Mock worker connection
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
# Create client writer to capture messages
client_writer = MockStreamWriter()
# Execute request through arbiter
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Verify all messages were forwarded
assert len(client_writer.messages) == 4
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[0]["data"] == "Hello"
assert client_writer.messages[1]["data"] == " "
assert client_writer.messages[2]["data"] == "World"
assert client_writer.messages[3]["type"] == "end"
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_async_generator_end_to_end(self):
"""Test complete flow: async generator -> worker -> arbiter -> client."""
worker_messages = [
make_chunk_message("req-456", "Async"),
make_chunk_message("req-456", " "),
make_chunk_message("req-456", "Stream"),
make_end_message("req-456"),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-456", "test:App", "async_generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 4
assert client_writer.messages[0]["data"] == "Async"
assert client_writer.messages[3]["type"] == "end"
arbiter._cleanup_sync()
class TestStreamingErrorHandling:
"""Tests for error handling during streaming."""
@pytest.mark.asyncio
async def test_error_mid_stream(self):
"""Test that errors during streaming are properly forwarded."""
worker_messages = [
make_chunk_message("req-789", "First"),
make_chunk_message("req-789", "Second"),
make_error_response("req-789", DirtyError("Stream failed")),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-789", "test:App", "generate_with_error")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 2 chunks + 1 error
assert len(client_writer.messages) == 3
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[1]["type"] == "chunk"
assert client_writer.messages[2]["type"] == "error"
assert "Stream failed" in client_writer.messages[2]["error"]["message"]
arbiter._cleanup_sync()
class TestStreamingBackwardCompatibility:
"""Tests for backward compatibility with non-streaming responses."""
@pytest.mark.asyncio
async def test_non_streaming_response_still_works(self):
"""Test that regular (non-streaming) responses still work."""
worker_messages = [
make_response("req-abc", {"result": 42, "data": [1, 2, 3]}),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-abc", "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 1 response
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "response"
assert client_writer.messages[0]["result"]["result"] == 42
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_error_response_still_works(self):
"""Test that error responses still work."""
worker_messages = [
make_error_response("req-def", DirtyError("Something failed")),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-def", "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "error"
arbiter._cleanup_sync()
class TestStreamingWorkerIntegration:
"""Integration tests for worker streaming with execute."""
@pytest.mark.asyncio
async def test_worker_handles_sync_generator(self):
"""Test worker properly handles sync generator from execute."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=os.getpid(),
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None
worker.tmp = mock.Mock()
writer = MockStreamWriter()
# Mock execute to return a sync generator
def sync_gen():
yield "one"
yield "two"
yield "three"
async def mock_execute(app_path, action, args, kwargs):
return sync_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end
assert len(writer.messages) == 4
assert writer.messages[0]["data"] == "one"
assert writer.messages[1]["data"] == "two"
assert writer.messages[2]["data"] == "three"
assert writer.messages[3]["type"] == "end"
@pytest.mark.asyncio
async def test_worker_handles_async_generator(self):
"""Test worker properly handles async generator from execute."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=os.getpid(),
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None
worker.tmp = mock.Mock()
writer = MockStreamWriter()
# Mock execute to return an async generator
async def async_gen():
yield "async_one"
yield "async_two"
async def mock_execute(app_path, action, args, kwargs):
return async_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-456", "test:App", "async_generate")
await worker.handle_request(request, writer)
# Should have 2 chunks + 1 end
assert len(writer.messages) == 3
assert writer.messages[0]["data"] == "async_one"
assert writer.messages[1]["data"] == "async_two"
assert writer.messages[2]["type"] == "end"
class TestStreamingMixedScenarios:
"""Tests for mixed streaming scenarios."""
@pytest.mark.asyncio
async def test_large_stream(self):
"""Test streaming with many chunks."""
worker_messages = []
for i in range(500):
worker_messages.append(make_chunk_message("req-large", f"chunk-{i}"))
worker_messages.append(make_end_message("req-large"))
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-large", "test:App", "large_stream")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 500 chunks + 1 end
assert len(client_writer.messages) == 501
assert client_writer.messages[0]["data"] == "chunk-0"
assert client_writer.messages[499]["data"] == "chunk-499"
assert client_writer.messages[500]["type"] == "end"
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_stream_with_complex_data(self):
"""Test streaming with complex JSON-serializable data."""
worker_messages = [
make_chunk_message("req-complex", {
"token": "Hello",
"scores": [0.1, 0.2, 0.3],
"metadata": {"position": 0}
}),
make_chunk_message("req-complex", {
"token": "World",
"scores": [0.4, 0.5],
"metadata": {"position": 1}
}),
make_end_message("req-complex"),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-complex", "test:App", "complex_stream")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 3
assert client_writer.messages[0]["data"]["token"] == "Hello"
assert client_writer.messages[0]["data"]["scores"] == [0.1, 0.2, 0.3]
assert client_writer.messages[1]["data"]["metadata"]["position"] == 1
arbiter._cleanup_sync()

View File

@ -0,0 +1,415 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty worker streaming functionality."""
import asyncio
import struct
from unittest import mock
import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_chunk_message,
make_end_message,
)
from gunicorn.dirty.worker import DirtyWorker
class FakeStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
def write(self, data):
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
pass
async def wait_closed(self):
pass
def create_worker():
"""Create a test worker with mocked components."""
cfg = mock.Mock()
cfg.dirty_timeout = 30
cfg.dirty_threads = 1
cfg.env = None
cfg.uid = None
cfg.gid = None
cfg.initgroups = False
cfg.dirty_worker_init = mock.Mock()
cfg.umask = 0o22
log = mock.Mock()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=1,
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None # Use default executor for sync generator tests
worker.tmp = mock.Mock()
return worker
class TestWorkerSyncGeneratorStreaming:
"""Tests for sync generator streaming."""
@pytest.mark.asyncio
async def test_sync_generator_sends_chunks_and_end(self):
"""Test that sync generator sends chunk messages then end message."""
def generate_tokens():
yield "Hello"
yield " "
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
# Mock execute to return the sync generator directly
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
assert len(writer.messages) == 4
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
assert writer.messages[1]["data"] == " "
assert writer.messages[2]["type"] == "chunk"
assert writer.messages[2]["data"] == "World"
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
@pytest.mark.asyncio
async def test_sync_generator_error_mid_stream(self):
"""Test that error during streaming sends error message."""
def generate_with_error():
yield "First"
raise ValueError("Something went wrong")
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
assert len(writer.messages) == 2
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["data"] == "First"
assert writer.messages[1]["type"] == "error"
assert "Something went wrong" in writer.messages[1]["error"]["message"]
class TestWorkerAsyncGeneratorStreaming:
"""Tests for async generator streaming."""
@pytest.mark.asyncio
async def test_async_generator_sends_chunks_and_end(self):
"""Test that async generator sends chunk messages then end message."""
async def async_generate_tokens():
yield "Hello"
yield " "
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return async_generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
assert len(writer.messages) == 4
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
assert writer.messages[1]["data"] == " "
assert writer.messages[2]["type"] == "chunk"
assert writer.messages[2]["data"] == "World"
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == "req-123"
@pytest.mark.asyncio
async def test_async_generator_error_mid_stream(self):
"""Test that error during async streaming sends error message."""
async def async_generate_with_error():
yield "First"
raise ValueError("Async error")
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return async_generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
assert len(writer.messages) == 2
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["data"] == "First"
assert writer.messages[1]["type"] == "error"
assert "Async error" in writer.messages[1]["error"]["message"]
class TestWorkerNonStreamingBackwardCompat:
"""Tests for backward compatibility with non-streaming responses."""
@pytest.mark.asyncio
async def test_non_generator_returns_response(self):
"""Test that non-generator method returns regular response."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return args[0] + args[1]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "compute", args=(2, 3))
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["id"] == "req-123"
assert writer.messages[0]["result"] == 5
@pytest.mark.asyncio
async def test_list_result_not_treated_as_streaming(self):
"""Test that list result is not treated as streaming."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return [1, 2, 3, 4, 5]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "get_list")
await worker.handle_request(request, writer)
# Should have 1 response message (not 5 chunks)
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["result"] == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
async def test_error_in_execute_sends_error(self):
"""Test that error in execute sends error response."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
raise RuntimeError("Failed!")
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "fail")
await worker.handle_request(request, writer)
# Should have 1 error message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "error"
assert "Failed!" in writer.messages[0]["error"]["message"]
@pytest.mark.asyncio
async def test_none_result(self):
"""Test that None result works correctly."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return None
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "void")
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["result"] is None
class TestWorkerStreamingComplexData:
"""Tests for streaming with complex data types."""
@pytest.mark.asyncio
async def test_streaming_dict_chunks(self):
"""Test streaming chunks that are dictionaries."""
async def generate_tokens():
yield {"token": "Hello", "score": 0.9}
yield {"token": "World", "score": 0.8}
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
assert len(writer.messages) == 3 # 2 chunks + 1 end
assert writer.messages[0]["data"]["token"] == "Hello"
assert writer.messages[0]["data"]["score"] == 0.9
assert writer.messages[1]["data"]["token"] == "World"
@pytest.mark.asyncio
async def test_streaming_empty_generator(self):
"""Test streaming with empty generator."""
async def empty_generate():
return
yield # Make it a generator
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return empty_generate()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have just 1 end message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "end"
@pytest.mark.asyncio
async def test_streaming_many_chunks(self):
"""Test streaming with many chunks."""
async def generate_many():
for i in range(100):
yield f"chunk-{i}"
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_many()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 100 chunks + 1 end message
assert len(writer.messages) == 101
assert writer.messages[0]["data"] == "chunk-0"
assert writer.messages[99]["data"] == "chunk-99"
assert writer.messages[100]["type"] == "end"
class TestWorkerStreamingHeartbeat:
"""Tests for heartbeat updates during streaming."""
@pytest.mark.asyncio
async def test_heartbeat_updated_during_streaming(self):
"""Test that heartbeat is updated during streaming."""
async def generate_tokens():
yield "Hello"
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
# Track notify calls
notify_count = [0]
original_notify = worker.notify
def counting_notify():
notify_count[0] += 1
return original_notify() if callable(original_notify) else None
worker.notify = counting_notify
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have been notified at least once per chunk + initial
assert notify_count[0] >= 2 # At least one per chunk
class TestWorkerMessageTypeValidation:
"""Tests for message type validation."""
@pytest.mark.asyncio
async def test_unknown_message_type_sends_error(self):
"""Test that unknown message type sends error response."""
worker = create_worker()
writer = FakeStreamWriter()
# Send a message with unknown type
message = {"type": "unknown", "id": "req-123"}
await worker.handle_request(message, writer)
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "error"
assert "Unknown message type" in writer.messages[0]["error"]["message"]

View File

@ -7,6 +7,7 @@
import asyncio
import os
import signal
import struct
import tempfile
import pytest
@ -16,6 +17,41 @@ from gunicorn.dirty.errors import DirtyError
from gunicorn.dirty.protocol import DirtyProtocol, make_request
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.closed = False
def write(self, data):
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
self.closed = True
async def wait_closed(self):
pass
def get_extra_info(self, name):
return None
class MockLog:
"""Mock logger for testing."""
@ -141,8 +177,11 @@ class TestDirtyArbiterRouteRequest:
action="test"
)
response = await arbiter.route_request(request)
writer = MockStreamWriter()
await arbiter.route_request(request, writer)
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
assert "No dirty workers available" in response["error"]["message"]
@ -430,8 +469,11 @@ class TestDirtyArbiterRouteTimeout:
)
# This should fail because socket doesn't exist
response = await arbiter.route_request(request)
writer = MockStreamWriter()
await arbiter.route_request(request, writer)
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
# Either "Worker communication failed" or "Worker socket not ready"
assert "error" in response
@ -893,7 +935,8 @@ class TestDirtyArbiterQueueBehavior:
)
# This will fail (no socket), but consumer should be started
await arbiter.route_request(request)
writer = MockStreamWriter()
await arbiter.route_request(request, writer)
assert fake_pid in arbiter.worker_queues
assert fake_pid in arbiter.worker_consumers

View File

@ -15,6 +15,8 @@ from gunicorn.dirty.protocol import (
make_request,
make_response,
make_error_response,
make_chunk_message,
make_end_message,
)
from gunicorn.dirty.errors import (
DirtyError,
@ -323,6 +325,51 @@ class TestMessageBuilders:
assert response["error"]["error_type"] == "ValueError"
assert response["error"]["message"] == "Invalid value"
def test_make_chunk_message(self):
"""Test chunk message builder."""
chunk = make_chunk_message("req-123", "Hello, ")
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
assert chunk["id"] == "req-123"
assert chunk["data"] == "Hello, "
def test_make_chunk_message_with_complex_data(self):
"""Test chunk message with complex data."""
data = {"token": "world", "score": 0.95, "index": 5}
chunk = make_chunk_message("req-456", data)
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
assert chunk["id"] == "req-456"
assert chunk["data"] == data
def test_make_chunk_message_with_list_data(self):
"""Test chunk message with list data."""
data = [1, 2, 3, "token"]
chunk = make_chunk_message("req-789", data)
assert chunk["data"] == data
def test_make_end_message(self):
"""Test end message builder."""
end = make_end_message("req-123")
assert end["type"] == DirtyProtocol.MSG_TYPE_END
assert end["id"] == "req-123"
assert "data" not in end
def test_chunk_and_end_encode_decode(self):
"""Test that chunk and end messages can be encoded/decoded."""
chunk = make_chunk_message("req-123", {"token": "hello"})
end = make_end_message("req-123")
# Test chunk roundtrip
encoded_chunk = DirtyProtocol.encode(chunk)
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == chunk
# Test end roundtrip
encoded_end = DirtyProtocol.encode(end)
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
decoded = DirtyProtocol.decode(payload)
assert decoded == end
class TestDirtyErrors:
"""Tests for error classes."""

View File

@ -16,6 +16,9 @@ from gunicorn.dirty.protocol import DirtyProtocol, make_request
from gunicorn.dirty.errors import DirtyAppNotFoundError
import struct
class MockLog:
"""Mock logger for testing."""
@ -41,6 +44,42 @@ class MockLog:
pass
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.closed = False
def write(self, data):
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
self.closed = True
async def wait_closed(self):
pass
def get_extra_info(self, name):
return None
class TestDirtyWorkerInit:
"""Tests for DirtyWorker initialization."""
@ -214,8 +253,11 @@ class TestDirtyWorkerHandleRequest:
kwargs={"operation": "multiply"}
)
response = await worker.handle_request(request)
writer = MockStreamWriter()
await worker.handle_request(request, writer)
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
assert response["id"] == "test-123"
assert response["result"] == 6
@ -247,8 +289,11 @@ class TestDirtyWorkerHandleRequest:
kwargs={"operation": "invalid"}
)
response = await worker.handle_request(request)
writer = MockStreamWriter()
await worker.handle_request(request, writer)
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
assert response["id"] == "test-456"
assert "Unknown operation" in response["error"]["message"]
@ -271,8 +316,11 @@ class TestDirtyWorkerHandleRequest:
)
request = {"type": "unknown", "id": "test-789"}
response = await worker.handle_request(request)
writer = MockStreamWriter()
await worker.handle_request(request, writer)
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
assert "Unknown message type" in response["error"]["message"]
@ -662,43 +710,21 @@ class TestDirtyWorkerRunAsync:
reader.feed_data(encoded_request)
reader.feed_eof()
class MockWriter:
def __init__(self):
self.closed = False
self.data = b""
def get_extra_info(self, name):
return None
def write(self, data):
self.data += data
async def drain(self):
pass
def close(self):
self.closed = True
async def wait_closed(self):
pass
writer = MockWriter()
writer = MockStreamWriter()
# Handle one message then exit
worker.alive = True
try:
message = await DirtyProtocol.read_message_async(reader)
response = await worker.handle_request(message)
await DirtyProtocol.write_message_async(writer, response)
await worker.handle_request(message, writer)
except asyncio.IncompleteReadError:
pass
# Decode response from writer
if writer.data:
payload = writer.data[DirtyProtocol.HEADER_SIZE:]
response = DirtyProtocol.decode(payload)
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
assert response["result"] == 8
# Check response from writer
assert len(writer.messages) == 1
response = writer.messages[0]
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
assert response["result"] == 8
worker._cleanup()