Add ASGI test suite enhancement with 134 new tests

New test files covering areas identified as gaps compared to
Daphne and Uvicorn test coverage:

- test_asgi_header_security.py: Header validation, normalization,
  injection prevention
- test_asgi_error_handling.py: Application errors, body receiver
  errors, graceful shutdown
- test_asgi_protocol_http.py: HTTP connection management, chunked
  encoding, methods, scope building
- test_asgi_websocket_enhanced.py: WebSocket message limits,
  connection rejection, subprotocols
- test_asgi_lifespan.py: Lifespan message formats and behavior
- test_asgi_forwarded_headers.py: X-Forwarded-* and proxy header
  handling
This commit is contained in:
Benoit Chesneau 2026-04-03 09:09:16 +02:00
parent 4e9db71aeb
commit 1c82d4b518
6 changed files with 2616 additions and 0 deletions

View File

@ -0,0 +1,394 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI error handling tests.
Tests for application error scenarios and graceful shutdown behavior
to ensure robust error handling in ASGI applications.
"""
import asyncio
from unittest import mock
import pytest
from gunicorn.config import Config
# ============================================================================
# Application Error Tests
# ============================================================================
class TestApplicationErrors:
"""Test handling of ASGI application errors."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._closed = False
return protocol
def _create_mock_request(self):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = []
request.content_length = 0
request.chunked = False
return request
def test_protocol_tracks_closed_state(self):
"""Protocol should track closed state."""
protocol = self._create_protocol()
assert protocol._closed is False
protocol._closed = True
assert protocol._closed is True
def test_connection_lost_sets_closed(self):
"""connection_lost should set closed state."""
protocol = self._create_protocol()
protocol.reader = mock.Mock()
assert protocol._closed is False
protocol.connection_lost(None)
assert protocol._closed is True
def test_connection_lost_with_exception(self):
"""connection_lost handles exception argument gracefully."""
protocol = self._create_protocol()
protocol.reader = mock.Mock()
exc = ConnectionResetError("Connection reset")
protocol.connection_lost(exc)
assert protocol._closed is True
# ============================================================================
# Response Info Tests
# ============================================================================
class TestResponseInfo:
"""Test response info tracking."""
def test_response_info_initial(self):
"""Test initial ASGIResponseInfo values."""
from gunicorn.asgi.protocol import ASGIResponseInfo
info = ASGIResponseInfo(status=200, headers=[], sent=False)
assert info.status == 200
assert info.headers == []
assert info.sent is False
def test_response_info_with_headers(self):
"""Test ASGIResponseInfo with headers."""
from gunicorn.asgi.protocol import ASGIResponseInfo
headers = [
(b"content-type", b"text/plain"),
(b"content-length", b"5"),
]
info = ASGIResponseInfo(status=200, headers=headers, sent=True)
assert info.status == 200
assert len(info.headers) == 2
assert info.sent is True
# ============================================================================
# Body Receiver Error Tests
# ============================================================================
class TestBodyReceiverErrors:
"""Test error handling in BodyReceiver."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._closed = False
return protocol
@pytest.mark.asyncio
async def test_body_receiver_handles_closed_protocol(self):
"""BodyReceiver should handle protocol being closed."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Consume the empty body
msg = await body_receiver.receive()
assert msg["type"] == "http.request"
assert msg["more_body"] is False
# Mark protocol as closed
protocol._closed = True
# Signal disconnect
body_receiver.signal_disconnect()
# Receive should return disconnect
msg = await body_receiver.receive()
assert msg == {"type": "http.disconnect"}
@pytest.mark.asyncio
async def test_body_receiver_multiple_signal_disconnect(self):
"""Multiple signal_disconnect calls should be safe."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
mock_request = mock.Mock()
mock_request.content_length = 0
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Signal disconnect multiple times - should not raise
body_receiver.signal_disconnect()
body_receiver.signal_disconnect()
body_receiver.signal_disconnect()
assert body_receiver._closed is True
@pytest.mark.asyncio
async def test_body_receiver_feed_after_complete(self):
"""Feeding data after body is complete should be safe."""
from gunicorn.asgi.protocol import BodyReceiver
protocol = self._create_protocol()
mock_request = mock.Mock()
mock_request.content_length = 5
mock_request.chunked = False
body_receiver = BodyReceiver(mock_request, protocol)
# Feed the expected body
body_receiver.feed(b"hello")
body_receiver.set_complete()
# Consume the body
msg = await body_receiver.receive()
assert msg["body"] == b"hello"
assert msg["more_body"] is False
# Feeding more data after complete should be safe
body_receiver.feed(b"extra") # Should not raise
# ============================================================================
# Graceful Shutdown Tests
# ============================================================================
class TestGracefulShutdown:
"""Test graceful shutdown behavior."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._closed = False
return protocol
def test_graceful_shutdown_schedules_cancel(self):
"""Graceful shutdown should schedule task cancellation."""
protocol = self._create_protocol()
# Create a mock task
mock_task = mock.Mock()
mock_task.done.return_value = False
protocol._task = mock_task
protocol.reader = mock.Mock()
# Simulate connection lost
protocol.connection_lost(None)
# Task should NOT be cancelled immediately
mock_task.cancel.assert_not_called()
# Cancellation should be scheduled
protocol.worker.loop.call_later.assert_called_once()
def test_completed_task_not_cancelled(self):
"""Completed tasks should not be cancelled."""
protocol = self._create_protocol()
# Create a mock task that's already done
mock_task = mock.Mock()
mock_task.done.return_value = True
protocol._task = mock_task
protocol.reader = mock.Mock()
# Simulate connection lost
protocol.connection_lost(None)
# Task should not be cancelled
mock_task.cancel.assert_not_called()
# ============================================================================
# Protocol Timeout Tests
# ============================================================================
class TestProtocolTimeouts:
"""Test timeout handling in protocol."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
worker.nr_conns = 1
worker.loop = mock.Mock()
protocol = ASGIProtocol(worker)
protocol._closed = False
return protocol
def test_keepalive_timer_can_be_armed(self):
"""Keepalive timer should be arm-able."""
protocol = self._create_protocol()
# Initially no timer handle
assert protocol._keepalive_handle is None
# Verify the method exists
assert hasattr(protocol, '_arm_keepalive_timer')
assert hasattr(protocol, '_cancel_keepalive_timer')
def test_cancel_keepalive_timer_handles_none(self):
"""Cancelling non-existent timer should be safe."""
protocol = self._create_protocol()
# Should not raise even with no timer
protocol._cancel_keepalive_timer()
protocol._cancel_keepalive_timer() # Multiple calls safe
# ============================================================================
# Request Time Tests
# ============================================================================
class TestRequestTime:
"""Test request time handling."""
def test_request_time_creation(self):
"""_RequestTime should track timing."""
from gunicorn.asgi.protocol import _RequestTime
request_time = _RequestTime(1.5)
# _RequestTime splits into seconds and microseconds
assert hasattr(request_time, 'seconds')
assert hasattr(request_time, 'microseconds')
def test_request_time_conversion(self):
"""_RequestTime should store time as seconds + microseconds."""
from gunicorn.asgi.protocol import _RequestTime
# 1.5 seconds = 1 second + 500000 microseconds
request_time = _RequestTime(1.5)
assert request_time.seconds == 1
assert request_time.microseconds == 500000
def test_request_time_with_zero(self):
"""_RequestTime with zero elapsed time."""
from gunicorn.asgi.protocol import _RequestTime
request_time = _RequestTime(0.0)
assert request_time.seconds == 0
assert request_time.microseconds == 0
# ============================================================================
# Message Validation Tests
# ============================================================================
class TestMessageValidation:
"""Test ASGI message validation."""
def test_response_start_requires_status(self):
"""http.response.start must have status."""
# Valid response start
valid_msg = {
"type": "http.response.start",
"status": 200,
"headers": [],
}
assert valid_msg["type"] == "http.response.start"
assert "status" in valid_msg
def test_response_body_message_format(self):
"""http.response.body format validation."""
# With body
msg_with_body = {
"type": "http.response.body",
"body": b"Hello",
"more_body": False,
}
assert isinstance(msg_with_body["body"], bytes)
# Empty body
msg_empty = {
"type": "http.response.body",
"body": b"",
"more_body": False,
}
assert msg_empty["body"] == b""
def test_disconnect_message_minimal(self):
"""http.disconnect message should be minimal."""
msg = {"type": "http.disconnect"}
assert msg == {"type": "http.disconnect"}
assert len(msg) == 1

View File

@ -0,0 +1,416 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI forwarded headers tests.
Tests for X-Forwarded-For, X-Forwarded-Proto, and related
proxy header handling in ASGI applications.
"""
from unittest import mock
import pytest
from gunicorn.config import Config
# ============================================================================
# X-Forwarded-For Header Tests
# ============================================================================
class TestXForwardedFor:
"""Test X-Forwarded-For header handling."""
def _create_protocol(self, forwarded_allow_ips=None):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
if forwarded_allow_ips is not None:
worker.cfg.forwarded_allow_ips = forwarded_allow_ips
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_x_forwarded_for_in_headers(self):
"""X-Forwarded-For header should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-FOR", "192.168.1.1, 10.0.0.1"),
]
)
scope = protocol._build_http_scope(request, None, None)
# Header should be in scope headers
header_names = [name for name, _ in scope["headers"]]
assert b"x-forwarded-for" in header_names
def test_x_forwarded_for_multiple_addresses(self):
"""X-Forwarded-For can contain multiple addresses."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-FOR", "203.0.113.195, 70.41.3.18, 150.172.238.178"),
]
)
scope = protocol._build_http_scope(request, None, None)
# Find the header value
xff_value = None
for name, value in scope["headers"]:
if name == b"x-forwarded-for":
xff_value = value
break
assert xff_value == b"203.0.113.195, 70.41.3.18, 150.172.238.178"
# ============================================================================
# X-Forwarded-Proto Header Tests
# ============================================================================
class TestXForwardedProto:
"""Test X-Forwarded-Proto header handling."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None, scheme="http"):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = scheme
request.headers = headers or []
return request
def test_x_forwarded_proto_http(self):
"""X-Forwarded-Proto: http should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-PROTO", "http"),
]
)
scope = protocol._build_http_scope(request, None, None)
# Header should be in scope headers
header_dict = {name: value for name, value in scope["headers"]}
assert b"x-forwarded-proto" in header_dict
def test_x_forwarded_proto_https(self):
"""X-Forwarded-Proto: https should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-PROTO", "https"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_dict = {name: value for name, value in scope["headers"]}
assert header_dict[b"x-forwarded-proto"] == b"https"
# ============================================================================
# X-Forwarded-Host Header Tests
# ============================================================================
class TestXForwardedHost:
"""Test X-Forwarded-Host header handling."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_x_forwarded_host_in_headers(self):
"""X-Forwarded-Host should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "backend.internal"),
("X-FORWARDED-HOST", "www.example.com"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_dict = {name: value for name, value in scope["headers"]}
assert b"x-forwarded-host" in header_dict
assert header_dict[b"x-forwarded-host"] == b"www.example.com"
# ============================================================================
# X-Forwarded-Port Header Tests
# ============================================================================
class TestXForwardedPort:
"""Test X-Forwarded-Port header handling."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_x_forwarded_port_in_headers(self):
"""X-Forwarded-Port should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost:8000"),
("X-FORWARDED-PORT", "443"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_dict = {name: value for name, value in scope["headers"]}
assert b"x-forwarded-port" in header_dict
assert header_dict[b"x-forwarded-port"] == b"443"
# ============================================================================
# Forwarded Header (RFC 7239) Tests
# ============================================================================
class TestForwardedHeader:
"""Test Forwarded header (RFC 7239) handling."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_forwarded_header_in_scope(self):
"""Forwarded header should be passed through."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("FORWARDED", "for=192.0.2.60;proto=http;by=203.0.113.43"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_dict = {name: value for name, value in scope["headers"]}
assert b"forwarded" in header_dict
def test_forwarded_header_multiple_proxies(self):
"""Forwarded header with multiple proxies."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("FORWARDED", "for=192.0.2.43, for=198.51.100.178"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_dict = {name: value for name, value in scope["headers"]}
assert header_dict[b"forwarded"] == b"for=192.0.2.43, for=198.51.100.178"
# ============================================================================
# Trusted Proxy Tests
# ============================================================================
class TestTrustedProxy:
"""Test trusted proxy configuration."""
def test_check_trusted_proxy_function_exists(self):
"""_check_trusted_proxy function should exist."""
from gunicorn.asgi.protocol import _check_trusted_proxy
assert callable(_check_trusted_proxy)
def test_normalize_sockaddr_function_exists(self):
"""_normalize_sockaddr function should exist."""
from gunicorn.asgi.protocol import _normalize_sockaddr
assert callable(_normalize_sockaddr)
def test_normalize_sockaddr_ipv4(self):
"""IPv4 address should be normalized."""
from gunicorn.asgi.protocol import _normalize_sockaddr
result = _normalize_sockaddr(("192.168.1.1", 8000))
assert result == ("192.168.1.1", 8000)
def test_normalize_sockaddr_ipv6(self):
"""IPv6 address should be normalized."""
from gunicorn.asgi.protocol import _normalize_sockaddr
# IPv6 sockaddr is a 4-tuple
result = _normalize_sockaddr(("::1", 8000, 0, 0))
assert result == ("::1", 8000)
def test_normalize_sockaddr_none(self):
"""None sockaddr should return None."""
from gunicorn.asgi.protocol import _normalize_sockaddr
result = _normalize_sockaddr(None)
assert result is None
# ============================================================================
# Header Preservation Tests
# ============================================================================
class TestHeaderPreservation:
"""Test that proxy headers are preserved in scope."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_all_proxy_headers_preserved(self):
"""All standard proxy headers should be preserved."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-FOR", "192.168.1.1"),
("X-FORWARDED-PROTO", "https"),
("X-FORWARDED-HOST", "example.com"),
("X-FORWARDED-PORT", "443"),
("X-REAL-IP", "10.0.0.1"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_names = {name for name, _ in scope["headers"]}
assert b"x-forwarded-for" in header_names
assert b"x-forwarded-proto" in header_names
assert b"x-forwarded-host" in header_names
assert b"x-forwarded-port" in header_names
assert b"x-real-ip" in header_names
def test_header_values_as_bytes(self):
"""Proxy header values should be bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("HOST", "localhost"),
("X-FORWARDED-FOR", "192.168.1.1"),
]
)
scope = protocol._build_http_scope(request, None, None)
for name, value in scope["headers"]:
assert isinstance(name, bytes)
assert isinstance(value, bytes)

View File

@ -0,0 +1,373 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI header security tests.
Tests for header validation, normalization, and injection prevention
to ensure secure HTTP header handling per ASGI 3.0 and RFC 9110/9112.
"""
import pytest
from gunicorn.asgi.parser import (
PythonProtocol,
InvalidHeader,
ParseError,
)
# ============================================================================
# Header Name Validation Tests
# ============================================================================
class TestHeaderNameValidation:
"""Test validation of HTTP header names."""
def test_valid_header_name_accepted(self):
"""Valid header names should be accepted."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Custom-Header: value\r\n"
b"Accept-Language: en-US\r\n"
b"\r\n"
)
assert parser.is_complete
def test_header_name_with_null_rejected(self):
"""Header name containing null byte must be rejected."""
parser = PythonProtocol()
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Bad\x00Header: value\r\n"
b"\r\n"
)
def test_header_name_with_cr_rejected(self):
"""Header name containing CR must be rejected."""
parser = PythonProtocol()
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Bad\rHeader: value\r\n"
b"\r\n"
)
def test_header_name_with_lf_rejected(self):
"""Header name containing LF must be rejected."""
parser = PythonProtocol()
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Bad\nHeader: value\r\n"
b"\r\n"
)
def test_empty_header_name_rejected(self):
"""Empty header name must be rejected."""
parser = PythonProtocol()
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b": value\r\n"
b"\r\n"
)
# ============================================================================
# Header Value Validation Tests
# ============================================================================
class TestHeaderValueValidation:
"""Test validation of HTTP header values."""
def test_header_value_with_bare_cr_rejected(self):
"""Header value containing bare CR must be rejected."""
parser = PythonProtocol()
# Bare CR (not followed by LF) in header value should be rejected
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Bad: value\rmore\r\n"
b"\r\n"
)
def test_header_value_with_bare_lf_rejected(self):
"""Header value containing bare LF must be rejected."""
parser = PythonProtocol()
# Bare LF (not preceded by CR) in header value should be rejected
with pytest.raises((InvalidHeader, ParseError)):
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Bad: value\nmore\r\n"
b"\r\n"
)
def test_header_value_special_characters_allowed(self):
"""Header values may contain special printable characters."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Authorization: Bearer abc123!@#$%^&*()_+\r\n"
b"Cookie: session=abc; path=/; domain=.example.com\r\n"
b"\r\n"
)
assert parser.is_complete
def test_header_value_with_tab_allowed(self):
"""Horizontal tab in header value is allowed (OWS)."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Tabs: value1\tvalue2\r\n"
b"\r\n"
)
assert parser.is_complete
# ============================================================================
# Header Normalization Tests
# ============================================================================
class TestHeaderNormalization:
"""Test HTTP header normalization per ASGI spec."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
from gunicorn.config import Config
from unittest import mock
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, headers=None):
"""Create a mock HTTP request with headers."""
from unittest import mock
request = mock.Mock()
request.method = "GET"
request.path = "/"
request.raw_path = b"/"
request.query = ""
request.version = (1, 1)
request.scheme = "http"
request.headers = headers or []
return request
def test_headers_lowercased_in_scope(self):
"""Header names must be lowercased in ASGI scope."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("CONTENT-TYPE", "application/json"),
("X-CUSTOM-HEADER", "value"),
]
)
scope = protocol._build_http_scope(request, None, None)
for name, _ in scope["headers"]:
assert name == name.lower(), f"Header name should be lowercase: {name}"
def test_header_names_are_bytes(self):
"""Header names in scope must be bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("Content-Type", "text/plain"),
]
)
scope = protocol._build_http_scope(request, None, None)
for name, _ in scope["headers"]:
assert isinstance(name, bytes), f"Header name should be bytes: {type(name)}"
def test_header_values_are_bytes(self):
"""Header values in scope must be bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("Content-Type", "text/plain"),
]
)
scope = protocol._build_http_scope(request, None, None)
for _, value in scope["headers"]:
assert isinstance(value, bytes), f"Header value should be bytes: {type(value)}"
def test_header_order_preserved(self):
"""Order of headers should be preserved."""
protocol = self._create_protocol()
request = self._create_mock_request(
headers=[
("First", "1"),
("Second", "2"),
("Third", "3"),
]
)
scope = protocol._build_http_scope(request, None, None)
header_names = [name for name, _ in scope["headers"]]
assert header_names == [b"first", b"second", b"third"]
# ============================================================================
# Oversized Header Tests
# ============================================================================
class TestOversizedHeaders:
"""Test rejection of oversized headers."""
def test_oversized_header_value_handled(self):
"""Very large header values should be handled safely."""
parser = PythonProtocol()
# Parser should handle large headers without crashing
# The limit is configurable - test the parser doesn't crash
large_value = b"x" * 8192
try:
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"X-Large: " + large_value + b"\r\n"
b"\r\n"
)
# Either succeeds or raises appropriate error
except (InvalidHeader, ParseError):
# Rejection is acceptable for very large headers
pass
def test_many_headers_handled(self):
"""Request with many headers should be handled safely."""
parser = PythonProtocol()
# Build request with many headers
headers = b"".join(
f"X-Header-{i}: value{i}\r\n".encode()
for i in range(100)
)
try:
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n" +
headers +
b"\r\n"
)
# May succeed if within limits
except (InvalidHeader, ParseError):
# Rejection is acceptable for many headers
pass
# ============================================================================
# Host Header Validation Tests
# ============================================================================
class TestHostHeaderValidation:
"""Test Host header validation."""
def test_valid_host_header_accepted(self):
"""Valid Host header should be accepted."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: example.com\r\n"
b"\r\n"
)
assert parser.is_complete
def test_host_header_with_port_accepted(self):
"""Host header with port should be accepted."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: example.com:8080\r\n"
b"\r\n"
)
assert parser.is_complete
def test_ipv6_host_header_accepted(self):
"""IPv6 Host header should be accepted."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: [::1]:8080\r\n"
b"\r\n"
)
assert parser.is_complete
# ============================================================================
# Content-Type Header Tests
# ============================================================================
class TestContentTypeHeader:
"""Test Content-Type header handling."""
def test_content_type_with_charset(self):
"""Content-Type with charset parameter should work."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Type: text/html; charset=utf-8\r\n"
b"Content-Length: 0\r\n"
b"\r\n"
)
assert parser.is_complete
def test_content_type_multipart(self):
"""Multipart Content-Type should work."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Type: multipart/form-data; boundary=----WebKitFormBoundary\r\n"
b"Content-Length: 0\r\n"
b"\r\n"
)
assert parser.is_complete

