diff --git a/gunicorn/asgi/parser.py b/gunicorn/asgi/parser.py index ac5b0d0f..fbc74d51 100644 --- a/gunicorn/asgi/parser.py +++ b/gunicorn/asgi/parser.py @@ -921,3 +921,424 @@ class FastAsyncRequest: data = await self.read_body(8192) if not data: break + + +class ParseError(Exception): + """Error raised during HTTP parsing.""" + pass + + +class PythonProtocol: + """Callback-based HTTP/1.1 parser (pure Python fallback). + + Mirrors H1CProtocol interface for seamless switching between + the C extension and pure Python implementations. + + Callbacks: + on_message_begin: () -> None - Called when request starts + on_url: (url: bytes) -> None - Called with request URL/path + on_header: (name: bytes, value: bytes) -> None - Called for each header + on_headers_complete: () -> bool - Called when headers done (return True to skip body) + on_body: (chunk: bytes) -> None - Called with body data chunks + on_message_complete: () -> None - Called when request is complete + """ + + __slots__ = ( + '_on_message_begin', '_on_url', '_on_header', + '_on_headers_complete', '_on_body', '_on_message_complete', + '_state', '_buffer', '_headers_list', + 'method', 'path', 'http_version', 'headers', + 'content_length', 'is_chunked', 'should_keep_alive', 'is_complete', + '_body_remaining', '_skip_body', + '_chunk_state', '_chunk_size', '_chunk_remaining', + ) + + def __init__( + self, + on_message_begin=None, + on_url=None, + on_header=None, + on_headers_complete=None, + on_body=None, + on_message_complete=None, + ): + self._on_message_begin = on_message_begin + self._on_url = on_url + self._on_header = on_header + self._on_headers_complete = on_headers_complete + self._on_body = on_body + self._on_message_complete = on_message_complete + + # Parser state: request_line, headers, body, chunked_size, chunked_data, complete + self._state = 'request_line' + self._buffer = bytearray() + self._headers_list = [] + + # Request info (populated during parsing) + self.method = None + self.path = None + self.http_version = None + self.headers = [] + self.content_length = None + self.is_chunked = False + self.should_keep_alive = True + self.is_complete = False + + # Body state + self._body_remaining = 0 + self._skip_body = False + + # Chunked transfer state + self._chunk_state = 'size' # size, data, trailer + self._chunk_size = 0 + self._chunk_remaining = 0 + + def feed(self, data): + """Process data, fire callbacks synchronously. + + Args: + data: bytes or bytearray of incoming data + + Raises: + ParseError: If the HTTP request is malformed + """ + self._buffer.extend(data) + + while self._buffer: + if self._state == 'request_line': + if not self._parse_request_line(): + break + elif self._state == 'headers': + if not self._parse_headers(): + break + elif self._state == 'body': + if not self._parse_body(): + break + elif self._state == 'chunked': + if not self._parse_chunked_body(): + break + else: + break + + def reset(self): + """Reset for next request (keepalive).""" + self._state = 'request_line' + self._buffer.clear() + self._headers_list = [] + self.method = None + self.path = None + self.http_version = None + self.headers = [] + self.content_length = None + self.is_chunked = False + self.should_keep_alive = True + self.is_complete = False + self._body_remaining = 0 + self._skip_body = False + self._chunk_state = 'size' + self._chunk_size = 0 + self._chunk_remaining = 0 + + def _parse_request_line(self): + """Parse request line, return True if complete.""" + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + # Parse: METHOD PATH HTTP/x.y + parts = line.split(b' ', 2) + if len(parts) != 3: + raise ParseError("Invalid request line") + + self.method = parts[0] + self.path = parts[1] + + # Parse version + version = parts[2] + if version == b'HTTP/1.1': + self.http_version = (1, 1) + elif version == b'HTTP/1.0': + self.http_version = (1, 0) + else: + raise ParseError("Unsupported HTTP version") + + if self._on_message_begin: + self._on_message_begin() + if self._on_url: + self._on_url(self.path) + + self._state = 'headers' + return True + + def _parse_headers(self): + """Parse headers, return True if headers are complete.""" + while True: + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + if not line: + # Empty line = end of headers + self._finalize_headers() + return True + + # Parse header + colon = line.find(b':') + if colon == -1: + raise ParseError("Invalid header") + + name = line[:colon].strip().lower() + value = line[colon + 1:].strip() + + self._headers_list.append((name, value)) + + if self._on_header: + self._on_header(name, value) + + def _finalize_headers(self): + """Called when all headers received.""" + self.headers = self._headers_list + + # Extract content-length and chunked + for name, value in self.headers: + if name == b'content-length': + self.content_length = int(value) + self._body_remaining = self.content_length + elif name == b'transfer-encoding': + self.is_chunked = b'chunked' in value.lower() + elif name == b'connection': + val = value.lower() + if b'close' in val: + self.should_keep_alive = False + elif b'keep-alive' in val: + self.should_keep_alive = True + + # HTTP/1.0 defaults to close + if self.http_version == (1, 0) and self.should_keep_alive: + # Only keep-alive if explicitly requested + has_keepalive = any( + name == b'connection' and b'keep-alive' in value.lower() + for name, value in self.headers + ) + if not has_keepalive: + self.should_keep_alive = False + + if self._on_headers_complete: + self._skip_body = self._on_headers_complete() + + # Determine next state + if self._skip_body: + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + elif self.is_chunked: + self._state = 'chunked' + self._chunk_state = 'size' + elif self.content_length and self.content_length > 0: + self._state = 'body' + else: + # No body + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + + def _parse_body(self): + """Parse Content-Length delimited body.""" + if not self._buffer or self._body_remaining <= 0: + return False + + chunk_size = min(len(self._buffer), self._body_remaining) + chunk = bytes(self._buffer[:chunk_size]) + del self._buffer[:chunk_size] + self._body_remaining -= chunk_size + + if self._on_body: + self._on_body(chunk) + + if self._body_remaining <= 0: + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + + return True + + def _parse_chunked_body(self): + """Parse chunked transfer encoding.""" + while self._buffer: + if self._chunk_state == 'size': + # Looking for chunk size line + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + size_line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + # Handle chunk extensions (e.g., "5;ext=value") + semicolon = size_line.find(b';') + if semicolon != -1: + size_line = size_line[:semicolon].strip() + + try: + self._chunk_size = int(size_line, 16) + except ValueError: + raise ParseError("Invalid chunk size") + + if self._chunk_size == 0: + # Final chunk - skip trailers + self._chunk_state = 'trailer' + else: + self._chunk_remaining = self._chunk_size + self._chunk_state = 'data' + + elif self._chunk_state == 'data': + # Reading chunk data + if not self._buffer: + return False + + to_read = min(len(self._buffer), self._chunk_remaining) + chunk = bytes(self._buffer[:to_read]) + del self._buffer[:to_read] + self._chunk_remaining -= to_read + + if self._on_body: + self._on_body(chunk) + + if self._chunk_remaining == 0: + # Need to consume trailing CRLF + self._chunk_state = 'crlf' + + elif self._chunk_state == 'crlf': + # Skip CRLF after chunk data + if len(self._buffer) < 2: + return False + del self._buffer[:2] # Skip \r\n + self._chunk_state = 'size' + + elif self._chunk_state == 'trailer': + # Skip trailer headers + idx = self._buffer.find(b'\r\n') + if idx == -1: + return False + + line = bytes(self._buffer[:idx]) + del self._buffer[:idx + 2] + + if not line: + # Empty line = end of trailers + self._state = 'complete' + self.is_complete = True + if self._on_message_complete: + self._on_message_complete() + return True + + return False + + +class CallbackRequest: + """Request object built from callback parser state. + + Works with both H1CProtocol (C extension) and PythonProtocol. + """ + + __slots__ = ( + 'method', 'uri', 'path', 'query', 'fragment', 'version', + 'headers', 'headers_bytes', 'scheme', + 'content_length', 'chunked', 'must_close', + 'proxy_protocol_info', '_expect_100_continue', + ) + + def __init__(self): + self.method = None + self.uri = None + self.path = None + self.query = None + self.fragment = None + self.version = None + self.headers = [] + self.headers_bytes = [] + self.scheme = "http" + self.content_length = 0 + self.chunked = False + self.must_close = False + self.proxy_protocol_info = None + self._expect_100_continue = False + + @classmethod + def from_parser(cls, parser, is_ssl=False): + """Build request from callback parser state. + + Args: + parser: H1CProtocol or PythonProtocol instance + is_ssl: Whether connection is SSL/TLS + + Returns: + CallbackRequest instance + """ + req = cls() + req.method = parser.method.decode('ascii') + + # Parse path and query from URL + raw_path = parser.path + if b'?' in raw_path: + path_part, query_part = raw_path.split(b'?', 1) + req.path = path_part.decode('latin-1') + req.query = query_part.decode('latin-1') + else: + req.path = raw_path.decode('latin-1') + req.query = '' + + req.uri = raw_path.decode('latin-1') + req.fragment = '' + req.version = parser.http_version + + # Headers - store both bytes (for ASGI scope) and strings (for compatibility) + req.headers_bytes = list(parser.headers) + req.headers = [ + (n.decode('latin-1').upper(), v.decode('latin-1')) + for n, v in parser.headers + ] + + req.scheme = 'https' if is_ssl else 'http' + req.content_length = parser.content_length or 0 + req.chunked = parser.is_chunked + req.must_close = not parser.should_keep_alive + + # Check for Expect: 100-continue + for name, value in parser.headers: + if name == b'expect' and value.lower() == b'100-continue': + req._expect_100_continue = True + break + + return req + + def should_close(self): + """Check if connection should be closed after this request.""" + if self.must_close: + return True + for name, value in self.headers: + if name == "CONNECTION": + v = value.lower().strip(" \t") + if v == "close": + return True + elif v == "keep-alive": + return False + break + return self.version <= (1, 0) + + def get_header(self, name): + """Get a header value by name (case-insensitive).""" + name = name.upper() + for h, v in self.headers: + if h == name: + return v + return None diff --git a/tests/test_python_protocol.py b/tests/test_python_protocol.py new file mode 100644 index 00000000..a8c6b8a8 --- /dev/null +++ b/tests/test_python_protocol.py @@ -0,0 +1,518 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Tests for PythonProtocol callback-based HTTP parser. +""" + +import pytest +from gunicorn.asgi.parser import PythonProtocol, CallbackRequest, ParseError + + +class TestPythonProtocolBasic: + """Test basic request parsing.""" + + def test_simple_get_request(self): + """Test parsing a simple GET request.""" + headers_complete = [] + message_complete = [] + + parser = PythonProtocol( + on_headers_complete=lambda: headers_complete.append(True), + on_message_complete=lambda: message_complete.append(True), + ) + + data = b"GET /path HTTP/1.1\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.method == b"GET" + assert parser.path == b"/path" + assert parser.http_version == (1, 1) + assert len(parser.headers) == 1 + assert parser.headers[0] == (b"host", b"example.com") + assert parser.is_complete is True + assert len(headers_complete) == 1 + assert len(message_complete) == 1 + + def test_get_with_query_string(self): + """Test parsing GET with query string.""" + parser = PythonProtocol() + + data = b"GET /search?q=test&page=1 HTTP/1.1\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.method == b"GET" + assert parser.path == b"/search?q=test&page=1" + assert parser.is_complete is True + + def test_http_10_request(self): + """Test parsing HTTP/1.0 request.""" + parser = PythonProtocol() + + data = b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n" + parser.feed(data) + + assert parser.http_version == (1, 0) + assert parser.should_keep_alive is False # HTTP/1.0 default + + def test_http_10_with_keepalive(self): + """Test HTTP/1.0 with explicit keep-alive.""" + parser = PythonProtocol() + + data = b"GET / HTTP/1.0\r\nHost: example.com\r\nConnection: keep-alive\r\n\r\n" + parser.feed(data) + + assert parser.http_version == (1, 0) + assert parser.should_keep_alive is True + + def test_multiple_headers(self): + """Test parsing multiple headers.""" + parser = PythonProtocol() + + data = ( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Accept: text/html\r\n" + b"Accept-Language: en-US\r\n" + b"User-Agent: Test/1.0\r\n" + b"\r\n" + ) + parser.feed(data) + + assert len(parser.headers) == 4 + header_names = [h[0] for h in parser.headers] + assert b"host" in header_names + assert b"accept" in header_names + assert b"accept-language" in header_names + assert b"user-agent" in header_names + + +class TestPythonProtocolBody: + """Test request body parsing.""" + + def test_post_with_content_length(self): + """Test POST with Content-Length body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST /submit HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 13\r\n" + b"\r\n" + b"name=testuser" + ) + parser.feed(data) + + assert parser.method == b"POST" + assert parser.content_length == 13 + assert parser.is_complete is True + assert b"".join(body_chunks) == b"name=testuser" + + def test_chunked_body(self): + """Test chunked transfer encoding.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\nhello\r\n" + b"6\r\n world\r\n" + b"0\r\n\r\n" + ) + parser.feed(data) + + assert parser.is_chunked is True + assert parser.is_complete is True + assert b"".join(body_chunks) == b"hello world" + + def test_chunked_with_extension(self): + """Test chunked with chunk extension.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + data = ( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5;ext=value\r\nhello\r\n" + b"0\r\n\r\n" + ) + parser.feed(data) + + assert b"".join(body_chunks) == b"hello" + + +class TestPythonProtocolIncremental: + """Test incremental/partial data feeding.""" + + def test_partial_request_line(self): + """Test feeding partial request line.""" + parser = PythonProtocol() + + # Feed partial request line + parser.feed(b"GET /path ") + assert parser.method is None + assert parser.is_complete is False + + # Complete the request line and headers + parser.feed(b"HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.method == b"GET" + assert parser.is_complete is True + + def test_partial_headers(self): + """Test feeding partial headers.""" + parser = PythonProtocol() + + parser.feed(b"GET / HTTP/1.1\r\n") + parser.feed(b"Host: exa") + assert parser.is_complete is False + + parser.feed(b"mple.com\r\n\r\n") + assert parser.is_complete is True + assert parser.headers[0] == (b"host", b"example.com") + + def test_partial_body(self): + """Test feeding partial body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + b"hello" + ) + assert parser.is_complete is False + + parser.feed(b"world") + assert parser.is_complete is True + assert b"".join(body_chunks) == b"helloworld" + + def test_partial_chunked_body(self): + """Test feeding partial chunked body.""" + body_chunks = [] + + parser = PythonProtocol( + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"5\r\nhel" + ) + assert parser.is_complete is False + + parser.feed(b"lo\r\n0\r\n\r\n") + assert parser.is_complete is True + assert b"".join(body_chunks) == b"hello" + + +class TestPythonProtocolErrors: + """Test error handling.""" + + def test_invalid_request_line(self): + """Test invalid request line.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"INVALID\r\n") + + def test_invalid_header(self): + """Test invalid header (no colon).""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"GET / HTTP/1.1\r\nBadHeader\r\n\r\n") + + def test_unsupported_http_version(self): + """Test unsupported HTTP version.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed(b"GET / HTTP/2.0\r\n\r\n") + + def test_invalid_chunk_size(self): + """Test invalid chunk size.""" + parser = PythonProtocol() + + with pytest.raises(ParseError): + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + b"XYZ\r\n" # Invalid hex + ) + + +class TestPythonProtocolReset: + """Test parser reset for keepalive.""" + + def test_reset_clears_state(self): + """Test that reset clears all state.""" + parser = PythonProtocol() + + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.is_complete is True + assert parser.method == b"GET" + + parser.reset() + + assert parser.method is None + assert parser.path is None + assert parser.http_version is None + assert parser.headers == [] + assert parser.content_length is None + assert parser.is_chunked is False + assert parser.is_complete is False + + def test_multiple_requests_keepalive(self): + """Test handling multiple requests on keepalive connection.""" + parser = PythonProtocol() + + # First request + parser.feed(b"GET /first HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.path == b"/first" + assert parser.is_complete is True + + parser.reset() + + # Second request + parser.feed(b"GET /second HTTP/1.1\r\nHost: example.com\r\n\r\n") + assert parser.path == b"/second" + assert parser.is_complete is True + + +class TestPythonProtocolCallbacks: + """Test callback firing.""" + + def test_all_callbacks(self): + """Test all callbacks fire in correct order.""" + events = [] + + parser = PythonProtocol( + on_message_begin=lambda: events.append("begin"), + on_url=lambda url: events.append(("url", url)), + on_header=lambda n, v: events.append(("header", n, v)), + on_headers_complete=lambda: events.append("headers_complete"), + on_body=lambda chunk: events.append(("body", chunk)), + on_message_complete=lambda: events.append("complete"), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + + assert events[0] == "begin" + assert events[1] == ("url", b"/") + assert events[2] == ("header", b"host", b"example.com") + assert events[3] == ("header", b"content-length", b"5") + assert events[4] == "headers_complete" + assert events[5] == ("body", b"hello") + assert events[6] == "complete" + + def test_skip_body_callback(self): + """Test on_headers_complete returning True skips body.""" + body_chunks = [] + + parser = PythonProtocol( + on_headers_complete=lambda: True, # Skip body + on_body=lambda chunk: body_chunks.append(chunk), + ) + + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 5\r\n" + b"\r\n" + b"hello" + ) + + # Body should be skipped + assert parser.is_complete is True + assert len(body_chunks) == 0 + + +class TestCallbackRequest: + """Test CallbackRequest adapter.""" + + def test_from_parser_simple(self): + """Test creating request from parser state.""" + parser = PythonProtocol() + parser.feed(b"GET /path?query=value HTTP/1.1\r\nHost: example.com\r\n\r\n") + + request = CallbackRequest.from_parser(parser) + + assert request.method == "GET" + assert request.path == "/path" + assert request.query == "query=value" + assert request.uri == "/path?query=value" + assert request.version == (1, 1) + assert request.scheme == "http" + + def test_from_parser_ssl(self): + """Test SSL scheme detection.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + request = CallbackRequest.from_parser(parser, is_ssl=True) + + assert request.scheme == "https" + + def test_from_parser_headers(self): + """Test header conversion.""" + parser = PythonProtocol() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Type: text/plain\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + # String headers (uppercase) + assert ("HOST", "example.com") in request.headers + assert ("CONTENT-TYPE", "text/plain") in request.headers + + # Bytes headers (lowercase) + assert (b"host", b"example.com") in request.headers_bytes + assert (b"content-type", b"text/plain") in request.headers_bytes + + def test_from_parser_body_info(self): + """Test body info extraction.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.content_length == 10 + assert request.chunked is False + + def test_from_parser_chunked(self): + """Test chunked transfer detection.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Transfer-Encoding: chunked\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.chunked is True + + def test_should_close(self): + """Test should_close method.""" + # HTTP/1.1 with Connection: close + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is True + + # HTTP/1.1 keep-alive (default) + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is False + + # HTTP/1.0 (default close) + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.0\r\n\r\n") + request = CallbackRequest.from_parser(parser) + assert request.should_close() is True + + def test_get_header(self): + """Test get_header method.""" + parser = PythonProtocol() + parser.feed( + b"GET / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"X-Custom: value\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request.get_header("Host") == "example.com" + assert request.get_header("x-custom") == "value" + assert request.get_header("X-CUSTOM") == "value" + assert request.get_header("X-Missing") is None + + def test_expect_100_continue(self): + """Test Expect: 100-continue detection.""" + parser = PythonProtocol() + parser.feed( + b"POST / HTTP/1.1\r\n" + b"Host: example.com\r\n" + b"Expect: 100-continue\r\n" + b"Content-Length: 10\r\n" + b"\r\n" + ) + + request = CallbackRequest.from_parser(parser) + + assert request._expect_100_continue is True + + +class TestPythonProtocolConnectionClose: + """Test connection close handling.""" + + def test_connection_close_header(self): + """Test Connection: close header.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: close\r\n\r\n") + + assert parser.should_keep_alive is False + + def test_connection_keepalive_header(self): + """Test Connection: keep-alive header.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nConnection: keep-alive\r\n\r\n") + + assert parser.should_keep_alive is True + + def test_http11_default_keepalive(self): + """Test HTTP/1.1 defaults to keep-alive.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + + assert parser.should_keep_alive is True + + def test_http10_default_close(self): + """Test HTTP/1.0 defaults to close.""" + parser = PythonProtocol() + parser.feed(b"GET / HTTP/1.0\r\nHost: example.com\r\n\r\n") + + assert parser.should_keep_alive is False