diff --git a/gunicorn/arbiter.py b/gunicorn/arbiter.py index 6200ac3a..cd838d00 100644 --- a/gunicorn/arbiter.py +++ b/gunicorn/arbiter.py @@ -74,6 +74,17 @@ class Arbiter: self.dirty_arbiter = None self.dirty_pidfile = None # Well-known location for orphan detection + # Control socket server + self._control_server = None + + # Stats tracking + self._stats = { + 'start_time': None, + 'workers_spawned': 0, + 'workers_killed': 0, + 'reloads': 0, + } + cwd = util.getcwd() args = sys.argv[:] @@ -133,6 +144,9 @@ class Arbiter: """ self.log.info("Starting gunicorn %s", __version__) + # Initialize stats tracking + self._stats['start_time'] = time.time() + if 'GUNICORN_PID' in os.environ: self.master_pid = int(os.environ.get('GUNICORN_PID')) self.proc_name = self.proc_name + ".2" @@ -179,6 +193,9 @@ class Arbiter: if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps: self.spawn_dirty_arbiter() + # Start control socket server + self._start_control_server() + self.cfg.when_ready(self) def init_signals(self): @@ -351,6 +368,9 @@ class Arbiter: def halt(self, reason=None, exit_status=0): """ halt arbiter """ + # Stop control socket server first + self._stop_control_server() + self.stop() log_func = self.log.info if exit_status == 0 else self.log.error @@ -477,6 +497,9 @@ class Arbiter: os.execvpe(self.START_CTX[0], self.START_CTX['args'], environ) def reload(self): + # Track reload stats + self._stats['reloads'] += 1 + old_address = self.cfg.address # reset old environment @@ -667,6 +690,7 @@ class Arbiter: if pid != 0: worker.pid = pid self.WORKERS[pid] = worker + self._stats['workers_spawned'] += 1 return pid # Do not inherit the temporary files of other workers @@ -737,6 +761,9 @@ class Arbiter: """ try: os.kill(pid, sig) + # Track kills only on SIGTERM/SIGKILL (actual termination signals) + if sig in (signal.SIGTERM, signal.SIGKILL): + self._stats['workers_killed'] += 1 except OSError as e: if e.errno == errno.ESRCH: try: @@ -906,3 +933,51 @@ class Arbiter: if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps: self.log.info("Spawning dirty arbiter...") self.spawn_dirty_arbiter() + + # ========================================================================= + # Control Socket Management + # ========================================================================= + + def _get_control_socket_path(self): + """Get the control socket path, making relative paths absolute.""" + socket_path = self.cfg.control_socket + if not os.path.isabs(socket_path): + socket_path = os.path.join(util.getcwd(), socket_path) + return socket_path + + def _start_control_server(self): + """\ + Start the control socket server. + + The server runs in a background thread and accepts commands + via Unix socket. + """ + if self.cfg.control_socket_disable: + self.log.debug("Control socket disabled") + return + + # Lazy import to avoid circular imports and gevent compatibility + from gunicorn.ctl.server import ControlSocketServer + + socket_path = self._get_control_socket_path() + socket_mode = self.cfg.control_socket_mode + + try: + self._control_server = ControlSocketServer( + self, socket_path, socket_mode + ) + self._control_server.start() + except Exception as e: + self.log.warning("Failed to start control socket: %s", e) + self._control_server = None + + def _stop_control_server(self): + """\ + Stop the control socket server. + """ + if self._control_server: + try: + self._control_server.stop() + except Exception as e: + self.log.debug("Error stopping control server: %s", e) + self._control_server = None diff --git a/gunicorn/config.py b/gunicorn/config.py index c391ae41..fa399f03 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -3113,3 +3113,63 @@ class DirtyWorkerExit(Setting): .. versionadded:: 25.0.0 """ + + +# Control Socket Settings + +class ControlSocket(Setting): + name = "control_socket" + section = "Control" + cli = ["--control-socket"] + meta = "PATH" + validator = validate_string + default = "gunicorn.ctl" + desc = """\ + Unix socket path for control interface. + + The control socket allows runtime management of Gunicorn via the + ``gunicornc`` command-line tool. Commands include viewing worker + status, adjusting worker count, and graceful reload/shutdown. + + By default, creates ``gunicorn.ctl`` in the working directory. + Set an absolute path for a fixed location (e.g., ``/var/run/gunicorn.ctl``). + + Use ``--no-control-socket`` to disable. + + .. versionadded:: 25.1.0 + """ + + +class ControlSocketMode(Setting): + name = "control_socket_mode" + section = "Control" + cli = ["--control-socket-mode"] + meta = "INT" + validator = validate_pos_int + type = auto_int + default = 0o600 + desc = """\ + Permission mode for control socket. + + Restricts who can connect to the control socket. Default ``0600`` + allows only the socket owner. Set to ``0660`` to allow group access. + + .. versionadded:: 25.1.0 + """ + + +class ControlSocketDisable(Setting): + name = "control_socket_disable" + section = "Control" + cli = ["--no-control-socket"] + validator = validate_bool + action = "store_true" + default = False + desc = """\ + Disable control socket. + + When set, no control socket is created and ``gunicornc`` cannot + connect to this Gunicorn instance. + + .. versionadded:: 25.1.0 + """ diff --git a/gunicorn/ctl/__init__.py b/gunicorn/ctl/__init__.py new file mode 100644 index 00000000..968d9d31 --- /dev/null +++ b/gunicorn/ctl/__init__.py @@ -0,0 +1,16 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Gunicorn Control Interface + +Provides a control socket server for runtime management and +a CLI client (gunicornc) for interacting with running Gunicorn instances. +""" + +from gunicorn.ctl.server import ControlSocketServer +from gunicorn.ctl.client import ControlClient +from gunicorn.ctl.protocol import ControlProtocol + +__all__ = ['ControlSocketServer', 'ControlClient', 'ControlProtocol'] diff --git a/gunicorn/ctl/cli.py b/gunicorn/ctl/cli.py new file mode 100644 index 00000000..d110bd0e --- /dev/null +++ b/gunicorn/ctl/cli.py @@ -0,0 +1,385 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +gunicornc - Gunicorn control interface CLI + +Interactive and single-command modes for controlling Gunicorn instances. +""" + +import argparse +import json +import os +import sys + +from gunicorn.ctl.client import ControlClient, ControlClientError, parse_command + + +def format_workers(data: dict) -> str: + """Format workers output for display.""" + workers = data.get("workers", []) + if not workers: + return "No workers running" + + lines = [] + lines.append(f"{'PID':<10} {'AGE':<6} {'BOOTED':<8} {'LAST_BEAT'}") + lines.append("-" * 40) + + for w in workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + booted = "yes" if w.get("booted") else "no" + hb = w.get("last_heartbeat") + hb_str = f"{hb}s ago" if hb is not None else "n/a" + + lines.append(f"{pid:<10} {age:<6} {booted:<8} {hb_str}") + + lines.append("") + lines.append(f"Total: {data.get('count', len(workers))} workers") + + return "\n".join(lines) + + +def format_dirty(data: dict) -> str: + """Format dirty workers output for display.""" + if not data.get("enabled"): + return "Dirty arbiter not running" + + lines = [] + lines.append(f"Dirty arbiter PID: {data.get('pid')}") + lines.append("") + + workers = data.get("workers", []) + if workers: + lines.append("DIRTY WORKERS:") + lines.append(f"{'PID':<10} {'AGE':<6} {'APPS':<30} {'LAST_BEAT'}") + lines.append("-" * 60) + + for w in workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + apps = ", ".join(w.get("apps", []))[:30] + hb = w.get("last_heartbeat") + hb_str = f"{hb}s ago" if hb is not None else "n/a" + + lines.append(f"{pid:<10} {age:<6} {apps:<30} {hb_str}") + lines.append("") + + apps = data.get("apps", []) + if apps: + lines.append("DIRTY APPS:") + lines.append(f"{'APP':<30} {'WORKERS':<10} {'LIMIT'}") + lines.append("-" * 50) + + for app in apps: + path = app.get("import_path", "?")[:30] + current = app.get("current_workers", 0) + limit = app.get("worker_count") + limit_str = str(limit) if limit is not None else "none" + + lines.append(f"{path:<30} {current:<10} {limit_str}") + + return "\n".join(lines) + + +def format_stats(data: dict) -> str: + """Format stats output for display.""" + lines = [] + + uptime = data.get("uptime") + if uptime: + hours = int(uptime // 3600) + minutes = int((uptime % 3600) // 60) + seconds = int(uptime % 60) + if hours: + uptime_str = f"{hours}h {minutes}m {seconds}s" + elif minutes: + uptime_str = f"{minutes}m {seconds}s" + else: + uptime_str = f"{seconds}s" + else: + uptime_str = "unknown" + + lines.append(f"Uptime: {uptime_str}") + lines.append(f"PID: {data.get('pid', 'unknown')}") + lines.append(f"Workers current: {data.get('workers_current', 0)}") + lines.append(f"Workers target: {data.get('workers_target', 0)}") + lines.append(f"Workers spawned: {data.get('workers_spawned', 0)}") + lines.append(f"Workers killed: {data.get('workers_killed', 0)}") + lines.append(f"Reloads: {data.get('reloads', 0)}") + + dirty_pid = data.get("dirty_arbiter_pid") + if dirty_pid: + lines.append(f"Dirty arbiter: {dirty_pid}") + + return "\n".join(lines) + + +def format_listeners(data: dict) -> str: + """Format listeners output for display.""" + listeners = data.get("listeners", []) + if not listeners: + return "No listeners bound" + + lines = [] + lines.append(f"{'ADDRESS':<40} {'TYPE':<8} {'FD'}") + lines.append("-" * 55) + + for lnr in listeners: + addr = lnr.get("address", "?") + ltype = lnr.get("type", "?") + fd = lnr.get("fd", "?") + lines.append(f"{addr:<40} {ltype:<8} {fd}") + + lines.append("") + lines.append(f"Total: {data.get('count', len(listeners))} listeners") + + return "\n".join(lines) + + +def format_config(data: dict) -> str: + """Format config output for display.""" + lines = [] + + # Sort keys for consistent output + for key in sorted(data.keys()): + value = data[key] + if isinstance(value, list): + value = ", ".join(str(v) for v in value) + lines.append(f"{key}: {value}") + + return "\n".join(lines) + + +def format_help(data: dict) -> str: + """Format help output for display.""" + commands = data.get("commands", {}) + lines = [] + lines.append("Available commands:") + lines.append("") + + # Find max command length for alignment + max_len = max(len(cmd) for cmd in commands.keys()) if commands else 0 + + for cmd, desc in sorted(commands.items()): + lines.append(f" {cmd:<{max_len + 2}} {desc}") + + return "\n".join(lines) + + +def format_response(command: str, data: dict) -> str: + """ + Format response data based on command. + + Args: + command: Original command string + data: Response data dictionary + + Returns: + Formatted string for display + """ + cmd_lower = command.lower().strip() + + # Route to specific formatters + if cmd_lower == "show workers": + return format_workers(data) + elif cmd_lower == "show dirty": + return format_dirty(data) + elif cmd_lower == "show stats": + return format_stats(data) + elif cmd_lower == "show listeners": + return format_listeners(data) + elif cmd_lower == "show config": + return format_config(data) + elif cmd_lower == "help": + return format_help(data) + else: + # Generic JSON output for other commands + if data: + return json.dumps(data, indent=2) + return "OK" + + +def run_command(socket_path: str, command: str, json_output: bool = False) -> int: + """ + Execute single command and exit. + + Args: + socket_path: Path to control socket + command: Command to execute + json_output: If True, output raw JSON + + Returns: + Exit code (0 for success, 1 for error) + """ + try: + with ControlClient(socket_path) as client: + cmd, args = parse_command(command) + full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd + result = client.send_command(full_command) + + if json_output: + print(json.dumps(result, indent=2)) + else: + output = format_response(cmd, result) + print(output) + + return 0 + + except ControlClientError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + except KeyboardInterrupt: + return 130 + + +def run_interactive(socket_path: str, json_output: bool = False) -> int: + """ + Run interactive CLI with readline support. + + Args: + socket_path: Path to control socket + json_output: If True, output raw JSON + + Returns: + Exit code + """ + try: + import readline # noqa: F401 - imported for side effects + has_readline = True + except ImportError: + has_readline = False + + try: + client = ControlClient(socket_path) + client.connect() + except ControlClientError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + print(f"Connected to {socket_path}") + print("Type 'help' for available commands, 'quit' to exit.") + print() + + # Set up readline history + history_file = os.path.expanduser("~/.gunicornc_history") + if has_readline: + try: + readline.read_history_file(history_file) + except FileNotFoundError: + pass + + exit_code = 0 + + try: + while True: + try: + line = input("gunicorn> ").strip() + except EOFError: + print() + break + + if not line: + continue + + if line.lower() in ('quit', 'exit', 'q'): + break + + try: + cmd, args = parse_command(line) + full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd + result = client.send_command(full_command) + + if json_output: + print(json.dumps(result, indent=2)) + else: + output = format_response(cmd, result) + print(output) + + except ControlClientError as e: + print(f"Error: {e}") + # Try to reconnect + try: + client.close() + client.connect() + except ControlClientError: + print("Connection lost. Exiting.") + exit_code = 1 + break + + print() + + except KeyboardInterrupt: + print() + exit_code = 130 + finally: + client.close() + if has_readline: + try: + readline.write_history_file(history_file) + except Exception: + pass + + return exit_code + + +def main(): + """Main entry point for gunicornc CLI.""" + parser = argparse.ArgumentParser( + description='Gunicorn control interface', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + gunicornc # Interactive mode (default socket) + gunicornc -s /tmp/myapp.ctl # Interactive mode with custom socket + gunicornc -c "show workers" # Single command mode + gunicornc -c "worker add 2" # Add 2 workers + gunicornc -c "show stats" -j # Output stats as JSON + """ + ) + + parser.add_argument( + '-s', '--socket', + default='gunicorn.ctl', + help='Control socket path (default: gunicorn.ctl in current directory)' + ) + + parser.add_argument( + '-c', '--command', + help='Execute single command and exit' + ) + + parser.add_argument( + '-j', '--json', + action='store_true', + help='Output raw JSON (for scripting)' + ) + + parser.add_argument( + '-v', '--version', + action='store_true', + help='Show version and exit' + ) + + args = parser.parse_args() + + if args.version: + from gunicorn import __version__ + print(f"gunicornc (gunicorn {__version__})") + return 0 + + socket_path = args.socket + + # Make relative paths absolute from cwd + if not os.path.isabs(socket_path): + socket_path = os.path.join(os.getcwd(), socket_path) + + if args.command: + return run_command(socket_path, args.command, args.json) + else: + return run_interactive(socket_path, args.json) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/gunicorn/ctl/client.py b/gunicorn/ctl/client.py new file mode 100644 index 00000000..e75ec713 --- /dev/null +++ b/gunicorn/ctl/client.py @@ -0,0 +1,140 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Client + +Client library for connecting to gunicorn control socket. +""" + +import shlex +import socket + +from gunicorn.ctl.protocol import ( + ControlProtocol, + make_request, +) + + +class ControlClientError(Exception): + """Control client error.""" + pass + + +class ControlClient: + """ + Client for connecting to gunicorn control socket. + + Can be used as a context manager: + + with ControlClient('/path/to/gunicorn.ctl') as client: + result = client.send_command('show workers') + """ + + def __init__(self, socket_path: str, timeout: float = 30.0): + """ + Initialize control client. + + Args: + socket_path: Path to the Unix socket + timeout: Socket timeout in seconds (default 30) + """ + self.socket_path = socket_path + self.timeout = timeout + self._sock = None + self._request_id = 0 + + def connect(self): + """ + Connect to control socket. + + Raises: + ControlClientError: If connection fails + """ + if self._sock: + return + + try: + self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._sock.settimeout(self.timeout) + self._sock.connect(self.socket_path) + except socket.error as e: + self._sock = None + raise ControlClientError(f"Failed to connect to {self.socket_path}: {e}") + + def close(self): + """Close connection.""" + if self._sock: + try: + self._sock.close() + except Exception: + pass + self._sock = None + + def send_command(self, command: str, args: list = None) -> dict: + """ + Send command and wait for response. + + Args: + command: Command string (e.g., "show workers") + args: Optional additional arguments + + Returns: + Response data dictionary + + Raises: + ControlClientError: If communication fails + """ + if not self._sock: + self.connect() + + self._request_id += 1 + request = make_request(self._request_id, command, args) + + try: + ControlProtocol.write_message(self._sock, request) + response = ControlProtocol.read_message(self._sock) + except Exception as e: + self.close() + raise ControlClientError(f"Communication error: {e}") + + if response.get("status") == "error": + raise ControlClientError(response.get("error", "Unknown error")) + + return response.get("data", {}) + + def __enter__(self): + self.connect() + return self + + def __exit__(self, *args): + self.close() + + +def parse_command(line: str) -> tuple: + """ + Parse a command line into command and args. + + Args: + line: Command line string + + Returns: + Tuple of (command_string, args_list) + """ + parts = shlex.split(line) + if not parts: + return "", [] + + # Find where numeric/value args start + command_parts = [] + args = [] + + for part in parts: + # If we haven't hit args yet and this looks like a command word + if not args and not part.isdigit() and not part.startswith('-'): + command_parts.append(part) + else: + args.append(part) + + return " ".join(command_parts), args diff --git a/gunicorn/ctl/handlers.py b/gunicorn/ctl/handlers.py new file mode 100644 index 00000000..7005b5bd --- /dev/null +++ b/gunicorn/ctl/handlers.py @@ -0,0 +1,431 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Interface Command Handlers + +Provides handlers for all control commands with access to arbiter state. +""" + +import os +import signal +import time + + +class CommandHandlers: + """ + Command handlers with access to arbiter state. + + All handler methods return dictionaries that will be sent + as the response data. + """ + + def __init__(self, arbiter): + """ + Initialize handlers with arbiter reference. + + Args: + arbiter: The Gunicorn arbiter instance + """ + self.arbiter = arbiter + + def show_workers(self) -> dict: + """ + Return list of HTTP workers. + + Returns: + Dictionary with workers list containing: + - pid: Worker process ID + - age: Worker age (spawn order) + - requests: Number of requests handled (if available) + - booted: Whether worker has finished booting + - last_heartbeat: Seconds since last heartbeat + """ + workers = [] + now = time.monotonic() + + for pid, worker in self.arbiter.WORKERS.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError): + last_heartbeat = None + + workers.append({ + "pid": pid, + "age": worker.age, + "booted": worker.booted, + "aborted": worker.aborted, + "last_heartbeat": last_heartbeat, + }) + + # Sort by age (oldest first) + workers.sort(key=lambda w: w["age"]) + + return {"workers": workers, "count": len(workers)} + + def show_dirty(self) -> dict: + """ + Return dirty workers and apps information. + + Returns: + Dictionary with: + - enabled: Whether dirty arbiter is running + - pid: Dirty arbiter PID + - workers: List of dirty worker info + - apps: List of dirty app specs + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "enabled": False, + "pid": None, + "workers": [], + "apps": [], + } + + # Get dirty arbiter reference if available + dirty_arbiter = getattr(self.arbiter, 'dirty_arbiter', None) + + workers = [] + apps = [] + + if dirty_arbiter and hasattr(dirty_arbiter, 'workers'): + now = time.monotonic() + for pid, worker in dirty_arbiter.workers.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError, AttributeError): + last_heartbeat = None + + workers.append({ + "pid": pid, + "age": worker.age, + "apps": getattr(worker, 'app_paths', []), + "booted": getattr(worker, 'booted', False), + "last_heartbeat": last_heartbeat, + }) + + # Get app specs + if hasattr(dirty_arbiter, 'app_specs'): + for path, spec in dirty_arbiter.app_specs.items(): + worker_pids = list(dirty_arbiter.app_worker_map.get(path, [])) + apps.append({ + "import_path": path, + "worker_count": spec.get('worker_count'), + "current_workers": len(worker_pids), + "worker_pids": worker_pids, + }) + + return { + "enabled": True, + "pid": self.arbiter.dirty_arbiter_pid, + "workers": workers, + "apps": apps, + } + + def show_config(self) -> dict: + """ + Return current effective configuration. + + Returns: + Dictionary of configuration values + """ + cfg = self.arbiter.cfg + config = {} + + # Get commonly needed config values + config_keys = [ + 'bind', 'workers', 'worker_class', 'threads', 'timeout', + 'graceful_timeout', 'keepalive', 'max_requests', + 'max_requests_jitter', 'worker_connections', 'preload_app', + 'daemon', 'pidfile', 'proc_name', 'reload', + 'dirty_workers', 'dirty_apps', 'dirty_timeout', + 'control_socket', 'control_socket_disable', + ] + + for key in config_keys: + try: + value = getattr(cfg, key) + # Convert non-serializable types + if callable(value): + value = str(value) + elif hasattr(value, '__class__') and not isinstance( + value, (str, int, float, bool, list, dict, type(None))): + value = str(value) + config[key] = value + except AttributeError: + pass + + return config + + def show_stats(self) -> dict: + """ + Return server statistics. + + Returns: + Dictionary with: + - uptime: Seconds since arbiter started + - pid: Arbiter PID + - workers_current: Current number of workers + - workers_spawned: Total workers spawned + - workers_killed: Total workers killed (if tracked) + - reloads: Number of reloads (if tracked) + """ + stats = getattr(self.arbiter, '_stats', {}) + start_time = stats.get('start_time') + + uptime = None + if start_time: + uptime = round(time.time() - start_time, 2) + + return { + "uptime": uptime, + "pid": self.arbiter.pid, + "workers_current": len(self.arbiter.WORKERS), + "workers_target": self.arbiter.num_workers, + "workers_spawned": stats.get('workers_spawned', 0), + "workers_killed": stats.get('workers_killed', 0), + "reloads": stats.get('reloads', 0), + "dirty_arbiter_pid": self.arbiter.dirty_arbiter_pid or None, + } + + def show_listeners(self) -> dict: + """ + Return bound socket information. + + Returns: + Dictionary with listeners list + """ + listeners = [] + + for lnr in self.arbiter.LISTENERS: + addr = str(lnr) + listener_info = { + "address": addr, + "fd": lnr.fileno(), + } + + # Try to get socket family + try: + import socket + sock = lnr.sock + if sock.family == socket.AF_UNIX: + listener_info["type"] = "unix" + elif sock.family == socket.AF_INET: + listener_info["type"] = "tcp" + elif sock.family == socket.AF_INET6: + listener_info["type"] = "tcp6" + except Exception: + listener_info["type"] = "unknown" + + listeners.append(listener_info) + + return {"listeners": listeners, "count": len(listeners)} + + def worker_add(self, count: int = 1) -> dict: + """ + Increase worker count. + + Args: + count: Number of workers to add (default 1) + + Returns: + Dictionary with added count and new total + """ + count = max(1, int(count)) + old_count = self.arbiter.num_workers + self.arbiter.num_workers += count + + # Wake up the arbiter to spawn workers + self.arbiter.wakeup() + + return { + "added": count, + "previous": old_count, + "total": self.arbiter.num_workers, + } + + def worker_remove(self, count: int = 1) -> dict: + """ + Decrease worker count. + + Args: + count: Number of workers to remove (default 1) + + Returns: + Dictionary with removed count and new total + """ + count = max(1, int(count)) + old_count = self.arbiter.num_workers + + # Don't go below 1 worker + new_count = max(1, old_count - count) + actual_removed = old_count - new_count + + self.arbiter.num_workers = new_count + + # Wake up the arbiter to kill excess workers + self.arbiter.wakeup() + + return { + "removed": actual_removed, + "previous": old_count, + "total": new_count, + } + + def worker_kill(self, pid: int) -> dict: + """ + Gracefully terminate a specific worker. + + Args: + pid: Worker process ID + + Returns: + Dictionary with killed PID or error + """ + pid = int(pid) + + if pid not in self.arbiter.WORKERS: + return { + "success": False, + "error": f"Worker {pid} not found", + } + + try: + os.kill(pid, signal.SIGTERM) + return { + "success": True, + "killed": pid, + } + except OSError as e: + return { + "success": False, + "error": str(e), + } + + def dirty_add(self, count: int = 1) -> dict: + """ + Spawn additional dirty workers. + + Args: + count: Number of dirty workers to add (default 1) + + Returns: + Dictionary with added count or error + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "success": False, + "error": "Dirty arbiter not running", + } + + # Send TTIN signals to dirty arbiter + count = max(1, int(count)) + try: + for _ in range(count): + os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN) + return { + "success": True, + "added": count, + } + except OSError as e: + return { + "success": False, + "error": str(e), + } + + def dirty_remove(self, count: int = 1) -> dict: + """ + Remove dirty workers. + + Args: + count: Number of dirty workers to remove (default 1) + + Returns: + Dictionary with removed count or error + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "success": False, + "error": "Dirty arbiter not running", + } + + # Send TTOU signals to dirty arbiter + count = max(1, int(count)) + try: + for _ in range(count): + os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU) + return { + "success": True, + "removed": count, + } + except OSError as e: + return { + "success": False, + "error": str(e), + } + + def reload(self) -> dict: + """ + Trigger graceful reload (equivalent to SIGHUP). + + Returns: + Dictionary with status + """ + # Send HUP to self to trigger reload + os.kill(self.arbiter.pid, signal.SIGHUP) + return {"status": "reloading"} + + def reopen(self) -> dict: + """ + Reopen log files (equivalent to SIGUSR1). + + Returns: + Dictionary with status + """ + os.kill(self.arbiter.pid, signal.SIGUSR1) + return {"status": "reopening"} + + def shutdown(self, mode: str = "graceful") -> dict: + """ + Initiate shutdown. + + Args: + mode: "graceful" (SIGTERM) or "quick" (SIGINT) + + Returns: + Dictionary with status + """ + if mode == "quick": + os.kill(self.arbiter.pid, signal.SIGINT) + else: + os.kill(self.arbiter.pid, signal.SIGTERM) + + return {"status": "shutting_down", "mode": mode} + + def help(self) -> dict: + """ + Return list of available commands. + + Returns: + Dictionary with commands and descriptions + """ + commands = { + "show workers": "List HTTP workers with their status", + "show dirty": "List dirty workers and apps", + "show config": "Show current effective configuration", + "show stats": "Show server statistics", + "show listeners": "Show bound sockets", + "worker add [N]": "Spawn N workers (default 1)", + "worker remove [N]": "Remove N workers (default 1)", + "worker kill ": "Gracefully terminate specific worker", + "dirty add [N]": "Spawn N dirty workers (default 1)", + "dirty remove [N]": "Remove N dirty workers (default 1)", + "reload": "Graceful reload (HUP)", + "reopen": "Reopen log files (USR1)", + "shutdown [graceful|quick]": "Shutdown server (TERM/INT)", + "help": "Show this help message", + } + return {"commands": commands} diff --git a/gunicorn/ctl/protocol.py b/gunicorn/ctl/protocol.py new file mode 100644 index 00000000..b00e4910 --- /dev/null +++ b/gunicorn/ctl/protocol.py @@ -0,0 +1,225 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Protocol + +JSON-based protocol with length-prefixed framing for the control interface. + +Message Format: + +----------------+------------------+ + | Length (4B BE) | JSON Payload | + +----------------+------------------+ + +Request Format: + {"id": 1, "command": "show", "args": ["workers"]} + +Response Format: + {"id": 1, "status": "ok", "data": {...}} + {"id": 1, "status": "error", "error": "message"} +""" + +import json +import struct + + +class ProtocolError(Exception): + """Protocol-level error.""" + pass + + +class ControlProtocol: + """ + Protocol implementation for control socket communication. + + Uses 4-byte big-endian length prefix followed by JSON payload. + """ + + # Maximum message size (16 MB) + MAX_MESSAGE_SIZE = 16 * 1024 * 1024 + + @staticmethod + def encode_message(data: dict) -> bytes: + """ + Encode a message for transmission. + + Args: + data: Dictionary to encode + + Returns: + Length-prefixed JSON bytes + """ + payload = json.dumps(data).encode('utf-8') + length = struct.pack('>I', len(payload)) + return length + payload + + @staticmethod + def decode_message(data: bytes) -> dict: + """ + Decode a message from bytes. + + Args: + data: Raw bytes (length prefix + JSON payload) + + Returns: + Decoded dictionary + """ + if len(data) < 4: + raise ProtocolError("Message too short") + + length = struct.unpack('>I', data[:4])[0] + if len(data) < 4 + length: + raise ProtocolError("Incomplete message") + + payload = data[4:4 + length] + return json.loads(payload.decode('utf-8')) + + @staticmethod + def read_message(sock) -> dict: + """ + Read one message from a socket. + + Args: + sock: Socket to read from + + Returns: + Decoded message dictionary + + Raises: + ProtocolError: If message is malformed + ConnectionError: If connection is closed + """ + # Read length prefix + length_data = b'' + while len(length_data) < 4: + chunk = sock.recv(4 - len(length_data)) + if not chunk: + if not length_data: + raise ConnectionError("Connection closed") + raise ProtocolError("Incomplete length prefix") + length_data += chunk + + length = struct.unpack('>I', length_data)[0] + + if length > ControlProtocol.MAX_MESSAGE_SIZE: + raise ProtocolError(f"Message too large: {length}") + + # Read payload + payload_data = b'' + while len(payload_data) < length: + chunk = sock.recv(min(length - len(payload_data), 65536)) + if not chunk: + raise ProtocolError("Incomplete payload") + payload_data += chunk + + try: + return json.loads(payload_data.decode('utf-8')) + except json.JSONDecodeError as e: + raise ProtocolError(f"Invalid JSON: {e}") + + @staticmethod + def write_message(sock, data: dict): + """ + Write one message to a socket. + + Args: + sock: Socket to write to + data: Message dictionary to send + """ + message = ControlProtocol.encode_message(data) + sock.sendall(message) + + @staticmethod + async def read_message_async(reader) -> dict: + """ + Read one message from an async reader. + + Args: + reader: asyncio StreamReader + + Returns: + Decoded message dictionary + """ + # Read length prefix + length_data = await reader.readexactly(4) + length = struct.unpack('>I', length_data)[0] + + if length > ControlProtocol.MAX_MESSAGE_SIZE: + raise ProtocolError(f"Message too large: {length}") + + # Read payload + payload_data = await reader.readexactly(length) + + try: + return json.loads(payload_data.decode('utf-8')) + except json.JSONDecodeError as e: + raise ProtocolError(f"Invalid JSON: {e}") + + @staticmethod + async def write_message_async(writer, data: dict): + """ + Write one message to an async writer. + + Args: + writer: asyncio StreamWriter + data: Message dictionary to send + """ + message = ControlProtocol.encode_message(data) + writer.write(message) + await writer.drain() + + +def make_request(request_id: int, command: str, args: list = None) -> dict: + """ + Create a request message. + + Args: + request_id: Unique request identifier + command: Command name (e.g., "show workers") + args: Optional list of arguments + + Returns: + Request dictionary + """ + return { + "id": request_id, + "command": command, + "args": args or [], + } + + +def make_response(request_id: int, data: dict = None) -> dict: + """ + Create a success response message. + + Args: + request_id: Request identifier being responded to + data: Response data + + Returns: + Response dictionary + """ + return { + "id": request_id, + "status": "ok", + "data": data or {}, + } + + +def make_error_response(request_id: int, error: str) -> dict: + """ + Create an error response message. + + Args: + request_id: Request identifier being responded to + error: Error message + + Returns: + Error response dictionary + """ + return { + "id": request_id, + "status": "error", + "error": error, + } diff --git a/gunicorn/ctl/server.py b/gunicorn/ctl/server.py new file mode 100644 index 00000000..3558d6e4 --- /dev/null +++ b/gunicorn/ctl/server.py @@ -0,0 +1,299 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Server + +Runs in the arbiter process and accepts commands via Unix socket. +Uses asyncio in a background thread to handle client connections. +""" + +import asyncio +import os +import shlex +import threading + +from gunicorn.ctl.handlers import CommandHandlers +from gunicorn.ctl.protocol import ( + ControlProtocol, + make_response, + make_error_response, +) + + +class ControlSocketServer: + """ + Control socket server running in arbiter process. + + The server runs an asyncio event loop in a background thread, + accepting connections and dispatching commands to handlers. + """ + + def __init__(self, arbiter, socket_path, socket_mode=0o600): + """ + Initialize control socket server. + + Args: + arbiter: The Gunicorn arbiter instance + socket_path: Path for the Unix socket + socket_mode: Permission mode for socket (default 0o600) + """ + self.arbiter = arbiter + self.socket_path = socket_path + self.socket_mode = socket_mode + + self.handlers = CommandHandlers(arbiter) + self._server = None + self._loop = None + self._thread = None + self._running = False + + def start(self): + """Start server in background thread with asyncio event loop.""" + if self._running: + return + + self._running = True + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + + def stop(self): + """Stop server and cleanup socket.""" + if not self._running: + return + + self._running = False + + if self._loop and self._server: + # Schedule server close in the loop + self._loop.call_soon_threadsafe(self._shutdown) + + if self._thread: + self._thread.join(timeout=2.0) + self._thread = None + + # Clean up socket file + if os.path.exists(self.socket_path): + try: + os.unlink(self.socket_path) + except OSError: + pass + + def _shutdown(self): + """Shutdown server (called from event loop thread).""" + if self._server: + self._server.close() + + def _run_loop(self): + """Run the asyncio event loop in background thread.""" + try: + asyncio.run(self._serve()) + except Exception as e: + if self.arbiter.log: + self.arbiter.log.error("Control server error: %s", e) + + async def _serve(self): + """Main async server loop.""" + self._loop = asyncio.get_running_loop() + + # Remove socket if it exists + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + # Create Unix socket server + self._server = await asyncio.start_unix_server( + self._handle_client, + path=self.socket_path + ) + + # Set socket permissions + os.chmod(self.socket_path, self.socket_mode) + + if self.arbiter.log: + self.arbiter.log.info("Control socket listening at %s", + self.socket_path) + + try: + async with self._server: + await self._server.serve_forever() + except asyncio.CancelledError: + pass + finally: + if os.path.exists(self.socket_path): + try: + os.unlink(self.socket_path) + except OSError: + pass + + async def _handle_client(self, reader, writer): + """ + Handle client connection. + + Args: + reader: asyncio StreamReader + writer: asyncio StreamWriter + """ + try: + while self._running: + try: + message = await asyncio.wait_for( + ControlProtocol.read_message_async(reader), + timeout=300.0 # 5 minute idle timeout + ) + except asyncio.TimeoutError: + # Client idle too long, close connection + break + except asyncio.IncompleteReadError: + # Client disconnected + break + except Exception: + # Protocol error + break + + # Process command + response = await self._dispatch(message) + + # Send response + await ControlProtocol.write_message_async(writer, response) + + except Exception as e: + if self.arbiter.log: + self.arbiter.log.debug("Control client error: %s", e) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + async def _dispatch(self, message: dict) -> dict: + """ + Dispatch command to appropriate handler. + + Args: + message: Request message dict + + Returns: + Response dictionary + """ + request_id = message.get("id", 0) + command = message.get("command", "").strip() + args = message.get("args", []) + + if not command: + return make_error_response(request_id, "Empty command") + + try: + # Parse command (e.g., "show workers" or "worker add 2") + parts = shlex.split(command) + if args: + parts.extend(str(a) for a in args) + + if not parts: + return make_error_response(request_id, "Empty command") + + # Route to handler + result = self._execute_command(parts) + return make_response(request_id, result) + + except ValueError as e: + return make_error_response(request_id, f"Invalid argument: {e}") + except Exception as e: + if self.arbiter.log: + self.arbiter.log.exception("Command error") + return make_error_response(request_id, f"Command failed: {e}") + + def _execute_command(self, parts: list) -> dict: + """ + Execute a parsed command. + + Args: + parts: Command parts (e.g., ["show", "workers"]) + + Returns: + Handler result dictionary + """ + if not parts: + raise ValueError("Empty command") + + cmd = parts[0].lower() + rest = parts[1:] + + # Map commands to handlers + if cmd == "show": + return self._handle_show(rest) + elif cmd == "worker": + return self._handle_worker(rest) + elif cmd == "dirty": + return self._handle_dirty(rest) + elif cmd == "reload": + return self.handlers.reload() + elif cmd == "reopen": + return self.handlers.reopen() + elif cmd == "shutdown": + mode = rest[0] if rest else "graceful" + return self.handlers.shutdown(mode) + elif cmd == "help": + return self.handlers.help() + else: + raise ValueError(f"Unknown command: {cmd}") + + def _handle_show(self, args: list) -> dict: + """Handle 'show' commands.""" + if not args: + raise ValueError("Missing show target (workers|dirty|config|stats|listeners)") + + target = args[0].lower() + + if target == "workers": + return self.handlers.show_workers() + elif target == "dirty": + return self.handlers.show_dirty() + elif target == "config": + return self.handlers.show_config() + elif target == "stats": + return self.handlers.show_stats() + elif target == "listeners": + return self.handlers.show_listeners() + else: + raise ValueError(f"Unknown show target: {target}") + + def _handle_worker(self, args: list) -> dict: + """Handle 'worker' commands.""" + if not args: + raise ValueError("Missing worker action (add|remove|kill)") + + action = args[0].lower() + action_args = args[1:] + + if action == "add": + count = int(action_args[0]) if action_args else 1 + return self.handlers.worker_add(count) + elif action == "remove": + count = int(action_args[0]) if action_args else 1 + return self.handlers.worker_remove(count) + elif action == "kill": + if not action_args: + raise ValueError("Missing PID for worker kill") + pid = int(action_args[0]) + return self.handlers.worker_kill(pid) + else: + raise ValueError(f"Unknown worker action: {action}") + + def _handle_dirty(self, args: list) -> dict: + """Handle 'dirty' commands.""" + if not args: + raise ValueError("Missing dirty action (add|remove)") + + action = args[0].lower() + action_args = args[1:] + + if action == "add": + count = int(action_args[0]) if action_args else 1 + return self.handlers.dirty_add(count) + elif action == "remove": + count = int(action_args[0]) if action_args else 1 + return self.handlers.dirty_remove(count) + else: + raise ValueError(f"Unknown dirty action: {action}") diff --git a/gunicorn/dirty/arbiter.py b/gunicorn/dirty/arbiter.py index 44b23329..dcf2ae9e 100644 --- a/gunicorn/dirty/arbiter.py +++ b/gunicorn/dirty/arbiter.py @@ -367,7 +367,8 @@ class DirtyArbiter: try: async with self._server: await self._server.serve_forever() - except asyncio.CancelledError: + except (asyncio.CancelledError, RuntimeError): + # RuntimeError raised when server.close() is called during serve_forever() pass finally: monitor_task.cancel() @@ -836,19 +837,19 @@ class DirtyArbiter: pid, app_paths) return pid - # Child process + # Child process - use os._exit() to avoid asyncio cleanup issues worker.pid = os.getpid() try: util._setproctitle(f"dirty-worker [{self.cfg.proc_name}]") worker.init_process() - sys.exit(0) - except SystemExit: - raise + os._exit(0) + except SystemExit as e: + os._exit(e.code if e.code is not None else 0) except Exception: self.log.exception("Exception in dirty worker process") if not worker.booted: - sys.exit(self.WORKER_BOOT_ERROR) - sys.exit(-1) + os._exit(self.WORKER_BOOT_ERROR) + os._exit(1) def kill_worker(self, pid, sig): """Kill a worker by PID.""" diff --git a/pyproject.toml b/pyproject.toml index 852cdfe6..c23d6622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ testing = [ [project.scripts] # duplicates "python -m gunicorn" handling in __main__.py gunicorn = "gunicorn.app.wsgiapp:run" +gunicornc = "gunicorn.ctl.cli:main" # note the quotes around "paste.server_runner" to escape the dot [project.entry-points."paste.server_runner"] diff --git a/tests/ctl/__init__.py b/tests/ctl/__init__.py new file mode 100644 index 00000000..530e35ca --- /dev/null +++ b/tests/ctl/__init__.py @@ -0,0 +1,3 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. diff --git a/tests/ctl/test_client.py b/tests/ctl/test_client.py new file mode 100644 index 00000000..7f7b770a --- /dev/null +++ b/tests/ctl/test_client.py @@ -0,0 +1,275 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket client.""" + +import os +import socket +import tempfile +import threading + +import pytest + +from gunicorn.ctl.client import ( + ControlClient, + ControlClientError, + parse_command, +) +from gunicorn.ctl.protocol import ControlProtocol, make_response + + +class TestControlClientInit: + """Tests for ControlClient initialization.""" + + def test_init_attributes(self): + """Test that client is initialized with correct attributes.""" + client = ControlClient("/tmp/test.sock", timeout=60.0) + + assert client.socket_path == "/tmp/test.sock" + assert client.timeout == 60.0 + assert client._sock is None + assert client._request_id == 0 + + +class TestControlClientConnect: + """Tests for ControlClient connection.""" + + def test_connect_nonexistent_socket(self): + """Test connecting to non-existent socket.""" + client = ControlClient("/nonexistent/socket.sock") + + with pytest.raises(ControlClientError) as exc_info: + client.connect() + + assert "Failed to connect" in str(exc_info.value) + + def test_connect_success(self): + """Test successful connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + # Create a listening socket + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + client = ControlClient(socket_path) + client.connect() + + assert client._sock is not None + client.close() + finally: + server_sock.close() + + def test_connect_already_connected(self): + """Test that connect is idempotent.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + client = ControlClient(socket_path) + client.connect() + first_sock = client._sock + client.connect() # Should not create new connection + + assert client._sock is first_sock + client.close() + finally: + server_sock.close() + + +class TestControlClientClose: + """Tests for ControlClient close.""" + + def test_close_idempotent(self): + """Test that close can be called multiple times.""" + client = ControlClient("/tmp/test.sock") + client.close() + client.close() # Should not raise + + def test_close_clears_socket(self): + """Test that close clears the socket.""" + client = ControlClient("/tmp/test.sock") + client._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.close() + + assert client._sock is None + + +class TestControlClientContextManager: + """Tests for context manager functionality.""" + + def test_context_manager_connection_error(self): + """Test context manager with connection error.""" + client = ControlClient("/nonexistent/socket.sock") + + with pytest.raises(ControlClientError): + with client: + pass + + def test_context_manager_success(self): + """Test successful context manager usage.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + with ControlClient(socket_path) as client: + assert client._sock is not None + + # After context manager exits, socket should be closed + assert client._sock is None + finally: + server_sock.close() + + +class TestControlClientSendCommand: + """Tests for send_command functionality.""" + + def test_send_command_success(self): + """Test successful command send.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + response_data = {"workers": [], "count": 0} + response_sent = threading.Event() + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = make_response(msg["id"], response_data) + ControlProtocol.write_message(conn, resp) + response_sent.set() + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + result = client.send_command("show workers") + + assert result == response_data + client.close() + finally: + response_sent.wait(timeout=2.0) + server_thread.join(timeout=2.0) + server_sock.close() + + def test_send_command_error_response(self): + """Test handling error response.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = { + "id": msg["id"], + "status": "error", + "error": "Unknown command", + } + ControlProtocol.write_message(conn, resp) + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + + with pytest.raises(ControlClientError) as exc_info: + client.send_command("invalid command") + + assert "Unknown command" in str(exc_info.value) + client.close() + finally: + server_thread.join(timeout=2.0) + server_sock.close() + + def test_send_command_auto_connect(self): + """Test that send_command auto-connects if not connected.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = make_response(msg["id"], {}) + ControlProtocol.write_message(conn, resp) + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + # Don't call connect() explicitly + result = client.send_command("help") + + assert isinstance(result, dict) + client.close() + finally: + server_thread.join(timeout=2.0) + server_sock.close() + + +class TestParseCommand: + """Tests for command parsing.""" + + def test_parse_simple_command(self): + """Test parsing simple command.""" + cmd, args = parse_command("show workers") + assert cmd == "show workers" + assert args == [] + + def test_parse_command_with_args(self): + """Test parsing command with arguments.""" + cmd, args = parse_command("worker add 2") + assert cmd == "worker add" + assert args == ["2"] + + def test_parse_command_with_multiple_args(self): + """Test parsing command with multiple arguments.""" + cmd, args = parse_command("worker kill 12345") + assert cmd == "worker kill" + assert args == ["12345"] + + def test_parse_empty_command(self): + """Test parsing empty command.""" + cmd, args = parse_command("") + assert cmd == "" + assert args == [] + + def test_parse_command_quoted(self): + """Test parsing command with quoted arguments.""" + cmd, args = parse_command('worker kill "12345"') + assert cmd == "worker kill" + assert args == ["12345"] diff --git a/tests/ctl/test_handlers.py b/tests/ctl/test_handlers.py new file mode 100644 index 00000000..f18f75ce --- /dev/null +++ b/tests/ctl/test_handlers.py @@ -0,0 +1,374 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket command handlers.""" + +import signal +import time +from unittest.mock import MagicMock, patch + +import pytest + +from gunicorn.ctl.handlers import CommandHandlers + + +class MockWorker: + """Mock worker for testing.""" + + def __init__(self, pid, age, booted=True, aborted=False): + self.pid = pid + self.age = age + self.booted = booted + self.aborted = aborted + self.tmp = MagicMock() + self.tmp.last_update.return_value = time.monotonic() + + +class MockListener: + """Mock listener for testing.""" + + def __init__(self, address, fd=3): + self._address = address + self._fd = fd + self.sock = MagicMock() + self.sock.family = 2 # AF_INET + + def __str__(self): + return self._address + + def fileno(self): + return self._fd + + +class MockConfig: + """Mock config for testing.""" + + def __init__(self): + self.bind = ['127.0.0.1:8000'] + self.workers = 4 + self.worker_class = 'sync' + self.threads = 1 + self.timeout = 30 + self.graceful_timeout = 30 + self.keepalive = 2 + self.max_requests = 0 + self.max_requests_jitter = 0 + self.worker_connections = 1000 + self.preload_app = False + self.daemon = False + self.pidfile = None + self.proc_name = 'test_app' + self.reload = False + self.dirty_workers = 0 + self.dirty_apps = [] + self.dirty_timeout = 30 + self.control_socket = 'gunicorn.ctl' + self.control_socket_disable = False + + +class MockArbiter: + """Mock arbiter for testing.""" + + def __init__(self): + self.cfg = MockConfig() + self.pid = 12345 + self.WORKERS = {} + self.LISTENERS = [] + self.dirty_arbiter_pid = 0 + self.dirty_arbiter = None + self.num_workers = 4 + self._stats = { + 'start_time': time.time() - 3600, # 1 hour ago + 'workers_spawned': 10, + 'workers_killed': 5, + 'reloads': 2, + } + + def wakeup(self): + pass + + +class TestShowWorkers: + """Tests for show workers command.""" + + def test_show_workers_empty(self): + """Test showing workers when none exist.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_workers() + + assert result["workers"] == [] + assert result["count"] == 0 + + def test_show_workers_with_workers(self): + """Test showing workers.""" + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + 1003: MockWorker(1003, 3), + } + handlers = CommandHandlers(arbiter) + + result = handlers.show_workers() + + assert result["count"] == 3 + assert len(result["workers"]) == 3 + + # Verify sorted by age + ages = [w["age"] for w in result["workers"]] + assert ages == sorted(ages) + + # Verify worker data + worker = result["workers"][0] + assert "pid" in worker + assert "age" in worker + assert "booted" in worker + assert "last_heartbeat" in worker + + +class TestShowStats: + """Tests for show stats command.""" + + def test_show_stats(self): + """Test showing stats.""" + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + } + handlers = CommandHandlers(arbiter) + + result = handlers.show_stats() + + assert result["pid"] == 12345 + assert result["workers_current"] == 2 + assert result["workers_target"] == 4 + assert result["workers_spawned"] == 10 + assert result["workers_killed"] == 5 + assert result["reloads"] == 2 + assert result["uptime"] is not None + assert result["uptime"] > 0 + + +class TestShowConfig: + """Tests for show config command.""" + + def test_show_config(self): + """Test showing config.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_config() + + assert result["workers"] == 4 + assert result["timeout"] == 30 + assert result["bind"] == ['127.0.0.1:8000'] + + +class TestShowListeners: + """Tests for show listeners command.""" + + def test_show_listeners_empty(self): + """Test showing listeners when none exist.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_listeners() + + assert result["listeners"] == [] + assert result["count"] == 0 + + def test_show_listeners(self): + """Test showing listeners.""" + arbiter = MockArbiter() + arbiter.LISTENERS = [ + MockListener("127.0.0.1:8000", fd=3), + MockListener("127.0.0.1:8001", fd=4), + ] + handlers = CommandHandlers(arbiter) + + result = handlers.show_listeners() + + assert result["count"] == 2 + assert len(result["listeners"]) == 2 + assert result["listeners"][0]["address"] == "127.0.0.1:8000" + + +class TestWorkerAdd: + """Tests for worker add command.""" + + def test_worker_add_default(self): + """Test adding one worker (default).""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_add() + + assert result["added"] == 1 + assert result["previous"] == 4 + assert result["total"] == 5 + assert arbiter.num_workers == 5 + arbiter.wakeup.assert_called_once() + + def test_worker_add_multiple(self): + """Test adding multiple workers.""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_add(3) + + assert result["added"] == 3 + assert result["total"] == 7 + + +class TestWorkerRemove: + """Tests for worker remove command.""" + + def test_worker_remove_default(self): + """Test removing one worker (default).""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_remove() + + assert result["removed"] == 1 + assert result["previous"] == 4 + assert result["total"] == 3 + assert arbiter.num_workers == 3 + arbiter.wakeup.assert_called_once() + + def test_worker_remove_cannot_go_below_one(self): + """Test that worker count cannot go below 1.""" + arbiter = MockArbiter() + arbiter.num_workers = 2 + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_remove(5) + + assert result["removed"] == 1 + assert result["total"] == 1 + assert arbiter.num_workers == 1 + + +class TestWorkerKill: + """Tests for worker kill command.""" + + def test_worker_kill_success(self): + """Test killing a worker.""" + arbiter = MockArbiter() + arbiter.WORKERS = {1001: MockWorker(1001, 1)} + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.worker_kill(1001) + + assert result["success"] is True + assert result["killed"] == 1001 + mock_kill.assert_called_once_with(1001, signal.SIGTERM) + + def test_worker_kill_not_found(self): + """Test killing a non-existent worker.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_kill(9999) + + assert result["success"] is False + assert "not found" in result["error"] + + +class TestShowDirty: + """Tests for show dirty command.""" + + def test_show_dirty_disabled(self): + """Test showing dirty when disabled.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_dirty() + + assert result["enabled"] is False + assert result["pid"] is None + + +class TestReload: + """Tests for reload command.""" + + def test_reload(self): + """Test reload command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.reload() + + assert result["status"] == "reloading" + mock_kill.assert_called_once_with(12345, signal.SIGHUP) + + +class TestReopen: + """Tests for reopen command.""" + + def test_reopen(self): + """Test reopen command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.reopen() + + assert result["status"] == "reopening" + mock_kill.assert_called_once_with(12345, signal.SIGUSR1) + + +class TestShutdown: + """Tests for shutdown command.""" + + def test_shutdown_graceful(self): + """Test graceful shutdown.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.shutdown() + + assert result["status"] == "shutting_down" + assert result["mode"] == "graceful" + mock_kill.assert_called_once_with(12345, signal.SIGTERM) + + def test_shutdown_quick(self): + """Test quick shutdown.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.shutdown("quick") + + assert result["status"] == "shutting_down" + assert result["mode"] == "quick" + mock_kill.assert_called_once_with(12345, signal.SIGINT) + + +class TestHelp: + """Tests for help command.""" + + def test_help(self): + """Test help command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.help() + + assert "commands" in result + commands = result["commands"] + assert "show workers" in commands + assert "worker add [N]" in commands + assert "reload" in commands + assert "shutdown [graceful|quick]" in commands diff --git a/tests/ctl/test_protocol.py b/tests/ctl/test_protocol.py new file mode 100644 index 00000000..899ea580 --- /dev/null +++ b/tests/ctl/test_protocol.py @@ -0,0 +1,249 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket protocol.""" + +import json +import struct +import pytest + +from gunicorn.ctl.protocol import ( + ControlProtocol, + ProtocolError, + make_request, + make_response, + make_error_response, +) + + +class TestControlProtocolEncoding: + """Tests for message encoding/decoding.""" + + def test_encode_message_simple(self): + """Test encoding a simple message.""" + data = {"command": "test"} + result = ControlProtocol.encode_message(data) + + # First 4 bytes are length + length = struct.unpack('>I', result[:4])[0] + payload = result[4:] + + assert length == len(payload) + assert json.loads(payload.decode('utf-8')) == data + + def test_encode_message_unicode(self): + """Test encoding message with unicode characters.""" + data = {"message": "Hello \u4e16\u754c"} + result = ControlProtocol.encode_message(data) + + length = struct.unpack('>I', result[:4])[0] + payload = result[4:] + + assert length == len(payload) + assert json.loads(payload.decode('utf-8')) == data + + def test_decode_message_simple(self): + """Test decoding a simple message.""" + data = {"command": "test", "args": [1, 2, 3]} + payload = json.dumps(data).encode('utf-8') + length = struct.pack('>I', len(payload)) + raw = length + payload + + result = ControlProtocol.decode_message(raw) + assert result == data + + def test_decode_message_too_short(self): + """Test decoding message that's too short.""" + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.decode_message(b'\x00\x00') + assert "too short" in str(exc_info.value) + + def test_decode_message_incomplete(self): + """Test decoding incomplete message.""" + # Length says 100 bytes but only 4 bytes provided + raw = struct.pack('>I', 100) + b'test' + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.decode_message(raw) + assert "Incomplete" in str(exc_info.value) + + def test_roundtrip(self): + """Test encode/decode roundtrip.""" + original = { + "id": 42, + "command": "show workers", + "args": ["arg1", 123, True, None], + "nested": {"a": 1, "b": [1, 2, 3]}, + } + + encoded = ControlProtocol.encode_message(original) + decoded = ControlProtocol.decode_message(encoded) + + assert decoded == original + + +class TestMakeRequest: + """Tests for request creation.""" + + def test_make_request_simple(self): + """Test creating a simple request.""" + result = make_request(1, "show workers") + + assert result["id"] == 1 + assert result["command"] == "show workers" + assert result["args"] == [] + + def test_make_request_with_args(self): + """Test creating a request with arguments.""" + result = make_request(42, "worker add", [2]) + + assert result["id"] == 42 + assert result["command"] == "worker add" + assert result["args"] == [2] + + +class TestMakeResponse: + """Tests for response creation.""" + + def test_make_response_simple(self): + """Test creating a simple response.""" + result = make_response(1, {"count": 5}) + + assert result["id"] == 1 + assert result["status"] == "ok" + assert result["data"] == {"count": 5} + + def test_make_response_empty_data(self): + """Test creating response with no data.""" + result = make_response(1) + + assert result["id"] == 1 + assert result["status"] == "ok" + assert result["data"] == {} + + +class TestMakeErrorResponse: + """Tests for error response creation.""" + + def test_make_error_response(self): + """Test creating an error response.""" + result = make_error_response(1, "Unknown command") + + assert result["id"] == 1 + assert result["status"] == "error" + assert result["error"] == "Unknown command" + + +class TestControlProtocolSocket: + """Tests for socket reading/writing.""" + + def test_read_write_message(self): + """Test read/write through socket pair.""" + import socket + import threading + + data = {"id": 1, "command": "test"} + received = [] + + # Create socket pair + server, client = socket.socketpair() + + def reader(): + received.append(ControlProtocol.read_message(server)) + + t = threading.Thread(target=reader) + t.start() + + ControlProtocol.write_message(client, data) + t.join(timeout=2.0) + + client.close() + server.close() + + assert len(received) == 1 + assert received[0] == data + + def test_read_connection_closed(self): + """Test reading from closed connection.""" + import socket + + server, client = socket.socketpair() + client.close() + + with pytest.raises(ConnectionError): + ControlProtocol.read_message(server) + + server.close() + + def test_read_message_too_large(self): + """Test reading message exceeding max size.""" + import socket + + server, client = socket.socketpair() + + # Send a length that exceeds MAX_MESSAGE_SIZE + huge_length = ControlProtocol.MAX_MESSAGE_SIZE + 1 + client.send(struct.pack('>I', huge_length)) + + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.read_message(server) + assert "too large" in str(exc_info.value) + + client.close() + server.close() + + +class TestControlProtocolAsync: + """Tests for async protocol methods.""" + + @pytest.mark.asyncio + async def test_async_read_write(self): + """Test async read/write using a unix server.""" + import asyncio + import tempfile + import os + + data = {"id": 1, "command": "async test"} + received = [] + + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + async def handler(reader, writer): + msg = await ControlProtocol.read_message_async(reader) + received.append(msg) + await ControlProtocol.write_message_async(writer, data) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_unix_server(handler, path=socket_path) + + async with server: + reader, writer = await asyncio.open_unix_connection(socket_path) + await ControlProtocol.write_message_async(writer, data) + response = await ControlProtocol.read_message_async(reader) + writer.close() + await writer.wait_closed() + + assert len(received) == 1 + assert received[0] == data + assert response == data + + +class TestProtocolMaxSize: + """Tests for protocol size limits.""" + + def test_max_message_size_constant(self): + """Test that MAX_MESSAGE_SIZE is set to a reasonable value.""" + # Should be 16 MB + assert ControlProtocol.MAX_MESSAGE_SIZE == 16 * 1024 * 1024 + + def test_encode_large_message(self): + """Test encoding a large (but valid) message.""" + # Create a message with ~1MB of data + data = {"data": "x" * (1024 * 1024)} + encoded = ControlProtocol.encode_message(data) + + # Should succeed and be decodable + decoded = ControlProtocol.decode_message(encoded) + assert decoded == data diff --git a/tests/ctl/test_server.py b/tests/ctl/test_server.py new file mode 100644 index 00000000..d52ae2ee --- /dev/null +++ b/tests/ctl/test_server.py @@ -0,0 +1,348 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket server.""" + +import os +import tempfile +import time +from unittest.mock import MagicMock + +import pytest + +from gunicorn.ctl.server import ControlSocketServer +from gunicorn.ctl.client import ControlClient + + +class MockWorker: + """Mock worker for testing.""" + + def __init__(self, pid, age, booted=True, aborted=False): + self.pid = pid + self.age = age + self.booted = booted + self.aborted = aborted + self.tmp = MagicMock() + self.tmp.last_update.return_value = time.monotonic() + + +class MockConfig: + """Mock config for testing.""" + + def __init__(self): + self.bind = ['127.0.0.1:8000'] + self.workers = 4 + self.worker_class = 'sync' + self.threads = 1 + self.timeout = 30 + self.graceful_timeout = 30 + self.keepalive = 2 + self.max_requests = 0 + self.max_requests_jitter = 0 + self.worker_connections = 1000 + self.preload_app = False + self.daemon = False + self.pidfile = None + self.proc_name = 'test_app' + self.reload = False + self.dirty_workers = 0 + self.dirty_apps = [] + self.dirty_timeout = 30 + self.control_socket = 'gunicorn.ctl' + self.control_socket_disable = False + + +class MockLog: + """Mock logger for testing.""" + + def debug(self, msg, *args): + pass + + def info(self, msg, *args): + pass + + def warning(self, msg, *args): + pass + + def error(self, msg, *args): + pass + + def exception(self, msg, *args): + pass + + +class MockArbiter: + """Mock arbiter for testing.""" + + def __init__(self): + self.cfg = MockConfig() + self.log = MockLog() + self.pid = 12345 + self.WORKERS = {} + self.LISTENERS = [] + self.dirty_arbiter_pid = 0 + self.dirty_arbiter = None + self.num_workers = 4 + self._stats = { + 'start_time': time.time() - 3600, + 'workers_spawned': 10, + 'workers_killed': 5, + 'reloads': 2, + } + + def wakeup(self): + pass + + +class TestControlSocketServerInit: + """Tests for server initialization.""" + + def test_init(self): + """Test server initialization.""" + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, "/tmp/test.sock", 0o600) + + assert server.arbiter is arbiter + assert server.socket_path == "/tmp/test.sock" + assert server.socket_mode == 0o600 + assert server._running is False + + +class TestControlSocketServerLifecycle: + """Tests for server start/stop.""" + + def test_start_stop(self): + """Test starting and stopping the server.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + # Wait for server to start + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + assert os.path.exists(socket_path) + + server.stop() + + # Wait for cleanup + time.sleep(0.2) + + # Socket should be cleaned up + assert not os.path.exists(socket_path) + + def test_start_already_running(self): + """Test that start is idempotent.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + first_thread = server._thread + server.start() + + assert server._thread is first_thread + + server.stop() + + def test_stop_not_running(self): + """Test stopping a non-running server.""" + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, "/tmp/test.sock") + + # Should not raise + server.stop() + + +class TestControlSocketServerIntegration: + """Integration tests for server with client.""" + + def test_show_workers(self): + """Test show workers command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + } + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + # Wait for server to start + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("show workers") + + assert result["count"] == 2 + assert len(result["workers"]) == 2 + finally: + server.stop() + + def test_show_stats(self): + """Test show stats command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("show stats") + + assert result["pid"] == 12345 + assert result["workers_spawned"] == 10 + finally: + server.stop() + + def test_help_command(self): + """Test help command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("help") + + assert "commands" in result + assert "show workers" in result["commands"] + finally: + server.stop() + + def test_worker_add(self): + """Test worker add command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("worker add 2") + + assert result["added"] == 2 + assert result["total"] == 6 + assert arbiter.num_workers == 6 + arbiter.wakeup.assert_called() + finally: + server.stop() + + def test_invalid_command(self): + """Test handling invalid command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + with pytest.raises(Exception) as exc_info: + client.send_command("invalid_command") + + assert "Unknown command" in str(exc_info.value) + finally: + server.stop() + + def test_multiple_commands(self): + """Test sending multiple commands on same connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.WORKERS = {1001: MockWorker(1001, 1)} + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result1 = client.send_command("show workers") + result2 = client.send_command("show stats") + result3 = client.send_command("help") + + assert result1["count"] == 1 + assert result2["pid"] == 12345 + assert "commands" in result3 + finally: + server.stop() + + +class TestControlSocketServerPermissions: + """Tests for socket permissions.""" + + def test_socket_permissions(self): + """Test that socket is created with correct permissions.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path, 0o660) + + server.start() + + for _ in range(20): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + try: + mode = os.stat(socket_path).st_mode & 0o777 + assert mode == 0o660 + finally: + server.stop()