feat(ctl): add gunicornc control interface

Add a control socket server and CLI client for runtime management
of Gunicorn instances, similar to birdc for BIRD routing daemon.

Features:
- Control socket server running in arbiter process (asyncio/threaded)
- gunicornc CLI with interactive and single-command modes
- JSON protocol with length-prefixed framing
- Commands: show workers/stats/config/listeners/dirty, worker add/remove/kill,
  dirty add/remove, reload, reopen, shutdown
- Stats tracking (uptime, workers spawned/killed, reloads)
- Configurable socket path and permissions

New config options:
- control_socket: Unix socket path (default: gunicorn.ctl)
- control_socket_mode: Socket permissions (default: 0o600)
- --no-control-socket: Disable control socket
This commit is contained in:
Benoit Chesneau 2026-02-13 01:28:46 +01:00
parent 3cba17b84a
commit a57507c4e5
15 changed files with 2889 additions and 7 deletions

View File

@ -74,6 +74,17 @@ class Arbiter:
self.dirty_arbiter = None
self.dirty_pidfile = None # Well-known location for orphan detection
# Control socket server
self._control_server = None
# Stats tracking
self._stats = {
'start_time': None,
'workers_spawned': 0,
'workers_killed': 0,
'reloads': 0,
}
cwd = util.getcwd()
args = sys.argv[:]
@ -133,6 +144,9 @@ class Arbiter:
"""
self.log.info("Starting gunicorn %s", __version__)
# Initialize stats tracking
self._stats['start_time'] = time.time()
if 'GUNICORN_PID' in os.environ:
self.master_pid = int(os.environ.get('GUNICORN_PID'))
self.proc_name = self.proc_name + ".2"
@ -179,6 +193,9 @@ class Arbiter:
if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps:
self.spawn_dirty_arbiter()
# Start control socket server
self._start_control_server()
self.cfg.when_ready(self)
def init_signals(self):
@ -351,6 +368,9 @@ class Arbiter:
def halt(self, reason=None, exit_status=0):
""" halt arbiter """
# Stop control socket server first
self._stop_control_server()
self.stop()
log_func = self.log.info if exit_status == 0 else self.log.error
@ -477,6 +497,9 @@ class Arbiter:
os.execvpe(self.START_CTX[0], self.START_CTX['args'], environ)
def reload(self):
# Track reload stats
self._stats['reloads'] += 1
old_address = self.cfg.address
# reset old environment
@ -667,6 +690,7 @@ class Arbiter:
if pid != 0:
worker.pid = pid
self.WORKERS[pid] = worker
self._stats['workers_spawned'] += 1
return pid
# Do not inherit the temporary files of other workers
@ -737,6 +761,9 @@ class Arbiter:
"""
try:
os.kill(pid, sig)
# Track kills only on SIGTERM/SIGKILL (actual termination signals)
if sig in (signal.SIGTERM, signal.SIGKILL):
self._stats['workers_killed'] += 1
except OSError as e:
if e.errno == errno.ESRCH:
try:
@ -906,3 +933,51 @@ class Arbiter:
if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps:
self.log.info("Spawning dirty arbiter...")
self.spawn_dirty_arbiter()
# =========================================================================
# Control Socket Management
# =========================================================================
def _get_control_socket_path(self):
"""Get the control socket path, making relative paths absolute."""
socket_path = self.cfg.control_socket
if not os.path.isabs(socket_path):
socket_path = os.path.join(util.getcwd(), socket_path)
return socket_path
def _start_control_server(self):
"""\
Start the control socket server.
The server runs in a background thread and accepts commands
via Unix socket.
"""
if self.cfg.control_socket_disable:
self.log.debug("Control socket disabled")
return
# Lazy import to avoid circular imports and gevent compatibility
from gunicorn.ctl.server import ControlSocketServer
socket_path = self._get_control_socket_path()
socket_mode = self.cfg.control_socket_mode
try:
self._control_server = ControlSocketServer(
self, socket_path, socket_mode
)
self._control_server.start()
except Exception as e:
self.log.warning("Failed to start control socket: %s", e)
self._control_server = None
def _stop_control_server(self):
"""\
Stop the control socket server.
"""
if self._control_server:
try:
self._control_server.stop()
except Exception as e:
self.log.debug("Error stopping control server: %s", e)
self._control_server = None

View File

@ -3113,3 +3113,63 @@ class DirtyWorkerExit(Setting):
.. versionadded:: 25.0.0
"""
# Control Socket Settings
class ControlSocket(Setting):
name = "control_socket"
section = "Control"
cli = ["--control-socket"]
meta = "PATH"
validator = validate_string
default = "gunicorn.ctl"
desc = """\
Unix socket path for control interface.
The control socket allows runtime management of Gunicorn via the
``gunicornc`` command-line tool. Commands include viewing worker
status, adjusting worker count, and graceful reload/shutdown.
By default, creates ``gunicorn.ctl`` in the working directory.
Set an absolute path for a fixed location (e.g., ``/var/run/gunicorn.ctl``).
Use ``--no-control-socket`` to disable.
.. versionadded:: 25.1.0
"""
class ControlSocketMode(Setting):
name = "control_socket_mode"
section = "Control"
cli = ["--control-socket-mode"]
meta = "INT"
validator = validate_pos_int
type = auto_int
default = 0o600
desc = """\
Permission mode for control socket.
Restricts who can connect to the control socket. Default ``0600``
allows only the socket owner. Set to ``0660`` to allow group access.
.. versionadded:: 25.1.0
"""
class ControlSocketDisable(Setting):
name = "control_socket_disable"
section = "Control"
cli = ["--no-control-socket"]
validator = validate_bool
action = "store_true"
default = False
desc = """\
Disable control socket.
When set, no control socket is created and ``gunicornc`` cannot
connect to this Gunicorn instance.
.. versionadded:: 25.1.0
"""

16
gunicorn/ctl/__init__.py Normal file
View File

@ -0,0 +1,16 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Gunicorn Control Interface
Provides a control socket server for runtime management and
a CLI client (gunicornc) for interacting with running Gunicorn instances.
"""
from gunicorn.ctl.server import ControlSocketServer
from gunicorn.ctl.client import ControlClient
from gunicorn.ctl.protocol import ControlProtocol
__all__ = ['ControlSocketServer', 'ControlClient', 'ControlProtocol']

385
gunicorn/ctl/cli.py Normal file
View File