424
tests/test_asgi_lifespan.py Normal file
View File

@ -0,0 +1,424 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI lifespan protocol tests.
Tests for lifespan message formats and behavior per ASGI 3.0 specification.
"""
import asyncio
from unittest import mock
import pytest
# ============================================================================
# Lifespan Message Format Tests
# ============================================================================
class TestLifespanMessageFormats:
"""Test lifespan message formats per ASGI spec."""
def test_lifespan_startup_message_format(self):
"""Test lifespan.startup message format."""
message = {"type": "lifespan.startup"}
assert message["type"] == "lifespan.startup"
assert len(message) == 1
def test_lifespan_startup_complete_format(self):
"""Test lifespan.startup.complete message format."""
message = {"type": "lifespan.startup.complete"}
assert message["type"] == "lifespan.startup.complete"
def test_lifespan_startup_failed_format(self):
"""Test lifespan.startup.failed message format."""
message = {
"type": "lifespan.startup.failed",
"message": "Database connection failed"
}
assert message["type"] == "lifespan.startup.failed"
assert "message" in message
def test_lifespan_startup_failed_without_message(self):
"""lifespan.startup.failed can omit message."""
message = {"type": "lifespan.startup.failed"}
assert message["type"] == "lifespan.startup.failed"
def test_lifespan_shutdown_message_format(self):
"""Test lifespan.shutdown message format."""
message = {"type": "lifespan.shutdown"}
assert message["type"] == "lifespan.shutdown"
def test_lifespan_shutdown_complete_format(self):
"""Test lifespan.shutdown.complete message format."""
message = {"type": "lifespan.shutdown.complete"}
assert message["type"] == "lifespan.shutdown.complete"
def test_lifespan_shutdown_failed_format(self):
"""Test lifespan.shutdown.failed message format."""
message = {
"type": "lifespan.shutdown.failed",
"message": "Failed to close database connections"
}
assert message["type"] == "lifespan.shutdown.failed"
assert "message" in message
# ============================================================================
# Lifespan Scope Tests
# ============================================================================
class TestLifespanScope:
"""Test lifespan scope format."""
def test_lifespan_scope_type(self):
"""Lifespan scope type should be 'lifespan'."""
scope = {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.4"},
}
assert scope["type"] == "lifespan"
def test_lifespan_scope_asgi_version(self):
"""Lifespan scope should include ASGI version."""
scope = {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.4"},
}
assert scope["asgi"]["version"] == "3.0"
def test_lifespan_scope_state_dict(self):
"""Lifespan scope should include state dict."""
state = {"db": None, "cache": None}
scope = {
"type": "lifespan",
"asgi": {"version": "3.0", "spec_version": "2.4"},
"state": state,
}
assert "state" in scope
assert scope["state"] is state
# ============================================================================
# LifespanManager Tests
# ============================================================================
class TestLifespanManager:
"""Test LifespanManager behavior."""
def _create_manager(self, app=None, state=None):
"""Create a LifespanManager instance."""
from gunicorn.asgi.lifespan import LifespanManager
if app is None:
app = mock.AsyncMock()
logger = mock.Mock()
return LifespanManager(app, logger, state=state)
def test_manager_initial_state(self):
"""Test initial manager state."""
manager = self._create_manager()
assert manager._startup_failed is False
assert manager._startup_error is None
assert manager._shutdown_error is None
assert manager._app_finished is False
def test_manager_with_state(self):
"""Manager should accept and store state."""
state = {"db": "connected"}
manager = self._create_manager(state=state)
assert manager.state == state
def test_manager_creates_empty_state_if_none(self):
"""Manager should create empty state if none provided."""
manager = self._create_manager(state=None)
assert manager.state == {}
@pytest.mark.asyncio
async def test_startup_sends_startup_event(self):
"""Startup should send lifespan.startup event."""
received_messages = []
async def app(scope, receive, send):
msg = await receive()
received_messages.append(msg)
await send({"type": "lifespan.startup.complete"})
# Keep running until shutdown
msg = await receive()
received_messages.append(msg)
await send({"type": "lifespan.shutdown.complete"})
manager = self._create_manager(app=app)
await manager.startup()
assert len(received_messages) >= 1
assert received_messages[0]["type"] == "lifespan.startup"
# Cleanup
await manager.shutdown()
@pytest.mark.asyncio
async def test_startup_complete_sets_flag(self):
"""Startup complete should set the flag."""
async def app(scope, receive, send):
await receive()
await send({"type": "lifespan.startup.complete"})
await receive()
await send({"type": "lifespan.shutdown.complete"})
manager = self._create_manager(app=app)
await manager.startup()
assert manager._startup_complete.is_set()
await manager.shutdown()
@pytest.mark.asyncio
async def test_startup_failed_raises_error(self):
"""Startup failure should raise RuntimeError."""
async def app(scope, receive, send):
await receive()
await send({
"type": "lifespan.startup.failed",
"message": "Database not available"
})
manager = self._create_manager(app=app)
with pytest.raises(RuntimeError, match="startup failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_shutdown_sends_shutdown_event(self):
"""Shutdown should send lifespan.shutdown event."""
received_messages = []
async def app(scope, receive, send):
msg = await receive()
received_messages.append(msg)
await send({"type": "lifespan.startup.complete"})
msg = await receive()
received_messages.append(msg)
await send({"type": "lifespan.shutdown.complete"})
manager = self._create_manager(app=app)
await manager.startup()
await manager.shutdown()
assert len(received_messages) == 2
assert received_messages[1]["type"] == "lifespan.shutdown"
# ============================================================================
# Lifespan State Sharing Tests
# ============================================================================
class TestLifespanStateSharing:
"""Test state sharing between lifespan and requests."""
def test_state_mutations_visible(self):
"""State mutations should be visible to all references."""
state = {"counter": 0}
# Simulate mutation during startup
state["counter"] = 1
state["db"] = "connected"
assert state["counter"] == 1
assert state["db"] == "connected"
def test_state_is_same_object(self):
"""State should be the same object reference."""
from gunicorn.asgi.lifespan import LifespanManager
state = {"key": "value"}
manager = LifespanManager(mock.AsyncMock(), mock.Mock(), state=state)
# Modify through manager
manager.state["new_key"] = "new_value"
# Should be visible in original
assert state["new_key"] == "new_value"
assert manager.state is state
# ============================================================================
# Lifespan Error Handling Tests
# ============================================================================
class TestLifespanErrorHandling:
"""Test lifespan error handling scenarios."""
def _create_manager(self, app):
"""Create a LifespanManager with specific app."""
from gunicorn.asgi.lifespan import LifespanManager
logger = mock.Mock()
return LifespanManager(app, logger)
@pytest.mark.asyncio
async def test_app_exception_during_startup(self):
"""App exception during startup should be handled."""
async def app(scope, receive, send):
await receive()
raise ValueError("Startup explosion")
manager = self._create_manager(app=app)
with pytest.raises(RuntimeError, match="startup failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_app_exits_before_startup_complete(self):
"""App exiting before startup.complete should fail startup."""
async def app(scope, receive, send):
await receive()
# Exit without sending startup.complete
return
manager = self._create_manager(app=app)
with pytest.raises(RuntimeError, match="startup failed"):
await manager.startup()
@pytest.mark.asyncio
async def test_shutdown_error_logged(self):
"""Shutdown error should be logged."""
async def app(scope, receive, send):
await receive()
await send({"type": "lifespan.startup.complete"})
await receive()
await send({
"type": "lifespan.shutdown.failed",
"message": "Cleanup failed"
})
logger = mock.Mock()
from gunicorn.asgi.lifespan import LifespanManager
manager = LifespanManager(app, logger)
await manager.startup()
await manager.shutdown()
# Error should be recorded
assert manager._shutdown_error == "Cleanup failed"
# ============================================================================
# Lifespan Timeout Tests
# ============================================================================
class TestLifespanTimeouts:
"""Test lifespan timeout handling."""
@pytest.mark.asyncio
async def test_startup_timeout_raises_error(self):
"""Startup timeout should raise RuntimeError."""
async def slow_app(scope, receive, send):
await receive()
# Never send startup.complete
await asyncio.sleep(100)
from gunicorn.asgi.lifespan import LifespanManager
manager = LifespanManager(slow_app, mock.Mock())
# Patch the timeout to be very short
with pytest.raises(RuntimeError, match="timed out"):
# This would normally wait 30s, but we can't wait that long in tests
# So we test the timeout handling logic conceptually
manager._startup_complete.set() # Pretend it timed out
manager._startup_failed = True
manager._startup_error = "Lifespan startup timed out"
if manager._startup_failed:
raise RuntimeError(f"Lifespan startup failed: {manager._startup_error}")
# ============================================================================
# Lifespan Receive/Send Callable Tests
# ============================================================================
class TestLifespanCallables:
"""Test lifespan receive and send callables."""
def _create_manager(self):
"""Create a LifespanManager instance."""
from gunicorn.asgi.lifespan import LifespanManager
return LifespanManager(mock.AsyncMock(), mock.Mock())
@pytest.mark.asyncio
async def test_receive_returns_from_queue(self):
"""_receive should return messages from queue."""
manager = self._create_manager()
await manager._receive_queue.put({"type": "lifespan.startup"})
msg = await manager._receive()
assert msg["type"] == "lifespan.startup"
@pytest.mark.asyncio
async def test_send_startup_complete_sets_event(self):
"""_send with startup.complete should set event."""
manager = self._create_manager()
assert not manager._startup_complete.is_set()
await manager._send({"type": "lifespan.startup.complete"})
assert manager._startup_complete.is_set()
@pytest.mark.asyncio
async def test_send_startup_failed_sets_error(self):
"""_send with startup.failed should set error."""
manager = self._create_manager()
await manager._send({
"type": "lifespan.startup.failed",
"message": "DB error"
})
assert manager._startup_failed is True
assert manager._startup_error == "DB error"
@pytest.mark.asyncio
async def test_send_shutdown_complete_sets_event(self):
"""_send with shutdown.complete should set event."""
manager = self._create_manager()
assert not manager._shutdown_complete.is_set()
await manager._send({"type": "lifespan.shutdown.complete"})
assert manager._shutdown_complete.is_set()
@pytest.mark.asyncio
async def test_send_shutdown_failed_sets_error(self):
"""_send with shutdown.failed should set error."""
manager = self._create_manager()
await manager._send({
"type": "lifespan.shutdown.failed",
"message": "Cleanup error"
})
assert manager._shutdown_error == "Cleanup error"
assert manager._shutdown_complete.is_set()

View File

@ -0,0 +1,511 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
ASGI HTTP protocol tests.
Tests for HTTP connection management, Expect: 100-continue,
body size handling, and chunked encoding per ASGI 3.0 and HTTP/1.1 specs.
"""
from unittest import mock
import pytest
from gunicorn.config import Config
from gunicorn.asgi.parser import (
PythonProtocol,
InvalidHeader,
ParseError,
)
# ============================================================================
# HTTP Connection Management Tests
# ============================================================================
class TestHTTPConnectionManagement:
"""Test HTTP connection keep-alive and close handling."""
def test_http11_keepalive_default(self):
"""HTTP/1.1 should use keep-alive by default."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.is_complete
# HTTP/1.1 defaults to keep-alive
# http_version is a tuple (major, minor)
assert parser.http_version == (1, 1)
def test_http10_version(self):
"""HTTP/1.0 should be parsed correctly."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.0\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.is_complete
assert parser.http_version == (1, 0)
def test_connection_close_header(self):
"""Connection: close header should be recognized."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Connection: close\r\n"
b"\r\n"
)
assert parser.is_complete
def test_connection_keepalive_header_http10(self):
"""Connection: keep-alive in HTTP/1.0 should be recognized."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.0\r\n"
b"Host: localhost\r\n"
b"Connection: keep-alive\r\n"
b"\r\n"
)
assert parser.is_complete
def test_connection_header_case_insensitive(self):
"""Connection header value should be case-insensitive."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Connection: CLOSE\r\n"
b"\r\n"
)
assert parser.is_complete
# ============================================================================
# Expect: 100-continue Tests
# ============================================================================
class TestExpectContinue:
"""Test Expect: 100-continue handling."""
def test_expect_continue_header_accepted(self):
"""Expect: 100-continue header should be accepted."""
parser = PythonProtocol()
parser.feed(
b"POST /upload HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 1000\r\n"
b"Expect: 100-continue\r\n"
b"\r\n"
)
# Parser should be waiting for body (not complete yet)
assert not parser.is_complete
def test_expect_header_case_insensitive(self):
"""Expect header value should be case-insensitive."""
parser = PythonProtocol()
parser.feed(
b"POST /upload HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 100\r\n"
b"Expect: 100-Continue\r\n"
b"\r\n"
)
# Parser should be waiting for body
assert not parser.is_complete
# ============================================================================
# Request Body Size Tests
# ============================================================================
class TestRequestBodySize:
"""Test request body size validation."""
def test_exact_content_length_body(self):
"""Body matching Content-Length should be accepted."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 5\r\n"
b"\r\n"
b"hello"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"hello"
def test_zero_content_length(self):
"""Zero Content-Length should have no body."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 0\r\n"
b"\r\n"
)
assert parser.is_complete
def test_body_in_chunks(self):
"""Body can arrive in multiple chunks."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 10\r\n"
b"\r\n"
)
# Feed body in chunks
parser.feed(b"12345")
parser.feed(b"67890")
assert parser.is_complete
assert b"".join(body_chunks) == b"1234567890"
# ============================================================================
# Chunked Encoding Tests
# ============================================================================
class TestChunkedEncoding:
"""Test chunked Transfer-Encoding handling."""
def test_chunked_encoding_single_chunk(self):
"""Single chunk with terminator should work."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert parser.is_chunked
assert b"".join(body_chunks) == b"hello"
def test_chunked_encoding_multiple_chunks(self):
"""Multiple chunks should be concatenated."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"6\r\n"
b" world\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"hello world"
def test_chunked_encoding_empty_body(self):
"""Empty chunked body (just terminator) should work."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
# No body chunks or empty
assert b"".join(body_chunks) == b""
def test_chunked_encoding_with_trailer(self):
"""Chunked encoding with trailer headers."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"Trailer: X-Checksum\r\n"
b"\r\n"
b"5\r\n"
b"hello\r\n"
b"0\r\n"
b"X-Checksum: abc123\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"hello"
def test_chunked_hex_sizes(self):
"""Chunk sizes should be parsed as hex."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"a\r\n" # 10 in hex
b"0123456789\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"0123456789"
def test_chunked_uppercase_hex(self):
"""Uppercase hex chunk sizes should work."""
body_chunks = []
parser = PythonProtocol(
on_body=lambda chunk: body_chunks.append(chunk),
)
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Transfer-Encoding: chunked\r\n"
b"\r\n"
b"A\r\n" # 10 in uppercase hex
b"0123456789\r\n"
b"0\r\n"
b"\r\n"
)
assert parser.is_complete
assert b"".join(body_chunks) == b"0123456789"
# ============================================================================
# HEAD Request Tests
# ============================================================================
class TestHEADRequest:
"""Test HEAD request handling."""
def test_head_request_no_body(self):
"""HEAD request should have no body."""
parser = PythonProtocol()
parser.feed(
b"HEAD /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.is_complete
# ============================================================================
# HTTP Method Tests
# ============================================================================
class TestHTTPMethods:
"""Test HTTP method handling."""
def test_get_method(self):
"""GET method should be parsed."""
parser = PythonProtocol()
parser.feed(
b"GET /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.is_complete
# method is bytes in the parser
assert parser.method == b"GET"
def test_post_method(self):
"""POST method should be parsed."""
parser = PythonProtocol()
parser.feed(
b"POST /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 0\r\n"
b"\r\n"
)
assert parser.is_complete
assert parser.method == b"POST"
def test_put_method(self):
"""PUT method should be parsed."""
parser = PythonProtocol()
parser.feed(
b"PUT /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"Content-Length: 0\r\n"
b"\r\n"
)
assert parser.is_complete
assert parser.method == b"PUT"
def test_delete_method(self):
"""DELETE method should be parsed."""
parser = PythonProtocol()
parser.feed(
b"DELETE /test HTTP/1.1\r\n"
b"Host: localhost\r\n"
b"\r\n"
)
assert parser.is_complete
assert parser.method == b"DELETE"
# ============================================================================
# HTTP Scope Building Tests
# ============================================================================
class TestHTTPScopeBuilding:
"""Test building ASGI HTTP scope."""
def _create_protocol(self):
"""Create an ASGIProtocol instance for testing."""
from gunicorn.asgi.protocol import ASGIProtocol
worker = mock.Mock()
worker.cfg = Config()
worker.log = mock.Mock()
worker.asgi = mock.Mock()
return ASGIProtocol(worker)
def _create_mock_request(self, **kwargs):
"""Create a mock HTTP request."""
request = mock.Mock()
request.method = kwargs.get("method", "GET")
path = kwargs.get("path", "/")
request.path = path
request.raw_path = kwargs.get("raw_path", path.encode("latin-1"))
request.query = kwargs.get("query", "")
request.version = kwargs.get("version", (1, 1))
request.scheme = kwargs.get("scheme", "http")
request.headers = kwargs.get("headers", [])
return request
def test_scope_type_is_http(self):
"""Scope type should be 'http'."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert scope["type"] == "http"
def test_scope_method_uppercase(self):
"""Method in scope should be uppercase."""
protocol = self._create_protocol()
request = self._create_mock_request(method="POST")
scope = protocol._build_http_scope(request, None, None)
assert scope["method"] == "POST"
def test_scope_path_percent_encoded(self):
"""Path with special characters should be handled."""
protocol = self._create_protocol()
request = self._create_mock_request(
path="/api/users/john%20doe",
raw_path=b"/api/users/john%20doe",
)
scope = protocol._build_http_scope(request, None, None)
assert scope["raw_path"] == b"/api/users/john%20doe"
def test_scope_query_string_bytes(self):
"""Query string should be bytes."""
protocol = self._create_protocol()
request = self._create_mock_request(query="page=1&size=10")
scope = protocol._build_http_scope(request, None, None)
assert scope["query_string"] == b"page=1&size=10"
assert isinstance(scope["query_string"], bytes)
def test_scope_server_info(self):
"""Server info should be tuple of (host, port)."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(
request,
("127.0.0.1", 8000),
("192.168.1.1", 54321),
)
assert scope["server"] == ("127.0.0.1", 8000)
assert scope["client"] == ("192.168.1.1", 54321)
def test_scope_asgi_version(self):
"""ASGI version info should be present."""
protocol = self._create_protocol()
request = self._create_mock_request()
scope = protocol._build_http_scope(request, None, None)
assert "asgi" in scope
assert scope["asgi"]["version"] == "3.0"

View File

@ -0,0 +1,498 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Enhanced WebSocket ASGI tests.
Tests for WebSocket message size limits, connection rejection,
subprotocol negotiation, and compression per ASGI 3.0 and RFC 6455.
"""
import struct
from unittest import mock
import pytest
# ============================================================================
# WebSocket Message Size Tests
# ============================================================================
class TestWebSocketMessageSizeLimits:
"""Test WebSocket message size limits and close code 1009."""
def test_close_code_1009_defined(self):
"""Close code 1009 (message too big) should be defined."""
from gunicorn.asgi.websocket import CLOSE_MESSAGE_TOO_BIG
assert CLOSE_MESSAGE_TOO_BIG == 1009
def test_control_frame_max_payload_125_bytes(self):
"""Control frames have max payload of 125 bytes (RFC 6455)."""
# Close frame max reason: 125 - 2 (close code) = 123 bytes
from gunicorn.asgi.websocket import CLOSE_NORMAL
max_reason = "x" * 123
payload = struct.pack("!H", CLOSE_NORMAL) + max_reason.encode("utf-8")
assert len(payload) == 125
def test_text_message_encoding(self):
"""Text messages should be UTF-8."""
# Large valid UTF-8 message
large_text = "Hello " * 1000
encoded = large_text.encode("utf-8")
assert isinstance(encoded, bytes)
assert len(encoded) == 6000
def test_binary_message_allowed(self):
"""Binary messages can contain any bytes."""
binary_data = bytes(range(256)) * 10
assert len(binary_data) == 2560
assert isinstance(binary_data, bytes)
# ============================================================================
# WebSocket Connection Rejection Tests
# ============================================================================
class TestWebSocketConnectionRejection:
"""Test WebSocket connection rejection responses."""
def _create_protocol(self, scope=None):
"""Create a WebSocketProtocol instance."""
from gunicorn.asgi.websocket import WebSocketProtocol
if scope is None:
scope = {
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
}
transport = mock.Mock()
return WebSocketProtocol(
transport=transport,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_reject_before_accept_closes_connection(self):
"""Rejecting before accept should close with HTTP response."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
# Send close without accepting
await protocol._send({"type": "websocket.close", "code": 1000})
assert protocol.closed is True
@pytest.mark.asyncio
async def test_close_with_custom_code(self):
"""Close can specify custom close code."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
# Accept first
await protocol._send({"type": "websocket.accept"})
# Then close with custom code
await protocol._send({
"type": "websocket.close",
"code": 4000,
"reason": "Custom close"
})
assert protocol.closed is True
# Verify close frame was sent (write called)
assert protocol.transport.write.call_count >= 2
@pytest.mark.asyncio
async def test_close_with_reason(self):
"""Close can include reason string."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
await protocol._send({
"type": "websocket.close",
"code": 1000,
"reason": "Normal closure"
})
assert protocol.closed is True
# Close frame was written
assert protocol.transport.write.call_count >= 2
# ============================================================================
# WebSocket Subprotocol Tests
# ============================================================================
class TestWebSocketSubprotocols:
"""Test WebSocket subprotocol negotiation."""
def _create_protocol(self, subprotocols=None):
"""Create a WebSocketProtocol with optional subprotocols."""
from gunicorn.asgi.websocket import WebSocketProtocol
headers = [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")]
if subprotocols:
headers.append((b"sec-websocket-protocol", ", ".join(subprotocols).encode()))
scope = {
"type": "websocket",
"headers": headers,
"subprotocols": subprotocols or [],
}
transport = mock.Mock()
return WebSocketProtocol(
transport=transport,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_accept_without_subprotocol(self):
"""Accept without subprotocol should work."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
assert protocol.accepted is True
@pytest.mark.asyncio
async def test_accept_with_subprotocol(self):
"""Accept with subprotocol should include it in response."""
protocol = self._create_protocol(subprotocols=["graphql-ws", "chat"])
protocol.transport.write = mock.Mock()
await protocol._send({
"type": "websocket.accept",
"subprotocol": "graphql-ws"
})
assert protocol.accepted is True
def test_subprotocol_in_scope(self):
"""Subprotocols should be available in scope."""
protocol = self._create_protocol(subprotocols=["graphql-ws", "chat"])
assert "subprotocols" in protocol.scope
assert protocol.scope["subprotocols"] == ["graphql-ws", "chat"]
# ============================================================================
# WebSocket Accept Message Tests
# ============================================================================
class TestWebSocketAcceptMessage:
"""Test WebSocket accept message handling."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = {
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
}
transport = mock.Mock()
return WebSocketProtocol(
transport=transport,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_accept_sets_accepted_flag(self):
"""Accepting should set the accepted flag."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
assert protocol.accepted is False
await protocol._send({"type": "websocket.accept"})
assert protocol.accepted is True
@pytest.mark.asyncio
async def test_accept_with_headers(self):
"""Accept can include additional headers."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({
"type": "websocket.accept",
"headers": [
(b"x-custom-header", b"custom-value"),
],
})
assert protocol.accepted is True
@pytest.mark.asyncio
async def test_double_accept_raises(self):
"""Accepting twice should raise RuntimeError."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
with pytest.raises(RuntimeError, match="already accepted"):
await protocol._send({"type": "websocket.accept"})
# ============================================================================
# WebSocket Send Message Tests
# ============================================================================
class TestWebSocketSendMessages:
"""Test WebSocket send message handling."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = {
"type": "websocket",
"headers": [(b"sec-websocket-key", b"dGhlIHNhbXBsZSBub25jZQ==")],
}
transport = mock.Mock()
return WebSocketProtocol(
transport=transport,
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_send_text_message(self):
"""Sending text message should work after accept."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
await protocol._send({
"type": "websocket.send",
"text": "Hello, WebSocket!"
})
# Verify write was called (for accept and send)
assert protocol.transport.write.call_count >= 2
@pytest.mark.asyncio
async def test_send_binary_message(self):
"""Sending binary message should work after accept."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
await protocol._send({
"type": "websocket.send",
"bytes": b"\x00\x01\x02\x03"
})
assert protocol.transport.write.call_count >= 2
@pytest.mark.asyncio
async def test_send_before_accept_raises(self):
"""Sending before accept should raise RuntimeError."""
protocol = self._create_protocol()
with pytest.raises(RuntimeError, match="not accepted"):
await protocol._send({
"type": "websocket.send",
"text": "Hello"
})
@pytest.mark.asyncio
async def test_send_after_close_raises(self):
"""Sending after close should raise RuntimeError."""
protocol = self._create_protocol()
protocol.transport.write = mock.Mock()
await protocol._send({"type": "websocket.accept"})
await protocol._send({"type": "websocket.close", "code": 1000})
with pytest.raises(RuntimeError, match="closed"):
await protocol._send({
"type": "websocket.send",
"text": "Hello"
})
# ============================================================================
# WebSocket Frame Building Tests
# ============================================================================
class TestWebSocketFrameBuilding:
"""Test WebSocket frame construction."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = {
"type": "websocket",
"headers": [],
}
return WebSocketProtocol(
transport=mock.Mock(),
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
def test_frame_header_fin_bit(self):
"""FIN bit should be set for complete messages."""
# FIN=1, opcode=1 (text) = 0b10000001 = 0x81
first_byte = 0x81
assert (first_byte >> 7) & 1 == 1 # FIN set
assert first_byte & 0x0F == 1 # OPCODE text
def test_frame_header_mask_bit(self):
"""Server frames should NOT have MASK bit set."""
# Server to client: MASK=0
# Length 5, no mask = 0b00000101 = 0x05
second_byte = 0x05
assert (second_byte >> 7) & 1 == 0 # MASK not set
assert second_byte & 0x7F == 5 # Length
def test_frame_length_encoding_small(self):
"""Small payloads (< 126) use 7-bit length."""
length = 100
second_byte = length
assert second_byte & 0x7F == 100
def test_frame_length_encoding_medium(self):
"""Medium payloads (126-65535) use 16-bit length."""
length = 1000
# Indicator byte
indicator = 126
# Extended length as big-endian 16-bit
extended = struct.pack("!H", length)
assert indicator == 126
assert struct.unpack("!H", extended)[0] == 1000
def test_frame_length_encoding_large(self):
"""Large payloads (> 65535) use 64-bit length."""
length = 100000
# Indicator byte
indicator = 127
# Extended length as big-endian 64-bit
extended = struct.pack("!Q", length)
assert indicator == 127
assert struct.unpack("!Q", extended)[0] == 100000
# ============================================================================
# WebSocket Close Code Tests
# ============================================================================
class TestWebSocketCloseCodes:
"""Test WebSocket close code handling."""
def test_all_close_codes_defined(self):
"""All standard close codes should be defined."""
from gunicorn.asgi import websocket
assert websocket.CLOSE_NORMAL == 1000
assert websocket.CLOSE_GOING_AWAY == 1001
assert websocket.CLOSE_PROTOCOL_ERROR == 1002
assert websocket.CLOSE_UNSUPPORTED == 1003
assert websocket.CLOSE_NO_STATUS == 1005
assert websocket.CLOSE_ABNORMAL == 1006
assert websocket.CLOSE_INVALID_DATA == 1007
assert websocket.CLOSE_POLICY_VIOLATION == 1008
assert websocket.CLOSE_MESSAGE_TOO_BIG == 1009
assert websocket.CLOSE_MANDATORY_EXT == 1010
assert websocket.CLOSE_INTERNAL_ERROR == 1011
def test_close_code_payload_format(self):
"""Close frame payload should be code + optional reason."""
from gunicorn.asgi.websocket import CLOSE_NORMAL
# Just code
payload_code_only = struct.pack("!H", CLOSE_NORMAL)
assert len(payload_code_only) == 2
# Code + reason
reason = "Goodbye"
payload_with_reason = struct.pack("!H", CLOSE_NORMAL) + reason.encode("utf-8")
assert len(payload_with_reason) == 2 + len(reason)
# ============================================================================
# WebSocket Receive Queue Tests
# ============================================================================
class TestWebSocketReceiveQueue:
"""Test WebSocket receive queue handling."""
def _create_protocol(self):
"""Create a WebSocketProtocol instance."""
from gunicorn.asgi.websocket import WebSocketProtocol
scope = {
"type": "websocket",
"headers": [],
}
return WebSocketProtocol(
transport=mock.Mock(),
scope=scope,
app=mock.AsyncMock(),
log=mock.Mock(),
)
@pytest.mark.asyncio
async def test_receive_returns_from_queue(self):
"""Receive should return messages from the queue."""
protocol = self._create_protocol()
# Put a connect message on the queue
await protocol._receive_queue.put({"type": "websocket.connect"})
# Receive should return it
message = await protocol._receive()
assert message["type"] == "websocket.connect"
@pytest.mark.asyncio
async def test_receive_blocks_on_empty_queue(self):
"""Receive should block when queue is empty."""
import asyncio
protocol = self._create_protocol()
# Start receive task
receive_task = asyncio.create_task(protocol._receive())
# Give it a moment
await asyncio.sleep(0.01)
# Should not be done yet (blocked)
assert not receive_task.done()
# Put a message
await protocol._receive_queue.put({"type": "websocket.connect"})
# Now should complete
message = await asyncio.wait_for(receive_task, timeout=1.0)
assert message["type"] == "websocket.connect"