mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 10:11:30 +08:00
Add callback-based HTTP parser for ASGI protocol
Add PythonProtocol class that mirrors H1CProtocol callback interface: - Callbacks: on_message_begin, on_url, on_header, on_headers_complete, on_body, on_message_complete - Properties: method, path, http_version, headers, content_length, is_chunked, should_keep_alive - Methods: feed(data), reset() - Supports Content-Length and chunked transfer encoding Add CallbackRequest adapter for building requests from parser state. Works with both H1CProtocol (C extension) and PythonProtocol. Add unit tests for PythonProtocol and CallbackRequest.
This commit is contained in:
parent
7818401182
commit
ae7653057f
@ -921,3 +921,424 @@ class FastAsyncRequest:
|
||||
data = await self.read_body(8192)
|
||||
if not data:
|
||||
break
|
||||
|
||||
|
||||
class ParseError(Exception):
|
||||
"""Error raised during HTTP parsing."""
|
||||
pass
|
||||
|
||||
|
||||
class PythonProtocol:
|
||||
"""Callback-based HTTP/1.1 parser (pure Python fallback).
|
||||
|
||||
Mirrors H1CProtocol interface for seamless switching between
|
||||
the C extension and pure Python implementations.
|
||||
|
||||
Callbacks:
|
||||
on_message_begin: () -> None - Called when request starts
|
||||
on_url: (url: bytes) -> None - Called with request URL/path
|
||||
on_header: (name: bytes, value: bytes) -> None - Called for each header
|
||||
on_headers_complete: () -> bool - Called when headers done (return True to skip body)
|
||||
on_body: (chunk: bytes) -> None - Called with body data chunks
|
||||
on_message_complete: () -> None - Called when request is complete
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'_on_message_begin', '_on_url', '_on_header',
|
||||
'_on_headers_complete', '_on_body', '_on_message_complete',
|
||||
'_state', '_buffer', '_headers_list',
|
||||
'method', 'path', 'http_version', 'headers',
|
||||
'content_length', 'is_chunked', 'should_keep_alive', 'is_complete',
|
||||
'_body_remaining', '_skip_body',
|
||||
'_chunk_state', '_chunk_size', '_chunk_remaining',
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_message_begin=None,
|
||||
on_url=None,
|
||||
on_header=None,
|
||||
on_headers_complete=None,
|
||||
on_body=None,
|
||||
on_message_complete=None,
|
||||
):
|
||||
self._on_message_begin = on_message_begin
|
||||
self._on_url = on_url
|
||||
self._on_header = on_header
|
||||
self._on_headers_complete = on_headers_complete
|
||||
self._on_body = on_body
|
||||
self._on_message_complete = on_message_complete
|
||||
|
||||
# Parser state: request_line, headers, body, chunked_size, chunked_data, complete
|
||||
self._state = 'request_line'
|
||||
self._buffer = bytearray()
|
||||
self._headers_list = []
|
||||
|
||||
# Request info (populated during parsing)
|
||||
self.method = None
|
||||
self.path = None
|
||||
self.http_version = None
|
||||
self.headers = []
|
||||
self.content_length = None
|
||||
self.is_chunked = False
|
||||
self.should_keep_alive = True
|
||||
self.is_complete = False
|
||||
|
||||
# Body state
|
||||
self._body_remaining = 0
|
||||
self._skip_body = False
|
||||
|
||||
# Chunked transfer state
|
||||
self._chunk_state = 'size' # size, data, trailer
|
||||
self._chunk_size = 0
|
||||
self._chunk_remaining = 0
|
||||
|
||||
def feed(self, data):
|
||||
"""Process data, fire callbacks synchronously.
|
||||
|
||||
Args:
|
||||
data: bytes or bytearray of incoming data
|
||||
|
||||
Raises:
|
||||
ParseError: If the HTTP request is malformed
|
||||
"""
|
||||
self._buffer.extend(data)
|
||||
|
||||
while self._buffer:
|
||||
if self._state == 'request_line':
|
||||
if not self._parse_request_line():
|
||||
break
|
||||
elif self._state == 'headers':
|
||||
if not self._parse_headers():
|
||||
break
|
||||
elif self._state == 'body':
|
||||
if not self._parse_body():
|
||||
break
|
||||
elif self._state == 'chunked':
|
||||
if not self._parse_chunked_body():
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
def reset(self):
|
||||
"""Reset for next request (keepalive)."""
|
||||
self._state = 'request_line'
|
||||
self._buffer.clear()
|
||||
self._headers_list = []
|
||||
self.method = None
|
||||
self.path = None
|
||||
self.http_version = None
|
||||
self.headers = []
|
||||
self.content_length = None
|
||||
self.is_chunked = False
|
||||
self.should_keep_alive = True
|
||||
self.is_complete = False
|
||||
self._body_remaining = 0
|
||||
self._skip_body = False
|
||||
self._chunk_state = 'size'
|
||||
self._chunk_size = 0
|
||||
self._chunk_remaining = 0
|
||||
|
||||
def _parse_request_line(self):
|
||||
"""Parse request line, return True if complete."""
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
# Parse: METHOD PATH HTTP/x.y
|
||||
parts = line.split(b' ', 2)
|
||||
if len(parts) != 3:
|
||||
raise ParseError("Invalid request line")
|
||||
|
||||
self.method = parts[0]
|
||||
self.path = parts[1]
|
||||
|
||||
# Parse version
|
||||
version = parts[2]
|
||||
if version == b'HTTP/1.1':
|
||||
self.http_version = (1, 1)
|
||||
elif version == b'HTTP/1.0':
|
||||
self.http_version = (1, 0)
|
||||
else:
|
||||
raise ParseError("Unsupported HTTP version")
|
||||
|
||||
if self._on_message_begin:
|
||||
self._on_message_begin()
|
||||
if self._on_url:
|
||||
self._on_url(self.path)
|
||||
|
||||
self._state = 'headers'
|
||||
return True
|
||||
|
||||
def _parse_headers(self):
|
||||
"""Parse headers, return True if headers are complete."""
|
||||
while True:
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
if not line:
|
||||
# Empty line = end of headers
|
||||
self._finalize_headers()
|
||||
return True
|
||||
|
||||
# Parse header
|
||||
colon = line.find(b':')
|
||||
if colon == -1:
|
||||
raise ParseError("Invalid header")
|
||||
|
||||
name = line[:colon].strip().lower()
|
||||
value = line[colon + 1:].strip()
|
||||
|
||||
self._headers_list.append((name, value))
|
||||
|
||||
if self._on_header:
|
||||
self._on_header(name, value)
|
||||
|
||||
def _finalize_headers(self):
|
||||
"""Called when all headers received."""
|
||||
self.headers = self._headers_list
|
||||
|
||||
# Extract content-length and chunked
|
||||
for name, value in self.headers:
|
||||
if name == b'content-length':
|
||||
self.content_length = int(value)
|
||||
self._body_remaining = self.content_length
|
||||
elif name == b'transfer-encoding':
|
||||
self.is_chunked = b'chunked' in value.lower()
|
||||
elif name == b'connection':
|
||||
val = value.lower()
|
||||
if b'close' in val:
|
||||
self.should_keep_alive = False
|
||||
elif b'keep-alive' in val:
|
||||
self.should_keep_alive = True
|
||||
|
||||
# HTTP/1.0 defaults to close
|
||||
if self.http_version == (1, 0) and self.should_keep_alive:
|
||||
# Only keep-alive if explicitly requested
|
||||
has_keepalive = any(
|
||||
name == b'connection' and b'keep-alive' in value.lower()
|
||||
for name, value in self.headers
|
||||
)
|
||||
if not has_keepalive:
|
||||
self.should_keep_alive = False
|
||||
|
||||
if self._on_headers_complete:
|
||||
self._skip_body = self._on_headers_complete()
|
||||
|
||||
# Determine next state
|
||||
if self._skip_body:
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
elif self.is_chunked:
|
||||
self._state = 'chunked'
|
||||
self._chunk_state = 'size'
|
||||
elif self.content_length and self.content_length > 0:
|
||||
self._state = 'body'
|
||||
else:
|
||||
# No body
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
|
||||
def _parse_body(self):
|
||||
"""Parse Content-Length delimited body."""
|
||||
if not self._buffer or self._body_remaining <= 0:
|
||||
return False
|
||||
|
||||
chunk_size = min(len(self._buffer), self._body_remaining)
|
||||
chunk = bytes(self._buffer[:chunk_size])
|
||||
del self._buffer[:chunk_size]
|
||||
self._body_remaining -= chunk_size
|
||||
|
||||
if self._on_body:
|
||||
self._on_body(chunk)
|
||||
|
||||
if self._body_remaining <= 0:
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
|
||||
return True
|
||||
|
||||
def _parse_chunked_body(self):
|
||||
"""Parse chunked transfer encoding."""
|
||||
while self._buffer:
|
||||
if self._chunk_state == 'size':
|
||||
# Looking for chunk size line
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
size_line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
# Handle chunk extensions (e.g., "5;ext=value")
|
||||
semicolon = size_line.find(b';')
|
||||
if semicolon != -1:
|
||||
size_line = size_line[:semicolon].strip()
|
||||
|
||||
try:
|
||||
self._chunk_size = int(size_line, 16)
|
||||
except ValueError:
|
||||
raise ParseError("Invalid chunk size")
|
||||
|
||||
if self._chunk_size == 0:
|
||||
# Final chunk - skip trailers
|
||||
self._chunk_state = 'trailer'
|
||||
else:
|
||||
self._chunk_remaining = self._chunk_size
|
||||
self._chunk_state = 'data'
|
||||
|
||||
elif self._chunk_state == 'data':
|
||||
# Reading chunk data
|
||||
if not self._buffer:
|
||||
return False
|
||||
|
||||
to_read = min(len(self._buffer), self._chunk_remaining)
|
||||
chunk = bytes(self._buffer[:to_read])
|
||||
del self._buffer[:to_read]
|
||||
self._chunk_remaining -= to_read
|
||||
|
||||
if self._on_body:
|
||||
self._on_body(chunk)
|
||||
|
||||
if self._chunk_remaining == 0:
|
||||
# Need to consume trailing CRLF
|
||||
self._chunk_state = 'crlf'
|
||||
|
||||
elif self._chunk_state == 'crlf':
|
||||
# Skip CRLF after chunk data
|
||||
if len(self._buffer) < 2:
|
||||
return False
|
||||
del self._buffer[:2] # Skip \r\n
|
||||
self._chunk_state = 'size'
|
||||
|
||||
elif self._chunk_state == 'trailer':
|
||||
# Skip trailer headers
|
||||
idx = self._buffer.find(b'\r\n')
|
||||
if idx == -1:
|
||||
return False
|
||||
|
||||
line = bytes(self._buffer[:idx])
|
||||
del self._buffer[:idx + 2]
|
||||
|
||||
if not line:
|
||||
# Empty line = end of trailers
|
||||
self._state = 'complete'
|
||||
self.is_complete = True
|
||||
if self._on_message_complete:
|
||||
self._on_message_complete()
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class CallbackRequest:
|
||||
"""Request object built from callback parser state.
|
||||
|
||||
Works with both H1CProtocol (C extension) and PythonProtocol.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'method', 'uri', 'path', 'query', 'fragment', 'version',
|
||||
'headers', 'headers_bytes', 'scheme',
|
||||
'content_length', 'chunked', 'must_close',
|
||||
'proxy_protocol_info', '_expect_100_continue',
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
self.query = None
|
||||
self.fragment = None
|
||||
self.version = None
|
||||
self.headers = []
|
||||
self.headers_bytes = []
|
||||
self.scheme = "http"
|
||||
self.content_length = 0
|
||||
self.chunked = False
|
||||
self.must_close = False
|
||||
self.proxy_protocol_info = None
|
||||
self._expect_100_continue = False
|
||||
|
||||
@classmethod
|
||||
def from_parser(cls, parser, is_ssl=False):
|
||||
"""Build request from callback parser state.
|
||||
|
||||
Args:
|
||||
parser: H1CProtocol or PythonProtocol instance
|
||||
is_ssl: Whether connection is SSL/TLS
|
||||
|
||||
Returns:
|
||||
CallbackRequest instance
|
||||
"""
|
||||
req = cls()
|
||||
req.method = parser.method.decode('ascii')
|
||||
|
||||
# Parse path and query from URL
|
||||
raw_path = parser.path
|
||||
if b'?' in raw_path:
|
||||
path_part, query_part = raw_path.split(b'?', 1)
|
||||
req.path = path_part.decode('latin-1')
|
||||
req.query = query_part.decode('latin-1')
|
||||
else:
|
||||
req.path = raw_path.decode('latin-1')
|
||||
req.query = ''
|
||||
|
||||
req.uri = raw_path.decode('latin-1')
|
||||
req.fragment = ''
|
||||
req.version = parser.http_version
|
||||
|
||||
# Headers - store both bytes (for ASGI scope) and strings (for compatibility)
|
||||
req.headers_bytes = list(parser.headers)
|
||||
req.headers = [
|
||||
(n.decode('latin-1').upper(), v.decode('latin-1'))
|
||||
for n, v in parser.headers
|
||||
]
|
||||
|
||||
req.scheme = 'https' if is_ssl else 'http'
|
||||
req.content_length = parser.content_length or 0
|
||||
req.chunked = parser.is_chunked
|
||||
req.must_close = not parser.should_keep_alive
|
||||
|
||||
# Check for Expect: 100-continue
|
||||
for name, value in parser.headers:
|
||||
if name == b'expect' and value.lower() == b'100-continue':
|
||||
req._expect_100_continue = True
|
||||
break
|
||||
|
||||
return req
|
||||
|
||||
def should_close(self):
|
||||
"""Check if connection should be closed after this request."""
|
||||
if self.must_close:
|
||||
return True
|
||||
for name, value in self.headers:
|
||||
if name == "CONNECTION":
|
||||
v = value.lower().strip(" \t")
|
||||
if v == "close":
|
||||
return True
|
||||
elif v == "keep-alive":
|
||||
return False
|
||||
break
|
||||
return self.version <= (1, 0)
|
||||
|
||||
def get_header(self, name):
|
||||
"""Get a header value by name (case-insensitive)."""
|
||||
name = name.upper()
|
||||
for h, v in self.headers:
|
||||
if h == name:
|
||||
return v
|
||||
return None
|
||||
|
||||
518
tests/test_python_protocol.py
Normal file
518
tests/test_python_protocol.py
Normal file
@ -0,0 +1,518 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Tests for PythonProtocol callback-based HTTP parser.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from gunicorn.asgi.parser import PythonProtocol, CallbackRequest, ParseError
|
||||
|
||||
|
||||
class TestPythonProtocolBasic:
|
||||
"""Test basic request parsing."""
|
||||
|
||||
def test_simple_get_request(self):
|
||||
"""Test parsing a simple GET request."""
|
||||
headers_complete = []
|
||||
message_complete = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_headers_complete=lambda: headers_complete.append(True),
|
||||
on_message_complete=lambda: message_complete.append(True),
|
||||
)
|
||||
|
||||
data = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.path == b"/path"
|
||||
assert parser.http_version == (1, 1)
|
||||
assert len(parser.headers) == 1
|
||||
assert parser.headers[0] == (b"host", b"example.com")
|
||||
assert parser.is_complete is True
|
||||
assert len(headers_complete) == 1
|
||||
assert len(message_complete) == 1
|
||||
|
||||
def test_get_with_query_string(self):
|
||||
"""Test parsing GET with query string."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"GET"
|
||||
assert parser.path == b"/search?q=test&page=1"
|
||||
assert parser.is_complete is True
|
||||
|
||||
def test_http_10_request(self):
|
||||
"""Test parsing HTTP/1.0 request."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.http_version == (1, 0)
|
||||
assert parser.should_keep_alive is False # HTTP/1.0 default
|
||||
|
||||
def test_http_10_with_keepalive(self):
|
||||
"""Test HTTP/1.0 with explicit keep-alive."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = b"GET / HTTP/1.0\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n"
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.http_version == (1, 0)
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_multiple_headers(self):
|
||||
"""Test parsing multiple headers."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
data = (
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Accept: text/html\r\n"
|
||||
b"Accept-Language: en-US\r\n"
|
||||
b"User-Agent: Test/1.0\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert len(parser.headers) == 4
|
||||
header_names = [h[0] for h in parser.headers]
|
||||
assert b"host" in header_names
|
||||
assert b"accept" in header_names
|
||||
assert b"accept-language" in header_names
|
||||
assert b"user-agent" in header_names
|
||||
|
||||
|
||||
class TestPythonProtocolBody:
|
||||
"""Test request body parsing."""
|
||||
|
||||
def test_post_with_content_length(self):
|
||||
"""Test POST with Content-Length body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST /submit HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 13\r\n"
|
||||
b"\r\n"
|
||||
b"name=testuser"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.method == b"POST"
|
||||
assert parser.content_length == 13
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"name=testuser"
|
||||
|
||||
def test_chunked_body(self):
|
||||
"""Test chunked transfer encoding."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\nhello\r\n"
|
||||
b"6\r\n world\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert parser.is_chunked is True
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"hello world"
|
||||
|
||||
def test_chunked_with_extension(self):
|
||||
"""Test chunked with chunk extension."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
data = (
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5;ext=value\r\nhello\r\n"
|
||||
b"0\r\n\r\n"
|
||||
)
|
||||
parser.feed(data)
|
||||
|
||||
assert b"".join(body_chunks) == b"hello"
|
||||
|
||||
|
||||
class TestPythonProtocolIncremental:
|
||||
"""Test incremental/partial data feeding."""
|
||||
|
||||
def test_partial_request_line(self):
|
||||
"""Test feeding partial request line."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
# Feed partial request line
|
||||
parser.feed(b"GET /path ")
|
||||
assert parser.method is None
|
||||
assert parser.is_complete is False
|
||||
|
||||
# Complete the request line and headers
|
||||
parser.feed(b"HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.method == b"GET"
|
||||
assert parser.is_complete is True
|
||||
|
||||
def test_partial_headers(self):
|
||||
"""Test feeding partial headers."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
parser.feed(b"GET / HTTP/1.1\r\n")
|
||||
parser.feed(b"Host: exa")
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"mple.com\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert parser.headers[0] == (b"host", b"example.com")
|
||||
|
||||
def test_partial_body(self):
|
||||
"""Test feeding partial body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"world")
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"helloworld"
|
||||
|
||||
def test_partial_chunked_body(self):
|
||||
"""Test feeding partial chunked body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"5\r\nhel"
|
||||
)
|
||||
assert parser.is_complete is False
|
||||
|
||||
parser.feed(b"lo\r\n0\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert b"".join(body_chunks) == b"hello"
|
||||
|
||||
|
||||
class TestPythonProtocolErrors:
|
||||
"""Test error handling."""
|
||||
|
||||
def test_invalid_request_line(self):
|
||||
"""Test invalid request line."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"INVALID\r\n")
|
||||
|
||||
def test_invalid_header(self):
|
||||
"""Test invalid header (no colon)."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"GET / HTTP/1.1\r\nBadHeader\r\n\r\n")
|
||||
|
||||
def test_unsupported_http_version(self):
|
||||
"""Test unsupported HTTP version."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(b"GET / HTTP/2.0\r\n\r\n")
|
||||
|
||||
def test_invalid_chunk_size(self):
|
||||
"""Test invalid chunk size."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
with pytest.raises(ParseError):
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
b"XYZ\r\n" # Invalid hex
|
||||
)
|
||||
|
||||
|
||||
class TestPythonProtocolReset:
|
||||
"""Test parser reset for keepalive."""
|
||||
|
||||
def test_reset_clears_state(self):
|
||||
"""Test that reset clears all state."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.is_complete is True
|
||||
assert parser.method == b"GET"
|
||||
|
||||
parser.reset()
|
||||
|
||||
assert parser.method is None
|
||||
assert parser.path is None
|
||||
assert parser.http_version is None
|
||||
assert parser.headers == []
|
||||
assert parser.content_length is None
|
||||
assert parser.is_chunked is False
|
||||
assert parser.is_complete is False
|
||||
|
||||
def test_multiple_requests_keepalive(self):
|
||||
"""Test handling multiple requests on keepalive connection."""
|
||||
parser = PythonProtocol()
|
||||
|
||||
# First request
|
||||
parser.feed(b"GET /first HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.path == b"/first"
|
||||
assert parser.is_complete is True
|
||||
|
||||
parser.reset()
|
||||
|
||||
# Second request
|
||||
parser.feed(b"GET /second HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
assert parser.path == b"/second"
|
||||
assert parser.is_complete is True
|
||||
|
||||
|
||||
class TestPythonProtocolCallbacks:
|
||||
"""Test callback firing."""
|
||||
|
||||
def test_all_callbacks(self):
|
||||
"""Test all callbacks fire in correct order."""
|
||||
events = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_message_begin=lambda: events.append("begin"),
|
||||
on_url=lambda url: events.append(("url", url)),
|
||||
on_header=lambda n, v: events.append(("header", n, v)),
|
||||
on_headers_complete=lambda: events.append("headers_complete"),
|
||||
on_body=lambda chunk: events.append(("body", chunk)),
|
||||
on_message_complete=lambda: events.append("complete"),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
|
||||
assert events[0] == "begin"
|
||||
assert events[1] == ("url", b"/")
|
||||
assert events[2] == ("header", b"host", b"example.com")
|
||||
assert events[3] == ("header", b"content-length", b"5")
|
||||
assert events[4] == "headers_complete"
|
||||
assert events[5] == ("body", b"hello")
|
||||
assert events[6] == "complete"
|
||||
|
||||
def test_skip_body_callback(self):
|
||||
"""Test on_headers_complete returning True skips body."""
|
||||
body_chunks = []
|
||||
|
||||
parser = PythonProtocol(
|
||||
on_headers_complete=lambda: True, # Skip body
|
||||
on_body=lambda chunk: body_chunks.append(chunk),
|
||||
)
|
||||
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 5\r\n"
|
||||
b"\r\n"
|
||||
b"hello"
|
||||
)
|
||||
|
||||
# Body should be skipped
|
||||
assert parser.is_complete is True
|
||||
assert len(body_chunks) == 0
|
||||
|
||||
|
||||
class TestCallbackRequest:
|
||||
"""Test CallbackRequest adapter."""
|
||||
|
||||
def test_from_parser_simple(self):
|
||||
"""Test creating request from parser state."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET /path?query=value HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.method == "GET"
|
||||
assert request.path == "/path"
|
||||
assert request.query == "query=value"
|
||||
assert request.uri == "/path?query=value"
|
||||
assert request.version == (1, 1)
|
||||
assert request.scheme == "http"
|
||||
|
||||
def test_from_parser_ssl(self):
|
||||
"""Test SSL scheme detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
request = CallbackRequest.from_parser(parser, is_ssl=True)
|
||||
|
||||
assert request.scheme == "https"
|
||||
|
||||
def test_from_parser_headers(self):
|
||||
"""Test header conversion."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Type: text/plain\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
# String headers (uppercase)
|
||||
assert ("HOST", "example.com") in request.headers
|
||||
assert ("CONTENT-TYPE", "text/plain") in request.headers
|
||||
|
||||
# Bytes headers (lowercase)
|
||||
assert (b"host", b"example.com") in request.headers_bytes
|
||||
assert (b"content-type", b"text/plain") in request.headers_bytes
|
||||
|
||||
def test_from_parser_body_info(self):
|
||||
"""Test body info extraction."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.content_length == 10
|
||||
assert request.chunked is False
|
||||
|
||||
def test_from_parser_chunked(self):
|
||||
"""Test chunked transfer detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Transfer-Encoding: chunked\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.chunked is True
|
||||
|
||||
def test_should_close(self):
|
||||
"""Test should_close method."""
|
||||
# HTTP/1.1 with Connection: close
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is True
|
||||
|
||||
# HTTP/1.1 keep-alive (default)
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is False
|
||||
|
||||
# HTTP/1.0 (default close)
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.0\r\n\r\n")
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
assert request.should_close() is True
|
||||
|
||||
def test_get_header(self):
|
||||
"""Test get_header method."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"GET / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"X-Custom: value\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request.get_header("Host") == "example.com"
|
||||
assert request.get_header("x-custom") == "value"
|
||||
assert request.get_header("X-CUSTOM") == "value"
|
||||
assert request.get_header("X-Missing") is None
|
||||
|
||||
def test_expect_100_continue(self):
|
||||
"""Test Expect: 100-continue detection."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(
|
||||
b"POST / HTTP/1.1\r\n"
|
||||
b"Host: example.com\r\n"
|
||||
b"Expect: 100-continue\r\n"
|
||||
b"Content-Length: 10\r\n"
|
||||
b"\r\n"
|
||||
)
|
||||
|
||||
request = CallbackRequest.from_parser(parser)
|
||||
|
||||
assert request._expect_100_continue is True
|
||||
|
||||
|
||||
class TestPythonProtocolConnectionClose:
|
||||
"""Test connection close handling."""
|
||||
|
||||
def test_connection_close_header(self):
|
||||
"""Test Connection: close header."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
|
||||
def test_connection_keepalive_header(self):
|
||||
"""Test Connection: keep-alive header."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_http11_default_keepalive(self):
|
||||
"""Test HTTP/1.1 defaults to keep-alive."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is True
|
||||
|
||||
def test_http10_default_close(self):
|
||||
"""Test HTTP/1.0 defaults to close."""
|
||||
parser = PythonProtocol()
|
||||
parser.feed(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n")
|
||||
|
||||
assert parser.should_keep_alive is False
|
||||
Loading…
x
Reference in New Issue
Block a user