@ -0,0 +1,385 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
gunicornc - Gunicorn control interface CLI
Interactive and single-command modes for controlling Gunicorn instances.
"""
import argparse
import json
import os
import sys
from gunicorn.ctl.client import ControlClient, ControlClientError, parse_command
def format_workers(data: dict) -> str:
"""Format workers output for display."""
workers = data.get("workers", [])
if not workers:
return "No workers running"
lines = []
lines.append(f"{'PID':<10} {'AGE':<6} {'BOOTED':<8} {'LAST_BEAT'}")
lines.append("-" * 40)
for w in workers:
pid = w.get("pid", "?")
age = w.get("age", "?")
booted = "yes" if w.get("booted") else "no"
hb = w.get("last_heartbeat")
hb_str = f"{hb}s ago" if hb is not None else "n/a"
lines.append(f"{pid:<10} {age:<6} {booted:<8} {hb_str}")
lines.append("")
lines.append(f"Total: {data.get('count', len(workers))} workers")
return "\n".join(lines)
def format_dirty(data: dict) -> str:
"""Format dirty workers output for display."""
if not data.get("enabled"):
return "Dirty arbiter not running"
lines = []
lines.append(f"Dirty arbiter PID: {data.get('pid')}")
lines.append("")
workers = data.get("workers", [])
if workers:
lines.append("DIRTY WORKERS:")
lines.append(f"{'PID':<10} {'AGE':<6} {'APPS':<30} {'LAST_BEAT'}")
lines.append("-" * 60)
for w in workers:
pid = w.get("pid", "?")
age = w.get("age", "?")
apps = ", ".join(w.get("apps", []))[:30]
hb = w.get("last_heartbeat")
hb_str = f"{hb}s ago" if hb is not None else "n/a"
lines.append(f"{pid:<10} {age:<6} {apps:<30} {hb_str}")
lines.append("")
apps = data.get("apps", [])
if apps:
lines.append("DIRTY APPS:")
lines.append(f"{'APP':<30} {'WORKERS':<10} {'LIMIT'}")
lines.append("-" * 50)
for app in apps:
path = app.get("import_path", "?")[:30]
current = app.get("current_workers", 0)
limit = app.get("worker_count")
limit_str = str(limit) if limit is not None else "none"
lines.append(f"{path:<30} {current:<10} {limit_str}")
return "\n".join(lines)
def format_stats(data: dict) -> str:
"""Format stats output for display."""
lines = []
uptime = data.get("uptime")
if uptime:
hours = int(uptime // 3600)
minutes = int((uptime % 3600) // 60)
seconds = int(uptime % 60)
if hours:
uptime_str = f"{hours}h {minutes}m {seconds}s"
elif minutes:
uptime_str = f"{minutes}m {seconds}s"
else:
uptime_str = f"{seconds}s"
else:
uptime_str = "unknown"
lines.append(f"Uptime: {uptime_str}")
lines.append(f"PID: {data.get('pid', 'unknown')}")
lines.append(f"Workers current: {data.get('workers_current', 0)}")
lines.append(f"Workers target: {data.get('workers_target', 0)}")
lines.append(f"Workers spawned: {data.get('workers_spawned', 0)}")
lines.append(f"Workers killed: {data.get('workers_killed', 0)}")
lines.append(f"Reloads: {data.get('reloads', 0)}")
dirty_pid = data.get("dirty_arbiter_pid")
if dirty_pid:
lines.append(f"Dirty arbiter: {dirty_pid}")
return "\n".join(lines)
def format_listeners(data: dict) -> str:
"""Format listeners output for display."""
listeners = data.get("listeners", [])
if not listeners:
return "No listeners bound"
lines = []
lines.append(f"{'ADDRESS':<40} {'TYPE':<8} {'FD'}")
lines.append("-" * 55)
for lnr in listeners:
addr = lnr.get("address", "?")
ltype = lnr.get("type", "?")
fd = lnr.get("fd", "?")
lines.append(f"{addr:<40} {ltype:<8} {fd}")
lines.append("")
lines.append(f"Total: {data.get('count', len(listeners))} listeners")
return "\n".join(lines)
def format_config(data: dict) -> str:
"""Format config output for display."""
lines = []
# Sort keys for consistent output
for key in sorted(data.keys()):
value = data[key]
if isinstance(value, list):
value = ", ".join(str(v) for v in value)
lines.append(f"{key}: {value}")
return "\n".join(lines)
def format_help(data: dict) -> str:
"""Format help output for display."""
commands = data.get("commands", {})
lines = []
lines.append("Available commands:")
lines.append("")
# Find max command length for alignment
max_len = max(len(cmd) for cmd in commands.keys()) if commands else 0
for cmd, desc in sorted(commands.items()):
lines.append(f" {cmd:<{max_len + 2}} {desc}")
return "\n".join(lines)
def format_response(command: str, data: dict) -> str:
"""
Format response data based on command.
Args:
command: Original command string
data: Response data dictionary
Returns:
Formatted string for display
"""
cmd_lower = command.lower().strip()
# Route to specific formatters
if cmd_lower == "show workers":
return format_workers(data)
elif cmd_lower == "show dirty":
return format_dirty(data)
elif cmd_lower == "show stats":
return format_stats(data)
elif cmd_lower == "show listeners":
return format_listeners(data)
elif cmd_lower == "show config":
return format_config(data)
elif cmd_lower == "help":
return format_help(data)
else:
# Generic JSON output for other commands
if data:
return json.dumps(data, indent=2)
return "OK"
def run_command(socket_path: str, command: str, json_output: bool = False) -> int:
"""
Execute single command and exit.
Args:
socket_path: Path to control socket
command: Command to execute
json_output: If True, output raw JSON
Returns:
Exit code (0 for success, 1 for error)
"""
try:
with ControlClient(socket_path) as client:
cmd, args = parse_command(command)
full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd
result = client.send_command(full_command)
if json_output:
print(json.dumps(result, indent=2))
else:
output = format_response(cmd, result)
print(output)
return 0
except ControlClientError as e:
print(f"Error: {e}", file=sys.stderr)
return 1
except KeyboardInterrupt:
return 130
def run_interactive(socket_path: str, json_output: bool = False) -> int:
"""
Run interactive CLI with readline support.
Args:
socket_path: Path to control socket
json_output: If True, output raw JSON
Returns:
Exit code
"""
try:
import readline # noqa: F401 - imported for side effects
has_readline = True
except ImportError:
has_readline = False
try:
client = ControlClient(socket_path)
client.connect()
except ControlClientError as e:
print(f"Error: {e}", file=sys.stderr)
return 1
print(f"Connected to {socket_path}")
print("Type 'help' for available commands, 'quit' to exit.")
print()
# Set up readline history
history_file = os.path.expanduser("~/.gunicornc_history")
if has_readline:
try:
readline.read_history_file(history_file)
except FileNotFoundError:
pass
exit_code = 0
try:
while True:
try:
line = input("gunicorn> ").strip()
except EOFError:
print()
break
if not line:
continue
if line.lower() in ('quit', 'exit', 'q'):
break
try:
cmd, args = parse_command(line)
full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd
result = client.send_command(full_command)
if json_output:
print(json.dumps(result, indent=2))
else:
output = format_response(cmd, result)
print(output)
except ControlClientError as e:
print(f"Error: {e}")
# Try to reconnect
try:
client.close()
client.connect()
except ControlClientError:
print("Connection lost. Exiting.")
exit_code = 1
break
print()
except KeyboardInterrupt:
print()
exit_code = 130
finally:
client.close()
if has_readline:
try:
readline.write_history_file(history_file)
except Exception:
pass
return exit_code
def main():
"""Main entry point for gunicornc CLI."""
parser = argparse.ArgumentParser(
description='Gunicorn control interface',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
gunicornc # Interactive mode (default socket)
gunicornc -s /tmp/myapp.ctl # Interactive mode with custom socket
gunicornc -c "show workers" # Single command mode
gunicornc -c "worker add 2" # Add 2 workers
gunicornc -c "show stats" -j # Output stats as JSON
"""
)
parser.add_argument(
'-s', '--socket',
default='gunicorn.ctl',
help='Control socket path (default: gunicorn.ctl in current directory)'
)
parser.add_argument(
'-c', '--command',
help='Execute single command and exit'
)
parser.add_argument(
'-j', '--json',
action='store_true',
help='Output raw JSON (for scripting)'
)
parser.add_argument(
'-v', '--version',
action='store_true',
help='Show version and exit'
)
args = parser.parse_args()
if args.version:
from gunicorn import __version__
print(f"gunicornc (gunicorn {__version__})")
return 0
socket_path = args.socket
# Make relative paths absolute from cwd
if not os.path.isabs(socket_path):
socket_path = os.path.join(os.getcwd(), socket_path)
if args.command:
return run_command(socket_path, args.command, args.json)
else:
return run_interactive(socket_path, args.json)
if __name__ == '__main__':
sys.exit(main())

140
gunicorn/ctl/client.py Normal file
View File

