gunicorn/tests/dirty/test_worker_streaming.py
Benoit Chesneau 477b7479cc feat(dirty): update client for binary protocol
Update client and streaming tests to work with the binary protocol:
- Update MockStreamWriter/MockStreamReader to use BinaryProtocol
- Replace string request IDs with integers
- Update test assertions to decode binary protocol messages
- Use HEADER_SIZE and decode_header/decode_message instead of old API
2026-02-11 23:12:44 +01:00

423 lines
14 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty worker streaming functionality."""
import asyncio
import struct
from unittest import mock
import pytest
from gunicorn.dirty.protocol import (
DirtyProtocol,
BinaryProtocol,
make_request,
make_chunk_message,
make_end_message,
HEADER_SIZE,
)
from gunicorn.dirty.worker import DirtyWorker
class FakeStreamWriter:
"""Mock StreamWriter that captures written messages."""
def __init__(self):
self.messages = []
self._buffer = b""
def write(self, data):
self._buffer += data
async def drain(self):
# Decode the buffer to extract messages using binary protocol
while len(self._buffer) >= HEADER_SIZE:
# Decode header to get payload length
_, _, length = BinaryProtocol.decode_header(
self._buffer[:HEADER_SIZE]
)
total_size = HEADER_SIZE + length
if len(self._buffer) >= total_size:
msg_data = self._buffer[:total_size]
self._buffer = self._buffer[total_size:]
# decode_message returns (msg_type_str, request_id, payload_dict)
msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data)
# Reconstruct the dict format for backwards compatibility
result = {"type": msg_type_str, "id": request_id}
result.update(payload_dict)
self.messages.append(result)
else:
break
def close(self):
pass
async def wait_closed(self):
pass
def create_worker():
"""Create a test worker with mocked components."""
cfg = mock.Mock()
cfg.dirty_timeout = 30
cfg.dirty_threads = 1
cfg.env = None
cfg.uid = None
cfg.gid = None
cfg.initgroups = False
cfg.dirty_worker_init = mock.Mock()
cfg.umask = 0o22
log = mock.Mock()
with mock.patch('gunicorn.dirty.worker.WorkerTmp'):
worker = DirtyWorker(
age=1,
ppid=1,
app_paths=["test:App"],
cfg=cfg,
log=log,
socket_path="/tmp/test.sock"
)
worker.apps = {}
worker._executor = None # Use default executor for sync generator tests
worker.tmp = mock.Mock()
return worker
class TestWorkerSyncGeneratorStreaming:
"""Tests for sync generator streaming."""
@pytest.mark.asyncio
async def test_sync_generator_sends_chunks_and_end(self):
"""Test that sync generator sends chunk messages then end message."""
def generate_tokens():
yield "Hello"
yield " "
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
# Mock execute to return the sync generator directly
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
assert len(writer.messages) == 4
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
assert writer.messages[1]["data"] == " "
assert writer.messages[2]["type"] == "chunk"
assert writer.messages[2]["data"] == "World"
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_sync_generator_error_mid_stream(self):
"""Test that error during streaming sends error message."""
def generate_with_error():
yield "First"
raise ValueError("Something went wrong")
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
assert len(writer.messages) == 2
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["data"] == "First"
assert writer.messages[1]["type"] == "error"
assert "Something went wrong" in writer.messages[1]["error"]["message"]
class TestWorkerAsyncGeneratorStreaming:
"""Tests for async generator streaming."""
@pytest.mark.asyncio
async def test_async_generator_sends_chunks_and_end(self):
"""Test that async generator sends chunk messages then end message."""
async def async_generate_tokens():
yield "Hello"
yield " "
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return async_generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 3 chunks + 1 end message
assert len(writer.messages) == 4
# Check chunk messages
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["data"] == "Hello"
assert writer.messages[1]["type"] == "chunk"
assert writer.messages[1]["data"] == " "
assert writer.messages[2]["type"] == "chunk"
assert writer.messages[2]["data"] == "World"
# Check end message
assert writer.messages[3]["type"] == "end"
assert writer.messages[3]["id"] == 123
@pytest.mark.asyncio
async def test_async_generator_error_mid_stream(self):
"""Test that error during async streaming sends error message."""
async def async_generate_with_error():
yield "First"
raise ValueError("Async error")
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return async_generate_with_error()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 1 chunk + 1 error message
assert len(writer.messages) == 2
assert writer.messages[0]["type"] == "chunk"
assert writer.messages[0]["data"] == "First"
assert writer.messages[1]["type"] == "error"
assert "Async error" in writer.messages[1]["error"]["message"]
class TestWorkerNonStreamingBackwardCompat:
"""Tests for backward compatibility with non-streaming responses."""
@pytest.mark.asyncio
async def test_non_generator_returns_response(self):
"""Test that non-generator method returns regular response."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return args[0] + args[1]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "compute", args=(2, 3))
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["id"] == 123
assert writer.messages[0]["result"] == 5
@pytest.mark.asyncio
async def test_list_result_not_treated_as_streaming(self):
"""Test that list result is not treated as streaming."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return [1, 2, 3, 4, 5]
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "get_list")
await worker.handle_request(request, writer)
# Should have 1 response message (not 5 chunks)
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["result"] == [1, 2, 3, 4, 5]
@pytest.mark.asyncio
async def test_error_in_execute_sends_error(self):
"""Test that error in execute sends error response."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
raise RuntimeError("Failed!")
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "fail")
await worker.handle_request(request, writer)
# Should have 1 error message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "error"
assert "Failed!" in writer.messages[0]["error"]["message"]
@pytest.mark.asyncio
async def test_none_result(self):
"""Test that None result works correctly."""
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return None
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "void")
await worker.handle_request(request, writer)
# Should have 1 response message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "response"
assert writer.messages[0]["result"] is None
class TestWorkerStreamingComplexData:
"""Tests for streaming with complex data types."""
@pytest.mark.asyncio
async def test_streaming_dict_chunks(self):
"""Test streaming chunks that are dictionaries."""
async def generate_tokens():
yield {"token": "Hello", "score": 0.9}
yield {"token": "World", "score": 0.8}
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
assert len(writer.messages) == 3 # 2 chunks + 1 end
assert writer.messages[0]["data"]["token"] == "Hello"
assert writer.messages[0]["data"]["score"] == 0.9
assert writer.messages[1]["data"]["token"] == "World"
@pytest.mark.asyncio
async def test_streaming_empty_generator(self):
"""Test streaming with empty generator."""
async def empty_generate():
return
yield # Make it a generator
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return empty_generate()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have just 1 end message
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "end"
@pytest.mark.asyncio
async def test_streaming_many_chunks(self):
"""Test streaming with many chunks."""
async def generate_many():
for i in range(100):
yield f"chunk-{i}"
worker = create_worker()
writer = FakeStreamWriter()
async def mock_execute(app_path, action, args, kwargs):
return generate_many()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have 100 chunks + 1 end message
assert len(writer.messages) == 101
assert writer.messages[0]["data"] == "chunk-0"
assert writer.messages[99]["data"] == "chunk-99"
assert writer.messages[100]["type"] == "end"
class TestWorkerStreamingHeartbeat:
"""Tests for heartbeat updates during streaming."""
@pytest.mark.asyncio
async def test_heartbeat_updated_during_streaming(self):
"""Test that heartbeat is updated during streaming."""
async def generate_tokens():
yield "Hello"
yield "World"
worker = create_worker()
writer = FakeStreamWriter()
# Track notify calls
notify_count = [0]
original_notify = worker.notify
def counting_notify():
notify_count[0] += 1
return original_notify() if callable(original_notify) else None
worker.notify = counting_notify
async def mock_execute(app_path, action, args, kwargs):
return generate_tokens()
with mock.patch.object(worker, 'execute', side_effect=mock_execute):
request = make_request(123, "test:App", "generate")
await worker.handle_request(request, writer)
# Should have been notified at least once per chunk + initial
assert notify_count[0] >= 2 # At least one per chunk
class TestWorkerMessageTypeValidation:
"""Tests for message type validation."""
@pytest.mark.asyncio
async def test_unknown_message_type_sends_error(self):
"""Test that unknown message type sends error response."""
worker = create_worker()
writer = FakeStreamWriter()
# Send a message with unknown type
message = {"type": "unknown", "id": 123}
await worker.handle_request(message, writer)
assert len(writer.messages) == 1
assert writer.messages[0]["type"] == "error"
assert "Unknown message type" in writer.messages[0]["error"]["message"]