# # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. """ Tests for ASGI worker components. """ import asyncio import io import ipaddress import pytest from unittest import mock from gunicorn.asgi.unreader import AsyncUnreader from gunicorn.asgi.message import AsyncRequest class MockStreamReader: """Mock asyncio.StreamReader for testing.""" def __init__(self, data): self.data = data self.pos = 0 async def read(self, size=-1): if self.pos >= len(self.data): return b"" if size < 0: result = self.data[self.pos:] self.pos = len(self.data) else: result = self.data[self.pos:self.pos + size] self.pos += size return result async def readexactly(self, n): if self.pos + n > len(self.data): raise asyncio.IncompleteReadError( self.data[self.pos:], n ) result = self.data[self.pos:self.pos + n] self.pos += n return result class MockConfig: """Mock gunicorn config for testing.""" def __init__(self): self.is_ssl = False self.proxy_protocol = "off" self.proxy_allow_ips = ["127.0.0.1"] self.forwarded_allow_ips = ["127.0.0.1"] self._proxy_allow_networks = None self._forwarded_allow_networks = None self.secure_scheme_headers = {} self.forwarder_headers = [] self.limit_request_line = 8190 self.limit_request_fields = 100 self.limit_request_field_size = 8190 self.permit_unconventional_http_method = False self.permit_unconventional_http_version = False self.permit_obsolete_folding = False self.casefold_http_method = False self.strip_header_spaces = False self.header_map = "refuse" def forwarded_allow_networks(self): if self._forwarded_allow_networks is None: self._forwarded_allow_networks = [ ipaddress.ip_network(addr) for addr in self.forwarded_allow_ips if addr != "*" ] return self._forwarded_allow_networks def proxy_allow_networks(self): if self._proxy_allow_networks is None: self._proxy_allow_networks = [ ipaddress.ip_network(addr) for addr in self.proxy_allow_ips if addr != "*" ] return self._proxy_allow_networks # AsyncUnreader Tests @pytest.mark.asyncio async def test_async_unreader_read_chunk(): """Test basic chunk reading.""" reader = MockStreamReader(b"hello world") unreader = AsyncUnreader(reader) data = await unreader.read() assert data == b"hello world" @pytest.mark.asyncio async def test_async_unreader_read_size(): """Test reading specific size.""" reader = MockStreamReader(b"hello world") unreader = AsyncUnreader(reader) data = await unreader.read(5) assert data == b"hello" @pytest.mark.asyncio async def test_async_unreader_unread(): """Test unread functionality.""" reader = MockStreamReader(b"hello world") unreader = AsyncUnreader(reader) # Read all data data = await unreader.read() assert data == b"hello world" # Unread some data unreader.unread(b"world") # Read again should get unread data data = await unreader.read() assert data == b"world" @pytest.mark.asyncio async def test_async_unreader_read_zero(): """Test reading zero bytes.""" reader = MockStreamReader(b"hello") unreader = AsyncUnreader(reader) data = await unreader.read(0) assert data == b"" @pytest.mark.asyncio async def test_async_unreader_read_empty(): """Test reading from empty stream.""" reader = MockStreamReader(b"") unreader = AsyncUnreader(reader) data = await unreader.read() assert data == b"" # AsyncRequest Tests @pytest.mark.asyncio async def test_async_request_simple_get(): """Test parsing a simple GET request.""" request_data = b"GET /path HTTP/1.1\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.method == "GET" assert request.path == "/path" assert request.version == (1, 1) assert ("HOST", "localhost") in request.headers @pytest.mark.asyncio async def test_async_request_with_query(): """Test parsing request with query string.""" request_data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.method == "GET" assert request.path == "/search" assert request.query == "q=test&page=1" @pytest.mark.asyncio async def test_async_request_post_with_body(): """Test parsing POST request with body.""" request_data = ( b"POST /submit HTTP/1.1\r\n" b"Host: localhost\r\n" b"Content-Length: 11\r\n" b"\r\n" b"hello=world" ) reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.method == "POST" assert request.path == "/submit" assert request.content_length == 11 # Read body body = await request.read_body(100) assert body == b"hello=world" @pytest.mark.asyncio async def test_async_request_multiple_headers(): """Test parsing request with multiple headers.""" request_data = ( b"GET / HTTP/1.1\r\n" b"Host: localhost\r\n" b"Accept: text/html\r\n" b"Accept-Language: en-US\r\n" b"Connection: keep-alive\r\n" b"\r\n" ) reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert len(request.headers) == 4 assert request.get_header("HOST") == "localhost" assert request.get_header("ACCEPT") == "text/html" @pytest.mark.asyncio async def test_async_request_should_close_http10(): """Test connection close detection for HTTP/1.0.""" request_data = b"GET / HTTP/1.0\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.version == (1, 0) assert request.should_close() is True @pytest.mark.asyncio async def test_async_request_should_close_connection_header(): """Test connection close detection with Connection header.""" request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.should_close() is True @pytest.mark.asyncio async def test_async_request_keepalive(): """Test keepalive detection.""" request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\nConnection: keep-alive\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.should_close() is False @pytest.mark.asyncio async def test_async_request_no_body_for_get(): """Test that GET requests have no body by default.""" request_data = b"GET / HTTP/1.1\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() request = await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) assert request.content_length == 0 body = await request.read_body() assert body == b"" # Error handling tests @pytest.mark.asyncio async def test_async_request_invalid_method(): """Test invalid HTTP method detection.""" from gunicorn.http.errors import InvalidRequestMethod request_data = b"ge!t / HTTP/1.1\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() with pytest.raises(InvalidRequestMethod): await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000)) @pytest.mark.asyncio async def test_async_request_invalid_http_version(): """Test invalid HTTP version detection.""" from gunicorn.http.errors import InvalidHTTPVersion request_data = b"GET / HTTP/2.0\r\nHost: localhost\r\n\r\n" reader = MockStreamReader(request_data) unreader = AsyncUnreader(reader) cfg = MockConfig() with pytest.raises(InvalidHTTPVersion): await AsyncRequest.parse(cfg, unreader, ("127.0.0.1", 8000))