@ -0,0 +1,140 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Control Socket Client
Client library for connecting to gunicorn control socket.
"""
import shlex
import socket
from gunicorn.ctl.protocol import (
ControlProtocol,
make_request,
)
class ControlClientError(Exception):
"""Control client error."""
pass
class ControlClient:
"""
Client for connecting to gunicorn control socket.
Can be used as a context manager:
with ControlClient('/path/to/gunicorn.ctl') as client:
result = client.send_command('show workers')
"""
def __init__(self, socket_path: str, timeout: float = 30.0):
"""
Initialize control client.
Args:
socket_path: Path to the Unix socket
timeout: Socket timeout in seconds (default 30)
"""
self.socket_path = socket_path
self.timeout = timeout
self._sock = None
self._request_id = 0
def connect(self):
"""
Connect to control socket.
Raises:
ControlClientError: If connection fails
"""
if self._sock:
return
try:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._sock.settimeout(self.timeout)
self._sock.connect(self.socket_path)
except socket.error as e:
self._sock = None
raise ControlClientError(f"Failed to connect to {self.socket_path}: {e}")
def close(self):
"""Close connection."""
if self._sock:
try:
self._sock.close()
except Exception:
pass
self._sock = None
def send_command(self, command: str, args: list = None) -> dict:
"""
Send command and wait for response.
Args:
command: Command string (e.g., "show workers")
args: Optional additional arguments
Returns:
Response data dictionary
Raises:
ControlClientError: If communication fails
"""
if not self._sock:
self.connect()
self._request_id += 1
request = make_request(self._request_id, command, args)
try:
ControlProtocol.write_message(self._sock, request)
response = ControlProtocol.read_message(self._sock)
except Exception as e:
self.close()
raise ControlClientError(f"Communication error: {e}")
if response.get("status") == "error":
raise ControlClientError(response.get("error", "Unknown error"))
return response.get("data", {})
def __enter__(self):
self.connect()
return self
def __exit__(self, *args):
self.close()
def parse_command(line: str) -> tuple:
"""
Parse a command line into command and args.
Args:
line: Command line string
Returns:
Tuple of (command_string, args_list)
"""
parts = shlex.split(line)
if not parts:
return "", []
# Find where numeric/value args start
command_parts = []
args = []
for part in parts:
# If we haven't hit args yet and this looks like a command word
if not args and not part.isdigit() and not part.startswith('-'):
command_parts.append(part)
else:
args.append(part)
return " ".join(command_parts), args

431
gunicorn/ctl/handlers.py Normal file
View File

@ -0,0 +1,431 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Control Interface Command Handlers
Provides handlers for all control commands with access to arbiter state.
"""
import os
import signal
import time
class CommandHandlers:
"""
Command handlers with access to arbiter state.
All handler methods return dictionaries that will be sent
as the response data.
"""
def __init__(self, arbiter):
"""
Initialize handlers with arbiter reference.
Args:
arbiter: The Gunicorn arbiter instance
"""
self.arbiter = arbiter
def show_workers(self) -> dict:
"""
Return list of HTTP workers.
Returns:
Dictionary with workers list containing:
- pid: Worker process ID
- age: Worker age (spawn order)
- requests: Number of requests handled (if available)
- booted: Whether worker has finished booting
- last_heartbeat: Seconds since last heartbeat
"""
workers = []
now = time.monotonic()
for pid, worker in self.arbiter.WORKERS.items():
try:
last_update = worker.tmp.last_update()
last_heartbeat = round(now - last_update, 2)
except (OSError, ValueError):
last_heartbeat = None
workers.append({
"pid": pid,
"age": worker.age,
"booted": worker.booted,
"aborted": worker.aborted,
"last_heartbeat": last_heartbeat,
})
# Sort by age (oldest first)
workers.sort(key=lambda w: w["age"])
return {"workers": workers, "count": len(workers)}
def show_dirty(self) -> dict:
"""
Return dirty workers and apps information.
Returns:
Dictionary with:
- enabled: Whether dirty arbiter is running
- pid: Dirty arbiter PID
- workers: List of dirty worker info
- apps: List of dirty app specs
"""
if not self.arbiter.dirty_arbiter_pid:
return {
"enabled": False,
"pid": None,
"workers": [],
"apps": [],
}
# Get dirty arbiter reference if available
dirty_arbiter = getattr(self.arbiter, 'dirty_arbiter', None)
workers = []
apps = []
if dirty_arbiter and hasattr(dirty_arbiter, 'workers'):
now = time.monotonic()
for pid, worker in dirty_arbiter.workers.items():
try:
last_update = worker.tmp.last_update()
last_heartbeat = round(now - last_update, 2)
except (OSError, ValueError, AttributeError):
last_heartbeat = None
workers.append({
"pid": pid,
"age": worker.age,
"apps": getattr(worker, 'app_paths', []),
"booted": getattr(worker, 'booted', False),
"last_heartbeat": last_heartbeat,
})
# Get app specs
if hasattr(dirty_arbiter, 'app_specs'):
for path, spec in dirty_arbiter.app_specs.items():
worker_pids = list(dirty_arbiter.app_worker_map.get(path, []))
apps.append({
"import_path": path,
"worker_count": spec.get('worker_count'),
"current_workers": len(worker_pids),
"worker_pids": worker_pids,
})
return {
"enabled": True,
"pid": self.arbiter.dirty_arbiter_pid,
"workers": workers,
"apps": apps,
}
def show_config(self) -> dict:
"""
Return current effective configuration.
Returns:
Dictionary of configuration values
"""
cfg = self.arbiter.cfg
config = {}
# Get commonly needed config values
config_keys = [
'bind', 'workers', 'worker_class', 'threads', 'timeout',
'graceful_timeout', 'keepalive', 'max_requests',
'max_requests_jitter', 'worker_connections', 'preload_app',
'daemon', 'pidfile', 'proc_name', 'reload',
'dirty_workers', 'dirty_apps', 'dirty_timeout',
'control_socket', 'control_socket_disable',
]
for key in config_keys:
try:
value = getattr(cfg, key)
# Convert non-serializable types
if callable(value):
value = str(value)
elif hasattr(value, '__class__') and not isinstance(
value, (str, int, float, bool, list, dict, type(None))):
value = str(value)
config[key] = value
except AttributeError:
pass
return config
def show_stats(self) -> dict:
"""
Return server statistics.
Returns:
Dictionary with:
- uptime: Seconds since arbiter started
- pid: Arbiter PID
- workers_current: Current number of workers
- workers_spawned: Total workers spawned
- workers_killed: Total workers killed (if tracked)
- reloads: Number of reloads (if tracked)
"""
stats = getattr(self.arbiter, '_stats', {})
start_time = stats.get('start_time')
uptime = None
if start_time:
uptime = round(time.time() - start_time, 2)
return {
"uptime": uptime,
"pid": self.arbiter.pid,
"workers_current": len(self.arbiter.WORKERS),
"workers_target": self.arbiter.num_workers,
"workers_spawned": stats.get('workers_spawned', 0),
"workers_killed": stats.get('workers_killed', 0),
"reloads": stats.get('reloads', 0),
"dirty_arbiter_pid": self.arbiter.dirty_arbiter_pid or None,
}
def show_listeners(self) -> dict:
"""
Return bound socket information.
Returns:
Dictionary with listeners list
"""
listeners = []
for lnr in self.arbiter.LISTENERS:
addr = str(lnr)
listener_info = {
"address": addr,
"fd": lnr.fileno(),
}
# Try to get socket family
try:
import socket
sock = lnr.sock
if sock.family == socket.AF_UNIX:
listener_info["type"] = "unix"
elif sock.family == socket.AF_INET:
listener_info["type"] = "tcp"
elif sock.family == socket.AF_INET6:
listener_info["type"] = "tcp6"
except Exception:
listener_info["type"] = "unknown"
listeners.append(listener_info)
return {"listeners": listeners, "count": len(listeners)}
def worker_add(self, count: int = 1) -> dict:
"""
Increase worker count.
Args:
count: Number of workers to add (default 1)
Returns:
Dictionary with added count and new total
"""
count = max(1, int(count))
old_count = self.arbiter.num_workers
self.arbiter.num_workers += count
# Wake up the arbiter to spawn workers
self.arbiter.wakeup()
return {
"added": count,
"previous": old_count,
"total": self.arbiter.num_workers,
}
def worker_remove(self, count: int = 1) -> dict:
"""
Decrease worker count.
Args:
count: Number of workers to remove (default 1)
Returns:
Dictionary with removed count and new total
"""
count = max(1, int(count))
old_count = self.arbiter.num_workers
# Don't go below 1 worker
new_count = max(1, old_count - count)
actual_removed = old_count - new_count
self.arbiter.num_workers = new_count
# Wake up the arbiter to kill excess workers
self.arbiter.wakeup()
return {
"removed": actual_removed,
"previous": old_count,
"total": new_count,
}
def worker_kill(self, pid: int) -> dict:
"""
Gracefully terminate a specific worker.
Args:
pid: Worker process ID
Returns:
Dictionary with killed PID or error
"""
pid = int(pid)
if pid not in self.arbiter.WORKERS:
return {
"success": False,
"error": f"Worker {pid} not found",
}
try:
os.kill(pid, signal.SIGTERM)
return {
"success": True,
"killed": pid,
}
except OSError as e:
return {
"success": False,
"error": str(e),
}
def dirty_add(self, count: int = 1) -> dict:
"""
Spawn additional dirty workers.
Args:
count: Number of dirty workers to add (default 1)
Returns:
Dictionary with added count or error
"""
if not self.arbiter.dirty_arbiter_pid:
return {
"success": False,
"error": "Dirty arbiter not running",
}
# Send TTIN signals to dirty arbiter
count = max(1, int(count))
try:
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN)
return {
"success": True,
"added": count,
}
except OSError as e:
return {
"success": False,
"error": str(e),
}
def dirty_remove(self, count: int = 1) -> dict:
"""
Remove dirty workers.
Args:
count: Number of dirty workers to remove (default 1)
Returns:
Dictionary with removed count or error
"""
if not self.arbiter.dirty_arbiter_pid:
return {
"success": False,
"error": "Dirty arbiter not running",
}
# Send TTOU signals to dirty arbiter
count = max(1, int(count))
try:
for _ in range(count):
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU)
return {
"success": True,
"removed": count,
}
except OSError as e:
return {
"success": False,
"error": str(e),
}
def reload(self) -> dict:
"""
Trigger graceful reload (equivalent to SIGHUP).
Returns:
Dictionary with status
"""
# Send HUP to self to trigger reload
os.kill(self.arbiter.pid, signal.SIGHUP)
return {"status": "reloading"}
def reopen(self) -> dict:
"""
Reopen log files (equivalent to SIGUSR1).
Returns:
Dictionary with status
"""
os.kill(self.arbiter.pid, signal.SIGUSR1)
return {"status": "reopening"}
def shutdown(self, mode: str = "graceful") -> dict:
"""
Initiate shutdown.
Args:
mode: "graceful" (SIGTERM) or "quick" (SIGINT)
Returns:
Dictionary with status
"""
if mode == "quick":
os.kill(self.arbiter.pid, signal.SIGINT)
else:
os.kill(self.arbiter.pid, signal.SIGTERM)
return {"status": "shutting_down", "mode": mode}
def help(self) -> dict:
"""
Return list of available commands.
Returns:
Dictionary with commands and descriptions
"""
commands = {
"show workers": "List HTTP workers with their status",
"show dirty": "List dirty workers and apps",
"show config": "Show current effective configuration",
"show stats": "Show server statistics",
"show listeners": "Show bound sockets",
"worker add [N]": "Spawn N workers (default 1)",
"worker remove [N]": "Remove N workers (default 1)",
"worker kill <PID>": "Gracefully terminate specific worker",
"dirty add [N]": "Spawn N dirty workers (default 1)",
"dirty remove [N]": "Remove N dirty workers (default 1)",
"reload": "Graceful reload (HUP)",
"reopen": "Reopen log files (USR1)",
"shutdown [graceful|quick]": "Shutdown server (TERM/INT)",
"help": "Show this help message",
}
return {"commands": commands}

