feat(ctl): add message-based dirty worker management

Replace signal-based dirty add/remove with protocol messages:
- Add MSG_TYPE_MANAGE to dirty protocol for worker management
- Add MANAGE_OP_ADD and MANAGE_OP_REMOVE operation codes
- Add handle_manage_request() in DirtyArbiter
- Update handlers to send messages instead of SIGTTIN/SIGTTOU signals

New workers only load apps that haven't reached their worker limits.
When all apps are at their limits, returns reason in response.
Only increment num_workers when a worker is actually spawned.
This commit is contained in:
Benoit Chesneau 2026-02-13 02:25:37 +01:00
parent 7df260930c
commit e05e40d19b
4 changed files with 288 additions and 21 deletions

View File

@ -10,6 +10,7 @@ Provides handlers for all control commands with access to arbiter state.
import os
import signal
import socket
import time
@ -309,6 +310,8 @@ class CommandHandlers:
"""
Spawn additional dirty workers.
Sends a MANAGE message to the dirty arbiter to spawn workers.
Args:
count: Number of dirty workers to add (default 1)
@ -321,25 +324,15 @@ class CommandHandlers:
"error": "Dirty arbiter not running",
}
# Send TTIN signals to dirty arbiter
count = max(1, int(count))
try:
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN)
return {
"success": True,
"added": count,
}
except OSError as e:
return {
"success": False,
"error": str(e),
}
return self._send_manage_message("add", count)
def dirty_remove(self, count: int = 1) -> dict:
"""
Remove dirty workers.
Sends a MANAGE message to the dirty arbiter to remove workers.
Args:
count: Number of dirty workers to remove (default 1)
@ -352,16 +345,73 @@ class CommandHandlers:
"error": "Dirty arbiter not running",
}
# Send TTOU signals to dirty arbiter
count = max(1, int(count))
try:
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU)
return self._send_manage_message("remove", count)
def _send_manage_message(self, operation: str, count: int) -> dict:
"""
Send a worker management message to the dirty arbiter.
Args:
operation: "add" or "remove"
count: Number of workers to add/remove
Returns:
Dictionary with result or error
"""
# Get socket path from arbiter object or environment
dirty_socket_path = None
if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter:
dirty_socket_path = getattr(
self.arbiter.dirty_arbiter, 'socket_path', None
)
if not dirty_socket_path:
dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET')
if not dirty_socket_path:
return {
"success": True,
"removed": count,
"success": False,
"error": "Cannot find dirty arbiter socket path",
}
except OSError as e:
try:
from gunicorn.dirty.protocol import (
DirtyProtocol, MANAGE_OP_ADD, MANAGE_OP_REMOVE
)
op = MANAGE_OP_ADD if operation == "add" else MANAGE_OP_REMOVE
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(10.0)
sock.connect(dirty_socket_path)
# Send manage request
request = {
"type": DirtyProtocol.MSG_TYPE_MANAGE,
"id": 1,
"op": op,
"count": count,
}
DirtyProtocol.write_message(sock, request)
# Read response
response = DirtyProtocol.read_message(sock)
sock.close()
if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE:
return response.get("result", {"success": True})
elif response.get("type") == DirtyProtocol.MSG_TYPE_ERROR:
error = response.get("error", {})
return {
"success": False,
"error": error.get("message", str(error)),
}
else:
return {
"success": False,
"error": f"Unexpected response type: {response.get('type')}",
}
except Exception as e:
return {
"success": False,
"error": str(e),

View File

@ -41,6 +41,8 @@ from .protocol import (
STASH_OP_DELETE_TABLE,
STASH_OP_TABLES,
STASH_OP_EXISTS,
MANAGE_OP_ADD,
MANAGE_OP_REMOVE,
)
from .worker import DirtyWorker
@ -426,6 +428,9 @@ class DirtyArbiter:
# Handle status queries
elif msg_type == DirtyProtocol.MSG_TYPE_STATUS:
await self.handle_status_request(message, writer)
# Handle worker management (add/remove workers)
elif msg_type == DirtyProtocol.MSG_TYPE_MANAGE:
await self.handle_manage_request(message, writer)
else:
# Route request to a dirty worker - pass writer for streaming
await self.route_request(message, writer)
@ -690,6 +695,100 @@ class DirtyArbiter:
response = make_response(request_id, result)
await DirtyProtocol.write_message_async(client_writer, response)
async def handle_manage_request(self, message, client_writer):
"""
Handle a worker management request.
Supports adding or removing dirty workers via protocol messages.
Args:
message: Manage request message
client_writer: StreamWriter to send response to client
"""
request_id = message.get("id", "unknown")
op = message.get("op")
count = max(1, int(message.get("count", 1)))
try:
if op == MANAGE_OP_ADD:
# Add workers - only loads apps that need more workers
spawned = 0
for _ in range(count):
result = self.spawn_worker()
if result is not None:
self.num_workers += 1
spawned += 1
await asyncio.sleep(0.1)
# Provide feedback about why no workers were spawned
if spawned == 0:
result = {
"success": True,
"operation": "add",
"requested": count,
"spawned": 0,
"reason": "All apps have reached their worker limits",
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
else:
result = {
"success": True,
"operation": "add",
"requested": count,
"spawned": spawned,
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
elif op == MANAGE_OP_REMOVE:
# Remove workers (similar to TTOU signal but via message)
min_workers = self._get_minimum_workers()
removed = 0
for _ in range(count):
if self.num_workers <= min_workers:
break
if len(self.workers) <= 1:
break
self.num_workers -= 1
# Kill oldest worker
oldest_pid = min(self.workers.keys(),
key=lambda p: self.workers[p].age)
self.kill_worker(oldest_pid, signal.SIGTERM)
removed += 1
await asyncio.sleep(0.1)
result = {
"success": True,
"operation": "remove",
"requested": count,
"removed": removed,
"total_workers": len(self.workers),
"target_workers": self.num_workers,
}
else:
error = DirtyError(f"Unknown manage operation: {op}")
response = make_error_response(request_id, error)
await DirtyProtocol.write_message_async(client_writer, response)
return
self.log.info("Worker management: %s %d workers (spawned/removed: %d)",
"add" if op == MANAGE_OP_ADD else "remove",
count,
result.get("spawned", result.get("removed", 0)))
response = make_response(request_id, result)
await DirtyProtocol.write_message_async(client_writer, response)
except Exception as e:
self.log.error("Manage operation error: %s", e)
response = make_error_response(request_id, DirtyError(str(e)))
await DirtyProtocol.write_message_async(client_writer, response)
async def handle_stash_request(self, message, client_writer):
"""
Handle a stash operation directly in the arbiter.
@ -830,13 +929,17 @@ class DirtyArbiter:
self.kill_worker(oldest_pid, signal.SIGTERM)
await asyncio.sleep(0.1)
def spawn_worker(self):
def spawn_worker(self, force_all_apps=False):
"""
Spawn a new dirty worker.
Worker app assignment follows these priorities:
1. If there are pending respawns (from dead workers), use those apps
2. Otherwise, determine apps for a new worker based on allocation
3. If force_all_apps=True, spawn with all apps regardless of limits
Args:
force_all_apps: If True, spawn worker with all apps ignoring limits
Returns:
Worker PID in parent process, or None if no apps need workers
@ -844,6 +947,9 @@ class DirtyArbiter:
# Priority 1: Respawn dead worker with same apps
if self._pending_respawns:
app_paths = self._pending_respawns.pop(0)
elif force_all_apps:
# Force spawn with all apps (used by TTIN signal)
app_paths = list(self.app_specs.keys())
else:
# Priority 2: New worker for initial pool
app_paths = self._get_apps_for_new_worker()

View File

@ -44,6 +44,7 @@ MSG_TYPE_CHUNK = 0x04
MSG_TYPE_END = 0x05
MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers)
MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers
MSG_TYPE_MANAGE = 0x12 # Worker management (add/remove workers)
# Message type names (for backwards compatibility with old API)
MSG_TYPE_REQUEST_STR = "request"
@ -53,6 +54,7 @@ MSG_TYPE_CHUNK_STR = "chunk"
MSG_TYPE_END_STR = "end"
MSG_TYPE_STASH_STR = "stash"
MSG_TYPE_STATUS_STR = "status"
MSG_TYPE_MANAGE_STR = "manage"
# Map int types to string names
MSG_TYPE_TO_STR = {
@ -63,6 +65,7 @@ MSG_TYPE_TO_STR = {
MSG_TYPE_END: MSG_TYPE_END_STR,
MSG_TYPE_STASH: MSG_TYPE_STASH_STR,
MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR,
MSG_TYPE_MANAGE: MSG_TYPE_MANAGE_STR,
}
# Map string names to int types
@ -80,6 +83,10 @@ STASH_OP_DELETE_TABLE = 8
STASH_OP_TABLES = 9
STASH_OP_EXISTS = 10
# Manage operation codes
MANAGE_OP_ADD = 1 # Add/spawn workers
MANAGE_OP_REMOVE = 2 # Remove/kill workers
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
HEADER_FORMAT = ">2sBBIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
@ -102,6 +109,7 @@ class BinaryProtocol:
MSG_TYPE_END = MSG_TYPE_END_STR
MSG_TYPE_STASH = MSG_TYPE_STASH_STR
MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR
MSG_TYPE_MANAGE = MSG_TYPE_MANAGE_STR
@staticmethod
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
@ -292,6 +300,28 @@ class BinaryProtocol:
header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0)
return header
@staticmethod
def encode_manage(request_id: int, op: int, count: int = 1) -> bytes:
"""
Encode a worker management message.
Args:
request_id: Request identifier
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
count: Number of workers to add/remove
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {
"op": op,
"count": count,
}
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_MANAGE, request_id,
len(payload))
return header + payload
@staticmethod
def encode_stash(request_id: int, op: int, table: str,
key=None, value=None, pattern=None) -> bytes:
@ -603,6 +633,12 @@ class BinaryProtocol:
)
elif msg_type == MSG_TYPE_STATUS:
return BinaryProtocol.encode_status(request_id)
elif msg_type == MSG_TYPE_MANAGE:
return BinaryProtocol.encode_manage(
request_id,
message.get("op"),
message.get("count", 1)
)
else:
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
@ -752,3 +788,23 @@ def make_stash_message(request_id, op: int, table: str,
if pattern is not None:
msg["pattern"] = pattern
return msg
def make_manage_message(request_id, op: int, count: int = 1) -> dict:
"""
Build a worker management message dict.
Args:
request_id: Unique request identifier (int or str)
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
count: Number of workers to add/remove
Returns:
dict: Manage message dict
"""
return {
"type": DirtyProtocol.MSG_TYPE_MANAGE,
"id": request_id,
"op": op,
"count": count,
}

View File

@ -4,6 +4,7 @@
"""Tests for control socket command handlers."""
import os
import signal
import time
from unittest.mock import MagicMock, patch
@ -298,6 +299,60 @@ class TestShowDirty:
assert result["pid"] is None
class TestDirtyAdd:
"""Tests for dirty add command."""
def test_dirty_add_not_running(self):
"""Test dirty add when dirty arbiter not running."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.dirty_add()
assert result["success"] is False
assert "not running" in result["error"]
def test_dirty_add_no_socket(self):
"""Test dirty add when socket path not available."""
arbiter = MockArbiter()
arbiter.dirty_arbiter_pid = 2000
handlers = CommandHandlers(arbiter)
# No dirty_arbiter attribute and no env var
with patch.dict('os.environ', {}, clear=True):
result = handlers.dirty_add()
assert result["success"] is False
assert "socket" in result["error"].lower()
class TestDirtyRemove:
"""Tests for dirty remove command."""
def test_dirty_remove_not_running(self):
"""Test dirty remove when dirty arbiter not running."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.dirty_remove()
assert result["success"] is False
assert "not running" in result["error"]
def test_dirty_remove_no_socket(self):
"""Test dirty remove when socket path not available."""
arbiter = MockArbiter()
arbiter.dirty_arbiter_pid = 2000
handlers = CommandHandlers(arbiter)
# No dirty_arbiter attribute and no env var
with patch.dict('os.environ', {}, clear=True):
result = handlers.dirty_remove()
assert result["success"] is False
assert "socket" in result["error"].lower()
class TestReload:
"""Tests for reload command."""