mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-02 18:51:31 +08:00
feat(ctl): add message-based dirty worker management
Replace signal-based dirty add/remove with protocol messages: - Add MSG_TYPE_MANAGE to dirty protocol for worker management - Add MANAGE_OP_ADD and MANAGE_OP_REMOVE operation codes - Add handle_manage_request() in DirtyArbiter - Update handlers to send messages instead of SIGTTIN/SIGTTOU signals New workers only load apps that haven't reached their worker limits. When all apps are at their limits, returns reason in response. Only increment num_workers when a worker is actually spawned.
This commit is contained in:
parent
7df260930c
commit
e05e40d19b
@ -10,6 +10,7 @@ Provides handlers for all control commands with access to arbiter state.
|
||||
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import time
|
||||
|
||||
|
||||
@ -309,6 +310,8 @@ class CommandHandlers:
|
||||
"""
|
||||
Spawn additional dirty workers.
|
||||
|
||||
Sends a MANAGE message to the dirty arbiter to spawn workers.
|
||||
|
||||
Args:
|
||||
count: Number of dirty workers to add (default 1)
|
||||
|
||||
@ -321,25 +324,15 @@ class CommandHandlers:
|
||||
"error": "Dirty arbiter not running",
|
||||
}
|
||||
|
||||
# Send TTIN signals to dirty arbiter
|
||||
count = max(1, int(count))
|
||||
try:
|
||||
for _ in range(count):
|
||||
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN)
|
||||
return {
|
||||
"success": True,
|
||||
"added": count,
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
return self._send_manage_message("add", count)
|
||||
|
||||
def dirty_remove(self, count: int = 1) -> dict:
|
||||
"""
|
||||
Remove dirty workers.
|
||||
|
||||
Sends a MANAGE message to the dirty arbiter to remove workers.
|
||||
|
||||
Args:
|
||||
count: Number of dirty workers to remove (default 1)
|
||||
|
||||
@ -352,16 +345,73 @@ class CommandHandlers:
|
||||
"error": "Dirty arbiter not running",
|
||||
}
|
||||
|
||||
# Send TTOU signals to dirty arbiter
|
||||
count = max(1, int(count))
|
||||
try:
|
||||
for _ in range(count):
|
||||
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU)
|
||||
return self._send_manage_message("remove", count)
|
||||
|
||||
def _send_manage_message(self, operation: str, count: int) -> dict:
|
||||
"""
|
||||
Send a worker management message to the dirty arbiter.
|
||||
|
||||
Args:
|
||||
operation: "add" or "remove"
|
||||
count: Number of workers to add/remove
|
||||
|
||||
Returns:
|
||||
Dictionary with result or error
|
||||
"""
|
||||
# Get socket path from arbiter object or environment
|
||||
dirty_socket_path = None
|
||||
if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter:
|
||||
dirty_socket_path = getattr(
|
||||
self.arbiter.dirty_arbiter, 'socket_path', None
|
||||
)
|
||||
if not dirty_socket_path:
|
||||
dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET')
|
||||
if not dirty_socket_path:
|
||||
return {
|
||||
"success": True,
|
||||
"removed": count,
|
||||
"success": False,
|
||||
"error": "Cannot find dirty arbiter socket path",
|
||||
}
|
||||
except OSError as e:
|
||||
|
||||
try:
|
||||
from gunicorn.dirty.protocol import (
|
||||
DirtyProtocol, MANAGE_OP_ADD, MANAGE_OP_REMOVE
|
||||
)
|
||||
|
||||
op = MANAGE_OP_ADD if operation == "add" else MANAGE_OP_REMOVE
|
||||
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.settimeout(10.0)
|
||||
sock.connect(dirty_socket_path)
|
||||
|
||||
# Send manage request
|
||||
request = {
|
||||
"type": DirtyProtocol.MSG_TYPE_MANAGE,
|
||||
"id": 1,
|
||||
"op": op,
|
||||
"count": count,
|
||||
}
|
||||
DirtyProtocol.write_message(sock, request)
|
||||
|
||||
# Read response
|
||||
response = DirtyProtocol.read_message(sock)
|
||||
sock.close()
|
||||
|
||||
if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE:
|
||||
return response.get("result", {"success": True})
|
||||
elif response.get("type") == DirtyProtocol.MSG_TYPE_ERROR:
|
||||
error = response.get("error", {})
|
||||
return {
|
||||
"success": False,
|
||||
"error": error.get("message", str(error)),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unexpected response type: {response.get('type')}",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
||||
@ -41,6 +41,8 @@ from .protocol import (
|
||||
STASH_OP_DELETE_TABLE,
|
||||
STASH_OP_TABLES,
|
||||
STASH_OP_EXISTS,
|
||||
MANAGE_OP_ADD,
|
||||
MANAGE_OP_REMOVE,
|
||||
)
|
||||
from .worker import DirtyWorker
|
||||
|
||||
@ -426,6 +428,9 @@ class DirtyArbiter:
|
||||
# Handle status queries
|
||||
elif msg_type == DirtyProtocol.MSG_TYPE_STATUS:
|
||||
await self.handle_status_request(message, writer)
|
||||
# Handle worker management (add/remove workers)
|
||||
elif msg_type == DirtyProtocol.MSG_TYPE_MANAGE:
|
||||
await self.handle_manage_request(message, writer)
|
||||
else:
|
||||
# Route request to a dirty worker - pass writer for streaming
|
||||
await self.route_request(message, writer)
|
||||
@ -690,6 +695,100 @@ class DirtyArbiter:
|
||||
response = make_response(request_id, result)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
async def handle_manage_request(self, message, client_writer):
|
||||
"""
|
||||
Handle a worker management request.
|
||||
|
||||
Supports adding or removing dirty workers via protocol messages.
|
||||
|
||||
Args:
|
||||
message: Manage request message
|
||||
client_writer: StreamWriter to send response to client
|
||||
"""
|
||||
request_id = message.get("id", "unknown")
|
||||
op = message.get("op")
|
||||
count = max(1, int(message.get("count", 1)))
|
||||
|
||||
try:
|
||||
if op == MANAGE_OP_ADD:
|
||||
# Add workers - only loads apps that need more workers
|
||||
spawned = 0
|
||||
for _ in range(count):
|
||||
result = self.spawn_worker()
|
||||
if result is not None:
|
||||
self.num_workers += 1
|
||||
spawned += 1
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Provide feedback about why no workers were spawned
|
||||
if spawned == 0:
|
||||
result = {
|
||||
"success": True,
|
||||
"operation": "add",
|
||||
"requested": count,
|
||||
"spawned": 0,
|
||||
"reason": "All apps have reached their worker limits",
|
||||
"total_workers": len(self.workers),
|
||||
"target_workers": self.num_workers,
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"success": True,
|
||||
"operation": "add",
|
||||
"requested": count,
|
||||
"spawned": spawned,
|
||||
"total_workers": len(self.workers),
|
||||
"target_workers": self.num_workers,
|
||||
}
|
||||
|
||||
elif op == MANAGE_OP_REMOVE:
|
||||
# Remove workers (similar to TTOU signal but via message)
|
||||
min_workers = self._get_minimum_workers()
|
||||
removed = 0
|
||||
|
||||
for _ in range(count):
|
||||
if self.num_workers <= min_workers:
|
||||
break
|
||||
if len(self.workers) <= 1:
|
||||
break
|
||||
|
||||
self.num_workers -= 1
|
||||
|
||||
# Kill oldest worker
|
||||
oldest_pid = min(self.workers.keys(),
|
||||
key=lambda p: self.workers[p].age)
|
||||
self.kill_worker(oldest_pid, signal.SIGTERM)
|
||||
removed += 1
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"operation": "remove",
|
||||
"requested": count,
|
||||
"removed": removed,
|
||||
"total_workers": len(self.workers),
|
||||
"target_workers": self.num_workers,
|
||||
}
|
||||
|
||||
else:
|
||||
error = DirtyError(f"Unknown manage operation: {op}")
|
||||
response = make_error_response(request_id, error)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
return
|
||||
|
||||
self.log.info("Worker management: %s %d workers (spawned/removed: %d)",
|
||||
"add" if op == MANAGE_OP_ADD else "remove",
|
||||
count,
|
||||
result.get("spawned", result.get("removed", 0)))
|
||||
|
||||
response = make_response(request_id, result)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Manage operation error: %s", e)
|
||||
response = make_error_response(request_id, DirtyError(str(e)))
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
async def handle_stash_request(self, message, client_writer):
|
||||
"""
|
||||
Handle a stash operation directly in the arbiter.
|
||||
@ -830,13 +929,17 @@ class DirtyArbiter:
|
||||
self.kill_worker(oldest_pid, signal.SIGTERM)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def spawn_worker(self):
|
||||
def spawn_worker(self, force_all_apps=False):
|
||||
"""
|
||||
Spawn a new dirty worker.
|
||||
|
||||
Worker app assignment follows these priorities:
|
||||
1. If there are pending respawns (from dead workers), use those apps
|
||||
2. Otherwise, determine apps for a new worker based on allocation
|
||||
3. If force_all_apps=True, spawn with all apps regardless of limits
|
||||
|
||||
Args:
|
||||
force_all_apps: If True, spawn worker with all apps ignoring limits
|
||||
|
||||
Returns:
|
||||
Worker PID in parent process, or None if no apps need workers
|
||||
@ -844,6 +947,9 @@ class DirtyArbiter:
|
||||
# Priority 1: Respawn dead worker with same apps
|
||||
if self._pending_respawns:
|
||||
app_paths = self._pending_respawns.pop(0)
|
||||
elif force_all_apps:
|
||||
# Force spawn with all apps (used by TTIN signal)
|
||||
app_paths = list(self.app_specs.keys())
|
||||
else:
|
||||
# Priority 2: New worker for initial pool
|
||||
app_paths = self._get_apps_for_new_worker()
|
||||
|
||||
@ -44,6 +44,7 @@ MSG_TYPE_CHUNK = 0x04
|
||||
MSG_TYPE_END = 0x05
|
||||
MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers)
|
||||
MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers
|
||||
MSG_TYPE_MANAGE = 0x12 # Worker management (add/remove workers)
|
||||
|
||||
# Message type names (for backwards compatibility with old API)
|
||||
MSG_TYPE_REQUEST_STR = "request"
|
||||
@ -53,6 +54,7 @@ MSG_TYPE_CHUNK_STR = "chunk"
|
||||
MSG_TYPE_END_STR = "end"
|
||||
MSG_TYPE_STASH_STR = "stash"
|
||||
MSG_TYPE_STATUS_STR = "status"
|
||||
MSG_TYPE_MANAGE_STR = "manage"
|
||||
|
||||
# Map int types to string names
|
||||
MSG_TYPE_TO_STR = {
|
||||
@ -63,6 +65,7 @@ MSG_TYPE_TO_STR = {
|
||||
MSG_TYPE_END: MSG_TYPE_END_STR,
|
||||
MSG_TYPE_STASH: MSG_TYPE_STASH_STR,
|
||||
MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR,
|
||||
MSG_TYPE_MANAGE: MSG_TYPE_MANAGE_STR,
|
||||
}
|
||||
|
||||
# Map string names to int types
|
||||
@ -80,6 +83,10 @@ STASH_OP_DELETE_TABLE = 8
|
||||
STASH_OP_TABLES = 9
|
||||
STASH_OP_EXISTS = 10
|
||||
|
||||
# Manage operation codes
|
||||
MANAGE_OP_ADD = 1 # Add/spawn workers
|
||||
MANAGE_OP_REMOVE = 2 # Remove/kill workers
|
||||
|
||||
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
|
||||
HEADER_FORMAT = ">2sBBIQ"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
@ -102,6 +109,7 @@ class BinaryProtocol:
|
||||
MSG_TYPE_END = MSG_TYPE_END_STR
|
||||
MSG_TYPE_STASH = MSG_TYPE_STASH_STR
|
||||
MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR
|
||||
MSG_TYPE_MANAGE = MSG_TYPE_MANAGE_STR
|
||||
|
||||
@staticmethod
|
||||
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
|
||||
@ -292,6 +300,28 @@ class BinaryProtocol:
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0)
|
||||
return header
|
||||
|
||||
@staticmethod
|
||||
def encode_manage(request_id: int, op: int, count: int = 1) -> bytes:
|
||||
"""
|
||||
Encode a worker management message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier
|
||||
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
|
||||
count: Number of workers to add/remove
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {
|
||||
"op": op,
|
||||
"count": count,
|
||||
}
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_MANAGE, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def encode_stash(request_id: int, op: int, table: str,
|
||||
key=None, value=None, pattern=None) -> bytes:
|
||||
@ -603,6 +633,12 @@ class BinaryProtocol:
|
||||
)
|
||||
elif msg_type == MSG_TYPE_STATUS:
|
||||
return BinaryProtocol.encode_status(request_id)
|
||||
elif msg_type == MSG_TYPE_MANAGE:
|
||||
return BinaryProtocol.encode_manage(
|
||||
request_id,
|
||||
message.get("op"),
|
||||
message.get("count", 1)
|
||||
)
|
||||
else:
|
||||
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
|
||||
|
||||
@ -752,3 +788,23 @@ def make_stash_message(request_id, op: int, table: str,
|
||||
if pattern is not None:
|
||||
msg["pattern"] = pattern
|
||||
return msg
|
||||
|
||||
|
||||
def make_manage_message(request_id, op: int, count: int = 1) -> dict:
|
||||
"""
|
||||
Build a worker management message dict.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (int or str)
|
||||
op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE)
|
||||
count: Number of workers to add/remove
|
||||
|
||||
Returns:
|
||||
dict: Manage message dict
|
||||
"""
|
||||
return {
|
||||
"type": DirtyProtocol.MSG_TYPE_MANAGE,
|
||||
"id": request_id,
|
||||
"op": op,
|
||||
"count": count,
|
||||
}
|
||||
|
||||
@ -4,6 +4,7 @@
|
||||
|
||||
"""Tests for control socket command handlers."""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
@ -298,6 +299,60 @@ class TestShowDirty:
|
||||
assert result["pid"] is None
|
||||
|
||||
|
||||
class TestDirtyAdd:
|
||||
"""Tests for dirty add command."""
|
||||
|
||||
def test_dirty_add_not_running(self):
|
||||
"""Test dirty add when dirty arbiter not running."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.dirty_add()
|
||||
|
||||
assert result["success"] is False
|
||||
assert "not running" in result["error"]
|
||||
|
||||
def test_dirty_add_no_socket(self):
|
||||
"""Test dirty add when socket path not available."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.dirty_arbiter_pid = 2000
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
# No dirty_arbiter attribute and no env var
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
result = handlers.dirty_add()
|
||||
|
||||
assert result["success"] is False
|
||||
assert "socket" in result["error"].lower()
|
||||
|
||||
|
||||
class TestDirtyRemove:
|
||||
"""Tests for dirty remove command."""
|
||||
|
||||
def test_dirty_remove_not_running(self):
|
||||
"""Test dirty remove when dirty arbiter not running."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.dirty_remove()
|
||||
|
||||
assert result["success"] is False
|
||||
assert "not running" in result["error"]
|
||||
|
||||
def test_dirty_remove_no_socket(self):
|
||||
"""Test dirty remove when socket path not available."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.dirty_arbiter_pid = 2000
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
# No dirty_arbiter attribute and no env var
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
result = handlers.dirty_remove()
|
||||
|
||||
assert result["success"] is False
|
||||
assert "socket" in result["error"].lower()
|
||||
|
||||
|
||||
class TestReload:
|
||||
"""Tests for reload command."""
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user