# # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. """Tests for dirty client sync streaming functionality.""" import socket import struct import pytest from unittest import mock from gunicorn.dirty.protocol import ( DirtyProtocol, make_chunk_message, make_end_message, make_response, make_error_response, ) from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyConnectionError class MockSocket: """Mock socket that returns predefined messages.""" def __init__(self, messages): self._data = b'' for msg in messages: self._data += DirtyProtocol.encode(msg) self._pos = 0 self._sent = [] self.closed = False self._timeout = None def sendall(self, data): self._sent.append(data) def recv(self, n, flags=0): if self._pos >= len(self._data): return b'' end = min(self._pos + n, len(self._data)) result = self._data[self._pos:end] self._pos = end return result def settimeout(self, timeout): self._timeout = timeout def close(self): self.closed = True def create_client_with_mock_socket(messages): """Create a client with a mock socket returning the given messages.""" client = DirtyClient("/tmp/test.sock") client._sock = MockSocket(messages) return client class TestDirtyStreamIterator: """Tests for DirtyStreamIterator.""" def test_stream_returns_iterator(self): """Test that stream() returns an iterator.""" client = DirtyClient("/tmp/test.sock") result = client.stream("test:App", "generate") assert isinstance(result, DirtyStreamIterator) def test_stream_iterator_yields_chunks(self): """Test that stream iterator yields chunks correctly.""" messages = [ make_chunk_message("req-123", "Hello"), make_chunk_message("req-123", " "), make_chunk_message("req-123", "World"), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) assert chunks == ["Hello", " ", "World"] def test_stream_iterator_yields_complex_chunks(self): """Test that stream iterator yields complex data types.""" messages = [ make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), make_chunk_message("req-123", {"token": "World", "score": 0.8}), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) assert len(chunks) == 2 assert chunks[0]["token"] == "Hello" assert chunks[1]["token"] == "World" def test_stream_iterator_handles_error(self): """Test that stream iterator raises on error message.""" messages = [ make_chunk_message("req-123", "First"), make_error_response("req-123", DirtyError("Something broke")), ] client = create_client_with_mock_socket(messages) iterator = client.stream("test:App", "generate") # First chunk should work chunk = next(iterator) assert chunk == "First" # Second should raise error with pytest.raises(DirtyError) as exc_info: next(iterator) assert "Something broke" in str(exc_info.value) def test_stream_iterator_empty_stream(self): """Test that empty stream (just end) works.""" messages = [make_end_message("req-123")] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) assert chunks == [] def test_stream_iterator_stops_after_exhausted(self): """Test that iterator stays exhausted after StopIteration.""" messages = [ make_chunk_message("req-123", "Only"), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) iterator = client.stream("test:App", "generate") # Get the chunk chunk = next(iterator) assert chunk == "Only" # Should stop with pytest.raises(StopIteration): next(iterator) # Should stay stopped with pytest.raises(StopIteration): next(iterator) def test_stream_iterator_with_for_loop(self): """Test stream iterator works in for loop.""" messages = [ make_chunk_message("req-123", "a"), make_chunk_message("req-123", "b"), make_chunk_message("req-123", "c"), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) result = "" for chunk in client.stream("test:App", "generate"): result += chunk assert result == "abc" def test_stream_sends_request_on_first_iteration(self): """Test that request is sent on first next() call.""" messages = [ make_chunk_message("req-123", "data"), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) iterator = client.stream("test:App", "generate", "prompt_arg") # Before iteration, no request sent assert len(client._sock._sent) == 0 # First iteration sends request next(iterator) assert len(client._sock._sent) == 1 # Decode sent request sent_data = client._sock._sent[0] length = struct.unpack( DirtyProtocol.HEADER_FORMAT, sent_data[:DirtyProtocol.HEADER_SIZE] )[0] request = DirtyProtocol.decode( sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] ) assert request["type"] == "request" assert request["app_path"] == "test:App" assert request["action"] == "generate" assert request["args"] == ["prompt_arg"] class TestDirtyStreamIteratorEdgeCases: """Edge cases for streaming.""" def test_stream_many_chunks(self): """Test streaming with many chunks.""" messages = [] for i in range(100): messages.append(make_chunk_message("req-123", f"chunk-{i}")) messages.append(make_end_message("req-123")) client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) assert len(chunks) == 100 assert chunks[0] == "chunk-0" assert chunks[99] == "chunk-99" def test_stream_with_kwargs(self): """Test streaming with keyword arguments.""" messages = [ make_chunk_message("req-123", "data"), make_end_message("req-123"), ] client = create_client_with_mock_socket(messages) # Use kwargs list(client.stream("test:App", "generate", "arg1", key="value")) # Check the sent request includes kwargs sent_data = client._sock._sent[0] length = struct.unpack( DirtyProtocol.HEADER_FORMAT, sent_data[:DirtyProtocol.HEADER_SIZE] )[0] request = DirtyProtocol.decode( sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] ) assert request["args"] == ["arg1"] assert request["kwargs"] == {"key": "value"}