225
gunicorn/ctl/protocol.py Normal file
View File

@ -0,0 +1,225 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Control Socket Protocol
JSON-based protocol with length-prefixed framing for the control interface.
Message Format:
+----------------+------------------+
| Length (4B BE) | JSON Payload |
+----------------+------------------+
Request Format:
{"id": 1, "command": "show", "args": ["workers"]}
Response Format:
{"id": 1, "status": "ok", "data": {...}}
{"id": 1, "status": "error", "error": "message"}
"""
import json
import struct
class ProtocolError(Exception):
"""Protocol-level error."""
pass
class ControlProtocol:
"""
Protocol implementation for control socket communication.
Uses 4-byte big-endian length prefix followed by JSON payload.
"""
# Maximum message size (16 MB)
MAX_MESSAGE_SIZE = 16 * 1024 * 1024
@staticmethod
def encode_message(data: dict) -> bytes:
"""
Encode a message for transmission.
Args:
data: Dictionary to encode
Returns:
Length-prefixed JSON bytes
"""
payload = json.dumps(data).encode('utf-8')
length = struct.pack('>I', len(payload))
return length + payload
@staticmethod
def decode_message(data: bytes) -> dict:
"""
Decode a message from bytes.
Args:
data: Raw bytes (length prefix + JSON payload)
Returns:
Decoded dictionary
"""
if len(data) < 4:
raise ProtocolError("Message too short")
length = struct.unpack('>I', data[:4])[0]
if len(data) < 4 + length:
raise ProtocolError("Incomplete message")
payload = data[4:4 + length]
return json.loads(payload.decode('utf-8'))
@staticmethod
def read_message(sock) -> dict:
"""
Read one message from a socket.
Args:
sock: Socket to read from
Returns:
Decoded message dictionary
Raises:
ProtocolError: If message is malformed
ConnectionError: If connection is closed
"""
# Read length prefix
length_data = b''
while len(length_data) < 4:
chunk = sock.recv(4 - len(length_data))
if not chunk:
if not length_data:
raise ConnectionError("Connection closed")
raise ProtocolError("Incomplete length prefix")
length_data += chunk
length = struct.unpack('>I', length_data)[0]
if length > ControlProtocol.MAX_MESSAGE_SIZE:
raise ProtocolError(f"Message too large: {length}")
# Read payload
payload_data = b''
while len(payload_data) < length:
chunk = sock.recv(min(length - len(payload_data), 65536))
if not chunk:
raise ProtocolError("Incomplete payload")
payload_data += chunk
try:
return json.loads(payload_data.decode('utf-8'))
except json.JSONDecodeError as e:
raise ProtocolError(f"Invalid JSON: {e}")
@staticmethod
def write_message(sock, data: dict):
"""
Write one message to a socket.
Args:
sock: Socket to write to
data: Message dictionary to send
"""
message = ControlProtocol.encode_message(data)
sock.sendall(message)
@staticmethod
async def read_message_async(reader) -> dict:
"""
Read one message from an async reader.
Args:
reader: asyncio StreamReader
Returns:
Decoded message dictionary
"""
# Read length prefix
length_data = await reader.readexactly(4)
length = struct.unpack('>I', length_data)[0]
if length > ControlProtocol.MAX_MESSAGE_SIZE:
raise ProtocolError(f"Message too large: {length}")
# Read payload
payload_data = await reader.readexactly(length)
try:
return json.loads(payload_data.decode('utf-8'))
except json.JSONDecodeError as e:
raise ProtocolError(f"Invalid JSON: {e}")
@staticmethod
async def write_message_async(writer, data: dict):
"""
Write one message to an async writer.
Args:
writer: asyncio StreamWriter
data: Message dictionary to send
"""
message = ControlProtocol.encode_message(data)
writer.write(message)
await writer.drain()
def make_request(request_id: int, command: str, args: list = None) -> dict:
"""
Create a request message.
Args:
request_id: Unique request identifier
command: Command name (e.g., "show workers")
args: Optional list of arguments
Returns:
Request dictionary
"""
return {
"id": request_id,
"command": command,
"args": args or [],
}
def make_response(request_id: int, data: dict = None) -> dict:
"""
Create a success response message.
Args:
request_id: Request identifier being responded to
data: Response data
Returns:
Response dictionary
"""
return {
"id": request_id,
"status": "ok",
"data": data or {},
}
def make_error_response(request_id: int, error: str) -> dict:
"""
Create an error response message.
Args:
request_id: Request identifier being responded to
error: Error message
Returns:
Error response dictionary
"""
return {
"id": request_id,
"status": "error",
"error": error,
}

299
gunicorn/ctl/server.py Normal file
View File

@ -0,0 +1,299 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Control Socket Server
Runs in the arbiter process and accepts commands via Unix socket.
Uses asyncio in a background thread to handle client connections.
"""
import asyncio
import os
import shlex
import threading
from gunicorn.ctl.handlers import CommandHandlers
from gunicorn.ctl.protocol import (
ControlProtocol,
make_response,
make_error_response,
)
class ControlSocketServer:
"""
Control socket server running in arbiter process.
The server runs an asyncio event loop in a background thread,
accepting connections and dispatching commands to handlers.
"""
def __init__(self, arbiter, socket_path, socket_mode=0o600):
"""
Initialize control socket server.
Args:
arbiter: The Gunicorn arbiter instance
socket_path: Path for the Unix socket
socket_mode: Permission mode for socket (default 0o600)
"""
self.arbiter = arbiter
self.socket_path = socket_path
self.socket_mode = socket_mode
self.handlers = CommandHandlers(arbiter)
self._server = None
self._loop = None
self._thread = None
self._running = False
def start(self):
"""Start server in background thread with asyncio event loop."""
if self._running:
return
self._running = True
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
def stop(self):
"""Stop server and cleanup socket."""
if not self._running:
return
self._running = False
if self._loop and self._server:
# Schedule server close in the loop
self._loop.call_soon_threadsafe(self._shutdown)
if self._thread:
self._thread.join(timeout=2.0)
self._thread = None
# Clean up socket file
if os.path.exists(self.socket_path):
try:
os.unlink(self.socket_path)
except OSError:
pass
def _shutdown(self):
"""Shutdown server (called from event loop thread)."""
if self._server:
self._server.close()
def _run_loop(self):
"""Run the asyncio event loop in background thread."""
try:
asyncio.run(self._serve())
except Exception as e:
if self.arbiter.log:
self.arbiter.log.error("Control server error: %s", e)
async def _serve(self):
"""Main async server loop."""
self._loop = asyncio.get_running_loop()
# Remove socket if it exists
if os.path.exists(self.socket_path):
os.unlink(self.socket_path)
# Create Unix socket server
self._server = await asyncio.start_unix_server(
self._handle_client,
path=self.socket_path
)
# Set socket permissions
os.chmod(self.socket_path, self.socket_mode)
if self.arbiter.log:
self.arbiter.log.info("Control socket listening at %s",
self.socket_path)
try:
async with self._server:
await self._server.serve_forever()
except asyncio.CancelledError:
pass
finally:
if os.path.exists(self.socket_path):
try:
os.unlink(self.socket_path)
except OSError:
pass
async def _handle_client(self, reader, writer):
"""
Handle client connection.
Args:
reader: asyncio StreamReader
writer: asyncio StreamWriter
"""
try:
while self._running:
try:
message = await asyncio.wait_for(
ControlProtocol.read_message_async(reader),
timeout=300.0 # 5 minute idle timeout
)
except asyncio.TimeoutError:
# Client idle too long, close connection
break
except asyncio.IncompleteReadError:
# Client disconnected
break
except Exception:
# Protocol error
break
# Process command
response = await self._dispatch(message)
# Send response
await ControlProtocol.write_message_async(writer, response)
except Exception as e:
if self.arbiter.log:
self.arbiter.log.debug("Control client error: %s", e)
finally:
writer.close()
try:
await writer.wait_closed()
except Exception:
pass
async def _dispatch(self, message: dict) -> dict:
"""
Dispatch command to appropriate handler.
Args:
message: Request message dict
Returns:
Response dictionary
"""
request_id = message.get("id", 0)
command = message.get("command", "").strip()
args = message.get("args", [])
if not command:
return make_error_response(request_id, "Empty command")
try:
# Parse command (e.g., "show workers" or "worker add 2")
parts = shlex.split(command)
if args:
parts.extend(str(a) for a in args)
if not parts:
return make_error_response(request_id, "Empty command")
# Route to handler
result = self._execute_command(parts)
return make_response(request_id, result)
except ValueError as e:
return make_error_response(request_id, f"Invalid argument: {e}")
except Exception as e:
if self.arbiter.log:
self.arbiter.log.exception("Command error")
return make_error_response(request_id, f"Command failed: {e}")
def _execute_command(self, parts: list) -> dict:
"""
Execute a parsed command.
Args:
parts: Command parts (e.g., ["show", "workers"])
Returns:
Handler result dictionary
"""
if not parts:
raise ValueError("Empty command")
cmd = parts[0].lower()
rest = parts[1:]
# Map commands to handlers
if cmd == "show":
return self._handle_show(rest)
elif cmd == "worker":
return self._handle_worker(rest)
elif cmd == "dirty":
return self._handle_dirty(rest)
elif cmd == "reload":
return self.handlers.reload()
elif cmd == "reopen":
return self.handlers.reopen()
elif cmd == "shutdown":
mode = rest[0] if rest else "graceful"
return self.handlers.shutdown(mode)
elif cmd == "help":
return self.handlers.help()
else:
raise ValueError(f"Unknown command: {cmd}")
def _handle_show(self, args: list) -> dict:
"""Handle 'show' commands."""
if not args:
raise ValueError("Missing show target (workers|dirty|config|stats|listeners)")
target = args[0].lower()
if target == "workers":
return self.handlers.show_workers()
elif target == "dirty":
return self.handlers.show_dirty()
elif target == "config":
return self.handlers.show_config()
elif target == "stats":
return self.handlers.show_stats()
elif target == "listeners":
return self.handlers.show_listeners()
else:
raise ValueError(f"Unknown show target: {target}")
def _handle_worker(self, args: list) -> dict:
"""Handle 'worker' commands."""
if not args:
raise ValueError("Missing worker action (add|remove|kill)")
action = args[0].lower()
action_args = args[1:]
if action == "add":
count = int(action_args[0]) if action_args else 1
return self.handlers.worker_add(count)
elif action == "remove":
count = int(action_args[0]) if action_args else 1
return self.handlers.worker_remove(count)
elif action == "kill":
if not action_args:
raise ValueError("Missing PID for worker kill")
pid = int(action_args[0])
return self.handlers.worker_kill(pid)
else:
raise ValueError(f"Unknown worker action: {action}")
def _handle_dirty(self, args: list) -> dict:
"""Handle 'dirty' commands."""
if not args:
raise ValueError("Missing dirty action (add|remove)")
action = args[0].lower()
action_args = args[1:]
if action == "add":
count = int(action_args[0]) if action_args else 1
return self.handlers.dirty_add(count)
elif action == "remove":
count = int(action_args[0]) if action_args else 1
return self.handlers.dirty_remove(count)
else:
raise ValueError(f"Unknown dirty action: {action}")

