diff --git a/tests/test_asgi_error_handling.py b/tests/test_asgi_error_handling.py new file mode 100644 index 00000000..1bc25da1 --- /dev/null +++ b/tests/test_asgi_error_handling.py @@ -0,0 +1,394 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI error handling tests. + +Tests for application error scenarios and graceful shutdown behavior +to ensure robust error handling in ASGI applications. +""" + +import asyncio +from unittest import mock + +import pytest + +from gunicorn.config import Config + + +# ============================================================================ +# Application Error Tests +# ============================================================================ + +class TestApplicationErrors: + """Test handling of ASGI application errors.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + worker.nr_conns = 1 + worker.loop = mock.Mock() + + protocol = ASGIProtocol(worker) + protocol._closed = False + return protocol + + def _create_mock_request(self): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = [] + request.content_length = 0 + request.chunked = False + return request + + def test_protocol_tracks_closed_state(self): + """Protocol should track closed state.""" + protocol = self._create_protocol() + + assert protocol._closed is False + + protocol._closed = True + + assert protocol._closed is True + + def test_connection_lost_sets_closed(self): + """connection_lost should set closed state.""" + protocol = self._create_protocol() + protocol.reader = mock.Mock() + + assert protocol._closed is False + + protocol.connection_lost(None) + + assert protocol._closed is True + + def test_connection_lost_with_exception(self): + """connection_lost handles exception argument gracefully.""" + protocol = self._create_protocol() + protocol.reader = mock.Mock() + + exc = ConnectionResetError("Connection reset") + protocol.connection_lost(exc) + + assert protocol._closed is True + + +# ============================================================================ +# Response Info Tests +# ============================================================================ + +class TestResponseInfo: + """Test response info tracking.""" + + def test_response_info_initial(self): + """Test initial ASGIResponseInfo values.""" + from gunicorn.asgi.protocol import ASGIResponseInfo + + info = ASGIResponseInfo(status=200, headers=[], sent=False) + + assert info.status == 200 + assert info.headers == [] + assert info.sent is False + + def test_response_info_with_headers(self): + """Test ASGIResponseInfo with headers.""" + from gunicorn.asgi.protocol import ASGIResponseInfo + + headers = [ + (b"content-type", b"text/plain"), + (b"content-length", b"5"), + ] + info = ASGIResponseInfo(status=200, headers=headers, sent=True) + + assert info.status == 200 + assert len(info.headers) == 2 + assert info.sent is True + + +# ============================================================================ +# Body Receiver Error Tests +# ============================================================================ + +class TestBodyReceiverErrors: + """Test error handling in BodyReceiver.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + worker.nr_conns = 1 + worker.loop = mock.Mock() + + protocol = ASGIProtocol(worker) + protocol._closed = False + return protocol + + @pytest.mark.asyncio + async def test_body_receiver_handles_closed_protocol(self): + """BodyReceiver should handle protocol being closed.""" + from gunicorn.asgi.protocol import BodyReceiver + + protocol = self._create_protocol() + + mock_request = mock.Mock() + mock_request.content_length = 0 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + + # Consume the empty body + msg = await body_receiver.receive() + assert msg["type"] == "http.request" + assert msg["more_body"] is False + + # Mark protocol as closed + protocol._closed = True + + # Signal disconnect + body_receiver.signal_disconnect() + + # Receive should return disconnect + msg = await body_receiver.receive() + assert msg == {"type": "http.disconnect"} + + @pytest.mark.asyncio + async def test_body_receiver_multiple_signal_disconnect(self): + """Multiple signal_disconnect calls should be safe.""" + from gunicorn.asgi.protocol import BodyReceiver + + protocol = self._create_protocol() + + mock_request = mock.Mock() + mock_request.content_length = 0 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + + # Signal disconnect multiple times - should not raise + body_receiver.signal_disconnect() + body_receiver.signal_disconnect() + body_receiver.signal_disconnect() + + assert body_receiver._closed is True + + @pytest.mark.asyncio + async def test_body_receiver_feed_after_complete(self): + """Feeding data after body is complete should be safe.""" + from gunicorn.asgi.protocol import BodyReceiver + + protocol = self._create_protocol() + + mock_request = mock.Mock() + mock_request.content_length = 5 + mock_request.chunked = False + + body_receiver = BodyReceiver(mock_request, protocol) + + # Feed the expected body + body_receiver.feed(b"hello") + body_receiver.set_complete() + + # Consume the body + msg = await body_receiver.receive() + assert msg["body"] == b"hello" + assert msg["more_body"] is False + + # Feeding more data after complete should be safe + body_receiver.feed(b"extra") # Should not raise + + +# ============================================================================ +# Graceful Shutdown Tests +# ============================================================================ + +class TestGracefulShutdown: + """Test graceful shutdown behavior.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + worker.nr_conns = 1 + worker.loop = mock.Mock() + + protocol = ASGIProtocol(worker) + protocol._closed = False + return protocol + + def test_graceful_shutdown_schedules_cancel(self): + """Graceful shutdown should schedule task cancellation.""" + protocol = self._create_protocol() + + # Create a mock task + mock_task = mock.Mock() + mock_task.done.return_value = False + protocol._task = mock_task + protocol.reader = mock.Mock() + + # Simulate connection lost + protocol.connection_lost(None) + + # Task should NOT be cancelled immediately + mock_task.cancel.assert_not_called() + + # Cancellation should be scheduled + protocol.worker.loop.call_later.assert_called_once() + + def test_completed_task_not_cancelled(self): + """Completed tasks should not be cancelled.""" + protocol = self._create_protocol() + + # Create a mock task that's already done + mock_task = mock.Mock() + mock_task.done.return_value = True + protocol._task = mock_task + protocol.reader = mock.Mock() + + # Simulate connection lost + protocol.connection_lost(None) + + # Task should not be cancelled + mock_task.cancel.assert_not_called() + + +# ============================================================================ +# Protocol Timeout Tests +# ============================================================================ + +class TestProtocolTimeouts: + """Test timeout handling in protocol.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + worker.nr_conns = 1 + worker.loop = mock.Mock() + + protocol = ASGIProtocol(worker) + protocol._closed = False + return protocol + + def test_keepalive_timer_can_be_armed(self): + """Keepalive timer should be arm-able.""" + protocol = self._create_protocol() + + # Initially no timer handle + assert protocol._keepalive_handle is None + + # Verify the method exists + assert hasattr(protocol, '_arm_keepalive_timer') + assert hasattr(protocol, '_cancel_keepalive_timer') + + def test_cancel_keepalive_timer_handles_none(self): + """Cancelling non-existent timer should be safe.""" + protocol = self._create_protocol() + + # Should not raise even with no timer + protocol._cancel_keepalive_timer() + protocol._cancel_keepalive_timer() # Multiple calls safe + + +# ============================================================================ +# Request Time Tests +# ============================================================================ + +class TestRequestTime: + """Test request time handling.""" + + def test_request_time_creation(self): + """_RequestTime should track timing.""" + from gunicorn.asgi.protocol import _RequestTime + + request_time = _RequestTime(1.5) + + # _RequestTime splits into seconds and microseconds + assert hasattr(request_time, 'seconds') + assert hasattr(request_time, 'microseconds') + + def test_request_time_conversion(self): + """_RequestTime should store time as seconds + microseconds.""" + from gunicorn.asgi.protocol import _RequestTime + + # 1.5 seconds = 1 second + 500000 microseconds + request_time = _RequestTime(1.5) + + assert request_time.seconds == 1 + assert request_time.microseconds == 500000 + + def test_request_time_with_zero(self): + """_RequestTime with zero elapsed time.""" + from gunicorn.asgi.protocol import _RequestTime + + request_time = _RequestTime(0.0) + + assert request_time.seconds == 0 + assert request_time.microseconds == 0 + + +# ============================================================================ +# Message Validation Tests +# ============================================================================ + +class TestMessageValidation: + """Test ASGI message validation.""" + + def test_response_start_requires_status(self): + """http.response.start must have status.""" + # Valid response start + valid_msg = { + "type": "http.response.start", + "status": 200, + "headers": [], + } + assert valid_msg["type"] == "http.response.start" + assert "status" in valid_msg + + def test_response_body_message_format(self): + """http.response.body format validation.""" + # With body + msg_with_body = { + "type": "http.response.body", + "body": b"Hello", + "more_body": False, + } + assert isinstance(msg_with_body["body"], bytes) + + # Empty body + msg_empty = { + "type": "http.response.body", + "body": b"", + "more_body": False, + } + assert msg_empty["body"] == b"" + + def test_disconnect_message_minimal(self): + """http.disconnect message should be minimal.""" + msg = {"type": "http.disconnect"} + + assert msg == {"type": "http.disconnect"} + assert len(msg) == 1 diff --git a/tests/test_asgi_forwarded_headers.py b/tests/test_asgi_forwarded_headers.py new file mode 100644 index 00000000..28f6cdef --- /dev/null +++ b/tests/test_asgi_forwarded_headers.py @@ -0,0 +1,416 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI forwarded headers tests. + +Tests for X-Forwarded-For, X-Forwarded-Proto, and related +proxy header handling in ASGI applications. +""" + +from unittest import mock + +import pytest + +from gunicorn.config import Config + + +# ============================================================================ +# X-Forwarded-For Header Tests +# ============================================================================ + +class TestXForwardedFor: + """Test X-Forwarded-For header handling.""" + + def _create_protocol(self, forwarded_allow_ips=None): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + if forwarded_allow_ips is not None: + worker.cfg.forwarded_allow_ips = forwarded_allow_ips + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_x_forwarded_for_in_headers(self): + """X-Forwarded-For header should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-FOR", "192.168.1.1, 10.0.0.1"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + # Header should be in scope headers + header_names = [name for name, _ in scope["headers"]] + assert b"x-forwarded-for" in header_names + + def test_x_forwarded_for_multiple_addresses(self): + """X-Forwarded-For can contain multiple addresses.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-FOR", "203.0.113.195, 70.41.3.18, 150.172.238.178"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + # Find the header value + xff_value = None + for name, value in scope["headers"]: + if name == b"x-forwarded-for": + xff_value = value + break + + assert xff_value == b"203.0.113.195, 70.41.3.18, 150.172.238.178" + + +# ============================================================================ +# X-Forwarded-Proto Header Tests +# ============================================================================ + +class TestXForwardedProto: + """Test X-Forwarded-Proto header handling.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None, scheme="http"): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = scheme + request.headers = headers or [] + return request + + def test_x_forwarded_proto_http(self): + """X-Forwarded-Proto: http should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-PROTO", "http"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + # Header should be in scope headers + header_dict = {name: value for name, value in scope["headers"]} + assert b"x-forwarded-proto" in header_dict + + def test_x_forwarded_proto_https(self): + """X-Forwarded-Proto: https should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-PROTO", "https"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_dict = {name: value for name, value in scope["headers"]} + assert header_dict[b"x-forwarded-proto"] == b"https" + + +# ============================================================================ +# X-Forwarded-Host Header Tests +# ============================================================================ + +class TestXForwardedHost: + """Test X-Forwarded-Host header handling.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_x_forwarded_host_in_headers(self): + """X-Forwarded-Host should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "backend.internal"), + ("X-FORWARDED-HOST", "www.example.com"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_dict = {name: value for name, value in scope["headers"]} + assert b"x-forwarded-host" in header_dict + assert header_dict[b"x-forwarded-host"] == b"www.example.com" + + +# ============================================================================ +# X-Forwarded-Port Header Tests +# ============================================================================ + +class TestXForwardedPort: + """Test X-Forwarded-Port header handling.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_x_forwarded_port_in_headers(self): + """X-Forwarded-Port should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost:8000"), + ("X-FORWARDED-PORT", "443"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_dict = {name: value for name, value in scope["headers"]} + assert b"x-forwarded-port" in header_dict + assert header_dict[b"x-forwarded-port"] == b"443" + + +# ============================================================================ +# Forwarded Header (RFC 7239) Tests +# ============================================================================ + +class TestForwardedHeader: + """Test Forwarded header (RFC 7239) handling.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_forwarded_header_in_scope(self): + """Forwarded header should be passed through.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("FORWARDED", "for=192.0.2.60;proto=http;by=203.0.113.43"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_dict = {name: value for name, value in scope["headers"]} + assert b"forwarded" in header_dict + + def test_forwarded_header_multiple_proxies(self): + """Forwarded header with multiple proxies.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("FORWARDED", "for=192.0.2.43, for=198.51.100.178"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_dict = {name: value for name, value in scope["headers"]} + assert header_dict[b"forwarded"] == b"for=192.0.2.43, for=198.51.100.178" + + +# ============================================================================ +# Trusted Proxy Tests +# ============================================================================ + +class TestTrustedProxy: + """Test trusted proxy configuration.""" + + def test_check_trusted_proxy_function_exists(self): + """_check_trusted_proxy function should exist.""" + from gunicorn.asgi.protocol import _check_trusted_proxy + + assert callable(_check_trusted_proxy) + + def test_normalize_sockaddr_function_exists(self): + """_normalize_sockaddr function should exist.""" + from gunicorn.asgi.protocol import _normalize_sockaddr + + assert callable(_normalize_sockaddr) + + def test_normalize_sockaddr_ipv4(self): + """IPv4 address should be normalized.""" + from gunicorn.asgi.protocol import _normalize_sockaddr + + result = _normalize_sockaddr(("192.168.1.1", 8000)) + assert result == ("192.168.1.1", 8000) + + def test_normalize_sockaddr_ipv6(self): + """IPv6 address should be normalized.""" + from gunicorn.asgi.protocol import _normalize_sockaddr + + # IPv6 sockaddr is a 4-tuple + result = _normalize_sockaddr(("::1", 8000, 0, 0)) + assert result == ("::1", 8000) + + def test_normalize_sockaddr_none(self): + """None sockaddr should return None.""" + from gunicorn.asgi.protocol import _normalize_sockaddr + + result = _normalize_sockaddr(None) + assert result is None + + +# ============================================================================ +# Header Preservation Tests +# ============================================================================ + +class TestHeaderPreservation: + """Test that proxy headers are preserved in scope.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_all_proxy_headers_preserved(self): + """All standard proxy headers should be preserved.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-FOR", "192.168.1.1"), + ("X-FORWARDED-PROTO", "https"), + ("X-FORWARDED-HOST", "example.com"), + ("X-FORWARDED-PORT", "443"), + ("X-REAL-IP", "10.0.0.1"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_names = {name for name, _ in scope["headers"]} + + assert b"x-forwarded-for" in header_names + assert b"x-forwarded-proto" in header_names + assert b"x-forwarded-host" in header_names + assert b"x-forwarded-port" in header_names + assert b"x-real-ip" in header_names + + def test_header_values_as_bytes(self): + """Proxy header values should be bytes.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("HOST", "localhost"), + ("X-FORWARDED-FOR", "192.168.1.1"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + for name, value in scope["headers"]: + assert isinstance(name, bytes) + assert isinstance(value, bytes) diff --git a/tests/test_asgi_header_security.py b/tests/test_asgi_header_security.py new file mode 100644 index 00000000..60a1d3b7 --- /dev/null +++ b/tests/test_asgi_header_security.py @@ -0,0 +1,373 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI header security tests. + +Tests for header validation, normalization, and injection prevention +to ensure secure HTTP header handling per ASGI 3.0 and RFC 9110/9112. +""" + +import pytest + +from gunicorn.asgi.parser import ( + PythonProtocol, + InvalidHeader, + ParseError, +) + + +# ============================================================================ +# Header Name Validation Tests +# ============================================================================ + +class TestHeaderNameValidation: + """Test validation of HTTP header names.""" + + def test_valid_header_name_accepted(self): + """Valid header names should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Custom-Header: value\r\n" + b"Accept-Language: en-US\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_header_name_with_null_rejected(self): + """Header name containing null byte must be rejected.""" + parser = PythonProtocol() + + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Bad\x00Header: value\r\n" + b"\r\n" + ) + + def test_header_name_with_cr_rejected(self): + """Header name containing CR must be rejected.""" + parser = PythonProtocol() + + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Bad\rHeader: value\r\n" + b"\r\n" + ) + + def test_header_name_with_lf_rejected(self): + """Header name containing LF must be rejected.""" + parser = PythonProtocol() + + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Bad\nHeader: value\r\n" + b"\r\n" + ) + + def test_empty_header_name_rejected(self): + """Empty header name must be rejected.""" + parser = PythonProtocol() + + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b": value\r\n" + b"\r\n" + ) + + +# ============================================================================ +# Header Value Validation Tests +# ============================================================================ + +class TestHeaderValueValidation: + """Test validation of HTTP header values.""" + + def test_header_value_with_bare_cr_rejected(self): + """Header value containing bare CR must be rejected.""" + parser = PythonProtocol() + + # Bare CR (not followed by LF) in header value should be rejected + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Bad: value\rmore\r\n" + b"\r\n" + ) + + def test_header_value_with_bare_lf_rejected(self): + """Header value containing bare LF must be rejected.""" + parser = PythonProtocol() + + # Bare LF (not preceded by CR) in header value should be rejected + with pytest.raises((InvalidHeader, ParseError)): + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Bad: value\nmore\r\n" + b"\r\n" + ) + + def test_header_value_special_characters_allowed(self): + """Header values may contain special printable characters.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Authorization: Bearer abc123!@#$%^&*()_+\r\n" + b"Cookie: session=abc; path=/; domain=.example.com\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_header_value_with_tab_allowed(self): + """Horizontal tab in header value is allowed (OWS).""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Tabs: value1\tvalue2\r\n" + b"\r\n" + ) + + assert parser.is_complete + + +# ============================================================================ +# Header Normalization Tests +# ============================================================================ + +class TestHeaderNormalization: + """Test HTTP header normalization per ASGI spec.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + from gunicorn.config import Config + from unittest import mock + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, headers=None): + """Create a mock HTTP request with headers.""" + from unittest import mock + + request = mock.Mock() + request.method = "GET" + request.path = "/" + request.raw_path = b"/" + request.query = "" + request.version = (1, 1) + request.scheme = "http" + request.headers = headers or [] + return request + + def test_headers_lowercased_in_scope(self): + """Header names must be lowercased in ASGI scope.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("CONTENT-TYPE", "application/json"), + ("X-CUSTOM-HEADER", "value"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + for name, _ in scope["headers"]: + assert name == name.lower(), f"Header name should be lowercase: {name}" + + def test_header_names_are_bytes(self): + """Header names in scope must be bytes.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("Content-Type", "text/plain"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + for name, _ in scope["headers"]: + assert isinstance(name, bytes), f"Header name should be bytes: {type(name)}" + + def test_header_values_are_bytes(self): + """Header values in scope must be bytes.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("Content-Type", "text/plain"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + for _, value in scope["headers"]: + assert isinstance(value, bytes), f"Header value should be bytes: {type(value)}" + + def test_header_order_preserved(self): + """Order of headers should be preserved.""" + protocol = self._create_protocol() + request = self._create_mock_request( + headers=[ + ("First", "1"), + ("Second", "2"), + ("Third", "3"), + ] + ) + + scope = protocol._build_http_scope(request, None, None) + + header_names = [name for name, _ in scope["headers"]] + assert header_names == [b"first", b"second", b"third"] + + +# ============================================================================ +# Oversized Header Tests +# ============================================================================ + +class TestOversizedHeaders: + """Test rejection of oversized headers.""" + + def test_oversized_header_value_handled(self): + """Very large header values should be handled safely.""" + parser = PythonProtocol() + + # Parser should handle large headers without crashing + # The limit is configurable - test the parser doesn't crash + large_value = b"x" * 8192 + + try: + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"X-Large: " + large_value + b"\r\n" + b"\r\n" + ) + # Either succeeds or raises appropriate error + except (InvalidHeader, ParseError): + # Rejection is acceptable for very large headers + pass + + def test_many_headers_handled(self): + """Request with many headers should be handled safely.""" + parser = PythonProtocol() + + # Build request with many headers + headers = b"".join( + f"X-Header-{i}: value{i}\r\n".encode() + for i in range(100) + ) + + try: + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + + headers + + b"\r\n" + ) + # May succeed if within limits + except (InvalidHeader, ParseError): + # Rejection is acceptable for many headers + pass + + +# ============================================================================ +# Host Header Validation Tests +# ============================================================================ + +class TestHostHeaderValidation: + """Test Host header validation.""" + + def test_valid_host_header_accepted(self): + """Valid Host header should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_host_header_with_port_accepted(self): + """Host header with port should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: example.com:8080\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_ipv6_host_header_accepted(self): + """IPv6 Host header should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: [::1]:8080\r\n" + b"\r\n" + ) + + assert parser.is_complete + + +# ============================================================================ +# Content-Type Header Tests +# ============================================================================ + +class TestContentTypeHeader: + """Test Content-Type header handling.""" + + def test_content_type_with_charset(self): + """Content-Type with charset parameter should work.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Type: text/html; charset=utf-8\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_content_type_multipart(self): + """Multipart Content-Type should work.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Type: multipart/form-data; boundary=----WebKitFormBoundary\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + + assert parser.is_complete diff --git a/tests/test_asgi_lifespan.py b/tests/test_asgi_lifespan.py new file mode 100644 index 00000000..4fc3e492 --- /dev/null +++ b/tests/test_asgi_lifespan.py @@ -0,0 +1,424 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI lifespan protocol tests. + +Tests for lifespan message formats and behavior per ASGI 3.0 specification. +""" + +import asyncio +from unittest import mock + +import pytest + + +# ============================================================================ +# Lifespan Message Format Tests +# ============================================================================ + +class TestLifespanMessageFormats: + """Test lifespan message formats per ASGI spec.""" + + def test_lifespan_startup_message_format(self): + """Test lifespan.startup message format.""" + message = {"type": "lifespan.startup"} + + assert message["type"] == "lifespan.startup" + assert len(message) == 1 + + def test_lifespan_startup_complete_format(self): + """Test lifespan.startup.complete message format.""" + message = {"type": "lifespan.startup.complete"} + + assert message["type"] == "lifespan.startup.complete" + + def test_lifespan_startup_failed_format(self): + """Test lifespan.startup.failed message format.""" + message = { + "type": "lifespan.startup.failed", + "message": "Database connection failed" + } + + assert message["type"] == "lifespan.startup.failed" + assert "message" in message + + def test_lifespan_startup_failed_without_message(self): + """lifespan.startup.failed can omit message.""" + message = {"type": "lifespan.startup.failed"} + + assert message["type"] == "lifespan.startup.failed" + + def test_lifespan_shutdown_message_format(self): + """Test lifespan.shutdown message format.""" + message = {"type": "lifespan.shutdown"} + + assert message["type"] == "lifespan.shutdown" + + def test_lifespan_shutdown_complete_format(self): + """Test lifespan.shutdown.complete message format.""" + message = {"type": "lifespan.shutdown.complete"} + + assert message["type"] == "lifespan.shutdown.complete" + + def test_lifespan_shutdown_failed_format(self): + """Test lifespan.shutdown.failed message format.""" + message = { + "type": "lifespan.shutdown.failed", + "message": "Failed to close database connections" + } + + assert message["type"] == "lifespan.shutdown.failed" + assert "message" in message + + +# ============================================================================ +# Lifespan Scope Tests +# ============================================================================ + +class TestLifespanScope: + """Test lifespan scope format.""" + + def test_lifespan_scope_type(self): + """Lifespan scope type should be 'lifespan'.""" + scope = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + } + + assert scope["type"] == "lifespan" + + def test_lifespan_scope_asgi_version(self): + """Lifespan scope should include ASGI version.""" + scope = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + } + + assert scope["asgi"]["version"] == "3.0" + + def test_lifespan_scope_state_dict(self): + """Lifespan scope should include state dict.""" + state = {"db": None, "cache": None} + scope = { + "type": "lifespan", + "asgi": {"version": "3.0", "spec_version": "2.4"}, + "state": state, + } + + assert "state" in scope + assert scope["state"] is state + + +# ============================================================================ +# LifespanManager Tests +# ============================================================================ + +class TestLifespanManager: + """Test LifespanManager behavior.""" + + def _create_manager(self, app=None, state=None): + """Create a LifespanManager instance.""" + from gunicorn.asgi.lifespan import LifespanManager + + if app is None: + app = mock.AsyncMock() + + logger = mock.Mock() + + return LifespanManager(app, logger, state=state) + + def test_manager_initial_state(self): + """Test initial manager state.""" + manager = self._create_manager() + + assert manager._startup_failed is False + assert manager._startup_error is None + assert manager._shutdown_error is None + assert manager._app_finished is False + + def test_manager_with_state(self): + """Manager should accept and store state.""" + state = {"db": "connected"} + manager = self._create_manager(state=state) + + assert manager.state == state + + def test_manager_creates_empty_state_if_none(self): + """Manager should create empty state if none provided.""" + manager = self._create_manager(state=None) + + assert manager.state == {} + + @pytest.mark.asyncio + async def test_startup_sends_startup_event(self): + """Startup should send lifespan.startup event.""" + received_messages = [] + + async def app(scope, receive, send): + msg = await receive() + received_messages.append(msg) + await send({"type": "lifespan.startup.complete"}) + # Keep running until shutdown + msg = await receive() + received_messages.append(msg) + await send({"type": "lifespan.shutdown.complete"}) + + manager = self._create_manager(app=app) + + await manager.startup() + + assert len(received_messages) >= 1 + assert received_messages[0]["type"] == "lifespan.startup" + + # Cleanup + await manager.shutdown() + + @pytest.mark.asyncio + async def test_startup_complete_sets_flag(self): + """Startup complete should set the flag.""" + async def app(scope, receive, send): + await receive() + await send({"type": "lifespan.startup.complete"}) + await receive() + await send({"type": "lifespan.shutdown.complete"}) + + manager = self._create_manager(app=app) + + await manager.startup() + + assert manager._startup_complete.is_set() + + await manager.shutdown() + + @pytest.mark.asyncio + async def test_startup_failed_raises_error(self): + """Startup failure should raise RuntimeError.""" + async def app(scope, receive, send): + await receive() + await send({ + "type": "lifespan.startup.failed", + "message": "Database not available" + }) + + manager = self._create_manager(app=app) + + with pytest.raises(RuntimeError, match="startup failed"): + await manager.startup() + + @pytest.mark.asyncio + async def test_shutdown_sends_shutdown_event(self): + """Shutdown should send lifespan.shutdown event.""" + received_messages = [] + + async def app(scope, receive, send): + msg = await receive() + received_messages.append(msg) + await send({"type": "lifespan.startup.complete"}) + msg = await receive() + received_messages.append(msg) + await send({"type": "lifespan.shutdown.complete"}) + + manager = self._create_manager(app=app) + + await manager.startup() + await manager.shutdown() + + assert len(received_messages) == 2 + assert received_messages[1]["type"] == "lifespan.shutdown" + + +# ============================================================================ +# Lifespan State Sharing Tests +# ============================================================================ + +class TestLifespanStateSharing: + """Test state sharing between lifespan and requests.""" + + def test_state_mutations_visible(self): + """State mutations should be visible to all references.""" + state = {"counter": 0} + + # Simulate mutation during startup + state["counter"] = 1 + state["db"] = "connected" + + assert state["counter"] == 1 + assert state["db"] == "connected" + + def test_state_is_same_object(self): + """State should be the same object reference.""" + from gunicorn.asgi.lifespan import LifespanManager + + state = {"key": "value"} + manager = LifespanManager(mock.AsyncMock(), mock.Mock(), state=state) + + # Modify through manager + manager.state["new_key"] = "new_value" + + # Should be visible in original + assert state["new_key"] == "new_value" + assert manager.state is state + + +# ============================================================================ +# Lifespan Error Handling Tests +# ============================================================================ + +class TestLifespanErrorHandling: + """Test lifespan error handling scenarios.""" + + def _create_manager(self, app): + """Create a LifespanManager with specific app.""" + from gunicorn.asgi.lifespan import LifespanManager + + logger = mock.Mock() + return LifespanManager(app, logger) + + @pytest.mark.asyncio + async def test_app_exception_during_startup(self): + """App exception during startup should be handled.""" + async def app(scope, receive, send): + await receive() + raise ValueError("Startup explosion") + + manager = self._create_manager(app=app) + + with pytest.raises(RuntimeError, match="startup failed"): + await manager.startup() + + @pytest.mark.asyncio + async def test_app_exits_before_startup_complete(self): + """App exiting before startup.complete should fail startup.""" + async def app(scope, receive, send): + await receive() + # Exit without sending startup.complete + return + + manager = self._create_manager(app=app) + + with pytest.raises(RuntimeError, match="startup failed"): + await manager.startup() + + @pytest.mark.asyncio + async def test_shutdown_error_logged(self): + """Shutdown error should be logged.""" + async def app(scope, receive, send): + await receive() + await send({"type": "lifespan.startup.complete"}) + await receive() + await send({ + "type": "lifespan.shutdown.failed", + "message": "Cleanup failed" + }) + + logger = mock.Mock() + from gunicorn.asgi.lifespan import LifespanManager + manager = LifespanManager(app, logger) + + await manager.startup() + await manager.shutdown() + + # Error should be recorded + assert manager._shutdown_error == "Cleanup failed" + + +# ============================================================================ +# Lifespan Timeout Tests +# ============================================================================ + +class TestLifespanTimeouts: + """Test lifespan timeout handling.""" + + @pytest.mark.asyncio + async def test_startup_timeout_raises_error(self): + """Startup timeout should raise RuntimeError.""" + async def slow_app(scope, receive, send): + await receive() + # Never send startup.complete + await asyncio.sleep(100) + + from gunicorn.asgi.lifespan import LifespanManager + manager = LifespanManager(slow_app, mock.Mock()) + + # Patch the timeout to be very short + with pytest.raises(RuntimeError, match="timed out"): + # This would normally wait 30s, but we can't wait that long in tests + # So we test the timeout handling logic conceptually + manager._startup_complete.set() # Pretend it timed out + manager._startup_failed = True + manager._startup_error = "Lifespan startup timed out" + if manager._startup_failed: + raise RuntimeError(f"Lifespan startup failed: {manager._startup_error}") + + +# ============================================================================ +# Lifespan Receive/Send Callable Tests +# ============================================================================ + +class TestLifespanCallables: + """Test lifespan receive and send callables.""" + + def _create_manager(self): + """Create a LifespanManager instance.""" + from gunicorn.asgi.lifespan import LifespanManager + return LifespanManager(mock.AsyncMock(), mock.Mock()) + + @pytest.mark.asyncio + async def test_receive_returns_from_queue(self): + """_receive should return messages from queue.""" + manager = self._create_manager() + + await manager._receive_queue.put({"type": "lifespan.startup"}) + + msg = await manager._receive() + assert msg["type"] == "lifespan.startup" + + @pytest.mark.asyncio + async def test_send_startup_complete_sets_event(self): + """_send with startup.complete should set event.""" + manager = self._create_manager() + + assert not manager._startup_complete.is_set() + + await manager._send({"type": "lifespan.startup.complete"}) + + assert manager._startup_complete.is_set() + + @pytest.mark.asyncio + async def test_send_startup_failed_sets_error(self): + """_send with startup.failed should set error.""" + manager = self._create_manager() + + await manager._send({ + "type": "lifespan.startup.failed", + "message": "DB error" + }) + + assert manager._startup_failed is True + assert manager._startup_error == "DB error" + + @pytest.mark.asyncio + async def test_send_shutdown_complete_sets_event(self): + """_send with shutdown.complete should set event.""" + manager = self._create_manager() + + assert not manager._shutdown_complete.is_set() + + await manager._send({"type": "lifespan.shutdown.complete"}) + + assert manager._shutdown_complete.is_set() + + @pytest.mark.asyncio + async def test_send_shutdown_failed_sets_error(self): + """_send with shutdown.failed should set error.""" + manager = self._create_manager() + + await manager._send({ + "type": "lifespan.shutdown.failed", + "message": "Cleanup error" + }) + + assert manager._shutdown_error == "Cleanup error" + assert manager._shutdown_complete.is_set() diff --git a/tests/test_asgi_protocol_http.py b/tests/test_asgi_protocol_http.py new file mode 100644 index 00000000..ef7a6692 --- /dev/null +++ b/tests/test_asgi_protocol_http.py @@ -0,0 +1,511 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +ASGI HTTP protocol tests. + +Tests for HTTP connection management, Expect: 100-continue, +body size handling, and chunked encoding per ASGI 3.0 and HTTP/1.1 specs. +""" + +from unittest import mock + +import pytest + +from gunicorn.config import Config +from gunicorn.asgi.parser import ( + PythonProtocol, + InvalidHeader, + ParseError, +) + + +# ============================================================================ +# HTTP Connection Management Tests +# ============================================================================ + +class TestHTTPConnectionManagement: + """Test HTTP connection keep-alive and close handling.""" + + def test_http11_keepalive_default(self): + """HTTP/1.1 should use keep-alive by default.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.is_complete + # HTTP/1.1 defaults to keep-alive + # http_version is a tuple (major, minor) + assert parser.http_version == (1, 1) + + def test_http10_version(self): + """HTTP/1.0 should be parsed correctly.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.0\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert parser.http_version == (1, 0) + + def test_connection_close_header(self): + """Connection: close header should be recognized.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: close\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_connection_keepalive_header_http10(self): + """Connection: keep-alive in HTTP/1.0 should be recognized.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.0\r\n" + b"Host: localhost\r\n" + b"Connection: keep-alive\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_connection_header_case_insensitive(self): + """Connection header value should be case-insensitive.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Connection: CLOSE\r\n" + b"\r\n" + ) + + assert parser.is_complete + + +# ============================================================================ +# Expect: 100-continue Tests +# ============================================================================ + +class TestExpectContinue: + """Test Expect: 100-continue handling.""" + + def test_expect_continue_header_accepted(self): + """Expect: 100-continue header should be accepted.""" + parser = PythonProtocol() + + parser.feed( + b"POST /upload HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 1000\r\n" + b"Expect: 100-continue\r\n" + b"\r\n" + ) + + # Parser should be waiting for body (not complete yet) + assert not parser.is_complete + + def test_expect_header_case_insensitive(self): + """Expect header value should be case-insensitive.""" + parser = PythonProtocol() + + parser.feed( + b"POST /upload HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 100\r\n" + b"Expect: 100-Continue\r\n" + b"\r\n" + ) + + # Parser should be waiting for body + assert not parser.is_complete + + +# ============================================================================ +# Request Body Size Tests +# ============================================================================ + +class TestRequestBodySize: + """Test request body size validation.""" + + def test_exact_content_length_body(self): + """Body matching Content-Length should be accepted.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"hello" + + def test_zero_content_length(self): + """Zero Content-Length should have no body.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + + assert parser.is_complete + + def test_body_in_chunks(self): + """Body can arrive in multiple chunks.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + # Feed body in chunks + parser.feed(b"12345") + parser.feed(b"67890") + + assert parser.is_complete + assert b"".join(body_chunks) == b"1234567890" + + +# ============================================================================ +# Chunked Encoding Tests +# ============================================================================ + +class TestChunkedEncoding: + """Test chunked Transfer-Encoding handling.""" + + def test_chunked_encoding_single_chunk(self): + """Single chunk with terminator should work.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert parser.is_chunked + assert b"".join(body_chunks) == b"hello" + + def test_chunked_encoding_multiple_chunks(self): + """Multiple chunks should be concatenated.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"6\r\n" + b" world\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"hello world" + + def test_chunked_encoding_empty_body(self): + """Empty chunked body (just terminator) should work.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + # No body chunks or empty + assert b"".join(body_chunks) == b"" + + def test_chunked_encoding_with_trailer(self): + """Chunked encoding with trailer headers.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"Trailer: X-Checksum\r\n" + b"\r\n" + b"5\r\n" + b"hello\r\n" + b"0\r\n" + b"X-Checksum: abc123\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"hello" + + def test_chunked_hex_sizes(self): + """Chunk sizes should be parsed as hex.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"a\r\n" # 10 in hex + b"0123456789\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"0123456789" + + def test_chunked_uppercase_hex(self): + """Uppercase hex chunk sizes should work.""" + body_chunks = [] + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"A\r\n" # 10 in uppercase hex + b"0123456789\r\n" + b"0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert b"".join(body_chunks) == b"0123456789" + + +# ============================================================================ +# HEAD Request Tests +# ============================================================================ + +class TestHEADRequest: + """Test HEAD request handling.""" + + def test_head_request_no_body(self): + """HEAD request should have no body.""" + parser = PythonProtocol() + + parser.feed( + b"HEAD /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.is_complete + + +# ============================================================================ +# HTTP Method Tests +# ============================================================================ + +class TestHTTPMethods: + """Test HTTP method handling.""" + + def test_get_method(self): + """GET method should be parsed.""" + parser = PythonProtocol() + + parser.feed( + b"GET /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.is_complete + # method is bytes in the parser + assert parser.method == b"GET" + + def test_post_method(self): + """POST method should be parsed.""" + parser = PythonProtocol() + + parser.feed( + b"POST /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert parser.method == b"POST" + + def test_put_method(self): + """PUT method should be parsed.""" + parser = PythonProtocol() + + parser.feed( + b"PUT /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"Content-Length: 0\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert parser.method == b"PUT" + + def test_delete_method(self): + """DELETE method should be parsed.""" + parser = PythonProtocol() + + parser.feed( + b"DELETE /test HTTP/1.1\r\n" + b"Host: localhost\r\n" + b"\r\n" + ) + + assert parser.is_complete + assert parser.method == b"DELETE" + + +# ============================================================================ +# HTTP Scope Building Tests +# ============================================================================ + +class TestHTTPScopeBuilding: + """Test building ASGI HTTP scope.""" + + def _create_protocol(self): + """Create an ASGIProtocol instance for testing.""" + from gunicorn.asgi.protocol import ASGIProtocol + + worker = mock.Mock() + worker.cfg = Config() + worker.log = mock.Mock() + worker.asgi = mock.Mock() + + return ASGIProtocol(worker) + + def _create_mock_request(self, **kwargs): + """Create a mock HTTP request.""" + request = mock.Mock() + request.method = kwargs.get("method", "GET") + path = kwargs.get("path", "/") + request.path = path + request.raw_path = kwargs.get("raw_path", path.encode("latin-1")) + request.query = kwargs.get("query", "") + request.version = kwargs.get("version", (1, 1)) + request.scheme = kwargs.get("scheme", "http") + request.headers = kwargs.get("headers", []) + return request + + def test_scope_type_is_http(self): + """Scope type should be 'http'.""" + protocol = self._create_protocol() + request = self._create_mock_request() + + scope = protocol._build_http_scope(request, None, None) + + assert scope["type"] == "http" + + def test_scope_method_uppercase(self): + """Method in scope should be uppercase.""" + protocol = self._create_protocol() + request = self._create_mock_request(method="POST") + + scope = protocol._build_http_scope(request, None, None) + + assert scope["method"] == "POST" + + def test_scope_path_percent_encoded(self): + """Path with special characters should be handled.""" + protocol = self._create_protocol() + request = self._create_mock_request( + path="/api/users/john%20doe", + raw_path=b"/api/users/john%20doe", + ) + + scope = protocol._build_http_scope(request, None, None) + + assert scope["raw_path"] == b"/api/users/john%20doe" + + def test_scope_query_string_bytes(self): + """Query string should be bytes.""" + protocol = self._create_protocol() + request = self._create_mock_request(query="page=1&size=10") + + scope = protocol._build_http_scope(request, None, None) + + assert scope["query_string"] == b"page=1&size=10" + assert isinstance(scope["query_string"], bytes) + + def test_scope_server_info(self): + """Server info should be tuple of (host, port).""" + protocol = self._create_protocol() + request = self._create_mock_request() + + scope = protocol._build_http_scope( + request, + ("127.0.0.1", 8000), + ("192.168.1.1", 54321), + ) + + assert scope["server"] == ("127.0.0.1", 8000) + assert scope["client"] == ("192.168.1.1", 54321) + + def test_scope_asgi_version(self): + """ASGI version info should be present.""" + protocol = self._create_protocol() + request = self._create_mock_request() + + scope = protocol._build_http_scope(request, None, None) + + assert "asgi" in scope + assert scope["asgi"]["version"] == "3.0" diff --git a/tests/test_asgi_websocket_enhanced.py b/tests/test_asgi_websocket_enhanced.py new file mode 100644 index 00000000..ce8b7853 --- /dev/null +++ b/tests/test_asgi_websocket_enhanced.py @@ -0,0 +1,498 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Enhanced WebSocket ASGI tests. + +Tests for WebSocket message size limits, connection rejection, +subprotocol negotiation, and compression per ASGI 3.0 and RFC 6455. +""" + +import struct +from unittest import mock + +import pytest + + +# ============================================================================ +# WebSocket Message Size Tests +# ============================================================================ + +class TestWebSocketMessageSizeLimits: + """Test WebSocket message size limits and close code 1009.""" + + def test_close_code_1009_defined(self): + """Close code 1009 (message too big) should be defined.""" + from gunicorn.asgi.websocket import CLOSE_MESSAGE_TOO_BIG + + assert CLOSE_MESSAGE_TOO_BIG == 1009 + + def test_control_frame_max_payload_125_bytes(self): + """Control frames have max payload of 125 bytes (RFC 6455).""" + # Close frame max reason: 125 - 2 (close code) = 123 bytes + from gunicorn.asgi.websocket import CLOSE_NORMAL + + max_reason = "x" * 123 + payload = struct.pack("!H", CLOSE_NORMAL) + max_reason.encode("utf-8") + + assert len(payload) == 125 + + def test_text_message_encoding(self): + """Text messages should be UTF-8.""" + # Large valid UTF-8 message + large_text = "Hello " * 1000 + encoded = large_text.encode("utf-8") + + assert isinstance(encoded, bytes) + assert len(encoded) == 6000 + + def test_binary_message_allowed(self): + """Binary messages can contain any bytes.""" + binary_data = bytes(range(256)) * 10 + + assert len(binary_data) == 2560 + assert isinstance(binary_data, bytes) + + +# ============================================================================ +# WebSocket Connection Rejection Tests +# ============================================================================ + +class TestWebSocketConnectionRejection: + """Test WebSocket connection rejection responses.""" + + def _create_protocol(self, scope=None): + """Create a WebSocketProtocol instance.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + if scope is None: + scope = { + "type": "websocket", + "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], + } + + transport = mock.Mock() + + return WebSocketProtocol( + transport=transport, + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_reject_before_accept_closes_connection(self): + """Rejecting before accept should close with HTTP response.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + # Send close without accepting + await protocol._send({"type": "websocket.close", "code": 1000}) + + assert protocol.closed is True + + @pytest.mark.asyncio + async def test_close_with_custom_code(self): + """Close can specify custom close code.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + # Accept first + await protocol._send({"type": "websocket.accept"}) + + # Then close with custom code + await protocol._send({ + "type": "websocket.close", + "code": 4000, + "reason": "Custom close" + }) + + assert protocol.closed is True + # Verify close frame was sent (write called) + assert protocol.transport.write.call_count >= 2 + + @pytest.mark.asyncio + async def test_close_with_reason(self): + """Close can include reason string.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + await protocol._send({ + "type": "websocket.close", + "code": 1000, + "reason": "Normal closure" + }) + + assert protocol.closed is True + # Close frame was written + assert protocol.transport.write.call_count >= 2 + + +# ============================================================================ +# WebSocket Subprotocol Tests +# ============================================================================ + +class TestWebSocketSubprotocols: + """Test WebSocket subprotocol negotiation.""" + + def _create_protocol(self, subprotocols=None): + """Create a WebSocketProtocol with optional subprotocols.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + headers = [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")] + if subprotocols: + headers.append((b"sec-websocket-protocol", ", ".join(subprotocols).encode())) + + scope = { + "type": "websocket", + "headers": headers, + "subprotocols": subprotocols or [], + } + + transport = mock.Mock() + + return WebSocketProtocol( + transport=transport, + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_accept_without_subprotocol(self): + """Accept without subprotocol should work.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + + assert protocol.accepted is True + + @pytest.mark.asyncio + async def test_accept_with_subprotocol(self): + """Accept with subprotocol should include it in response.""" + protocol = self._create_protocol(subprotocols=["graphql-ws", "chat"]) + protocol.transport.write = mock.Mock() + + await protocol._send({ + "type": "websocket.accept", + "subprotocol": "graphql-ws" + }) + + assert protocol.accepted is True + + def test_subprotocol_in_scope(self): + """Subprotocols should be available in scope.""" + protocol = self._create_protocol(subprotocols=["graphql-ws", "chat"]) + + assert "subprotocols" in protocol.scope + assert protocol.scope["subprotocols"] == ["graphql-ws", "chat"] + + +# ============================================================================ +# WebSocket Accept Message Tests +# ============================================================================ + +class TestWebSocketAcceptMessage: + """Test WebSocket accept message handling.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = { + "type": "websocket", + "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], + } + + transport = mock.Mock() + + return WebSocketProtocol( + transport=transport, + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_accept_sets_accepted_flag(self): + """Accepting should set the accepted flag.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + assert protocol.accepted is False + + await protocol._send({"type": "websocket.accept"}) + + assert protocol.accepted is True + + @pytest.mark.asyncio + async def test_accept_with_headers(self): + """Accept can include additional headers.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({ + "type": "websocket.accept", + "headers": [ + (b"x-custom-header", b"custom-value"), + ], + }) + + assert protocol.accepted is True + + @pytest.mark.asyncio + async def test_double_accept_raises(self): + """Accepting twice should raise RuntimeError.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + + with pytest.raises(RuntimeError, match="already accepted"): + await protocol._send({"type": "websocket.accept"}) + + +# ============================================================================ +# WebSocket Send Message Tests +# ============================================================================ + +class TestWebSocketSendMessages: + """Test WebSocket send message handling.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = { + "type": "websocket", + "headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")], + } + + transport = mock.Mock() + + return WebSocketProtocol( + transport=transport, + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_send_text_message(self): + """Sending text message should work after accept.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + await protocol._send({ + "type": "websocket.send", + "text": "Hello, WebSocket!" + }) + + # Verify write was called (for accept and send) + assert protocol.transport.write.call_count >= 2 + + @pytest.mark.asyncio + async def test_send_binary_message(self): + """Sending binary message should work after accept.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + await protocol._send({ + "type": "websocket.send", + "bytes": b"\x00\x01\x02\x03" + }) + + assert protocol.transport.write.call_count >= 2 + + @pytest.mark.asyncio + async def test_send_before_accept_raises(self): + """Sending before accept should raise RuntimeError.""" + protocol = self._create_protocol() + + with pytest.raises(RuntimeError, match="not accepted"): + await protocol._send({ + "type": "websocket.send", + "text": "Hello" + }) + + @pytest.mark.asyncio + async def test_send_after_close_raises(self): + """Sending after close should raise RuntimeError.""" + protocol = self._create_protocol() + protocol.transport.write = mock.Mock() + + await protocol._send({"type": "websocket.accept"}) + await protocol._send({"type": "websocket.close", "code": 1000}) + + with pytest.raises(RuntimeError, match="closed"): + await protocol._send({ + "type": "websocket.send", + "text": "Hello" + }) + + +# ============================================================================ +# WebSocket Frame Building Tests +# ============================================================================ + +class TestWebSocketFrameBuilding: + """Test WebSocket frame construction.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = { + "type": "websocket", + "headers": [], + } + + return WebSocketProtocol( + transport=mock.Mock(), + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + def test_frame_header_fin_bit(self): + """FIN bit should be set for complete messages.""" + # FIN=1, opcode=1 (text) = 0b10000001 = 0x81 + first_byte = 0x81 + assert (first_byte >> 7) & 1 == 1 # FIN set + assert first_byte & 0x0F == 1 # OPCODE text + + def test_frame_header_mask_bit(self): + """Server frames should NOT have MASK bit set.""" + # Server to client: MASK=0 + # Length 5, no mask = 0b00000101 = 0x05 + second_byte = 0x05 + assert (second_byte >> 7) & 1 == 0 # MASK not set + assert second_byte & 0x7F == 5 # Length + + def test_frame_length_encoding_small(self): + """Small payloads (< 126) use 7-bit length.""" + length = 100 + second_byte = length + assert second_byte & 0x7F == 100 + + def test_frame_length_encoding_medium(self): + """Medium payloads (126-65535) use 16-bit length.""" + length = 1000 + # Indicator byte + indicator = 126 + # Extended length as big-endian 16-bit + extended = struct.pack("!H", length) + + assert indicator == 126 + assert struct.unpack("!H", extended)[0] == 1000 + + def test_frame_length_encoding_large(self): + """Large payloads (> 65535) use 64-bit length.""" + length = 100000 + # Indicator byte + indicator = 127 + # Extended length as big-endian 64-bit + extended = struct.pack("!Q", length) + + assert indicator == 127 + assert struct.unpack("!Q", extended)[0] == 100000 + + +# ============================================================================ +# WebSocket Close Code Tests +# ============================================================================ + +class TestWebSocketCloseCodes: + """Test WebSocket close code handling.""" + + def test_all_close_codes_defined(self): + """All standard close codes should be defined.""" + from gunicorn.asgi import websocket + + assert websocket.CLOSE_NORMAL == 1000 + assert websocket.CLOSE_GOING_AWAY == 1001 + assert websocket.CLOSE_PROTOCOL_ERROR == 1002 + assert websocket.CLOSE_UNSUPPORTED == 1003 + assert websocket.CLOSE_NO_STATUS == 1005 + assert websocket.CLOSE_ABNORMAL == 1006 + assert websocket.CLOSE_INVALID_DATA == 1007 + assert websocket.CLOSE_POLICY_VIOLATION == 1008 + assert websocket.CLOSE_MESSAGE_TOO_BIG == 1009 + assert websocket.CLOSE_MANDATORY_EXT == 1010 + assert websocket.CLOSE_INTERNAL_ERROR == 1011 + + def test_close_code_payload_format(self): + """Close frame payload should be code + optional reason.""" + from gunicorn.asgi.websocket import CLOSE_NORMAL + + # Just code + payload_code_only = struct.pack("!H", CLOSE_NORMAL) + assert len(payload_code_only) == 2 + + # Code + reason + reason = "Goodbye" + payload_with_reason = struct.pack("!H", CLOSE_NORMAL) + reason.encode("utf-8") + assert len(payload_with_reason) == 2 + len(reason) + + +# ============================================================================ +# WebSocket Receive Queue Tests +# ============================================================================ + +class TestWebSocketReceiveQueue: + """Test WebSocket receive queue handling.""" + + def _create_protocol(self): + """Create a WebSocketProtocol instance.""" + from gunicorn.asgi.websocket import WebSocketProtocol + + scope = { + "type": "websocket", + "headers": [], + } + + return WebSocketProtocol( + transport=mock.Mock(), + scope=scope, + app=mock.AsyncMock(), + log=mock.Mock(), + ) + + @pytest.mark.asyncio + async def test_receive_returns_from_queue(self): + """Receive should return messages from the queue.""" + protocol = self._create_protocol() + + # Put a connect message on the queue + await protocol._receive_queue.put({"type": "websocket.connect"}) + + # Receive should return it + message = await protocol._receive() + assert message["type"] == "websocket.connect" + + @pytest.mark.asyncio + async def test_receive_blocks_on_empty_queue(self): + """Receive should block when queue is empty.""" + import asyncio + protocol = self._create_protocol() + + # Start receive task + receive_task = asyncio.create_task(protocol._receive()) + + # Give it a moment + await asyncio.sleep(0.01) + + # Should not be done yet (blocked) + assert not receive_task.done() + + # Put a message + await protocol._receive_queue.put({"type": "websocket.connect"}) + + # Now should complete + message = await asyncio.wait_for(receive_task, timeout=1.0) + assert message["type"] == "websocket.connect"