diff --git a/gunicorn/config.py b/gunicorn/config.py index 522dcae9..1c36f987 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -2096,6 +2096,53 @@ class ProxyAllowFrom(Setting): """ +class Protocol(Setting): + name = "protocol" + section = "Server Mechanics" + cli = ["--protocol"] + meta = "STRING" + validator = validate_string + default = "http" + desc = """\ + The protocol for incoming connections. + + * ``http`` - Standard HTTP/1.x (default) + * ``uwsgi`` - uWSGI binary protocol (for nginx uwsgi_pass) + + When using the uWSGI protocol, Gunicorn can receive requests from + nginx using the uwsgi_pass directive:: + + upstream gunicorn { + server 127.0.0.1:8000; + } + location / { + uwsgi_pass gunicorn; + include uwsgi_params; + } + """ + + +class UWSGIAllowFrom(Setting): + name = "uwsgi_allow_ips" + section = "Server Mechanics" + cli = ["--uwsgi-allow-from"] + validator = validate_string_to_addr_list + default = "127.0.0.1,::1" + desc = """\ + IPs allowed to send uWSGI protocol requests (comma separated). + + Set to ``*`` to allow all IPs. This is useful for setups where you + don't know in advance the IP address of front-end, but instead have + ensured via other means that only your authorized front-ends can + access Gunicorn. + + .. note:: + + This option does not affect UNIX socket connections. Connections not associated with + an IP address are treated as allowed, unconditionally. + """ + + class KeyFile(Setting): name = "keyfile" section = "SSL" diff --git a/gunicorn/http/__init__.py b/gunicorn/http/__init__.py index 11473bb0..1d35b7c7 100644 --- a/gunicorn/http/__init__.py +++ b/gunicorn/http/__init__.py @@ -5,4 +5,23 @@ from gunicorn.http.message import Message, Request from gunicorn.http.parser import RequestParser -__all__ = ['Message', 'Request', 'RequestParser'] + +def get_parser(cfg, source, source_addr): + """Get appropriate parser based on protocol config. + + Args: + cfg: Gunicorn config object + source: Socket or iterable source + source_addr: Source address tuple or None + + Returns: + Parser instance (RequestParser or UWSGIParser) + """ + protocol = getattr(cfg, 'protocol', 'http') + if protocol == 'uwsgi': + from gunicorn.uwsgi.parser import UWSGIParser + return UWSGIParser(cfg, source, source_addr) + return RequestParser(cfg, source, source_addr) + + +__all__ = ['Message', 'Request', 'RequestParser', 'get_parser'] diff --git a/gunicorn/uwsgi/__init__.py b/gunicorn/uwsgi/__init__.py new file mode 100644 index 00000000..cdf4f60c --- /dev/null +++ b/gunicorn/uwsgi/__init__.py @@ -0,0 +1,21 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.uwsgi.message import UWSGIRequest +from gunicorn.uwsgi.parser import UWSGIParser +from gunicorn.uwsgi.errors import ( + UWSGIParseException, + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) + +__all__ = [ + 'UWSGIRequest', + 'UWSGIParser', + 'UWSGIParseException', + 'InvalidUWSGIHeader', + 'UnsupportedModifier', + 'ForbiddenUWSGIRequest', +] diff --git a/gunicorn/uwsgi/errors.py b/gunicorn/uwsgi/errors.py new file mode 100644 index 00000000..cdbaee21 --- /dev/null +++ b/gunicorn/uwsgi/errors.py @@ -0,0 +1,46 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +# We don't need to call super() in __init__ methods of our +# BaseException and Exception classes because we also define +# our own __str__ methods so there is no need to pass 'message' +# to the base class to get a meaningful output from 'str(exc)'. +# pylint: disable=super-init-not-called + + +class UWSGIParseException(Exception): + """Base exception for uWSGI protocol parsing errors.""" + + +class InvalidUWSGIHeader(UWSGIParseException): + """Raised when the uWSGI header is malformed.""" + + def __init__(self, msg=""): + self.msg = msg + self.code = 400 + + def __str__(self): + return "Invalid uWSGI header: %s" % self.msg + + +class UnsupportedModifier(UWSGIParseException): + """Raised when modifier1 is not 0 (WSGI request).""" + + def __init__(self, modifier): + self.modifier = modifier + self.code = 501 + + def __str__(self): + return "Unsupported uWSGI modifier1: %d" % self.modifier + + +class ForbiddenUWSGIRequest(UWSGIParseException): + """Raised when source IP is not in the allow list.""" + + def __init__(self, host): + self.host = host + self.code = 403 + + def __str__(self): + return "uWSGI request from %r not allowed" % self.host diff --git a/gunicorn/uwsgi/message.py b/gunicorn/uwsgi/message.py new file mode 100644 index 00000000..a63172eb --- /dev/null +++ b/gunicorn/uwsgi/message.py @@ -0,0 +1,232 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +import io + +from gunicorn.http.body import LengthReader, Body +from gunicorn.uwsgi.errors import ( + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) + + +# Maximum number of variables to prevent DoS +MAX_UWSGI_VARS = 1000 + + +class UWSGIRequest: + """uWSGI protocol request parser. + + The uWSGI protocol uses a 4-byte binary header: + - Byte 0: modifier1 (packet type, 0 = WSGI request) + - Bytes 1-2: datasize (16-bit little-endian, size of vars block) + - Byte 3: modifier2 (additional flags, typically 0) + + After the header: + 1. Vars block (datasize bytes): Key-value pairs containing WSGI environ + - Each pair: 2-byte key_size (LE) + key + 2-byte val_size (LE) + value + 2. Request body (determined by CONTENT_LENGTH in vars) + """ + + def __init__(self, cfg, unreader, peer_addr, req_number=1): + self.cfg = cfg + self.unreader = unreader + self.peer_addr = peer_addr + self.remote_addr = peer_addr + self.req_number = req_number + + # Request attributes (compatible with HTTP Request interface) + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = "" + self.version = (1, 1) # uWSGI is HTTP/1.1 compatible + self.headers = [] + self.trailers = [] + self.body = None + self.scheme = "https" if cfg.is_ssl else "http" + self.must_close = False + + # uWSGI specific + self.uwsgi_vars = {} + self.modifier1 = 0 + self.modifier2 = 0 + + # Proxy protocol compatibility + self.proxy_protocol_info = None + + # Check if the source IP is allowed + self._check_allowed_ip() + + # Parse the request + unused = self.parse(self.unreader) + self.unreader.unread(unused) + self.set_body_reader() + + def _check_allowed_ip(self): + """Verify source IP is in the allowed list.""" + allow_ips = getattr(self.cfg, 'uwsgi_allow_ips', ['127.0.0.1', '::1']) + + # UNIX sockets don't have IP addresses + if not isinstance(self.peer_addr, tuple): + return + + # Wildcard allows all + if '*' in allow_ips: + return + + if self.peer_addr[0] not in allow_ips: + raise ForbiddenUWSGIRequest(self.peer_addr[0]) + + def force_close(self): + """Force the connection to close after this request.""" + self.must_close = True + + def parse(self, unreader): + """Parse uWSGI packet header and vars block.""" + # Read the 4-byte header + header = self._read_exact(unreader, 4) + if len(header) < 4: + raise InvalidUWSGIHeader("incomplete header") + + self.modifier1 = header[0] + datasize = int.from_bytes(header[1:3], 'little') + self.modifier2 = header[3] + + # Only modifier1=0 (WSGI request) is supported + if self.modifier1 != 0: + raise UnsupportedModifier(self.modifier1) + + # Read the vars block + if datasize > 0: + vars_data = self._read_exact(unreader, datasize) + if len(vars_data) < datasize: + raise InvalidUWSGIHeader("incomplete vars block") + self._parse_vars(vars_data) + + # Extract HTTP request info from vars + self._extract_request_info() + + return b"" + + def _read_exact(self, unreader, size): + """Read exactly size bytes from the unreader.""" + buf = io.BytesIO() + remaining = size + + while remaining > 0: + data = unreader.read() + if not data: + break + buf.write(data) + remaining = size - buf.tell() + + result = buf.getvalue() + # Put back any extra bytes + if len(result) > size: + unreader.unread(result[size:]) + result = result[:size] + + return result + + def _parse_vars(self, data): + """Parse uWSGI vars block into key-value pairs. + + Format: key_size (2 bytes LE) + key + val_size (2 bytes LE) + value + """ + pos = 0 + var_count = 0 + + while pos < len(data): + if var_count >= MAX_UWSGI_VARS: + raise InvalidUWSGIHeader("too many variables") + + # Key size (2 bytes, little-endian) + if pos + 2 > len(data): + raise InvalidUWSGIHeader("truncated key size") + key_size = int.from_bytes(data[pos:pos + 2], 'little') + pos += 2 + + # Key + if pos + key_size > len(data): + raise InvalidUWSGIHeader("truncated key") + key = data[pos:pos + key_size].decode('latin-1') + pos += key_size + + # Value size (2 bytes, little-endian) + if pos + 2 > len(data): + raise InvalidUWSGIHeader("truncated value size") + val_size = int.from_bytes(data[pos:pos + 2], 'little') + pos += 2 + + # Value + if pos + val_size > len(data): + raise InvalidUWSGIHeader("truncated value") + value = data[pos:pos + val_size].decode('latin-1') + pos += val_size + + self.uwsgi_vars[key] = value + var_count += 1 + + def _extract_request_info(self): + """Extract HTTP request info from uWSGI vars.""" + # Method + self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET') + + # URI and path + self.path = self.uwsgi_vars.get('PATH_INFO', '/') + self.query = self.uwsgi_vars.get('QUERY_STRING', '') + + # Build URI + if self.query: + self.uri = "%s?%s" % (self.path, self.query) + else: + self.uri = self.path + + # Scheme + if self.uwsgi_vars.get('HTTPS', '').lower() in ('on', '1', 'true'): + self.scheme = 'https' + elif 'wsgi.url_scheme' in self.uwsgi_vars: + self.scheme = self.uwsgi_vars['wsgi.url_scheme'] + + # Extract HTTP headers (HTTP_* vars) + for key, value in self.uwsgi_vars.items(): + if key.startswith('HTTP_'): + # Convert HTTP_HEADER_NAME to HEADER-NAME + header_name = key[5:].replace('_', '-') + self.headers.append((header_name, value)) + elif key == 'CONTENT_TYPE': + self.headers.append(('CONTENT-TYPE', value)) + elif key == 'CONTENT_LENGTH': + self.headers.append(('CONTENT-LENGTH', value)) + + def set_body_reader(self): + """Set up the body reader based on CONTENT_LENGTH.""" + content_length = 0 + + # Get content length from vars + if 'CONTENT_LENGTH' in self.uwsgi_vars: + try: + content_length = max(int(self.uwsgi_vars['CONTENT_LENGTH']), 0) + except ValueError: + content_length = 0 + + self.body = Body(LengthReader(self.unreader, content_length)) + + def should_close(self): + """Determine if the connection should be closed after this request.""" + if self.must_close: + return True + + # Check HTTP_CONNECTION header + connection = self.uwsgi_vars.get('HTTP_CONNECTION', '').lower() + if connection == 'close': + return True + elif connection == 'keep-alive': + return False + + # Default to keep-alive for HTTP/1.1 + return False diff --git a/gunicorn/uwsgi/parser.py b/gunicorn/uwsgi/parser.py new file mode 100644 index 00000000..fede8c56 --- /dev/null +++ b/gunicorn/uwsgi/parser.py @@ -0,0 +1,12 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +from gunicorn.http.parser import Parser +from gunicorn.uwsgi.message import UWSGIRequest + + +class UWSGIParser(Parser): + """Parser for uWSGI protocol requests.""" + + mesg_class = UWSGIRequest diff --git a/gunicorn/workers/base_async.py b/gunicorn/workers/base_async.py index 9466d6aa..22ea09ab 100644 --- a/gunicorn/workers/base_async.py +++ b/gunicorn/workers/base_async.py @@ -32,7 +32,7 @@ class AsyncWorker(base.Worker): def handle(self, listener, client, addr): req = None try: - parser = http.RequestParser(self.cfg, client, addr) + parser = http.get_parser(self.cfg, client, addr) try: listener_name = listener.getsockname() if not self.cfg.keepalive: diff --git a/gunicorn/workers/gthread.py b/gunicorn/workers/gthread.py index 47270725..7cab9920 100644 --- a/gunicorn/workers/gthread.py +++ b/gunicorn/workers/gthread.py @@ -58,7 +58,7 @@ class TConn: self.sock = sock.ssl_wrap_socket(self.sock, self.cfg) # initialize the parser - self.parser = http.RequestParser(self.cfg, self.sock, self.client) + self.parser = http.get_parser(self.cfg, self.sock, self.client) def set_timeout(self): # Use monotonic clock for reliability (time.time() can jump due to NTP) diff --git a/gunicorn/workers/sync.py b/gunicorn/workers/sync.py index 4c029f91..99dbdaac 100644 --- a/gunicorn/workers/sync.py +++ b/gunicorn/workers/sync.py @@ -129,7 +129,7 @@ class SyncWorker(base.Worker): try: if self.cfg.is_ssl: client = sock.ssl_wrap_socket(client, self.cfg) - parser = http.RequestParser(self.cfg, client, addr) + parser = http.get_parser(self.cfg, client, addr) req = next(parser) self.handle_request(listener, req, client, addr) except http.errors.NoMoreData as e: diff --git a/tests/test_uwsgi.py b/tests/test_uwsgi.py new file mode 100644 index 00000000..26ff09f5 --- /dev/null +++ b/tests/test_uwsgi.py @@ -0,0 +1,435 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +import io +import pytest +from unittest import mock + +from gunicorn.uwsgi import ( + UWSGIRequest, + UWSGIParser, + UWSGIParseException, + InvalidUWSGIHeader, + UnsupportedModifier, + ForbiddenUWSGIRequest, +) +from gunicorn.http.unreader import IterUnreader + + +def make_uwsgi_packet(vars_dict, modifier1=0, modifier2=0): + """Create uWSGI packet for testing. + + Args: + vars_dict: Dict of WSGI environ variables + modifier1: Packet type (0 = WSGI request) + modifier2: Additional flags + + Returns: + bytes: Complete uWSGI packet + """ + vars_data = b'' + for key, value in vars_dict.items(): + k = key.encode('latin-1') + v = value.encode('latin-1') + vars_data += len(k).to_bytes(2, 'little') + k + vars_data += len(v).to_bytes(2, 'little') + v + + header = bytes([modifier1]) + len(vars_data).to_bytes(2, 'little') + bytes([modifier2]) + return header + vars_data + + +def make_uwsgi_packet_with_body(vars_dict, body=b'', modifier1=0, modifier2=0): + """Create uWSGI packet with body for testing.""" + if body: + vars_dict = dict(vars_dict) + vars_dict['CONTENT_LENGTH'] = str(len(body)) + return make_uwsgi_packet(vars_dict, modifier1, modifier2) + body + + +class MockConfig: + """Mock config object for testing.""" + + def __init__(self, is_ssl=False, uwsgi_allow_ips=None): + self.is_ssl = is_ssl + self.uwsgi_allow_ips = uwsgi_allow_ips or ['127.0.0.1', '::1'] + + +class TestUWSGIPacketConstruction: + """Test the packet construction helper.""" + + def test_empty_vars(self): + packet = make_uwsgi_packet({}) + assert packet == b'\x00\x00\x00\x00' # modifier1=0, size=0, modifier2=0 + + def test_single_var(self): + packet = make_uwsgi_packet({'KEY': 'val'}) + # Header: modifier1(0) + size(10 in LE) + modifier2(0) + # Var: key_size(3 in LE) + 'KEY' + val_size(3 in LE) + 'val' + # Size = 2 + 3 + 2 + 3 = 10 bytes + expected_header = b'\x00\x0a\x00\x00' + expected_var = b'\x03\x00KEY\x03\x00val' + assert packet == expected_header + expected_var + + def test_multiple_vars(self): + packet = make_uwsgi_packet({'A': '1', 'B': '2'}) + assert len(packet) == 4 + (2 + 1 + 2 + 1) * 2 # header + 2 vars + + +class TestUWSGIRequest: + """Test UWSGIRequest parsing.""" + + def test_parse_simple_request(self): + """Test parsing a simple GET request.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/test', + 'QUERY_STRING': 'foo=bar', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'GET' + assert req.path == '/test' + assert req.query == 'foo=bar' + assert req.uri == '/test?foo=bar' + + def test_parse_post_request_with_body(self): + """Test parsing a POST request with body.""" + body = b'name=test&value=123' + packet = make_uwsgi_packet_with_body({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/submit', + 'CONTENT_TYPE': 'application/x-www-form-urlencoded', + }, body) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'POST' + assert req.path == '/submit' + assert req.body.read() == body + + def test_parse_headers(self): + """Test that HTTP_* vars become headers.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_HOST': 'example.com', + 'HTTP_USER_AGENT': 'TestClient/1.0', + 'HTTP_ACCEPT': 'text/html', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + headers_dict = dict(req.headers) + assert headers_dict['HOST'] == 'example.com' + assert headers_dict['USER-AGENT'] == 'TestClient/1.0' + assert headers_dict['ACCEPT'] == 'text/html' + + def test_parse_content_type_header(self): + """Test that CONTENT_TYPE becomes a header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_TYPE': 'application/json', + 'CONTENT_LENGTH': '0', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + headers_dict = dict(req.headers) + assert headers_dict['CONTENT-TYPE'] == 'application/json' + assert headers_dict['CONTENT-LENGTH'] == '0' + + def test_https_scheme(self): + """Test scheme detection from HTTPS variable.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTPS': 'on', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.scheme == 'https' + + def test_wsgi_url_scheme(self): + """Test scheme from wsgi.url_scheme variable.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'wsgi.url_scheme': 'https', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.scheme == 'https' + + def test_default_values(self): + """Test default values when vars are missing.""" + packet = make_uwsgi_packet({}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.method == 'GET' + assert req.path == '/' + assert req.query == '' + assert req.uri == '/' + + def test_uwsgi_vars_preserved(self): + """Test that all vars are preserved in uwsgi_vars.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'SERVER_NAME': 'localhost', + 'SERVER_PORT': '8000', + 'CUSTOM_VAR': 'custom_value', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.uwsgi_vars['SERVER_NAME'] == 'localhost' + assert req.uwsgi_vars['SERVER_PORT'] == '8000' + assert req.uwsgi_vars['CUSTOM_VAR'] == 'custom_value' + + +class TestUWSGIRequestErrors: + """Test UWSGIRequest error handling.""" + + def test_incomplete_header(self): + """Test error on incomplete header.""" + unreader = IterUnreader([b'\x00\x00']) # Only 2 bytes + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'incomplete header' in str(exc_info.value) + + def test_incomplete_vars_block(self): + """Test error on truncated vars block.""" + # Header says 100 bytes of vars, but we only provide 10 + header = b'\x00\x64\x00\x00' # modifier1=0, size=100, modifier2=0 + unreader = IterUnreader([header + b'1234567890']) + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'incomplete vars block' in str(exc_info.value) + + def test_unsupported_modifier(self): + """Test error on non-zero modifier1.""" + packet = bytes([1]) + b'\x00\x00\x00' # modifier1=1 + unreader = IterUnreader([packet]) + cfg = MockConfig() + + with pytest.raises(UnsupportedModifier) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert exc_info.value.modifier == 1 + assert exc_info.value.code == 501 + + def test_truncated_key_size(self): + """Test error on truncated key size.""" + header = b'\x00\x01\x00\x00' # size=1, but need at least 2 bytes for key_size + unreader = IterUnreader([header + b'X']) + cfg = MockConfig() + + with pytest.raises(InvalidUWSGIHeader) as exc_info: + UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + assert 'truncated' in str(exc_info.value) + + def test_forbidden_ip(self): + """Test error when source IP not in allow list.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['192.168.1.1']) + + with pytest.raises(ForbiddenUWSGIRequest) as exc_info: + UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345)) + assert exc_info.value.code == 403 + assert '10.0.0.1' in str(exc_info.value) + + def test_allowed_ip_wildcard(self): + """Test that wildcard allows any IP.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['*']) + + # Should not raise + req = UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345)) + assert req.method == 'GET' + + def test_unix_socket_always_allowed(self): + """Test that UNIX socket connections are always allowed.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig(uwsgi_allow_ips=['127.0.0.1']) + + # UNIX socket has non-tuple peer_addr + req = UWSGIRequest(cfg, unreader, None) + assert req.method == 'GET' + + +class TestUWSGIRequestConnection: + """Test connection handling.""" + + def test_should_close_default(self): + """Test default keep-alive behavior.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is False + + def test_should_close_connection_close(self): + """Test Connection: close header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_CONNECTION': 'close', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is True + + def test_should_close_connection_keepalive(self): + """Test Connection: keep-alive header.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/', + 'HTTP_CONNECTION': 'keep-alive', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + assert req.should_close() is False + + def test_force_close(self): + """Test force_close method.""" + packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'}) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + req.force_close() + + assert req.should_close() is True + + +class TestUWSGIParser: + """Test UWSGIParser.""" + + def test_parser_iteration(self): + """Test iterating over parser for multiple requests.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'GET', + 'PATH_INFO': '/test', + 'HTTP_CONNECTION': 'close', # Single request + }) + cfg = MockConfig() + + # Parser expects an iterable source, not an unreader + parser = UWSGIParser(cfg, [packet], ('127.0.0.1', 12345)) + req = next(parser) + + assert req.method == 'GET' + assert req.path == '/test' + + def test_parser_mesg_class(self): + """Test that parser uses UWSGIRequest.""" + assert UWSGIParser.mesg_class is UWSGIRequest + + +class TestExceptionStrings: + """Test exception string representations.""" + + def test_invalid_uwsgi_header_str(self): + exc = InvalidUWSGIHeader("test message") + assert str(exc) == "Invalid uWSGI header: test message" + assert exc.code == 400 + + def test_unsupported_modifier_str(self): + exc = UnsupportedModifier(5) + assert str(exc) == "Unsupported uWSGI modifier1: 5" + assert exc.code == 501 + + def test_forbidden_uwsgi_request_str(self): + exc = ForbiddenUWSGIRequest("10.0.0.1") + assert str(exc) == "uWSGI request from '10.0.0.1' not allowed" + assert exc.code == 403 + + +class TestUWSGIBody: + """Test body reading.""" + + def test_read_body_in_chunks(self): + """Test reading body in multiple chunks.""" + body = b'A' * 1000 + packet = make_uwsgi_packet_with_body({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + }, body) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + result = b'' + chunk = req.body.read(100) + while chunk: + result += chunk + chunk = req.body.read(100) + + assert result == body + + def test_invalid_content_length(self): + """Test handling of invalid CONTENT_LENGTH.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_LENGTH': 'invalid', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + # Invalid content length should default to 0 + assert req.body.read() == b'' + + def test_negative_content_length(self): + """Test handling of negative CONTENT_LENGTH.""" + packet = make_uwsgi_packet({ + 'REQUEST_METHOD': 'POST', + 'PATH_INFO': '/', + 'CONTENT_LENGTH': '-5', + }) + unreader = IterUnreader([packet]) + cfg = MockConfig() + + req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345)) + + # Negative content length should default to 0 + assert req.body.read() == b''