mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 19:21:29 +08:00
feat(dirty): add streaming support and async client benchmarks
Add support for streaming responses when dirty app actions return generators (sync or async). This enables real-time delivery of incremental results for use cases like LLM token generation. Features: - Streaming protocol with chunk/end/error message types - Worker support for sync and async generators - Arbiter forwarding of streaming messages - Deadline-based timeout handling - Async client streaming API Protocol: - Chunk messages (type: "chunk") contain partial data - End messages (type: "end") signal stream completion - Error messages can occur mid-stream New files: - benchmarks/dirty_streaming.py: Streaming benchmark suite - tests/dirty/test_*_streaming*.py: Streaming test coverage - docs/content/dirty.md: Streaming documentation with examples
This commit is contained in:
parent
62a29bd0e1
commit
f6418d4eb0
755
benchmarks/dirty_streaming.py
Normal file
755
benchmarks/dirty_streaming.py
Normal file
@ -0,0 +1,755 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Benchmark suite for dirty worker streaming functionality.
|
||||
|
||||
This script benchmarks the streaming performance of dirty workers
|
||||
to measure throughput, latency, and memory usage.
|
||||
|
||||
Usage:
|
||||
python benchmarks/dirty_streaming.py [OPTIONS]
|
||||
|
||||
Options:
|
||||
--quick Run quick benchmarks only
|
||||
--full Run full benchmark suite including stress tests
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import time
|
||||
import tracemalloc
|
||||
from datetime import datetime
|
||||
from unittest import mock
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.client import (
|
||||
DirtyClient,
|
||||
DirtyStreamIterator,
|
||||
DirtyAsyncStreamIterator,
|
||||
)
|
||||
from gunicorn.config import Config
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
self.bytes_written = 0
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
self.bytes_written += len(data)
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock StreamReader that yields predefined messages."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self._pos + n > len(self._data):
|
||||
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockLog:
|
||||
"""Silent logger for benchmarks."""
|
||||
|
||||
def debug(self, msg, *args):
|
||||
pass
|
||||
|
||||
def info(self, msg, *args):
|
||||
pass
|
||||
|
||||
def warning(self, msg, *args):
|
||||
pass
|
||||
|
||||
def error(self, msg, *args):
|
||||
pass
|
||||
|
||||
def close_on_exec(self):
|
||||
pass
|
||||
|
||||
def reopen_files(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_worker():
|
||||
"""Create a test worker for benchmarks."""
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 300)
|
||||
log = MockLog()
|
||||
|
||||
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
|
||||
worker = DirtyWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
app_paths=["benchmark:App"],
|
||||
cfg=cfg,
|
||||
log=log,
|
||||
socket_path="/tmp/benchmark.sock"
|
||||
)
|
||||
|
||||
worker.apps = {}
|
||||
worker._executor = None
|
||||
worker.tmp = mock.Mock()
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_arbiter():
|
||||
"""Create a test arbiter for benchmarks."""
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 300)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
return arbiter
|
||||
|
||||
|
||||
class BenchmarkResults:
|
||||
"""Store and display benchmark results."""
|
||||
|
||||
def __init__(self):
|
||||
self.results = []
|
||||
|
||||
def add(self, name, iterations, duration, chunks=None, bytes_total=None,
|
||||
memory_start=None, memory_end=None):
|
||||
throughput = iterations / duration if duration > 0 else 0
|
||||
result = {
|
||||
"name": name,
|
||||
"iterations": iterations,
|
||||
"duration_s": round(duration, 4),
|
||||
"throughput_per_s": round(throughput, 2),
|
||||
}
|
||||
if chunks:
|
||||
result["chunks_per_s"] = round(chunks / duration, 2)
|
||||
if bytes_total:
|
||||
result["mb_per_s"] = round(bytes_total / (1024 * 1024) / duration, 2)
|
||||
if memory_start is not None and memory_end is not None:
|
||||
result["memory_start_mb"] = round(memory_start / (1024 * 1024), 2)
|
||||
result["memory_end_mb"] = round(memory_end / (1024 * 1024), 2)
|
||||
result["memory_delta_mb"] = round((memory_end - memory_start) / (1024 * 1024), 2)
|
||||
self.results.append(result)
|
||||
|
||||
def display(self):
|
||||
print("\n" + "=" * 70)
|
||||
print("BENCHMARK RESULTS")
|
||||
print("=" * 70)
|
||||
for result in self.results:
|
||||
print(f"\n{result['name']}")
|
||||
print("-" * 50)
|
||||
for key, value in result.items():
|
||||
if key != "name":
|
||||
print(f" {key}: {value}")
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
def save_json(self, filepath):
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump({
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"results": self.results
|
||||
}, f, indent=2)
|
||||
print(f"Results saved to {filepath}")
|
||||
|
||||
|
||||
async def benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000):
|
||||
"""Benchmark worker streaming throughput with various chunk sizes."""
|
||||
worker = create_worker()
|
||||
writer = MockStreamWriter()
|
||||
|
||||
chunk_data = "x" * chunk_size
|
||||
|
||||
async def sync_gen():
|
||||
for _ in range(num_chunks):
|
||||
yield chunk_data
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return sync_gen()
|
||||
|
||||
gc.collect()
|
||||
tracemalloc.start()
|
||||
memory_start = tracemalloc.get_traced_memory()[0]
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("bench-1", "benchmark:App", "stream")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
duration = time.perf_counter() - start
|
||||
memory_end = tracemalloc.get_traced_memory()[0]
|
||||
tracemalloc.stop()
|
||||
|
||||
total_bytes = chunk_size * num_chunks
|
||||
|
||||
results.add(
|
||||
f"Worker streaming ({chunk_size}B chunks, {num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=duration,
|
||||
chunks=num_chunks,
|
||||
bytes_total=total_bytes,
|
||||
memory_start=memory_start,
|
||||
memory_end=memory_end
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_arbiter_forwarding(results, num_chunks=1000):
|
||||
"""Benchmark arbiter message forwarding throughput."""
|
||||
arbiter = create_arbiter()
|
||||
|
||||
messages = []
|
||||
for i in range(num_chunks):
|
||||
messages.append(make_chunk_message(f"bench-{i}", f"data-{i}"))
|
||||
messages.append(make_end_message(f"bench-{num_chunks}"))
|
||||
|
||||
mock_reader = MockStreamReader(messages)
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
gc.collect()
|
||||
start = time.perf_counter()
|
||||
|
||||
request = make_request("bench-forward", "benchmark:App", "stream")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
results.add(
|
||||
f"Arbiter forwarding ({num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=duration,
|
||||
chunks=num_chunks,
|
||||
bytes_total=client_writer.bytes_written
|
||||
)
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
|
||||
async def benchmark_streaming_latency(results, iterations=100):
|
||||
"""Benchmark time-to-first-chunk and time-to-last-chunk."""
|
||||
worker = create_worker()
|
||||
|
||||
first_chunk_times = []
|
||||
total_times = []
|
||||
|
||||
for _ in range(iterations):
|
||||
writer = MockStreamWriter()
|
||||
|
||||
async def gen_3_chunks():
|
||||
yield "first"
|
||||
yield "second"
|
||||
yield "third"
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return gen_3_chunks()
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("bench-latency", "benchmark:App", "stream")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Find time when first chunk was received
|
||||
if writer.messages:
|
||||
first_chunk_times.append(time.perf_counter() - start)
|
||||
|
||||
total_times.append(time.perf_counter() - start)
|
||||
|
||||
avg_first_chunk = sum(first_chunk_times) / len(first_chunk_times) if first_chunk_times else 0
|
||||
avg_total = sum(total_times) / len(total_times)
|
||||
|
||||
print(f"\nLatency Results ({iterations} iterations):")
|
||||
print(f" Avg time-to-first-chunk: {avg_first_chunk * 1000:.3f}ms")
|
||||
print(f" Avg time-to-last-chunk: {avg_total * 1000:.3f}ms")
|
||||
|
||||
results.add(
|
||||
f"Streaming latency ({iterations} iterations)",
|
||||
iterations=iterations,
|
||||
duration=sum(total_times),
|
||||
chunks=iterations * 3
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100):
|
||||
"""Benchmark multiple concurrent streams."""
|
||||
arbiter = create_arbiter()
|
||||
|
||||
async def run_stream(stream_id):
|
||||
messages = []
|
||||
for i in range(chunks_per_stream):
|
||||
messages.append(make_chunk_message(f"stream-{stream_id}", f"chunk-{i}"))
|
||||
messages.append(make_end_message(f"stream-{stream_id}"))
|
||||
|
||||
mock_reader = MockStreamReader(messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request(f"bench-concurrent-{stream_id}", "benchmark:App", "stream")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
return len(client_writer.messages)
|
||||
|
||||
gc.collect()
|
||||
start = time.perf_counter()
|
||||
|
||||
# Run streams concurrently
|
||||
tasks = [run_stream(i) for i in range(num_streams)]
|
||||
results_list = await asyncio.gather(*tasks)
|
||||
|
||||
duration = time.perf_counter() - start
|
||||
|
||||
total_chunks = sum(results_list)
|
||||
|
||||
results.add(
|
||||
f"Concurrent streams ({num_streams} streams, {chunks_per_stream} chunks each)",
|
||||
iterations=num_streams,
|
||||
duration=duration,
|
||||
chunks=total_chunks
|
||||
)
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
|
||||
async def benchmark_memory_stability(results, iterations=10, chunks=1000):
|
||||
"""Check memory stability over many iterations."""
|
||||
worker = create_worker()
|
||||
|
||||
gc.collect()
|
||||
tracemalloc.start()
|
||||
memory_samples = [tracemalloc.get_traced_memory()[0]]
|
||||
|
||||
for i in range(iterations):
|
||||
writer = MockStreamWriter()
|
||||
|
||||
async def gen_chunks():
|
||||
for j in range(chunks):
|
||||
yield f"chunk-{j}"
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return gen_chunks()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request(f"bench-mem-{i}", "benchmark:App", "stream")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
gc.collect()
|
||||
memory_samples.append(tracemalloc.get_traced_memory()[0])
|
||||
|
||||
tracemalloc.stop()
|
||||
|
||||
memory_start = memory_samples[0]
|
||||
memory_end = memory_samples[-1]
|
||||
memory_max = max(memory_samples)
|
||||
|
||||
print(f"\nMemory stability ({iterations} iterations of {chunks} chunks):")
|
||||
print(f" Start: {memory_start / 1024 / 1024:.2f}MB")
|
||||
print(f" End: {memory_end / 1024 / 1024:.2f}MB")
|
||||
print(f" Max: {memory_max / 1024 / 1024:.2f}MB")
|
||||
print(f" Delta: {(memory_end - memory_start) / 1024 / 1024:.2f}MB")
|
||||
|
||||
results.add(
|
||||
f"Memory stability ({iterations} x {chunks} chunks)",
|
||||
iterations=iterations * chunks,
|
||||
duration=0.001, # Use small non-zero value to avoid division by zero
|
||||
memory_start=memory_start,
|
||||
memory_end=memory_end
|
||||
)
|
||||
|
||||
|
||||
class MockClientReader:
|
||||
"""Mock async reader that simulates receiving streaming messages."""
|
||||
|
||||
def __init__(self, num_chunks, chunk_data):
|
||||
self.num_chunks = num_chunks
|
||||
self.chunk_data = chunk_data
|
||||
self._chunk_idx = 0
|
||||
self._messages = []
|
||||
self._build_messages()
|
||||
self._pos = 0
|
||||
self._data = b''
|
||||
for msg in self._messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
|
||||
def _build_messages(self):
|
||||
for i in range(self.num_chunks):
|
||||
self._messages.append(make_chunk_message(f"bench-{i}", self.chunk_data))
|
||||
self._messages.append(make_end_message(f"bench-end"))
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self._pos + n > len(self._data):
|
||||
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockClientWriter:
|
||||
"""Mock async writer for client connection."""
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = b""
|
||||
self._closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self._closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
|
||||
async def benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000):
|
||||
"""
|
||||
Benchmark DirtyAsyncStreamIterator directly.
|
||||
|
||||
Measures async iterator overhead vs raw message reading.
|
||||
"""
|
||||
chunk_data = "x" * chunk_size
|
||||
|
||||
# Create mock client with mock reader/writer
|
||||
client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
|
||||
client._reader = MockClientReader(num_chunks, chunk_data)
|
||||
client._writer = MockClientWriter()
|
||||
|
||||
gc.collect()
|
||||
tracemalloc.start()
|
||||
memory_start = tracemalloc.get_traced_memory()[0]
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Use the async stream iterator directly
|
||||
iterator = DirtyAsyncStreamIterator(client, "benchmark:App", "stream", (), {})
|
||||
iterator._started = True # Skip the request sending
|
||||
iterator._request_id = "bench-async"
|
||||
iterator._deadline = time.perf_counter() + 300 # 5 min deadline
|
||||
iterator._last_chunk_time = time.perf_counter()
|
||||
|
||||
chunks_received = 0
|
||||
bytes_received = 0
|
||||
async for chunk in iterator:
|
||||
chunks_received += 1
|
||||
bytes_received += len(chunk)
|
||||
|
||||
duration = time.perf_counter() - start
|
||||
memory_end = tracemalloc.get_traced_memory()[0]
|
||||
tracemalloc.stop()
|
||||
|
||||
results.add(
|
||||
f"Async client streaming ({chunk_size}B chunks, {num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=duration,
|
||||
chunks=chunks_received,
|
||||
bytes_total=bytes_received,
|
||||
memory_start=memory_start,
|
||||
memory_end=memory_end
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_sync_client_streaming(results, chunk_size=1024, num_chunks=1000):
|
||||
"""
|
||||
Benchmark DirtyStreamIterator directly (for comparison with async).
|
||||
|
||||
Note: This runs the sync iterator within an async context for comparison.
|
||||
"""
|
||||
chunk_data = "x" * chunk_size
|
||||
|
||||
# Build raw message data
|
||||
messages_data = b''
|
||||
for i in range(num_chunks):
|
||||
msg = make_chunk_message(f"bench-{i}", chunk_data)
|
||||
messages_data += DirtyProtocol.encode(msg)
|
||||
messages_data += DirtyProtocol.encode(make_end_message("bench-end"))
|
||||
|
||||
# Create a mock socket-like object
|
||||
class MockSocket:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._pos = 0
|
||||
self._timeout = None
|
||||
|
||||
def recv(self, n, flags=0):
|
||||
if self._pos >= len(self._data):
|
||||
return b''
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += len(result)
|
||||
return result
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self._timeout = timeout
|
||||
|
||||
# Create mock client
|
||||
client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
|
||||
client._sock = MockSocket(messages_data)
|
||||
|
||||
gc.collect()
|
||||
tracemalloc.start()
|
||||
memory_start = tracemalloc.get_traced_memory()[0]
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Use the sync stream iterator
|
||||
iterator = DirtyStreamIterator(client, "benchmark:App", "stream", (), {})
|
||||
iterator._started = True # Skip the request sending
|
||||
iterator._request_id = "bench-sync"
|
||||
iterator._deadline = time.perf_counter() + 300 # 5 min deadline
|
||||
iterator._last_chunk_time = time.perf_counter()
|
||||
|
||||
chunks_received = 0
|
||||
bytes_received = 0
|
||||
for chunk in iterator:
|
||||
chunks_received += 1
|
||||
bytes_received += len(chunk)
|
||||
|
||||
duration = time.perf_counter() - start
|
||||
memory_end = tracemalloc.get_traced_memory()[0]
|
||||
tracemalloc.stop()
|
||||
|
||||
results.add(
|
||||
f"Sync client streaming ({chunk_size}B chunks, {num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=duration,
|
||||
chunks=chunks_received,
|
||||
bytes_total=bytes_received,
|
||||
memory_start=memory_start,
|
||||
memory_end=memory_end
|
||||
)
|
||||
|
||||
|
||||
async def benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000):
|
||||
"""
|
||||
Compare stream() vs stream_async() performance with the same workload.
|
||||
"""
|
||||
chunk_data = "x" * chunk_size
|
||||
|
||||
# --- Sync test ---
|
||||
messages_data = b''
|
||||
for i in range(num_chunks):
|
||||
msg = make_chunk_message(f"bench-{i}", chunk_data)
|
||||
messages_data += DirtyProtocol.encode(msg)
|
||||
messages_data += DirtyProtocol.encode(make_end_message("bench-end"))
|
||||
|
||||
class MockSocket:
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
self._pos = 0
|
||||
self._timeout = None
|
||||
|
||||
def recv(self, n, flags=0):
|
||||
if self._pos >= len(self._data):
|
||||
return b''
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += len(result)
|
||||
return result
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self._timeout = timeout
|
||||
|
||||
sync_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
|
||||
sync_client._sock = MockSocket(messages_data)
|
||||
|
||||
gc.collect()
|
||||
sync_start = time.perf_counter()
|
||||
|
||||
sync_iter = DirtyStreamIterator(sync_client, "benchmark:App", "stream", (), {})
|
||||
sync_iter._started = True
|
||||
sync_iter._request_id = "bench-sync"
|
||||
sync_iter._deadline = time.perf_counter() + 300 # 5 min deadline
|
||||
sync_iter._last_chunk_time = time.perf_counter()
|
||||
|
||||
sync_chunks = 0
|
||||
for _ in sync_iter:
|
||||
sync_chunks += 1
|
||||
|
||||
sync_duration = time.perf_counter() - sync_start
|
||||
|
||||
# --- Async test ---
|
||||
async_client = DirtyClient("/tmp/benchmark.sock", timeout=30.0)
|
||||
async_client._reader = MockClientReader(num_chunks, chunk_data)
|
||||
async_client._writer = MockClientWriter()
|
||||
|
||||
gc.collect()
|
||||
async_start = time.perf_counter()
|
||||
|
||||
async_iter = DirtyAsyncStreamIterator(async_client, "benchmark:App", "stream", (), {})
|
||||
async_iter._started = True
|
||||
async_iter._request_id = "bench-async"
|
||||
async_iter._deadline = time.perf_counter() + 300 # 5 min deadline
|
||||
async_iter._last_chunk_time = time.perf_counter()
|
||||
|
||||
async_chunks = 0
|
||||
async for _ in async_iter:
|
||||
async_chunks += 1
|
||||
|
||||
async_duration = time.perf_counter() - async_start
|
||||
|
||||
# Report comparison
|
||||
print(f"\nSync vs Async Client Streaming Comparison ({num_chunks} x {chunk_size}B chunks):")
|
||||
print(f" Sync: {sync_duration * 1000:.3f}ms ({sync_chunks} chunks)")
|
||||
print(f" Async: {async_duration * 1000:.3f}ms ({async_chunks} chunks)")
|
||||
if sync_duration > 0:
|
||||
ratio = async_duration / sync_duration
|
||||
print(f" Ratio (async/sync): {ratio:.3f}x")
|
||||
|
||||
results.add(
|
||||
f"Sync client streaming comparison ({chunk_size}B, {num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=sync_duration,
|
||||
chunks=sync_chunks,
|
||||
bytes_total=sync_chunks * chunk_size
|
||||
)
|
||||
|
||||
results.add(
|
||||
f"Async client streaming comparison ({chunk_size}B, {num_chunks} chunks)",
|
||||
iterations=1,
|
||||
duration=async_duration,
|
||||
chunks=async_chunks,
|
||||
bytes_total=async_chunks * chunk_size
|
||||
)
|
||||
|
||||
|
||||
async def run_quick_benchmarks():
|
||||
"""Run quick benchmarks."""
|
||||
results = BenchmarkResults()
|
||||
|
||||
print("Running quick benchmarks...")
|
||||
|
||||
await benchmark_worker_streaming_throughput(results, chunk_size=64, num_chunks=1000)
|
||||
await benchmark_worker_streaming_throughput(results, chunk_size=1024, num_chunks=1000)
|
||||
await benchmark_arbiter_forwarding(results, num_chunks=1000)
|
||||
await benchmark_streaming_latency(results, iterations=50)
|
||||
|
||||
# Async client streaming benchmarks
|
||||
await benchmark_async_client_streaming(results, chunk_size=1024, num_chunks=1000)
|
||||
await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=1000)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def run_full_benchmarks():
|
||||
"""Run full benchmark suite including stress tests."""
|
||||
results = BenchmarkResults()
|
||||
|
||||
print("Running full benchmark suite...")
|
||||
|
||||
# Throughput tests with different chunk sizes
|
||||
for chunk_size in [1, 64, 1024, 65536]:
|
||||
await benchmark_worker_streaming_throughput(
|
||||
results, chunk_size=chunk_size, num_chunks=1000
|
||||
)
|
||||
|
||||
# Arbiter forwarding
|
||||
await benchmark_arbiter_forwarding(results, num_chunks=10000)
|
||||
|
||||
# Latency
|
||||
await benchmark_streaming_latency(results, iterations=100)
|
||||
|
||||
# Concurrent streams
|
||||
await benchmark_concurrent_streams(results, num_streams=10, chunks_per_stream=100)
|
||||
await benchmark_concurrent_streams(results, num_streams=50, chunks_per_stream=100)
|
||||
|
||||
# Memory stability
|
||||
await benchmark_memory_stability(results, iterations=20, chunks=1000)
|
||||
|
||||
# Async client streaming benchmarks
|
||||
for chunk_size in [64, 1024, 65536]:
|
||||
await benchmark_async_client_streaming(results, chunk_size=chunk_size, num_chunks=1000)
|
||||
await benchmark_sync_client_streaming(results, chunk_size=chunk_size, num_chunks=1000)
|
||||
|
||||
# Comparison benchmark
|
||||
await benchmark_async_vs_sync_client_streaming(results, chunk_size=1024, num_chunks=5000)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Dirty streaming benchmarks")
|
||||
parser.add_argument("--quick", action="store_true", help="Run quick benchmarks only")
|
||||
parser.add_argument("--full", action="store_true", help="Run full benchmark suite")
|
||||
parser.add_argument("--output", "-o", help="Output JSON file path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.full:
|
||||
results = asyncio.run(run_full_benchmarks())
|
||||
else:
|
||||
results = asyncio.run(run_quick_benchmarks())
|
||||
|
||||
results.display()
|
||||
|
||||
if args.output:
|
||||
results.save_json(args.output)
|
||||
else:
|
||||
# Save to default location
|
||||
output_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
results_dir = os.path.join(output_dir, "results")
|
||||
os.makedirs(results_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_file = os.path.join(results_dir, f"streaming_benchmark_{timestamp}.json")
|
||||
results.save_json(output_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -200,6 +200,160 @@ async def my_view(request):
|
||||
return result
|
||||
```
|
||||
|
||||
## Streaming
|
||||
|
||||
Dirty Arbiters support streaming responses for use cases like LLM token generation, where data is produced incrementally. This enables real-time delivery of results without waiting for complete execution.
|
||||
|
||||
### Streaming with Generators
|
||||
|
||||
Any dirty app action that returns a generator (sync or async) automatically streams chunks to the client:
|
||||
|
||||
```python
|
||||
# myapp/llm.py
|
||||
from gunicorn.dirty import DirtyApp
|
||||
|
||||
class LLMApp(DirtyApp):
|
||||
def init(self):
|
||||
from transformers import pipeline
|
||||
self.generator = pipeline("text-generation", model="gpt2")
|
||||
|
||||
def generate(self, prompt):
|
||||
"""Sync streaming - yields tokens."""
|
||||
for token in self.generator(prompt, stream=True):
|
||||
yield token["generated_text"]
|
||||
|
||||
async def generate_async(self, prompt):
|
||||
"""Async streaming - yields tokens."""
|
||||
import openai
|
||||
client = openai.AsyncOpenAI()
|
||||
stream = await client.chat.completions.create(
|
||||
model="gpt-4",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
```
|
||||
|
||||
### Client Streaming API
|
||||
|
||||
Use `stream()` for sync workers and `stream_async()` for async workers:
|
||||
|
||||
**Sync Workers (sync, gthread):**
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import get_dirty_client
|
||||
|
||||
def generate_view(request):
|
||||
client = get_dirty_client()
|
||||
|
||||
def generate_response():
|
||||
for chunk in client.stream("myapp.llm:LLMApp", "generate", request.prompt):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_response())
|
||||
```
|
||||
|
||||
**Async Workers (ASGI):**
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import get_dirty_client_async
|
||||
|
||||
async def generate_view(request):
|
||||
client = await get_dirty_client_async()
|
||||
|
||||
async def generate_response():
|
||||
async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", request.prompt):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(generate_response())
|
||||
```
|
||||
|
||||
### Streaming Protocol
|
||||
|
||||
Streaming uses a simple protocol with three message types:
|
||||
|
||||
1. **Chunk** (`type: "chunk"`) - Contains partial data
|
||||
2. **End** (`type: "end"`) - Signals stream completion
|
||||
3. **Error** (`type: "error"`) - Signals error during streaming
|
||||
|
||||
Example message flow:
|
||||
```
|
||||
Client -> Arbiter -> Worker: request
|
||||
Worker -> Arbiter -> Client: chunk (data: "Hello")
|
||||
Worker -> Arbiter -> Client: chunk (data: " ")
|
||||
Worker -> Arbiter -> Client: chunk (data: "World")
|
||||
Worker -> Arbiter -> Client: end
|
||||
```
|
||||
|
||||
### Error Handling in Streams
|
||||
|
||||
Errors during streaming are delivered as error messages:
|
||||
|
||||
```python
|
||||
def generate_view(request):
|
||||
client = get_dirty_client()
|
||||
|
||||
try:
|
||||
for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt):
|
||||
yield chunk
|
||||
except DirtyError as e:
|
||||
# Error occurred mid-stream
|
||||
yield f"\n[Error: {e.message}]"
|
||||
```
|
||||
|
||||
### Best Practices for Streaming
|
||||
|
||||
1. **Use async generators for I/O-bound streaming** - e.g., API calls to external services
|
||||
2. **Use sync generators for CPU-bound streaming** - e.g., local model inference
|
||||
3. **Yield frequently** - Heartbeats are sent during streaming to keep workers alive
|
||||
4. **Keep chunks small** - Smaller chunks provide better perceived latency
|
||||
5. **Handle client disconnection** - Streams continue even if client disconnects; design accordingly
|
||||
|
||||
### Flask Example
|
||||
|
||||
```python
|
||||
from flask import Flask, Response
|
||||
from gunicorn.dirty import get_dirty_client
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/chat", methods=["POST"])
|
||||
def chat():
|
||||
prompt = request.json.get("prompt")
|
||||
client = get_dirty_client()
|
||||
|
||||
def stream():
|
||||
for token in client.stream("myapp.llm:LLMApp", "generate", prompt):
|
||||
yield f"data: {token}\n\n"
|
||||
|
||||
return Response(stream(), content_type="text/event-stream")
|
||||
```
|
||||
|
||||
### FastAPI Example
|
||||
|
||||
```python
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import StreamingResponse
|
||||
from gunicorn.dirty import get_dirty_client_async
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/chat")
|
||||
async def chat(prompt: str):
|
||||
client = await get_dirty_client_async()
|
||||
|
||||
async def stream():
|
||||
async for token in client.stream_async("myapp.llm:LLMApp", "generate", prompt):
|
||||
yield f"data: {token}\n\n"
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
```
|
||||
|
||||
## Lifecycle Hooks
|
||||
|
||||
Dirty Arbiters provide hooks for customization:
|
||||
|
||||
@ -199,6 +199,7 @@ class DirtyArbiter:
|
||||
Handle a connection from an HTTP worker.
|
||||
|
||||
Routes requests to available dirty workers and returns responses.
|
||||
Supports both regular responses and streaming (chunk-based) responses.
|
||||
"""
|
||||
self.log.debug("New client connection from HTTP worker")
|
||||
|
||||
@ -209,11 +210,8 @@ class DirtyArbiter:
|
||||
except asyncio.IncompleteReadError:
|
||||
break
|
||||
|
||||
# Route request to a dirty worker
|
||||
response = await self.route_request(message)
|
||||
|
||||
# Send response back to HTTP worker
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
# Route request to a dirty worker - pass writer for streaming
|
||||
await self.route_request(message, writer)
|
||||
except Exception as e:
|
||||
self.log.error("Client connection error: %s", e)
|
||||
finally:
|
||||
@ -223,28 +221,31 @@ class DirtyArbiter:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def route_request(self, request):
|
||||
async def route_request(self, request, client_writer):
|
||||
"""
|
||||
Route a request to an available dirty worker via queue.
|
||||
|
||||
Each worker has a dedicated queue and consumer task. Requests are
|
||||
submitted to the queue and processed sequentially by the consumer.
|
||||
|
||||
For streaming responses, messages (chunks) are forwarded directly
|
||||
to the client_writer as they arrive from the worker.
|
||||
|
||||
Args:
|
||||
request: Request message dict
|
||||
|
||||
Returns:
|
||||
Response message dict
|
||||
client_writer: StreamWriter to send responses to client
|
||||
"""
|
||||
request_id = request.get("id", "unknown")
|
||||
|
||||
# Find an available worker
|
||||
worker_pid = await self._get_available_worker()
|
||||
if worker_pid is None:
|
||||
return make_error_response(
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyError("No dirty workers available")
|
||||
)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
return
|
||||
|
||||
# Get queue (start consumer if needed)
|
||||
if worker_pid not in self.worker_queues:
|
||||
@ -253,17 +254,18 @@ class DirtyArbiter:
|
||||
queue = self.worker_queues[worker_pid]
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
|
||||
# Submit request to queue
|
||||
await queue.put((request, future))
|
||||
# Submit request to queue with client writer for streaming support
|
||||
await queue.put((request, client_writer, future))
|
||||
|
||||
# Wait for response
|
||||
# Wait for completion (streaming messages forwarded by consumer)
|
||||
try:
|
||||
return await future
|
||||
await future
|
||||
except Exception as e:
|
||||
return make_error_response(
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyWorkerError(f"Request failed: {e}", worker_id=worker_pid)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
async def _start_worker_consumer(self, worker_pid):
|
||||
"""Start a consumer task for a worker's request queue."""
|
||||
@ -273,11 +275,13 @@ class DirtyArbiter:
|
||||
async def consumer():
|
||||
while self.alive:
|
||||
try:
|
||||
request, future = await queue.get()
|
||||
request, client_writer, future = await queue.get()
|
||||
try:
|
||||
response = await self._execute_on_worker(worker_pid, request)
|
||||
await self._execute_on_worker(
|
||||
worker_pid, request, client_writer
|
||||
)
|
||||
if not future.done():
|
||||
future.set_result(response)
|
||||
future.set_result(None)
|
||||
except Exception as e:
|
||||
if not future.done():
|
||||
future.set_exception(e)
|
||||
@ -289,32 +293,65 @@ class DirtyArbiter:
|
||||
task = asyncio.create_task(consumer())
|
||||
self.worker_consumers[worker_pid] = task
|
||||
|
||||
async def _execute_on_worker(self, worker_pid, request):
|
||||
"""Execute request on a specific worker (called by consumer)."""
|
||||
async def _execute_on_worker(self, worker_pid, request, client_writer):
|
||||
"""
|
||||
Execute request on a specific worker (called by consumer).
|
||||
|
||||
Handles both regular responses and streaming (chunk-based) responses.
|
||||
For streaming, chunk and end messages are forwarded directly to the
|
||||
client_writer as they arrive from the worker.
|
||||
"""
|
||||
request_id = request.get("id", "unknown")
|
||||
|
||||
try:
|
||||
reader, writer = await self._get_worker_connection(worker_pid)
|
||||
await DirtyProtocol.write_message_async(writer, request)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
DirtyProtocol.read_message_async(reader),
|
||||
timeout=self.cfg.dirty_timeout
|
||||
)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
return make_error_response(
|
||||
request_id,
|
||||
DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout)
|
||||
)
|
||||
# Read messages until we get a response, end, or error
|
||||
while True:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
DirtyProtocol.read_message_async(reader),
|
||||
timeout=self.cfg.dirty_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyTimeoutError("Worker timeout", self.cfg.dirty_timeout)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
return
|
||||
|
||||
msg_type = message.get("type")
|
||||
|
||||
# Forward chunk messages to client
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
|
||||
await DirtyProtocol.write_message_async(client_writer, message)
|
||||
continue
|
||||
|
||||
# Forward end message and complete
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_END:
|
||||
await DirtyProtocol.write_message_async(client_writer, message)
|
||||
return
|
||||
|
||||
# Forward response or error and complete
|
||||
if msg_type in (DirtyProtocol.MSG_TYPE_RESPONSE,
|
||||
DirtyProtocol.MSG_TYPE_ERROR):
|
||||
await DirtyProtocol.write_message_async(client_writer, message)
|
||||
return
|
||||
|
||||
# Unknown message type - log and continue
|
||||
self.log.warning("Unknown message type from worker: %s", msg_type)
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Error executing on worker %s: %s", worker_pid, e)
|
||||
self._close_worker_connection(worker_pid)
|
||||
return make_error_response(
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyWorkerError(f"Worker communication failed: {e}",
|
||||
worker_id=worker_pid)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
async def _get_available_worker(self):
|
||||
"""
|
||||
|
||||
@ -14,6 +14,7 @@ import contextvars
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from .errors import (
|
||||
@ -134,6 +135,34 @@ class DirtyClient:
|
||||
raise
|
||||
raise DirtyConnectionError(f"Communication error: {e}") from e
|
||||
|
||||
def stream(self, app_path, action, *args, **kwargs):
|
||||
"""
|
||||
Stream results from a dirty app action (sync).
|
||||
|
||||
This method returns an iterator that yields chunks from a streaming
|
||||
response. Use this for actions that return generators.
|
||||
|
||||
Args:
|
||||
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
|
||||
action: Action to call on the app
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Yields:
|
||||
Chunks of data from the streaming response
|
||||
|
||||
Raises:
|
||||
DirtyConnectionError: If connection fails
|
||||
DirtyTimeoutError: If operation times out
|
||||
DirtyError: If execution fails
|
||||
|
||||
Example::
|
||||
|
||||
for chunk in client.stream("myapp.llm:LLMApp", "generate", prompt):
|
||||
print(chunk, end="", flush=True)
|
||||
"""
|
||||
return DirtyStreamIterator(self, app_path, action, args, kwargs)
|
||||
|
||||
def _handle_response(self, response):
|
||||
"""Handle response message, extracting result or raising error."""
|
||||
msg_type = response.get("type")
|
||||
@ -247,6 +276,34 @@ class DirtyClient:
|
||||
raise
|
||||
raise DirtyConnectionError(f"Communication error: {e}") from e
|
||||
|
||||
def stream_async(self, app_path, action, *args, **kwargs):
|
||||
"""
|
||||
Stream results from a dirty app action (async).
|
||||
|
||||
This method returns an async iterator that yields chunks from a
|
||||
streaming response. Use this for actions that return generators.
|
||||
|
||||
Args:
|
||||
app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp')
|
||||
action: Action to call on the app
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Yields:
|
||||
Chunks of data from the streaming response
|
||||
|
||||
Raises:
|
||||
DirtyConnectionError: If connection fails
|
||||
DirtyTimeoutError: If operation times out
|
||||
DirtyError: If execution fails
|
||||
|
||||
Example::
|
||||
|
||||
async for chunk in client.stream_async("myapp.llm:LLMApp", "generate", prompt):
|
||||
await response.write(chunk)
|
||||
"""
|
||||
return DirtyAsyncStreamIterator(self, app_path, action, args, kwargs)
|
||||
|
||||
async def _close_async(self):
|
||||
"""Close the async connection."""
|
||||
if self._writer is not None:
|
||||
@ -281,6 +338,308 @@ class DirtyClient:
|
||||
await self.close_async()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stream Iterator classes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class DirtyStreamIterator:
|
||||
"""
|
||||
Iterator for streaming responses from dirty workers (sync).
|
||||
|
||||
This class is returned by `DirtyClient.stream()` and yields chunks
|
||||
from a streaming response until the end message is received.
|
||||
|
||||
Uses a deadline-based timeout approach:
|
||||
- Total stream timeout: limits entire stream duration
|
||||
- Idle timeout: limits gap between chunks (defaults to total timeout)
|
||||
"""
|
||||
|
||||
# Default idle timeout between chunks (seconds)
|
||||
DEFAULT_IDLE_TIMEOUT = 30.0
|
||||
|
||||
# Threshold for applying per-read timeout (seconds)
|
||||
# When remaining time is above this, use a larger timeout for efficiency
|
||||
_TIMEOUT_THRESHOLD = 5.0
|
||||
|
||||
def __init__(self, client, app_path, action, args, kwargs,
|
||||
idle_timeout=None):
|
||||
self.client = client
|
||||
self.app_path = app_path
|
||||
self.action = action
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self._started = False
|
||||
self._exhausted = False
|
||||
self._request_id = None
|
||||
self._deadline = None
|
||||
self._last_chunk_time = None
|
||||
# Idle timeout: max time between chunks
|
||||
self._idle_timeout = (
|
||||
idle_timeout if idle_timeout is not None
|
||||
else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout)
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._exhausted:
|
||||
raise StopIteration
|
||||
|
||||
if not self._started:
|
||||
self._start_request()
|
||||
self._started = True
|
||||
|
||||
return self._read_next_chunk()
|
||||
|
||||
def _start_request(self):
|
||||
"""Send the initial request to the arbiter."""
|
||||
with self.client._lock:
|
||||
if self.client._sock is None:
|
||||
self.client.connect()
|
||||
|
||||
# Set deadline for entire stream
|
||||
now = time.monotonic()
|
||||
self._deadline = now + self.client.timeout
|
||||
self._last_chunk_time = now
|
||||
|
||||
self._request_id = str(uuid.uuid4())
|
||||
request = make_request(
|
||||
self._request_id,
|
||||
self.app_path,
|
||||
self.action,
|
||||
args=self.args,
|
||||
kwargs=self.kwargs,
|
||||
)
|
||||
DirtyProtocol.write_message(self.client._sock, request)
|
||||
|
||||
def _read_next_chunk(self):
|
||||
"""Read the next message from the stream."""
|
||||
with self.client._lock:
|
||||
# Check total stream deadline
|
||||
now = time.monotonic()
|
||||
if now >= self._deadline:
|
||||
self._exhausted = True
|
||||
raise DirtyTimeoutError(
|
||||
"Stream exceeded total timeout",
|
||||
timeout=self.client.timeout
|
||||
)
|
||||
|
||||
remaining = self._deadline - now
|
||||
|
||||
# Set socket timeout based on remaining time
|
||||
# Fast path: use larger timeout when plenty of time remains
|
||||
if remaining > self._TIMEOUT_THRESHOLD:
|
||||
read_timeout = self._TIMEOUT_THRESHOLD
|
||||
else:
|
||||
read_timeout = min(remaining, self._idle_timeout)
|
||||
|
||||
try:
|
||||
self.client._sock.settimeout(read_timeout)
|
||||
response = DirtyProtocol.read_message(self.client._sock)
|
||||
except socket.timeout:
|
||||
# Check which timeout was hit
|
||||
now = time.monotonic()
|
||||
if now >= self._deadline:
|
||||
self._exhausted = True
|
||||
raise DirtyTimeoutError(
|
||||
"Stream exceeded total timeout",
|
||||
timeout=self.client.timeout
|
||||
)
|
||||
else:
|
||||
idle_duration = now - self._last_chunk_time
|
||||
self._exhausted = True
|
||||
raise DirtyTimeoutError(
|
||||
f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)",
|
||||
timeout=self._idle_timeout
|
||||
)
|
||||
except Exception as e:
|
||||
self._exhausted = True
|
||||
self.client._close_socket()
|
||||
raise DirtyConnectionError(f"Communication error: {e}") from e
|
||||
|
||||
# Update last chunk time for idle tracking
|
||||
self._last_chunk_time = time.monotonic()
|
||||
|
||||
msg_type = response.get("type")
|
||||
|
||||
# Chunk message - return the data
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
|
||||
return response.get("data")
|
||||
|
||||
# End message - stop iteration
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_END:
|
||||
self._exhausted = True
|
||||
raise StopIteration
|
||||
|
||||
# Error message - raise exception
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_ERROR:
|
||||
self._exhausted = True
|
||||
error_info = response.get("error", {})
|
||||
raise DirtyError.from_dict(error_info)
|
||||
|
||||
# Regular response - shouldn't happen for streaming, but handle it
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
|
||||
self._exhausted = True
|
||||
# Return the result as the only chunk then stop
|
||||
raise StopIteration
|
||||
|
||||
# Unknown type
|
||||
self._exhausted = True
|
||||
raise DirtyError(f"Unknown message type: {msg_type}")
|
||||
|
||||
|
||||
class DirtyAsyncStreamIterator:
|
||||
"""
|
||||
Async iterator for streaming responses from dirty workers.
|
||||
|
||||
This class is returned by `DirtyClient.stream_async()` and yields chunks
|
||||
from a streaming response until the end message is received.
|
||||
|
||||
Uses a deadline-based timeout approach for efficiency:
|
||||
- Total stream timeout: limits entire stream duration
|
||||
- Idle timeout: limits gap between chunks (defaults to total timeout)
|
||||
|
||||
This avoids the overhead of asyncio.wait_for() on every chunk read.
|
||||
"""
|
||||
|
||||
# Default idle timeout between chunks (seconds)
|
||||
DEFAULT_IDLE_TIMEOUT = 30.0
|
||||
|
||||
def __init__(self, client, app_path, action, args, kwargs,
|
||||
idle_timeout=None):
|
||||
self.client = client
|
||||
self.app_path = app_path
|
||||
self.action = action
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self._started = False
|
||||
self._exhausted = False
|
||||
self._request_id = None
|
||||
self._deadline = None
|
||||
self._last_chunk_time = None
|
||||
# Idle timeout: max time between chunks
|
||||
self._idle_timeout = (
|
||||
idle_timeout if idle_timeout is not None
|
||||
else min(self.DEFAULT_IDLE_TIMEOUT, client.timeout)
|
||||
)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self._exhausted:
|
||||
raise StopAsyncIteration
|
||||
|
||||
if not self._started:
|
||||
await self._start_request()
|
||||
self._started = True
|
||||
|
||||
return await self._read_next_chunk()
|
||||
|
||||
async def _start_request(self):
|
||||
"""Send the initial request to the arbiter."""
|
||||
if self.client._writer is None:
|
||||
await self.client.connect_async()
|
||||
|
||||
# Set deadline for entire stream
|
||||
now = time.monotonic()
|
||||
self._deadline = now + self.client.timeout
|
||||
self._last_chunk_time = now
|
||||
|
||||
self._request_id = str(uuid.uuid4())
|
||||
request = make_request(
|
||||
self._request_id,
|
||||
self.app_path,
|
||||
self.action,
|
||||
args=self.args,
|
||||
kwargs=self.kwargs,
|
||||
)
|
||||
await DirtyProtocol.write_message_async(self.client._writer, request)
|
||||
|
||||
# Threshold for applying timeout wrapper (seconds)
|
||||
# When remaining time is above this, skip timeout for performance
|
||||
_TIMEOUT_THRESHOLD = 5.0
|
||||
|
||||
async def _read_next_chunk(self):
|
||||
"""Read the next message from the stream."""
|
||||
# Calculate remaining time until deadline
|
||||
now = time.monotonic()
|
||||
|
||||
# Check total stream deadline
|
||||
if now >= self._deadline:
|
||||
self._exhausted = True
|
||||
raise DirtyTimeoutError(
|
||||
"Stream exceeded total timeout",
|
||||
timeout=self.client.timeout
|
||||
)
|
||||
|
||||
remaining = self._deadline - now
|
||||
|
||||
try:
|
||||
# Fast path: skip timeout wrapper when we have plenty of time
|
||||
# This avoids asyncio.wait_for() overhead for most chunks
|
||||
if remaining > self._TIMEOUT_THRESHOLD:
|
||||
response = await DirtyProtocol.read_message_async(
|
||||
self.client._reader
|
||||
)
|
||||
else:
|
||||
# Near deadline: apply timeout protection
|
||||
read_timeout = min(remaining, self._idle_timeout)
|
||||
response = await asyncio.wait_for(
|
||||
DirtyProtocol.read_message_async(self.client._reader),
|
||||
timeout=read_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self._exhausted = True
|
||||
now = time.monotonic()
|
||||
if now >= self._deadline:
|
||||
raise DirtyTimeoutError(
|
||||
"Stream exceeded total timeout",
|
||||
timeout=self.client.timeout
|
||||
)
|
||||
else:
|
||||
idle_duration = now - self._last_chunk_time
|
||||
raise DirtyTimeoutError(
|
||||
f"Timeout waiting for next chunk (idle {idle_duration:.1f}s)",
|
||||
timeout=self._idle_timeout
|
||||
)
|
||||
except Exception as e:
|
||||
self._exhausted = True
|
||||
await self.client._close_async()
|
||||
raise DirtyConnectionError(f"Communication error: {e}") from e
|
||||
|
||||
# Update last chunk time for idle tracking
|
||||
self._last_chunk_time = time.monotonic()
|
||||
|
||||
msg_type = response.get("type")
|
||||
|
||||
# Chunk message - return the data
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_CHUNK:
|
||||
return response.get("data")
|
||||
|
||||
# End message - stop iteration
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_END:
|
||||
self._exhausted = True
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Error message - raise exception
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_ERROR:
|
||||
self._exhausted = True
|
||||
error_info = response.get("error", {})
|
||||
raise DirtyError.from_dict(error_info)
|
||||
|
||||
# Regular response - shouldn't happen for streaming
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
|
||||
self._exhausted = True
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Unknown type
|
||||
self._exhausted = True
|
||||
raise DirtyError(f"Unknown message type: {msg_type}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Thread-local and context-local client management
|
||||
# =============================================================================
|
||||
|
||||
@ -312,3 +312,37 @@ def make_error_response(request_id: str, error) -> dict:
|
||||
"id": request_id,
|
||||
"error": error_dict,
|
||||
}
|
||||
|
||||
|
||||
def make_chunk_message(request_id: str, data) -> dict:
|
||||
"""
|
||||
Build a chunk message for streaming responses.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this chunk belongs to
|
||||
data: Chunk data (must be JSON-serializable)
|
||||
|
||||
Returns:
|
||||
dict: Chunk message
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_CHUNK,
|
||||
"id": request_id,
|
||||
"data": data,
|
||||
}
|
||||
|
||||
|
||||
def make_end_message(request_id: str) -> dict:
|
||||
"""
|
||||
Build an end-of-stream message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier this ends
|
||||
|
||||
Returns:
|
||||
dict: End message
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_END,
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
@ -69,6 +69,7 @@ operation will continue until the worker is killed by the arbiter.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import inspect
|
||||
import os
|
||||
import signal
|
||||
import traceback
|
||||
@ -88,6 +89,8 @@ from .protocol import (
|
||||
DirtyProtocol,
|
||||
make_response,
|
||||
make_error_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
)
|
||||
|
||||
|
||||
@ -296,11 +299,8 @@ class DirtyWorker:
|
||||
# Connection closed
|
||||
break
|
||||
|
||||
# Handle the request
|
||||
response = await self.handle_request(message)
|
||||
|
||||
# Send response
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
# Handle the request - pass writer for streaming support
|
||||
await self.handle_request(message, writer)
|
||||
except Exception as e:
|
||||
self.log.error("Connection error: %s", e)
|
||||
finally:
|
||||
@ -310,24 +310,28 @@ class DirtyWorker:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def handle_request(self, message):
|
||||
async def handle_request(self, message, writer):
|
||||
"""
|
||||
Handle a single request message.
|
||||
|
||||
Supports both regular (non-streaming) and streaming responses.
|
||||
For streaming, detects if the result is a generator and sends
|
||||
chunk messages followed by an end message.
|
||||
|
||||
Args:
|
||||
message: Request dict from protocol
|
||||
|
||||
Returns:
|
||||
Response dict to send back
|
||||
writer: StreamWriter for sending responses
|
||||
"""
|
||||
request_id = message.get("id", str(uuid.uuid4()))
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type != DirtyProtocol.MSG_TYPE_REQUEST:
|
||||
return make_error_response(
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyWorkerError(f"Unknown message type: {msg_type}")
|
||||
)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
return
|
||||
|
||||
app_path = message.get("app_path")
|
||||
action = message.get("action")
|
||||
@ -339,16 +343,107 @@ class DirtyWorker:
|
||||
|
||||
try:
|
||||
result = await self.execute(app_path, action, args, kwargs)
|
||||
return make_response(request_id, result)
|
||||
|
||||
# Check if result is a generator (streaming)
|
||||
if inspect.isgenerator(result):
|
||||
await self._stream_sync_generator(request_id, result, writer)
|
||||
elif inspect.isasyncgen(result):
|
||||
await self._stream_async_generator(request_id, result, writer)
|
||||
else:
|
||||
# Regular non-streaming response
|
||||
response = make_response(request_id, result)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
self.log.error("Error executing %s.%s: %s\n%s",
|
||||
app_path, action, e, tb)
|
||||
return make_error_response(
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyAppError(str(e), app_path=app_path, action=action,
|
||||
traceback=tb)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
|
||||
async def _stream_sync_generator(self, request_id, gen, writer):
|
||||
"""
|
||||
Stream chunks from a synchronous generator.
|
||||
|
||||
Args:
|
||||
request_id: Request ID for the messages
|
||||
gen: Sync generator to iterate
|
||||
writer: StreamWriter for sending messages
|
||||
"""
|
||||
# Sentinel value to detect end of generator
|
||||
# (StopIteration cannot be raised into a Future in Python 3.7+)
|
||||
_EXHAUSTED = object()
|
||||
|
||||
def _get_next():
|
||||
try:
|
||||
return next(gen)
|
||||
except StopIteration:
|
||||
return _EXHAUSTED
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
while True:
|
||||
# Run next() in executor to avoid blocking event loop
|
||||
chunk = await loop.run_in_executor(self._executor, _get_next)
|
||||
if chunk is _EXHAUSTED:
|
||||
break
|
||||
# Send chunk message
|
||||
await DirtyProtocol.write_message_async(
|
||||
writer, make_chunk_message(request_id, chunk)
|
||||
)
|
||||
# Update heartbeat during long streams
|
||||
self.notify()
|
||||
# Send end message
|
||||
await DirtyProtocol.write_message_async(
|
||||
writer, make_end_message(request_id)
|
||||
)
|
||||
except Exception as e:
|
||||
# Error during streaming - send error message
|
||||
tb = traceback.format_exc()
|
||||
self.log.error("Error during streaming: %s\n%s", e, tb)
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyAppError(str(e), traceback=tb)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
finally:
|
||||
gen.close()
|
||||
|
||||
async def _stream_async_generator(self, request_id, gen, writer):
|
||||
"""
|
||||
Stream chunks from an asynchronous generator.
|
||||
|
||||
Args:
|
||||
request_id: Request ID for the messages
|
||||
gen: Async generator to iterate
|
||||
writer: StreamWriter for sending messages
|
||||
"""
|
||||
try:
|
||||
async for chunk in gen:
|
||||
# Send chunk message
|
||||
await DirtyProtocol.write_message_async(
|
||||
writer, make_chunk_message(request_id, chunk)
|
||||
)
|
||||
# Update heartbeat during long streams
|
||||
self.notify()
|
||||
# Send end message
|
||||
await DirtyProtocol.write_message_async(
|
||||
writer, make_end_message(request_id)
|
||||
)
|
||||
except Exception as e:
|
||||
# Error during streaming - send error message
|
||||
tb = traceback.format_exc()
|
||||
self.log.error("Error during streaming: %s\n%s", e, tb)
|
||||
response = make_error_response(
|
||||
request_id,
|
||||
DirtyAppError(str(e), traceback=tb)
|
||||
)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
finally:
|
||||
await gen.aclose()
|
||||
|
||||
async def execute(self, app_path, action, args, kwargs):
|
||||
"""
|
||||
|
||||
5
tests/dirty/__init__.py
Normal file
5
tests/dirty/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty worker streaming functionality."""
|
||||
319
tests/dirty/test_arbiter_streaming.py
Normal file
319
tests/dirty/test_arbiter_streaming.py
Normal file
@ -0,0 +1,319 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty arbiter streaming functionality."""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
)
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.errors import DirtyError
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock StreamReader that yields predefined messages."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self._pos + n > len(self._data):
|
||||
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
|
||||
def create_arbiter():
|
||||
"""Create a test arbiter with mocked components."""
|
||||
cfg = mock.Mock()
|
||||
cfg.dirty_timeout = 30
|
||||
cfg.dirty_workers = 1
|
||||
cfg.dirty_apps = []
|
||||
cfg.dirty_graceful_timeout = 30
|
||||
cfg.on_dirty_starting = mock.Mock()
|
||||
cfg.dirty_post_fork = mock.Mock()
|
||||
cfg.dirty_worker_exit = mock.Mock()
|
||||
|
||||
log = mock.Mock()
|
||||
|
||||
with mock.patch('tempfile.mkdtemp', return_value='/tmp/test-dirty'):
|
||||
arbiter = DirtyArbiter(cfg, log)
|
||||
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()} # Fake worker
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
return arbiter
|
||||
|
||||
|
||||
class TestArbiterStreamingForwarding:
|
||||
"""Tests for arbiter streaming message forwarding."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_chunk_messages(self):
|
||||
"""Test that arbiter forwards chunk messages to client."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Mock worker connection that returns chunks
|
||||
chunk1 = make_chunk_message("req-123", "Hello")
|
||||
chunk2 = make_chunk_message("req-123", " World")
|
||||
end = make_end_message("req-123")
|
||||
|
||||
mock_reader = MockStreamReader([chunk1, chunk2, end])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have forwarded all messages
|
||||
assert len(client_writer.messages) == 3
|
||||
assert client_writer.messages[0]["type"] == "chunk"
|
||||
assert client_writer.messages[0]["data"] == "Hello"
|
||||
assert client_writer.messages[1]["type"] == "chunk"
|
||||
assert client_writer.messages[1]["data"] == " World"
|
||||
assert client_writer.messages[2]["type"] == "end"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_regular_response(self):
|
||||
"""Test that arbiter forwards regular response to client."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", {"result": 42})
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "response"
|
||||
assert client_writer.messages[0]["result"] == {"result": 42}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_error_mid_stream(self):
|
||||
"""Test that arbiter forwards error during streaming."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
chunk = make_chunk_message("req-123", "First")
|
||||
error = make_error_response("req-123", DirtyError("Something broke"))
|
||||
|
||||
mock_reader = MockStreamReader([chunk, error])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 2
|
||||
assert client_writer.messages[0]["type"] == "chunk"
|
||||
assert client_writer.messages[1]["type"] == "error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_during_streaming(self):
|
||||
"""Test that timeout during streaming sends error."""
|
||||
arbiter = create_arbiter()
|
||||
arbiter.cfg.dirty_timeout = 0.01 # Very short timeout
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Reader that times out
|
||||
class TimeoutReader:
|
||||
async def readexactly(self, n):
|
||||
await asyncio.sleep(1) # Longer than timeout
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return TimeoutReader(), MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "error"
|
||||
assert "timeout" in client_writer.messages[0]["error"]["message"].lower()
|
||||
|
||||
|
||||
class TestArbiterRouteRequestStreaming:
|
||||
"""Tests for route_request with streaming support."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_request_no_workers(self):
|
||||
"""Test route_request when no workers available."""
|
||||
arbiter = create_arbiter()
|
||||
arbiter.workers = {} # No workers
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter.route_request(request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "error"
|
||||
assert "No dirty workers" in client_writer.messages[0]["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_route_request_starts_consumer(self):
|
||||
"""Test that route_request starts consumer if needed."""
|
||||
arbiter = create_arbiter()
|
||||
|
||||
# Mock _execute_on_worker to complete immediately
|
||||
async def mock_execute(pid, request, client_writer):
|
||||
response = make_response("req-123", "result")
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
arbiter._execute_on_worker = mock_execute
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
request = make_request("req-123", "test:App", "compute")
|
||||
|
||||
# Worker queue should be created
|
||||
assert 1234 not in arbiter.worker_queues
|
||||
|
||||
await arbiter.route_request(request, client_writer)
|
||||
|
||||
# Consumer should have been started
|
||||
assert 1234 in arbiter.worker_queues
|
||||
assert 1234 in arbiter.worker_consumers
|
||||
|
||||
# Clean up
|
||||
arbiter.worker_consumers[1234].cancel()
|
||||
|
||||
|
||||
class TestArbiterStreamingManyChunks:
|
||||
"""Tests for streaming with many chunks."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forwards_many_chunks(self):
|
||||
"""Test that arbiter forwards many chunks correctly."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Generate 50 chunks + end
|
||||
messages = []
|
||||
for i in range(50):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
|
||||
mock_reader = MockStreamReader(messages)
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 51
|
||||
assert client_writer.messages[0]["data"] == "chunk-0"
|
||||
assert client_writer.messages[49]["data"] == "chunk-49"
|
||||
assert client_writer.messages[50]["type"] == "end"
|
||||
|
||||
|
||||
class TestArbiterBackwardCompatibility:
|
||||
"""Tests for backward compatibility with non-streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_regular_response(self):
|
||||
"""Test that regular (non-streaming) responses still work."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
response = make_response("req-123", [1, 2, 3, 4, 5])
|
||||
mock_reader = MockStreamReader([response])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "response"
|
||||
assert client_writer.messages[0]["result"] == [1, 2, 3, 4, 5]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_error_response(self):
|
||||
"""Test that error responses still work."""
|
||||
arbiter = create_arbiter()
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
error = make_error_response("req-123", DirtyError("Something failed"))
|
||||
mock_reader = MockStreamReader([error])
|
||||
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "error"
|
||||
236
tests/dirty/test_client_streaming.py
Normal file
236
tests/dirty/test_client_streaming.py
Normal file
@ -0,0 +1,236 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty client sync streaming functionality."""
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyConnectionError
|
||||
|
||||
|
||||
class MockSocket:
|
||||
"""Mock socket that returns predefined messages."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._pos = 0
|
||||
self._sent = []
|
||||
self.closed = False
|
||||
self._timeout = None
|
||||
|
||||
def sendall(self, data):
|
||||
self._sent.append(data)
|
||||
|
||||
def recv(self, n, flags=0):
|
||||
if self._pos >= len(self._data):
|
||||
return b''
|
||||
end = min(self._pos + n, len(self._data))
|
||||
result = self._data[self._pos:end]
|
||||
self._pos = end
|
||||
return result
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self._timeout = timeout
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def create_client_with_mock_socket(messages):
|
||||
"""Create a client with a mock socket returning the given messages."""
|
||||
client = DirtyClient("/tmp/test.sock")
|
||||
client._sock = MockSocket(messages)
|
||||
return client
|
||||
|
||||
|
||||
class TestDirtyStreamIterator:
|
||||
"""Tests for DirtyStreamIterator."""
|
||||
|
||||
def test_stream_returns_iterator(self):
|
||||
"""Test that stream() returns an iterator."""
|
||||
client = DirtyClient("/tmp/test.sock")
|
||||
result = client.stream("test:App", "generate")
|
||||
assert isinstance(result, DirtyStreamIterator)
|
||||
|
||||
def test_stream_iterator_yields_chunks(self):
|
||||
"""Test that stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
|
||||
assert chunks == ["Hello", " ", "World"]
|
||||
|
||||
def test_stream_iterator_yields_complex_chunks(self):
|
||||
"""Test that stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["token"] == "Hello"
|
||||
assert chunks[1]["token"] == "World"
|
||||
|
||||
def test_stream_iterator_handles_error(self):
|
||||
"""Test that stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
iterator = client.stream("test:App", "generate")
|
||||
|
||||
# First chunk should work
|
||||
chunk = next(iterator)
|
||||
assert chunk == "First"
|
||||
|
||||
# Second should raise error
|
||||
with pytest.raises(DirtyError) as exc_info:
|
||||
next(iterator)
|
||||
assert "Something broke" in str(exc_info.value)
|
||||
|
||||
def test_stream_iterator_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
assert chunks == []
|
||||
|
||||
def test_stream_iterator_stops_after_exhausted(self):
|
||||
"""Test that iterator stays exhausted after StopIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
iterator = client.stream("test:App", "generate")
|
||||
|
||||
# Get the chunk
|
||||
chunk = next(iterator)
|
||||
assert chunk == "Only"
|
||||
|
||||
# Should stop
|
||||
with pytest.raises(StopIteration):
|
||||
next(iterator)
|
||||
|
||||
# Should stay stopped
|
||||
with pytest.raises(StopIteration):
|
||||
next(iterator)
|
||||
|
||||
def test_stream_iterator_with_for_loop(self):
|
||||
"""Test stream iterator works in for loop."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "a"),
|
||||
make_chunk_message("req-123", "b"),
|
||||
make_chunk_message("req-123", "c"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
result = ""
|
||||
for chunk in client.stream("test:App", "generate"):
|
||||
result += chunk
|
||||
|
||||
assert result == "abc"
|
||||
|
||||
def test_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first next() call."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
iterator = client.stream("test:App", "generate", "prompt_arg")
|
||||
|
||||
# Before iteration, no request sent
|
||||
assert len(client._sock._sent) == 0
|
||||
|
||||
# First iteration sends request
|
||||
next(iterator)
|
||||
assert len(client._sock._sent) == 1
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyStreamIteratorEdgeCases:
|
||||
"""Edge cases for streaming."""
|
||||
|
||||
def test_stream_many_chunks(self):
|
||||
"""Test streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
chunks = list(client.stream("test:App", "generate"))
|
||||
|
||||
assert len(chunks) == 100
|
||||
assert chunks[0] == "chunk-0"
|
||||
assert chunks[99] == "chunk-99"
|
||||
|
||||
def test_stream_with_kwargs(self):
|
||||
"""Test streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_client_with_mock_socket(messages)
|
||||
|
||||
# Use kwargs
|
||||
list(client.stream("test:App", "generate", "arg1", key="value"))
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._sock._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
267
tests/dirty/test_client_streaming_async.py
Normal file
267
tests/dirty/test_client_streaming_async.py
Normal file
@ -0,0 +1,267 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty client async streaming functionality."""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_error_response,
|
||||
)
|
||||
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
|
||||
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
|
||||
|
||||
|
||||
class MockAsyncReader:
|
||||
"""Mock async reader that returns predefined messages."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self._pos + n > len(self._data):
|
||||
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
|
||||
class MockAsyncWriter:
|
||||
"""Mock async writer that captures sent data."""
|
||||
|
||||
def __init__(self):
|
||||
self._sent = []
|
||||
self.closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._sent.append(data)
|
||||
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_async_client_with_mocks(messages):
|
||||
"""Create a client with mock async reader/writer."""
|
||||
client = DirtyClient("/tmp/test.sock")
|
||||
client._reader = MockAsyncReader(messages)
|
||||
client._writer = MockAsyncWriter()
|
||||
return client
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamIterator:
|
||||
"""Tests for DirtyAsyncStreamIterator."""
|
||||
|
||||
def test_stream_async_returns_async_iterator(self):
|
||||
"""Test that stream_async() returns an async iterator."""
|
||||
client = DirtyClient("/tmp/test.sock")
|
||||
result = client.stream_async("test:App", "generate")
|
||||
assert isinstance(result, DirtyAsyncStreamIterator)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_yields_chunks(self):
|
||||
"""Test that async stream iterator yields chunks correctly."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_async("test:App", "generate"):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == ["Hello", " ", "World"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_yields_complex_chunks(self):
|
||||
"""Test that async stream iterator yields complex data types."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", {"token": "Hello", "score": 0.9}),
|
||||
make_chunk_message("req-123", {"token": "World", "score": 0.8}),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_async("test:App", "generate"):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert chunks[0]["token"] == "Hello"
|
||||
assert chunks[1]["token"] == "World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_handles_error(self):
|
||||
"""Test that async stream iterator raises on error message."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "First"),
|
||||
make_error_response("req-123", DirtyError("Something broke")),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
iterator = client.stream_async("test:App", "generate")
|
||||
|
||||
# First chunk should work
|
||||
chunk = await iterator.__anext__()
|
||||
assert chunk == "First"
|
||||
|
||||
# Second should raise error
|
||||
with pytest.raises(DirtyError) as exc_info:
|
||||
await iterator.__anext__()
|
||||
assert "Something broke" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_empty_stream(self):
|
||||
"""Test that empty stream (just end) works."""
|
||||
messages = [make_end_message("req-123")]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_async("test:App", "generate"):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert chunks == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_stops_after_exhausted(self):
|
||||
"""Test that async iterator stays exhausted after StopAsyncIteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "Only"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
iterator = client.stream_async("test:App", "generate")
|
||||
|
||||
# Get the chunk
|
||||
chunk = await iterator.__anext__()
|
||||
assert chunk == "Only"
|
||||
|
||||
# Should stop
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iterator.__anext__()
|
||||
|
||||
# Should stay stopped
|
||||
with pytest.raises(StopAsyncIteration):
|
||||
await iterator.__anext__()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_sends_request_on_first_iteration(self):
|
||||
"""Test that request is sent on first async iteration."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
iterator = client.stream_async("test:App", "generate", "prompt_arg")
|
||||
|
||||
# Before iteration, no request sent
|
||||
assert len(client._writer._sent) == 0
|
||||
|
||||
# First iteration sends request
|
||||
await iterator.__anext__()
|
||||
assert len(client._writer._sent) == 1
|
||||
|
||||
# Decode sent request
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["type"] == "request"
|
||||
assert request["app_path"] == "test:App"
|
||||
assert request["action"] == "generate"
|
||||
assert request["args"] == ["prompt_arg"]
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamIteratorEdgeCases:
|
||||
"""Edge cases for async streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_many_chunks(self):
|
||||
"""Test async streaming with many chunks."""
|
||||
messages = []
|
||||
for i in range(100):
|
||||
messages.append(make_chunk_message("req-123", f"chunk-{i}"))
|
||||
messages.append(make_end_message("req-123"))
|
||||
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
chunks = []
|
||||
async for chunk in client.stream_async("test:App", "generate"):
|
||||
chunks.append(chunk)
|
||||
|
||||
assert len(chunks) == 100
|
||||
assert chunks[0] == "chunk-0"
|
||||
assert chunks[99] == "chunk-99"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_with_kwargs(self):
|
||||
"""Test async streaming with keyword arguments."""
|
||||
messages = [
|
||||
make_chunk_message("req-123", "data"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
client = create_async_client_with_mocks(messages)
|
||||
|
||||
# Use kwargs
|
||||
chunks = []
|
||||
async for chunk in client.stream_async("test:App", "generate", "arg1", key="value"):
|
||||
chunks.append(chunk)
|
||||
|
||||
# Check the sent request includes kwargs
|
||||
sent_data = client._writer._sent[0]
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
sent_data[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
request = DirtyProtocol.decode(
|
||||
sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length]
|
||||
)
|
||||
|
||||
assert request["args"] == ["arg1"]
|
||||
assert request["kwargs"] == {"key": "value"}
|
||||
|
||||
|
||||
class TestDirtyAsyncStreamTimeout:
|
||||
"""Tests for async streaming timeout handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_stream_timeout(self):
|
||||
"""Test that timeout during async streaming raises DirtyTimeoutError."""
|
||||
client = DirtyClient("/tmp/test.sock", timeout=0.01)
|
||||
|
||||
# Create a reader that times out
|
||||
class SlowReader:
|
||||
async def readexactly(self, n):
|
||||
await asyncio.sleep(1) # Longer than timeout
|
||||
|
||||
client._reader = SlowReader()
|
||||
client._writer = MockAsyncWriter()
|
||||
|
||||
iterator = client.stream_async("test:App", "generate")
|
||||
|
||||
with pytest.raises(DirtyTimeoutError):
|
||||
await iterator.__anext__()
|
||||
469
tests/dirty/test_streaming_integration.py
Normal file
469
tests/dirty/test_streaming_integration.py
Normal file
@ -0,0 +1,469 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Integration tests for dirty streaming functionality.
|
||||
|
||||
These tests verify the full streaming pipeline:
|
||||
client -> arbiter -> worker -> generator -> chunks -> client
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import struct
|
||||
import tempfile
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
make_response,
|
||||
make_error_response,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
from gunicorn.dirty.arbiter import DirtyArbiter
|
||||
from gunicorn.dirty.client import DirtyClient
|
||||
from gunicorn.dirty.errors import DirtyError
|
||||
|
||||
|
||||
class MockLog:
|
||||
"""Mock logger for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
|
||||
def debug(self, msg, *args):
|
||||
self.messages.append(("debug", msg % args if args else msg))
|
||||
|
||||
def info(self, msg, *args):
|
||||
self.messages.append(("info", msg % args if args else msg))
|
||||
|
||||
def warning(self, msg, *args):
|
||||
self.messages.append(("warning", msg % args if args else msg))
|
||||
|
||||
def error(self, msg, *args):
|
||||
self.messages.append(("error", msg % args if args else msg))
|
||||
|
||||
def close_on_exec(self):
|
||||
pass
|
||||
|
||||
def reopen_files(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
|
||||
class MockStreamReader:
|
||||
"""Mock StreamReader that yields predefined messages."""
|
||||
|
||||
def __init__(self, messages):
|
||||
self._data = b''
|
||||
for msg in messages:
|
||||
self._data += DirtyProtocol.encode(msg)
|
||||
self._pos = 0
|
||||
|
||||
async def readexactly(self, n):
|
||||
if self._pos + n > len(self._data):
|
||||
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
|
||||
result = self._data[self._pos:self._pos + n]
|
||||
self._pos += n
|
||||
return result
|
||||
|
||||
|
||||
class TestStreamingEndToEnd:
|
||||
"""End-to-end streaming tests using mocked components."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generator_end_to_end(self):
|
||||
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
|
||||
# Simulate what a worker would produce for a sync generator
|
||||
worker_messages = [
|
||||
make_chunk_message("req-123", "Hello"),
|
||||
make_chunk_message("req-123", " "),
|
||||
make_chunk_message("req-123", "World"),
|
||||
make_end_message("req-123"),
|
||||
]
|
||||
|
||||
# Create an arbiter with mocked worker connection
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
# Mock worker connection
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
# Create client writer to capture messages
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
# Execute request through arbiter
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Verify all messages were forwarded
|
||||
assert len(client_writer.messages) == 4
|
||||
assert client_writer.messages[0]["type"] == "chunk"
|
||||
assert client_writer.messages[0]["data"] == "Hello"
|
||||
assert client_writer.messages[1]["data"] == " "
|
||||
assert client_writer.messages[2]["data"] == "World"
|
||||
assert client_writer.messages[3]["type"] == "end"
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator_end_to_end(self):
|
||||
"""Test complete flow: async generator -> worker -> arbiter -> client."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-456", "Async"),
|
||||
make_chunk_message("req-456", " "),
|
||||
make_chunk_message("req-456", "Stream"),
|
||||
make_end_message("req-456"),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 4
|
||||
assert client_writer.messages[0]["data"] == "Async"
|
||||
assert client_writer.messages[3]["type"] == "end"
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
|
||||
class TestStreamingErrorHandling:
|
||||
"""Tests for error handling during streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_mid_stream(self):
|
||||
"""Test that errors during streaming are properly forwarded."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-789", "First"),
|
||||
make_chunk_message("req-789", "Second"),
|
||||
make_error_response("req-789", DirtyError("Stream failed")),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-789", "test:App", "generate_with_error")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have 2 chunks + 1 error
|
||||
assert len(client_writer.messages) == 3
|
||||
assert client_writer.messages[0]["type"] == "chunk"
|
||||
assert client_writer.messages[1]["type"] == "chunk"
|
||||
assert client_writer.messages[2]["type"] == "error"
|
||||
assert "Stream failed" in client_writer.messages[2]["error"]["message"]
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
|
||||
class TestStreamingBackwardCompatibility:
|
||||
"""Tests for backward compatibility with non-streaming responses."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_streaming_response_still_works(self):
|
||||
"""Test that regular (non-streaming) responses still work."""
|
||||
worker_messages = [
|
||||
make_response("req-abc", {"result": 42, "data": [1, 2, 3]}),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-abc", "test:App", "compute")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have 1 response
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "response"
|
||||
assert client_writer.messages[0]["result"]["result"] == 42
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_still_works(self):
|
||||
"""Test that error responses still work."""
|
||||
worker_messages = [
|
||||
make_error_response("req-def", DirtyError("Something failed")),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-def", "test:App", "fail")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 1
|
||||
assert client_writer.messages[0]["type"] == "error"
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
|
||||
class TestStreamingWorkerIntegration:
|
||||
"""Integration tests for worker streaming with execute."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_handles_sync_generator(self):
|
||||
"""Test worker properly handles sync generator from execute."""
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 300)
|
||||
log = MockLog()
|
||||
|
||||
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
|
||||
worker = DirtyWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
app_paths=["test:App"],
|
||||
cfg=cfg,
|
||||
log=log,
|
||||
socket_path="/tmp/test.sock"
|
||||
)
|
||||
|
||||
worker.apps = {}
|
||||
worker._executor = None
|
||||
worker.tmp = mock.Mock()
|
||||
|
||||
writer = MockStreamWriter()
|
||||
|
||||
# Mock execute to return a sync generator
|
||||
def sync_gen():
|
||||
yield "one"
|
||||
yield "two"
|
||||
yield "three"
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return sync_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end
|
||||
assert len(writer.messages) == 4
|
||||
assert writer.messages[0]["data"] == "one"
|
||||
assert writer.messages[1]["data"] == "two"
|
||||
assert writer.messages[2]["data"] == "three"
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_handles_async_generator(self):
|
||||
"""Test worker properly handles async generator from execute."""
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 300)
|
||||
log = MockLog()
|
||||
|
||||
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
|
||||
worker = DirtyWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
app_paths=["test:App"],
|
||||
cfg=cfg,
|
||||
log=log,
|
||||
socket_path="/tmp/test.sock"
|
||||
)
|
||||
|
||||
worker.apps = {}
|
||||
worker._executor = None
|
||||
worker.tmp = mock.Mock()
|
||||
|
||||
writer = MockStreamWriter()
|
||||
|
||||
# Mock execute to return an async generator
|
||||
async def async_gen():
|
||||
yield "async_one"
|
||||
yield "async_two"
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return async_gen()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-456", "test:App", "async_generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 2 chunks + 1 end
|
||||
assert len(writer.messages) == 3
|
||||
assert writer.messages[0]["data"] == "async_one"
|
||||
assert writer.messages[1]["data"] == "async_two"
|
||||
assert writer.messages[2]["type"] == "end"
|
||||
|
||||
|
||||
class TestStreamingMixedScenarios:
|
||||
"""Tests for mixed streaming scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_stream(self):
|
||||
"""Test streaming with many chunks."""
|
||||
worker_messages = []
|
||||
for i in range(500):
|
||||
worker_messages.append(make_chunk_message("req-large", f"chunk-{i}"))
|
||||
worker_messages.append(make_end_message("req-large"))
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-large", "test:App", "large_stream")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
# Should have 500 chunks + 1 end
|
||||
assert len(client_writer.messages) == 501
|
||||
assert client_writer.messages[0]["data"] == "chunk-0"
|
||||
assert client_writer.messages[499]["data"] == "chunk-499"
|
||||
assert client_writer.messages[500]["type"] == "end"
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_with_complex_data(self):
|
||||
"""Test streaming with complex JSON-serializable data."""
|
||||
worker_messages = [
|
||||
make_chunk_message("req-complex", {
|
||||
"token": "Hello",
|
||||
"scores": [0.1, 0.2, 0.3],
|
||||
"metadata": {"position": 0}
|
||||
}),
|
||||
make_chunk_message("req-complex", {
|
||||
"token": "World",
|
||||
"scores": [0.4, 0.5],
|
||||
"metadata": {"position": 1}
|
||||
}),
|
||||
make_end_message("req-complex"),
|
||||
]
|
||||
|
||||
cfg = Config()
|
||||
cfg.set("dirty_timeout", 30)
|
||||
log = MockLog()
|
||||
|
||||
arbiter = DirtyArbiter(cfg=cfg, log=log)
|
||||
arbiter.alive = True
|
||||
arbiter.workers = {1234: mock.Mock()}
|
||||
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
|
||||
|
||||
mock_reader = MockStreamReader(worker_messages)
|
||||
async def mock_get_connection(pid):
|
||||
return mock_reader, MockStreamWriter()
|
||||
arbiter._get_worker_connection = mock_get_connection
|
||||
|
||||
client_writer = MockStreamWriter()
|
||||
|
||||
request = make_request("req-complex", "test:App", "complex_stream")
|
||||
await arbiter._execute_on_worker(1234, request, client_writer)
|
||||
|
||||
assert len(client_writer.messages) == 3
|
||||
assert client_writer.messages[0]["data"]["token"] == "Hello"
|
||||
assert client_writer.messages[0]["data"]["scores"] == [0.1, 0.2, 0.3]
|
||||
assert client_writer.messages[1]["data"]["metadata"]["position"] == 1
|
||||
|
||||
arbiter._cleanup_sync()
|
||||
415
tests/dirty/test_worker_streaming.py
Normal file
415
tests/dirty/test_worker_streaming.py
Normal file
@ -0,0 +1,415 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty worker streaming functionality."""
|
||||
|
||||
import asyncio
|
||||
import struct
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol,
|
||||
make_request,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
)
|
||||
from gunicorn.dirty.worker import DirtyWorker
|
||||
|
||||
|
||||
class FakeStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode the buffer to extract messages
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
|
||||
def create_worker():
|
||||
"""Create a test worker with mocked components."""
|
||||
cfg = mock.Mock()
|
||||
cfg.dirty_timeout = 30
|
||||
cfg.dirty_threads = 1
|
||||
cfg.env = None
|
||||
cfg.uid = None
|
||||
cfg.gid = None
|
||||
cfg.initgroups = False
|
||||
cfg.dirty_worker_init = mock.Mock()
|
||||
cfg.umask = 0o22
|
||||
|
||||
log = mock.Mock()
|
||||
|
||||
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
|
||||
worker = DirtyWorker(
|
||||
age=1,
|
||||
ppid=1,
|
||||
app_paths=["test:App"],
|
||||
cfg=cfg,
|
||||
log=log,
|
||||
socket_path="/tmp/test.sock"
|
||||
)
|
||||
|
||||
worker.apps = {}
|
||||
worker._executor = None # Use default executor for sync generator tests
|
||||
worker.tmp = mock.Mock()
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
class TestWorkerSyncGeneratorStreaming:
|
||||
"""Tests for sync generator streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generator_sends_chunks_and_end(self):
|
||||
"""Test that sync generator sends chunk messages then end message."""
|
||||
def generate_tokens():
|
||||
yield "Hello"
|
||||
yield " "
|
||||
yield "World"
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
# Mock execute to return the sync generator directly
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
assert len(writer.messages) == 4
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
assert writer.messages[1]["data"] == " "
|
||||
|
||||
assert writer.messages[2]["type"] == "chunk"
|
||||
assert writer.messages[2]["data"] == "World"
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_generator_error_mid_stream(self):
|
||||
"""Test that error during streaming sends error message."""
|
||||
def generate_with_error():
|
||||
yield "First"
|
||||
raise ValueError("Something went wrong")
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
assert len(writer.messages) == 2
|
||||
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["data"] == "First"
|
||||
|
||||
assert writer.messages[1]["type"] == "error"
|
||||
assert "Something went wrong" in writer.messages[1]["error"]["message"]
|
||||
|
||||
|
||||
class TestWorkerAsyncGeneratorStreaming:
|
||||
"""Tests for async generator streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator_sends_chunks_and_end(self):
|
||||
"""Test that async generator sends chunk messages then end message."""
|
||||
async def async_generate_tokens():
|
||||
yield "Hello"
|
||||
yield " "
|
||||
yield "World"
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return async_generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 3 chunks + 1 end message
|
||||
assert len(writer.messages) == 4
|
||||
|
||||
# Check chunk messages
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["data"] == "Hello"
|
||||
|
||||
assert writer.messages[1]["type"] == "chunk"
|
||||
assert writer.messages[1]["data"] == " "
|
||||
|
||||
assert writer.messages[2]["type"] == "chunk"
|
||||
assert writer.messages[2]["data"] == "World"
|
||||
|
||||
# Check end message
|
||||
assert writer.messages[3]["type"] == "end"
|
||||
assert writer.messages[3]["id"] == "req-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_generator_error_mid_stream(self):
|
||||
"""Test that error during async streaming sends error message."""
|
||||
async def async_generate_with_error():
|
||||
yield "First"
|
||||
raise ValueError("Async error")
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return async_generate_with_error()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 chunk + 1 error message
|
||||
assert len(writer.messages) == 2
|
||||
|
||||
assert writer.messages[0]["type"] == "chunk"
|
||||
assert writer.messages[0]["data"] == "First"
|
||||
|
||||
assert writer.messages[1]["type"] == "error"
|
||||
assert "Async error" in writer.messages[1]["error"]["message"]
|
||||
|
||||
|
||||
class TestWorkerNonStreamingBackwardCompat:
|
||||
"""Tests for backward compatibility with non-streaming responses."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_generator_returns_response(self):
|
||||
"""Test that non-generator method returns regular response."""
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return args[0] + args[1]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "compute", args=(2, 3))
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "response"
|
||||
assert writer.messages[0]["id"] == "req-123"
|
||||
assert writer.messages[0]["result"] == 5
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_result_not_treated_as_streaming(self):
|
||||
"""Test that list result is not treated as streaming."""
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return [1, 2, 3, 4, 5]
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "get_list")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message (not 5 chunks)
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "response"
|
||||
assert writer.messages[0]["result"] == [1, 2, 3, 4, 5]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_in_execute_sends_error(self):
|
||||
"""Test that error in execute sends error response."""
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
raise RuntimeError("Failed!")
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "fail")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 error message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "error"
|
||||
assert "Failed!" in writer.messages[0]["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_result(self):
|
||||
"""Test that None result works correctly."""
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return None
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "void")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 1 response message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "response"
|
||||
assert writer.messages[0]["result"] is None
|
||||
|
||||
|
||||
class TestWorkerStreamingComplexData:
|
||||
"""Tests for streaming with complex data types."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_dict_chunks(self):
|
||||
"""Test streaming chunks that are dictionaries."""
|
||||
async def generate_tokens():
|
||||
yield {"token": "Hello", "score": 0.9}
|
||||
yield {"token": "World", "score": 0.8}
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 3 # 2 chunks + 1 end
|
||||
|
||||
assert writer.messages[0]["data"]["token"] == "Hello"
|
||||
assert writer.messages[0]["data"]["score"] == 0.9
|
||||
assert writer.messages[1]["data"]["token"] == "World"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_empty_generator(self):
|
||||
"""Test streaming with empty generator."""
|
||||
async def empty_generate():
|
||||
return
|
||||
yield # Make it a generator
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return empty_generate()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have just 1 end message
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "end"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_many_chunks(self):
|
||||
"""Test streaming with many chunks."""
|
||||
async def generate_many():
|
||||
for i in range(100):
|
||||
yield f"chunk-{i}"
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return generate_many()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have 100 chunks + 1 end message
|
||||
assert len(writer.messages) == 101
|
||||
assert writer.messages[0]["data"] == "chunk-0"
|
||||
assert writer.messages[99]["data"] == "chunk-99"
|
||||
assert writer.messages[100]["type"] == "end"
|
||||
|
||||
|
||||
class TestWorkerStreamingHeartbeat:
|
||||
"""Tests for heartbeat updates during streaming."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_heartbeat_updated_during_streaming(self):
|
||||
"""Test that heartbeat is updated during streaming."""
|
||||
async def generate_tokens():
|
||||
yield "Hello"
|
||||
yield "World"
|
||||
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
# Track notify calls
|
||||
notify_count = [0]
|
||||
original_notify = worker.notify
|
||||
|
||||
def counting_notify():
|
||||
notify_count[0] += 1
|
||||
return original_notify() if callable(original_notify) else None
|
||||
|
||||
worker.notify = counting_notify
|
||||
|
||||
async def mock_execute(app_path, action, args, kwargs):
|
||||
return generate_tokens()
|
||||
|
||||
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
|
||||
request = make_request("req-123", "test:App", "generate")
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
# Should have been notified at least once per chunk + initial
|
||||
assert notify_count[0] >= 2 # At least one per chunk
|
||||
|
||||
|
||||
class TestWorkerMessageTypeValidation:
|
||||
"""Tests for message type validation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_message_type_sends_error(self):
|
||||
"""Test that unknown message type sends error response."""
|
||||
worker = create_worker()
|
||||
writer = FakeStreamWriter()
|
||||
|
||||
# Send a message with unknown type
|
||||
message = {"type": "unknown", "id": "req-123"}
|
||||
await worker.handle_request(message, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
assert writer.messages[0]["type"] == "error"
|
||||
assert "Unknown message type" in writer.messages[0]["error"]["message"]
|
||||
@ -7,6 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import struct
|
||||
import tempfile
|
||||
import pytest
|
||||
|
||||
@ -16,6 +17,41 @@ from gunicorn.dirty.errors import DirtyError
|
||||
from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
|
||||
class MockLog:
|
||||
"""Mock logger for testing."""
|
||||
|
||||
@ -141,8 +177,11 @@ class TestDirtyArbiterRouteRequest:
|
||||
action="test"
|
||||
)
|
||||
|
||||
response = await arbiter.route_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await arbiter.route_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
||||
assert "No dirty workers available" in response["error"]["message"]
|
||||
|
||||
@ -430,8 +469,11 @@ class TestDirtyArbiterRouteTimeout:
|
||||
)
|
||||
|
||||
# This should fail because socket doesn't exist
|
||||
response = await arbiter.route_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await arbiter.route_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
||||
# Either "Worker communication failed" or "Worker socket not ready"
|
||||
assert "error" in response
|
||||
@ -893,7 +935,8 @@ class TestDirtyArbiterQueueBehavior:
|
||||
)
|
||||
|
||||
# This will fail (no socket), but consumer should be started
|
||||
await arbiter.route_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await arbiter.route_request(request, writer)
|
||||
|
||||
assert fake_pid in arbiter.worker_queues
|
||||
assert fake_pid in arbiter.worker_consumers
|
||||
|
||||
@ -15,6 +15,8 @@ from gunicorn.dirty.protocol import (
|
||||
make_request,
|
||||
make_response,
|
||||
make_error_response,
|
||||
make_chunk_message,
|
||||
make_end_message,
|
||||
)
|
||||
from gunicorn.dirty.errors import (
|
||||
DirtyError,
|
||||
@ -323,6 +325,51 @@ class TestMessageBuilders:
|
||||
assert response["error"]["error_type"] == "ValueError"
|
||||
assert response["error"]["message"] == "Invalid value"
|
||||
|
||||
def test_make_chunk_message(self):
|
||||
"""Test chunk message builder."""
|
||||
chunk = make_chunk_message("req-123", "Hello, ")
|
||||
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
|
||||
assert chunk["id"] == "req-123"
|
||||
assert chunk["data"] == "Hello, "
|
||||
|
||||
def test_make_chunk_message_with_complex_data(self):
|
||||
"""Test chunk message with complex data."""
|
||||
data = {"token": "world", "score": 0.95, "index": 5}
|
||||
chunk = make_chunk_message("req-456", data)
|
||||
assert chunk["type"] == DirtyProtocol.MSG_TYPE_CHUNK
|
||||
assert chunk["id"] == "req-456"
|
||||
assert chunk["data"] == data
|
||||
|
||||
def test_make_chunk_message_with_list_data(self):
|
||||
"""Test chunk message with list data."""
|
||||
data = [1, 2, 3, "token"]
|
||||
chunk = make_chunk_message("req-789", data)
|
||||
assert chunk["data"] == data
|
||||
|
||||
def test_make_end_message(self):
|
||||
"""Test end message builder."""
|
||||
end = make_end_message("req-123")
|
||||
assert end["type"] == DirtyProtocol.MSG_TYPE_END
|
||||
assert end["id"] == "req-123"
|
||||
assert "data" not in end
|
||||
|
||||
def test_chunk_and_end_encode_decode(self):
|
||||
"""Test that chunk and end messages can be encoded/decoded."""
|
||||
chunk = make_chunk_message("req-123", {"token": "hello"})
|
||||
end = make_end_message("req-123")
|
||||
|
||||
# Test chunk roundtrip
|
||||
encoded_chunk = DirtyProtocol.encode(chunk)
|
||||
payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == chunk
|
||||
|
||||
# Test end roundtrip
|
||||
encoded_end = DirtyProtocol.encode(end)
|
||||
payload = encoded_end[DirtyProtocol.HEADER_SIZE:]
|
||||
decoded = DirtyProtocol.decode(payload)
|
||||
assert decoded == end
|
||||
|
||||
|
||||
class TestDirtyErrors:
|
||||
"""Tests for error classes."""
|
||||
|
||||
@ -16,6 +16,9 @@ from gunicorn.dirty.protocol import DirtyProtocol, make_request
|
||||
from gunicorn.dirty.errors import DirtyAppNotFoundError
|
||||
|
||||
|
||||
import struct
|
||||
|
||||
|
||||
class MockLog:
|
||||
"""Mock logger for testing."""
|
||||
|
||||
@ -41,6 +44,42 @@ class MockLog:
|
||||
pass
|
||||
|
||||
|
||||
class MockStreamWriter:
|
||||
"""Mock StreamWriter that captures written messages."""
|
||||
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self._buffer = b""
|
||||
self.closed = False
|
||||
|
||||
def write(self, data):
|
||||
self._buffer += data
|
||||
|
||||
async def drain(self):
|
||||
# Decode the buffer to extract messages
|
||||
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
|
||||
length = struct.unpack(
|
||||
DirtyProtocol.HEADER_FORMAT,
|
||||
self._buffer[:DirtyProtocol.HEADER_SIZE]
|
||||
)[0]
|
||||
total_size = DirtyProtocol.HEADER_SIZE + length
|
||||
if len(self._buffer) >= total_size:
|
||||
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
|
||||
self._buffer = self._buffer[total_size:]
|
||||
self.messages.append(DirtyProtocol.decode(msg_data))
|
||||
else:
|
||||
break
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
|
||||
class TestDirtyWorkerInit:
|
||||
"""Tests for DirtyWorker initialization."""
|
||||
|
||||
@ -214,8 +253,11 @@ class TestDirtyWorkerHandleRequest:
|
||||
kwargs={"operation": "multiply"}
|
||||
)
|
||||
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
|
||||
assert response["id"] == "test-123"
|
||||
assert response["result"] == 6
|
||||
@ -247,8 +289,11 @@ class TestDirtyWorkerHandleRequest:
|
||||
kwargs={"operation": "invalid"}
|
||||
)
|
||||
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
||||
assert response["id"] == "test-456"
|
||||
assert "Unknown operation" in response["error"]["message"]
|
||||
@ -271,8 +316,11 @@ class TestDirtyWorkerHandleRequest:
|
||||
)
|
||||
|
||||
request = {"type": "unknown", "id": "test-789"}
|
||||
response = await worker.handle_request(request)
|
||||
writer = MockStreamWriter()
|
||||
await worker.handle_request(request, writer)
|
||||
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR
|
||||
assert "Unknown message type" in response["error"]["message"]
|
||||
|
||||
@ -662,43 +710,21 @@ class TestDirtyWorkerRunAsync:
|
||||
reader.feed_data(encoded_request)
|
||||
reader.feed_eof()
|
||||
|
||||
class MockWriter:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
self.data = b""
|
||||
|
||||
def get_extra_info(self, name):
|
||||
return None
|
||||
|
||||
def write(self, data):
|
||||
self.data += data
|
||||
|
||||
async def drain(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
async def wait_closed(self):
|
||||
pass
|
||||
|
||||
writer = MockWriter()
|
||||
writer = MockStreamWriter()
|
||||
|
||||
# Handle one message then exit
|
||||
worker.alive = True
|
||||
try:
|
||||
message = await DirtyProtocol.read_message_async(reader)
|
||||
response = await worker.handle_request(message)
|
||||
await DirtyProtocol.write_message_async(writer, response)
|
||||
await worker.handle_request(message, writer)
|
||||
except asyncio.IncompleteReadError:
|
||||
pass
|
||||
|
||||
# Decode response from writer
|
||||
if writer.data:
|
||||
payload = writer.data[DirtyProtocol.HEADER_SIZE:]
|
||||
response = DirtyProtocol.decode(payload)
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
|
||||
assert response["result"] == 8
|
||||
# Check response from writer
|
||||
assert len(writer.messages) == 1
|
||||
response = writer.messages[0]
|
||||
assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE
|
||||
assert response["result"] == 8
|
||||
|
||||
worker._cleanup()
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user