gunicorn/tests/dirty/test_client_streaming_async.py
Benoit Chesneau 477b7479cc feat(dirty): update client for binary protocol
Update client and streaming tests to work with the binary protocol:
- Update MockStreamWriter/MockStreamReader to use BinaryProtocol
- Replace string request IDs with integers
- Update test assertions to decode binary protocol messages
- Use HEADER_SIZE and decode_header/decode_message instead of old API
2026-02-11 23:12:44 +01:00

264 lines
8.3 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty client async streaming functionality."""
import asyncio
import struct
import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_chunk_message,
make_end_message,
make_error_response,
HEADER_SIZE,
)
from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator
from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError
class MockAsyncReader:
"""Mock async reader that returns predefined messages."""
def __init__(self, messages):
self._data = b''
for msg in messages:
self._data += BinaryProtocol._encode_from_dict(msg)
self._pos = 0
async def readexactly(self, n):
if self._pos + n > len(self._data):
raise asyncio.IncompleteReadError(self._data[self._pos:], n)
result = self._data[self._pos:self._pos + n]
self._pos += n
return result
class MockAsyncWriter:
"""Mock async writer that captures sent data."""
def __init__(self):
self._sent = []
self.closed = False
def write(self, data):
self._sent.append(data)
async def drain(self):
pass
def close(self):
self.closed = True
async def wait_closed(self):
pass
def create_async_client_with_mocks(messages):
"""Create a client with mock async reader/writer."""
client = DirtyClient("/tmp/test.sock")
client._reader = MockAsyncReader(messages)
client._writer = MockAsyncWriter()
return client
class TestDirtyAsyncStreamIterator:
"""Tests for DirtyAsyncStreamIterator."""
def test_stream_async_returns_async_iterator(self):
"""Test that stream_async() returns an async iterator."""
client = DirtyClient("/tmp/test.sock")
result = client.stream_async("test:App", "generate")
assert isinstance(result, DirtyAsyncStreamIterator)
@pytest.mark.asyncio
async def test_async_stream_yields_chunks(self):
"""Test that async stream iterator yields chunks correctly."""
messages = [
make_chunk_message(123, "Hello"),
make_chunk_message(123, " "),
make_chunk_message(123, "World"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert chunks == ["Hello", " ", "World"]
@pytest.mark.asyncio
async def test_async_stream_yields_complex_chunks(self):
"""Test that async stream iterator yields complex data types."""
messages = [
make_chunk_message(123, {"token": "Hello", "score": 0.9}),
make_chunk_message(123, {"token": "World", "score": 0.8}),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert len(chunks) == 2
assert chunks[0]["token"] == "Hello"
assert chunks[1]["token"] == "World"
@pytest.mark.asyncio
async def test_async_stream_handles_error(self):
"""Test that async stream iterator raises on error message."""
messages = [
make_chunk_message(123, "First"),
make_error_response(123, DirtyError("Something broke")),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate")
# First chunk should work
chunk = await iterator.__anext__()
assert chunk == "First"
# Second should raise error
with pytest.raises(DirtyError) as exc_info:
await iterator.__anext__()
assert "Something broke" in str(exc_info.value)
@pytest.mark.asyncio
async def test_async_stream_empty_stream(self):
"""Test that empty stream (just end) works."""
messages = [make_end_message(123)]
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert chunks == []
@pytest.mark.asyncio
async def test_async_stream_stops_after_exhausted(self):
"""Test that async iterator stays exhausted after StopAsyncIteration."""
messages = [
make_chunk_message(123, "Only"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate")
# Get the chunk
chunk = await iterator.__anext__()
assert chunk == "Only"
# Should stop
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
# Should stay stopped
with pytest.raises(StopAsyncIteration):
await iterator.__anext__()
@pytest.mark.asyncio
async def test_async_stream_sends_request_on_first_iteration(self):
"""Test that request is sent on first async iteration."""
messages = [
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
iterator = client.stream_async("test:App", "generate", "prompt_arg")
# Before iteration, no request sent
assert len(client._writer._sent) == 0
# First iteration sends request
await iterator.__anext__()
assert len(client._writer._sent) == 1
# Decode sent request
sent_data = client._writer._sent[0]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert msg_type_str == "request"
assert payload["app_path"] == "test:App"
assert payload["action"] == "generate"
assert payload["args"] == ["prompt_arg"]
class TestDirtyAsyncStreamIteratorEdgeCases:
"""Edge cases for async streaming."""
@pytest.mark.asyncio
async def test_async_stream_many_chunks(self):
"""Test async streaming with many chunks."""
messages = []
for i in range(100):
messages.append(make_chunk_message(123, f"chunk-{i}"))
messages.append(make_end_message(123))
client = create_async_client_with_mocks(messages)
chunks = []
async for chunk in client.stream_async("test:App", "generate"):
chunks.append(chunk)
assert len(chunks) == 100
assert chunks[0] == "chunk-0"
assert chunks[99] == "chunk-99"
@pytest.mark.asyncio
async def test_async_stream_with_kwargs(self):
"""Test async streaming with keyword arguments."""
messages = [
make_chunk_message(123, "data"),
make_end_message(123),
]
client = create_async_client_with_mocks(messages)
# Use kwargs
chunks = []
async for chunk in client.stream_async("test:App", "generate", "arg1", key="value"):
chunks.append(chunk)
# Check the sent request includes kwargs
sent_data = client._writer._sent[0]
_, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE])
msg_type_str, request_id, payload = BinaryProtocol.decode_message(
sent_data[:HEADER_SIZE + length]
)
assert payload["args"] == ["arg1"]
assert payload["kwargs"] == {"key": "value"}
class TestDirtyAsyncStreamTimeout:
"""Tests for async streaming timeout handling."""
@pytest.mark.asyncio
async def test_async_stream_timeout(self):
"""Test that timeout during async streaming raises DirtyTimeoutError."""
client = DirtyClient("/tmp/test.sock", timeout=0.01)
# Create a reader that times out
class SlowReader:
async def readexactly(self, n):
await asyncio.sleep(1) # Longer than timeout
client._reader = SlowReader()
client._writer = MockAsyncWriter()
iterator = client.stream_async("test:App", "generate")
with pytest.raises(DirtyTimeoutError):
await iterator.__anext__()