mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 11:11:30 +08:00
uwsgi: Add native uWSGI binary protocol support
Add support for the uWSGI binary protocol, enabling gunicorn to work
with nginx's uwsgi_pass directive.
New module gunicorn/uwsgi/ with:
- UWSGIRequest: Parses 4-byte binary header and key-value vars block
- UWSGIParser: Protocol parser following existing Parser pattern
- Error classes: InvalidUWSGIHeader, UnsupportedModifier, ForbiddenUWSGIRequest
New configuration options:
- --protocol: Select 'http' (default) or 'uwsgi' protocol
- --uwsgi-allow-from: IP allowlist for uWSGI requests (default: localhost)
Worker integration via get_parser() factory in gunicorn/http/__init__.py,
updates to sync, gthread, and base_async workers.
Example nginx config:
upstream gunicorn {
server 127.0.0.1:8000;
}
location / {
uwsgi_pass gunicorn;
include uwsgi_params;
}
This commit is contained in:
parent
903a1fdf3c
commit
ac7296ec49
@ -2096,6 +2096,53 @@ class ProxyAllowFrom(Setting):
|
||||
"""
|
||||
|
||||
|
||||
class Protocol(Setting):
|
||||
name = "protocol"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--protocol"]
|
||||
meta = "STRING"
|
||||
validator = validate_string
|
||||
default = "http"
|
||||
desc = """\
|
||||
The protocol for incoming connections.
|
||||
|
||||
* ``http`` - Standard HTTP/1.x (default)
|
||||
* ``uwsgi`` - uWSGI binary protocol (for nginx uwsgi_pass)
|
||||
|
||||
When using the uWSGI protocol, Gunicorn can receive requests from
|
||||
nginx using the uwsgi_pass directive::
|
||||
|
||||
upstream gunicorn {
|
||||
server 127.0.0.1:8000;
|
||||
}
|
||||
location / {
|
||||
uwsgi_pass gunicorn;
|
||||
include uwsgi_params;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class UWSGIAllowFrom(Setting):
|
||||
name = "uwsgi_allow_ips"
|
||||
section = "Server Mechanics"
|
||||
cli = ["--uwsgi-allow-from"]
|
||||
validator = validate_string_to_addr_list
|
||||
default = "127.0.0.1,::1"
|
||||
desc = """\
|
||||
IPs allowed to send uWSGI protocol requests (comma separated).
|
||||
|
||||
Set to ``*`` to allow all IPs. This is useful for setups where you
|
||||
don't know in advance the IP address of front-end, but instead have
|
||||
ensured via other means that only your authorized front-ends can
|
||||
access Gunicorn.
|
||||
|
||||
.. note::
|
||||
|
||||
This option does not affect UNIX socket connections. Connections not associated with
|
||||
an IP address are treated as allowed, unconditionally.
|
||||
"""
|
||||
|
||||
|
||||
class KeyFile(Setting):
|
||||
name = "keyfile"
|
||||
section = "SSL"
|
||||
|
||||
@ -5,4 +5,23 @@
|
||||
from gunicorn.http.message import Message, Request
|
||||
from gunicorn.http.parser import RequestParser
|
||||
|
||||
__all__ = ['Message', 'Request', 'RequestParser']
|
||||
|
||||
def get_parser(cfg, source, source_addr):
|
||||
"""Get appropriate parser based on protocol config.
|
||||
|
||||
Args:
|
||||
cfg: Gunicorn config object
|
||||
source: Socket or iterable source
|
||||
source_addr: Source address tuple or None
|
||||
|
||||
Returns:
|
||||
Parser instance (RequestParser or UWSGIParser)
|
||||
"""
|
||||
protocol = getattr(cfg, 'protocol', 'http')
|
||||
if protocol == 'uwsgi':
|
||||
from gunicorn.uwsgi.parser import UWSGIParser
|
||||
return UWSGIParser(cfg, source, source_addr)
|
||||
return RequestParser(cfg, source, source_addr)
|
||||
|
||||
|
||||
__all__ = ['Message', 'Request', 'RequestParser', 'get_parser']
|
||||
|
||||
21
gunicorn/uwsgi/__init__.py
Normal file
21
gunicorn/uwsgi/__init__.py
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.uwsgi.message import UWSGIRequest
|
||||
from gunicorn.uwsgi.parser import UWSGIParser
|
||||
from gunicorn.uwsgi.errors import (
|
||||
UWSGIParseException,
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'UWSGIRequest',
|
||||
'UWSGIParser',
|
||||
'UWSGIParseException',
|
||||
'InvalidUWSGIHeader',
|
||||
'UnsupportedModifier',
|
||||
'ForbiddenUWSGIRequest',
|
||||
]
|
||||
46
gunicorn/uwsgi/errors.py
Normal file
46
gunicorn/uwsgi/errors.py
Normal file
@ -0,0 +1,46 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
# We don't need to call super() in __init__ methods of our
|
||||
# BaseException and Exception classes because we also define
|
||||
# our own __str__ methods so there is no need to pass 'message'
|
||||
# to the base class to get a meaningful output from 'str(exc)'.
|
||||
# pylint: disable=super-init-not-called
|
||||
|
||||
|
||||
class UWSGIParseException(Exception):
|
||||
"""Base exception for uWSGI protocol parsing errors."""
|
||||
|
||||
|
||||
class InvalidUWSGIHeader(UWSGIParseException):
|
||||
"""Raised when the uWSGI header is malformed."""
|
||||
|
||||
def __init__(self, msg=""):
|
||||
self.msg = msg
|
||||
self.code = 400
|
||||
|
||||
def __str__(self):
|
||||
return "Invalid uWSGI header: %s" % self.msg
|
||||
|
||||
|
||||
class UnsupportedModifier(UWSGIParseException):
|
||||
"""Raised when modifier1 is not 0 (WSGI request)."""
|
||||
|
||||
def __init__(self, modifier):
|
||||
self.modifier = modifier
|
||||
self.code = 501
|
||||
|
||||
def __str__(self):
|
||||
return "Unsupported uWSGI modifier1: %d" % self.modifier
|
||||
|
||||
|
||||
class ForbiddenUWSGIRequest(UWSGIParseException):
|
||||
"""Raised when source IP is not in the allow list."""
|
||||
|
||||
def __init__(self, host):
|
||||
self.host = host
|
||||
self.code = 403
|
||||
|
||||
def __str__(self):
|
||||
return "uWSGI request from %r not allowed" % self.host
|
||||
232
gunicorn/uwsgi/message.py
Normal file
232
gunicorn/uwsgi/message.py
Normal file
@ -0,0 +1,232 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import io
|
||||
|
||||
from gunicorn.http.body import LengthReader, Body
|
||||
from gunicorn.uwsgi.errors import (
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
|
||||
|
||||
# Maximum number of variables to prevent DoS
|
||||
MAX_UWSGI_VARS = 1000
|
||||
|
||||
|
||||
class UWSGIRequest:
|
||||
"""uWSGI protocol request parser.
|
||||
|
||||
The uWSGI protocol uses a 4-byte binary header:
|
||||
- Byte 0: modifier1 (packet type, 0 = WSGI request)
|
||||
- Bytes 1-2: datasize (16-bit little-endian, size of vars block)
|
||||
- Byte 3: modifier2 (additional flags, typically 0)
|
||||
|
||||
After the header:
|
||||
1. Vars block (datasize bytes): Key-value pairs containing WSGI environ
|
||||
- Each pair: 2-byte key_size (LE) + key + 2-byte val_size (LE) + value
|
||||
2. Request body (determined by CONTENT_LENGTH in vars)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.remote_addr = peer_addr
|
||||
self.req_number = req_number
|
||||
|
||||
# Request attributes (compatible with HTTP Request interface)
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = ""
|
||||
self.version = (1, 1) # uWSGI is HTTP/1.1 compatible
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
self.body = None
|
||||
self.scheme = "https" if cfg.is_ssl else "http"
|
||||
self.must_close = False
|
||||
|
||||
# uWSGI specific
|
||||
self.uwsgi_vars = {}
|
||||
self.modifier1 = 0
|
||||
self.modifier2 = 0
|
||||
|
||||
# Proxy protocol compatibility
|
||||
self.proxy_protocol_info = None
|
||||
|
||||
# Check if the source IP is allowed
|
||||
self._check_allowed_ip()
|
||||
|
||||
# Parse the request
|
||||
unused = self.parse(self.unreader)
|
||||
self.unreader.unread(unused)
|
||||
self.set_body_reader()
|
||||
|
||||
def _check_allowed_ip(self):
|
||||
"""Verify source IP is in the allowed list."""
|
||||
allow_ips = getattr(self.cfg, 'uwsgi_allow_ips', ['127.0.0.1', '::1'])
|
||||
|
||||
# UNIX sockets don't have IP addresses
|
||||
if not isinstance(self.peer_addr, tuple):
|
||||
return
|
||||
|
||||
# Wildcard allows all
|
||||
if '*' in allow_ips:
|
||||
return
|
||||
|
||||
if self.peer_addr[0] not in allow_ips:
|
||||
raise ForbiddenUWSGIRequest(self.peer_addr[0])
|
||||
|
||||
def force_close(self):
|
||||
"""Force the connection to close after this request."""
|
||||
self.must_close = True
|
||||
|
||||
def parse(self, unreader):
|
||||
"""Parse uWSGI packet header and vars block."""
|
||||
# Read the 4-byte header
|
||||
header = self._read_exact(unreader, 4)
|
||||
if len(header) < 4:
|
||||
raise InvalidUWSGIHeader("incomplete header")
|
||||
|
||||
self.modifier1 = header[0]
|
||||
datasize = int.from_bytes(header[1:3], 'little')
|
||||
self.modifier2 = header[3]
|
||||
|
||||
# Only modifier1=0 (WSGI request) is supported
|
||||
if self.modifier1 != 0:
|
||||
raise UnsupportedModifier(self.modifier1)
|
||||
|
||||
# Read the vars block
|
||||
if datasize > 0:
|
||||
vars_data = self._read_exact(unreader, datasize)
|
||||
if len(vars_data) < datasize:
|
||||
raise InvalidUWSGIHeader("incomplete vars block")
|
||||
self._parse_vars(vars_data)
|
||||
|
||||
# Extract HTTP request info from vars
|
||||
self._extract_request_info()
|
||||
|
||||
return b""
|
||||
|
||||
def _read_exact(self, unreader, size):
|
||||
"""Read exactly size bytes from the unreader."""
|
||||
buf = io.BytesIO()
|
||||
remaining = size
|
||||
|
||||
while remaining > 0:
|
||||
data = unreader.read()
|
||||
if not data:
|
||||
break
|
||||
buf.write(data)
|
||||
remaining = size - buf.tell()
|
||||
|
||||
result = buf.getvalue()
|
||||
# Put back any extra bytes
|
||||
if len(result) > size:
|
||||
unreader.unread(result[size:])
|
||||
result = result[:size]
|
||||
|
||||
return result
|
||||
|
||||
def _parse_vars(self, data):
|
||||
"""Parse uWSGI vars block into key-value pairs.
|
||||
|
||||
Format: key_size (2 bytes LE) + key + val_size (2 bytes LE) + value
|
||||
"""
|
||||
pos = 0
|
||||
var_count = 0
|
||||
|
||||
while pos < len(data):
|
||||
if var_count >= MAX_UWSGI_VARS:
|
||||
raise InvalidUWSGIHeader("too many variables")
|
||||
|
||||
# Key size (2 bytes, little-endian)
|
||||
if pos + 2 > len(data):
|
||||
raise InvalidUWSGIHeader("truncated key size")
|
||||
key_size = int.from_bytes(data[pos:pos + 2], 'little')
|
||||
pos += 2
|
||||
|
||||
# Key
|
||||
if pos + key_size > len(data):
|
||||
raise InvalidUWSGIHeader("truncated key")
|
||||
key = data[pos:pos + key_size].decode('latin-1')
|
||||
pos += key_size
|
||||
|
||||
# Value size (2 bytes, little-endian)
|
||||
if pos + 2 > len(data):
|
||||
raise InvalidUWSGIHeader("truncated value size")
|
||||
val_size = int.from_bytes(data[pos:pos + 2], 'little')
|
||||
pos += 2
|
||||
|
||||
# Value
|
||||
if pos + val_size > len(data):
|
||||
raise InvalidUWSGIHeader("truncated value")
|
||||
value = data[pos:pos + val_size].decode('latin-1')
|
||||
pos += val_size
|
||||
|
||||
self.uwsgi_vars[key] = value
|
||||
var_count += 1
|
||||
|
||||
def _extract_request_info(self):
|
||||
"""Extract HTTP request info from uWSGI vars."""
|
||||
# Method
|
||||
self.method = self.uwsgi_vars.get('REQUEST_METHOD', 'GET')
|
||||
|
||||
# URI and path
|
||||
self.path = self.uwsgi_vars.get('PATH_INFO', '/')
|
||||
self.query = self.uwsgi_vars.get('QUERY_STRING', '')
|
||||
|
||||
# Build URI
|
||||
if self.query:
|
||||
self.uri = "%s?%s" % (self.path, self.query)
|
||||
else:
|
||||
self.uri = self.path
|
||||
|
||||
# Scheme
|
||||
if self.uwsgi_vars.get('HTTPS', '').lower() in ('on', '1', 'true'):
|
||||
self.scheme = 'https'
|
||||
elif 'wsgi.url_scheme' in self.uwsgi_vars:
|
||||
self.scheme = self.uwsgi_vars['wsgi.url_scheme']
|
||||
|
||||
# Extract HTTP headers (HTTP_* vars)
|
||||
for key, value in self.uwsgi_vars.items():
|
||||
if key.startswith('HTTP_'):
|
||||
# Convert HTTP_HEADER_NAME to HEADER-NAME
|
||||
header_name = key[5:].replace('_', '-')
|
||||
self.headers.append((header_name, value))
|
||||
elif key == 'CONTENT_TYPE':
|
||||
self.headers.append(('CONTENT-TYPE', value))
|
||||
elif key == 'CONTENT_LENGTH':
|
||||
self.headers.append(('CONTENT-LENGTH', value))
|
||||
|
||||
def set_body_reader(self):
|
||||
"""Set up the body reader based on CONTENT_LENGTH."""
|
||||
content_length = 0
|
||||
|
||||
# Get content length from vars
|
||||
if 'CONTENT_LENGTH' in self.uwsgi_vars:
|
||||
try:
|
||||
content_length = max(int(self.uwsgi_vars['CONTENT_LENGTH']), 0)
|
||||
except ValueError:
|
||||
content_length = 0
|
||||
|
||||
self.body = Body(LengthReader(self.unreader, content_length))
|
||||
|
||||
def should_close(self):
|
||||
"""Determine if the connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
|
||||
# Check HTTP_CONNECTION header
|
||||
connection = self.uwsgi_vars.get('HTTP_CONNECTION', '').lower()
|
||||
if connection == 'close':
|
||||
return True
|
||||
elif connection == 'keep-alive':
|
||||
return False
|
||||
|
||||
# Default to keep-alive for HTTP/1.1
|
||||
return False
|
||||
12
gunicorn/uwsgi/parser.py
Normal file
12
gunicorn/uwsgi/parser.py
Normal file
@ -0,0 +1,12 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
from gunicorn.http.parser import Parser
|
||||
from gunicorn.uwsgi.message import UWSGIRequest
|
||||
|
||||
|
||||
class UWSGIParser(Parser):
|
||||
"""Parser for uWSGI protocol requests."""
|
||||
|
||||
mesg_class = UWSGIRequest
|
||||
@ -32,7 +32,7 @@ class AsyncWorker(base.Worker):
|
||||
def handle(self, listener, client, addr):
|
||||
req = None
|
||||
try:
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
parser = http.get_parser(self.cfg, client, addr)
|
||||
try:
|
||||
listener_name = listener.getsockname()
|
||||
if not self.cfg.keepalive:
|
||||
|
||||
@ -58,7 +58,7 @@ class TConn:
|
||||
self.sock = sock.ssl_wrap_socket(self.sock, self.cfg)
|
||||
|
||||
# initialize the parser
|
||||
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
|
||||
self.parser = http.get_parser(self.cfg, self.sock, self.client)
|
||||
|
||||
def set_timeout(self):
|
||||
# Use monotonic clock for reliability (time.time() can jump due to NTP)
|
||||
|
||||
@ -129,7 +129,7 @@ class SyncWorker(base.Worker):
|
||||
try:
|
||||
if self.cfg.is_ssl:
|
||||
client = sock.ssl_wrap_socket(client, self.cfg)
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
parser = http.get_parser(self.cfg, client, addr)
|
||||
req = next(parser)
|
||||
self.handle_request(listener, req, client, addr)
|
||||
except http.errors.NoMoreData as e:
|
||||
|
||||
435
tests/test_uwsgi.py
Normal file
435
tests/test_uwsgi.py
Normal file
@ -0,0 +1,435 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from unittest import mock
|
||||
|
||||
from gunicorn.uwsgi import (
|
||||
UWSGIRequest,
|
||||
UWSGIParser,
|
||||
UWSGIParseException,
|
||||
InvalidUWSGIHeader,
|
||||
UnsupportedModifier,
|
||||
ForbiddenUWSGIRequest,
|
||||
)
|
||||
from gunicorn.http.unreader import IterUnreader
|
||||
|
||||
|
||||
def make_uwsgi_packet(vars_dict, modifier1=0, modifier2=0):
|
||||
"""Create uWSGI packet for testing.
|
||||
|
||||
Args:
|
||||
vars_dict: Dict of WSGI environ variables
|
||||
modifier1: Packet type (0 = WSGI request)
|
||||
modifier2: Additional flags
|
||||
|
||||
Returns:
|
||||
bytes: Complete uWSGI packet
|
||||
"""
|
||||
vars_data = b''
|
||||
for key, value in vars_dict.items():
|
||||
k = key.encode('latin-1')
|
||||
v = value.encode('latin-1')
|
||||
vars_data += len(k).to_bytes(2, 'little') + k
|
||||
vars_data += len(v).to_bytes(2, 'little') + v
|
||||
|
||||
header = bytes([modifier1]) + len(vars_data).to_bytes(2, 'little') + bytes([modifier2])
|
||||
return header + vars_data
|
||||
|
||||
|
||||
def make_uwsgi_packet_with_body(vars_dict, body=b'', modifier1=0, modifier2=0):
|
||||
"""Create uWSGI packet with body for testing."""
|
||||
if body:
|
||||
vars_dict = dict(vars_dict)
|
||||
vars_dict['CONTENT_LENGTH'] = str(len(body))
|
||||
return make_uwsgi_packet(vars_dict, modifier1, modifier2) + body
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config object for testing."""
|
||||
|
||||
def __init__(self, is_ssl=False, uwsgi_allow_ips=None):
|
||||
self.is_ssl = is_ssl
|
||||
self.uwsgi_allow_ips = uwsgi_allow_ips or ['127.0.0.1', '::1']
|
||||
|
||||
|
||||
class TestUWSGIPacketConstruction:
|
||||
"""Test the packet construction helper."""
|
||||
|
||||
def test_empty_vars(self):
|
||||
packet = make_uwsgi_packet({})
|
||||
assert packet == b'\x00\x00\x00\x00' # modifier1=0, size=0, modifier2=0
|
||||
|
||||
def test_single_var(self):
|
||||
packet = make_uwsgi_packet({'KEY': 'val'})
|
||||
# Header: modifier1(0) + size(10 in LE) + modifier2(0)
|
||||
# Var: key_size(3 in LE) + 'KEY' + val_size(3 in LE) + 'val'
|
||||
# Size = 2 + 3 + 2 + 3 = 10 bytes
|
||||
expected_header = b'\x00\x0a\x00\x00'
|
||||
expected_var = b'\x03\x00KEY\x03\x00val'
|
||||
assert packet == expected_header + expected_var
|
||||
|
||||
def test_multiple_vars(self):
|
||||
packet = make_uwsgi_packet({'A': '1', 'B': '2'})
|
||||
assert len(packet) == 4 + (2 + 1 + 2 + 1) * 2 # header + 2 vars
|
||||
|
||||
|
||||
class TestUWSGIRequest:
|
||||
"""Test UWSGIRequest parsing."""
|
||||
|
||||
def test_parse_simple_request(self):
|
||||
"""Test parsing a simple GET request."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/test',
|
||||
'QUERY_STRING': 'foo=bar',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/test'
|
||||
assert req.query == 'foo=bar'
|
||||
assert req.uri == '/test?foo=bar'
|
||||
|
||||
def test_parse_post_request_with_body(self):
|
||||
"""Test parsing a POST request with body."""
|
||||
body = b'name=test&value=123'
|
||||
packet = make_uwsgi_packet_with_body({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/submit',
|
||||
'CONTENT_TYPE': 'application/x-www-form-urlencoded',
|
||||
}, body)
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'POST'
|
||||
assert req.path == '/submit'
|
||||
assert req.body.read() == body
|
||||
|
||||
def test_parse_headers(self):
|
||||
"""Test that HTTP_* vars become headers."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_HOST': 'example.com',
|
||||
'HTTP_USER_AGENT': 'TestClient/1.0',
|
||||
'HTTP_ACCEPT': 'text/html',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
headers_dict = dict(req.headers)
|
||||
assert headers_dict['HOST'] == 'example.com'
|
||||
assert headers_dict['USER-AGENT'] == 'TestClient/1.0'
|
||||
assert headers_dict['ACCEPT'] == 'text/html'
|
||||
|
||||
def test_parse_content_type_header(self):
|
||||
"""Test that CONTENT_TYPE becomes a header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_TYPE': 'application/json',
|
||||
'CONTENT_LENGTH': '0',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
headers_dict = dict(req.headers)
|
||||
assert headers_dict['CONTENT-TYPE'] == 'application/json'
|
||||
assert headers_dict['CONTENT-LENGTH'] == '0'
|
||||
|
||||
def test_https_scheme(self):
|
||||
"""Test scheme detection from HTTPS variable."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTPS': 'on',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.scheme == 'https'
|
||||
|
||||
def test_wsgi_url_scheme(self):
|
||||
"""Test scheme from wsgi.url_scheme variable."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'wsgi.url_scheme': 'https',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.scheme == 'https'
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values when vars are missing."""
|
||||
packet = make_uwsgi_packet({})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/'
|
||||
assert req.query == ''
|
||||
assert req.uri == '/'
|
||||
|
||||
def test_uwsgi_vars_preserved(self):
|
||||
"""Test that all vars are preserved in uwsgi_vars."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'SERVER_NAME': 'localhost',
|
||||
'SERVER_PORT': '8000',
|
||||
'CUSTOM_VAR': 'custom_value',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.uwsgi_vars['SERVER_NAME'] == 'localhost'
|
||||
assert req.uwsgi_vars['SERVER_PORT'] == '8000'
|
||||
assert req.uwsgi_vars['CUSTOM_VAR'] == 'custom_value'
|
||||
|
||||
|
||||
class TestUWSGIRequestErrors:
|
||||
"""Test UWSGIRequest error handling."""
|
||||
|
||||
def test_incomplete_header(self):
|
||||
"""Test error on incomplete header."""
|
||||
unreader = IterUnreader([b'\x00\x00']) # Only 2 bytes
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'incomplete header' in str(exc_info.value)
|
||||
|
||||
def test_incomplete_vars_block(self):
|
||||
"""Test error on truncated vars block."""
|
||||
# Header says 100 bytes of vars, but we only provide 10
|
||||
header = b'\x00\x64\x00\x00' # modifier1=0, size=100, modifier2=0
|
||||
unreader = IterUnreader([header + b'1234567890'])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'incomplete vars block' in str(exc_info.value)
|
||||
|
||||
def test_unsupported_modifier(self):
|
||||
"""Test error on non-zero modifier1."""
|
||||
packet = bytes([1]) + b'\x00\x00\x00' # modifier1=1
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(UnsupportedModifier) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert exc_info.value.modifier == 1
|
||||
assert exc_info.value.code == 501
|
||||
|
||||
def test_truncated_key_size(self):
|
||||
"""Test error on truncated key size."""
|
||||
header = b'\x00\x01\x00\x00' # size=1, but need at least 2 bytes for key_size
|
||||
unreader = IterUnreader([header + b'X'])
|
||||
cfg = MockConfig()
|
||||
|
||||
with pytest.raises(InvalidUWSGIHeader) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
assert 'truncated' in str(exc_info.value)
|
||||
|
||||
def test_forbidden_ip(self):
|
||||
"""Test error when source IP not in allow list."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['192.168.1.1'])
|
||||
|
||||
with pytest.raises(ForbiddenUWSGIRequest) as exc_info:
|
||||
UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
|
||||
assert exc_info.value.code == 403
|
||||
assert '10.0.0.1' in str(exc_info.value)
|
||||
|
||||
def test_allowed_ip_wildcard(self):
|
||||
"""Test that wildcard allows any IP."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['*'])
|
||||
|
||||
# Should not raise
|
||||
req = UWSGIRequest(cfg, unreader, ('10.0.0.1', 12345))
|
||||
assert req.method == 'GET'
|
||||
|
||||
def test_unix_socket_always_allowed(self):
|
||||
"""Test that UNIX socket connections are always allowed."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig(uwsgi_allow_ips=['127.0.0.1'])
|
||||
|
||||
# UNIX socket has non-tuple peer_addr
|
||||
req = UWSGIRequest(cfg, unreader, None)
|
||||
assert req.method == 'GET'
|
||||
|
||||
|
||||
class TestUWSGIRequestConnection:
|
||||
"""Test connection handling."""
|
||||
|
||||
def test_should_close_default(self):
|
||||
"""Test default keep-alive behavior."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is False
|
||||
|
||||
def test_should_close_connection_close(self):
|
||||
"""Test Connection: close header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_CONNECTION': 'close',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is True
|
||||
|
||||
def test_should_close_connection_keepalive(self):
|
||||
"""Test Connection: keep-alive header."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/',
|
||||
'HTTP_CONNECTION': 'keep-alive',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
assert req.should_close() is False
|
||||
|
||||
def test_force_close(self):
|
||||
"""Test force_close method."""
|
||||
packet = make_uwsgi_packet({'REQUEST_METHOD': 'GET', 'PATH_INFO': '/'})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
req.force_close()
|
||||
|
||||
assert req.should_close() is True
|
||||
|
||||
|
||||
class TestUWSGIParser:
|
||||
"""Test UWSGIParser."""
|
||||
|
||||
def test_parser_iteration(self):
|
||||
"""Test iterating over parser for multiple requests."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'GET',
|
||||
'PATH_INFO': '/test',
|
||||
'HTTP_CONNECTION': 'close', # Single request
|
||||
})
|
||||
cfg = MockConfig()
|
||||
|
||||
# Parser expects an iterable source, not an unreader
|
||||
parser = UWSGIParser(cfg, [packet], ('127.0.0.1', 12345))
|
||||
req = next(parser)
|
||||
|
||||
assert req.method == 'GET'
|
||||
assert req.path == '/test'
|
||||
|
||||
def test_parser_mesg_class(self):
|
||||
"""Test that parser uses UWSGIRequest."""
|
||||
assert UWSGIParser.mesg_class is UWSGIRequest
|
||||
|
||||
|
||||
class TestExceptionStrings:
|
||||
"""Test exception string representations."""
|
||||
|
||||
def test_invalid_uwsgi_header_str(self):
|
||||
exc = InvalidUWSGIHeader("test message")
|
||||
assert str(exc) == "Invalid uWSGI header: test message"
|
||||
assert exc.code == 400
|
||||
|
||||
def test_unsupported_modifier_str(self):
|
||||
exc = UnsupportedModifier(5)
|
||||
assert str(exc) == "Unsupported uWSGI modifier1: 5"
|
||||
assert exc.code == 501
|
||||
|
||||
def test_forbidden_uwsgi_request_str(self):
|
||||
exc = ForbiddenUWSGIRequest("10.0.0.1")
|
||||
assert str(exc) == "uWSGI request from '10.0.0.1' not allowed"
|
||||
assert exc.code == 403
|
||||
|
||||
|
||||
class TestUWSGIBody:
|
||||
"""Test body reading."""
|
||||
|
||||
def test_read_body_in_chunks(self):
|
||||
"""Test reading body in multiple chunks."""
|
||||
body = b'A' * 1000
|
||||
packet = make_uwsgi_packet_with_body({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
}, body)
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
result = b''
|
||||
chunk = req.body.read(100)
|
||||
while chunk:
|
||||
result += chunk
|
||||
chunk = req.body.read(100)
|
||||
|
||||
assert result == body
|
||||
|
||||
def test_invalid_content_length(self):
|
||||
"""Test handling of invalid CONTENT_LENGTH."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_LENGTH': 'invalid',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
# Invalid content length should default to 0
|
||||
assert req.body.read() == b''
|
||||
|
||||
def test_negative_content_length(self):
|
||||
"""Test handling of negative CONTENT_LENGTH."""
|
||||
packet = make_uwsgi_packet({
|
||||
'REQUEST_METHOD': 'POST',
|
||||
'PATH_INFO': '/',
|
||||
'CONTENT_LENGTH': '-5',
|
||||
})
|
||||
unreader = IterUnreader([packet])
|
||||
cfg = MockConfig()
|
||||
|
||||
req = UWSGIRequest(cfg, unreader, ('127.0.0.1', 12345))
|
||||
|
||||
# Negative content length should default to 0
|
||||
assert req.body.read() == b''
|
||||
Loading…
x
Reference in New Issue
Block a user