feat(ctl): add message-based dirty worker management

Replace signal-based dirty add/remove with protocol messages:
- Add MSG_TYPE_MANAGE to dirty protocol for worker management
- Add MANAGE_OP_ADD and MANAGE_OP_REMOVE operation codes
- Add handle_manage_request() in DirtyArbiter
- Update handlers to send messages instead of SIGTTIN/SIGTTOU signals

New workers only load apps that haven't reached their worker limits.
When all apps are at their limits, returns reason in response.
Only increment num_workers when a worker is actually spawned.
This commit is contained in:
Benoit Chesneau 2026-02-13 02:25:37 +01:00
parent 7df260930c
commit e05e40d19b
4 changed files with 288 additions and 21 deletions

View File

@ -10,6 +10,7 @@ Provides handlers for all control commands with access to arbiter state.
import os import os
import signal import signal
import socket
import time import time
@ -309,6 +310,8 @@ class CommandHandlers:
""" """
Spawn additional dirty workers. Spawn additional dirty workers.
Sends a MANAGE message to the dirty arbiter to spawn workers.
Args: Args:
count: Number of dirty workers to add (default 1) count: Number of dirty workers to add (default 1)
@ -321,25 +324,15 @@ class CommandHandlers:
"error": "Dirty arbiter not running", "error": "Dirty arbiter not running",
} }
# Send TTIN signals to dirty arbiter
count = max(1, int(count)) count = max(1, int(count))
try: return self._send_manage_message("add", count)
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN)
return {
"success": True,
"added": count,
}
except OSError as e:
return {
"success": False,
"error": str(e),
}
def dirty_remove(self, count: int = 1) -> dict: def dirty_remove(self, count: int = 1) -> dict:
""" """
Remove dirty workers. Remove dirty workers.
Sends a MANAGE message to the dirty arbiter to remove workers.
Args: Args:
count: Number of dirty workers to remove (default 1) count: Number of dirty workers to remove (default 1)
@ -352,16 +345,73 @@ class CommandHandlers:
"error": "Dirty arbiter not running", "error": "Dirty arbiter not running",
} }
# Send TTOU signals to dirty arbiter
count = max(1, int(count)) count = max(1, int(count))
try: return self._send_manage_message("remove", count)
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU) def _send_manage_message(self, operation: str, count: int) -> dict:
"""
Send a worker management message to the dirty arbiter.
Args:
operation: "add" or "remove"
count: Number of workers to add/remove
Returns:
Dictionary with result or error
"""
# Get socket path from arbiter object or environment
dirty_socket_path = None
if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter:
dirty_socket_path = getattr(
self.arbiter.dirty_arbiter, 'socket_path', None
)
if not dirty_socket_path:
dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET')
if not dirty_socket_path:
return { return {
"success": True, "success": False,
"removed": count, "error": "Cannot find dirty arbiter socket path",
} }
except OSError as e:
try:
from gunicorn.dirty.protocol import (
DirtyProtocol, MANAGE_OP_ADD, MANAGE_OP_REMOVE
)
op = MANAGE_OP_ADD if operation == "add" else MANAGE_OP_REMOVE
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(10.0)
sock.connect(dirty_socket_path)
# Send manage request
request = {
"type": DirtyProtocol.MSG_TYPE_MANAGE,
"id": 1,
"op": op,
"count": count,
}
DirtyProtocol.write_message(sock, request)
# Read response
response = DirtyProtocol.read_message(sock)
sock.close()
if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE:
return response.get("result", {"success": True})
elif response.get("type") == DirtyProtocol.MSG_TYPE_ERROR:
error = response.get("error", {})
return {
"success": False,
"error": error.get("message", str(error)),
}
else:
return {
"success": False,
"error": f"Unexpected response type: {response.get('type')}",
}
except Exception as e:
return { return {
"success": False, "success": False,
"error": str(e), "error": str(e),

View File

@ -41,6 +41,8 @@ from .protocol import (
STASH_OP_DELETE_TABLE, STASH_OP_DELETE_TABLE,
STASH_OP_TABLES, STASH_OP_TABLES,
STASH_OP_EXISTS, STASH_OP_EXISTS,
MANAGE_OP_ADD,
MANAGE_OP_REMOVE,
) )
from .worker import DirtyWorker from .worker import DirtyWorker
@ -426,6 +428,9 @@ class DirtyArbiter:
# Handle status queries # Handle status queries
elif msg_type == DirtyProtocol.MSG_TYPE_STATUS: elif msg_type == DirtyProtocol.MSG_TYPE_STATUS:
await self.handle_status_request(message, writer) await self.handle_status_request(message, writer)
# Handle worker management (add/remove workers)
elif msg_type == DirtyProtocol.MSG_TYPE_MANAGE:
await self.handle_manage_request(message, writer)
else: else:
# Route request to a dirty worker - pass writer for streaming # Route request to a dirty worker - pass writer for streaming
await self.route_request(message, writer) await self.route_request(message, writer)
@ -690,6 +695,100 @@ class DirtyArbiter:
response = make_response(request_id, result) response = make_response(request_id, result)
await DirtyProtocol.write_message_async(client_writer, response) await DirtyProtocol.write_message_async(client_writer, response)
async def handle_manage_request(self, message, client_writer):
"""
Handle a worker management request.
Supports adding or removing dirty workers via protocol messages.
Args:
message: Manage request message
client_writer: StreamWriter to send response to client
"""
request_id = message.get("id", "unknown")
op = message.get("op")
count = max(1, int(message.get("count", 1)))
try:
if op == MANAGE_OP_ADD:
# Add workers - only loads apps that need more workers
spawned = 0
for _ in range(count):
result = self.spawn_worker()
if result is not None:
self.num_workers += 1
spawned += 1
await asyncio.sleep(0.1)
# Provide feedback about why no workers were spawned
if spawned == 0:
result = {
"success": True,
"operation": "add",
"requested": count,
"spawned": 0,
"reason": "All apps have reached their worker limits",
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
else:
result = {
"success": True,
"operation": "add",
"requested": count,
"spawned": spawned,
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
elif op == MANAGE_OP_REMOVE:
# Remove workers (similar to TTOU signal but via message)
min_workers = self._get_minimum_workers()
removed = 0
for _ in range(count):
if self.num_workers <= min_workers:
break
if len(self.workers) <= 1:
break
self.num_workers -= 1
# Kill oldest worker
oldest_pid = min(self.workers.keys(),
key=lambda p: self.workers[p].age)
self.kill_worker(oldest_pid, signal.SIGTERM)
removed += 1
await asyncio.sleep(0.1)
result = {
"success": True,
"operation": "remove",
"requested": count,
"removed": removed,
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
else:
error = DirtyError(f"Unknown manage operation: {op}")
response = make_error_response(request_id, error)
await DirtyProtocol.write_message_async(client_writer, response)
return
self.log.info("Worker management: %s %d workers (spawned/removed: %d)",
"add" if op == MANAGE_OP_ADD else "remove",
count,
result.get("spawned", result.get("removed", 0)))
response = make_response(request_id, result)
await DirtyProtocol.write_message_async(client_writer, response)
except Exception as e:
self.log.error("Manage operation error: %s", e)
response = make_error_response(request_id, DirtyError(str(e)))
await DirtyProtocol.write_message_async(client_writer, response)
async def handle_stash_request(self, message, client_writer): async def handle_stash_request(self, message, client_writer):
""" """
Handle a stash operation directly in the arbiter. Handle a stash operation directly in the arbiter.
@ -830,13 +929,17 @@ class DirtyArbiter:
self.kill_worker(oldest_pid, signal.SIGTERM) self.kill_worker(oldest_pid, signal.SIGTERM)
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
def spawn_worker(self): def spawn_worker(self, force_all_apps=False):
""" """
Spawn a new dirty worker. Spawn a new dirty worker.
Worker app assignment follows these priorities: Worker app assignment follows these priorities:
1. If there are pending respawns (from dead workers), use those apps 1. If there are pending respawns (from dead workers), use those apps
2. Otherwise, determine apps for a new worker based on allocation 2. Otherwise, determine apps for a new worker based on allocation
3. If force_all_apps=True, spawn with all apps regardless of limits
Args:
force_all_apps: If True, spawn worker with all apps ignoring limits
Returns: Returns:
Worker PID in parent process, or None if no apps need workers Worker PID in parent process, or None if no apps need workers
@ -844,6 +947,9 @@ class DirtyArbiter:
# Priority 1: Respawn dead worker with same apps # Priority 1: Respawn dead worker with same apps
if self._pending_respawns: if self._pending_respawns:
app_paths = self._pending_respawns.pop(0) app_paths = self._pending_respawns.pop(0)
elif force_all_apps:
# Force spawn with all apps (used by TTIN signal)
app_paths = list(self.app_specs.keys())
else: else:
# Priority 2: New worker for initial pool # Priority 2: New worker for initial pool
app_paths = self._get_apps_for_new_worker() app_paths = self._get_apps_for_new_worker()

View File

@ -44,6 +44,7 @@ MSG_TYPE_CHUNK = 0x04
MSG_TYPE_END = 0x05 MSG_TYPE_END = 0x05
MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers) MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers)
MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers
MSG_TYPE_MANAGE = 0x12 # Worker management (add/remove workers)
# Message type names (for backwards compatibility with old API) # Message type names (for backwards compatibility with old API)
MSG_TYPE_REQUEST_STR = "request" MSG_TYPE_REQUEST_STR = "request"
@ -53,6 +54,7 @@ MSG_TYPE_CHUNK_STR = "chunk"
MSG_TYPE_END_STR = "end" MSG_TYPE_END_STR = "end"
MSG_TYPE_STASH_STR = "stash" MSG_TYPE_STASH_STR = "stash"
MSG_TYPE_STATUS_STR = "status" MSG_TYPE_STATUS_STR = "status"
MSG_TYPE_MANAGE_STR = "manage"
# Map int types to string names # Map int types to string names
MSG_TYPE_TO_STR = { MSG_TYPE_TO_STR = {
@ -63,6 +65,7 @@ MSG_TYPE_TO_STR = {
MSG_TYPE_END: MSG_TYPE_END_STR, MSG_TYPE_END: MSG_TYPE_END_STR,
MSG_TYPE_STASH: MSG_TYPE_STASH_STR, MSG_TYPE_STASH: MSG_TYPE_STASH_STR,
MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR, MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR,
MSG_TYPE_MANAGE: MSG_TYPE_MANAGE_STR,
} }
# Map string names to int types # Map string names to int types
@ -80,6 +83,10 @@ STASH_OP_DELETE_TABLE = 8
STASH_OP_TABLES = 9 STASH_OP_TABLES = 9
STASH_OP_EXISTS = 10 STASH_OP_EXISTS = 10
# Manage operation codes
MANAGE_OP_ADD = 1 # Add/spawn workers
MANAGE_OP_REMOVE = 2 # Remove/kill workers
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 # Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
HEADER_FORMAT = ">2sBBIQ" HEADER_FORMAT = ">2sBBIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT) HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
@ -102,6 +109,7 @@ class BinaryProtocol:
MSG_TYPE_END = MSG_TYPE_END_STR MSG_TYPE_END = MSG_TYPE_END_STR
MSG_TYPE_STASH = MSG_TYPE_STASH_STR MSG_TYPE_STASH = MSG_TYPE_STASH_STR
MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR
MSG_TYPE_MANAGE = MSG_TYPE_MANAGE_STR
@staticmethod @staticmethod
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
@ -292,6 +300,28 @@ class BinaryProtocol:
header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0) header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0)
return header return header
@staticmethod
def encode_manage(request_id: int, op: int, count: int = 1) -> bytes:
"""
Encode a worker management message.
Args:
request_id: Request identifier
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
count: Number of workers to add/remove
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {
"op": op,
"count": count,
}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_MANAGE, request_id,
len(payload))
return header + payload
@staticmethod @staticmethod
def encode_stash(request_id: int, op: int, table: str, def encode_stash(request_id: int, op: int, table: str,
key=None, value=None, pattern=None) -> bytes: key=None, value=None, pattern=None) -> bytes:
@ -603,6 +633,12 @@ class BinaryProtocol:
) )
elif msg_type == MSG_TYPE_STATUS: elif msg_type == MSG_TYPE_STATUS:
return BinaryProtocol.encode_status(request_id) return BinaryProtocol.encode_status(request_id)
elif msg_type == MSG_TYPE_MANAGE:
return BinaryProtocol.encode_manage(
request_id,
message.get("op"),
message.get("count", 1)
)
else: else:
raise DirtyProtocolError(f"Unhandled message type: {msg_type}") raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
@ -752,3 +788,23 @@ def make_stash_message(request_id, op: int, table: str,
if pattern is not None: if pattern is not None:
msg["pattern"] = pattern msg["pattern"] = pattern
return msg return msg
def make_manage_message(request_id, op: int, count: int = 1) -> dict:
"""
Build a worker management message dict.
Args:
request_id: Unique request identifier (int or str)
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
count: Number of workers to add/remove
Returns:
dict: Manage message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_MANAGE,
"id": request_id,
"op": op,
"count": count,
}

