mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
New test files covering areas identified as gaps compared to Daphne and Uvicorn test coverage: - test_asgi_header_security.py: Header validation, normalization, injection prevention - test_asgi_error_handling.py: Application errors, body receiver errors, graceful shutdown - test_asgi_protocol_http.py: HTTP connection management, chunked encoding, methods, scope building - test_asgi_websocket_enhanced.py: WebSocket message limits, connection rejection, subprotocols - test_asgi_lifespan.py: Lifespan message formats and behavior - test_asgi_forwarded_headers.py: X-Forwarded-* and proxy header handling
395 lines
12 KiB
Python
395 lines
12 KiB
Python
#
|
|
# This file is part of gunicorn released under the MIT license.
|
|
# See the NOTICE for more information.
|
|
|
|
"""
|
|
ASGI error handling tests.
|
|
|
|
Tests for application error scenarios and graceful shutdown behavior
|
|
to ensure robust error handling in ASGI applications.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
|
|
from gunicorn.config import Config
|
|
|
|
|
|
# ============================================================================
|
|
# Application Error Tests
|
|
# ============================================================================
|
|
|
|
class TestApplicationErrors:
|
|
"""Test handling of ASGI application errors."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create an ASGIProtocol instance for testing."""
|
|
from gunicorn.asgi.protocol import ASGIProtocol
|
|
|
|
worker = mock.Mock()
|
|
worker.cfg = Config()
|
|
worker.log = mock.Mock()
|
|
worker.asgi = mock.Mock()
|
|
worker.nr_conns = 1
|
|
worker.loop = mock.Mock()
|
|
|
|
protocol = ASGIProtocol(worker)
|
|
protocol._closed = False
|
|
return protocol
|
|
|
|
def _create_mock_request(self):
|
|
"""Create a mock HTTP request."""
|
|
request = mock.Mock()
|
|
request.method = "GET"
|
|
request.path = "/"
|
|
request.raw_path = b"/"
|
|
request.query = ""
|
|
request.version = (1, 1)
|
|
request.scheme = "http"
|
|
request.headers = []
|
|
request.content_length = 0
|
|
request.chunked = False
|
|
return request
|
|
|
|
def test_protocol_tracks_closed_state(self):
|
|
"""Protocol should track closed state."""
|
|
protocol = self._create_protocol()
|
|
|
|
assert protocol._closed is False
|
|
|
|
protocol._closed = True
|
|
|
|
assert protocol._closed is True
|
|
|
|
def test_connection_lost_sets_closed(self):
|
|
"""connection_lost should set closed state."""
|
|
protocol = self._create_protocol()
|
|
protocol.reader = mock.Mock()
|
|
|
|
assert protocol._closed is False
|
|
|
|
protocol.connection_lost(None)
|
|
|
|
assert protocol._closed is True
|
|
|
|
def test_connection_lost_with_exception(self):
|
|
"""connection_lost handles exception argument gracefully."""
|
|
protocol = self._create_protocol()
|
|
protocol.reader = mock.Mock()
|
|
|
|
exc = ConnectionResetError("Connection reset")
|
|
protocol.connection_lost(exc)
|
|
|
|
assert protocol._closed is True
|
|
|
|
|
|
# ============================================================================
|
|
# Response Info Tests
|
|
# ============================================================================
|
|
|
|
class TestResponseInfo:
|
|
"""Test response info tracking."""
|
|
|
|
def test_response_info_initial(self):
|
|
"""Test initial ASGIResponseInfo values."""
|
|
from gunicorn.asgi.protocol import ASGIResponseInfo
|
|
|
|
info = ASGIResponseInfo(status=200, headers=[], sent=False)
|
|
|
|
assert info.status == 200
|
|
assert info.headers == []
|
|
assert info.sent is False
|
|
|
|
def test_response_info_with_headers(self):
|
|
"""Test ASGIResponseInfo with headers."""
|
|
from gunicorn.asgi.protocol import ASGIResponseInfo
|
|
|
|
headers = [
|
|
(b"content-type", b"text/plain"),
|
|
(b"content-length", b"5"),
|
|
]
|
|
info = ASGIResponseInfo(status=200, headers=headers, sent=True)
|
|
|
|
assert info.status == 200
|
|
assert len(info.headers) == 2
|
|
assert info.sent is True
|
|
|
|
|
|
# ============================================================================
|
|
# Body Receiver Error Tests
|
|
# ============================================================================
|
|
|
|
class TestBodyReceiverErrors:
|
|
"""Test error handling in BodyReceiver."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create an ASGIProtocol instance for testing."""
|
|
from gunicorn.asgi.protocol import ASGIProtocol
|
|
|
|
worker = mock.Mock()
|
|
worker.cfg = Config()
|
|
worker.log = mock.Mock()
|
|
worker.asgi = mock.Mock()
|
|
worker.nr_conns = 1
|
|
worker.loop = mock.Mock()
|
|
|
|
protocol = ASGIProtocol(worker)
|
|
protocol._closed = False
|
|
return protocol
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_body_receiver_handles_closed_protocol(self):
|
|
"""BodyReceiver should handle protocol being closed."""
|
|
from gunicorn.asgi.protocol import BodyReceiver
|
|
|
|
protocol = self._create_protocol()
|
|
|
|
mock_request = mock.Mock()
|
|
mock_request.content_length = 0
|
|
mock_request.chunked = False
|
|
|
|
body_receiver = BodyReceiver(mock_request, protocol)
|
|
|
|
# Consume the empty body
|
|
msg = await body_receiver.receive()
|
|
assert msg["type"] == "http.request"
|
|
assert msg["more_body"] is False
|
|
|
|
# Mark protocol as closed
|
|
protocol._closed = True
|
|
|
|
# Signal disconnect
|
|
body_receiver.signal_disconnect()
|
|
|
|
# Receive should return disconnect
|
|
msg = await body_receiver.receive()
|
|
assert msg == {"type": "http.disconnect"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_body_receiver_multiple_signal_disconnect(self):
|
|
"""Multiple signal_disconnect calls should be safe."""
|
|
from gunicorn.asgi.protocol import BodyReceiver
|
|
|
|
protocol = self._create_protocol()
|
|
|
|
mock_request = mock.Mock()
|
|
mock_request.content_length = 0
|
|
mock_request.chunked = False
|
|
|
|
body_receiver = BodyReceiver(mock_request, protocol)
|
|
|
|
# Signal disconnect multiple times - should not raise
|
|
body_receiver.signal_disconnect()
|
|
body_receiver.signal_disconnect()
|
|
body_receiver.signal_disconnect()
|
|
|
|
assert body_receiver._closed is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_body_receiver_feed_after_complete(self):
|
|
"""Feeding data after body is complete should be safe."""
|
|
from gunicorn.asgi.protocol import BodyReceiver
|
|
|
|
protocol = self._create_protocol()
|
|
|
|
mock_request = mock.Mock()
|
|
mock_request.content_length = 5
|
|
mock_request.chunked = False
|
|
|
|
body_receiver = BodyReceiver(mock_request, protocol)
|
|
|
|
# Feed the expected body
|
|
body_receiver.feed(b"hello")
|
|
body_receiver.set_complete()
|
|
|
|
# Consume the body
|
|
msg = await body_receiver.receive()
|
|
assert msg["body"] == b"hello"
|
|
assert msg["more_body"] is False
|
|
|
|
# Feeding more data after complete should be safe
|
|
body_receiver.feed(b"extra") # Should not raise
|
|
|
|
|
|
# ============================================================================
|
|
# Graceful Shutdown Tests
|
|
# ============================================================================
|
|
|
|
class TestGracefulShutdown:
|
|
"""Test graceful shutdown behavior."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create an ASGIProtocol instance for testing."""
|
|
from gunicorn.asgi.protocol import ASGIProtocol
|
|
|
|
worker = mock.Mock()
|
|
worker.cfg = Config()
|
|
worker.log = mock.Mock()
|
|
worker.asgi = mock.Mock()
|
|
worker.nr_conns = 1
|
|
worker.loop = mock.Mock()
|
|
|
|
protocol = ASGIProtocol(worker)
|
|
protocol._closed = False
|
|
return protocol
|
|
|
|
def test_graceful_shutdown_schedules_cancel(self):
|
|
"""Graceful shutdown should schedule task cancellation."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Create a mock task
|
|
mock_task = mock.Mock()
|
|
mock_task.done.return_value = False
|
|
protocol._task = mock_task
|
|
protocol.reader = mock.Mock()
|
|
|
|
# Simulate connection lost
|
|
protocol.connection_lost(None)
|
|
|
|
# Task should NOT be cancelled immediately
|
|
mock_task.cancel.assert_not_called()
|
|
|
|
# Cancellation should be scheduled
|
|
protocol.worker.loop.call_later.assert_called_once()
|
|
|
|
def test_completed_task_not_cancelled(self):
|
|
"""Completed tasks should not be cancelled."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Create a mock task that's already done
|
|
mock_task = mock.Mock()
|
|
mock_task.done.return_value = True
|
|
protocol._task = mock_task
|
|
protocol.reader = mock.Mock()
|
|
|
|
# Simulate connection lost
|
|
protocol.connection_lost(None)
|
|
|
|
# Task should not be cancelled
|
|
mock_task.cancel.assert_not_called()
|
|
|
|
|
|
# ============================================================================
|
|
# Protocol Timeout Tests
|
|
# ============================================================================
|
|
|
|
class TestProtocolTimeouts:
|
|
"""Test timeout handling in protocol."""
|
|
|
|
def _create_protocol(self):
|
|
"""Create an ASGIProtocol instance for testing."""
|
|
from gunicorn.asgi.protocol import ASGIProtocol
|
|
|
|
worker = mock.Mock()
|
|
worker.cfg = Config()
|
|
worker.log = mock.Mock()
|
|
worker.asgi = mock.Mock()
|
|
worker.nr_conns = 1
|
|
worker.loop = mock.Mock()
|
|
|
|
protocol = ASGIProtocol(worker)
|
|
protocol._closed = False
|
|
return protocol
|
|
|
|
def test_keepalive_timer_can_be_armed(self):
|
|
"""Keepalive timer should be arm-able."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Initially no timer handle
|
|
assert protocol._keepalive_handle is None
|
|
|
|
# Verify the method exists
|
|
assert hasattr(protocol, '_arm_keepalive_timer')
|
|
assert hasattr(protocol, '_cancel_keepalive_timer')
|
|
|
|
def test_cancel_keepalive_timer_handles_none(self):
|
|
"""Cancelling non-existent timer should be safe."""
|
|
protocol = self._create_protocol()
|
|
|
|
# Should not raise even with no timer
|
|
protocol._cancel_keepalive_timer()
|
|
protocol._cancel_keepalive_timer() # Multiple calls safe
|
|
|
|
|
|
# ============================================================================
|
|
# Request Time Tests
|
|
# ============================================================================
|
|
|
|
class TestRequestTime:
|
|
"""Test request time handling."""
|
|
|
|
def test_request_time_creation(self):
|
|
"""_RequestTime should track timing."""
|
|
from gunicorn.asgi.protocol import _RequestTime
|
|
|
|
request_time = _RequestTime(1.5)
|
|
|
|
# _RequestTime splits into seconds and microseconds
|
|
assert hasattr(request_time, 'seconds')
|
|
assert hasattr(request_time, 'microseconds')
|
|
|
|
def test_request_time_conversion(self):
|
|
"""_RequestTime should store time as seconds + microseconds."""
|
|
from gunicorn.asgi.protocol import _RequestTime
|
|
|
|
# 1.5 seconds = 1 second + 500000 microseconds
|
|
request_time = _RequestTime(1.5)
|
|
|
|
assert request_time.seconds == 1
|
|
assert request_time.microseconds == 500000
|
|
|
|
def test_request_time_with_zero(self):
|
|
"""_RequestTime with zero elapsed time."""
|
|
from gunicorn.asgi.protocol import _RequestTime
|
|
|
|
request_time = _RequestTime(0.0)
|
|
|
|
assert request_time.seconds == 0
|
|
assert request_time.microseconds == 0
|
|
|
|
|
|
# ============================================================================
|
|
# Message Validation Tests
|
|
# ============================================================================
|
|
|
|
class TestMessageValidation:
|
|
"""Test ASGI message validation."""
|
|
|
|
def test_response_start_requires_status(self):
|
|
"""http.response.start must have status."""
|
|
# Valid response start
|
|
valid_msg = {
|
|
"type": "http.response.start",
|
|
"status": 200,
|
|
"headers": [],
|
|
}
|
|
assert valid_msg["type"] == "http.response.start"
|
|
assert "status" in valid_msg
|
|
|
|
def test_response_body_message_format(self):
|
|
"""http.response.body format validation."""
|
|
# With body
|
|
msg_with_body = {
|
|
"type": "http.response.body",
|
|
"body": b"Hello",
|
|
"more_body": False,
|
|
}
|
|
assert isinstance(msg_with_body["body"], bytes)
|
|
|
|
# Empty body
|
|
msg_empty = {
|
|
"type": "http.response.body",
|
|
"body": b"",
|
|
"more_body": False,
|
|
}
|
|
assert msg_empty["body"] == b""
|
|
|
|
def test_disconnect_message_minimal(self):
|
|
"""http.disconnect message should be minimal."""
|
|
msg = {"type": "http.disconnect"}
|
|
|
|
assert msg == {"type": "http.disconnect"}
|
|
assert len(msg) == 1
|