From f6418d4eb02b27c71078ae317a2733f651481dfb Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Sat, 24 Jan 2026 18:39:14 +0100 Subject: [PATCH] feat(dirty): add streaming support and async client benchmarks Add support for streaming responses when dirty app actions return generators (sync or async). This enables real-time delivery of incremental results for use cases like LLM token generation. Features: - Streaming protocol with chunk/end/error message types - Worker support for sync and async generators - Arbiter forwarding of streaming messages - Deadline-based timeout handling - Async client streaming API Protocol: - Chunk messages (type: "chunk") contain partial data - End messages (type: "end") signal stream completion - Error messages can occur mid-stream New files: - benchmarks/dirty_streaming.py: Streaming benchmark suite - tests/dirty/test_*_streaming*.py: Streaming test coverage - docs/content/dirty.md: Streaming documentation with examples --- benchmarks/dirty_streaming.py | 755 +++++++++++++++++++++ docs/content/dirty.md | 154 +++++ gunicorn/dirty/arbiter.py | 99 ++- gunicorn/dirty/client.py | 359 ++++++++++ gunicorn/dirty/protocol.py | 34 + gunicorn/dirty/worker.py | 119 +++- tests/dirty/__init__.py | 5 + tests/dirty/test_arbiter_streaming.py | 319 +++++++++ tests/dirty/test_client_streaming.py | 236 +++++++ tests/dirty/test_client_streaming_async.py | 267 ++++++++ tests/dirty/test_streaming_integration.py | 469 +++++++++++++ tests/dirty/test_worker_streaming.py | 415 +++++++++++ tests/test_dirty_arbiter.py | 49 +- tests/test_dirty_protocol.py | 47 ++ tests/test_dirty_worker.py | 90 ++- 15 files changed, 3339 insertions(+), 78 deletions(-) create mode 100644 benchmarks/dirty_streaming.py create mode 100644 tests/dirty/__init__.py create mode 100644 tests/dirty/test_arbiter_streaming.py create mode 100644 tests/dirty/test_client_streaming.py create mode 100644 tests/dirty/test_client_streaming_async.py create mode 100644 tests/dirty/test_streaming_integration.py create mode 100644 tests/dirty/test_worker_streaming.py diff --git a/benchmarks/dirty_streaming.py b/benchmarks/dirty_streaming.py new file mode 100644 index 00000000..f5a27918 --- /dev/null +++ b/benchmarks/dirty_streaming.py @@ -0,0 +1,755 @@ +#!/usr/bin/env python +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Benchmark suite for dirty worker streaming functionality. + +This script benchmarks the streaming performance of dirty workers +to measure throughput, latency, and memory usage. + +Usage: + python benchmarks/dirty_streaming.py [OPTIONS] + +Options: + --quick Run quick benchmarks only + --full Run full benchmark suite including stress tests +""" + +import argparse +import asyncio +import gc +import json +import os +import struct +import sys +import time +import tracemalloc +from datetime import datetime +from unittest import mock + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_request, + make_chunk_message, + make_end_message, + make_response, +) +from gunicorn.dirty.worker import DirtyWorker +from gunicorn.dirty.arbiter import DirtyArbiter +from gunicorn.dirty.client import ( + DirtyClient, + DirtyStreamIterator, + DirtyAsyncStreamIterator, +) +from gunicorn.config import Config + + +class MockStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + self.bytes_written = 0 + + def write(self, data): + self._buffer += data + self.bytes_written += len(data) + + async def drain(self): + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + pass + + async def wait_closed(self): + pass + + +class MockStreamReader: + """Mock StreamReader that yields predefined messages.""" + + def __init__(self, messages): + self._data = b'' + for msg in messages: + self._data += DirtyProtocol.encode(msg) + self._pos = 0 + + async def readexactly(self, n): + if self._pos + n > len(self._data): + raise asyncio.IncompleteReadError(self._data[self._pos:], n) + result = self._data[self._pos:self._pos + n] + self._pos += n + return result + + +class MockLog: + """Silent logger for benchmarks.""" + + def debug(self, msg, *args): + pass + + def info(self, msg, *args): + pass + + def warning(self, msg, *args): + pass + + def error(self, msg, *args): + pass + + def close_on_exec(self): + pass + + def reopen_files(self): + pass + + +def create_worker(): + """Create a test worker for benchmarks.""" + cfg = Config() + cfg.set("dirty_timeout", 300) + log = MockLog() + + with mock.patch('gunicorn.dirty.worker.WorkerTmp'): + worker = DirtyWorker( + age=1, + ppid=os.getpid(), + app_paths=["benchmark:App"], + cfg=cfg, + log=log, + socket_path="/tmp/benchmark.sock" + ) + + worker.apps = {} + worker._executor = None + worker.tmp = mock.Mock() + + return worker + + +def create_arbiter(): + """Create a test arbiter for benchmarks.""" + cfg = Config() + cfg.set("dirty_timeout", 300) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + return arbiter + + +class BenchmarkResults: + """Store and display benchmark results.""" + + def __init__(self): + self.results = [] + + def add(self, name, iterations, duration, chunks=None, bytes_total=None, + memory_start=None, memory_end=None): + throughput = iterations / duration if duration > 0 else 0 + result = { + "name": name, + "iterations": iterations, + "duration_s": round(duration, 4), + "throughput_per_s": round(throughput, 2), + } + if chunks: + result["chunks_per_s"] = round(chunks / duration, 2) + if bytes_total: + result["mb_per_s"] = round(bytes_total / (1024 * 1024) / duration, 2) + if memory_start is not None and memory_end is not None: + result["memory_start_mb"] = round(memory_start / (1024 * 1024), 2) + result["memory_end_mb"] = round(memory_end / (1024 * 1024), 2) + result["memory_delta_mb"] = round((memory_end - memory_start) / (1024 * 1024), 2) + self.results.append(result) + + def display(self): + print("\n" + "=" * 70) + print("BENCHMARK RESULTS") + print("=" * 70) + for result in self.results: + print(f"\n{result['name']}") + print("-" * 50) + for key, value in result.items(): + if key != "name": + print(f" {key}: {value}") + print("\n" + "=" * 70) + + def save_json(self, filepath): + with open(filepath, 'w') as f: + json.dump({ + "timestamp": datetime.now().isoformat(), + "results": self.results + }, f, indent=2) + print(f"Results saved to {filepath}") + + +async def benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000): + """Benchmark worker streaming throughput with various chunk sizes.""" + worker = create_worker() + writer = MockStreamWriter() + + chunk_data = "x" * chunk_size + + async def sync_gen(): + for _ in range(num_chunks): + yield chunk_data + + async def mock_execute(app_path, action, args, kwargs): + return sync_gen() + + gc.collect() + tracemalloc.start() + memory_start = tracemalloc.get_traced_memory()[0] + + start = time.perf_counter() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("bench-1", "benchmark:App", "stream") + await worker.handle_request(request, writer) + + duration = time.perf_counter() - start + memory_end = tracemalloc.get_traced_memory()[0] + tracemalloc.stop() + + total_bytes = chunk_size * num_chunks + + results.add( + f"Worker streaming ({chunk_size}B chunks, {num_chunks} chunks)", + iterations=1, + duration=duration, + chunks=num_chunks, + bytes_total=total_bytes, + memory_start=memory_start, + memory_end=memory_end + ) + + +async def benchmark_arbiter_forwarding(results, num_chunks=1000): + """Benchmark arbiter message forwarding throughput.""" + arbiter = create_arbiter() + + messages = [] + for i in range(num_chunks): + messages.append(make_chunk_message(f"bench-{i}", f"data-{i}")) + messages.append(make_end_message(f"bench-{num_chunks}")) + + mock_reader = MockStreamReader(messages) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + gc.collect() + start = time.perf_counter() + + request = make_request("bench-forward", "benchmark:App", "stream") + await arbiter._execute_on_worker(1234, request, client_writer) + + duration = time.perf_counter() - start + + results.add( + f"Arbiter forwarding ({num_chunks} chunks)", + iterations=1, + duration=duration, + chunks=num_chunks, + bytes_total=client_writer.bytes_written + ) + + arbiter._cleanup_sync() + + +async def benchmark_streaming_latency(results, iterations=100): + """Benchmark time-to-first-chunk and time-to-last-chunk.""" + worker = create_worker() + + first_chunk_times = [] + total_times = [] + + for _ in range(iterations): + writer = MockStreamWriter() + + async def gen_3_chunks(): + yield "first" + yield "second" + yield "third" + + async def mock_execute(app_path, action, args, kwargs): + return gen_3_chunks() + + start = time.perf_counter() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("bench-latency", "benchmark:App", "stream") + await worker.handle_request(request, writer) + + # Find time when first chunk was received + if writer.messages: + first_chunk_times.append(time.perf_counter() - start) + + total_times.append(time.perf_counter() - start) + + avg_first_chunk = sum(first_chunk_times) / len(first_chunk_times) if first_chunk_times else 0 + avg_total = sum(total_times) / len(total_times) + + print(f"\nLatency Results ({iterations} iterations):") + print(f" Avg time-to-first-chunk: {avg_first_chunk * 1000:.3f}ms") + print(f" Avg time-to-last-chunk: {avg_total * 1000:.3f}ms") + + results.add( + f"Streaming latency ({iterations} iterations)", + iterations=iterations, + duration=sum(total_times), + chunks=iterations * 3 + ) + + +async def benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100): + """Benchmark multiple concurrent streams.""" + arbiter = create_arbiter() + + async def run_stream(stream_id): + messages = [] + for i in range(chunks_per_stream): + messages.append(make_chunk_message(f"stream-{stream_id}", f"chunk-{i}")) + messages.append(make_end_message(f"stream-{stream_id}")) + + mock_reader = MockStreamReader(messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + client_writer = MockStreamWriter() + + request = make_request(f"bench-concurrent-{stream_id}", "benchmark:App", "stream") + await arbiter._execute_on_worker(1234, request, client_writer) + return len(client_writer.messages) + + gc.collect() + start = time.perf_counter() + + # Run streams concurrently + tasks = [run_stream(i) for i in range(num_streams)] + results_list = await asyncio.gather(*tasks) + + duration = time.perf_counter() - start + + total_chunks = sum(results_list) + + results.add( + f"Concurrent streams ({num_streams} streams, {chunks_per_stream} chunks each)", + iterations=num_streams, + duration=duration, + chunks=total_chunks + ) + + arbiter._cleanup_sync() + + +async def benchmark_memory_stability(results, iterations=10, chunks=1000): + """Check memory stability over many iterations.""" + worker = create_worker() + + gc.collect() + tracemalloc.start() + memory_samples = [tracemalloc.get_traced_memory()[0]] + + for i in range(iterations): + writer = MockStreamWriter() + + async def gen_chunks(): + for j in range(chunks): + yield f"chunk-{j}" + + async def mock_execute(app_path, action, args, kwargs): + return gen_chunks() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request(f"bench-mem-{i}", "benchmark:App", "stream") + await worker.handle_request(request, writer) + + gc.collect() + memory_samples.append(tracemalloc.get_traced_memory()[0]) + + tracemalloc.stop() + + memory_start = memory_samples[0] + memory_end = memory_samples[-1] + memory_max = max(memory_samples) + + print(f"\nMemory stability ({iterations} iterations of {chunks} chunks):") + print(f" Start: {memory_start / 1024 / 1024:.2f}MB") + print(f" End: {memory_end / 1024 / 1024:.2f}MB") + print(f" Max: {memory_max / 1024 / 1024:.2f}MB") + print(f" Delta: {(memory_end - memory_start) / 1024 / 1024:.2f}MB") + + results.add( + f"Memory stability ({iterations} x {chunks} chunks)", + iterations=iterations * chunks, + duration=0.001, # Use small non-zero value to avoid division by zero + memory_start=memory_start, + memory_end=memory_end + ) + + +class MockClientReader: + """Mock async reader that simulates receiving streaming messages.""" + + def __init__(self, num_chunks, chunk_data): + self.num_chunks = num_chunks + self.chunk_data = chunk_data + self._chunk_idx = 0 + self._messages = [] + self._build_messages() + self._pos = 0 + self._data = b'' + for msg in self._messages: + self._data += DirtyProtocol.encode(msg) + + def _build_messages(self): + for i in range(self.num_chunks): + self._messages.append(make_chunk_message(f"bench-{i}", self.chunk_data)) + self._messages.append(make_end_message(f"bench-end")) + + async def readexactly(self, n): + if self._pos + n > len(self._data): + raise asyncio.IncompleteReadError(self._data[self._pos:], n) + result = self._data[self._pos:self._pos + n] + self._pos += n + return result + + +class MockClientWriter: + """Mock async writer for client connection.""" + + def __init__(self): + self._buffer = b"" + self._closed = False + + def write(self, data): + self._buffer += data + + async def drain(self): + pass + + def close(self): + self._closed = True + + async def wait_closed(self): + pass + + +async def benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000): + """ + Benchmark DirtyAsyncStreamIterator directly. + + Measures async iterator overhead vs raw message reading. + """ + chunk_data = "x" * chunk_size + + # Create mock client with mock reader/writer + client = DirtyClient("/tmp/benchmark.sock", timeout=30.0) + client._reader = MockClientReader(num_chunks, chunk_data) + client._writer = MockClientWriter() + + gc.collect() + tracemalloc.start() + memory_start = tracemalloc.get_traced_memory()[0] + + start = time.perf_counter() + + # Use the async stream iterator directly + iterator = DirtyAsyncStreamIterator(client, "benchmark:App", "stream", (), {}) + iterator._started = True # Skip the request sending + iterator._request_id = "bench-async" + iterator._deadline = time.perf_counter() + 300 # 5 min deadline + iterator._last_chunk_time = time.perf_counter() + + chunks_received = 0 + bytes_received = 0 + async for chunk in iterator: + chunks_received += 1 + bytes_received += len(chunk) + + duration = time.perf_counter() - start + memory_end = tracemalloc.get_traced_memory()[0] + tracemalloc.stop() + + results.add( + f"Async client streaming ({chunk_size}B chunks, {num_chunks} chunks)", + iterations=1, + duration=duration, + chunks=chunks_received, + bytes_total=bytes_received, + memory_start=memory_start, + memory_end=memory_end + ) + + +async def benchmark_sync_client_streaming(results, chunk_size=1024, num_chunks=1000): + """ + Benchmark DirtyStreamIterator directly (for comparison with async). + + Note: This runs the sync iterator within an async context for comparison. + """ + chunk_data = "x" * chunk_size + + # Build raw message data + messages_data = b'' + for i in range(num_chunks): + msg = make_chunk_message(f"bench-{i}", chunk_data) + messages_data += DirtyProtocol.encode(msg) + messages_data += DirtyProtocol.encode(make_end_message("bench-end")) + + # Create a mock socket-like object + class MockSocket: + def __init__(self, data): + self._data = data + self._pos = 0 + self._timeout = None + + def recv(self, n, flags=0): + if self._pos >= len(self._data): + return b'' + result = self._data[self._pos:self._pos + n] + self._pos += len(result) + return result + + def settimeout(self, timeout): + self._timeout = timeout + + # Create mock client + client = DirtyClient("/tmp/benchmark.sock", timeout=30.0) + client._sock = MockSocket(messages_data) + + gc.collect() + tracemalloc.start() + memory_start = tracemalloc.get_traced_memory()[0] + + start = time.perf_counter() + + # Use the sync stream iterator + iterator = DirtyStreamIterator(client, "benchmark:App", "stream", (), {}) + iterator._started = True # Skip the request sending + iterator._request_id = "bench-sync" + iterator._deadline = time.perf_counter() + 300 # 5 min deadline + iterator._last_chunk_time = time.perf_counter() + + chunks_received = 0 + bytes_received = 0 + for chunk in iterator: + chunks_received += 1 + bytes_received += len(chunk) + + duration = time.perf_counter() - start + memory_end = tracemalloc.get_traced_memory()[0] + tracemalloc.stop() + + results.add( + f"Sync client streaming ({chunk_size}B chunks, {num_chunks} chunks)", + iterations=1, + duration=duration, + chunks=chunks_received, + bytes_total=bytes_received, + memory_start=memory_start, + memory_end=memory_end + ) + + +async def benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000): + """ + Compare stream() vs stream_async() performance with the same workload. + """ + chunk_data = "x" * chunk_size + + # --- Sync test --- + messages_data = b'' + for i in range(num_chunks): + msg = make_chunk_message(f"bench-{i}", chunk_data) + messages_data += DirtyProtocol.encode(msg) + messages_data += DirtyProtocol.encode(make_end_message("bench-end")) + + class MockSocket: + def __init__(self, data): + self._data = data + self._pos = 0 + self._timeout = None + + def recv(self, n, flags=0): + if self._pos >= len(self._data): + return b'' + result = self._data[self._pos:self._pos + n] + self._pos += len(result) + return result + + def settimeout(self, timeout): + self._timeout = timeout + + sync_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0) + sync_client._sock = MockSocket(messages_data) + + gc.collect() + sync_start = time.perf_counter() + + sync_iter = DirtyStreamIterator(sync_client, "benchmark:App", "stream", (), {}) + sync_iter._started = True + sync_iter._request_id = "bench-sync" + sync_iter._deadline = time.perf_counter() + 300 # 5 min deadline + sync_iter._last_chunk_time = time.perf_counter() + + sync_chunks = 0 + for _ in sync_iter: + sync_chunks += 1 + + sync_duration = time.perf_counter() - sync_start + + # --- Async test --- + async_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0) + async_client._reader = MockClientReader(num_chunks, chunk_data) + async_client._writer = MockClientWriter() + + gc.collect() + async_start = time.perf_counter() + + async_iter = DirtyAsyncStreamIterator(async_client, "benchmark:App", "stream", (), {}) + async_iter._started = True + async_iter._request_id = "bench-async" + async_iter._deadline = time.perf_counter() + 300 # 5 min deadline + async_iter._last_chunk_time = time.perf_counter() + + async_chunks = 0 + async for _ in async_iter: + async_chunks += 1 + + async_duration = time.perf_counter() - async_start + + # Report comparison + print(f"\nSync vs Async Client Streaming Comparison ({num_chunks} x {chunk_size}B chunks):") + print(f" Sync: {sync_duration * 1000:.3f}ms ({sync_chunks} chunks)") + print(f" Async: {async_duration * 1000:.3f}ms ({async_chunks} chunks)") + if sync_duration > 0: + ratio = async_duration / sync_duration + print(f" Ratio (async/sync): {ratio:.3f}x") + + results.add( + f"Sync client streaming comparison ({chunk_size}B, {num_chunks} chunks)", + iterations=1, + duration=sync_duration, + chunks=sync_chunks, + bytes_total=sync_chunks * chunk_size + ) + + results.add( + f"Async client streaming comparison ({chunk_size}B, {num_chunks} chunks)", + iterations=1, + duration=async_duration, + chunks=async_chunks, + bytes_total=async_chunks * chunk_size + ) + + +async def run_quick_benchmarks(): + """Run quick benchmarks.""" + results = BenchmarkResults() + + print("Running quick benchmarks...") + + await benchmark_worker_streaming_throughput(results, chunk_size=64, num_chunks=1000) + await benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000) + await benchmark_arbiter_forwarding(results, num_chunks=1000) + await benchmark_streaming_latency(results, iterations=50) + + # Async client streaming benchmarks + await benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000) + await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000) + + return results + + +async def run_full_benchmarks(): + """Run full benchmark suite including stress tests.""" + results = BenchmarkResults() + + print("Running full benchmark suite...") + + # Throughput tests with different chunk sizes + for chunk_size in [1, 64, 1024, 65536]: + await benchmark_worker_streaming_throughput( + results, chunk_size=chunk_size, num_chunks=1000 + ) + + # Arbiter forwarding + await benchmark_arbiter_forwarding(results, num_chunks=10000) + + # Latency + await benchmark_streaming_latency(results, iterations=100) + + # Concurrent streams + await benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100) + await benchmark_concurrent_streams(results, num_streams=50, chunks_per_stream=100) + + # Memory stability + await benchmark_memory_stability(results, iterations=20, chunks=1000) + + # Async client streaming benchmarks + for chunk_size in [64, 1024, 65536]: + await benchmark_async_client_streaming(results, chunk_size=chunk_size, num_chunks=1000) + await benchmark_sync_client_streaming(results, chunk_size=chunk_size, num_chunks=1000) + + # Comparison benchmark + await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=5000) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Dirty streaming benchmarks") + parser.add_argument("--quick", action="store_true", help="Run quick benchmarks only") + parser.add_argument("--full", action="store_true", help="Run full benchmark suite") + parser.add_argument("--output", "-o", help="Output JSON file path") + args = parser.parse_args() + + if args.full: + results = asyncio.run(run_full_benchmarks()) + else: + results = asyncio.run(run_quick_benchmarks()) + + results.display() + + if args.output: + results.save_json(args.output) + else: + # Save to default location + output_dir = os.path.dirname(os.path.abspath(__file__)) + results_dir = os.path.join(output_dir, "results") + os.makedirs(results_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = os.path.join(results_dir, f"streaming_benchmark_{timestamp}.json") + results.save_json(output_file) + + +if __name__ == "__main__": + main() diff --git a/docs/content/dirty.md b/docs/content/dirty.md index 387aaf4a..c9ee5df3 100644 --- a/docs/content/dirty.md +++ b/docs/content/dirty.md @@ -200,6 +200,160 @@ async def my_view(request): return result ``` +## Streaming + +Dirty Arbiters support streaming responses for use cases like LLM token generation, where data is produced incrementally. This enables real-time delivery of results without waiting for complete execution. + +### Streaming with Generators + +Any dirty app action that returns a generator (sync or async) automatically streams chunks to the client: + +```python +# myapp/llm.py +from gunicorn.dirty import DirtyApp + +class LLMApp(DirtyApp): + def init(self): + from transformers import pipeline + self.generator = pipeline("text-generation", model="gpt2") + + def generate(self, prompt): + """Sync streaming - yields tokens.""" + for token in self.generator(prompt, stream=True): + yield token["generated_text"] + + async def generate_async(self, prompt): + """Async streaming - yields tokens.""" + import openai + client = openai.AsyncOpenAI() + stream = await client.chat.completions.create( + model="gpt-4", + messages=[{"role": "user", "content": prompt}], + stream=True + ) + async for chunk in stream: + if chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + def close(self): + pass +``` + +### Client Streaming API + +Use `stream()` for sync workers and `stream_async()` for async workers: + +**Sync Workers (sync, gthread):** + +```python +from gunicorn.dirty import get_dirty_client + +def generate_view(request): + client = get_dirty_client() + + def generate_response(): + for chunk in client.stream("myapp.llm:LLMApp", "generate", request.prompt): + yield chunk + + return StreamingResponse(generate_response()) +``` + +**Async Workers (ASGI):** + +```python +from gunicorn.dirty import get_dirty_client_async + +async def generate_view(request): + client = await get_dirty_client_async() + + async def generate_response(): + async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", request.prompt): + yield chunk + + return StreamingResponse(generate_response()) +``` + +### Streaming Protocol + +Streaming uses a simple protocol with three message types: + +1. **Chunk** (`type: "chunk"`) - Contains partial data +2. **End** (`type: "end"`) - Signals stream completion +3. **Error** (`type: "error"`) - Signals error during streaming + +Example message flow: +``` +Client -> Arbiter -> Worker: request +Worker -> Arbiter -> Client: chunk (data: "Hello") +Worker -> Arbiter -> Client: chunk (data: " ") +Worker -> Arbiter -> Client: chunk (data: "World") +Worker -> Arbiter -> Client: end +``` + +### Error Handling in Streams + +Errors during streaming are delivered as error messages: + +```python +def generate_view(request): + client = get_dirty_client() + + try: + for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt): + yield chunk + except DirtyError as e: + # Error occurred mid-stream + yield f"\n[Error: {e.message}]" +``` + +### Best Practices for Streaming + +1. **Use async generators for I/O-bound streaming** - e.g., API calls to external services +2. **Use sync generators for CPU-bound streaming** - e.g., local model inference +3. **Yield frequently** - Heartbeats are sent during streaming to keep workers alive +4. **Keep chunks small** - Smaller chunks provide better perceived latency +5. **Handle client disconnection** - Streams continue even if client disconnects; design accordingly + +### Flask Example + +```python +from flask import Flask, Response +from gunicorn.dirty import get_dirty_client + +app = Flask(__name__) + +@app.route("/chat", methods=["POST"]) +def chat(): + prompt = request.json.get("prompt") + client = get_dirty_client() + + def stream(): + for token in client.stream("myapp.llm:LLMApp", "generate", prompt): + yield f"data: {token}\n\n" + + return Response(stream(), content_type="text/event-stream") +``` + +### FastAPI Example + +```python +from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from gunicorn.dirty import get_dirty_client_async + +app = FastAPI() + +@app.post("/chat") +async def chat(prompt: str): + client = await get_dirty_client_async() + + async def stream(): + async for token in client.stream_async("myapp.llm:LLMApp", "generate", prompt): + yield f"data: {token}\n\n" + + return StreamingResponse(stream(), media_type="text/event-stream") +``` + ## Lifecycle Hooks Dirty Arbiters provide hooks for customization: diff --git a/gunicorn/dirty/arbiter.py b/gunicorn/dirty/arbiter.py index 63dd6269..4c2243fd 100644 --- a/gunicorn/dirty/arbiter.py +++ b/gunicorn/dirty/arbiter.py @@ -199,6 +199,7 @@ class DirtyArbiter: Handle a connection from an HTTP worker. Routes requests to available dirty workers and returns responses. + Supports both regular responses and streaming (chunk-based) responses. """ self.log.debug("New client connection from HTTP worker") @@ -209,11 +210,8 @@ class DirtyArbiter: except asyncio.IncompleteReadError: break - # Route request to a dirty worker - response = await self.route_request(message) - - # Send response back to HTTP worker - await DirtyProtocol.write_message_async(writer, response) + # Route request to a dirty worker - pass writer for streaming + await self.route_request(message, writer) except Exception as e: self.log.error("Client connection error: %s", e) finally: @@ -223,28 +221,31 @@ class DirtyArbiter: except Exception: pass - async def route_request(self, request): + async def route_request(self, request, client_writer): """ Route a request to an available dirty worker via queue. Each worker has a dedicated queue and consumer task. Requests are submitted to the queue and processed sequentially by the consumer. + For streaming responses, messages (chunks) are forwarded directly + to the client_writer as they arrive from the worker. + Args: request: Request message dict - - Returns: - Response message dict + client_writer: StreamWriter to send responses to client """ request_id = request.get("id", "unknown") # Find an available worker worker_pid = await self._get_available_worker() if worker_pid is None: - return make_error_response( + response = make_error_response( request_id, DirtyError("No dirty workers available") ) + await DirtyProtocol.write_message_async(client_writer, response) + return # Get queue (start consumer if needed) if worker_pid not in self.worker_queues: @@ -253,17 +254,18 @@ class DirtyArbiter: queue = self.worker_queues[worker_pid] future = asyncio.get_running_loop().create_future() - # Submit request to queue - await queue.put((request, future)) + # Submit request to queue with client writer for streaming support + await queue.put((request, client_writer, future)) - # Wait for response + # Wait for completion (streaming messages forwarded by consumer) try: - return await future + await future except Exception as e: - return make_error_response( + response = make_error_response( request_id, DirtyWorkerError(f"Request failed: {e}", worker_id=worker_pid) ) + await DirtyProtocol.write_message_async(client_writer, response) async def _start_worker_consumer(self, worker_pid): """Start a consumer task for a worker's request queue.""" @@ -273,11 +275,13 @@ class DirtyArbiter: async def consumer(): while self.alive: try: - request, future = await queue.get() + request, client_writer, future = await queue.get() try: - response = await self._execute_on_worker(worker_pid, request) + await self._execute_on_worker( + worker_pid, request, client_writer + ) if not future.done(): - future.set_result(response) + future.set_result(None) except Exception as e: if not future.done(): future.set_exception(e) @@ -289,32 +293,65 @@ class DirtyArbiter: task = asyncio.create_task(consumer()) self.worker_consumers[worker_pid] = task - async def _execute_on_worker(self, worker_pid, request): - """Execute request on a specific worker (called by consumer).""" + async def _execute_on_worker(self, worker_pid, request, client_writer): + """ + Execute request on a specific worker (called by consumer). + + Handles both regular responses and streaming (chunk-based) responses. + For streaming, chunk and end messages are forwarded directly to the + client_writer as they arrive from the worker. + """ request_id = request.get("id", "unknown") try: reader, writer = await self._get_worker_connection(worker_pid) await DirtyProtocol.write_message_async(writer, request) - response = await asyncio.wait_for( - DirtyProtocol.read_message_async(reader), - timeout=self.cfg.dirty_timeout - ) - return response - except asyncio.TimeoutError: - return make_error_response( - request_id, - DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout) - ) + # Read messages until we get a response, end, or error + while True: + try: + message = await asyncio.wait_for( + DirtyProtocol.read_message_async(reader), + timeout=self.cfg.dirty_timeout + ) + except asyncio.TimeoutError: + response = make_error_response( + request_id, + DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout) + ) + await DirtyProtocol.write_message_async(client_writer, response) + return + + msg_type = message.get("type") + + # Forward chunk messages to client + if msg_type == DirtyProtocol.MSG_TYPE_CHUNK: + await DirtyProtocol.write_message_async(client_writer, message) + continue + + # Forward end message and complete + if msg_type == DirtyProtocol.MSG_TYPE_END: + await DirtyProtocol.write_message_async(client_writer, message) + return + + # Forward response or error and complete + if msg_type in (DirtyProtocol.MSG_TYPE_RESPONSE, + DirtyProtocol.MSG_TYPE_ERROR): + await DirtyProtocol.write_message_async(client_writer, message) + return + + # Unknown message type - log and continue + self.log.warning("Unknown message type from worker: %s", msg_type) + except Exception as e: self.log.error("Error executing on worker %s: %s", worker_pid, e) self._close_worker_connection(worker_pid) - return make_error_response( + response = make_error_response( request_id, DirtyWorkerError(f"Worker communication failed: {e}", worker_id=worker_pid) ) + await DirtyProtocol.write_message_async(client_writer, response) async def _get_available_worker(self): """ diff --git a/gunicorn/dirty/client.py b/gunicorn/dirty/client.py index 39eac4b1..f487a29a 100644 --- a/gunicorn/dirty/client.py +++ b/gunicorn/dirty/client.py @@ -14,6 +14,7 @@ import contextvars import os import socket import threading +import time import uuid from .errors import ( @@ -134,6 +135,34 @@ class DirtyClient: raise raise DirtyConnectionError(f"Communication error: {e}") from e + def stream(self, app_path, action, *args, **kwargs): + """ + Stream results from a dirty app action (sync). + + This method returns an iterator that yields chunks from a streaming + response. Use this for actions that return generators. + + Args: + app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp') + action: Action to call on the app + *args: Positional arguments + **kwargs: Keyword arguments + + Yields: + Chunks of data from the streaming response + + Raises: + DirtyConnectionError: If connection fails + DirtyTimeoutError: If operation times out + DirtyError: If execution fails + + Example:: + + for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt): + print(chunk, end="", flush=True) + """ + return DirtyStreamIterator(self, app_path, action, args, kwargs) + def _handle_response(self, response): """Handle response message, extracting result or raising error.""" msg_type = response.get("type") @@ -247,6 +276,34 @@ class DirtyClient: raise raise DirtyConnectionError(f"Communication error: {e}") from e + def stream_async(self, app_path, action, *args, **kwargs): + """ + Stream results from a dirty app action (async). + + This method returns an async iterator that yields chunks from a + streaming response. Use this for actions that return generators. + + Args: + app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp') + action: Action to call on the app + *args: Positional arguments + **kwargs: Keyword arguments + + Yields: + Chunks of data from the streaming response + + Raises: + DirtyConnectionError: If connection fails + DirtyTimeoutError: If operation times out + DirtyError: If execution fails + + Example:: + + async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", prompt): + await response.write(chunk) + """ + return DirtyAsyncStreamIterator(self, app_path, action, args, kwargs) + async def _close_async(self): """Close the async connection.""" if self._writer is not None: @@ -281,6 +338,308 @@ class DirtyClient: await self.close_async() +# ============================================================================= +# Stream Iterator classes +# ============================================================================= + + +class DirtyStreamIterator: + """ + Iterator for streaming responses from dirty workers (sync). + + This class is returned by `DirtyClient.stream()` and yields chunks + from a streaming response until the end message is received. + + Uses a deadline-based timeout approach: + - Total stream timeout: limits entire stream duration + - Idle timeout: limits gap between chunks (defaults to total timeout) + """ + + # Default idle timeout between chunks (seconds) + DEFAULT_IDLE_TIMEOUT = 30.0 + + # Threshold for applying per-read timeout (seconds) + # When remaining time is above this, use a larger timeout for efficiency + _TIMEOUT_THRESHOLD = 5.0 + + def __init__(self, client, app_path, action, args, kwargs, + idle_timeout=None): + self.client = client + self.app_path = app_path + self.action = action + self.args = args + self.kwargs = kwargs + self._started = False + self._exhausted = False + self._request_id = None + self._deadline = None + self._last_chunk_time = None + # Idle timeout: max time between chunks + self._idle_timeout = ( + idle_timeout if idle_timeout is not None + else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout) + ) + + def __iter__(self): + return self + + def __next__(self): + if self._exhausted: + raise StopIteration + + if not self._started: + self._start_request() + self._started = True + + return self._read_next_chunk() + + def _start_request(self): + """Send the initial request to the arbiter.""" + with self.client._lock: + if self.client._sock is None: + self.client.connect() + + # Set deadline for entire stream + now = time.monotonic() + self._deadline = now + self.client.timeout + self._last_chunk_time = now + + self._request_id = str(uuid.uuid4()) + request = make_request( + self._request_id, + self.app_path, + self.action, + args=self.args, + kwargs=self.kwargs, + ) + DirtyProtocol.write_message(self.client._sock, request) + + def _read_next_chunk(self): + """Read the next message from the stream.""" + with self.client._lock: + # Check total stream deadline + now = time.monotonic() + if now >= self._deadline: + self._exhausted = True + raise DirtyTimeoutError( + "Stream exceeded total timeout", + timeout=self.client.timeout + ) + + remaining = self._deadline - now + + # Set socket timeout based on remaining time + # Fast path: use larger timeout when plenty of time remains + if remaining > self._TIMEOUT_THRESHOLD: + read_timeout = self._TIMEOUT_THRESHOLD + else: + read_timeout = min(remaining, self._idle_timeout) + + try: + self.client._sock.settimeout(read_timeout) + response = DirtyProtocol.read_message(self.client._sock) + except socket.timeout: + # Check which timeout was hit + now = time.monotonic() + if now >= self._deadline: + self._exhausted = True + raise DirtyTimeoutError( + "Stream exceeded total timeout", + timeout=self.client.timeout + ) + else: + idle_duration = now - self._last_chunk_time + self._exhausted = True + raise DirtyTimeoutError( + f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)", + timeout=self._idle_timeout + ) + except Exception as e: + self._exhausted = True + self.client._close_socket() + raise DirtyConnectionError(f"Communication error: {e}") from e + + # Update last chunk time for idle tracking + self._last_chunk_time = time.monotonic() + + msg_type = response.get("type") + + # Chunk message - return the data + if msg_type == DirtyProtocol.MSG_TYPE_CHUNK: + return response.get("data") + + # End message - stop iteration + if msg_type == DirtyProtocol.MSG_TYPE_END: + self._exhausted = True + raise StopIteration + + # Error message - raise exception + if msg_type == DirtyProtocol.MSG_TYPE_ERROR: + self._exhausted = True + error_info = response.get("error", {}) + raise DirtyError.from_dict(error_info) + + # Regular response - shouldn't happen for streaming, but handle it + if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE: + self._exhausted = True + # Return the result as the only chunk then stop + raise StopIteration + + # Unknown type + self._exhausted = True + raise DirtyError(f"Unknown message type: {msg_type}") + + +class DirtyAsyncStreamIterator: + """ + Async iterator for streaming responses from dirty workers. + + This class is returned by `DirtyClient.stream_async()` and yields chunks + from a streaming response until the end message is received. + + Uses a deadline-based timeout approach for efficiency: + - Total stream timeout: limits entire stream duration + - Idle timeout: limits gap between chunks (defaults to total timeout) + + This avoids the overhead of asyncio.wait_for() on every chunk read. + """ + + # Default idle timeout between chunks (seconds) + DEFAULT_IDLE_TIMEOUT = 30.0 + + def __init__(self, client, app_path, action, args, kwargs, + idle_timeout=None): + self.client = client + self.app_path = app_path + self.action = action + self.args = args + self.kwargs = kwargs + self._started = False + self._exhausted = False + self._request_id = None + self._deadline = None + self._last_chunk_time = None + # Idle timeout: max time between chunks + self._idle_timeout = ( + idle_timeout if idle_timeout is not None + else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout) + ) + + def __aiter__(self): + return self + + async def __anext__(self): + if self._exhausted: + raise StopAsyncIteration + + if not self._started: + await self._start_request() + self._started = True + + return await self._read_next_chunk() + + async def _start_request(self): + """Send the initial request to the arbiter.""" + if self.client._writer is None: + await self.client.connect_async() + + # Set deadline for entire stream + now = time.monotonic() + self._deadline = now + self.client.timeout + self._last_chunk_time = now + + self._request_id = str(uuid.uuid4()) + request = make_request( + self._request_id, + self.app_path, + self.action, + args=self.args, + kwargs=self.kwargs, + ) + await DirtyProtocol.write_message_async(self.client._writer, request) + + # Threshold for applying timeout wrapper (seconds) + # When remaining time is above this, skip timeout for performance + _TIMEOUT_THRESHOLD = 5.0 + + async def _read_next_chunk(self): + """Read the next message from the stream.""" + # Calculate remaining time until deadline + now = time.monotonic() + + # Check total stream deadline + if now >= self._deadline: + self._exhausted = True + raise DirtyTimeoutError( + "Stream exceeded total timeout", + timeout=self.client.timeout + ) + + remaining = self._deadline - now + + try: + # Fast path: skip timeout wrapper when we have plenty of time + # This avoids asyncio.wait_for() overhead for most chunks + if remaining > self._TIMEOUT_THRESHOLD: + response = await DirtyProtocol.read_message_async( + self.client._reader + ) + else: + # Near deadline: apply timeout protection + read_timeout = min(remaining, self._idle_timeout) + response = await asyncio.wait_for( + DirtyProtocol.read_message_async(self.client._reader), + timeout=read_timeout + ) + except asyncio.TimeoutError: + self._exhausted = True + now = time.monotonic() + if now >= self._deadline: + raise DirtyTimeoutError( + "Stream exceeded total timeout", + timeout=self.client.timeout + ) + else: + idle_duration = now - self._last_chunk_time + raise DirtyTimeoutError( + f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)", + timeout=self._idle_timeout + ) + except Exception as e: + self._exhausted = True + await self.client._close_async() + raise DirtyConnectionError(f"Communication error: {e}") from e + + # Update last chunk time for idle tracking + self._last_chunk_time = time.monotonic() + + msg_type = response.get("type") + + # Chunk message - return the data + if msg_type == DirtyProtocol.MSG_TYPE_CHUNK: + return response.get("data") + + # End message - stop iteration + if msg_type == DirtyProtocol.MSG_TYPE_END: + self._exhausted = True + raise StopAsyncIteration + + # Error message - raise exception + if msg_type == DirtyProtocol.MSG_TYPE_ERROR: + self._exhausted = True + error_info = response.get("error", {}) + raise DirtyError.from_dict(error_info) + + # Regular response - shouldn't happen for streaming + if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE: + self._exhausted = True + raise StopAsyncIteration + + # Unknown type + self._exhausted = True + raise DirtyError(f"Unknown message type: {msg_type}") + + # ============================================================================= # Thread-local and context-local client management # ============================================================================= diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index e0895037..e5ac6cfa 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -312,3 +312,37 @@ def make_error_response(request_id: str, error) -> dict: "id": request_id, "error": error_dict, } + + +def make_chunk_message(request_id: str, data) -> dict: + """ + Build a chunk message for streaming responses. + + Args: + request_id: Request identifier this chunk belongs to + data: Chunk data (must be JSON-serializable) + + Returns: + dict: Chunk message + """ + return { + "type": DirtyProtocol.MSG_TYPE_CHUNK, + "id": request_id, + "data": data, + } + + +def make_end_message(request_id: str) -> dict: + """ + Build an end-of-stream message. + + Args: + request_id: Request identifier this ends + + Returns: + dict: End message + """ + return { + "type": DirtyProtocol.MSG_TYPE_END, + "id": request_id, + } diff --git a/gunicorn/dirty/worker.py b/gunicorn/dirty/worker.py index 70e45047..62162a77 100644 --- a/gunicorn/dirty/worker.py +++ b/gunicorn/dirty/worker.py @@ -69,6 +69,7 @@ operation will continue until the worker is killed by the arbiter. import asyncio from concurrent.futures import ThreadPoolExecutor +import inspect import os import signal import traceback @@ -88,6 +89,8 @@ from .protocol import ( DirtyProtocol, make_response, make_error_response, + make_chunk_message, + make_end_message, ) @@ -296,11 +299,8 @@ class DirtyWorker: # Connection closed break - # Handle the request - response = await self.handle_request(message) - - # Send response - await DirtyProtocol.write_message_async(writer, response) + # Handle the request - pass writer for streaming support + await self.handle_request(message, writer) except Exception as e: self.log.error("Connection error: %s", e) finally: @@ -310,24 +310,28 @@ class DirtyWorker: except Exception: pass - async def handle_request(self, message): + async def handle_request(self, message, writer): """ Handle a single request message. + Supports both regular (non-streaming) and streaming responses. + For streaming, detects if the result is a generator and sends + chunk messages followed by an end message. + Args: message: Request dict from protocol - - Returns: - Response dict to send back + writer: StreamWriter for sending responses """ request_id = message.get("id", str(uuid.uuid4())) msg_type = message.get("type") if msg_type != DirtyProtocol.MSG_TYPE_REQUEST: - return make_error_response( + response = make_error_response( request_id, DirtyWorkerError(f"Unknown message type: {msg_type}") ) + await DirtyProtocol.write_message_async(writer, response) + return app_path = message.get("app_path") action = message.get("action") @@ -339,16 +343,107 @@ class DirtyWorker: try: result = await self.execute(app_path, action, args, kwargs) - return make_response(request_id, result) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + await self._stream_sync_generator(request_id, result, writer) + elif inspect.isasyncgen(result): + await self._stream_async_generator(request_id, result, writer) + else: + # Regular non-streaming response + response = make_response(request_id, result) + await DirtyProtocol.write_message_async(writer, response) except Exception as e: tb = traceback.format_exc() self.log.error("Error executing %s.%s: %s\n%s", app_path, action, e, tb) - return make_error_response( + response = make_error_response( request_id, DirtyAppError(str(e), app_path=app_path, action=action, traceback=tb) ) + await DirtyProtocol.write_message_async(writer, response) + + async def _stream_sync_generator(self, request_id, gen, writer): + """ + Stream chunks from a synchronous generator. + + Args: + request_id: Request ID for the messages + gen: Sync generator to iterate + writer: StreamWriter for sending messages + """ + # Sentinel value to detect end of generator + # (StopIteration cannot be raised into a Future in Python 3.7+) + _EXHAUSTED = object() + + def _get_next(): + try: + return next(gen) + except StopIteration: + return _EXHAUSTED + + try: + loop = asyncio.get_running_loop() + while True: + # Run next() in executor to avoid blocking event loop + chunk = await loop.run_in_executor(self._executor, _get_next) + if chunk is _EXHAUSTED: + break + # Send chunk message + await DirtyProtocol.write_message_async( + writer, make_chunk_message(request_id, chunk) + ) + # Update heartbeat during long streams + self.notify() + # Send end message + await DirtyProtocol.write_message_async( + writer, make_end_message(request_id) + ) + except Exception as e: + # Error during streaming - send error message + tb = traceback.format_exc() + self.log.error("Error during streaming: %s\n%s", e, tb) + response = make_error_response( + request_id, + DirtyAppError(str(e), traceback=tb) + ) + await DirtyProtocol.write_message_async(writer, response) + finally: + gen.close() + + async def _stream_async_generator(self, request_id, gen, writer): + """ + Stream chunks from an asynchronous generator. + + Args: + request_id: Request ID for the messages + gen: Async generator to iterate + writer: StreamWriter for sending messages + """ + try: + async for chunk in gen: + # Send chunk message + await DirtyProtocol.write_message_async( + writer, make_chunk_message(request_id, chunk) + ) + # Update heartbeat during long streams + self.notify() + # Send end message + await DirtyProtocol.write_message_async( + writer, make_end_message(request_id) + ) + except Exception as e: + # Error during streaming - send error message + tb = traceback.format_exc() + self.log.error("Error during streaming: %s\n%s", e, tb) + response = make_error_response( + request_id, + DirtyAppError(str(e), traceback=tb) + ) + await DirtyProtocol.write_message_async(writer, response) + finally: + await gen.aclose() async def execute(self, app_path, action, args, kwargs): """ diff --git a/tests/dirty/__init__.py b/tests/dirty/__init__.py new file mode 100644 index 00000000..2e16acee --- /dev/null +++ b/tests/dirty/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty worker streaming functionality.""" diff --git a/tests/dirty/test_arbiter_streaming.py b/tests/dirty/test_arbiter_streaming.py new file mode 100644 index 00000000..ef15c33a --- /dev/null +++ b/tests/dirty/test_arbiter_streaming.py @@ -0,0 +1,319 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty arbiter streaming functionality.""" + +import asyncio +import struct +from unittest import mock + +import pytest + +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_request, + make_response, + make_chunk_message, + make_end_message, + make_error_response, +) +from gunicorn.dirty.arbiter import DirtyArbiter +from gunicorn.dirty.errors import DirtyError + + +class MockStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + self.closed = False + + def write(self, data): + self._buffer += data + + async def drain(self): + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + self.closed = True + + async def wait_closed(self): + pass + + def get_extra_info(self, name): + return None + + +class MockStreamReader: + """Mock StreamReader that yields predefined messages.""" + + def __init__(self, messages): + self._data = b'' + for msg in messages: + self._data += DirtyProtocol.encode(msg) + self._pos = 0 + + async def readexactly(self, n): + if self._pos + n > len(self._data): + raise asyncio.IncompleteReadError(self._data[self._pos:], n) + result = self._data[self._pos:self._pos + n] + self._pos += n + return result + + +def create_arbiter(): + """Create a test arbiter with mocked components.""" + cfg = mock.Mock() + cfg.dirty_timeout = 30 + cfg.dirty_workers = 1 + cfg.dirty_apps = [] + cfg.dirty_graceful_timeout = 30 + cfg.on_dirty_starting = mock.Mock() + cfg.dirty_post_fork = mock.Mock() + cfg.dirty_worker_exit = mock.Mock() + + log = mock.Mock() + + with mock.patch('tempfile.mkdtemp', return_value='/tmp/test-dirty'): + arbiter = DirtyArbiter(cfg, log) + + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} # Fake worker + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + return arbiter + + +class TestArbiterStreamingForwarding: + """Tests for arbiter streaming message forwarding.""" + + @pytest.mark.asyncio + async def test_forwards_chunk_messages(self): + """Test that arbiter forwards chunk messages to client.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + # Mock worker connection that returns chunks + chunk1 = make_chunk_message("req-123", "Hello") + chunk2 = make_chunk_message("req-123", " World") + end = make_end_message("req-123") + + mock_reader = MockStreamReader([chunk1, chunk2, end]) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + # Should have forwarded all messages + assert len(client_writer.messages) == 3 + assert client_writer.messages[0]["type"] == "chunk" + assert client_writer.messages[0]["data"] == "Hello" + assert client_writer.messages[1]["type"] == "chunk" + assert client_writer.messages[1]["data"] == " World" + assert client_writer.messages[2]["type"] == "end" + + @pytest.mark.asyncio + async def test_forwards_regular_response(self): + """Test that arbiter forwards regular response to client.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + response = make_response("req-123", {"result": 42}) + mock_reader = MockStreamReader([response]) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "compute") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "response" + assert client_writer.messages[0]["result"] == {"result": 42} + + @pytest.mark.asyncio + async def test_forwards_error_mid_stream(self): + """Test that arbiter forwards error during streaming.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + chunk = make_chunk_message("req-123", "First") + error = make_error_response("req-123", DirtyError("Something broke")) + + mock_reader = MockStreamReader([chunk, error]) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 2 + assert client_writer.messages[0]["type"] == "chunk" + assert client_writer.messages[1]["type"] == "error" + + @pytest.mark.asyncio + async def test_timeout_during_streaming(self): + """Test that timeout during streaming sends error.""" + arbiter = create_arbiter() + arbiter.cfg.dirty_timeout = 0.01 # Very short timeout + client_writer = MockStreamWriter() + + # Reader that times out + class TimeoutReader: + async def readexactly(self, n): + await asyncio.sleep(1) # Longer than timeout + + async def mock_get_connection(pid): + return TimeoutReader(), MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "error" + assert "timeout" in client_writer.messages[0]["error"]["message"].lower() + + +class TestArbiterRouteRequestStreaming: + """Tests for route_request with streaming support.""" + + @pytest.mark.asyncio + async def test_route_request_no_workers(self): + """Test route_request when no workers available.""" + arbiter = create_arbiter() + arbiter.workers = {} # No workers + client_writer = MockStreamWriter() + + request = make_request("req-123", "test:App", "generate") + await arbiter.route_request(request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "error" + assert "No dirty workers" in client_writer.messages[0]["error"]["message"] + + @pytest.mark.asyncio + async def test_route_request_starts_consumer(self): + """Test that route_request starts consumer if needed.""" + arbiter = create_arbiter() + + # Mock _execute_on_worker to complete immediately + async def mock_execute(pid, request, client_writer): + response = make_response("req-123", "result") + await DirtyProtocol.write_message_async(client_writer, response) + + arbiter._execute_on_worker = mock_execute + + client_writer = MockStreamWriter() + request = make_request("req-123", "test:App", "compute") + + # Worker queue should be created + assert 1234 not in arbiter.worker_queues + + await arbiter.route_request(request, client_writer) + + # Consumer should have been started + assert 1234 in arbiter.worker_queues + assert 1234 in arbiter.worker_consumers + + # Clean up + arbiter.worker_consumers[1234].cancel() + + +class TestArbiterStreamingManyChunks: + """Tests for streaming with many chunks.""" + + @pytest.mark.asyncio + async def test_forwards_many_chunks(self): + """Test that arbiter forwards many chunks correctly.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + # Generate 50 chunks + end + messages = [] + for i in range(50): + messages.append(make_chunk_message("req-123", f"chunk-{i}")) + messages.append(make_end_message("req-123")) + + mock_reader = MockStreamReader(messages) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 51 + assert client_writer.messages[0]["data"] == "chunk-0" + assert client_writer.messages[49]["data"] == "chunk-49" + assert client_writer.messages[50]["type"] == "end" + + +class TestArbiterBackwardCompatibility: + """Tests for backward compatibility with non-streaming.""" + + @pytest.mark.asyncio + async def test_handles_regular_response(self): + """Test that regular (non-streaming) responses still work.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + response = make_response("req-123", [1, 2, 3, 4, 5]) + mock_reader = MockStreamReader([response]) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "get_list") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "response" + assert client_writer.messages[0]["result"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_handles_error_response(self): + """Test that error responses still work.""" + arbiter = create_arbiter() + client_writer = MockStreamWriter() + + error = make_error_response("req-123", DirtyError("Something failed")) + mock_reader = MockStreamReader([error]) + + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + + arbiter._get_worker_connection = mock_get_connection + + request = make_request("req-123", "test:App", "fail") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "error" diff --git a/tests/dirty/test_client_streaming.py b/tests/dirty/test_client_streaming.py new file mode 100644 index 00000000..7bc13525 --- /dev/null +++ b/tests/dirty/test_client_streaming.py @@ -0,0 +1,236 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty client sync streaming functionality.""" + +import socket +import struct +import pytest +from unittest import mock + +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_chunk_message, + make_end_message, + make_response, + make_error_response, +) +from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator +from gunicorn.dirty.errors import DirtyError, DirtyConnectionError + + +class MockSocket: + """Mock socket that returns predefined messages.""" + + def __init__(self, messages): + self._data = b'' + for msg in messages: + self._data += DirtyProtocol.encode(msg) + self._pos = 0 + self._sent = [] + self.closed = False + self._timeout = None + + def sendall(self, data): + self._sent.append(data) + + def recv(self, n, flags=0): + if self._pos >= len(self._data): + return b'' + end = min(self._pos + n, len(self._data)) + result = self._data[self._pos:end] + self._pos = end + return result + + def settimeout(self, timeout): + self._timeout = timeout + + def close(self): + self.closed = True + + +def create_client_with_mock_socket(messages): + """Create a client with a mock socket returning the given messages.""" + client = DirtyClient("/tmp/test.sock") + client._sock = MockSocket(messages) + return client + + +class TestDirtyStreamIterator: + """Tests for DirtyStreamIterator.""" + + def test_stream_returns_iterator(self): + """Test that stream() returns an iterator.""" + client = DirtyClient("/tmp/test.sock") + result = client.stream("test:App", "generate") + assert isinstance(result, DirtyStreamIterator) + + def test_stream_iterator_yields_chunks(self): + """Test that stream iterator yields chunks correctly.""" + messages = [ + make_chunk_message("req-123", "Hello"), + make_chunk_message("req-123", " "), + make_chunk_message("req-123", "World"), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + chunks = list(client.stream("test:App", "generate")) + + assert chunks == ["Hello", " ", "World"] + + def test_stream_iterator_yields_complex_chunks(self): + """Test that stream iterator yields complex data types.""" + messages = [ + make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), + make_chunk_message("req-123", {"token": "World", "score": 0.8}), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + chunks = list(client.stream("test:App", "generate")) + + assert len(chunks) == 2 + assert chunks[0]["token"] == "Hello" + assert chunks[1]["token"] == "World" + + def test_stream_iterator_handles_error(self): + """Test that stream iterator raises on error message.""" + messages = [ + make_chunk_message("req-123", "First"), + make_error_response("req-123", DirtyError("Something broke")), + ] + client = create_client_with_mock_socket(messages) + + iterator = client.stream("test:App", "generate") + + # First chunk should work + chunk = next(iterator) + assert chunk == "First" + + # Second should raise error + with pytest.raises(DirtyError) as exc_info: + next(iterator) + assert "Something broke" in str(exc_info.value) + + def test_stream_iterator_empty_stream(self): + """Test that empty stream (just end) works.""" + messages = [make_end_message("req-123")] + client = create_client_with_mock_socket(messages) + + chunks = list(client.stream("test:App", "generate")) + assert chunks == [] + + def test_stream_iterator_stops_after_exhausted(self): + """Test that iterator stays exhausted after StopIteration.""" + messages = [ + make_chunk_message("req-123", "Only"), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + iterator = client.stream("test:App", "generate") + + # Get the chunk + chunk = next(iterator) + assert chunk == "Only" + + # Should stop + with pytest.raises(StopIteration): + next(iterator) + + # Should stay stopped + with pytest.raises(StopIteration): + next(iterator) + + def test_stream_iterator_with_for_loop(self): + """Test stream iterator works in for loop.""" + messages = [ + make_chunk_message("req-123", "a"), + make_chunk_message("req-123", "b"), + make_chunk_message("req-123", "c"), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + result = "" + for chunk in client.stream("test:App", "generate"): + result += chunk + + assert result == "abc" + + def test_stream_sends_request_on_first_iteration(self): + """Test that request is sent on first next() call.""" + messages = [ + make_chunk_message("req-123", "data"), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + iterator = client.stream("test:App", "generate", "prompt_arg") + + # Before iteration, no request sent + assert len(client._sock._sent) == 0 + + # First iteration sends request + next(iterator) + assert len(client._sock._sent) == 1 + + # Decode sent request + sent_data = client._sock._sent[0] + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + sent_data[:DirtyProtocol.HEADER_SIZE] + )[0] + request = DirtyProtocol.decode( + sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + ) + + assert request["type"] == "request" + assert request["app_path"] == "test:App" + assert request["action"] == "generate" + assert request["args"] == ["prompt_arg"] + + +class TestDirtyStreamIteratorEdgeCases: + """Edge cases for streaming.""" + + def test_stream_many_chunks(self): + """Test streaming with many chunks.""" + messages = [] + for i in range(100): + messages.append(make_chunk_message("req-123", f"chunk-{i}")) + messages.append(make_end_message("req-123")) + + client = create_client_with_mock_socket(messages) + + chunks = list(client.stream("test:App", "generate")) + + assert len(chunks) == 100 + assert chunks[0] == "chunk-0" + assert chunks[99] == "chunk-99" + + def test_stream_with_kwargs(self): + """Test streaming with keyword arguments.""" + messages = [ + make_chunk_message("req-123", "data"), + make_end_message("req-123"), + ] + client = create_client_with_mock_socket(messages) + + # Use kwargs + list(client.stream("test:App", "generate", "arg1", key="value")) + + # Check the sent request includes kwargs + sent_data = client._sock._sent[0] + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + sent_data[:DirtyProtocol.HEADER_SIZE] + )[0] + request = DirtyProtocol.decode( + sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + ) + + assert request["args"] == ["arg1"] + assert request["kwargs"] == {"key": "value"} diff --git a/tests/dirty/test_client_streaming_async.py b/tests/dirty/test_client_streaming_async.py new file mode 100644 index 00000000..651c73d1 --- /dev/null +++ b/tests/dirty/test_client_streaming_async.py @@ -0,0 +1,267 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty client async streaming functionality.""" + +import asyncio +import struct +import pytest + +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_chunk_message, + make_end_message, + make_error_response, +) +from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator +from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError + + +class MockAsyncReader: + """Mock async reader that returns predefined messages.""" + + def __init__(self, messages): + self._data = b'' + for msg in messages: + self._data += DirtyProtocol.encode(msg) + self._pos = 0 + + async def readexactly(self, n): + if self._pos + n > len(self._data): + raise asyncio.IncompleteReadError(self._data[self._pos:], n) + result = self._data[self._pos:self._pos + n] + self._pos += n + return result + + +class MockAsyncWriter: + """Mock async writer that captures sent data.""" + + def __init__(self): + self._sent = [] + self.closed = False + + def write(self, data): + self._sent.append(data) + + async def drain(self): + pass + + def close(self): + self.closed = True + + async def wait_closed(self): + pass + + +def create_async_client_with_mocks(messages): + """Create a client with mock async reader/writer.""" + client = DirtyClient("/tmp/test.sock") + client._reader = MockAsyncReader(messages) + client._writer = MockAsyncWriter() + return client + + +class TestDirtyAsyncStreamIterator: + """Tests for DirtyAsyncStreamIterator.""" + + def test_stream_async_returns_async_iterator(self): + """Test that stream_async() returns an async iterator.""" + client = DirtyClient("/tmp/test.sock") + result = client.stream_async("test:App", "generate") + assert isinstance(result, DirtyAsyncStreamIterator) + + @pytest.mark.asyncio + async def test_async_stream_yields_chunks(self): + """Test that async stream iterator yields chunks correctly.""" + messages = [ + make_chunk_message("req-123", "Hello"), + make_chunk_message("req-123", " "), + make_chunk_message("req-123", "World"), + make_end_message("req-123"), + ] + client = create_async_client_with_mocks(messages) + + chunks = [] + async for chunk in client.stream_async("test:App", "generate"): + chunks.append(chunk) + + assert chunks == ["Hello", " ", "World"] + + @pytest.mark.asyncio + async def test_async_stream_yields_complex_chunks(self): + """Test that async stream iterator yields complex data types.""" + messages = [ + make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), + make_chunk_message("req-123", {"token": "World", "score": 0.8}), + make_end_message("req-123"), + ] + client = create_async_client_with_mocks(messages) + + chunks = [] + async for chunk in client.stream_async("test:App", "generate"): + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0]["token"] == "Hello" + assert chunks[1]["token"] == "World" + + @pytest.mark.asyncio + async def test_async_stream_handles_error(self): + """Test that async stream iterator raises on error message.""" + messages = [ + make_chunk_message("req-123", "First"), + make_error_response("req-123", DirtyError("Something broke")), + ] + client = create_async_client_with_mocks(messages) + + iterator = client.stream_async("test:App", "generate") + + # First chunk should work + chunk = await iterator.__anext__() + assert chunk == "First" + + # Second should raise error + with pytest.raises(DirtyError) as exc_info: + await iterator.__anext__() + assert "Something broke" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_async_stream_empty_stream(self): + """Test that empty stream (just end) works.""" + messages = [make_end_message("req-123")] + client = create_async_client_with_mocks(messages) + + chunks = [] + async for chunk in client.stream_async("test:App", "generate"): + chunks.append(chunk) + + assert chunks == [] + + @pytest.mark.asyncio + async def test_async_stream_stops_after_exhausted(self): + """Test that async iterator stays exhausted after StopAsyncIteration.""" + messages = [ + make_chunk_message("req-123", "Only"), + make_end_message("req-123"), + ] + client = create_async_client_with_mocks(messages) + + iterator = client.stream_async("test:App", "generate") + + # Get the chunk + chunk = await iterator.__anext__() + assert chunk == "Only" + + # Should stop + with pytest.raises(StopAsyncIteration): + await iterator.__anext__() + + # Should stay stopped + with pytest.raises(StopAsyncIteration): + await iterator.__anext__() + + @pytest.mark.asyncio + async def test_async_stream_sends_request_on_first_iteration(self): + """Test that request is sent on first async iteration.""" + messages = [ + make_chunk_message("req-123", "data"), + make_end_message("req-123"), + ] + client = create_async_client_with_mocks(messages) + + iterator = client.stream_async("test:App", "generate", "prompt_arg") + + # Before iteration, no request sent + assert len(client._writer._sent) == 0 + + # First iteration sends request + await iterator.__anext__() + assert len(client._writer._sent) == 1 + + # Decode sent request + sent_data = client._writer._sent[0] + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + sent_data[:DirtyProtocol.HEADER_SIZE] + )[0] + request = DirtyProtocol.decode( + sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + ) + + assert request["type"] == "request" + assert request["app_path"] == "test:App" + assert request["action"] == "generate" + assert request["args"] == ["prompt_arg"] + + +class TestDirtyAsyncStreamIteratorEdgeCases: + """Edge cases for async streaming.""" + + @pytest.mark.asyncio + async def test_async_stream_many_chunks(self): + """Test async streaming with many chunks.""" + messages = [] + for i in range(100): + messages.append(make_chunk_message("req-123", f"chunk-{i}")) + messages.append(make_end_message("req-123")) + + client = create_async_client_with_mocks(messages) + + chunks = [] + async for chunk in client.stream_async("test:App", "generate"): + chunks.append(chunk) + + assert len(chunks) == 100 + assert chunks[0] == "chunk-0" + assert chunks[99] == "chunk-99" + + @pytest.mark.asyncio + async def test_async_stream_with_kwargs(self): + """Test async streaming with keyword arguments.""" + messages = [ + make_chunk_message("req-123", "data"), + make_end_message("req-123"), + ] + client = create_async_client_with_mocks(messages) + + # Use kwargs + chunks = [] + async for chunk in client.stream_async("test:App", "generate", "arg1", key="value"): + chunks.append(chunk) + + # Check the sent request includes kwargs + sent_data = client._writer._sent[0] + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + sent_data[:DirtyProtocol.HEADER_SIZE] + )[0] + request = DirtyProtocol.decode( + sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + ) + + assert request["args"] == ["arg1"] + assert request["kwargs"] == {"key": "value"} + + +class TestDirtyAsyncStreamTimeout: + """Tests for async streaming timeout handling.""" + + @pytest.mark.asyncio + async def test_async_stream_timeout(self): + """Test that timeout during async streaming raises DirtyTimeoutError.""" + client = DirtyClient("/tmp/test.sock", timeout=0.01) + + # Create a reader that times out + class SlowReader: + async def readexactly(self, n): + await asyncio.sleep(1) # Longer than timeout + + client._reader = SlowReader() + client._writer = MockAsyncWriter() + + iterator = client.stream_async("test:App", "generate") + + with pytest.raises(DirtyTimeoutError): + await iterator.__anext__() diff --git a/tests/dirty/test_streaming_integration.py b/tests/dirty/test_streaming_integration.py new file mode 100644 index 00000000..06b9645f --- /dev/null +++ b/tests/dirty/test_streaming_integration.py @@ -0,0 +1,469 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Integration tests for dirty streaming functionality. + +These tests verify the full streaming pipeline: +client -> arbiter -> worker -> generator -> chunks -> client +""" + +import asyncio +import os +import struct +import tempfile +import pytest +from unittest import mock + +from gunicorn.config import Config +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_request, + make_chunk_message, + make_end_message, + make_response, + make_error_response, +) +from gunicorn.dirty.worker import DirtyWorker +from gunicorn.dirty.arbiter import DirtyArbiter +from gunicorn.dirty.client import DirtyClient +from gunicorn.dirty.errors import DirtyError + + +class MockLog: + """Mock logger for testing.""" + + def __init__(self): + self.messages = [] + + def debug(self, msg, *args): + self.messages.append(("debug", msg % args if args else msg)) + + def info(self, msg, *args): + self.messages.append(("info", msg % args if args else msg)) + + def warning(self, msg, *args): + self.messages.append(("warning", msg % args if args else msg)) + + def error(self, msg, *args): + self.messages.append(("error", msg % args if args else msg)) + + def close_on_exec(self): + pass + + def reopen_files(self): + pass + + +class MockStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + self.closed = False + + def write(self, data): + self._buffer += data + + async def drain(self): + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + self.closed = True + + async def wait_closed(self): + pass + + def get_extra_info(self, name): + return None + + +class MockStreamReader: + """Mock StreamReader that yields predefined messages.""" + + def __init__(self, messages): + self._data = b'' + for msg in messages: + self._data += DirtyProtocol.encode(msg) + self._pos = 0 + + async def readexactly(self, n): + if self._pos + n > len(self._data): + raise asyncio.IncompleteReadError(self._data[self._pos:], n) + result = self._data[self._pos:self._pos + n] + self._pos += n + return result + + +class TestStreamingEndToEnd: + """End-to-end streaming tests using mocked components.""" + + @pytest.mark.asyncio + async def test_sync_generator_end_to_end(self): + """Test complete flow: sync generator -> worker -> arbiter -> client.""" + # Simulate what a worker would produce for a sync generator + worker_messages = [ + make_chunk_message("req-123", "Hello"), + make_chunk_message("req-123", " "), + make_chunk_message("req-123", "World"), + make_end_message("req-123"), + ] + + # Create an arbiter with mocked worker connection + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + # Mock worker connection + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + # Create client writer to capture messages + client_writer = MockStreamWriter() + + # Execute request through arbiter + request = make_request("req-123", "test:App", "generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + # Verify all messages were forwarded + assert len(client_writer.messages) == 4 + assert client_writer.messages[0]["type"] == "chunk" + assert client_writer.messages[0]["data"] == "Hello" + assert client_writer.messages[1]["data"] == " " + assert client_writer.messages[2]["data"] == "World" + assert client_writer.messages[3]["type"] == "end" + + arbiter._cleanup_sync() + + @pytest.mark.asyncio + async def test_async_generator_end_to_end(self): + """Test complete flow: async generator -> worker -> arbiter -> client.""" + worker_messages = [ + make_chunk_message("req-456", "Async"), + make_chunk_message("req-456", " "), + make_chunk_message("req-456", "Stream"), + make_end_message("req-456"), + ] + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-456", "test:App", "async_generate") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 4 + assert client_writer.messages[0]["data"] == "Async" + assert client_writer.messages[3]["type"] == "end" + + arbiter._cleanup_sync() + + +class TestStreamingErrorHandling: + """Tests for error handling during streaming.""" + + @pytest.mark.asyncio + async def test_error_mid_stream(self): + """Test that errors during streaming are properly forwarded.""" + worker_messages = [ + make_chunk_message("req-789", "First"), + make_chunk_message("req-789", "Second"), + make_error_response("req-789", DirtyError("Stream failed")), + ] + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-789", "test:App", "generate_with_error") + await arbiter._execute_on_worker(1234, request, client_writer) + + # Should have 2 chunks + 1 error + assert len(client_writer.messages) == 3 + assert client_writer.messages[0]["type"] == "chunk" + assert client_writer.messages[1]["type"] == "chunk" + assert client_writer.messages[2]["type"] == "error" + assert "Stream failed" in client_writer.messages[2]["error"]["message"] + + arbiter._cleanup_sync() + + +class TestStreamingBackwardCompatibility: + """Tests for backward compatibility with non-streaming responses.""" + + @pytest.mark.asyncio + async def test_non_streaming_response_still_works(self): + """Test that regular (non-streaming) responses still work.""" + worker_messages = [ + make_response("req-abc", {"result": 42, "data": [1, 2, 3]}), + ] + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-abc", "test:App", "compute") + await arbiter._execute_on_worker(1234, request, client_writer) + + # Should have 1 response + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "response" + assert client_writer.messages[0]["result"]["result"] == 42 + + arbiter._cleanup_sync() + + @pytest.mark.asyncio + async def test_error_response_still_works(self): + """Test that error responses still work.""" + worker_messages = [ + make_error_response("req-def", DirtyError("Something failed")), + ] + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-def", "test:App", "fail") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 1 + assert client_writer.messages[0]["type"] == "error" + + arbiter._cleanup_sync() + + +class TestStreamingWorkerIntegration: + """Integration tests for worker streaming with execute.""" + + @pytest.mark.asyncio + async def test_worker_handles_sync_generator(self): + """Test worker properly handles sync generator from execute.""" + cfg = Config() + cfg.set("dirty_timeout", 300) + log = MockLog() + + with mock.patch('gunicorn.dirty.worker.WorkerTmp'): + worker = DirtyWorker( + age=1, + ppid=os.getpid(), + app_paths=["test:App"], + cfg=cfg, + log=log, + socket_path="/tmp/test.sock" + ) + + worker.apps = {} + worker._executor = None + worker.tmp = mock.Mock() + + writer = MockStreamWriter() + + # Mock execute to return a sync generator + def sync_gen(): + yield "one" + yield "two" + yield "three" + + async def mock_execute(app_path, action, args, kwargs): + return sync_gen() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 3 chunks + 1 end + assert len(writer.messages) == 4 + assert writer.messages[0]["data"] == "one" + assert writer.messages[1]["data"] == "two" + assert writer.messages[2]["data"] == "three" + assert writer.messages[3]["type"] == "end" + + @pytest.mark.asyncio + async def test_worker_handles_async_generator(self): + """Test worker properly handles async generator from execute.""" + cfg = Config() + cfg.set("dirty_timeout", 300) + log = MockLog() + + with mock.patch('gunicorn.dirty.worker.WorkerTmp'): + worker = DirtyWorker( + age=1, + ppid=os.getpid(), + app_paths=["test:App"], + cfg=cfg, + log=log, + socket_path="/tmp/test.sock" + ) + + worker.apps = {} + worker._executor = None + worker.tmp = mock.Mock() + + writer = MockStreamWriter() + + # Mock execute to return an async generator + async def async_gen(): + yield "async_one" + yield "async_two" + + async def mock_execute(app_path, action, args, kwargs): + return async_gen() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-456", "test:App", "async_generate") + await worker.handle_request(request, writer) + + # Should have 2 chunks + 1 end + assert len(writer.messages) == 3 + assert writer.messages[0]["data"] == "async_one" + assert writer.messages[1]["data"] == "async_two" + assert writer.messages[2]["type"] == "end" + + +class TestStreamingMixedScenarios: + """Tests for mixed streaming scenarios.""" + + @pytest.mark.asyncio + async def test_large_stream(self): + """Test streaming with many chunks.""" + worker_messages = [] + for i in range(500): + worker_messages.append(make_chunk_message("req-large", f"chunk-{i}")) + worker_messages.append(make_end_message("req-large")) + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-large", "test:App", "large_stream") + await arbiter._execute_on_worker(1234, request, client_writer) + + # Should have 500 chunks + 1 end + assert len(client_writer.messages) == 501 + assert client_writer.messages[0]["data"] == "chunk-0" + assert client_writer.messages[499]["data"] == "chunk-499" + assert client_writer.messages[500]["type"] == "end" + + arbiter._cleanup_sync() + + @pytest.mark.asyncio + async def test_stream_with_complex_data(self): + """Test streaming with complex JSON-serializable data.""" + worker_messages = [ + make_chunk_message("req-complex", { + "token": "Hello", + "scores": [0.1, 0.2, 0.3], + "metadata": {"position": 0} + }), + make_chunk_message("req-complex", { + "token": "World", + "scores": [0.4, 0.5], + "metadata": {"position": 1} + }), + make_end_message("req-complex"), + ] + + cfg = Config() + cfg.set("dirty_timeout", 30) + log = MockLog() + + arbiter = DirtyArbiter(cfg=cfg, log=log) + arbiter.alive = True + arbiter.workers = {1234: mock.Mock()} + arbiter.worker_sockets = {1234: '/tmp/worker.sock'} + + mock_reader = MockStreamReader(worker_messages) + async def mock_get_connection(pid): + return mock_reader, MockStreamWriter() + arbiter._get_worker_connection = mock_get_connection + + client_writer = MockStreamWriter() + + request = make_request("req-complex", "test:App", "complex_stream") + await arbiter._execute_on_worker(1234, request, client_writer) + + assert len(client_writer.messages) == 3 + assert client_writer.messages[0]["data"]["token"] == "Hello" + assert client_writer.messages[0]["data"]["scores"] == [0.1, 0.2, 0.3] + assert client_writer.messages[1]["data"]["metadata"]["position"] == 1 + + arbiter._cleanup_sync() diff --git a/tests/dirty/test_worker_streaming.py b/tests/dirty/test_worker_streaming.py new file mode 100644 index 00000000..bb674590 --- /dev/null +++ b/tests/dirty/test_worker_streaming.py @@ -0,0 +1,415 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty worker streaming functionality.""" + +import asyncio +import struct +from unittest import mock + +import pytest + +from gunicorn.dirty.protocol import ( + DirtyProtocol, + make_request, + make_chunk_message, + make_end_message, +) +from gunicorn.dirty.worker import DirtyWorker + + +class FakeStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + + def write(self, data): + self._buffer += data + + async def drain(self): + # Decode the buffer to extract messages + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + pass + + async def wait_closed(self): + pass + + +def create_worker(): + """Create a test worker with mocked components.""" + cfg = mock.Mock() + cfg.dirty_timeout = 30 + cfg.dirty_threads = 1 + cfg.env = None + cfg.uid = None + cfg.gid = None + cfg.initgroups = False + cfg.dirty_worker_init = mock.Mock() + cfg.umask = 0o22 + + log = mock.Mock() + + with mock.patch('gunicorn.dirty.worker.WorkerTmp'): + worker = DirtyWorker( + age=1, + ppid=1, + app_paths=["test:App"], + cfg=cfg, + log=log, + socket_path="/tmp/test.sock" + ) + + worker.apps = {} + worker._executor = None # Use default executor for sync generator tests + worker.tmp = mock.Mock() + + return worker + + +class TestWorkerSyncGeneratorStreaming: + """Tests for sync generator streaming.""" + + @pytest.mark.asyncio + async def test_sync_generator_sends_chunks_and_end(self): + """Test that sync generator sends chunk messages then end message.""" + def generate_tokens(): + yield "Hello" + yield " " + yield "World" + + worker = create_worker() + writer = FakeStreamWriter() + + # Mock execute to return the sync generator directly + async def mock_execute(app_path, action, args, kwargs): + return generate_tokens() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 3 chunks + 1 end message + assert len(writer.messages) == 4 + + # Check chunk messages + assert writer.messages[0]["type"] == "chunk" + assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["data"] == "Hello" + + assert writer.messages[1]["type"] == "chunk" + assert writer.messages[1]["data"] == " " + + assert writer.messages[2]["type"] == "chunk" + assert writer.messages[2]["data"] == "World" + + # Check end message + assert writer.messages[3]["type"] == "end" + assert writer.messages[3]["id"] == "req-123" + + @pytest.mark.asyncio + async def test_sync_generator_error_mid_stream(self): + """Test that error during streaming sends error message.""" + def generate_with_error(): + yield "First" + raise ValueError("Something went wrong") + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return generate_with_error() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 1 chunk + 1 error message + assert len(writer.messages) == 2 + + assert writer.messages[0]["type"] == "chunk" + assert writer.messages[0]["data"] == "First" + + assert writer.messages[1]["type"] == "error" + assert "Something went wrong" in writer.messages[1]["error"]["message"] + + +class TestWorkerAsyncGeneratorStreaming: + """Tests for async generator streaming.""" + + @pytest.mark.asyncio + async def test_async_generator_sends_chunks_and_end(self): + """Test that async generator sends chunk messages then end message.""" + async def async_generate_tokens(): + yield "Hello" + yield " " + yield "World" + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return async_generate_tokens() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 3 chunks + 1 end message + assert len(writer.messages) == 4 + + # Check chunk messages + assert writer.messages[0]["type"] == "chunk" + assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["data"] == "Hello" + + assert writer.messages[1]["type"] == "chunk" + assert writer.messages[1]["data"] == " " + + assert writer.messages[2]["type"] == "chunk" + assert writer.messages[2]["data"] == "World" + + # Check end message + assert writer.messages[3]["type"] == "end" + assert writer.messages[3]["id"] == "req-123" + + @pytest.mark.asyncio + async def test_async_generator_error_mid_stream(self): + """Test that error during async streaming sends error message.""" + async def async_generate_with_error(): + yield "First" + raise ValueError("Async error") + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return async_generate_with_error() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 1 chunk + 1 error message + assert len(writer.messages) == 2 + + assert writer.messages[0]["type"] == "chunk" + assert writer.messages[0]["data"] == "First" + + assert writer.messages[1]["type"] == "error" + assert "Async error" in writer.messages[1]["error"]["message"] + + +class TestWorkerNonStreamingBackwardCompat: + """Tests for backward compatibility with non-streaming responses.""" + + @pytest.mark.asyncio + async def test_non_generator_returns_response(self): + """Test that non-generator method returns regular response.""" + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return args[0] + args[1] + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "compute", args=(2, 3)) + await worker.handle_request(request, writer) + + # Should have 1 response message + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "response" + assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["result"] == 5 + + @pytest.mark.asyncio + async def test_list_result_not_treated_as_streaming(self): + """Test that list result is not treated as streaming.""" + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return [1, 2, 3, 4, 5] + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "get_list") + await worker.handle_request(request, writer) + + # Should have 1 response message (not 5 chunks) + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "response" + assert writer.messages[0]["result"] == [1, 2, 3, 4, 5] + + @pytest.mark.asyncio + async def test_error_in_execute_sends_error(self): + """Test that error in execute sends error response.""" + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + raise RuntimeError("Failed!") + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "fail") + await worker.handle_request(request, writer) + + # Should have 1 error message + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "error" + assert "Failed!" in writer.messages[0]["error"]["message"] + + @pytest.mark.asyncio + async def test_none_result(self): + """Test that None result works correctly.""" + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return None + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "void") + await worker.handle_request(request, writer) + + # Should have 1 response message + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "response" + assert writer.messages[0]["result"] is None + + +class TestWorkerStreamingComplexData: + """Tests for streaming with complex data types.""" + + @pytest.mark.asyncio + async def test_streaming_dict_chunks(self): + """Test streaming chunks that are dictionaries.""" + async def generate_tokens(): + yield {"token": "Hello", "score": 0.9} + yield {"token": "World", "score": 0.8} + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return generate_tokens() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + assert len(writer.messages) == 3 # 2 chunks + 1 end + + assert writer.messages[0]["data"]["token"] == "Hello" + assert writer.messages[0]["data"]["score"] == 0.9 + assert writer.messages[1]["data"]["token"] == "World" + + @pytest.mark.asyncio + async def test_streaming_empty_generator(self): + """Test streaming with empty generator.""" + async def empty_generate(): + return + yield # Make it a generator + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return empty_generate() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have just 1 end message + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "end" + + @pytest.mark.asyncio + async def test_streaming_many_chunks(self): + """Test streaming with many chunks.""" + async def generate_many(): + for i in range(100): + yield f"chunk-{i}" + + worker = create_worker() + writer = FakeStreamWriter() + + async def mock_execute(app_path, action, args, kwargs): + return generate_many() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have 100 chunks + 1 end message + assert len(writer.messages) == 101 + assert writer.messages[0]["data"] == "chunk-0" + assert writer.messages[99]["data"] == "chunk-99" + assert writer.messages[100]["type"] == "end" + + +class TestWorkerStreamingHeartbeat: + """Tests for heartbeat updates during streaming.""" + + @pytest.mark.asyncio + async def test_heartbeat_updated_during_streaming(self): + """Test that heartbeat is updated during streaming.""" + async def generate_tokens(): + yield "Hello" + yield "World" + + worker = create_worker() + writer = FakeStreamWriter() + + # Track notify calls + notify_count = [0] + original_notify = worker.notify + + def counting_notify(): + notify_count[0] += 1 + return original_notify() if callable(original_notify) else None + + worker.notify = counting_notify + + async def mock_execute(app_path, action, args, kwargs): + return generate_tokens() + + with mock.patch.object(worker, 'execute', side_effect=mock_execute): + request = make_request("req-123", "test:App", "generate") + await worker.handle_request(request, writer) + + # Should have been notified at least once per chunk + initial + assert notify_count[0] >= 2 # At least one per chunk + + +class TestWorkerMessageTypeValidation: + """Tests for message type validation.""" + + @pytest.mark.asyncio + async def test_unknown_message_type_sends_error(self): + """Test that unknown message type sends error response.""" + worker = create_worker() + writer = FakeStreamWriter() + + # Send a message with unknown type + message = {"type": "unknown", "id": "req-123"} + await worker.handle_request(message, writer) + + assert len(writer.messages) == 1 + assert writer.messages[0]["type"] == "error" + assert "Unknown message type" in writer.messages[0]["error"]["message"] diff --git a/tests/test_dirty_arbiter.py b/tests/test_dirty_arbiter.py index c8c9444c..11bcd796 100644 --- a/tests/test_dirty_arbiter.py +++ b/tests/test_dirty_arbiter.py @@ -7,6 +7,7 @@ import asyncio import os import signal +import struct import tempfile import pytest @@ -16,6 +17,41 @@ from gunicorn.dirty.errors import DirtyError from gunicorn.dirty.protocol import DirtyProtocol, make_request +class MockStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + self.closed = False + + def write(self, data): + self._buffer += data + + async def drain(self): + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + self.closed = True + + async def wait_closed(self): + pass + + def get_extra_info(self, name): + return None + + class MockLog: """Mock logger for testing.""" @@ -141,8 +177,11 @@ class TestDirtyArbiterRouteRequest: action="test" ) - response = await arbiter.route_request(request) + writer = MockStreamWriter() + await arbiter.route_request(request, writer) + assert len(writer.messages) == 1 + response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR assert "No dirty workers available" in response["error"]["message"] @@ -430,8 +469,11 @@ class TestDirtyArbiterRouteTimeout: ) # This should fail because socket doesn't exist - response = await arbiter.route_request(request) + writer = MockStreamWriter() + await arbiter.route_request(request, writer) + assert len(writer.messages) == 1 + response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR # Either "Worker communication failed" or "Worker socket not ready" assert "error" in response @@ -893,7 +935,8 @@ class TestDirtyArbiterQueueBehavior: ) # This will fail (no socket), but consumer should be started - await arbiter.route_request(request) + writer = MockStreamWriter() + await arbiter.route_request(request, writer) assert fake_pid in arbiter.worker_queues assert fake_pid in arbiter.worker_consumers diff --git a/tests/test_dirty_protocol.py b/tests/test_dirty_protocol.py index 7ea0da14..dbabc51e 100644 --- a/tests/test_dirty_protocol.py +++ b/tests/test_dirty_protocol.py @@ -15,6 +15,8 @@ from gunicorn.dirty.protocol import ( make_request, make_response, make_error_response, + make_chunk_message, + make_end_message, ) from gunicorn.dirty.errors import ( DirtyError, @@ -323,6 +325,51 @@ class TestMessageBuilders: assert response["error"]["error_type"] == "ValueError" assert response["error"]["message"] == "Invalid value" + def test_make_chunk_message(self): + """Test chunk message builder.""" + chunk = make_chunk_message("req-123", "Hello, ") + assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK + assert chunk["id"] == "req-123" + assert chunk["data"] == "Hello, " + + def test_make_chunk_message_with_complex_data(self): + """Test chunk message with complex data.""" + data = {"token": "world", "score": 0.95, "index": 5} + chunk = make_chunk_message("req-456", data) + assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK + assert chunk["id"] == "req-456" + assert chunk["data"] == data + + def test_make_chunk_message_with_list_data(self): + """Test chunk message with list data.""" + data = [1, 2, 3, "token"] + chunk = make_chunk_message("req-789", data) + assert chunk["data"] == data + + def test_make_end_message(self): + """Test end message builder.""" + end = make_end_message("req-123") + assert end["type"] == DirtyProtocol.MSG_TYPE_END + assert end["id"] == "req-123" + assert "data" not in end + + def test_chunk_and_end_encode_decode(self): + """Test that chunk and end messages can be encoded/decoded.""" + chunk = make_chunk_message("req-123", {"token": "hello"}) + end = make_end_message("req-123") + + # Test chunk roundtrip + encoded_chunk = DirtyProtocol.encode(chunk) + payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:] + decoded = DirtyProtocol.decode(payload) + assert decoded == chunk + + # Test end roundtrip + encoded_end = DirtyProtocol.encode(end) + payload = encoded_end[DirtyProtocol.HEADER_SIZE:] + decoded = DirtyProtocol.decode(payload) + assert decoded == end + class TestDirtyErrors: """Tests for error classes.""" diff --git a/tests/test_dirty_worker.py b/tests/test_dirty_worker.py index 6bfaa68f..f68a2276 100644 --- a/tests/test_dirty_worker.py +++ b/tests/test_dirty_worker.py @@ -16,6 +16,9 @@ from gunicorn.dirty.protocol import DirtyProtocol, make_request from gunicorn.dirty.errors import DirtyAppNotFoundError +import struct + + class MockLog: """Mock logger for testing.""" @@ -41,6 +44,42 @@ class MockLog: pass +class MockStreamWriter: + """Mock StreamWriter that captures written messages.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + self.closed = False + + def write(self, data): + self._buffer += data + + async def drain(self): + # Decode the buffer to extract messages + while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: + length = struct.unpack( + DirtyProtocol.HEADER_FORMAT, + self._buffer[:DirtyProtocol.HEADER_SIZE] + )[0] + total_size = DirtyProtocol.HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + self._buffer = self._buffer[total_size:] + self.messages.append(DirtyProtocol.decode(msg_data)) + else: + break + + def close(self): + self.closed = True + + async def wait_closed(self): + pass + + def get_extra_info(self, name): + return None + + class TestDirtyWorkerInit: """Tests for DirtyWorker initialization.""" @@ -214,8 +253,11 @@ class TestDirtyWorkerHandleRequest: kwargs={"operation": "multiply"} ) - response = await worker.handle_request(request) + writer = MockStreamWriter() + await worker.handle_request(request, writer) + assert len(writer.messages) == 1 + response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE assert response["id"] == "test-123" assert response["result"] == 6 @@ -247,8 +289,11 @@ class TestDirtyWorkerHandleRequest: kwargs={"operation": "invalid"} ) - response = await worker.handle_request(request) + writer = MockStreamWriter() + await worker.handle_request(request, writer) + assert len(writer.messages) == 1 + response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR assert response["id"] == "test-456" assert "Unknown operation" in response["error"]["message"] @@ -271,8 +316,11 @@ class TestDirtyWorkerHandleRequest: ) request = {"type": "unknown", "id": "test-789"} - response = await worker.handle_request(request) + writer = MockStreamWriter() + await worker.handle_request(request, writer) + assert len(writer.messages) == 1 + response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR assert "Unknown message type" in response["error"]["message"] @@ -662,43 +710,21 @@ class TestDirtyWorkerRunAsync: reader.feed_data(encoded_request) reader.feed_eof() - class MockWriter: - def __init__(self): - self.closed = False - self.data = b"" - - def get_extra_info(self, name): - return None - - def write(self, data): - self.data += data - - async def drain(self): - pass - - def close(self): - self.closed = True - - async def wait_closed(self): - pass - - writer = MockWriter() + writer = MockStreamWriter() # Handle one message then exit worker.alive = True try: message = await DirtyProtocol.read_message_async(reader) - response = await worker.handle_request(message) - await DirtyProtocol.write_message_async(writer, response) + await worker.handle_request(message, writer) except asyncio.IncompleteReadError: pass - # Decode response from writer - if writer.data: - payload = writer.data[DirtyProtocol.HEADER_SIZE:] - response = DirtyProtocol.decode(payload) - assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE - assert response["result"] == 8 + # Check response from writer + assert len(writer.messages) == 1 + response = writer.messages[0] + assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE + assert response["result"] == 8 worker._cleanup()