View File

@ -4,6 +4,7 @@
"""Tests for control socket command handlers.""" """Tests for control socket command handlers."""
import os
import signal import signal
import time import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@ -298,6 +299,60 @@ class TestShowDirty:
assert result["pid"] is None assert result["pid"] is None
class TestDirtyAdd:
"""Tests for dirty add command."""
def test_dirty_add_not_running(self):
"""Test dirty add when dirty arbiter not running."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.dirty_add()
assert result["success"] is False
assert "not running" in result["error"]
def test_dirty_add_no_socket(self):
"""Test dirty add when socket path not available."""
arbiter = MockArbiter()
arbiter.dirty_arbiter_pid = 2000
handlers = CommandHandlers(arbiter)
# No dirty_arbiter attribute and no env var
with patch.dict('os.environ', {}, clear=True):
result = handlers.dirty_add()
assert result["success"] is False
assert "socket" in result["error"].lower()
class TestDirtyRemove:
"""Tests for dirty remove command."""
def test_dirty_remove_not_running(self):
"""Test dirty remove when dirty arbiter not running."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.dirty_remove()
assert result["success"] is False
assert "not running" in result["error"]
def test_dirty_remove_no_socket(self):
"""Test dirty remove when socket path not available."""
arbiter = MockArbiter()
arbiter.dirty_arbiter_pid = 2000
handlers = CommandHandlers(arbiter)
# No dirty_arbiter attribute and no env var
with patch.dict('os.environ', {}, clear=True):
result = handlers.dirty_remove()
assert result["success"] is False
assert "socket" in result["error"].lower()
class TestReload: class TestReload:
"""Tests for reload command.""" """Tests for reload command."""