gunicorn/tests/test_dirty_client.py
2026-01-25 10:21:18 +01:00

325 lines
10 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty client module."""
import os
import socket
import tempfile
import threading
import pytest
from gunicorn.dirty.client import (
DirtyClient,
get_dirty_client,
get_dirty_socket_path,
set_dirty_socket_path,
close_dirty_client,
)
from gunicorn.dirty.errors import DirtyConnectionError, DirtyError
from gunicorn.dirty.protocol import DirtyProtocol, make_response
class TestDirtyClientInit:
"""Tests for DirtyClient initialization."""
def test_init_attributes(self):
"""Test that client is initialized with correct attributes."""
client = DirtyClient("/tmp/test.sock", timeout=60.0)
assert client.socket_path == "/tmp/test.sock"
assert client.timeout == 60.0
assert client._sock is None
assert client._reader is None
assert client._writer is None
class TestDirtyClientSync:
"""Tests for sync API."""
def test_connect_nonexistent_socket(self):
"""Test connecting to non-existent socket."""
client = DirtyClient("/nonexistent/socket.sock")
with pytest.raises(DirtyConnectionError) as exc_info:
client.connect()
assert "Failed to connect" in str(exc_info.value)
def test_connect_success(self):
"""Test successful connection."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create a listening socket
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
try:
client = DirtyClient(socket_path)
client.connect()
assert client._sock is not None
client.close()
finally:
server_sock.close()
def test_close_idempotent(self):
"""Test that close can be called multiple times."""
client = DirtyClient("/tmp/test.sock")
client.close()
client.close() # Should not raise
class TestDirtyClientAsync:
"""Tests for async API."""
@pytest.mark.asyncio
async def test_connect_async_nonexistent_socket(self):
"""Test async connecting to non-existent socket."""
client = DirtyClient("/nonexistent/socket.sock", timeout=1.0)
with pytest.raises(DirtyConnectionError):
await client.connect_async()
@pytest.mark.asyncio
async def test_close_async_idempotent(self):
"""Test that close_async can be called multiple times."""
client = DirtyClient("/tmp/test.sock")
await client.close_async()
await client.close_async() # Should not raise
class TestDirtyClientContextManagers:
"""Tests for context manager functionality."""
def test_sync_context_manager_connection_error(self):
"""Test sync context manager with connection error."""
client = DirtyClient("/nonexistent/socket.sock")
with pytest.raises(DirtyConnectionError):
with client:
pass
@pytest.mark.asyncio
async def test_async_context_manager_connection_error(self):
"""Test async context manager with connection error."""
client = DirtyClient("/nonexistent/socket.sock", timeout=1.0)
with pytest.raises(DirtyConnectionError):
async with client:
pass
class TestDirtyClientHelpers:
"""Tests for helper functions."""
def test_set_get_socket_path(self):
"""Test setting and getting socket path."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
try:
set_dirty_socket_path("/tmp/dirty.sock")
assert get_dirty_socket_path() == "/tmp/dirty.sock"
finally:
set_dirty_socket_path(None)
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
def test_get_socket_path_from_env(self):
"""Test getting socket path from environment."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
try:
set_dirty_socket_path(None)
os.environ['GUNICORN_DIRTY_SOCKET'] = "/env/dirty.sock"
assert get_dirty_socket_path() == "/env/dirty.sock"
finally:
set_dirty_socket_path(None)
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
else:
os.environ.pop('GUNICORN_DIRTY_SOCKET', None)
def test_get_socket_path_not_configured(self):
"""Test error when socket path not configured."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
try:
set_dirty_socket_path(None)
os.environ.pop('GUNICORN_DIRTY_SOCKET', None)
with pytest.raises(DirtyError) as exc_info:
get_dirty_socket_path()
assert "not configured" in str(exc_info.value)
finally:
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
def test_get_dirty_client_thread_local(self):
"""Test that get_dirty_client returns thread-local client."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
try:
set_dirty_socket_path("/tmp/test.sock")
# Clean up any existing client
close_dirty_client()
client1 = get_dirty_client()
client2 = get_dirty_client()
# Should return same instance in same thread
assert client1 is client2
close_dirty_client()
finally:
set_dirty_socket_path(None)
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
def test_get_dirty_client_different_threads(self):
"""Test that different threads get different clients."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
clients = []
try:
set_dirty_socket_path("/tmp/test.sock")
def get_client():
clients.append(get_dirty_client())
close_dirty_client()
# Clean up main thread client
close_dirty_client()
t1 = threading.Thread(target=get_client)
t2 = threading.Thread(target=get_client)
t1.start()
t2.start()
t1.join()
t2.join()
# Different threads should get different clients
assert len(clients) == 2
assert clients[0] is not clients[1]
finally:
set_dirty_socket_path(None)
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
def test_close_dirty_client(self):
"""Test closing thread-local client."""
original = os.environ.get('GUNICORN_DIRTY_SOCKET')
try:
set_dirty_socket_path("/tmp/test.sock")
client = get_dirty_client()
close_dirty_client()
# Should be able to get a new client
client2 = get_dirty_client()
assert client2 is not client
close_dirty_client()
finally:
set_dirty_socket_path(None)
if original:
os.environ['GUNICORN_DIRTY_SOCKET'] = original
class TestDirtyClientResponseHandling:
"""Tests for response handling."""
def test_handle_response_success(self):
"""Test handling successful response."""
client = DirtyClient("/tmp/test.sock")
response = make_response("test-id", {"data": "value"})
result = client._handle_response(response)
assert result == {"data": "value"}
def test_handle_response_error(self):
"""Test handling error response."""
client = DirtyClient("/tmp/test.sock")
response = {
"type": DirtyProtocol.MSG_TYPE_ERROR,
"id": "test-id",
"error": {
"error_type": "DirtyError",
"message": "Test error",
"details": {},
},
}
with pytest.raises(DirtyError) as exc_info:
client._handle_response(response)
assert "Test error" in str(exc_info.value)
def test_handle_response_unknown_type(self):
"""Test handling unknown response type."""
client = DirtyClient("/tmp/test.sock")
response = {
"type": "unknown",
"id": "test-id",
}
with pytest.raises(DirtyError) as exc_info:
client._handle_response(response)
assert "Unknown response type" in str(exc_info.value)
class TestDirtyClientExecute:
"""Tests for execute functionality with mock sockets."""
def test_execute_with_socket_pair(self):
"""Test execute using a socket pair to simulate server."""
import threading
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create server socket
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
response_sent = threading.Event()
def server_handler():
conn, _ = server_sock.accept()
try:
# Read request
msg = DirtyProtocol.read_message(conn)
# Send response
resp = make_response(msg["id"], {"result": "success"})
DirtyProtocol.write_message(conn, resp)
response_sent.set()
finally:
conn.close()
server_thread = threading.Thread(target=server_handler)
server_thread.start()
try:
client = DirtyClient(socket_path, timeout=5.0)
result = client.execute("test:App", "action", "arg1", key="value")
assert result == {"result": "success"}
client.close()
finally:
response_sent.wait(timeout=2.0)
server_thread.join(timeout=2.0)
server_sock.close()
def test_close_socket_clears_sock(self):
"""Test that _close_socket clears the socket."""
client = DirtyClient("/tmp/test.sock")
# Simulate having a socket
client._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client._close_socket()
assert client._sock is None