View File

@ -367,7 +367,8 @@ class DirtyArbiter:
try:
async with self._server:
await self._server.serve_forever()
except asyncio.CancelledError:
except (asyncio.CancelledError, RuntimeError):
# RuntimeError raised when server.close() is called during serve_forever()
pass
finally:
monitor_task.cancel()
@ -836,19 +837,19 @@ class DirtyArbiter:
pid, app_paths)
return pid
# Child process
# Child process - use os._exit() to avoid asyncio cleanup issues
worker.pid = os.getpid()
try:
util._setproctitle(f"dirty-worker [{self.cfg.proc_name}]")
worker.init_process()
sys.exit(0)
except SystemExit:
raise
os._exit(0)
except SystemExit as e:
os._exit(e.code if e.code is not None else 0)
except Exception:
self.log.exception("Exception in dirty worker process")
if not worker.booted:
sys.exit(self.WORKER_BOOT_ERROR)
sys.exit(-1)
os._exit(self.WORKER_BOOT_ERROR)
os._exit(1)
def kill_worker(self, pid, sig):
"""Kill a worker by PID."""

View File

@ -68,6 +68,7 @@ testing = [
[project.scripts]
# duplicates "python -m gunicorn" handling in __main__.py
gunicorn = "gunicorn.app.wsgiapp:run"
gunicornc = "gunicorn.ctl.cli:main"
# note the quotes around "paste.server_runner" to escape the dot
[project.entry-points."paste.server_runner"]

3
tests/ctl/__init__.py Normal file
View File

@ -0,0 +1,3 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.

275
tests/ctl/test_client.py Normal file
View File

@ -0,0 +1,275 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for control socket client."""
import os
import socket
import tempfile
import threading
import pytest
from gunicorn.ctl.client import (
ControlClient,
ControlClientError,
parse_command,
)
from gunicorn.ctl.protocol import ControlProtocol, make_response
class TestControlClientInit:
"""Tests for ControlClient initialization."""
def test_init_attributes(self):
"""Test that client is initialized with correct attributes."""
client = ControlClient("/tmp/test.sock", timeout=60.0)
assert client.socket_path == "/tmp/test.sock"
assert client.timeout == 60.0
assert client._sock is None
assert client._request_id == 0
class TestControlClientConnect:
"""Tests for ControlClient connection."""
def test_connect_nonexistent_socket(self):
"""Test connecting to non-existent socket."""
client = ControlClient("/nonexistent/socket.sock")
with pytest.raises(ControlClientError) as exc_info:
client.connect()
assert "Failed to connect" in str(exc_info.value)
def test_connect_success(self):
"""Test successful connection."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
# Create a listening socket
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
try:
client = ControlClient(socket_path)
client.connect()
assert client._sock is not None
client.close()
finally:
server_sock.close()
def test_connect_already_connected(self):
"""Test that connect is idempotent."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
try:
client = ControlClient(socket_path)
client.connect()
first_sock = client._sock
client.connect() # Should not create new connection
assert client._sock is first_sock
client.close()
finally:
server_sock.close()
class TestControlClientClose:
"""Tests for ControlClient close."""
def test_close_idempotent(self):
"""Test that close can be called multiple times."""
client = ControlClient("/tmp/test.sock")
client.close()
client.close() # Should not raise
def test_close_clears_socket(self):
"""Test that close clears the socket."""
client = ControlClient("/tmp/test.sock")
client._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client.close()
assert client._sock is None
class TestControlClientContextManager:
"""Tests for context manager functionality."""
def test_context_manager_connection_error(self):
"""Test context manager with connection error."""
client = ControlClient("/nonexistent/socket.sock")
with pytest.raises(ControlClientError):
with client:
pass
def test_context_manager_success(self):
"""Test successful context manager usage."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
try:
with ControlClient(socket_path) as client:
assert client._sock is not None
# After context manager exits, socket should be closed
assert client._sock is None
finally:
server_sock.close()
class TestControlClientSendCommand:
"""Tests for send_command functionality."""
def test_send_command_success(self):
"""Test successful command send."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
response_data = {"workers": [], "count": 0}
response_sent = threading.Event()
def server_handler():
conn, _ = server_sock.accept()
try:
msg = ControlProtocol.read_message(conn)
resp = make_response(msg["id"], response_data)
ControlProtocol.write_message(conn, resp)
response_sent.set()
finally:
conn.close()
server_thread = threading.Thread(target=server_handler)
server_thread.start()
try:
client = ControlClient(socket_path, timeout=5.0)
result = client.send_command("show workers")
assert result == response_data
client.close()
finally:
response_sent.wait(timeout=2.0)
server_thread.join(timeout=2.0)
server_sock.close()
def test_send_command_error_response(self):
"""Test handling error response."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
def server_handler():
conn, _ = server_sock.accept()
try:
msg = ControlProtocol.read_message(conn)
resp = {
"id": msg["id"],
"status": "error",
"error": "Unknown command",
}
ControlProtocol.write_message(conn, resp)
finally:
conn.close()
server_thread = threading.Thread(target=server_handler)
server_thread.start()
try:
client = ControlClient(socket_path, timeout=5.0)
with pytest.raises(ControlClientError) as exc_info:
client.send_command("invalid command")
assert "Unknown command" in str(exc_info.value)
client.close()
finally:
server_thread.join(timeout=2.0)
server_sock.close()
def test_send_command_auto_connect(self):
"""Test that send_command auto-connects if not connected."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server_sock.bind(socket_path)
server_sock.listen(1)
def server_handler():
conn, _ = server_sock.accept()
try:
msg = ControlProtocol.read_message(conn)
resp = make_response(msg["id"], {})
ControlProtocol.write_message(conn, resp)
finally:
conn.close()
server_thread = threading.Thread(target=server_handler)
server_thread.start()
try:
client = ControlClient(socket_path, timeout=5.0)
# Don't call connect() explicitly
result = client.send_command("help")
assert isinstance(result, dict)
client.close()
finally:
server_thread.join(timeout=2.0)
server_sock.close()
class TestParseCommand:
"""Tests for command parsing."""
def test_parse_simple_command(self):
"""Test parsing simple command."""
cmd, args = parse_command("show workers")
assert cmd == "show workers"
assert args == []
def test_parse_command_with_args(self):
"""Test parsing command with arguments."""
cmd, args = parse_command("worker add 2")
assert cmd == "worker add"
assert args == ["2"]
def test_parse_command_with_multiple_args(self):
"""Test parsing command with multiple arguments."""
cmd, args = parse_command("worker kill 12345")
assert cmd == "worker kill"
assert args == ["12345"]
def test_parse_empty_command(self):
"""Test parsing empty command."""
cmd, args = parse_command("")
assert cmd == ""
assert args == []
def test_parse_command_quoted(self):
"""Test parsing command with quoted arguments."""
cmd, args = parse_command('worker kill "12345"')
assert cmd == "worker kill"
assert args == ["12345"]

374
tests/ctl/test_handlers.py Normal file
View File

@ -0,0 +1,374 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for control socket command handlers."""
import signal
import time
from unittest.mock import MagicMock, patch
import pytest
from gunicorn.ctl.handlers import CommandHandlers
class MockWorker:
"""Mock worker for testing."""
def __init__(self, pid, age, booted=True, aborted=False):
self.pid = pid
self.age = age
self.booted = booted
self.aborted = aborted
self.tmp = MagicMock()
self.tmp.last_update.return_value = time.monotonic()
class MockListener:
"""Mock listener for testing."""
def __init__(self, address, fd=3):
self._address = address
self._fd = fd
self.sock = MagicMock()
self.sock.family = 2 # AF_INET
def __str__(self):
return self._address
def fileno(self):
return self._fd
class MockConfig:
"""Mock config for testing."""
def __init__(self):
self.bind = ['127.0.0.1:8000']
self.workers = 4
self.worker_class = 'sync'
self.threads = 1
self.timeout = 30
self.graceful_timeout = 30
self.keepalive = 2
self.max_requests = 0
self.max_requests_jitter = 0
self.worker_connections = 1000
self.preload_app = False
self.daemon = False
self.pidfile = None
self.proc_name = 'test_app'
self.reload = False
self.dirty_workers = 0
self.dirty_apps = []
self.dirty_timeout = 30
self.control_socket = 'gunicorn.ctl'
self.control_socket_disable = False
class MockArbiter:
"""Mock arbiter for testing."""
def __init__(self):
self.cfg = MockConfig()
self.pid = 12345
self.WORKERS = {}
self.LISTENERS = []
self.dirty_arbiter_pid = 0
self.dirty_arbiter = None
self.num_workers = 4
self._stats = {
'start_time': time.time() - 3600, # 1 hour ago
'workers_spawned': 10,
'workers_killed': 5,
'reloads': 2,
}
def wakeup(self):
pass
class TestShowWorkers:
"""Tests for show workers command."""
def test_show_workers_empty(self):
"""Test showing workers when none exist."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.show_workers()
assert result["workers"] == []
assert result["count"] == 0
def test_show_workers_with_workers(self):
"""Test showing workers."""
arbiter = MockArbiter()
arbiter.WORKERS = {
1001: MockWorker(1001, 1),
1002: MockWorker(1002, 2),
1003: MockWorker(1003, 3),
}
handlers = CommandHandlers(arbiter)
result = handlers.show_workers()
assert result["count"] == 3
assert len(result["workers"]) == 3
# Verify sorted by age
ages = [w["age"] for w in result["workers"]]
assert ages == sorted(ages)
# Verify worker data
worker = result["workers"][0]
assert "pid" in worker
assert "age" in worker
assert "booted" in worker
assert "last_heartbeat" in worker
class TestShowStats:
"""Tests for show stats command."""
def test_show_stats(self):
"""Test showing stats."""
arbiter = MockArbiter()
arbiter.WORKERS = {
1001: MockWorker(1001, 1),
1002: MockWorker(1002, 2),
}
handlers = CommandHandlers(arbiter)
result = handlers.show_stats()
assert result["pid"] == 12345
assert result["workers_current"] == 2
assert result["workers_target"] == 4
assert result["workers_spawned"] == 10
assert result["workers_killed"] == 5
assert result["reloads"] == 2
assert result["uptime"] is not None
assert result["uptime"] > 0
class TestShowConfig:
"""Tests for show config command."""
def test_show_config(self):
"""Test showing config."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.show_config()
assert result["workers"] == 4
assert result["timeout"] == 30
assert result["bind"] == ['127.0.0.1:8000']
class TestShowListeners:
"""Tests for show listeners command."""
def test_show_listeners_empty(self):
"""Test showing listeners when none exist."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.show_listeners()
assert result["listeners"] == []
assert result["count"] == 0
def test_show_listeners(self):
"""Test showing listeners."""
arbiter = MockArbiter()
arbiter.LISTENERS = [
MockListener("127.0.0.1:8000", fd=3),
MockListener("127.0.0.1:8001", fd=4),
]
handlers = CommandHandlers(arbiter)
result = handlers.show_listeners()
assert result["count"] == 2
assert len(result["listeners"]) == 2
assert result["listeners"][0]["address"] == "127.0.0.1:8000"
class TestWorkerAdd:
"""Tests for worker add command."""
def test_worker_add_default(self):
"""Test adding one worker (default)."""
arbiter = MockArbiter()
arbiter.wakeup = MagicMock()
handlers = CommandHandlers(arbiter)
result = handlers.worker_add()
assert result["added"] == 1
assert result["previous"] == 4
assert result["total"] == 5
assert arbiter.num_workers == 5
arbiter.wakeup.assert_called_once()
def test_worker_add_multiple(self):
"""Test adding multiple workers."""
arbiter = MockArbiter()
arbiter.wakeup = MagicMock()
handlers = CommandHandlers(arbiter)
result = handlers.worker_add(3)
assert result["added"] == 3
assert result["total"] == 7
class TestWorkerRemove:
"""Tests for worker remove command."""
def test_worker_remove_default(self):
"""Test removing one worker (default)."""
arbiter = MockArbiter()
arbiter.wakeup = MagicMock()
handlers = CommandHandlers(arbiter)
result = handlers.worker_remove()
assert result["removed"] == 1
assert result["previous"] == 4
assert result["total"] == 3
assert arbiter.num_workers == 3
arbiter.wakeup.assert_called_once()
def test_worker_remove_cannot_go_below_one(self):
"""Test that worker count cannot go below 1."""
arbiter = MockArbiter()
arbiter.num_workers = 2
arbiter.wakeup = MagicMock()
handlers = CommandHandlers(arbiter)
result = handlers.worker_remove(5)
assert result["removed"] == 1
assert result["total"] == 1
assert arbiter.num_workers == 1
class TestWorkerKill:
"""Tests for worker kill command."""
def test_worker_kill_success(self):
"""Test killing a worker."""
arbiter = MockArbiter()
arbiter.WORKERS = {1001: MockWorker(1001, 1)}
handlers = CommandHandlers(arbiter)
with patch('os.kill') as mock_kill:
result = handlers.worker_kill(1001)
assert result["success"] is True
assert result["killed"] == 1001
mock_kill.assert_called_once_with(1001, signal.SIGTERM)
def test_worker_kill_not_found(self):
"""Test killing a non-existent worker."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.worker_kill(9999)
assert result["success"] is False
assert "not found" in result["error"]
class TestShowDirty:
"""Tests for show dirty command."""
def test_show_dirty_disabled(self):
"""Test showing dirty when disabled."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.show_dirty()
assert result["enabled"] is False
assert result["pid"] is None
class TestReload:
"""Tests for reload command."""
def test_reload(self):
"""Test reload command."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
with patch('os.kill') as mock_kill:
result = handlers.reload()
assert result["status"] == "reloading"
mock_kill.assert_called_once_with(12345, signal.SIGHUP)
class TestReopen:
"""Tests for reopen command."""
def test_reopen(self):
"""Test reopen command."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
with patch('os.kill') as mock_kill:
result = handlers.reopen()
assert result["status"] == "reopening"
mock_kill.assert_called_once_with(12345, signal.SIGUSR1)
class TestShutdown:
"""Tests for shutdown command."""
def test_shutdown_graceful(self):
"""Test graceful shutdown."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
with patch('os.kill') as mock_kill:
result = handlers.shutdown()
assert result["status"] == "shutting_down"
assert result["mode"] == "graceful"
mock_kill.assert_called_once_with(12345, signal.SIGTERM)
def test_shutdown_quick(self):
"""Test quick shutdown."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
with patch('os.kill') as mock_kill:
result = handlers.shutdown("quick")
assert result["status"] == "shutting_down"
assert result["mode"] == "quick"
mock_kill.assert_called_once_with(12345, signal.SIGINT)
class TestHelp:
"""Tests for help command."""
def test_help(self):
"""Test help command."""
arbiter = MockArbiter()
handlers = CommandHandlers(arbiter)
result = handlers.help()
assert "commands" in result
commands = result["commands"]
assert "show workers" in commands
assert "worker add [N]" in commands
assert "reload" in commands
assert "shutdown [graceful|quick]" in commands

249
tests/ctl/test_protocol.py Normal file
View File

@ -0,0 +1,249 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for control socket protocol."""
import json
import struct
import pytest
from gunicorn.ctl.protocol import (
ControlProtocol,
ProtocolError,
make_request,
make_response,
make_error_response,
)
class TestControlProtocolEncoding:
"""Tests for message encoding/decoding."""
def test_encode_message_simple(self):
"""Test encoding a simple message."""
data = {"command": "test"}
result = ControlProtocol.encode_message(data)
# First 4 bytes are length
length = struct.unpack('>I', result[:4])[0]
payload = result[4:]
assert length == len(payload)
assert json.loads(payload.decode('utf-8')) == data
def test_encode_message_unicode(self):
"""Test encoding message with unicode characters."""
data = {"message": "Hello \u4e16\u754c"}
result = ControlProtocol.encode_message(data)
length = struct.unpack('>I', result[:4])[0]
payload = result[4:]
assert length == len(payload)
assert json.loads(payload.decode('utf-8')) == data
def test_decode_message_simple(self):
"""Test decoding a simple message."""
data = {"command": "test", "args": [1, 2, 3]}
payload = json.dumps(data).encode('utf-8')
length = struct.pack('>I', len(payload))
raw = length + payload
result = ControlProtocol.decode_message(raw)
assert result == data
def test_decode_message_too_short(self):
"""Test decoding message that's too short."""
with pytest.raises(ProtocolError) as exc_info:
ControlProtocol.decode_message(b'\x00\x00')
assert "too short" in str(exc_info.value)
def test_decode_message_incomplete(self):
"""Test decoding incomplete message."""
# Length says 100 bytes but only 4 bytes provided
raw = struct.pack('>I', 100) + b'test'
with pytest.raises(ProtocolError) as exc_info:
ControlProtocol.decode_message(raw)
assert "Incomplete" in str(exc_info.value)
def test_roundtrip(self):
"""Test encode/decode roundtrip."""
original = {
"id": 42,
"command": "show workers",
"args": ["arg1", 123, True, None],
"nested": {"a": 1, "b": [1, 2, 3]},
}
encoded = ControlProtocol.encode_message(original)
decoded = ControlProtocol.decode_message(encoded)
assert decoded == original
class TestMakeRequest:
"""Tests for request creation."""
def test_make_request_simple(self):
"""Test creating a simple request."""
result = make_request(1, "show workers")
assert result["id"] == 1
assert result["command"] == "show workers"
assert result["args"] == []
def test_make_request_with_args(self):
"""Test creating a request with arguments."""
result = make_request(42, "worker add", [2])
assert result["id"] == 42
assert result["command"] == "worker add"
assert result["args"] == [2]
class TestMakeResponse:
"""Tests for response creation."""
def test_make_response_simple(self):
"""Test creating a simple response."""
result = make_response(1, {"count": 5})
assert result["id"] == 1
assert result["status"] == "ok"
assert result["data"] == {"count": 5}
def test_make_response_empty_data(self):
"""Test creating response with no data."""
result = make_response(1)
assert result["id"] == 1
assert result["status"] == "ok"
assert result["data"] == {}
class TestMakeErrorResponse:
"""Tests for error response creation."""
def test_make_error_response(self):
"""Test creating an error response."""
result = make_error_response(1, "Unknown command")
assert result["id"] == 1
assert result["status"] == "error"
assert result["error"] == "Unknown command"
class TestControlProtocolSocket:
"""Tests for socket reading/writing."""
def test_read_write_message(self):
"""Test read/write through socket pair."""
import socket
import threading
data = {"id": 1, "command": "test"}
received = []
# Create socket pair
server, client = socket.socketpair()
def reader():
received.append(ControlProtocol.read_message(server))
t = threading.Thread(target=reader)
t.start()
ControlProtocol.write_message(client, data)
t.join(timeout=2.0)
client.close()
server.close()
assert len(received) == 1
assert received[0] == data
def test_read_connection_closed(self):
"""Test reading from closed connection."""
import socket
server, client = socket.socketpair()
client.close()
with pytest.raises(ConnectionError):
ControlProtocol.read_message(server)
server.close()
def test_read_message_too_large(self):
"""Test reading message exceeding max size."""
import socket
server, client = socket.socketpair()
# Send a length that exceeds MAX_MESSAGE_SIZE
huge_length = ControlProtocol.MAX_MESSAGE_SIZE + 1
client.send(struct.pack('>I', huge_length))
with pytest.raises(ProtocolError) as exc_info:
ControlProtocol.read_message(server)
assert "too large" in str(exc_info.value)
client.close()
server.close()
class TestControlProtocolAsync:
"""Tests for async protocol methods."""
@pytest.mark.asyncio
async def test_async_read_write(self):
"""Test async read/write using a unix server."""
import asyncio
import tempfile
import os
data = {"id": 1, "command": "async test"}
received = []
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
async def handler(reader, writer):
msg = await ControlProtocol.read_message_async(reader)
received.append(msg)
await ControlProtocol.write_message_async(writer, data)
writer.close()
await writer.wait_closed()
server = await asyncio.start_unix_server(handler, path=socket_path)
async with server:
reader, writer = await asyncio.open_unix_connection(socket_path)
await ControlProtocol.write_message_async(writer, data)
response = await ControlProtocol.read_message_async(reader)
writer.close()
await writer.wait_closed()
assert len(received) == 1
assert received[0] == data
assert response == data
class TestProtocolMaxSize:
"""Tests for protocol size limits."""
def test_max_message_size_constant(self):
"""Test that MAX_MESSAGE_SIZE is set to a reasonable value."""
# Should be 16 MB
assert ControlProtocol.MAX_MESSAGE_SIZE == 16 * 1024 * 1024
def test_encode_large_message(self):
"""Test encoding a large (but valid) message."""
# Create a message with ~1MB of data
data = {"data": "x" * (1024 * 1024)}
encoded = ControlProtocol.encode_message(data)
# Should succeed and be decodable
decoded = ControlProtocol.decode_message(encoded)
assert decoded == data

348
tests/ctl/test_server.py Normal file
View File

@ -0,0 +1,348 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for control socket server."""
import os
import tempfile
import time
from unittest.mock import MagicMock
import pytest
from gunicorn.ctl.server import ControlSocketServer
from gunicorn.ctl.client import ControlClient
class MockWorker:
"""Mock worker for testing."""
def __init__(self, pid, age, booted=True, aborted=False):
self.pid = pid
self.age = age
self.booted = booted
self.aborted = aborted
self.tmp = MagicMock()
self.tmp.last_update.return_value = time.monotonic()
class MockConfig:
"""Mock config for testing."""
def __init__(self):
self.bind = ['127.0.0.1:8000']
self.workers = 4
self.worker_class = 'sync'
self.threads = 1
self.timeout = 30
self.graceful_timeout = 30
self.keepalive = 2
self.max_requests = 0
self.max_requests_jitter = 0
self.worker_connections = 1000
self.preload_app = False
self.daemon = False
self.pidfile = None
self.proc_name = 'test_app'
self.reload = False
self.dirty_workers = 0
self.dirty_apps = []
self.dirty_timeout = 30
self.control_socket = 'gunicorn.ctl'
self.control_socket_disable = False
class MockLog:
"""Mock logger for testing."""
def debug(self, msg, *args):
pass
def info(self, msg, *args):
pass
def warning(self, msg, *args):
pass
def error(self, msg, *args):
pass
def exception(self, msg, *args):
pass
class MockArbiter:
"""Mock arbiter for testing."""
def __init__(self):
self.cfg = MockConfig()
self.log = MockLog()
self.pid = 12345
self.WORKERS = {}
self.LISTENERS = []
self.dirty_arbiter_pid = 0
self.dirty_arbiter = None
self.num_workers = 4
self._stats = {
'start_time': time.time() - 3600,
'workers_spawned': 10,
'workers_killed': 5,
'reloads': 2,
}
def wakeup(self):
pass
class TestControlSocketServerInit:
"""Tests for server initialization."""
def test_init(self):
"""Test server initialization."""
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, "/tmp/test.sock", 0o600)
assert server.arbiter is arbiter
assert server.socket_path == "/tmp/test.sock"
assert server.socket_mode == 0o600
assert server._running is False
class TestControlSocketServerLifecycle:
"""Tests for server start/stop."""
def test_start_stop(self):
"""Test starting and stopping the server."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path)
server.start()
# Wait for server to start
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
assert os.path.exists(socket_path)
server.stop()
# Wait for cleanup
time.sleep(0.2)
# Socket should be cleaned up
assert not os.path.exists(socket_path)
def test_start_already_running(self):
"""Test that start is idempotent."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path)
server.start()
first_thread = server._thread
server.start()
assert server._thread is first_thread
server.stop()
def test_stop_not_running(self):
"""Test stopping a non-running server."""
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, "/tmp/test.sock")
# Should not raise
server.stop()
class TestControlSocketServerIntegration:
"""Integration tests for server with client."""
def test_show_workers(self):
"""Test show workers command."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
arbiter.WORKERS = {
1001: MockWorker(1001, 1),
1002: MockWorker(1002, 2),
}
server = ControlSocketServer(arbiter, socket_path)
server.start()
# Wait for server to start
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
result = client.send_command("show workers")
assert result["count"] == 2
assert len(result["workers"]) == 2
finally:
server.stop()
def test_show_stats(self):
"""Test show stats command."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
result = client.send_command("show stats")
assert result["pid"] == 12345
assert result["workers_spawned"] == 10
finally:
server.stop()
def test_help_command(self):
"""Test help command."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
result = client.send_command("help")
assert "commands" in result
assert "show workers" in result["commands"]
finally:
server.stop()
def test_worker_add(self):
"""Test worker add command."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
arbiter.wakeup = MagicMock()
server = ControlSocketServer(arbiter, socket_path)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
result = client.send_command("worker add 2")
assert result["added"] == 2
assert result["total"] == 6
assert arbiter.num_workers == 6
arbiter.wakeup.assert_called()
finally:
server.stop()
def test_invalid_command(self):
"""Test handling invalid command."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
with pytest.raises(Exception) as exc_info:
client.send_command("invalid_command")
assert "Unknown command" in str(exc_info.value)
finally:
server.stop()
def test_multiple_commands(self):
"""Test sending multiple commands on same connection."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
arbiter.WORKERS = {1001: MockWorker(1001, 1)}
server = ControlSocketServer(arbiter, socket_path)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
with ControlClient(socket_path, timeout=5.0) as client:
result1 = client.send_command("show workers")
result2 = client.send_command("show stats")
result3 = client.send_command("help")
assert result1["count"] == 1
assert result2["pid"] == 12345
assert "commands" in result3
finally:
server.stop()
class TestControlSocketServerPermissions:
"""Tests for socket permissions."""
def test_socket_permissions(self):
"""Test that socket is created with correct permissions."""
with tempfile.TemporaryDirectory() as tmpdir:
socket_path = os.path.join(tmpdir, "test.sock")
arbiter = MockArbiter()
server = ControlSocketServer(arbiter, socket_path, 0o660)
server.start()
for _ in range(20):
if os.path.exists(socket_path):
break
time.sleep(0.1)
try:
mode = os.stat(socket_path).st_mode & 0o777
assert mode == 0o660
finally:
server.stop()