From e05e40d19ba4e6e99bf92d7732b20880a15cd7fc Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Fri, 13 Feb 2026 02:25:37 +0100 Subject: [PATCH] feat(ctl): add message-based dirty worker management Replace signal-based dirty add/remove with protocol messages: - Add MSG_TYPE_MANAGE to dirty protocol for worker management - Add MANAGE_OP_ADD and MANAGE_OP_REMOVE operation codes - Add handle_manage_request() in DirtyArbiter - Update handlers to send messages instead of SIGTTIN/SIGTTOU signals New workers only load apps that haven't reached their worker limits. When all apps are at their limits, returns reason in response. Only increment num_workers when a worker is actually spawned. --- gunicorn/ctl/handlers.py | 90 ++++++++++++++++++++++++------- gunicorn/dirty/arbiter.py | 108 ++++++++++++++++++++++++++++++++++++- gunicorn/dirty/protocol.py | 56 +++++++++++++++++++ tests/ctl/test_handlers.py | 55 +++++++++++++++++++ 4 files changed, 288 insertions(+), 21 deletions(-) diff --git a/gunicorn/ctl/handlers.py b/gunicorn/ctl/handlers.py index 83480389..ffde7393 100644 --- a/gunicorn/ctl/handlers.py +++ b/gunicorn/ctl/handlers.py @@ -10,6 +10,7 @@ Provides handlers for all control commands with access to arbiter state. import os import signal +import socket import time @@ -309,6 +310,8 @@ class CommandHandlers: """ Spawn additional dirty workers. + Sends a MANAGE message to the dirty arbiter to spawn workers. + Args: count: Number of dirty workers to add (default 1) @@ -321,25 +324,15 @@ class CommandHandlers: "error": "Dirty arbiter not running", } - # Send TTIN signals to dirty arbiter count = max(1, int(count)) - try: - for _ in range(count): - os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN) - return { - "success": True, - "added": count, - } - except OSError as e: - return { - "success": False, - "error": str(e), - } + return self._send_manage_message("add", count) def dirty_remove(self, count: int = 1) -> dict: """ Remove dirty workers. + Sends a MANAGE message to the dirty arbiter to remove workers. + Args: count: Number of dirty workers to remove (default 1) @@ -352,16 +345,73 @@ class CommandHandlers: "error": "Dirty arbiter not running", } - # Send TTOU signals to dirty arbiter count = max(1, int(count)) - try: - for _ in range(count): - os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU) + return self._send_manage_message("remove", count) + + def _send_manage_message(self, operation: str, count: int) -> dict: + """ + Send a worker management message to the dirty arbiter. + + Args: + operation: "add" or "remove" + count: Number of workers to add/remove + + Returns: + Dictionary with result or error + """ + # Get socket path from arbiter object or environment + dirty_socket_path = None + if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter: + dirty_socket_path = getattr( + self.arbiter.dirty_arbiter, 'socket_path', None + ) + if not dirty_socket_path: + dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET') + if not dirty_socket_path: return { - "success": True, - "removed": count, + "success": False, + "error": "Cannot find dirty arbiter socket path", } - except OSError as e: + + try: + from gunicorn.dirty.protocol import ( + DirtyProtocol, MANAGE_OP_ADD, MANAGE_OP_REMOVE + ) + + op = MANAGE_OP_ADD if operation == "add" else MANAGE_OP_REMOVE + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(10.0) + sock.connect(dirty_socket_path) + + # Send manage request + request = { + "type": DirtyProtocol.MSG_TYPE_MANAGE, + "id": 1, + "op": op, + "count": count, + } + DirtyProtocol.write_message(sock, request) + + # Read response + response = DirtyProtocol.read_message(sock) + sock.close() + + if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE: + return response.get("result", {"success": True}) + elif response.get("type") == DirtyProtocol.MSG_TYPE_ERROR: + error = response.get("error", {}) + return { + "success": False, + "error": error.get("message", str(error)), + } + else: + return { + "success": False, + "error": f"Unexpected response type: {response.get('type')}", + } + + except Exception as e: return { "success": False, "error": str(e), diff --git a/gunicorn/dirty/arbiter.py b/gunicorn/dirty/arbiter.py index 47f14dc5..fe7567ba 100644 --- a/gunicorn/dirty/arbiter.py +++ b/gunicorn/dirty/arbiter.py @@ -41,6 +41,8 @@ from .protocol import ( STASH_OP_DELETE_TABLE, STASH_OP_TABLES, STASH_OP_EXISTS, + MANAGE_OP_ADD, + MANAGE_OP_REMOVE, ) from .worker import DirtyWorker @@ -426,6 +428,9 @@ class DirtyArbiter: # Handle status queries elif msg_type == DirtyProtocol.MSG_TYPE_STATUS: await self.handle_status_request(message, writer) + # Handle worker management (add/remove workers) + elif msg_type == DirtyProtocol.MSG_TYPE_MANAGE: + await self.handle_manage_request(message, writer) else: # Route request to a dirty worker - pass writer for streaming await self.route_request(message, writer) @@ -690,6 +695,100 @@ class DirtyArbiter: response = make_response(request_id, result) await DirtyProtocol.write_message_async(client_writer, response) + async def handle_manage_request(self, message, client_writer): + """ + Handle a worker management request. + + Supports adding or removing dirty workers via protocol messages. + + Args: + message: Manage request message + client_writer: StreamWriter to send response to client + """ + request_id = message.get("id", "unknown") + op = message.get("op") + count = max(1, int(message.get("count", 1))) + + try: + if op == MANAGE_OP_ADD: + # Add workers - only loads apps that need more workers + spawned = 0 + for _ in range(count): + result = self.spawn_worker() + if result is not None: + self.num_workers += 1 + spawned += 1 + await asyncio.sleep(0.1) + + # Provide feedback about why no workers were spawned + if spawned == 0: + result = { + "success": True, + "operation": "add", + "requested": count, + "spawned": 0, + "reason": "All apps have reached their worker limits", + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + else: + result = { + "success": True, + "operation": "add", + "requested": count, + "spawned": spawned, + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + + elif op == MANAGE_OP_REMOVE: + # Remove workers (similar to TTOU signal but via message) + min_workers = self._get_minimum_workers() + removed = 0 + + for _ in range(count): + if self.num_workers <= min_workers: + break + if len(self.workers) <= 1: + break + + self.num_workers -= 1 + + # Kill oldest worker + oldest_pid = min(self.workers.keys(), + key=lambda p: self.workers[p].age) + self.kill_worker(oldest_pid, signal.SIGTERM) + removed += 1 + await asyncio.sleep(0.1) + + result = { + "success": True, + "operation": "remove", + "requested": count, + "removed": removed, + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + + else: + error = DirtyError(f"Unknown manage operation: {op}") + response = make_error_response(request_id, error) + await DirtyProtocol.write_message_async(client_writer, response) + return + + self.log.info("Worker management: %s %d workers (spawned/removed: %d)", + "add" if op == MANAGE_OP_ADD else "remove", + count, + result.get("spawned", result.get("removed", 0))) + + response = make_response(request_id, result) + await DirtyProtocol.write_message_async(client_writer, response) + + except Exception as e: + self.log.error("Manage operation error: %s", e) + response = make_error_response(request_id, DirtyError(str(e))) + await DirtyProtocol.write_message_async(client_writer, response) + async def handle_stash_request(self, message, client_writer): """ Handle a stash operation directly in the arbiter. @@ -830,13 +929,17 @@ class DirtyArbiter: self.kill_worker(oldest_pid, signal.SIGTERM) await asyncio.sleep(0.1) - def spawn_worker(self): + def spawn_worker(self, force_all_apps=False): """ Spawn a new dirty worker. Worker app assignment follows these priorities: 1. If there are pending respawns (from dead workers), use those apps 2. Otherwise, determine apps for a new worker based on allocation + 3. If force_all_apps=True, spawn with all apps regardless of limits + + Args: + force_all_apps: If True, spawn worker with all apps ignoring limits Returns: Worker PID in parent process, or None if no apps need workers @@ -844,6 +947,9 @@ class DirtyArbiter: # Priority 1: Respawn dead worker with same apps if self._pending_respawns: app_paths = self._pending_respawns.pop(0) + elif force_all_apps: + # Force spawn with all apps (used by TTIN signal) + app_paths = list(self.app_specs.keys()) else: # Priority 2: New worker for initial pool app_paths = self._get_apps_for_new_worker() diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index 8091b23f..3d216c54 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -44,6 +44,7 @@ MSG_TYPE_CHUNK = 0x04 MSG_TYPE_END = 0x05 MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers) MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers +MSG_TYPE_MANAGE = 0x12 # Worker management (add/remove workers) # Message type names (for backwards compatibility with old API) MSG_TYPE_REQUEST_STR = "request" @@ -53,6 +54,7 @@ MSG_TYPE_CHUNK_STR = "chunk" MSG_TYPE_END_STR = "end" MSG_TYPE_STASH_STR = "stash" MSG_TYPE_STATUS_STR = "status" +MSG_TYPE_MANAGE_STR = "manage" # Map int types to string names MSG_TYPE_TO_STR = { @@ -63,6 +65,7 @@ MSG_TYPE_TO_STR = { MSG_TYPE_END: MSG_TYPE_END_STR, MSG_TYPE_STASH: MSG_TYPE_STASH_STR, MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR, + MSG_TYPE_MANAGE: MSG_TYPE_MANAGE_STR, } # Map string names to int types @@ -80,6 +83,10 @@ STASH_OP_DELETE_TABLE = 8 STASH_OP_TABLES = 9 STASH_OP_EXISTS = 10 +# Manage operation codes +MANAGE_OP_ADD = 1 # Add/spawn workers +MANAGE_OP_REMOVE = 2 # Remove/kill workers + # Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 HEADER_FORMAT = ">2sBBIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) @@ -102,6 +109,7 @@ class BinaryProtocol: MSG_TYPE_END = MSG_TYPE_END_STR MSG_TYPE_STASH = MSG_TYPE_STASH_STR MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR + MSG_TYPE_MANAGE = MSG_TYPE_MANAGE_STR @staticmethod def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: @@ -292,6 +300,28 @@ class BinaryProtocol: header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0) return header + @staticmethod + def encode_manage(request_id: int, op: int, count: int = 1) -> bytes: + """ + Encode a worker management message. + + Args: + request_id: Request identifier + op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE) + count: Number of workers to add/remove + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = { + "op": op, + "count": count, + } + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_MANAGE, request_id, + len(payload)) + return header + payload + @staticmethod def encode_stash(request_id: int, op: int, table: str, key=None, value=None, pattern=None) -> bytes: @@ -603,6 +633,12 @@ class BinaryProtocol: ) elif msg_type == MSG_TYPE_STATUS: return BinaryProtocol.encode_status(request_id) + elif msg_type == MSG_TYPE_MANAGE: + return BinaryProtocol.encode_manage( + request_id, + message.get("op"), + message.get("count", 1) + ) else: raise DirtyProtocolError(f"Unhandled message type: {msg_type}") @@ -752,3 +788,23 @@ def make_stash_message(request_id, op: int, table: str, if pattern is not None: msg["pattern"] = pattern return msg + + +def make_manage_message(request_id, op: int, count: int = 1) -> dict: + """ + Build a worker management message dict. + + Args: + request_id: Unique request identifier (int or str) + op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE) + count: Number of workers to add/remove + + Returns: + dict: Manage message dict + """ + return { + "type": DirtyProtocol.MSG_TYPE_MANAGE, + "id": request_id, + "op": op, + "count": count, + } diff --git a/tests/ctl/test_handlers.py b/tests/ctl/test_handlers.py index a6bcde28..4d279e94 100644 --- a/tests/ctl/test_handlers.py +++ b/tests/ctl/test_handlers.py @@ -4,6 +4,7 @@ """Tests for control socket command handlers.""" +import os import signal import time from unittest.mock import MagicMock, patch @@ -298,6 +299,60 @@ class TestShowDirty: assert result["pid"] is None +class TestDirtyAdd: + """Tests for dirty add command.""" + + def test_dirty_add_not_running(self): + """Test dirty add when dirty arbiter not running.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.dirty_add() + + assert result["success"] is False + assert "not running" in result["error"] + + def test_dirty_add_no_socket(self): + """Test dirty add when socket path not available.""" + arbiter = MockArbiter() + arbiter.dirty_arbiter_pid = 2000 + handlers = CommandHandlers(arbiter) + + # No dirty_arbiter attribute and no env var + with patch.dict('os.environ', {}, clear=True): + result = handlers.dirty_add() + + assert result["success"] is False + assert "socket" in result["error"].lower() + + +class TestDirtyRemove: + """Tests for dirty remove command.""" + + def test_dirty_remove_not_running(self): + """Test dirty remove when dirty arbiter not running.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.dirty_remove() + + assert result["success"] is False + assert "not running" in result["error"] + + def test_dirty_remove_no_socket(self): + """Test dirty remove when socket path not available.""" + arbiter = MockArbiter() + arbiter.dirty_arbiter_pid = 2000 + handlers = CommandHandlers(arbiter) + + # No dirty_arbiter attribute and no env var + with patch.dict('os.environ', {}, clear=True): + result = handlers.dirty_remove() + + assert result["success"] is False + assert "socket" in result["error"].lower() + + class TestReload: """Tests for reload command."""