From 0e0dc669c81c25c15ddc0487efa51e74826d8bed Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 22:55:03 +0100 Subject: [PATCH 01/11] feat(dirty): add TLV binary encoder/decoder Implement TLV (Type-Length-Value) serialization layer for the binary dirty worker protocol. This enables efficient binary data transfer without base64 encoding overhead. Supported types: - None, bool, int64, float64 - bytes (raw binary, no encoding needed) - string (UTF-8) - list, dict (nested structures) Inspired by OpenBSD msgctl/msgsnd message format. --- gunicorn/dirty/tlv.py | 305 ++++++++++++++++++++++ tests/test_dirty_tlv.py | 553 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 858 insertions(+) create mode 100644 gunicorn/dirty/tlv.py create mode 100644 tests/test_dirty_tlv.py diff --git a/gunicorn/dirty/tlv.py b/gunicorn/dirty/tlv.py new file mode 100644 index 00000000..5682b0c6 --- /dev/null +++ b/gunicorn/dirty/tlv.py @@ -0,0 +1,305 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +TLV (Type-Length-Value) Binary Encoder/Decoder + +Provides efficient binary serialization for dirty worker protocol messages. +Inspired by OpenBSD msgctl/msgsnd message format. + +Type Codes: + 0x00: None (no value bytes) + 0x01: bool (1 byte: 0x00 or 0x01) + 0x05: int64 (8 bytes big-endian signed) + 0x06: float64 (8 bytes IEEE 754) + 0x10: bytes (4-byte length + raw bytes) + 0x11: string (4-byte length + UTF-8 encoded) + 0x20: list (4-byte count + encoded elements) + 0x21: dict (4-byte count + encoded key-value pairs) +""" + +import struct + +from .errors import DirtyProtocolError + + +# Type codes +TYPE_NONE = 0x00 +TYPE_BOOL = 0x01 +TYPE_INT64 = 0x05 +TYPE_FLOAT64 = 0x06 +TYPE_BYTES = 0x10 +TYPE_STRING = 0x11 +TYPE_LIST = 0x20 +TYPE_DICT = 0x21 + +# Maximum sizes for safety +MAX_STRING_SIZE = 64 * 1024 * 1024 # 64 MB +MAX_BYTES_SIZE = 64 * 1024 * 1024 # 64 MB +MAX_LIST_SIZE = 1024 * 1024 # 1 million items +MAX_DICT_SIZE = 1024 * 1024 # 1 million items + + +class TLVEncoder: + """ + TLV binary encoder/decoder. + + Encodes Python values to binary TLV format and decodes back. + Supports: None, bool, int, float, bytes, str, list, dict. + """ + + @staticmethod + def encode(value) -> bytes: + """ + Encode a Python value to TLV binary format. + + Args: + value: Python value to encode (None, bool, int, float, + bytes, str, list, or dict) + + Returns: + bytes: TLV-encoded binary data + + Raises: + DirtyProtocolError: If value type is not supported + """ + if value is None: + return bytes([TYPE_NONE]) + + if isinstance(value, bool): + # bool must come before int since bool is a subclass of int + return bytes([TYPE_BOOL, 0x01 if value else 0x00]) + + if isinstance(value, int): + return bytes([TYPE_INT64]) + struct.pack(">q", value) + + if isinstance(value, float): + return bytes([TYPE_FLOAT64]) + struct.pack(">d", value) + + if isinstance(value, bytes): + if len(value) > MAX_BYTES_SIZE: + raise DirtyProtocolError( + f"Bytes too large: {len(value)} bytes " + f"(max: {MAX_BYTES_SIZE})" + ) + return bytes([TYPE_BYTES]) + struct.pack(">I", len(value)) + value + + if isinstance(value, str): + encoded = value.encode("utf-8") + if len(encoded) > MAX_STRING_SIZE: + raise DirtyProtocolError( + f"String too large: {len(encoded)} bytes " + f"(max: {MAX_STRING_SIZE})" + ) + return bytes([TYPE_STRING]) + struct.pack(">I", len(encoded)) + encoded + + if isinstance(value, (list, tuple)): + if len(value) > MAX_LIST_SIZE: + raise DirtyProtocolError( + f"List too large: {len(value)} items " + f"(max: {MAX_LIST_SIZE})" + ) + parts = [bytes([TYPE_LIST]), struct.pack(">I", len(value))] + for item in value: + parts.append(TLVEncoder.encode(item)) + return b"".join(parts) + + if isinstance(value, dict): + if len(value) > MAX_DICT_SIZE: + raise DirtyProtocolError( + f"Dict too large: {len(value)} items " + f"(max: {MAX_DICT_SIZE})" + ) + parts = [bytes([TYPE_DICT]), struct.pack(">I", len(value))] + for k, v in value.items(): + # Keys must be strings + if not isinstance(k, str): + raise DirtyProtocolError( + f"Dict keys must be strings, got {type(k).__name__}" + ) + parts.append(TLVEncoder.encode(k)) + parts.append(TLVEncoder.encode(v)) + return b"".join(parts) + + raise DirtyProtocolError( + f"Unsupported type for TLV encoding: {type(value).__name__}" + ) + + @staticmethod + def decode(data: bytes, offset: int = 0) -> tuple: + """ + Decode a TLV-encoded value from binary data. + + Args: + data: Binary data to decode + offset: Starting offset in the data + + Returns: + tuple: (decoded_value, new_offset) + + Raises: + DirtyProtocolError: If data is malformed or truncated + """ + if offset >= len(data): + raise DirtyProtocolError( + "Truncated TLV data: no type byte", + raw_data=data[offset:offset + 20] + ) + + type_code = data[offset] + offset += 1 + + if type_code == TYPE_NONE: + return None, offset + + if type_code == TYPE_BOOL: + if offset >= len(data): + raise DirtyProtocolError( + "Truncated TLV data: missing bool value", + raw_data=data[offset - 1:offset + 20] + ) + value = data[offset] != 0x00 + return value, offset + 1 + + if type_code == TYPE_INT64: + if offset + 8 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete int64", + raw_data=data[offset - 1:offset + 20] + ) + value = struct.unpack(">q", data[offset:offset + 8])[0] + return value, offset + 8 + + if type_code == TYPE_FLOAT64: + if offset + 8 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete float64", + raw_data=data[offset - 1:offset + 20] + ) + value = struct.unpack(">d", data[offset:offset + 8])[0] + return value, offset + 8 + + if type_code == TYPE_BYTES: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete bytes length", + raw_data=data[offset - 1:offset + 20] + ) + length = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if length > MAX_BYTES_SIZE: + raise DirtyProtocolError( + f"Bytes too large: {length} bytes (max: {MAX_BYTES_SIZE})" + ) + + if offset + length > len(data): + raise DirtyProtocolError( + f"Truncated TLV data: expected {length} bytes, " + f"got {len(data) - offset}", + raw_data=data[offset - 5:offset + 20] + ) + value = data[offset:offset + length] + return value, offset + length + + if type_code == TYPE_STRING: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete string length", + raw_data=data[offset - 1:offset + 20] + ) + length = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if length > MAX_STRING_SIZE: + raise DirtyProtocolError( + f"String too large: {length} bytes (max: {MAX_STRING_SIZE})" + ) + + if offset + length > len(data): + raise DirtyProtocolError( + f"Truncated TLV data: expected {length} bytes for string, " + f"got {len(data) - offset}", + raw_data=data[offset - 5:offset + 20] + ) + try: + value = data[offset:offset + length].decode("utf-8") + except UnicodeDecodeError as e: + raise DirtyProtocolError( + f"Invalid UTF-8 in string: {e}", + raw_data=data[offset:offset + min(length, 20)] + ) + return value, offset + length + + if type_code == TYPE_LIST: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete list count", + raw_data=data[offset - 1:offset + 20] + ) + count = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if count > MAX_LIST_SIZE: + raise DirtyProtocolError( + f"List too large: {count} items (max: {MAX_LIST_SIZE})" + ) + + items = [] + for _ in range(count): + item, offset = TLVEncoder.decode(data, offset) + items.append(item) + return items, offset + + if type_code == TYPE_DICT: + if offset + 4 > len(data): + raise DirtyProtocolError( + "Truncated TLV data: incomplete dict count", + raw_data=data[offset - 1:offset + 20] + ) + count = struct.unpack(">I", data[offset:offset + 4])[0] + offset += 4 + + if count > MAX_DICT_SIZE: + raise DirtyProtocolError( + f"Dict too large: {count} items (max: {MAX_DICT_SIZE})" + ) + + result = {} + for _ in range(count): + key, offset = TLVEncoder.decode(data, offset) + if not isinstance(key, str): + raise DirtyProtocolError( + f"Dict key must be string, got {type(key).__name__}" + ) + value, offset = TLVEncoder.decode(data, offset) + result[key] = value + return result, offset + + raise DirtyProtocolError( + f"Unknown TLV type code: 0x{type_code:02x}", + raw_data=data[offset - 1:offset + 20] + ) + + @staticmethod + def decode_full(data: bytes): + """ + Decode a complete TLV-encoded value, ensuring all data is consumed. + + Args: + data: Binary data to decode + + Returns: + Decoded Python value + + Raises: + DirtyProtocolError: If data is malformed or has trailing bytes + """ + value, offset = TLVEncoder.decode(data, 0) + if offset != len(data): + raise DirtyProtocolError( + f"Trailing data after TLV: {len(data) - offset} bytes", + raw_data=data[offset:offset + 20] + ) + return value diff --git a/tests/test_dirty_tlv.py b/tests/test_dirty_tlv.py new file mode 100644 index 00000000..87727203 --- /dev/null +++ b/tests/test_dirty_tlv.py @@ -0,0 +1,553 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for dirty TLV binary encoder/decoder.""" + +import math +import struct +import pytest + +from gunicorn.dirty.tlv import ( + TLVEncoder, + TYPE_NONE, + TYPE_BOOL, + TYPE_INT64, + TYPE_FLOAT64, + TYPE_BYTES, + TYPE_STRING, + TYPE_LIST, + TYPE_DICT, + MAX_STRING_SIZE, + MAX_BYTES_SIZE, + MAX_LIST_SIZE, + MAX_DICT_SIZE, +) +from gunicorn.dirty.errors import DirtyProtocolError + + +class TestTLVEncoderBasicTypes: + """Tests for basic type encoding/decoding.""" + + def test_encode_decode_none(self): + """Test None encoding/decoding.""" + encoded = TLVEncoder.encode(None) + assert encoded == bytes([TYPE_NONE]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is None + assert offset == 1 + + def test_encode_decode_true(self): + """Test True encoding/decoding.""" + encoded = TLVEncoder.encode(True) + assert encoded == bytes([TYPE_BOOL, 0x01]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is True + assert offset == 2 + + def test_encode_decode_false(self): + """Test False encoding/decoding.""" + encoded = TLVEncoder.encode(False) + assert encoded == bytes([TYPE_BOOL, 0x00]) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value is False + assert offset == 2 + + def test_encode_decode_positive_int(self): + """Test positive integer encoding/decoding.""" + encoded = TLVEncoder.encode(42) + assert encoded[0] == TYPE_INT64 + assert len(encoded) == 9 # 1 type + 8 value + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == 42 + assert offset == 9 + + def test_encode_decode_negative_int(self): + """Test negative integer encoding/decoding.""" + encoded = TLVEncoder.encode(-12345) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == -12345 + + def test_encode_decode_large_int(self): + """Test large integer encoding/decoding.""" + large_val = 2**62 + encoded = TLVEncoder.encode(large_val) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == large_val + + def test_encode_decode_zero(self): + """Test zero encoding/decoding.""" + encoded = TLVEncoder.encode(0) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == 0 + + def test_encode_decode_float(self): + """Test float encoding/decoding.""" + encoded = TLVEncoder.encode(3.14159) + assert encoded[0] == TYPE_FLOAT64 + assert len(encoded) == 9 # 1 type + 8 value + + value, offset = TLVEncoder.decode(encoded, 0) + assert abs(value - 3.14159) < 1e-10 + + def test_encode_decode_negative_float(self): + """Test negative float encoding/decoding.""" + encoded = TLVEncoder.encode(-273.15) + + value, offset = TLVEncoder.decode(encoded, 0) + assert abs(value - (-273.15)) < 1e-10 + + def test_encode_decode_float_infinity(self): + """Test infinity encoding/decoding.""" + encoded = TLVEncoder.encode(float('inf')) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == float('inf') + + def test_encode_decode_float_nan(self): + """Test NaN encoding/decoding.""" + encoded = TLVEncoder.encode(float('nan')) + + value, offset = TLVEncoder.decode(encoded, 0) + assert math.isnan(value) + + +class TestTLVEncoderBytes: + """Tests for bytes encoding/decoding.""" + + def test_encode_decode_empty_bytes(self): + """Test empty bytes encoding/decoding.""" + encoded = TLVEncoder.encode(b"") + assert encoded[0] == TYPE_BYTES + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == b"" + + def test_encode_decode_bytes(self): + """Test bytes encoding/decoding.""" + data = b"\x00\x01\x02\xff\xfe\xfd" + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_large_bytes(self): + """Test large bytes encoding/decoding.""" + data = b"x" * 10000 + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_bytes_too_large(self): + """Test that bytes exceeding max size raises error.""" + # We won't actually allocate MAX_BYTES_SIZE, just check the encoding + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(b"x" * (MAX_BYTES_SIZE + 1)) + assert "too large" in str(exc_info.value).lower() + + +class TestTLVEncoderString: + """Tests for string encoding/decoding.""" + + def test_encode_decode_empty_string(self): + """Test empty string encoding/decoding.""" + encoded = TLVEncoder.encode("") + assert encoded[0] == TYPE_STRING + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == "" + + def test_encode_decode_ascii_string(self): + """Test ASCII string encoding/decoding.""" + encoded = TLVEncoder.encode("hello world") + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == "hello world" + + def test_encode_decode_unicode_string(self): + """Test Unicode string encoding/decoding.""" + text = "Hello, world! \u00a9 \u2603 \U0001F600" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_chinese(self): + """Test Chinese characters encoding/decoding.""" + text = "Hello, world!" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_emoji(self): + """Test emoji encoding/decoding.""" + text = "Test emoji" + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + def test_encode_decode_large_string(self): + """Test large string encoding/decoding.""" + text = "x" * 10000 + encoded = TLVEncoder.encode(text) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == text + + +class TestTLVEncoderList: + """Tests for list encoding/decoding.""" + + def test_encode_decode_empty_list(self): + """Test empty list encoding/decoding.""" + encoded = TLVEncoder.encode([]) + assert encoded[0] == TYPE_LIST + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == [] + + def test_encode_decode_simple_list(self): + """Test simple list encoding/decoding.""" + data = [1, 2, 3] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_mixed_list(self): + """Test mixed type list encoding/decoding.""" + data = [1, "hello", 3.14, True, None, b"bytes"] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_nested_list(self): + """Test nested list encoding/decoding.""" + data = [[1, 2], [3, [4, 5]], ["a", "b"]] + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_tuple_as_list(self): + """Test that tuples are encoded as lists.""" + data = (1, 2, 3) + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == [1, 2, 3] # Decoded as list + + def test_encode_decode_large_list(self): + """Test large list encoding/decoding.""" + data = list(range(1000)) + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + +class TestTLVEncoderDict: + """Tests for dict encoding/decoding.""" + + def test_encode_decode_empty_dict(self): + """Test empty dict encoding/decoding.""" + encoded = TLVEncoder.encode({}) + assert encoded[0] == TYPE_DICT + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == {} + + def test_encode_decode_simple_dict(self): + """Test simple dict encoding/decoding.""" + data = {"a": 1, "b": 2, "c": 3} + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_mixed_values_dict(self): + """Test dict with mixed value types.""" + data = { + "int": 42, + "float": 3.14, + "string": "hello", + "bool": True, + "none": None, + "bytes": b"data", + "list": [1, 2, 3], + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_nested_dict(self): + """Test nested dict encoding/decoding.""" + data = { + "outer": { + "inner": { + "value": 42 + }, + "list": [{"a": 1}, {"b": 2}] + } + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_dict_non_string_key(self): + """Test that non-string keys raise error.""" + data = {1: "value"} + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(data) + assert "keys must be strings" in str(exc_info.value).lower() + + +class TestTLVEncoderComplexStructures: + """Tests for complex nested structures.""" + + def test_encode_decode_request_like(self): + """Test encoding/decoding a request-like structure.""" + data = { + "id": 12345, + "app_path": "myapp.ml:MLApp", + "action": "predict", + "args": [b"input_data", 0.7], + "kwargs": {"temperature": 0.7, "max_tokens": 1000}, + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_response_like(self): + """Test encoding/decoding a response-like structure.""" + data = { + "id": 12345, + "result": { + "predictions": [0.1, 0.2, 0.7], + "metadata": {"model": "v1.0", "latency_ms": 42}, + } + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + def test_encode_decode_deeply_nested(self): + """Test deeply nested structures.""" + data = {"a": {"b": {"c": {"d": {"e": {"f": "deep"}}}}}} + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data + + +class TestTLVEncoderRoundtrip: + """Tests for complete roundtrip using decode_full.""" + + def test_decode_full_simple(self): + """Test decode_full with simple value.""" + data = {"key": "value"} + encoded = TLVEncoder.encode(data) + + value = TLVEncoder.decode_full(encoded) + assert value == data + + def test_decode_full_trailing_data(self): + """Test decode_full raises on trailing data.""" + encoded = TLVEncoder.encode(42) + b"extra" + + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode_full(encoded) + assert "trailing" in str(exc_info.value).lower() + + +class TestTLVEncoderErrors: + """Tests for error handling.""" + + def test_decode_empty_data(self): + """Test decoding empty data raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(b"", 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_int(self): + """Test decoding truncated int raises error.""" + # TYPE_INT64 followed by only 4 bytes instead of 8 + data = bytes([TYPE_INT64, 0, 0, 0, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_float(self): + """Test decoding truncated float raises error.""" + data = bytes([TYPE_FLOAT64, 0, 0, 0, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_bytes_length(self): + """Test decoding truncated bytes length raises error.""" + data = bytes([TYPE_BYTES, 0, 0]) # Only 2 bytes of length + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_bytes_data(self): + """Test decoding truncated bytes data raises error.""" + # Says 10 bytes but only provides 5 + data = bytes([TYPE_BYTES]) + struct.pack(">I", 10) + b"12345" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_string_length(self): + """Test decoding truncated string length raises error.""" + data = bytes([TYPE_STRING, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_string_data(self): + """Test decoding truncated string data raises error.""" + data = bytes([TYPE_STRING]) + struct.pack(">I", 10) + b"hello" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_invalid_utf8(self): + """Test decoding invalid UTF-8 raises error.""" + # Valid length, but invalid UTF-8 bytes + data = bytes([TYPE_STRING]) + struct.pack(">I", 3) + b"\x80\x81\x82" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "utf-8" in str(exc_info.value).lower() + + def test_decode_truncated_list_count(self): + """Test decoding truncated list count raises error.""" + data = bytes([TYPE_LIST, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_truncated_dict_count(self): + """Test decoding truncated dict count raises error.""" + data = bytes([TYPE_DICT, 0]) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "truncated" in str(exc_info.value).lower() + + def test_decode_unknown_type(self): + """Test decoding unknown type raises error.""" + data = bytes([0xFF]) # Unknown type + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "unknown" in str(exc_info.value).lower() + + def test_encode_unsupported_type(self): + """Test encoding unsupported type raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(object()) + assert "unsupported type" in str(exc_info.value).lower() + + def test_encode_function_raises_error(self): + """Test encoding a function raises error.""" + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.encode(lambda x: x) + assert "unsupported type" in str(exc_info.value).lower() + + def test_decode_dict_non_string_key_in_data(self): + """Test decoding dict with non-string key raises error.""" + # Manually construct a dict with int key + # TYPE_DICT, count=1, TYPE_INT64 key, TYPE_INT64 value + data = ( + bytes([TYPE_DICT]) + + struct.pack(">I", 1) + + bytes([TYPE_INT64]) + + struct.pack(">q", 1) # Key (int, not string) + + bytes([TYPE_INT64]) + + struct.pack(">q", 2) # Value + ) + with pytest.raises(DirtyProtocolError) as exc_info: + TLVEncoder.decode(data, 0) + assert "string" in str(exc_info.value).lower() + + +class TestTLVEncoderOffset: + """Tests for offset handling.""" + + def test_decode_with_offset(self): + """Test decoding from specific offset.""" + # Create data with prefix + prefix = b"garbage" + encoded = TLVEncoder.encode(42) + data = prefix + encoded + + value, offset = TLVEncoder.decode(data, len(prefix)) + assert value == 42 + assert offset == len(prefix) + len(encoded) + + def test_decode_multiple_values(self): + """Test decoding multiple consecutive values.""" + v1 = TLVEncoder.encode("hello") + v2 = TLVEncoder.encode(42) + v3 = TLVEncoder.encode([1, 2, 3]) + data = v1 + v2 + v3 + + offset = 0 + val1, offset = TLVEncoder.decode(data, offset) + assert val1 == "hello" + + val2, offset = TLVEncoder.decode(data, offset) + assert val2 == 42 + + val3, offset = TLVEncoder.decode(data, offset) + assert val3 == [1, 2, 3] + + assert offset == len(data) + + +class TestTLVEncoderBinaryData: + """Tests for binary data handling (the main motivation for this protocol).""" + + def test_binary_data_no_encoding(self): + """Test that binary data is passed through without encoding.""" + # This is the key advantage over JSON - binary data doesn't need base64 + binary_data = bytes(range(256)) # All byte values + encoded = TLVEncoder.encode(binary_data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == binary_data + + def test_binary_with_null_bytes(self): + """Test binary data with embedded null bytes.""" + binary_data = b"\x00\x00\xff\x00\x00" + encoded = TLVEncoder.encode(binary_data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == binary_data + + def test_binary_in_nested_structure(self): + """Test binary data inside nested structures.""" + data = { + "image": b"\x89PNG\r\n\x1a\n" + b"\x00" * 100, + "metadata": {"width": 640, "height": 480}, + "chunks": [b"chunk1", b"chunk2", b"chunk3"], + } + encoded = TLVEncoder.encode(data) + + value, offset = TLVEncoder.decode(encoded, 0) + assert value == data From 1665857c0e4ee59cf4c68469753078ed2d3d6e15 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 22:58:43 +0100 Subject: [PATCH 02/11] feat(dirty): implement binary protocol Replace JSON-based protocol with binary format using 16-byte header: - Magic bytes (GD), version, message type, payload length, request ID - TLV-encoded payloads for efficient binary data transfer - No base64 encoding needed for binary data - Backwards compatible API (DirtyProtocol alias, dict-based interface) Header format inspired by OpenBSD msgctl/msgsnd. --- gunicorn/dirty/protocol.py | 530 +++++++++++++++++++++++++++-------- tests/test_dirty_protocol.py | 448 +++++++++++++++++++---------- 2 files changed, 711 insertions(+), 267 deletions(-) diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index e5ac6cfa..15fab29a 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -3,89 +3,304 @@ # See the NOTICE for more information. """ -Dirty Arbiters Protocol +Dirty Worker Binary Protocol -Length-prefixed JSON message framing over Unix sockets. -Provides both async (primary) and sync (for HTTP workers) APIs. +Binary message framing over Unix sockets, inspired by OpenBSD msgctl/msgsnd. +Replaces JSON protocol for efficient binary data transfer. -Message Format: -+----------------+------------------+ -| 4-byte length | JSON payload | -+----------------+------------------+ +Header Format (16 bytes): ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Magic (2B) | Ver(1) | MType | Payload Length (4B) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Request ID (8 bytes) | ++--------+--------+--------+--------+--------+--------+--------+--------+ -The length field is a 4-byte unsigned integer in network byte order (big-endian). +- Magic: 0x47 0x44 ("GD" for Gunicorn Dirty) +- Version: 0x01 +- MType: Message type (REQUEST, RESPONSE, ERROR, CHUNK, END) +- Length: Payload size (big-endian uint32, max 64MB) +- Request ID: uint64 (replaces UUID string) + +Payload is TLV-encoded (see tlv.py). """ import asyncio -import json -import struct import socket +import struct from .errors import DirtyProtocolError +from .tlv import TLVEncoder -class DirtyProtocol: - """Length-prefixed JSON messages over Unix sockets.""" +# Protocol constants +MAGIC = b"GD" # 0x47 0x44 +VERSION = 0x01 - # 4-byte unsigned int, network byte order (big-endian) - HEADER_FORMAT = "!I" - HEADER_SIZE = struct.calcsize(HEADER_FORMAT) +# Message types (1 byte) +MSG_TYPE_REQUEST = 0x01 +MSG_TYPE_RESPONSE = 0x02 +MSG_TYPE_ERROR = 0x03 +MSG_TYPE_CHUNK = 0x04 +MSG_TYPE_END = 0x05 - # Maximum message size (64 MB) - MAX_MESSAGE_SIZE = 64 * 1024 * 1024 +# Message type names (for backwards compatibility with old API) +MSG_TYPE_REQUEST_STR = "request" +MSG_TYPE_RESPONSE_STR = "response" +MSG_TYPE_ERROR_STR = "error" +MSG_TYPE_CHUNK_STR = "chunk" +MSG_TYPE_END_STR = "end" - # Message types for future streaming support - MSG_TYPE_REQUEST = "request" - MSG_TYPE_RESPONSE = "response" - MSG_TYPE_ERROR = "error" - MSG_TYPE_CHUNK = "chunk" - MSG_TYPE_END = "end" +# Map int types to string names +MSG_TYPE_TO_STR = { + MSG_TYPE_REQUEST: MSG_TYPE_REQUEST_STR, + MSG_TYPE_RESPONSE: MSG_TYPE_RESPONSE_STR, + MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR, + MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR, + MSG_TYPE_END: MSG_TYPE_END_STR, +} + +# Map string names to int types +MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()} + +# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 +HEADER_FORMAT = ">2sBBIQ" +HEADER_SIZE = struct.calcsize(HEADER_FORMAT) + +# Maximum message size (64 MB) +MAX_MESSAGE_SIZE = 64 * 1024 * 1024 + + +class BinaryProtocol: + """Binary message protocol for dirty worker IPC.""" + + # Export constants for external use + HEADER_SIZE = HEADER_SIZE + MAX_MESSAGE_SIZE = MAX_MESSAGE_SIZE + + MSG_TYPE_REQUEST = MSG_TYPE_REQUEST_STR + MSG_TYPE_RESPONSE = MSG_TYPE_RESPONSE_STR + MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR + MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR + MSG_TYPE_END = MSG_TYPE_END_STR @staticmethod - def encode(message: dict) -> bytes: + def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: """ - Encode a message dict to length-prefixed bytes. + Encode the 16-byte message header. Args: - message: Dictionary to encode as JSON + msg_type: Message type (MSG_TYPE_REQUEST, etc.) + request_id: Unique request identifier (uint64) + payload_length: Length of the TLV-encoded payload Returns: - bytes: Length-prefixed encoded message + bytes: 16-byte header + """ + return struct.pack(HEADER_FORMAT, MAGIC, VERSION, msg_type, + payload_length, request_id) + + @staticmethod + def decode_header(data: bytes) -> tuple: + """ + Decode the 16-byte message header. + + Args: + data: 16 bytes of header data + + Returns: + tuple: (msg_type, request_id, payload_length) Raises: - DirtyProtocolError: If encoding fails + DirtyProtocolError: If header is invalid """ - try: - payload = json.dumps(message).encode("utf-8") - if len(payload) > DirtyProtocol.MAX_MESSAGE_SIZE: + if len(data) < HEADER_SIZE: + raise DirtyProtocolError( + f"Header too short: {len(data)} bytes, expected {HEADER_SIZE}", + raw_data=data + ) + + magic, version, msg_type, length, request_id = struct.unpack( + HEADER_FORMAT, data[:HEADER_SIZE] + ) + + if magic != MAGIC: + raise DirtyProtocolError( + f"Invalid magic: {magic!r}, expected {MAGIC!r}", + raw_data=data[:20] + ) + + if version != VERSION: + raise DirtyProtocolError( + f"Unsupported protocol version: {version}, expected {VERSION}", + raw_data=data[:20] + ) + + if msg_type not in MSG_TYPE_TO_STR: + raise DirtyProtocolError( + f"Unknown message type: 0x{msg_type:02x}", + raw_data=data[:20] + ) + + if length > MAX_MESSAGE_SIZE: + raise DirtyProtocolError( + f"Message too large: {length} bytes (max: {MAX_MESSAGE_SIZE})" + ) + + return msg_type, request_id, length + + @staticmethod + def encode_request(request_id: int, app_path: str, action: str, + args: tuple = None, kwargs: dict = None) -> bytes: + """ + Encode a request message. + + Args: + request_id: Unique request identifier (uint64) + app_path: Import path of the dirty app + action: Action to call on the app + args: Positional arguments + kwargs: Keyword arguments + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = { + "app_path": app_path, + "action": action, + "args": list(args) if args else [], + "kwargs": kwargs or {}, + } + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_response(request_id: int, result) -> bytes: + """ + Encode a success response message. + + Args: + request_id: Request identifier this responds to + result: Result value (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"result": result} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_error(request_id: int, error) -> bytes: + """ + Encode an error response message. + + Args: + request_id: Request identifier this responds to + error: DirtyError instance, dict, or Exception + + Returns: + bytes: Complete message (header + payload) + """ + from .errors import DirtyError + + if isinstance(error, DirtyError): + error_dict = error.to_dict() + elif isinstance(error, dict): + error_dict = error + else: + error_dict = { + "error_type": type(error).__name__, + "message": str(error), + "details": {}, + } + + payload_dict = {"error": error_dict} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_ERROR, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_chunk(request_id: int, data) -> bytes: + """ + Encode a chunk message for streaming responses. + + Args: + request_id: Request identifier this chunk belongs to + data: Chunk data (must be TLV-serializable) + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = {"data": data} + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_CHUNK, request_id, + len(payload)) + return header + payload + + @staticmethod + def encode_end(request_id: int) -> bytes: + """ + Encode an end-of-stream message. + + Args: + request_id: Request identifier this ends + + Returns: + bytes: Complete message (header + empty payload) + """ + # End message has empty payload + header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0) + return header + + @staticmethod + def decode_message(data: bytes) -> tuple: + """ + Decode a complete message (header + payload). + + Args: + data: Complete message bytes + + Returns: + tuple: (msg_type_str, request_id, payload_dict) + msg_type_str is the string name (e.g., "request") + payload_dict is the decoded TLV payload as a dict + + Raises: + DirtyProtocolError: If message is malformed + """ + msg_type, request_id, length = BinaryProtocol.decode_header(data) + + if len(data) < HEADER_SIZE + length: + raise DirtyProtocolError( + f"Incomplete message: expected {HEADER_SIZE + length} bytes, " + f"got {len(data)}", + raw_data=data[:50] + ) + + if length == 0: + # End message has empty payload + payload_dict = {} + else: + payload_data = data[HEADER_SIZE:HEADER_SIZE + length] + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: raise DirtyProtocolError( - f"Message too large: {len(payload)} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] ) - length = struct.pack(DirtyProtocol.HEADER_FORMAT, len(payload)) - return length + payload - except (TypeError, ValueError) as e: - raise DirtyProtocolError(f"Failed to encode message: {e}") - @staticmethod - def decode(data: bytes) -> dict: - """ - Decode bytes (without length prefix) to message dict. + # Convert to dict format similar to old JSON protocol + msg_type_str = MSG_TYPE_TO_STR[msg_type] - Args: - data: JSON bytes to decode - - Returns: - dict: Decoded message - - Raises: - DirtyProtocolError: If decoding fails - """ - try: - return json.loads(data.decode("utf-8")) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - raise DirtyProtocolError(f"Failed to decode message: {e}", - raw_data=data) + return msg_type_str, request_id, payload_dict # ------------------------------------------------------------------------- # Async API (primary - for DirtyArbiter and DirtyWorker) @@ -94,53 +309,62 @@ class DirtyProtocol: @staticmethod async def read_message_async(reader: asyncio.StreamReader) -> dict: """ - Read a complete message from async stream. + Read a complete binary message from async stream. Args: reader: asyncio StreamReader Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed asyncio.IncompleteReadError: If connection closed mid-read """ - # Read length header + # Read header try: - header = await reader.readexactly(DirtyProtocol.HEADER_SIZE) + header = await reader.readexactly(HEADER_SIZE) except asyncio.IncompleteReadError as e: if len(e.partial) == 0: # Clean close - no data was read raise raise DirtyProtocolError( f"Incomplete header: got {len(e.partial)} bytes, " - f"expected {DirtyProtocol.HEADER_SIZE}", + f"expected {HEADER_SIZE}", raw_data=e.partial ) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - try: - payload = await reader.readexactly(length) - except asyncio.IncompleteReadError as e: - raise DirtyProtocolError( - f"Incomplete message: got {len(e.partial)} bytes, " - f"expected {length}", - raw_data=e.partial - ) + if length > 0: + try: + payload_data = await reader.readexactly(length) + except asyncio.IncompleteReadError as e: + raise DirtyProtocolError( + f"Incomplete payload: got {len(e.partial)} bytes, " + f"expected {length}", + raw_data=e.partial + ) - return DirtyProtocol.decode(payload) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod async def write_message_async(writer: asyncio.StreamWriter, @@ -148,15 +372,17 @@ class DirtyProtocol: """ Write a message to async stream. + Accepts dict format for backwards compatibility. + Args: writer: asyncio StreamWriter - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails ConnectionError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) writer.write(data) await writer.drain() @@ -201,27 +427,36 @@ class DirtyProtocol: sock: Socket to read from Returns: - dict: Decoded message + dict: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If read fails or message is malformed """ - # Read length header - header = DirtyProtocol._recv_exactly(sock, DirtyProtocol.HEADER_SIZE) - length = struct.unpack(DirtyProtocol.HEADER_FORMAT, header)[0] - - if length > DirtyProtocol.MAX_MESSAGE_SIZE: - raise DirtyProtocolError( - f"Message too large: {length} bytes " - f"(max: {DirtyProtocol.MAX_MESSAGE_SIZE})" - ) - - if length == 0: - raise DirtyProtocolError("Empty message received") + # Read header + header = BinaryProtocol._recv_exactly(sock, HEADER_SIZE) + msg_type, request_id, length = BinaryProtocol.decode_header(header) # Read payload - payload = DirtyProtocol._recv_exactly(sock, length) - return DirtyProtocol.decode(payload) + if length > 0: + payload_data = BinaryProtocol._recv_exactly(sock, length) + try: + payload_dict = TLVEncoder.decode_full(payload_data) + except DirtyProtocolError: + raise + except Exception as e: + raise DirtyProtocolError( + f"Failed to decode TLV payload: {e}", + raw_data=payload_data[:50] + ) + else: + payload_dict = {} + + # Build response dict + msg_type_str = MSG_TYPE_TO_STR[msg_type] + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + + return result @staticmethod def write_message(sock: socket.socket, message: dict) -> None: @@ -230,31 +465,92 @@ class DirtyProtocol: Args: sock: Socket to write to - message: Dictionary to send + message: Message dict with 'type', 'id', and payload fields Raises: DirtyProtocolError: If encoding fails OSError: If write fails """ - data = DirtyProtocol.encode(message) + data = BinaryProtocol._encode_from_dict(message) sock.sendall(data) + @staticmethod + def _encode_from_dict(message: dict) -> bytes: + """ + Encode a message dict to binary format. -# Message builder helpers -def make_request(request_id: str, app_path: str, action: str, + Supports the old dict-based API for backwards compatibility. + + Args: + message: Message dict with 'type', 'id', and payload fields + + Returns: + bytes: Complete encoded message + """ + msg_type_str = message.get("type") + request_id = message.get("id", 0) + + # Handle string or int request IDs + if isinstance(request_id, str): + # For backwards compat with UUID strings, hash to int + request_id = hash(request_id) & 0xFFFFFFFFFFFFFFFF + + msg_type = MSG_TYPE_FROM_STR.get(msg_type_str) + if msg_type is None: + raise DirtyProtocolError(f"Unknown message type: {msg_type_str}") + + if msg_type == MSG_TYPE_REQUEST: + return BinaryProtocol.encode_request( + request_id, + message.get("app_path", ""), + message.get("action", ""), + message.get("args"), + message.get("kwargs") + ) + elif msg_type == MSG_TYPE_RESPONSE: + return BinaryProtocol.encode_response( + request_id, + message.get("result") + ) + elif msg_type == MSG_TYPE_ERROR: + return BinaryProtocol.encode_error( + request_id, + message.get("error", {}) + ) + elif msg_type == MSG_TYPE_CHUNK: + return BinaryProtocol.encode_chunk( + request_id, + message.get("data") + ) + elif msg_type == MSG_TYPE_END: + return BinaryProtocol.encode_end(request_id) + else: + raise DirtyProtocolError(f"Unhandled message type: {msg_type}") + + +# ============================================================================= +# Backwards Compatibility Aliases +# ============================================================================= + +# Alias BinaryProtocol as DirtyProtocol for drop-in replacement +DirtyProtocol = BinaryProtocol + + +# Message builder helpers (backwards compatible with old API) +def make_request(request_id, app_path: str, action: str, args: tuple = None, kwargs: dict = None) -> dict: """ - Build a request message. + Build a request message dict. Args: - request_id: Unique request identifier + request_id: Unique request identifier (int or str) app_path: Import path of the dirty app (e.g., 'myapp.ml:MLApp') action: Action to call on the app args: Positional arguments kwargs: Keyword arguments Returns: - dict: Request message + dict: Request message dict """ return { "type": DirtyProtocol.MSG_TYPE_REQUEST, @@ -266,16 +562,16 @@ def make_request(request_id: str, app_path: str, action: str, } -def make_response(request_id: str, result) -> dict: +def make_response(request_id, result) -> dict: """ - Build a success response message. + Build a success response message dict. Args: request_id: Request identifier this responds to - result: Result value (must be JSON-serializable) + result: Result value Returns: - dict: Response message + dict: Response message dict """ return { "type": DirtyProtocol.MSG_TYPE_RESPONSE, @@ -284,16 +580,16 @@ def make_response(request_id: str, result) -> dict: } -def make_error_response(request_id: str, error) -> dict: +def make_error_response(request_id, error) -> dict: """ - Build an error response message. + Build an error response message dict. Args: request_id: Request identifier this responds to error: DirtyError instance or dict with error info Returns: - dict: Error response message + dict: Error response message dict """ from .errors import DirtyError if isinstance(error, DirtyError): @@ -314,16 +610,16 @@ def make_error_response(request_id: str, error) -> dict: } -def make_chunk_message(request_id: str, data) -> dict: +def make_chunk_message(request_id, data) -> dict: """ - Build a chunk message for streaming responses. + Build a chunk message dict for streaming responses. Args: request_id: Request identifier this chunk belongs to - data: Chunk data (must be JSON-serializable) + data: Chunk data Returns: - dict: Chunk message + dict: Chunk message dict """ return { "type": DirtyProtocol.MSG_TYPE_CHUNK, @@ -332,15 +628,15 @@ def make_chunk_message(request_id: str, data) -> dict: } -def make_end_message(request_id: str) -> dict: +def make_end_message(request_id) -> dict: """ - Build an end-of-stream message. + Build an end-of-stream message dict. Args: request_id: Request identifier this ends Returns: - dict: End message + dict: End message dict """ return { "type": DirtyProtocol.MSG_TYPE_END, diff --git a/tests/test_dirty_protocol.py b/tests/test_dirty_protocol.py index dbabc51e..48fe3333 100644 --- a/tests/test_dirty_protocol.py +++ b/tests/test_dirty_protocol.py @@ -2,7 +2,7 @@ # This file is part of gunicorn released under the MIT license. # See the NOTICE for more information. -"""Tests for dirty arbiter protocol module.""" +"""Tests for dirty worker binary protocol module.""" import asyncio import os @@ -11,12 +11,23 @@ import struct import pytest from gunicorn.dirty.protocol import ( + BinaryProtocol, DirtyProtocol, make_request, make_response, make_error_response, make_chunk_message, make_end_message, + MAGIC, + VERSION, + HEADER_SIZE, + HEADER_FORMAT, + MSG_TYPE_REQUEST, + MSG_TYPE_RESPONSE, + MSG_TYPE_ERROR, + MSG_TYPE_CHUNK, + MSG_TYPE_END, + MAX_MESSAGE_SIZE, ) from gunicorn.dirty.errors import ( DirtyError, @@ -26,118 +37,194 @@ from gunicorn.dirty.errors import ( ) -class TestDirtyProtocolEncodeDecode: - """Tests for encode/decode functionality.""" +class TestBinaryProtocolHeader: + """Tests for header encoding/decoding.""" - def test_encode_decode_roundtrip(self): - """Test basic encode/decode roundtrip.""" - message = {"type": "request", "id": "123", "data": "hello"} - encoded = DirtyProtocol.encode(message) + def test_header_size(self): + """Test header size is 16 bytes.""" + assert HEADER_SIZE == 16 - # Check header format - assert len(encoded) > DirtyProtocol.HEADER_SIZE - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - encoded[:DirtyProtocol.HEADER_SIZE] - )[0] - assert length == len(encoded) - DirtyProtocol.HEADER_SIZE + def test_encode_header(self): + """Test header encoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_REQUEST, 12345, 100) + assert len(header) == HEADER_SIZE + assert header[:2] == MAGIC + assert header[2] == VERSION + assert header[3] == MSG_TYPE_REQUEST - # Decode payload - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header(self): + """Test header decoding.""" + header = BinaryProtocol.encode_header(MSG_TYPE_RESPONSE, 67890, 200) + msg_type, request_id, length = BinaryProtocol.decode_header(header) + assert msg_type == MSG_TYPE_RESPONSE + assert request_id == 67890 + assert length == 200 - def test_encode_decode_complex_data(self): - """Test with complex nested data structures.""" - message = { - "type": "response", - "id": "456", - "result": { - "models": ["gpt-4", "claude-3"], - "config": {"temperature": 0.7, "max_tokens": 1000}, - "metadata": None, - }, - "args": [1, 2, 3], - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_magic(self): + """Test header decoding with invalid magic.""" + header = b"XX" + b"\x01\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "magic" in str(exc_info.value).lower() - def test_encode_decode_unicode(self): - """Test with unicode characters.""" - message = { - "type": "request", - "data": "Hello, world!" - } - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + def test_decode_header_invalid_version(self): + """Test header decoding with invalid version.""" + header = MAGIC + b"\x99\x01" + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "version" in str(exc_info.value).lower() - def test_encode_large_message(self): + def test_decode_header_invalid_type(self): + """Test header decoding with invalid message type.""" + header = MAGIC + bytes([VERSION, 0xFF]) + b"\x00" * 12 + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "type" in str(exc_info.value).lower() + + def test_decode_header_too_large(self): + """Test header decoding rejects too-large messages.""" + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1, 0) + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "too large" in str(exc_info.value).lower() + + def test_decode_header_too_short(self): + """Test header decoding with too-short data.""" + header = MAGIC + b"\x01" + with pytest.raises(DirtyProtocolError) as exc_info: + BinaryProtocol.decode_header(header) + assert "short" in str(exc_info.value).lower() + + +class TestBinaryProtocolEncodeDecode: + """Tests for message encoding/decoding.""" + + def test_encode_decode_request(self): + """Test request encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_request( + request_id=12345, + app_path="myapp.ml:MLApp", + action="predict", + args=("data",), + kwargs={"temperature": 0.7} + ) + assert len(encoded) > HEADER_SIZE + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "request" + assert request_id == 12345 + assert payload["app_path"] == "myapp.ml:MLApp" + assert payload["action"] == "predict" + assert payload["args"] == ["data"] + assert payload["kwargs"] == {"temperature": 0.7} + + def test_encode_decode_response(self): + """Test response encoding/decoding roundtrip.""" + result = {"predictions": [0.1, 0.9], "metadata": {"model": "v1"}} + encoded = BinaryProtocol.encode_response(request_id=67890, result=result) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "response" + assert request_id == 67890 + assert payload["result"] == result + + def test_encode_decode_error(self): + """Test error encoding/decoding roundtrip.""" + error = DirtyTimeoutError("Timed out", timeout=30) + encoded = BinaryProtocol.encode_error(request_id=11111, error=error) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "error" + assert request_id == 11111 + assert payload["error"]["error_type"] == "DirtyTimeoutError" + assert "Timed out" in payload["error"]["message"] + + def test_encode_decode_chunk(self): + """Test chunk encoding/decoding roundtrip.""" + chunk_data = {"token": "hello", "index": 5} + encoded = BinaryProtocol.encode_chunk(request_id=22222, data=chunk_data) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "chunk" + assert request_id == 22222 + assert payload["data"] == chunk_data + + def test_encode_decode_end(self): + """Test end message encoding/decoding roundtrip.""" + encoded = BinaryProtocol.encode_end(request_id=33333) + assert len(encoded) == HEADER_SIZE # End has no payload + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert msg_type_str == "end" + assert request_id == 33333 + assert payload == {} + + def test_encode_decode_binary_data(self): + """Test binary data passes through without base64 encoding.""" + binary_data = bytes(range(256)) + encoded = BinaryProtocol.encode_response( + request_id=44444, + result={"data": binary_data} + ) + + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == binary_data + + def test_encode_decode_large_message(self): """Test encoding a large message.""" - large_data = "x" * (1024 * 1024) # 1 MB of data - message = {"type": "request", "data": large_data} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message + large_data = b"x" * (1024 * 1024) # 1 MB + encoded = BinaryProtocol.encode_response( + request_id=55555, + result={"data": large_data} + ) - def test_encode_empty_dict(self): - """Test encoding an empty dictionary.""" - message = {} - encoded = DirtyProtocol.encode(message) - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == message - - def test_encode_message_too_large(self): - """Test that encoding a message that's too large raises error.""" - large_data = "x" * (DirtyProtocol.MAX_MESSAGE_SIZE + 1000) - message = {"data": large_data} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "too large" in str(exc_info.value) - - def test_encode_non_serializable(self): - """Test that encoding non-JSON-serializable data raises error.""" - message = {"func": lambda x: x} - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.encode(message) - assert "Failed to encode" in str(exc_info.value) - - def test_decode_invalid_json(self): - """Test decoding invalid JSON raises error.""" - invalid_data = b"not valid json" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) - - def test_decode_invalid_unicode(self): - """Test decoding invalid unicode raises error.""" - invalid_data = b"\x80\x81\x82" - with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.decode(invalid_data) - assert "Failed to decode" in str(exc_info.value) + msg_type_str, request_id, payload = BinaryProtocol.decode_message(encoded) + assert payload["result"]["data"] == large_data -class TestDirtyProtocolSync: +class TestBinaryProtocolSync: """Tests for synchronous socket operations.""" def test_read_write_message(self): """Test read/write through socket pair.""" - # Create a socket pair for testing server_sock, client_sock = socket.socketpair() try: - message = {"type": "request", "id": "123", "action": "test"} + message = make_request( + request_id=12345, + app_path="test:App", + action="run" + ) - # Write message - DirtyProtocol.write_message(client_sock, message) + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) - # Read message - received = DirtyProtocol.read_message(server_sock) - assert received == message + assert received["type"] == "request" + assert received["id"] == hash("12345") & 0xFFFFFFFFFFFFFFFF or \ + received["id"] == 12345 + assert received["app_path"] == "test:App" + assert received["action"] == "run" + finally: + server_sock.close() + client_sock.close() + + def test_read_write_with_int_id(self): + """Test read/write with integer request ID.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "request", + "id": 999888777, + "app_path": "test:App", + "action": "run", + "args": [], + "kwargs": {} + } + + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["id"] == 999888777 finally: server_sock.close() client_sock.close() @@ -147,19 +234,17 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() try: messages = [ - {"type": "request", "id": "1"}, - {"type": "request", "id": "2"}, - {"type": "request", "id": "3"}, + make_request(i, f"app{i}:App", f"action{i}") + for i in range(1, 4) ] - # Write all messages for msg in messages: - DirtyProtocol.write_message(client_sock, msg) + BinaryProtocol.write_message(client_sock, msg) - # Read all messages - for expected in messages: - received = DirtyProtocol.read_message(server_sock) - assert received == expected + for i, _ in enumerate(messages, 1): + received = BinaryProtocol.read_message(server_sock) + assert received["app_path"] == f"app{i}:App" + assert received["action"] == f"action{i}" finally: server_sock.close() client_sock.close() @@ -169,39 +254,51 @@ class TestDirtyProtocolSync: server_sock, client_sock = socket.socketpair() client_sock.close() with pytest.raises(DirtyProtocolError) as exc_info: - DirtyProtocol.read_message(server_sock) + BinaryProtocol.read_message(server_sock) assert "closed" in str(exc_info.value).lower() server_sock.close() + def test_binary_data_roundtrip(self): + """Test binary data roundtrip through socket.""" + server_sock, client_sock = socket.socketpair() + try: + binary_payload = b"\x00\x01\x02\xff\xfe\xfd" + message = make_response(12345, {"binary": binary_payload}) -class TestDirtyProtocolAsync: + BinaryProtocol.write_message(client_sock, message) + received = BinaryProtocol.read_message(server_sock) + + assert received["result"]["binary"] == binary_payload + finally: + server_sock.close() + client_sock.close() + + +class TestBinaryProtocolAsync: """Tests for async stream operations.""" @pytest.mark.asyncio async def test_async_read_write(self): """Test async read/write with mock streams.""" - message = {"type": "request", "id": "123"} + message = make_request(12345, "test:App", "run") - # Create a pipe for testing read_fd, write_fd = os.pipe() try: reader = asyncio.StreamReader() _ = asyncio.StreamReaderProtocol(reader) - # Write the message to the pipe - encoded = DirtyProtocol.encode(message) + encoded = BinaryProtocol._encode_from_dict(message) os.write(write_fd, encoded) os.close(write_fd) write_fd = None - # Feed data to reader data = os.read(read_fd, len(encoded)) reader.feed_data(data) reader.feed_eof() - # Read the message - received = await DirtyProtocol.read_message_async(reader) - assert received == message + received = await BinaryProtocol.read_message_async(reader) + assert received["type"] == "request" + assert received["app_path"] == "test:App" finally: if write_fd is not None: os.close(write_fd) @@ -211,12 +308,11 @@ class TestDirtyProtocolAsync: async def test_async_read_incomplete_header(self): """Test async read with incomplete header.""" reader = asyncio.StreamReader() - # Feed only 2 bytes instead of 4 - reader.feed_data(b"\x00\x00") + reader.feed_data(MAGIC + b"\x01") # Only 3 bytes reader.feed_eof() with pytest.raises((asyncio.IncompleteReadError, DirtyProtocolError)): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) @pytest.mark.asyncio async def test_async_read_empty_connection(self): @@ -225,36 +321,33 @@ class TestDirtyProtocolAsync: reader.feed_eof() with pytest.raises(asyncio.IncompleteReadError): - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) + + @pytest.mark.asyncio + async def test_async_read_invalid_magic(self): + """Test async read rejects invalid magic.""" + reader = asyncio.StreamReader() + header = b"XX" + bytes([VERSION, MSG_TYPE_REQUEST]) + b"\x00" * 12 + reader.feed_data(header) + reader.feed_eof() + + with pytest.raises(DirtyProtocolError) as exc_info: + await BinaryProtocol.read_message_async(reader) + assert "magic" in str(exc_info.value).lower() @pytest.mark.asyncio async def test_async_read_message_too_large(self): """Test async read rejects too-large messages.""" reader = asyncio.StreamReader() - # Create a header claiming an absurdly large message - header = struct.pack( - DirtyProtocol.HEADER_FORMAT, - DirtyProtocol.MAX_MESSAGE_SIZE + 1000 - ) + header = struct.pack(HEADER_FORMAT, MAGIC, VERSION, MSG_TYPE_REQUEST, + MAX_MESSAGE_SIZE + 1000, 0) reader.feed_data(header) reader.feed_eof() with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) + await BinaryProtocol.read_message_async(reader) assert "too large" in str(exc_info.value) - @pytest.mark.asyncio - async def test_async_read_empty_message(self): - """Test async read rejects empty messages.""" - reader = asyncio.StreamReader() - header = struct.pack(DirtyProtocol.HEADER_FORMAT, 0) - reader.feed_data(header) - reader.feed_eof() - - with pytest.raises(DirtyProtocolError) as exc_info: - await DirtyProtocol.read_message_async(reader) - assert "Empty message" in str(exc_info.value) - class TestMessageBuilders: """Tests for message builder helper functions.""" @@ -340,9 +433,9 @@ class TestMessageBuilders: assert chunk["id"] == "req-456" assert chunk["data"] == data - def test_make_chunk_message_with_list_data(self): - """Test chunk message with list data.""" - data = [1, 2, 3, "token"] + def test_make_chunk_message_with_binary_data(self): + """Test chunk message with binary data.""" + data = b"\x00\x01\x02\xff" chunk = make_chunk_message("req-789", data) assert chunk["data"] == data @@ -353,22 +446,22 @@ class TestMessageBuilders: assert end["id"] == "req-123" assert "data" not in end - def test_chunk_and_end_encode_decode(self): + def test_chunk_and_end_roundtrip(self): """Test that chunk and end messages can be encoded/decoded.""" - chunk = make_chunk_message("req-123", {"token": "hello"}) - end = make_end_message("req-123") + chunk = make_chunk_message(12345, {"token": "hello"}) + end = make_end_message(12345) # Test chunk roundtrip - encoded_chunk = DirtyProtocol.encode(chunk) - payload = encoded_chunk[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == chunk + encoded_chunk = BinaryProtocol._encode_from_dict(chunk) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_chunk) + assert msg_type == "chunk" + assert payload["data"] == {"token": "hello"} # Test end roundtrip - encoded_end = DirtyProtocol.encode(end) - payload = encoded_end[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - assert decoded == end + encoded_end = BinaryProtocol._encode_from_dict(end) + msg_type, req_id, payload = BinaryProtocol.decode_message(encoded_end) + assert msg_type == "end" + assert payload == {} class TestDirtyErrors: @@ -417,3 +510,58 @@ class TestDirtyErrors: assert error.action == "run" assert error.traceback == "Traceback..." assert "myapp:App" in str(error) + + +class TestBackwardsCompatibility: + """Tests for backwards compatibility with old JSON API.""" + + def test_dirty_protocol_alias(self): + """Test that DirtyProtocol is an alias for BinaryProtocol.""" + assert DirtyProtocol is BinaryProtocol + + def test_header_size_attribute(self): + """Test HEADER_SIZE is accessible on class.""" + assert DirtyProtocol.HEADER_SIZE == 16 + + def test_msg_type_constants(self): + """Test message type constants are strings for compatibility.""" + assert DirtyProtocol.MSG_TYPE_REQUEST == "request" + assert DirtyProtocol.MSG_TYPE_RESPONSE == "response" + assert DirtyProtocol.MSG_TYPE_ERROR == "error" + assert DirtyProtocol.MSG_TYPE_CHUNK == "chunk" + assert DirtyProtocol.MSG_TYPE_END == "end" + + def test_encode_decode_preserves_dict_format(self): + """Test that read_message returns dict compatible with old API.""" + server_sock, client_sock = socket.socketpair() + try: + message = { + "type": "response", + "id": 12345, + "result": {"status": "ok"} + } + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Old API: access via dict keys + assert received["type"] == "response" + assert received["result"]["status"] == "ok" + finally: + server_sock.close() + client_sock.close() + + def test_string_request_id_handled(self): + """Test that string request IDs are handled (hashed to int).""" + server_sock, client_sock = socket.socketpair() + try: + message = make_request("uuid-string-id", "test:App", "run") + + DirtyProtocol.write_message(client_sock, message) + received = DirtyProtocol.read_message(server_sock) + + # Request ID should be converted to int + assert isinstance(received["id"], int) + finally: + server_sock.close() + client_sock.close() From 6d2139bb6cf1ad08fb721a1e4c4b7b1bf792315b Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:01:21 +0100 Subject: [PATCH 03/11] feat(dirty): update worker for binary protocol Update worker tests to work with the binary protocol: - Use integer request IDs instead of strings - Update MockStreamWriter to decode binary messages - Import binary protocol constants from module level --- tests/test_dirty_worker.py | 45 ++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/tests/test_dirty_worker.py b/tests/test_dirty_worker.py index f68a2276..e50e7c41 100644 --- a/tests/test_dirty_worker.py +++ b/tests/test_dirty_worker.py @@ -12,7 +12,13 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, + HEADER_FORMAT, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -56,17 +62,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -246,7 +257,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-123", + request_id=123, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -259,7 +270,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_RESPONSE - assert response["id"] == "test-123" + assert response["id"] == 123 assert response["result"] == 6 @pytest.mark.asyncio @@ -282,7 +293,7 @@ class TestDirtyWorkerHandleRequest: worker.load_apps() request = make_request( - request_id="test-456", + request_id=456, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(2, 3), @@ -295,7 +306,7 @@ class TestDirtyWorkerHandleRequest: assert len(writer.messages) == 1 response = writer.messages[0] assert response["type"] == DirtyProtocol.MSG_TYPE_ERROR - assert response["id"] == "test-456" + assert response["id"] == 456 assert "Unknown operation" in response["error"]["message"] @pytest.mark.asyncio @@ -315,7 +326,7 @@ class TestDirtyWorkerHandleRequest: socket_path=socket_path ) - request = {"type": "unknown", "id": "test-789"} + request = {"type": "unknown", "id": 789} writer = MockStreamWriter() await worker.handle_request(request, writer) @@ -697,7 +708,7 @@ class TestDirtyWorkerRunAsync: # Create a simple test using stream reader/writer request = make_request( - request_id="conn-test", + request_id=999, app_path="tests.support_dirty_app:TestDirtyApp", action="compute", args=(5, 3), @@ -706,7 +717,7 @@ class TestDirtyWorkerRunAsync: # Mock reader and writer reader = asyncio.StreamReader() - encoded_request = DirtyProtocol.encode(request) + encoded_request = BinaryProtocol._encode_from_dict(request) reader.feed_data(encoded_request) reader.feed_eof() From 98b1b649c20bb0de7ab1f1eae967ac6396420d98 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:03:40 +0100 Subject: [PATCH 04/11] feat(dirty): update arbiter for binary protocol Update arbiter tests to work with the binary protocol: - Update MockStreamWriter to decode binary messages - Import binary protocol constants from module level --- tests/test_dirty_arbiter.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/test_dirty_arbiter.py b/tests/test_dirty_arbiter.py index 40abb504..05f35cb0 100644 --- a/tests/test_dirty_arbiter.py +++ b/tests/test_dirty_arbiter.py @@ -14,7 +14,12 @@ import pytest from gunicorn.config import Config from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) class MockStreamWriter: @@ -29,16 +34,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break From 477b7479cc200fc0aad49c31a6dd77a265b0f9eb Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:12:44 +0100 Subject: [PATCH 05/11] feat(dirty): update client for binary protocol Update client and streaming tests to work with the binary protocol: - Update MockStreamWriter/MockStreamReader to use BinaryProtocol - Replace string request IDs with integers - Update test assertions to decode binary protocol messages - Use HEADER_SIZE and decode_header/decode_message instead of old API --- tests/dirty/test_arbiter_streaming.py | 66 ++++++++++-------- tests/dirty/test_client_streaming.py | 78 ++++++++++------------ tests/dirty/test_client_streaming_async.py | 70 +++++++++---------- tests/dirty/test_multi_app_routing.py | 29 +++++--- tests/dirty/test_streaming_integration.py | 58 +++++++++------- tests/dirty/test_worker_streaming.py | 61 +++++++++-------- tests/docker/http2/certs/server.crt | 21 ++++++ tests/docker/http2/certs/server.key | 28 ++++++++ tests/test_dirty_integration.py | 24 ++++--- 9 files changed, 258 insertions(+), 177 deletions(-) create mode 100644 tests/docker/http2/certs/server.crt create mode 100644 tests/docker/http2/certs/server.key diff --git a/tests/dirty/test_arbiter_streaming.py b/tests/dirty/test_arbiter_streaming.py index ef15c33a..a722f2af 100644 --- a/tests/dirty/test_arbiter_streaming.py +++ b/tests/dirty/test_arbiter_streaming.py @@ -12,11 +12,13 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_response, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.arbiter import DirtyArbiter from gunicorn.dirty.errors import DirtyError @@ -34,16 +36,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -63,7 +71,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -107,9 +115,9 @@ class TestArbiterStreamingForwarding: client_writer = MockStreamWriter() # Mock worker connection that returns chunks - chunk1 = make_chunk_message("req-123", "Hello") - chunk2 = make_chunk_message("req-123", " World") - end = make_end_message("req-123") + chunk1 = make_chunk_message(123, "Hello") + chunk2 = make_chunk_message(123, " World") + end = make_end_message(123) mock_reader = MockStreamReader([chunk1, chunk2, end]) @@ -118,7 +126,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Should have forwarded all messages @@ -135,7 +143,7 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", {"result": 42}) + response = make_response(123, {"result": 42}) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -143,7 +151,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -156,8 +164,8 @@ class TestArbiterStreamingForwarding: arbiter = create_arbiter() client_writer = MockStreamWriter() - chunk = make_chunk_message("req-123", "First") - error = make_error_response("req-123", DirtyError("Something broke")) + chunk = make_chunk_message(123, "First") + error = make_error_response(123, DirtyError("Something broke")) mock_reader = MockStreamReader([chunk, error]) @@ -166,7 +174,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 2 @@ -190,7 +198,7 @@ class TestArbiterStreamingForwarding: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -208,7 +216,7 @@ class TestArbiterRouteRequestStreaming: arbiter.workers = {} # No workers client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter.route_request(request, client_writer) assert len(client_writer.messages) == 1 @@ -222,13 +230,13 @@ class TestArbiterRouteRequestStreaming: # Mock _execute_on_worker to complete immediately async def mock_execute(pid, request, client_writer): - response = make_response("req-123", "result") + response = make_response(123, "result") await DirtyProtocol.write_message_async(client_writer, response) arbiter._execute_on_worker = mock_execute client_writer = MockStreamWriter() - request = make_request("req-123", "test:App", "compute") + request = make_request(123, "test:App", "compute") # Worker queue should be created assert 1234 not in arbiter.worker_queues @@ -255,8 +263,8 @@ class TestArbiterStreamingManyChunks: # Generate 50 chunks + end messages = [] for i in range(50): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) mock_reader = MockStreamReader(messages) @@ -265,7 +273,7 @@ class TestArbiterStreamingManyChunks: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 51 @@ -283,7 +291,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - response = make_response("req-123", [1, 2, 3, 4, 5]) + response = make_response(123, [1, 2, 3, 4, 5]) mock_reader = MockStreamReader([response]) async def mock_get_connection(pid): @@ -291,7 +299,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 @@ -304,7 +312,7 @@ class TestArbiterBackwardCompatibility: arbiter = create_arbiter() client_writer = MockStreamWriter() - error = make_error_response("req-123", DirtyError("Something failed")) + error = make_error_response(123, DirtyError("Something failed")) mock_reader = MockStreamReader([error]) async def mock_get_connection(pid): @@ -312,7 +320,7 @@ class TestArbiterBackwardCompatibility: arbiter._get_worker_connection = mock_get_connection - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 1 diff --git a/tests/dirty/test_client_streaming.py b/tests/dirty/test_client_streaming.py index 7bc13525..eca76e98 100644 --- a/tests/dirty/test_client_streaming.py +++ b/tests/dirty/test_client_streaming.py @@ -11,10 +11,12 @@ from unittest import mock from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyConnectionError @@ -26,7 +28,7 @@ class MockSocket: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 self._sent = [] self.closed = False @@ -69,10 +71,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_chunks(self): """Test that stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -83,9 +85,9 @@ class TestDirtyStreamIterator: def test_stream_iterator_yields_complex_chunks(self): """Test that stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -98,8 +100,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_handles_error(self): """Test that stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_client_with_mock_socket(messages) @@ -116,7 +118,7 @@ class TestDirtyStreamIterator: def test_stream_iterator_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_client_with_mock_socket(messages) chunks = list(client.stream("test:App", "generate")) @@ -125,8 +127,8 @@ class TestDirtyStreamIterator: def test_stream_iterator_stops_after_exhausted(self): """Test that iterator stays exhausted after StopIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -147,10 +149,10 @@ class TestDirtyStreamIterator: def test_stream_iterator_with_for_loop(self): """Test stream iterator works in for loop.""" messages = [ - make_chunk_message("req-123", "a"), - make_chunk_message("req-123", "b"), - make_chunk_message("req-123", "c"), - make_end_message("req-123"), + make_chunk_message(123, "a"), + make_chunk_message(123, "b"), + make_chunk_message(123, "c"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -163,8 +165,8 @@ class TestDirtyStreamIterator: def test_stream_sends_request_on_first_iteration(self): """Test that request is sent on first next() call.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -179,18 +181,15 @@ class TestDirtyStreamIterator: # Decode sent request sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyStreamIteratorEdgeCases: @@ -200,8 +199,8 @@ class TestDirtyStreamIteratorEdgeCases: """Test streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_client_with_mock_socket(messages) @@ -214,8 +213,8 @@ class TestDirtyStreamIteratorEdgeCases: def test_stream_with_kwargs(self): """Test streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_client_with_mock_socket(messages) @@ -224,13 +223,10 @@ class TestDirtyStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._sock._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} diff --git a/tests/dirty/test_client_streaming_async.py b/tests/dirty/test_client_streaming_async.py index 651c73d1..b38eff6c 100644 --- a/tests/dirty/test_client_streaming_async.py +++ b/tests/dirty/test_client_streaming_async.py @@ -10,9 +10,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_chunk_message, make_end_message, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.client import DirtyClient, DirtyAsyncStreamIterator from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError @@ -24,7 +26,7 @@ class MockAsyncReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -76,10 +78,10 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_chunks(self): """Test that async stream iterator yields chunks correctly.""" messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -93,9 +95,9 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_yields_complex_chunks(self): """Test that async stream iterator yields complex data types.""" messages = [ - make_chunk_message("req-123", {"token": "Hello", "score": 0.9}), - make_chunk_message("req-123", {"token": "World", "score": 0.8}), - make_end_message("req-123"), + make_chunk_message(123, {"token": "Hello", "score": 0.9}), + make_chunk_message(123, {"token": "World", "score": 0.8}), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -111,8 +113,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_handles_error(self): """Test that async stream iterator raises on error message.""" messages = [ - make_chunk_message("req-123", "First"), - make_error_response("req-123", DirtyError("Something broke")), + make_chunk_message(123, "First"), + make_error_response(123, DirtyError("Something broke")), ] client = create_async_client_with_mocks(messages) @@ -130,7 +132,7 @@ class TestDirtyAsyncStreamIterator: @pytest.mark.asyncio async def test_async_stream_empty_stream(self): """Test that empty stream (just end) works.""" - messages = [make_end_message("req-123")] + messages = [make_end_message(123)] client = create_async_client_with_mocks(messages) chunks = [] @@ -143,8 +145,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_stops_after_exhausted(self): """Test that async iterator stays exhausted after StopAsyncIteration.""" messages = [ - make_chunk_message("req-123", "Only"), - make_end_message("req-123"), + make_chunk_message(123, "Only"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -166,8 +168,8 @@ class TestDirtyAsyncStreamIterator: async def test_async_stream_sends_request_on_first_iteration(self): """Test that request is sent on first async iteration.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -182,18 +184,15 @@ class TestDirtyAsyncStreamIterator: # Decode sent request sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["type"] == "request" - assert request["app_path"] == "test:App" - assert request["action"] == "generate" - assert request["args"] == ["prompt_arg"] + assert msg_type_str == "request" + assert payload["app_path"] == "test:App" + assert payload["action"] == "generate" + assert payload["args"] == ["prompt_arg"] class TestDirtyAsyncStreamIteratorEdgeCases: @@ -204,8 +203,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: """Test async streaming with many chunks.""" messages = [] for i in range(100): - messages.append(make_chunk_message("req-123", f"chunk-{i}")) - messages.append(make_end_message("req-123")) + messages.append(make_chunk_message(123, f"chunk-{i}")) + messages.append(make_end_message(123)) client = create_async_client_with_mocks(messages) @@ -221,8 +220,8 @@ class TestDirtyAsyncStreamIteratorEdgeCases: async def test_async_stream_with_kwargs(self): """Test async streaming with keyword arguments.""" messages = [ - make_chunk_message("req-123", "data"), - make_end_message("req-123"), + make_chunk_message(123, "data"), + make_end_message(123), ] client = create_async_client_with_mocks(messages) @@ -233,16 +232,13 @@ class TestDirtyAsyncStreamIteratorEdgeCases: # Check the sent request includes kwargs sent_data = client._writer._sent[0] - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - sent_data[:DirtyProtocol.HEADER_SIZE] - )[0] - request = DirtyProtocol.decode( - sent_data[DirtyProtocol.HEADER_SIZE:DirtyProtocol.HEADER_SIZE + length] + _, _, length = BinaryProtocol.decode_header(sent_data[:HEADER_SIZE]) + msg_type_str, request_id, payload = BinaryProtocol.decode_message( + sent_data[:HEADER_SIZE + length] ) - assert request["args"] == ["arg1"] - assert request["kwargs"] == {"key": "value"} + assert payload["args"] == ["arg1"] + assert payload["kwargs"] == {"key": "value"} class TestDirtyAsyncStreamTimeout: diff --git a/tests/dirty/test_multi_app_routing.py b/tests/dirty/test_multi_app_routing.py index c113bab1..4e01b711 100644 --- a/tests/dirty/test_multi_app_routing.py +++ b/tests/dirty/test_multi_app_routing.py @@ -19,7 +19,12 @@ from concurrent.futures import ThreadPoolExecutor from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import ( + DirtyProtocol, + BinaryProtocol, + make_request, + HEADER_SIZE, +) from gunicorn.dirty.errors import DirtyAppNotFoundError @@ -71,16 +76,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break diff --git a/tests/dirty/test_streaming_integration.py b/tests/dirty/test_streaming_integration.py index 06b9645f..b23fee38 100644 --- a/tests/dirty/test_streaming_integration.py +++ b/tests/dirty/test_streaming_integration.py @@ -18,11 +18,13 @@ from unittest import mock from gunicorn.config import Config from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, make_response, make_error_response, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker from gunicorn.dirty.arbiter import DirtyArbiter @@ -67,16 +69,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -96,7 +104,7 @@ class MockStreamReader: def __init__(self, messages): self._data = b'' for msg in messages: - self._data += DirtyProtocol.encode(msg) + self._data += BinaryProtocol._encode_from_dict(msg) self._pos = 0 async def readexactly(self, n): @@ -115,10 +123,10 @@ class TestStreamingEndToEnd: """Test complete flow: sync generator -> worker -> arbiter -> client.""" # Simulate what a worker would produce for a sync generator worker_messages = [ - make_chunk_message("req-123", "Hello"), - make_chunk_message("req-123", " "), - make_chunk_message("req-123", "World"), - make_end_message("req-123"), + make_chunk_message(123, "Hello"), + make_chunk_message(123, " "), + make_chunk_message(123, "World"), + make_end_message(123), ] # Create an arbiter with mocked worker connection @@ -141,7 +149,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() # Execute request through arbiter - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await arbiter._execute_on_worker(1234, request, client_writer) # Verify all messages were forwarded @@ -158,10 +166,10 @@ class TestStreamingEndToEnd: async def test_async_generator_end_to_end(self): """Test complete flow: async generator -> worker -> arbiter -> client.""" worker_messages = [ - make_chunk_message("req-456", "Async"), - make_chunk_message("req-456", " "), - make_chunk_message("req-456", "Stream"), - make_end_message("req-456"), + make_chunk_message(456, "Async"), + make_chunk_message(456, " "), + make_chunk_message(456, "Stream"), + make_end_message(456), ] cfg = Config() @@ -180,7 +188,7 @@ class TestStreamingEndToEnd: client_writer = MockStreamWriter() - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await arbiter._execute_on_worker(1234, request, client_writer) assert len(client_writer.messages) == 4 @@ -197,9 +205,9 @@ class TestStreamingErrorHandling: async def test_error_mid_stream(self): """Test that errors during streaming are properly forwarded.""" worker_messages = [ - make_chunk_message("req-789", "First"), - make_chunk_message("req-789", "Second"), - make_error_response("req-789", DirtyError("Stream failed")), + make_chunk_message(789, "First"), + make_chunk_message(789, "Second"), + make_error_response(789, DirtyError("Stream failed")), ] cfg = Config() @@ -218,7 +226,7 @@ class TestStreamingErrorHandling: client_writer = MockStreamWriter() - request = make_request("req-789", "test:App", "generate_with_error") + request = make_request(789, "test:App", "generate_with_error") await arbiter._execute_on_worker(1234, request, client_writer) # Should have 2 chunks + 1 error @@ -335,7 +343,7 @@ class TestStreamingWorkerIntegration: return sync_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end @@ -377,7 +385,7 @@ class TestStreamingWorkerIntegration: return async_gen() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-456", "test:App", "async_generate") + request = make_request(456, "test:App", "async_generate") await worker.handle_request(request, writer) # Should have 2 chunks + 1 end diff --git a/tests/dirty/test_worker_streaming.py b/tests/dirty/test_worker_streaming.py index bb674590..6efc471d 100644 --- a/tests/dirty/test_worker_streaming.py +++ b/tests/dirty/test_worker_streaming.py @@ -12,9 +12,11 @@ import pytest from gunicorn.dirty.protocol import ( DirtyProtocol, + BinaryProtocol, make_request, make_chunk_message, make_end_message, + HEADER_SIZE, ) from gunicorn.dirty.worker import DirtyWorker @@ -30,17 +32,22 @@ class FakeStreamWriter: self._buffer += data async def drain(self): - # Decode the buffer to extract messages - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break @@ -101,7 +108,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -109,7 +116,7 @@ class TestWorkerSyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -120,7 +127,7 @@ class TestWorkerSyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_sync_generator_error_mid_stream(self): @@ -136,7 +143,7 @@ class TestWorkerSyncGeneratorStreaming: return generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -167,7 +174,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 3 chunks + 1 end message @@ -175,7 +182,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check chunk messages assert writer.messages[0]["type"] == "chunk" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["data"] == "Hello" assert writer.messages[1]["type"] == "chunk" @@ -186,7 +193,7 @@ class TestWorkerAsyncGeneratorStreaming: # Check end message assert writer.messages[3]["type"] == "end" - assert writer.messages[3]["id"] == "req-123" + assert writer.messages[3]["id"] == 123 @pytest.mark.asyncio async def test_async_generator_error_mid_stream(self): @@ -202,7 +209,7 @@ class TestWorkerAsyncGeneratorStreaming: return async_generate_with_error() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 1 chunk + 1 error message @@ -228,13 +235,13 @@ class TestWorkerNonStreamingBackwardCompat: return args[0] + args[1] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "compute", args=(2, 3)) + request = make_request(123, "test:App", "compute", args=(2, 3)) await worker.handle_request(request, writer) # Should have 1 response message assert len(writer.messages) == 1 assert writer.messages[0]["type"] == "response" - assert writer.messages[0]["id"] == "req-123" + assert writer.messages[0]["id"] == 123 assert writer.messages[0]["result"] == 5 @pytest.mark.asyncio @@ -247,7 +254,7 @@ class TestWorkerNonStreamingBackwardCompat: return [1, 2, 3, 4, 5] with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "get_list") + request = make_request(123, "test:App", "get_list") await worker.handle_request(request, writer) # Should have 1 response message (not 5 chunks) @@ -265,7 +272,7 @@ class TestWorkerNonStreamingBackwardCompat: raise RuntimeError("Failed!") with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "fail") + request = make_request(123, "test:App", "fail") await worker.handle_request(request, writer) # Should have 1 error message @@ -283,7 +290,7 @@ class TestWorkerNonStreamingBackwardCompat: return None with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "void") + request = make_request(123, "test:App", "void") await worker.handle_request(request, writer) # Should have 1 response message @@ -309,7 +316,7 @@ class TestWorkerStreamingComplexData: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) assert len(writer.messages) == 3 # 2 chunks + 1 end @@ -332,7 +339,7 @@ class TestWorkerStreamingComplexData: return empty_generate() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have just 1 end message @@ -353,7 +360,7 @@ class TestWorkerStreamingComplexData: return generate_many() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have 100 chunks + 1 end message @@ -390,7 +397,7 @@ class TestWorkerStreamingHeartbeat: return generate_tokens() with mock.patch.object(worker, 'execute', side_effect=mock_execute): - request = make_request("req-123", "test:App", "generate") + request = make_request(123, "test:App", "generate") await worker.handle_request(request, writer) # Should have been notified at least once per chunk + initial @@ -407,7 +414,7 @@ class TestWorkerMessageTypeValidation: writer = FakeStreamWriter() # Send a message with unknown type - message = {"type": "unknown", "id": "req-123"} + message = {"type": "unknown", "id": 123} await worker.handle_request(message, writer) assert len(writer.messages) == 1 diff --git a/tests/docker/http2/certs/server.crt b/tests/docker/http2/certs/server.crt new file mode 100644 index 00000000..b4056d76 --- /dev/null +++ b/tests/docker/http2/certs/server.crt @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDfDCCAmSgAwIBAgIUDxTarKRHe0FIyczGmoYwm377ZpcwDQYJKoZIhvcNAQEL +BQAwOTESMBAGA1UEAwwJbG9jYWxob3N0MRYwFAYDVQQKDA1HdW5pY29ybiBUZXN0 +MQswCQYDVQQGEwJVUzAeFw0yNjAyMDUxMTE1MjJaFw0yNjAyMDYxMTE1MjJaMDkx +EjAQBgNVBAMMCWxvY2FsaG9zdDEWMBQGA1UECgwNR3VuaWNvcm4gVGVzdDELMAkG +A1UEBhMCVVMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCRQTHakkqY +6l6dMqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKt +z4rPoHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtq +AWqjKR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2 +HL5JP2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7Lr +FIp7wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySC +TNA/LsI8tsybAgMBAAGjfDB6MB0GA1UdDgQWBBRK2VkAeM0hL4j/45ckkKbGrb/Q +FjAfBgNVHSMEGDAWgBRK2VkAeM0hL4j/45ckkKbGrb/QFjAPBgNVHRMBAf8EBTAD +AQH/MCcGA1UdEQQgMB6CCWxvY2FsaG9zdIILZ3VuaWNvcm4taDKHBH8AAAEwDQYJ +KoZIhvcNAQELBQADggEBAAXwuw0KTQUC4UEFudQ1rceK6By9WCSJND7xJi+UQ50G +Zrp5tJ2YB4ZWY+APadfuJo+zUxYVZ3jhs0mxgVeiGdDW6yZdHkeX8MlXBTLHR+/a +A7DXn6wCw9NDeDtcY/bKg5iamvoGGTL6szPrqeuZPz4UdbsFlr0MdcjgSNOqnkjr +YS4ukgZ71aWSjfraRRPjFMzkfnQ1xm96A1ngMH4DvU/t62D7r8+SvxQ8M6ERL84Z +FBu4bTXDdYIjJ24ojmDDO2irTVW1FMGXQTPzMaTEbE1rvBYeEYhf10KiMynK9xfO +5j8LWmCkgek0CqBrf3zbDEwu8QxcaxITAIUkSXLOZbo= +-----END CERTIFICATE----- diff --git a/tests/docker/http2/certs/server.key b/tests/docker/http2/certs/server.key new file mode 100644 index 00000000..3d472c7c --- /dev/null +++ b/tests/docker/http2/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCRQTHakkqY6l6d +Mqfs4oiY98+rjvZubwjp0PH7UBuxXCi/4Ao78o0JhKcs+jgAGAXyb8eRjEKtz4rP +oHZYE91D/eD0lWAz9r/LRoutDJd9IO0rfDtHlYXamciuxJJ8cckOrnuTXLtqAWqj +KR3U9RIDD3eumCKG4l7Py0L67zTomwMRPfeIdlWBfxGjWMqOewdTc/O/cuK2HL5J +P2ixy+iTufs0jhljI9cbu49J606f+TQH9eXRTD716q+KsHPJX1X5dVd7V7LrFIp7 +wSUFdbiy56JfmrGmfJbZgFH67P0ZyiTpQBaVHt1YYRIcOUJZqM+0MAtrsySCTNA/ +LsI8tsybAgMBAAECggEANBhGOYZLI9G2sjlXOaG7bOU/wV9KKaw/7Z/HEaOW8wLD +CKHg+cQRai79yCdLi1kSVPNbB2vfBDhRqAp8NzWUn0x/8ChcsvZVriF0edFwyWtU +NErfddp+Absy2t9cTC6A9feFEYJqIug0JyVZciWc2qUi/ubIR0kLyQm00YuWFa/s +GJou8Nhg70rqW+3FB1H8kAEXqob+PFW4xbTwexw1+MbHxN7UKLTzS8uzYGLo2UpB +7bksumyD0o+lZtlx9HZ6CwrB6IPjgJ0HyaD8SrOc7/ozd7rR2LmvMmBCV1uC5VSO +jhr0PScLoNv60fjkVOiF9uqaPY2kNKymsOzpZ7/mwQKBgQDMcz+ve8WGGbE+bbM7 +2uinQ5smm8rWPnfbHJIHQUetrEQKljRovybmjiiXN08uxlX6VA/Vnp4fmL5fzsTD +xTeiCVPsR1huXIfMLGJ6crUgvlbiaB8XsxtVNBpfEEtBe27qjSIj3xtmwqM6+LD1 +FKLsYzgotHUH9JwyLA1RMKPBwQKBgQC14QWtI5YtZcTX46BqxlZ07iAAuy19Jywn +UtgmTawkJuEcseewIjxtJkMz+aSy7V3PsLII8tY48oSjAVx84w50zLJ2OlJnFT1S +zEmIOu9YDcGLZkYXJ2AwndRAIXpJVHwtFM9eDSMh+wVPBFeboYP1dO/VxmN6QV0W +GqDaQfItWwKBgEb31mp2n0j+UB0ofSfQxCOTfx62w4D87CPd1f64tUXe3zuBii21 +9K3hOMvMwiqtZBjh5yEyzxaOsb6WCo0eP0J61GvXFCYy7lx8J67zdFYqXAR5OhnC +7UD1NhY7lLPlQcofNXOYNW3FMF3/B4X7JNbDVjIi+eDKExIDYpgFN0LBAoGADGCf +7kR5t+UxHDAVfq64u4RpESOr2NSNoK92nkSy7lLnBvjkd4wc6KCt+h+HIdYdiEDS +HOHJyl5WwHEbRjR9i11S19DoQrOjVLsqVecM2sU04rO3GWRIm4ZiJ2sf01W4jajY +4+Go/msC1XnKLIE1ZcLrf3Tc2DkSiKqPP8s1G/kCgYA8sCPAXedwhULhOBM45x4J +vkwT1Icm5RHOwOr8t34IFozTLokba6pjhYua3nE+V3FglRct7NpX+Op4gUgHa80g +5zoHboq5/pTUTclx41jndC1YGa3NLvthDWTWmyo/Qj7F/R7jGJf8E3KUDe0tFoSp +JlfEuUHtKpFJReBnmWTFiQ== +-----END PRIVATE KEY----- diff --git a/tests/test_dirty_integration.py b/tests/test_dirty_integration.py index f24c6894..a841cf2c 100644 --- a/tests/test_dirty_integration.py +++ b/tests/test_dirty_integration.py @@ -11,7 +11,7 @@ import pytest from gunicorn.arbiter import Arbiter from gunicorn.config import Config from gunicorn.app.base import BaseApplication -from gunicorn.dirty.protocol import DirtyProtocol +from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, HEADER_SIZE class MockStreamWriter: @@ -26,16 +26,22 @@ class MockStreamWriter: self._buffer += data async def drain(self): - while len(self._buffer) >= DirtyProtocol.HEADER_SIZE: - length = struct.unpack( - DirtyProtocol.HEADER_FORMAT, - self._buffer[:DirtyProtocol.HEADER_SIZE] - )[0] - total_size = DirtyProtocol.HEADER_SIZE + length + # Decode the buffer to extract messages using binary protocol + while len(self._buffer) >= HEADER_SIZE: + # Decode header to get payload length + _, _, length = BinaryProtocol.decode_header( + self._buffer[:HEADER_SIZE] + ) + total_size = HEADER_SIZE + length if len(self._buffer) >= total_size: - msg_data = self._buffer[DirtyProtocol.HEADER_SIZE:total_size] + msg_data = self._buffer[:total_size] self._buffer = self._buffer[total_size:] - self.messages.append(DirtyProtocol.decode(msg_data)) + # decode_message returns (msg_type_str, request_id, payload_dict) + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + # Reconstruct the dict format for backwards compatibility + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) else: break From 00da70292fe192fc570e67dcc751cf19a5d0d5fe Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:15:14 +0100 Subject: [PATCH 06/11] docs(dirty): update examples and docs for binary protocol - Update test_protocol.py example to use binary protocol - Add test_binary_data_handling example showing raw bytes transfer - Update dirty.md to document binary TLV protocol format - Replace JSON references with binary protocol - Add Binary Protocol section with header and TLV encoding details --- docs/content/dirty.md | 62 ++++++++++++++- examples/dirty_example/test_protocol.py | 100 +++++++++++++++++------- 2 files changed, 132 insertions(+), 30 deletions(-) diff --git a/docs/content/dirty.md b/docs/content/dirty.md index 8afede06..bb98f6ee 100644 --- a/docs/content/dirty.md +++ b/docs/content/dirty.md @@ -17,7 +17,7 @@ Dirty Arbiters solve this by providing: - **Separate worker pool** - Completely separate from HTTP workers, can be killed/restarted independently - **Stateful workers** - Loaded resources persist in dirty worker memory -- **Message-passing IPC** - Communication via Unix sockets with JSON serialization +- **Message-passing IPC** - Communication via Unix sockets with binary TLV protocol - **Explicit API** - Clear `execute()` calls (no hidden IPC) - **Asyncio-based** - Clean concurrent handling with streaming support @@ -476,6 +476,64 @@ Worker -> Arbiter -> Client: chunk (data: "World") Worker -> Arbiter -> Client: end ``` +## Binary Protocol + +The dirty worker IPC uses a binary protocol inspired by OpenBSD msgctl/msgsnd for efficient data transfer. This eliminates base64 encoding overhead for binary data like images, audio, or model weights. + +### Header Format (16 bytes) + +``` ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Magic (2B) | Ver(1) | MType | Payload Length (4B) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +| Request ID (8 bytes) | ++--------+--------+--------+--------+--------+--------+--------+--------+ +``` + +- **Magic**: `0x47 0x44` ("GD" for Gunicorn Dirty) +- **Version**: `0x01` +- **MType**: Message type (`0x01`=REQUEST, `0x02`=RESPONSE, `0x03`=ERROR, `0x04`=CHUNK, `0x05`=END) +- **Length**: Payload size (big-endian uint32, max 64MB) +- **Request ID**: uint64 identifier + +### TLV Payload Encoding + +Payloads use Type-Length-Value encoding: + +| Type | Code | Description | +|------|------|-------------| +| None | `0x00` | No value bytes | +| Bool | `0x01` | 1 byte (0x00/0x01) | +| Int64 | `0x05` | 8 bytes big-endian signed | +| Float64 | `0x06` | 8 bytes IEEE 754 | +| Bytes | `0x10` | 4-byte length + raw bytes | +| String | `0x11` | 4-byte length + UTF-8 | +| List | `0x20` | 4-byte count + elements | +| Dict | `0x21` | 4-byte count + key-value pairs | + +### Binary Data Benefits + +The binary protocol allows passing raw bytes directly without encoding: + +```python +# Image processing with binary data +def resize(self, image_data, width, height): + """Resize an image - image_data is raw bytes.""" + img = Image.open(io.BytesIO(image_data)) + resized = img.resize((width, height)) + buffer = io.BytesIO() + resized.save(buffer, format='PNG') + return buffer.getvalue() # Returns raw bytes + +# Called from HTTP worker +thumbnail = client.execute( + "myapp.images:ImageApp", + "thumbnail", + raw_image_bytes, # No base64 encoding needed + size=256 +) +``` + ### Error Handling in Streams Errors during streaming are delivered as error messages: @@ -768,7 +826,7 @@ except DirtyConnectionError: 2. **Set appropriate timeouts** based on your workload 3. **Handle errors gracefully** - dirty workers may restart 4. **Use meaningful action names** for easier debugging -5. **Keep responses JSON-serializable** - results are passed via IPC +5. **Keep responses serializable** - results are passed via binary IPC (supports bytes directly) ## Monitoring diff --git a/examples/dirty_example/test_protocol.py b/examples/dirty_example/test_protocol.py index 31c45b92..8a677d03 100644 --- a/examples/dirty_example/test_protocol.py +++ b/examples/dirty_example/test_protocol.py @@ -4,7 +4,10 @@ #!/usr/bin/env python """ -Test script to demonstrate the Dirty Protocol layer. +Test script to demonstrate the Dirty Binary Protocol layer. + +The binary protocol uses a 16-byte header + TLV-encoded payloads for efficient +binary data transfer without base64 encoding overhead. Run with: python examples/dirty_example/test_protocol.py @@ -18,10 +21,14 @@ import socket sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from gunicorn.dirty.protocol import ( + BinaryProtocol, DirtyProtocol, make_request, make_response, make_error_response, + HEADER_SIZE, + MAGIC, + VERSION, ) from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError @@ -29,13 +36,13 @@ from gunicorn.dirty.errors import DirtyError, DirtyTimeoutError def test_protocol_encode_decode(): """Test protocol encoding and decoding.""" print("=" * 60) - print("Testing Protocol Encode/Decode") + print("Testing Binary Protocol Encode/Decode") print("=" * 60) - # Test request + # Test request with integer ID (recommended for binary protocol) print("\n1. Creating a request message...") request = make_request( - request_id="req-001", + request_id=12345, # Integer IDs are efficient app_path="myapp.ml:MLApp", action="inference", args=("model1",), @@ -43,18 +50,53 @@ def test_protocol_encode_decode(): ) print(f" Request: {request}") - # Encode - print("\n2. Encoding message...") - encoded = DirtyProtocol.encode(request) + # Encode using binary protocol + print("\n2. Encoding message with binary protocol...") + encoded = BinaryProtocol._encode_from_dict(request) print(f" Encoded length: {len(encoded)} bytes") - print(f" Header (4 bytes): {encoded[:4].hex()}") + print(f" Header ({HEADER_SIZE} bytes): {encoded[:HEADER_SIZE].hex()}") + print(f" Magic: {MAGIC!r}") + print(f" Version: {VERSION}") - # Decode - print("\n3. Decoding payload...") - payload = encoded[DirtyProtocol.HEADER_SIZE:] - decoded = DirtyProtocol.decode(payload) - print(f" Decoded: {decoded}") - print(f" Match: {decoded == request}") + # Decode header + print("\n3. Decoding header...") + msg_type, request_id, payload_len = BinaryProtocol.decode_header(encoded[:HEADER_SIZE]) + print(f" Message type: {msg_type} (0x{msg_type:02x})") + print(f" Request ID: {request_id}") + print(f" Payload length: {payload_len} bytes") + + # Decode full message + print("\n4. Decoding full message...") + msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded) + print(f" Type: {msg_type_str}") + print(f" Request ID: {req_id}") + print(f" Payload: {payload}") + + +def test_binary_data_handling(): + """Test binary data handling - the main advantage of binary protocol.""" + print("\n" + "=" * 60) + print("Testing Binary Data Handling") + print("=" * 60) + + # Create binary data (e.g., image, audio, model weights) + binary_data = bytes(range(256)) # All byte values + print(f"\n1. Original binary data: {len(binary_data)} bytes") + print(f" First 16 bytes: {binary_data[:16].hex()}") + + # Create response with binary data (no base64 needed!) + print("\n2. Encoding binary data in response...") + response = make_response(67890, {"image_data": binary_data, "format": "raw"}) + encoded = BinaryProtocol._encode_from_dict(response) + print(f" Encoded total size: {len(encoded)} bytes") + + # Decode and verify + print("\n3. Decoding binary data...") + msg_type_str, req_id, payload = BinaryProtocol.decode_message(encoded) + recovered_data = payload["result"]["image_data"] + print(f" Recovered data size: {len(recovered_data)} bytes") + print(f" Data matches: {recovered_data == binary_data}") + print(f" First 16 bytes: {recovered_data[:16].hex()}") def test_protocol_response(): @@ -65,13 +107,13 @@ def test_protocol_response(): # Success response print("\n1. Creating success response...") - response = make_response("req-001", {"result": "Hello, World!", "confidence": 0.95}) + response = make_response(12345, {"result": "Hello, World!", "confidence": 0.95}) print(f" Response: {response}") # Error response print("\n2. Creating error response...") error = DirtyTimeoutError("Operation timed out", timeout=30) - error_response = make_error_response("req-001", error) + error_response = make_error_response(12345, error) print(f" Error response: {error_response}") @@ -88,7 +130,7 @@ def test_socket_communication(): # Send a request print("\n1. Sending request over socket...") request = make_request( - request_id="socket-req-001", + request_id=100001, app_path="test:App", action="compute", args=(1, 2, 3), @@ -101,19 +143,20 @@ def test_socket_communication(): print("\n2. Receiving request...") received = DirtyProtocol.read_message(server_sock) print(f" Received: {received}") - print(f" Match: {received == request}") + print(f" Request ID: {received['id']}") - # Send a response - print("\n3. Sending response...") - response = make_response("socket-req-001", {"sum": 6}) + # Send a response with binary data + print("\n3. Sending response with binary data...") + binary_result = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc" + response = make_response(100001, {"data": binary_result, "sum": 6}) DirtyProtocol.write_message(server_sock, response) - print(f" Sent: {response}") + print(f" Sent binary data: {binary_result.hex()}") # Receive the response print("\n4. Receiving response...") received = DirtyProtocol.read_message(client_sock) - print(f" Received: {received}") - print(f" Match: {received == response}") + print(f" Received binary data: {received['result']['data'].hex()}") + print(f" Sum: {received['result']['sum']}") finally: server_sock.close() @@ -132,7 +175,7 @@ async def test_async_communication(): try: # Create message request = make_request( - request_id="async-req-001", + request_id=200001, app_path="async:App", action="process", args=("data",), @@ -141,7 +184,7 @@ async def test_async_communication(): # Write to pipe print("\n1. Writing async message...") - encoded = DirtyProtocol.encode(request) + encoded = BinaryProtocol._encode_from_dict(request) os.write(write_fd, encoded) os.close(write_fd) write_fd = None @@ -156,7 +199,7 @@ async def test_async_communication(): received = await DirtyProtocol.read_message_async(reader) print(f" Received: {received}") - print(f" Match: {received == request}") + print(f" Request ID: {received['id']}") finally: if write_fd is not None: @@ -193,10 +236,11 @@ def test_error_serialization(): if __name__ == "__main__": print("\n" + "#" * 60) - print("# Dirty Protocol Demonstration") + print("# Dirty Binary Protocol Demonstration") print("#" * 60) test_protocol_encode_decode() + test_binary_data_handling() test_protocol_response() test_socket_communication() asyncio.run(test_async_communication()) From c0cc8c0de056d3306e44dee1f2d53bbc054276af Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:30:48 +0100 Subject: [PATCH 07/11] test(dirty): add Docker setup for dirty example integration tests - Add Dockerfile and docker-compose.yml for running examples in containers - Add test_integration.py for HTTP-level integration testing - Update test_worker_integration.py to use MockWriter for handle_request - Use integer request IDs for binary protocol compatibility - Add GUNICORN_BIND env var support in gunicorn_conf.py for Docker --- examples/dirty_example/Dockerfile | 25 ++++++ examples/dirty_example/docker-compose.yml | 54 +++++++++++++ examples/dirty_example/gunicorn_conf.py | 4 +- examples/dirty_example/test_integration.py | 81 +++++++++++++++++++ .../dirty_example/test_worker_integration.py | 56 ++++++++++--- 5 files changed, 210 insertions(+), 10 deletions(-) create mode 100644 examples/dirty_example/Dockerfile create mode 100644 examples/dirty_example/docker-compose.yml create mode 100644 examples/dirty_example/test_integration.py diff --git a/examples/dirty_example/Dockerfile b/examples/dirty_example/Dockerfile new file mode 100644 index 00000000..302578dc --- /dev/null +++ b/examples/dirty_example/Dockerfile @@ -0,0 +1,25 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +FROM python:3.12-slim + +WORKDIR /app + +# Copy gunicorn source +COPY . /app/gunicorn-src + +# Install gunicorn and dependencies +# setproctitle is needed for process title changes +RUN pip install --no-cache-dir /app/gunicorn-src setproctitle + +# Copy example files +COPY examples/dirty_example/ /app/examples/dirty_example/ + +WORKDIR /app + +# Expose the port +EXPOSE 8000 + +# Default command - run the example tests +CMD ["python", "-m", "pytest", "-v", "examples/dirty_example/"] diff --git a/examples/dirty_example/docker-compose.yml b/examples/dirty_example/docker-compose.yml new file mode 100644 index 00000000..c55669a3 --- /dev/null +++ b/examples/dirty_example/docker-compose.yml @@ -0,0 +1,54 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +services: + # Run the example tests (protocol, dirty app, worker integration) + tests: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + command: > + bash -c " + echo '=== Running Protocol Tests ===' && + python examples/dirty_example/test_protocol.py && + echo '' && + echo '=== Running Dirty App Tests ===' && + python examples/dirty_example/test_dirty_app.py && + echo '' && + echo '=== Running Worker Integration Tests ===' && + python examples/dirty_example/test_worker_integration.py && + echo '' && + echo '=== All tests passed! ===' + " + + # Run the full gunicorn server with dirty workers + server: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + ports: + - "8001:8000" + environment: + - GUNICORN_BIND=0.0.0.0:8000 + command: > + gunicorn examples.dirty_example.wsgi_app:app + -c examples/dirty_example/gunicorn_conf.py + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/')"] + interval: 5s + timeout: 5s + retries: 5 + start_period: 10s + + # Run integration test against the server + integration-test: + build: + context: ../.. + dockerfile: examples/dirty_example/Dockerfile + depends_on: + server: + condition: service_healthy + environment: + - TEST_BASE_URL=http://server:8000 + command: python examples/dirty_example/test_integration.py diff --git a/examples/dirty_example/gunicorn_conf.py b/examples/dirty_example/gunicorn_conf.py index a7877c3e..ba3d28e4 100644 --- a/examples/dirty_example/gunicorn_conf.py +++ b/examples/dirty_example/gunicorn_conf.py @@ -11,7 +11,9 @@ Run with: """ # Basic settings -bind = "127.0.0.1:8000" +# Use 0.0.0.0 for Docker, override with GUNICORN_BIND env var if needed +import os +bind = os.environ.get("GUNICORN_BIND", "127.0.0.1:8000") workers = 2 worker_class = "sync" timeout = 30 diff --git a/examples/dirty_example/test_integration.py b/examples/dirty_example/test_integration.py new file mode 100644 index 00000000..20008522 --- /dev/null +++ b/examples/dirty_example/test_integration.py @@ -0,0 +1,81 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +#!/usr/bin/env python +""" +Integration test for the dirty example server. + +This tests that the full gunicorn server with dirty workers responds +correctly to HTTP requests. + +Run with: + python examples/dirty_example/test_integration.py [base_url] + +Default base_url is http://localhost:8000 +""" + +import sys +import os +import json +import urllib.request +import urllib.error + + +def test_endpoint(base, path, expected_key=None): + """Test an endpoint and check for expected key in response.""" + url = base + path + print(f"Testing: {url}") + try: + with urllib.request.urlopen(url, timeout=10) as resp: + data = json.loads(resp.read()) + print(f" Response: {str(data)[:200]}") + if expected_key and expected_key not in data: + print(f" ERROR: Expected key '{expected_key}' not found!") + return False + return True + except urllib.error.HTTPError as e: + print(f" HTTP ERROR {e.code}: {e.reason}") + return False + except Exception as e: + print(f" ERROR: {e}") + return False + + +def main(): + # Get base URL from env or command line + base = os.environ.get("TEST_BASE_URL", "http://localhost:8000") + if len(sys.argv) > 1: + base = sys.argv[1] + + print(f"Testing dirty example server at: {base}") + print("=" * 60) + + # Define tests: (path, expected_key_in_response) + tests = [ + ("/", "endpoints"), + ("/models", "models"), + ("/load?name=test-model", "status"), + ("/inference?model=default&data=hello", "prediction"), + ("/fibonacci?n=20", "result"), + ("/prime?n=17", "is_prime"), + ("/stats", "ml_app"), + ("/unload?name=test-model", "status"), + ] + + failed = 0 + for path, key in tests: + if not test_endpoint(base, path, key): + failed += 1 + print() + + print("=" * 60) + if failed: + print(f"FAILED: {failed} tests failed") + sys.exit(1) + else: + print("SUCCESS: All integration tests passed!") + + +if __name__ == "__main__": + main() diff --git a/examples/dirty_example/test_worker_integration.py b/examples/dirty_example/test_worker_integration.py index 1711ee3d..acca9961 100644 --- a/examples/dirty_example/test_worker_integration.py +++ b/examples/dirty_example/test_worker_integration.py @@ -22,7 +22,7 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspa from gunicorn.config import Config from gunicorn.dirty.worker import DirtyWorker -from gunicorn.dirty.protocol import DirtyProtocol, make_request +from gunicorn.dirty.protocol import DirtyProtocol, BinaryProtocol, make_request, HEADER_SIZE class MockLog: @@ -35,6 +35,36 @@ class MockLog: def reopen_files(self): pass +class MockWriter: + """Mock StreamWriter that captures written responses.""" + + def __init__(self): + self.messages = [] + self._buffer = b"" + + def write(self, data): + self._buffer += data + + async def drain(self): + # Decode messages from buffer using binary protocol + while len(self._buffer) >= HEADER_SIZE: + _, _, length = BinaryProtocol.decode_header(self._buffer[:HEADER_SIZE]) + total_size = HEADER_SIZE + length + if len(self._buffer) >= total_size: + msg_data = self._buffer[:total_size] + self._buffer = self._buffer[total_size:] + msg_type_str, request_id, payload_dict = BinaryProtocol.decode_message(msg_data) + result = {"type": msg_type_str, "id": request_id} + result.update(payload_dict) + self.messages.append(result) + else: + break + + def get_last_response(self): + """Get the last response message.""" + return self.messages[-1] if self.messages else None + + async def test_worker_request_handling(): """Test that a worker can load apps and handle requests.""" print("=" * 60) @@ -75,52 +105,60 @@ async def test_worker_request_handling(): # Test handle_request with a proper request message print("\n3. Testing handle_request() - load_model...") request = make_request( - request_id="test-001", + request_id=1001, app_path="examples.dirty_example.dirty_app:MLApp", action="load_model", args=("gpt-4",), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Result: {response.get('result', response.get('error'))}") # Test inference print("\n4. Testing handle_request() - inference...") request = make_request( - request_id="test-002", + request_id=1002, app_path="examples.dirty_example.dirty_app:MLApp", action="inference", args=("default", "Hello AI!"), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Result: {response.get('result', response.get('error'))}") # Test error handling print("\n5. Testing error handling - unknown action...") request = make_request( - request_id="test-003", + request_id=1003, app_path="examples.dirty_example.dirty_app:MLApp", action="nonexistent_action", args=(), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Error: {response.get('error', {}).get('message')}") # Test app not found print("\n6. Testing error handling - app not found...") request = make_request( - request_id="test-004", + request_id=1004, app_path="nonexistent:App", action="test", args=(), kwargs={} ) - response = await worker.handle_request(request) + writer = MockWriter() + await worker.handle_request(request, writer) + response = writer.get_last_response() print(f" Response type: {response['type']}") print(f" Error type: {response.get('error', {}).get('error_type')}") From 68ce658f5d651c62bf0b92da5bdb080ce63cd002 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:39:53 +0100 Subject: [PATCH 08/11] fix(dirty): convert dict int keys to strings in TLV encoder JSON serializes all dict keys as strings, so for compatibility the TLV encoder should do the same. This fixes an error when tasks return dicts with integer keys (e.g., aggregation results grouped by numeric ID). --- gunicorn/dirty/tlv.py | 6 ++---- tests/test_dirty_tlv.py | 13 +++++++------ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/gunicorn/dirty/tlv.py b/gunicorn/dirty/tlv.py index 5682b0c6..e0b1e2f4 100644 --- a/gunicorn/dirty/tlv.py +++ b/gunicorn/dirty/tlv.py @@ -113,11 +113,9 @@ class TLVEncoder: ) parts = [bytes([TYPE_DICT]), struct.pack(">I", len(value))] for k, v in value.items(): - # Keys must be strings + # Convert keys to strings (like JSON) if not isinstance(k, str): - raise DirtyProtocolError( - f"Dict keys must be strings, got {type(k).__name__}" - ) + k = str(k) parts.append(TLVEncoder.encode(k)) parts.append(TLVEncoder.encode(v)) return b"".join(parts) diff --git a/tests/test_dirty_tlv.py b/tests/test_dirty_tlv.py index 87727203..c36b839a 100644 --- a/tests/test_dirty_tlv.py +++ b/tests/test_dirty_tlv.py @@ -307,12 +307,13 @@ class TestTLVEncoderDict: value, offset = TLVEncoder.decode(encoded, 0) assert value == data - def test_encode_dict_non_string_key(self): - """Test that non-string keys raise error.""" - data = {1: "value"} - with pytest.raises(DirtyProtocolError) as exc_info: - TLVEncoder.encode(data) - assert "keys must be strings" in str(exc_info.value).lower() + def test_encode_dict_non_string_key_converted(self): + """Test that non-string keys are converted to strings (like JSON).""" + data = {1: "value", 2: "other"} + encoded = TLVEncoder.encode(data) + decoded, _ = TLVEncoder.decode(encoded, 0) + # Keys should be converted to strings + assert decoded == {"1": "value", "2": "other"} class TestTLVEncoderComplexStructures: From 6d691b30e1f8218d4b2401f47d15cb13169a601a Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:40:10 +0100 Subject: [PATCH 09/11] chore: use different ports in example docker-compose files Avoid port conflicts when running multiple examples: - dirty_example: 8001 - embedding_service: 8002 - celery_alternative: 8003 --- examples/celery_alternative/docker-compose.yml | 2 +- examples/embedding_service/docker-compose.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/celery_alternative/docker-compose.yml b/examples/celery_alternative/docker-compose.yml index ebf8d303..66ce282e 100644 --- a/examples/celery_alternative/docker-compose.yml +++ b/examples/celery_alternative/docker-compose.yml @@ -15,7 +15,7 @@ services: context: ../.. # Gunicorn repo root dockerfile: examples/celery_alternative/Dockerfile ports: - - "8000:8000" + - "8003:8000" environment: - GUNICORN_WORKERS=4 - DIRTY_WORKERS=9 diff --git a/examples/embedding_service/docker-compose.yml b/examples/embedding_service/docker-compose.yml index 6b956fdc..d3408cf9 100644 --- a/examples/embedding_service/docker-compose.yml +++ b/examples/embedding_service/docker-compose.yml @@ -4,7 +4,7 @@ services: context: ../.. dockerfile: examples/embedding_service/Dockerfile ports: - - "8000:8000" + - "8002:8000" healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5)"] interval: 10s From f4e219716f44c901ea5feb9dbc82220d2751a507 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:48:07 +0100 Subject: [PATCH 10/11] fix(dirty): disable pylint too-many-return-statements in TLV --- gunicorn/dirty/tlv.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gunicorn/dirty/tlv.py b/gunicorn/dirty/tlv.py index e0b1e2f4..ec18cd76 100644 --- a/gunicorn/dirty/tlv.py +++ b/gunicorn/dirty/tlv.py @@ -50,7 +50,7 @@ class TLVEncoder: """ @staticmethod - def encode(value) -> bytes: + def encode(value) -> bytes: # pylint: disable=too-many-return-statements """ Encode a Python value to TLV binary format. @@ -125,7 +125,7 @@ class TLVEncoder: ) @staticmethod - def decode(data: bytes, offset: int = 0) -> tuple: + def decode(data: bytes, offset: int = 0) -> tuple: # pylint: disable=too-many-return-statements """ Decode a TLV-encoded value from binary data. From 415aa77343eddc46475511ceb80898ac08bfd84f Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Wed, 11 Feb 2026 23:54:36 +0100 Subject: [PATCH 11/11] fix: revert embedding_service port to 8000 for CI --- examples/embedding_service/docker-compose.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/embedding_service/docker-compose.yml b/examples/embedding_service/docker-compose.yml index d3408cf9..6b956fdc 100644 --- a/examples/embedding_service/docker-compose.yml +++ b/examples/embedding_service/docker-compose.yml @@ -4,7 +4,7 @@ services: context: ../.. dockerfile: examples/embedding_service/Dockerfile ports: - - "8002:8000" + - "8000:8000" healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/health', timeout=5)"] interval: 10s