gunicorn/tests/dirty/test_streaming_integration.py
Benoit Chesneau f6418d4eb0 feat(dirty): add streaming support and async client benchmarks
Add support for streaming responses when dirty app actions return
generators (sync or async). This enables real-time delivery of
incremental results for use cases like LLM token generation.

Features:
- Streaming protocol with chunk/end/error message types
- Worker support for sync and async generators
- Arbiter forwarding of streaming messages
- Deadline-based timeout handling
- Async client streaming API

Protocol:
- Chunk messages (type: "chunk") contain partial data
- End messages (type: "end") signal stream completion
- Error messages can occur mid-stream

New files:
- benchmarks/dirty_streaming.py: Streaming benchmark suite
- tests/dirty/test_*_streaming*.py: Streaming test coverage
- docs/content/dirty.md: Streaming documentation with examples
2026-01-25 10:23:25 +01:00

470 lines
15 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Integration tests for dirty streaming functionality.
These tests verify the full streaming pipeline:
client -> arbiter -> worker -> generator -> chunks -> client
"""
import asyncio
import os
import struct
import tempfile
import pytest
from unittest import mock
from gunicorn.config import Config
from gunicorn.dirty.protocol import (
DirtyProtocol,
make_request,
make_chunk_message,
make_end_message,
make_response,
make_error_response,
)
from gunicorn.dirty.worker import DirtyWorker
from gunicorn.dirty.arbiter import DirtyArbiter
from gunicorn.dirty.client import DirtyClient
from gunicorn.dirty.errors import DirtyError
class MockLog:
"""Mock logger for testing."""
def __init__(self):
self.messages = []
def debug(self, msg, *args):
self.messages.append(("debug", msg % args if args else msg))
def info(self, msg, *args):
self.messages.append(("info", msg % args if args else msg))
def warning(self, msg, *args):
self.messages.append(("warning", msg % args if args else msg))
def error(self, msg, *args):
self.messages.append(("error", msg % args if args else msg))
def close_on_exec(self):
pass
def reopen_files(self):
pass
class MockStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
self.closed = False
def write(self, data):
self._buffer += data
async def drain(self):
while len(self._buffer) >= DirtyProtocol.HEADER_SIZE:
length = struct.unpack(
DirtyProtocol.HEADER_FORMAT,
self._buffer[:DirtyProtocol.HEADER_SIZE]
)[0]
total_size = DirtyProtocol.HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size]
self._buffer = self._buffer[total_size:]
self.messages.append(DirtyProtocol.decode(msg_data))
else:
break
def close(self):
self.closed = True
async def wait_closed(self):
pass
def get_extra_info(self, name):
return None
class MockStreamReader:
"""Mock StreamReader that yields predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += DirtyProtocol.encode(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class TestStreamingEndToEnd:
"""End-to-end streaming tests using mocked components."""
@pytest.mark.asyncio
async def test_sync_generator_end_to_end(self):
"""Test complete flow: sync generator -> worker -> arbiter -> client."""
# Simulate what a worker would produce for a sync generator
worker_messages = [
make_chunk_message("req-123", "Hello"),
make_chunk_message("req-123", " "),
make_chunk_message("req-123", "World"),
make_end_message("req-123"),
]
# Create an arbiter with mocked worker connection
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
# Mock worker connection
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
# Create client writer to capture messages
client_writer = MockStreamWriter()
# Execute request through arbiter
request = make_request("req-123", "test:App", "generate")
await arbiter._execute_on_worker(1234, request, client_writer)
# Verify all messages were forwarded
assert len(client_writer.messages) == 4
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[0]["data"] == "Hello"
assert client_writer.messages[1]["data"] == " "
assert client_writer.messages[2]["data"] == "World"
assert client_writer.messages[3]["type"] == "end"
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_async_generator_end_to_end(self):
"""Test complete flow: async generator -> worker -> arbiter -> client."""
worker_messages = [
make_chunk_message("req-456", "Async"),
make_chunk_message("req-456", " "),
make_chunk_message("req-456", "Stream"),
make_end_message("req-456"),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-456", "test:App", "async_generate")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 4
assert client_writer.messages[0]["data"] == "Async"
assert client_writer.messages[3]["type"] == "end"
arbiter._cleanup_sync()
class TestStreamingErrorHandling:
"""Tests for error handling during streaming."""
@pytest.mark.asyncio
async def test_error_mid_stream(self):
"""Test that errors during streaming are properly forwarded."""
worker_messages = [
make_chunk_message("req-789", "First"),
make_chunk_message("req-789", "Second"),
make_error_response("req-789", DirtyError("Stream failed")),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-789", "test:App", "generate_with_error")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 2 chunks + 1 error
assert len(client_writer.messages) == 3
assert client_writer.messages[0]["type"] == "chunk"
assert client_writer.messages[1]["type"] == "chunk"
assert client_writer.messages[2]["type"] == "error"
assert "Stream failed" in client_writer.messages[2]["error"]["message"]
arbiter._cleanup_sync()
class TestStreamingBackwardCompatibility:
"""Tests for backward compatibility with non-streaming responses."""
@pytest.mark.asyncio
async def test_non_streaming_response_still_works(self):
"""Test that regular (non-streaming) responses still work."""
worker_messages = [
make_response("req-abc", {"result": 42, "data": [1, 2, 3]}),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-abc", "test:App", "compute")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 1 response
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "response"
assert client_writer.messages[0]["result"]["result"] == 42
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_error_response_still_works(self):
"""Test that error responses still work."""
worker_messages = [
make_error_response("req-def", DirtyError("Something failed")),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-def", "test:App", "fail")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 1
assert client_writer.messages[0]["type"] == "error"
arbiter._cleanup_sync()
class TestStreamingWorkerIntegration:
"""Integration tests for worker streaming with execute."""
@pytest.mark.asyncio
async def test_worker_handles_sync_generator(self):
"""Test worker properly handles sync generator from execute."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=os.getpid(),
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None
worker.tmp = mock.Mock()
writer = MockStreamWriter()
# Mock execute to return a sync generator
def sync_gen():
yield "one"
yield "two"
yield "three"
async def mock_execute(app_path, action, args, kwargs):
return sync_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-123", "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end
assert len(writer.messages) == 4
assert writer.messages[0]["data"] == "one"
assert writer.messages[1]["data"] == "two"
assert writer.messages[2]["data"] == "three"
assert writer.messages[3]["type"] == "end"
@pytest.mark.asyncio
async def test_worker_handles_async_generator(self):
"""Test worker properly handles async generator from execute."""
cfg = Config()
cfg.set("dirty_timeout", 300)
log = MockLog()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=os.getpid(),
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None
worker.tmp = mock.Mock()
writer = MockStreamWriter()
# Mock execute to return an async generator
async def async_gen():
yield "async_one"
yield "async_two"
async def mock_execute(app_path, action, args, kwargs):
return async_gen()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request("req-456", "test:App", "async_generate")
await worker.handle_request(request, writer)
# Should have 2 chunks + 1 end
assert len(writer.messages) == 3
assert writer.messages[0]["data"] == "async_one"
assert writer.messages[1]["data"] == "async_two"
assert writer.messages[2]["type"] == "end"
class TestStreamingMixedScenarios:
"""Tests for mixed streaming scenarios."""
@pytest.mark.asyncio
async def test_large_stream(self):
"""Test streaming with many chunks."""
worker_messages = []
for i in range(500):
worker_messages.append(make_chunk_message("req-large", f"chunk-{i}"))
worker_messages.append(make_end_message("req-large"))
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-large", "test:App", "large_stream")
await arbiter._execute_on_worker(1234, request, client_writer)
# Should have 500 chunks + 1 end
assert len(client_writer.messages) == 501
assert client_writer.messages[0]["data"] == "chunk-0"
assert client_writer.messages[499]["data"] == "chunk-499"
assert client_writer.messages[500]["type"] == "end"
arbiter._cleanup_sync()
@pytest.mark.asyncio
async def test_stream_with_complex_data(self):
"""Test streaming with complex JSON-serializable data."""
worker_messages = [
make_chunk_message("req-complex", {
"token": "Hello",
"scores": [0.1, 0.2, 0.3],
"metadata": {"position": 0}
}),
make_chunk_message("req-complex", {
"token": "World",
"scores": [0.4, 0.5],
"metadata": {"position": 1}
}),
make_end_message("req-complex"),
]
cfg = Config()
cfg.set("dirty_timeout", 30)
log = MockLog()
arbiter = DirtyArbiter(cfg=cfg, log=log)
arbiter.alive = True
arbiter.workers = {1234: mock.Mock()}
arbiter.worker_sockets = {1234: '/tmp/worker.sock'}
mock_reader = MockStreamReader(worker_messages)
async def mock_get_connection(pid):
return mock_reader, MockStreamWriter()
arbiter._get_worker_connection = mock_get_connection
client_writer = MockStreamWriter()
request = make_request("req-complex", "test:App", "complex_stream")
await arbiter._execute_on_worker(1234, request, client_writer)
assert len(client_writer.messages) == 3
assert client_writer.messages[0]["data"]["token"] == "Hello"
assert client_writer.messages[0]["data"]["scores"] == [0.1, 0.2, 0.3]
assert client_writer.messages[1]["data"]["metadata"]["position"] == 1
arbiter._cleanup_sync()