diff --git a/docs/content/2026-news.md b/docs/content/2026-news.md index 59a3dcd6..56476e9c 100644 --- a/docs/content/2026-news.md +++ b/docs/content/2026-news.md @@ -1,6 +1,23 @@ # Changelog - 2026 +## 25.2.0 - 2026-02-13 + +### New Features + +- **Control Interface (gunicornc)**: Add interactive control interface for managing + running Gunicorn instances, similar to birdc for BIRD routing daemon + - Unix socket-based communication with JSON protocol + - Interactive mode with readline support and command history + - Commands: `show all/workers/dirty/config/stats/listeners` + - Worker management: `worker add/remove/kill`, `dirty add/remove` + - Server control: `reload`, `reopen`, `shutdown` + - New settings: `--control-socket`, `--control-socket-mode`, `--no-control-socket` + - New CLI tool: `gunicornc` for connecting to control socket + - See [Control Interface Guide](guides/gunicornc.md) for details + +--- + ## 25.1.0 - 2026-02-12 ### New Features diff --git a/docs/content/guides/gunicornc.md b/docs/content/guides/gunicornc.md new file mode 100644 index 00000000..012f0791 --- /dev/null +++ b/docs/content/guides/gunicornc.md @@ -0,0 +1,306 @@ +--- +title: Control Interface (gunicornc) +menu: + guides: + weight: 15 +--- + +# Control Interface (gunicornc) + +Gunicorn provides a control interface similar to [birdc](https://bird.network.cz/?get_doc&v=20&f=bird-3.html) for the BIRD routing daemon. This allows you to inspect and manage a running Gunicorn instance via a Unix socket. + +## Overview + +The control interface consists of two parts: + +1. **Control Socket Server** - Runs in the arbiter process, accepts commands via Unix socket +2. **gunicornc CLI** - Interactive client that connects to the control socket + +## Quick Start + +### Start Gunicorn with Control Socket + +By default, Gunicorn creates a control socket at `gunicorn.ctl` in the current directory: + +```bash +gunicorn -w 4 myapp:app +``` + +Or specify a custom path: + +```bash +gunicorn --control-socket /tmp/myapp.ctl -w 4 myapp:app +``` + +### Connect with gunicornc + +```bash +# Connect to default socket (./gunicorn.ctl) +gunicornc + +# Connect to custom socket +gunicornc -s /tmp/myapp.ctl + +# Run a single command +gunicornc -c "show workers" + +# Output as JSON (for scripting) +gunicornc -c "show stats" -j +``` + +## Interactive Mode + +When run without the `-c` flag, gunicornc enters interactive mode with readline support: + +``` +$ gunicornc +Connected to gunicorn.ctl +Type 'help' for available commands, 'quit' to exit. + +gunicorn> show workers +PID AGE BOOTED LAST_BEAT +---------------------------------------- +12345 1 yes 0.2s ago +12346 2 yes 0.1s ago +12347 3 yes 0.3s ago + +Total: 3 workers + +gunicorn> worker add 2 +{ + "added": 2, + "previous": 3, + "total": 5 +} + +gunicorn> quit +``` + +## Commands + +### Show Commands + +| Command | Description | +|---------|-------------| +| `show all` | Overview of all processes (arbiter, web workers, dirty workers) | +| `show workers` | List HTTP workers with status | +| `show dirty` | List dirty workers and apps | +| `show config` | Show current effective configuration | +| `show stats` | Show server statistics | +| `show listeners` | Show bound sockets | +| `help` | Show available commands | + +### Worker Management + +| Command | Description | +|---------|-------------| +| `worker add [N]` | Spawn N workers (default 1) | +| `worker remove [N]` | Remove N workers (default 1) | +| `worker kill ` | Gracefully terminate specific worker | + +### Dirty Worker Management + +| Command | Description | +|---------|-------------| +| `dirty add [N]` | Spawn N dirty workers (default 1) | +| `dirty remove [N]` | Remove N dirty workers (default 1) | + +!!! note "Per-App Worker Limits" + When using `dirty add`, workers only load apps that haven't reached their + worker limits. If all apps are at their limits, no new workers will be spawned. + The response will include a `reason` field explaining this. + +### Server Control + +| Command | Description | +|---------|-------------| +| `reload` | Graceful reload (equivalent to SIGHUP) | +| `reopen` | Reopen log files (equivalent to SIGUSR1) | +| `shutdown [graceful\|quick]` | Shutdown server (SIGTERM or SIGINT) | + +## Example Session + +``` +$ gunicornc +Connected to gunicorn.ctl +Type 'help' for available commands, 'quit' to exit. + +gunicorn> show all +ARBITER (master) + PID: 12345 + +WEB WORKERS (4) + PID AGE BOOTED LAST_BEAT + -------------------------------------- + 12346 1 yes 0.05s ago + 12347 2 yes 0.04s ago + 12348 3 yes 0.03s ago + 12349 4 yes 0.02s ago + +DIRTY ARBITER + PID: 12350 + +DIRTY WORKERS (2) + PID AGE APPS + -------------------------------------------------- + 12351 1 MLModel + ImageProcessor + 12352 2 MLModel + +gunicorn> show stats +Uptime: 2h 15m 30s +PID: 12345 +Workers current: 4 +Workers target: 4 +Workers spawned: 6 +Workers killed: 2 +Reloads: 1 + +gunicorn> worker add +{ + "added": 1, + "previous": 4, + "total": 5 +} + +gunicorn> dirty add 1 +{ + "success": true, + "operation": "add", + "requested": 1, + "spawned": 1, + "total_workers": 3, + "target_workers": 3 +} + +gunicorn> quit +``` + +## Configuration + +### Settings + +| Setting | CLI Flag | Default | Description | +|---------|----------|---------|-------------| +| `control_socket` | `--control-socket` | `gunicorn.ctl` | Unix socket path | +| `control_socket_mode` | `--control-socket-mode` | `0o600` | Socket file permissions | +| `control_socket_disable` | `--no-control-socket` | `False` | Disable control socket | + +### Example Configuration + +```python +# gunicorn.conf.py +bind = "0.0.0.0:8000" +workers = 4 + +# Control socket settings +control_socket = "/var/run/gunicorn/myapp.ctl" +control_socket_mode = 0o660 # Allow group access +``` + +## Scripting + +Use the `-j` flag for JSON output when scripting: + +```bash +#!/bin/bash + +# Get current worker count +workers=$(gunicornc -c "show stats" -j | jq -r '.workers_current') +echo "Current workers: $workers" + +# Scale up if needed +if [ "$workers" -lt 8 ]; then + gunicornc -c "worker add $((8 - workers))" +fi +``` + +## Security + +The control socket uses filesystem permissions for access control: + +- **Default mode**: `0o600` (owner only) +- **No authentication**: Relies on filesystem permissions +- **Unix socket only**: No TCP/remote access + +To allow group access: + +```python +control_socket_mode = 0o660 +``` + +To disable the control socket entirely: + +```bash +gunicorn --no-control-socket myapp:app +``` + +## Protocol + +The control interface uses a JSON-based protocol with length-prefixed framing: + +``` ++----------------+------------------+ +| Length (4B BE) | JSON Payload | ++----------------+------------------+ +``` + +### Request Format + +```json +{ + "id": 1, + "command": "show workers" +} +``` + +### Response Format + +```json +{ + "id": 1, + "status": "ok", + "data": { ... } +} +``` + +### Error Response + +```json +{ + "id": 1, + "status": "error", + "error": "Unknown command: foo" +} +``` + +## Troubleshooting + +### Cannot connect to socket + +``` +Error: Connection refused +``` + +- Check that Gunicorn is running +- Verify the socket path is correct +- Check socket file permissions + +### Permission denied + +``` +Error: Permission denied +``` + +- Check that you have read/write access to the socket file +- The socket is created with mode `0o600` by default (owner only) + +### Socket not found + +``` +Error: No such file or directory +``` + +- Gunicorn creates the socket relative to the working directory by default +- Use an absolute path with `--control-socket /path/to/socket.ctl` +- Check if `--no-control-socket` was specified diff --git a/docs/content/news.md b/docs/content/news.md index 8125086e..15592160 100644 --- a/docs/content/news.md +++ b/docs/content/news.md @@ -1,6 +1,23 @@ # Changelog +## 25.2.0 - 2026-02-13 + +### New Features + +- **Control Interface (gunicornc)**: Add interactive control interface for managing + running Gunicorn instances, similar to birdc for BIRD routing daemon + - Unix socket-based communication with JSON protocol + - Interactive mode with readline support and command history + - Commands: `show all/workers/dirty/config/stats/listeners` + - Worker management: `worker add/remove/kill`, `dirty add/remove` + - Server control: `reload`, `reopen`, `shutdown` + - New settings: `--control-socket`, `--control-socket-mode`, `--no-control-socket` + - New CLI tool: `gunicornc` for connecting to control socket + - See [Control Interface Guide](guides/gunicornc.md) for details + +--- + ## 25.1.0 - 2026-02-12 ### New Features diff --git a/docs/content/reference/settings.md b/docs/content/reference/settings.md index 18023593..bd9aaa3f 100644 --- a/docs/content/reference/settings.md +++ b/docs/content/reference/settings.md @@ -48,6 +48,53 @@ A WSGI application path in pattern ``$(MODULE_NAME):$(VARIABLE_NAME)``. !!! info "Added in 20.1.0" +## Control + +### `control_socket` + +**Command line:** `--control-socket PATH` + +**Default:** `'gunicorn.ctl'` + +Unix socket path for control interface. + +The control socket allows runtime management of Gunicorn via the +``gunicornc`` command-line tool. Commands include viewing worker +status, adjusting worker count, and graceful reload/shutdown. + +By default, creates ``gunicorn.ctl`` in the working directory. +Set an absolute path for a fixed location (e.g., ``/var/run/gunicorn.ctl``). + +Use ``--no-control-socket`` to disable. + +!!! info "Added in 25.1.0" + +### `control_socket_mode` + +**Command line:** `--control-socket-mode INT` + +**Default:** `384` + +Permission mode for control socket. + +Restricts who can connect to the control socket. Default ``0600`` +allows only the socket owner. Set to ``0660`` to allow group access. + +!!! info "Added in 25.1.0" + +### `control_socket_disable` + +**Command line:** `--no-control-socket` + +**Default:** `False` + +Disable control socket. + +When set, no control socket is created and ``gunicornc`` cannot +connect to this Gunicorn instance. + +!!! info "Added in 25.1.0" + ## Debugging ### `reload` diff --git a/gunicorn/arbiter.py b/gunicorn/arbiter.py index 6200ac3a..cd838d00 100644 --- a/gunicorn/arbiter.py +++ b/gunicorn/arbiter.py @@ -74,6 +74,17 @@ class Arbiter: self.dirty_arbiter = None self.dirty_pidfile = None # Well-known location for orphan detection + # Control socket server + self._control_server = None + + # Stats tracking + self._stats = { + 'start_time': None, + 'workers_spawned': 0, + 'workers_killed': 0, + 'reloads': 0, + } + cwd = util.getcwd() args = sys.argv[:] @@ -133,6 +144,9 @@ class Arbiter: """ self.log.info("Starting gunicorn %s", __version__) + # Initialize stats tracking + self._stats['start_time'] = time.time() + if 'GUNICORN_PID' in os.environ: self.master_pid = int(os.environ.get('GUNICORN_PID')) self.proc_name = self.proc_name + ".2" @@ -179,6 +193,9 @@ class Arbiter: if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps: self.spawn_dirty_arbiter() + # Start control socket server + self._start_control_server() + self.cfg.when_ready(self) def init_signals(self): @@ -351,6 +368,9 @@ class Arbiter: def halt(self, reason=None, exit_status=0): """ halt arbiter """ + # Stop control socket server first + self._stop_control_server() + self.stop() log_func = self.log.info if exit_status == 0 else self.log.error @@ -477,6 +497,9 @@ class Arbiter: os.execvpe(self.START_CTX[0], self.START_CTX['args'], environ) def reload(self): + # Track reload stats + self._stats['reloads'] += 1 + old_address = self.cfg.address # reset old environment @@ -667,6 +690,7 @@ class Arbiter: if pid != 0: worker.pid = pid self.WORKERS[pid] = worker + self._stats['workers_spawned'] += 1 return pid # Do not inherit the temporary files of other workers @@ -737,6 +761,9 @@ class Arbiter: """ try: os.kill(pid, sig) + # Track kills only on SIGTERM/SIGKILL (actual termination signals) + if sig in (signal.SIGTERM, signal.SIGKILL): + self._stats['workers_killed'] += 1 except OSError as e: if e.errno == errno.ESRCH: try: @@ -906,3 +933,51 @@ class Arbiter: if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps: self.log.info("Spawning dirty arbiter...") self.spawn_dirty_arbiter() + + # ========================================================================= + # Control Socket Management + # ========================================================================= + + def _get_control_socket_path(self): + """Get the control socket path, making relative paths absolute.""" + socket_path = self.cfg.control_socket + if not os.path.isabs(socket_path): + socket_path = os.path.join(util.getcwd(), socket_path) + return socket_path + + def _start_control_server(self): + """\ + Start the control socket server. + + The server runs in a background thread and accepts commands + via Unix socket. + """ + if self.cfg.control_socket_disable: + self.log.debug("Control socket disabled") + return + + # Lazy import to avoid circular imports and gevent compatibility + from gunicorn.ctl.server import ControlSocketServer + + socket_path = self._get_control_socket_path() + socket_mode = self.cfg.control_socket_mode + + try: + self._control_server = ControlSocketServer( + self, socket_path, socket_mode + ) + self._control_server.start() + except Exception as e: + self.log.warning("Failed to start control socket: %s", e) + self._control_server = None + + def _stop_control_server(self): + """\ + Stop the control socket server. + """ + if self._control_server: + try: + self._control_server.stop() + except Exception as e: + self.log.debug("Error stopping control server: %s", e) + self._control_server = None diff --git a/gunicorn/config.py b/gunicorn/config.py index c391ae41..fa399f03 100644 --- a/gunicorn/config.py +++ b/gunicorn/config.py @@ -3113,3 +3113,63 @@ class DirtyWorkerExit(Setting): .. versionadded:: 25.0.0 """ + + +# Control Socket Settings + +class ControlSocket(Setting): + name = "control_socket" + section = "Control" + cli = ["--control-socket"] + meta = "PATH" + validator = validate_string + default = "gunicorn.ctl" + desc = """\ + Unix socket path for control interface. + + The control socket allows runtime management of Gunicorn via the + ``gunicornc`` command-line tool. Commands include viewing worker + status, adjusting worker count, and graceful reload/shutdown. + + By default, creates ``gunicorn.ctl`` in the working directory. + Set an absolute path for a fixed location (e.g., ``/var/run/gunicorn.ctl``). + + Use ``--no-control-socket`` to disable. + + .. versionadded:: 25.1.0 + """ + + +class ControlSocketMode(Setting): + name = "control_socket_mode" + section = "Control" + cli = ["--control-socket-mode"] + meta = "INT" + validator = validate_pos_int + type = auto_int + default = 0o600 + desc = """\ + Permission mode for control socket. + + Restricts who can connect to the control socket. Default ``0600`` + allows only the socket owner. Set to ``0660`` to allow group access. + + .. versionadded:: 25.1.0 + """ + + +class ControlSocketDisable(Setting): + name = "control_socket_disable" + section = "Control" + cli = ["--no-control-socket"] + validator = validate_bool + action = "store_true" + default = False + desc = """\ + Disable control socket. + + When set, no control socket is created and ``gunicornc`` cannot + connect to this Gunicorn instance. + + .. versionadded:: 25.1.0 + """ diff --git a/gunicorn/ctl/__init__.py b/gunicorn/ctl/__init__.py new file mode 100644 index 00000000..968d9d31 --- /dev/null +++ b/gunicorn/ctl/__init__.py @@ -0,0 +1,16 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Gunicorn Control Interface + +Provides a control socket server for runtime management and +a CLI client (gunicornc) for interacting with running Gunicorn instances. +""" + +from gunicorn.ctl.server import ControlSocketServer +from gunicorn.ctl.client import ControlClient +from gunicorn.ctl.protocol import ControlProtocol + +__all__ = ['ControlSocketServer', 'ControlClient', 'ControlProtocol'] diff --git a/gunicorn/ctl/cli.py b/gunicorn/ctl/cli.py new file mode 100644 index 00000000..6e8f0783 --- /dev/null +++ b/gunicorn/ctl/cli.py @@ -0,0 +1,449 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +gunicornc - Gunicorn control interface CLI + +Interactive and single-command modes for controlling Gunicorn instances. +""" + +import argparse +import json +import os +import sys + +from gunicorn.ctl.client import ControlClient, ControlClientError, parse_command + + +def format_workers(data: dict) -> str: + """Format workers output for display.""" + workers = data.get("workers", []) + if not workers: + return "No workers running" + + lines = [] + lines.append(f"{'PID':<10} {'AGE':<6} {'BOOTED':<8} {'LAST_BEAT'}") + lines.append("-" * 40) + + for w in workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + booted = "yes" if w.get("booted") else "no" + hb = w.get("last_heartbeat") + hb_str = f"{hb}s ago" if hb is not None else "n/a" + + lines.append(f"{pid:<10} {age:<6} {booted:<8} {hb_str}") + + lines.append("") + lines.append(f"Total: {data.get('count', len(workers))} workers") + + return "\n".join(lines) + + +def format_dirty(data: dict) -> str: + """Format dirty workers output for display.""" + if not data.get("enabled"): + return "Dirty arbiter not running" + + lines = [] + lines.append(f"Dirty arbiter PID: {data.get('pid')}") + lines.append("") + + workers = data.get("workers", []) + if workers: + lines.append("DIRTY WORKERS:") + lines.append(f"{'PID':<10} {'AGE':<6} {'APPS':<30} {'LAST_BEAT'}") + lines.append("-" * 60) + + for w in workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + apps = ", ".join(w.get("apps", []))[:30] + hb = w.get("last_heartbeat") + hb_str = f"{hb}s ago" if hb is not None else "n/a" + + lines.append(f"{pid:<10} {age:<6} {apps:<30} {hb_str}") + lines.append("") + + apps = data.get("apps", []) + if apps: + lines.append("DIRTY APPS:") + lines.append(f"{'APP':<30} {'WORKERS':<10} {'LIMIT'}") + lines.append("-" * 50) + + for app in apps: + path = app.get("import_path", "?")[:30] + current = app.get("current_workers", 0) + limit = app.get("worker_count") + limit_str = str(limit) if limit is not None else "none" + + lines.append(f"{path:<30} {current:<10} {limit_str}") + + return "\n".join(lines) + + +def format_stats(data: dict) -> str: + """Format stats output for display.""" + lines = [] + + uptime = data.get("uptime") + if uptime: + hours = int(uptime // 3600) + minutes = int((uptime % 3600) // 60) + seconds = int(uptime % 60) + if hours: + uptime_str = f"{hours}h {minutes}m {seconds}s" + elif minutes: + uptime_str = f"{minutes}m {seconds}s" + else: + uptime_str = f"{seconds}s" + else: + uptime_str = "unknown" + + lines.append(f"Uptime: {uptime_str}") + lines.append(f"PID: {data.get('pid', 'unknown')}") + lines.append(f"Workers current: {data.get('workers_current', 0)}") + lines.append(f"Workers target: {data.get('workers_target', 0)}") + lines.append(f"Workers spawned: {data.get('workers_spawned', 0)}") + lines.append(f"Workers killed: {data.get('workers_killed', 0)}") + lines.append(f"Reloads: {data.get('reloads', 0)}") + + dirty_pid = data.get("dirty_arbiter_pid") + if dirty_pid: + lines.append(f"Dirty arbiter: {dirty_pid}") + + return "\n".join(lines) + + +def format_listeners(data: dict) -> str: + """Format listeners output for display.""" + listeners = data.get("listeners", []) + if not listeners: + return "No listeners bound" + + lines = [] + lines.append(f"{'ADDRESS':<40} {'TYPE':<8} {'FD'}") + lines.append("-" * 55) + + for lnr in listeners: + addr = lnr.get("address", "?") + ltype = lnr.get("type", "?") + fd = lnr.get("fd", "?") + lines.append(f"{addr:<40} {ltype:<8} {fd}") + + lines.append("") + lines.append(f"Total: {data.get('count', len(listeners))} listeners") + + return "\n".join(lines) + + +def format_config(data: dict) -> str: + """Format config output for display.""" + lines = [] + + # Sort keys for consistent output + for key in sorted(data.keys()): + value = data[key] + if isinstance(value, list): + value = ", ".join(str(v) for v in value) + lines.append(f"{key}: {value}") + + return "\n".join(lines) + + +def format_help(data: dict) -> str: + """Format help output for display.""" + commands = data.get("commands", {}) + lines = [] + lines.append("Available commands:") + lines.append("") + + # Find max command length for alignment + max_len = max(len(cmd) for cmd in commands.keys()) if commands else 0 + + for cmd, desc in sorted(commands.items()): + lines.append(f" {cmd:<{max_len + 2}} {desc}") + + return "\n".join(lines) + + +def format_all(data: dict) -> str: + """Format show all output for display.""" + lines = [] + + # Arbiter + arbiter = data.get("arbiter", {}) + lines.append("ARBITER (master)") + lines.append(f" PID: {arbiter.get('pid', '?')}") + lines.append("") + + # Web workers + web_workers = data.get("web_workers", []) + lines.append(f"WEB WORKERS ({data.get('web_worker_count', 0)})") + if web_workers: + lines.append(f" {'PID':<10} {'AGE':<6} {'BOOTED':<8} {'LAST_BEAT'}") + lines.append(f" {'-' * 38}") + for w in web_workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + booted = "yes" if w.get("booted") else "no" + hb = w.get("last_heartbeat") + hb_str = f"{hb}s ago" if hb is not None else "n/a" + lines.append(f" {pid:<10} {age:<6} {booted:<8} {hb_str}") + else: + lines.append(" (none)") + lines.append("") + + # Dirty arbiter + dirty_arbiter = data.get("dirty_arbiter") + if dirty_arbiter: + lines.append("DIRTY ARBITER") + lines.append(f" PID: {dirty_arbiter.get('pid', '?')}") + lines.append("") + + # Dirty workers + dirty_workers = data.get("dirty_workers", []) + lines.append(f"DIRTY WORKERS ({data.get('dirty_worker_count', 0)})") + if dirty_workers: + lines.append(f" {'PID':<10} {'AGE':<6} {'APPS'}") + lines.append(f" {'-' * 50}") + for w in dirty_workers: + pid = w.get("pid", "?") + age = w.get("age", "?") + apps = w.get("apps", []) + # Show each app on its own line if multiple + if apps: + first_app = apps[0].split(":")[-1] # Just the class name + lines.append(f" {pid:<10} {age:<6} {first_app}") + for app in apps[1:]: + app_name = app.split(":")[-1] + lines.append(f" {'':<10} {'':<6} {app_name}") + else: + lines.append(f" {pid:<10} {age:<6} (no apps)") + else: + lines.append(" (none)") + else: + lines.append("DIRTY ARBITER") + lines.append(" (not running)") + + return "\n".join(lines) + + +def format_response(command: str, data: dict) -> str: # pylint: disable=too-many-return-statements + """ + Format response data based on command. + + Args: + command: Original command string + data: Response data dictionary + + Returns: + Formatted string for display + """ + cmd_lower = command.lower().strip() + + # Route to specific formatters + if cmd_lower == "show all": + return format_all(data) + elif cmd_lower == "show workers": + return format_workers(data) + elif cmd_lower == "show dirty": + return format_dirty(data) + elif cmd_lower == "show stats": + return format_stats(data) + elif cmd_lower == "show listeners": + return format_listeners(data) + elif cmd_lower == "show config": + return format_config(data) + elif cmd_lower == "help": + return format_help(data) + else: + # Generic JSON output for other commands + if data: + return json.dumps(data, indent=2) + return "OK" + + +def run_command(socket_path: str, command: str, json_output: bool = False) -> int: + """ + Execute single command and exit. + + Args: + socket_path: Path to control socket + command: Command to execute + json_output: If True, output raw JSON + + Returns: + Exit code (0 for success, 1 for error) + """ + try: + with ControlClient(socket_path) as client: + cmd, args = parse_command(command) + full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd + result = client.send_command(full_command) + + if json_output: + print(json.dumps(result, indent=2)) + else: + output = format_response(cmd, result) + print(output) + + return 0 + + except ControlClientError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + except KeyboardInterrupt: + return 130 + + +def run_interactive(socket_path: str, json_output: bool = False) -> int: + """ + Run interactive CLI with readline support. + + Args: + socket_path: Path to control socket + json_output: If True, output raw JSON + + Returns: + Exit code + """ + try: + import readline # noqa: F401 - imported for side effects + has_readline = True + except ImportError: + has_readline = False + + try: + client = ControlClient(socket_path) + client.connect() + except ControlClientError as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + print(f"Connected to {socket_path}") + print("Type 'help' for available commands, 'quit' to exit.") + print() + + # Set up readline history + history_file = os.path.expanduser("~/.gunicornc_history") + if has_readline: + try: + readline.read_history_file(history_file) + except FileNotFoundError: + pass + + exit_code = 0 + + try: + while True: + try: + line = input("gunicorn> ").strip() + except EOFError: + print() + break + + if not line: + continue + + if line.lower() in ('quit', 'exit', 'q'): + break + + try: + cmd, args = parse_command(line) + full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd + result = client.send_command(full_command) + + if json_output: + print(json.dumps(result, indent=2)) + else: + output = format_response(cmd, result) + print(output) + + except ControlClientError as e: + print(f"Error: {e}") + # Try to reconnect + try: + client.close() + client.connect() + except ControlClientError: + print("Connection lost. Exiting.") + exit_code = 1 + break + + print() + + except KeyboardInterrupt: + print() + exit_code = 130 + finally: + client.close() + if has_readline: + try: + readline.write_history_file(history_file) + except Exception: + pass + + return exit_code + + +def main(): + """Main entry point for gunicornc CLI.""" + parser = argparse.ArgumentParser( + description='Gunicorn control interface', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + gunicornc # Interactive mode (default socket) + gunicornc -s /tmp/myapp.ctl # Interactive mode with custom socket + gunicornc -c "show workers" # Single command mode + gunicornc -c "worker add 2" # Add 2 workers + gunicornc -c "show stats" -j # Output stats as JSON + """ + ) + + parser.add_argument( + '-s', '--socket', + default='gunicorn.ctl', + help='Control socket path (default: gunicorn.ctl in current directory)' + ) + + parser.add_argument( + '-c', '--command', + help='Execute single command and exit' + ) + + parser.add_argument( + '-j', '--json', + action='store_true', + help='Output raw JSON (for scripting)' + ) + + parser.add_argument( + '-v', '--version', + action='store_true', + help='Show version and exit' + ) + + args = parser.parse_args() + + if args.version: + from gunicorn import __version__ + print(f"gunicornc (gunicorn {__version__})") + return 0 + + socket_path = args.socket + + # Make relative paths absolute from cwd + if not os.path.isabs(socket_path): + socket_path = os.path.join(os.getcwd(), socket_path) + + if args.command: + return run_command(socket_path, args.command, args.json) + else: + return run_interactive(socket_path, args.json) + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/gunicorn/ctl/client.py b/gunicorn/ctl/client.py new file mode 100644 index 00000000..cc9d39fa --- /dev/null +++ b/gunicorn/ctl/client.py @@ -0,0 +1,139 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Client + +Client library for connecting to gunicorn control socket. +""" + +import shlex +import socket + +from gunicorn.ctl.protocol import ( + ControlProtocol, + make_request, +) + + +class ControlClientError(Exception): + """Control client error.""" + + +class ControlClient: + """ + Client for connecting to gunicorn control socket. + + Can be used as a context manager: + + with ControlClient('/path/to/gunicorn.ctl') as client: + result = client.send_command('show workers') + """ + + def __init__(self, socket_path: str, timeout: float = 30.0): + """ + Initialize control client. + + Args: + socket_path: Path to the Unix socket + timeout: Socket timeout in seconds (default 30) + """ + self.socket_path = socket_path + self.timeout = timeout + self._sock = None + self._request_id = 0 + + def connect(self): + """ + Connect to control socket. + + Raises: + ControlClientError: If connection fails + """ + if self._sock: + return + + try: + self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self._sock.settimeout(self.timeout) + self._sock.connect(self.socket_path) + except socket.error as e: + self._sock = None + raise ControlClientError(f"Failed to connect to {self.socket_path}: {e}") + + def close(self): + """Close connection.""" + if self._sock: + try: + self._sock.close() + except Exception: + pass + self._sock = None + + def send_command(self, command: str, args: list = None) -> dict: + """ + Send command and wait for response. + + Args: + command: Command string (e.g., "show workers") + args: Optional additional arguments + + Returns: + Response data dictionary + + Raises: + ControlClientError: If communication fails + """ + if not self._sock: + self.connect() + + self._request_id += 1 + request = make_request(self._request_id, command, args) + + try: + ControlProtocol.write_message(self._sock, request) + response = ControlProtocol.read_message(self._sock) + except Exception as e: + self.close() + raise ControlClientError(f"Communication error: {e}") + + if response.get("status") == "error": + raise ControlClientError(response.get("error", "Unknown error")) + + return response.get("data", {}) + + def __enter__(self): + self.connect() + return self + + def __exit__(self, *args): + self.close() + + +def parse_command(line: str) -> tuple: + """ + Parse a command line into command and args. + + Args: + line: Command line string + + Returns: + Tuple of (command_string, args_list) + """ + parts = shlex.split(line) + if not parts: + return "", [] + + # Find where numeric/value args start + command_parts = [] + args = [] + + for part in parts: + # If we haven't hit args yet and this looks like a command word + if not args and not part.isdigit() and not part.startswith('-'): + command_parts.append(part) + else: + args.append(part) + + return " ".join(command_parts), args diff --git a/gunicorn/ctl/handlers.py b/gunicorn/ctl/handlers.py new file mode 100644 index 00000000..827f62ee --- /dev/null +++ b/gunicorn/ctl/handlers.py @@ -0,0 +1,585 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Interface Command Handlers + +Provides handlers for all control commands with access to arbiter state. +""" + +import os +import signal +import socket +import time + + +class CommandHandlers: + """ + Command handlers with access to arbiter state. + + All handler methods return dictionaries that will be sent + as the response data. + """ + + def __init__(self, arbiter): + """ + Initialize handlers with arbiter reference. + + Args: + arbiter: The Gunicorn arbiter instance + """ + self.arbiter = arbiter + + def show_workers(self) -> dict: + """ + Return list of HTTP workers. + + Returns: + Dictionary with workers list containing: + - pid: Worker process ID + - age: Worker age (spawn order) + - requests: Number of requests handled (if available) + - booted: Whether worker has finished booting + - last_heartbeat: Seconds since last heartbeat + """ + workers = [] + now = time.monotonic() + + for pid, worker in self.arbiter.WORKERS.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError): + last_heartbeat = None + + workers.append({ + "pid": pid, + "age": worker.age, + "booted": worker.booted, + "aborted": worker.aborted, + "last_heartbeat": last_heartbeat, + }) + + # Sort by age (oldest first) + workers.sort(key=lambda w: w["age"]) + + return {"workers": workers, "count": len(workers)} + + def show_dirty(self) -> dict: + """ + Return dirty workers and apps information. + + Returns: + Dictionary with: + - enabled: Whether dirty arbiter is running + - pid: Dirty arbiter PID + - workers: List of dirty worker info + - apps: List of dirty app specs + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "enabled": False, + "pid": None, + "workers": [], + "apps": [], + } + + # Get dirty arbiter reference if available + dirty_arbiter = getattr(self.arbiter, 'dirty_arbiter', None) + + workers = [] + apps = [] + + if dirty_arbiter and hasattr(dirty_arbiter, 'workers'): + now = time.monotonic() + for pid, worker in dirty_arbiter.workers.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError, AttributeError): + last_heartbeat = None + + workers.append({ + "pid": pid, + "age": worker.age, + "apps": getattr(worker, 'app_paths', []), + "booted": getattr(worker, 'booted', False), + "last_heartbeat": last_heartbeat, + }) + + # Get app specs + if hasattr(dirty_arbiter, 'app_specs'): + for path, spec in dirty_arbiter.app_specs.items(): + worker_pids = list(dirty_arbiter.app_worker_map.get(path, [])) + apps.append({ + "import_path": path, + "worker_count": spec.get('worker_count'), + "current_workers": len(worker_pids), + "worker_pids": worker_pids, + }) + + return { + "enabled": True, + "pid": self.arbiter.dirty_arbiter_pid, + "workers": workers, + "apps": apps, + } + + def show_config(self) -> dict: + """ + Return current effective configuration. + + Returns: + Dictionary of configuration values + """ + cfg = self.arbiter.cfg + config = {} + + # Get commonly needed config values + config_keys = [ + 'bind', 'workers', 'worker_class', 'threads', 'timeout', + 'graceful_timeout', 'keepalive', 'max_requests', + 'max_requests_jitter', 'worker_connections', 'preload_app', + 'daemon', 'pidfile', 'proc_name', 'reload', + 'dirty_workers', 'dirty_apps', 'dirty_timeout', + 'control_socket', 'control_socket_disable', + ] + + for key in config_keys: + try: + value = getattr(cfg, key) + # Convert non-serializable types + if callable(value): + value = str(value) + elif hasattr(value, '__class__') and not isinstance( + value, (str, int, float, bool, list, dict, type(None))): + value = str(value) + config[key] = value + except AttributeError: + pass + + return config + + def show_stats(self) -> dict: + """ + Return server statistics. + + Returns: + Dictionary with: + - uptime: Seconds since arbiter started + - pid: Arbiter PID + - workers_current: Current number of workers + - workers_spawned: Total workers spawned + - workers_killed: Total workers killed (if tracked) + - reloads: Number of reloads (if tracked) + """ + stats = getattr(self.arbiter, '_stats', {}) + start_time = stats.get('start_time') + + uptime = None + if start_time: + uptime = round(time.time() - start_time, 2) + + return { + "uptime": uptime, + "pid": self.arbiter.pid, + "workers_current": len(self.arbiter.WORKERS), + "workers_target": self.arbiter.num_workers, + "workers_spawned": stats.get('workers_spawned', 0), + "workers_killed": stats.get('workers_killed', 0), + "reloads": stats.get('reloads', 0), + "dirty_arbiter_pid": self.arbiter.dirty_arbiter_pid or None, + } + + def show_listeners(self) -> dict: + """ + Return bound socket information. + + Returns: + Dictionary with listeners list + """ + listeners = [] + + for lnr in self.arbiter.LISTENERS: + addr = str(lnr) + listener_info = { + "address": addr, + "fd": lnr.fileno(), + } + + # Try to get socket family + try: + sock = lnr.sock + if sock.family == socket.AF_UNIX: + listener_info["type"] = "unix" + elif sock.family == socket.AF_INET: + listener_info["type"] = "tcp" + elif sock.family == socket.AF_INET6: + listener_info["type"] = "tcp6" + except Exception: + listener_info["type"] = "unknown" + + listeners.append(listener_info) + + return {"listeners": listeners, "count": len(listeners)} + + def worker_add(self, count: int = 1) -> dict: + """ + Increase worker count. + + Args: + count: Number of workers to add (default 1) + + Returns: + Dictionary with added count and new total + """ + count = max(1, int(count)) + old_count = self.arbiter.num_workers + self.arbiter.num_workers += count + + # Wake up the arbiter to spawn workers + self.arbiter.wakeup() + + return { + "added": count, + "previous": old_count, + "total": self.arbiter.num_workers, + } + + def worker_remove(self, count: int = 1) -> dict: + """ + Decrease worker count. + + Args: + count: Number of workers to remove (default 1) + + Returns: + Dictionary with removed count and new total + """ + count = max(1, int(count)) + old_count = self.arbiter.num_workers + + # Don't go below 1 worker + new_count = max(1, old_count - count) + actual_removed = old_count - new_count + + self.arbiter.num_workers = new_count + + # Wake up the arbiter to kill excess workers + self.arbiter.wakeup() + + return { + "removed": actual_removed, + "previous": old_count, + "total": new_count, + } + + def worker_kill(self, pid: int) -> dict: + """ + Gracefully terminate a specific worker. + + Args: + pid: Worker process ID + + Returns: + Dictionary with killed PID or error + """ + pid = int(pid) + + if pid not in self.arbiter.WORKERS: + return { + "success": False, + "error": f"Worker {pid} not found", + } + + try: + os.kill(pid, signal.SIGTERM) + return { + "success": True, + "killed": pid, + } + except OSError as e: + return { + "success": False, + "error": str(e), + } + + def dirty_add(self, count: int = 1) -> dict: + """ + Spawn additional dirty workers. + + Sends a MANAGE message to the dirty arbiter to spawn workers. + + Args: + count: Number of dirty workers to add (default 1) + + Returns: + Dictionary with added count or error + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "success": False, + "error": "Dirty arbiter not running", + } + + count = max(1, int(count)) + return self._send_manage_message("add", count) + + def dirty_remove(self, count: int = 1) -> dict: + """ + Remove dirty workers. + + Sends a MANAGE message to the dirty arbiter to remove workers. + + Args: + count: Number of dirty workers to remove (default 1) + + Returns: + Dictionary with removed count or error + """ + if not self.arbiter.dirty_arbiter_pid: + return { + "success": False, + "error": "Dirty arbiter not running", + } + + count = max(1, int(count)) + return self._send_manage_message("remove", count) + + def _send_manage_message(self, operation: str, count: int) -> dict: + """ + Send a worker management message to the dirty arbiter. + + Args: + operation: "add" or "remove" + count: Number of workers to add/remove + + Returns: + Dictionary with result or error + """ + # Get socket path from arbiter object or environment + dirty_socket_path = None + if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter: + dirty_socket_path = getattr( + self.arbiter.dirty_arbiter, 'socket_path', None + ) + if not dirty_socket_path: + dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET') + if not dirty_socket_path: + return { + "success": False, + "error": "Cannot find dirty arbiter socket path", + } + + try: + from gunicorn.dirty.protocol import ( + DirtyProtocol, MANAGE_OP_ADD, MANAGE_OP_REMOVE + ) + + op = MANAGE_OP_ADD if operation == "add" else MANAGE_OP_REMOVE + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(10.0) + sock.connect(dirty_socket_path) + + # Send manage request + request = { + "type": DirtyProtocol.MSG_TYPE_MANAGE, + "id": 1, + "op": op, + "count": count, + } + DirtyProtocol.write_message(sock, request) + + # Read response + response = DirtyProtocol.read_message(sock) + sock.close() + + if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE: + return response.get("result", {"success": True}) + elif response.get("type") == DirtyProtocol.MSG_TYPE_ERROR: + error = response.get("error", {}) + return { + "success": False, + "error": error.get("message", str(error)), + } + else: + return { + "success": False, + "error": f"Unexpected response type: {response.get('type')}", + } + + except Exception as e: + return { + "success": False, + "error": str(e), + } + + def reload(self) -> dict: + """ + Trigger graceful reload (equivalent to SIGHUP). + + Returns: + Dictionary with status + """ + # Send HUP to self to trigger reload + os.kill(self.arbiter.pid, signal.SIGHUP) + return {"status": "reloading"} + + def reopen(self) -> dict: + """ + Reopen log files (equivalent to SIGUSR1). + + Returns: + Dictionary with status + """ + os.kill(self.arbiter.pid, signal.SIGUSR1) + return {"status": "reopening"} + + def shutdown(self, mode: str = "graceful") -> dict: + """ + Initiate shutdown. + + Args: + mode: "graceful" (SIGTERM) or "quick" (SIGINT) + + Returns: + Dictionary with status + """ + if mode == "quick": + os.kill(self.arbiter.pid, signal.SIGINT) + else: + os.kill(self.arbiter.pid, signal.SIGTERM) + + return {"status": "shutting_down", "mode": mode} + + def show_all(self) -> dict: + """ + Return overview of all processes (arbiter, web workers, dirty arbiter, dirty workers). + + Returns: + Dictionary with complete process hierarchy + """ + now = time.monotonic() + + # Arbiter info + arbiter_info = { + "pid": self.arbiter.pid, + "type": "arbiter", + "role": "master", + } + + # Web workers (HTTP workers) + web_workers = [] + for pid, worker in self.arbiter.WORKERS.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError): + last_heartbeat = None + + web_workers.append({ + "pid": pid, + "type": "web", + "age": worker.age, + "booted": worker.booted, + "last_heartbeat": last_heartbeat, + }) + + # Sort by age + web_workers.sort(key=lambda w: w["age"]) + + # Dirty arbiter info (runs in separate process) + dirty_arbiter_info = None + dirty_workers = [] + + if self.arbiter.dirty_arbiter_pid: + dirty_arbiter_info = { + "pid": self.arbiter.dirty_arbiter_pid, + "type": "dirty_arbiter", + "role": "dirty master", + } + + # Query dirty arbiter for worker info via its socket + dirty_workers = self._query_dirty_workers() + + return { + "arbiter": arbiter_info, + "web_workers": web_workers, + "web_worker_count": len(web_workers), + "dirty_arbiter": dirty_arbiter_info, + "dirty_workers": dirty_workers, + "dirty_worker_count": len(dirty_workers), + } + + def _query_dirty_workers(self) -> list: + """ + Query the dirty arbiter for worker information. + + Connects to the dirty arbiter socket and sends a status request. + + Returns: + List of dirty worker info dicts, or empty list on error + """ + # Get socket path from arbiter object or environment + dirty_socket_path = None + if hasattr(self.arbiter, 'dirty_arbiter') and self.arbiter.dirty_arbiter: + dirty_socket_path = getattr(self.arbiter.dirty_arbiter, 'socket_path', None) + if not dirty_socket_path: + dirty_socket_path = os.environ.get('GUNICORN_DIRTY_SOCKET') + if not dirty_socket_path: + return [] + + try: + from gunicorn.dirty.protocol import DirtyProtocol + + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(2.0) + sock.connect(dirty_socket_path) + + # Send status request + request = { + "type": DirtyProtocol.MSG_TYPE_STATUS, + "id": "ctl-status-1", + } + DirtyProtocol.write_message(sock, request) + + # Read response + response = DirtyProtocol.read_message(sock) + sock.close() + + if response.get("type") == DirtyProtocol.MSG_TYPE_RESPONSE: + result = response.get("result", {}) + return result.get("workers", []) + + except Exception: + pass + + return [] + + def help(self) -> dict: + """ + Return list of available commands. + + Returns: + Dictionary with commands and descriptions + """ + commands = { + "show all": "Show all processes (arbiter, web workers, dirty workers)", + "show workers": "List HTTP workers with their status", + "show dirty": "List dirty workers and apps", + "show config": "Show current effective configuration", + "show stats": "Show server statistics", + "show listeners": "Show bound sockets", + "worker add [N]": "Spawn N workers (default 1)", + "worker remove [N]": "Remove N workers (default 1)", + "worker kill ": "Gracefully terminate specific worker", + "dirty add [N]": "Spawn N dirty workers (default 1)", + "dirty remove [N]": "Remove N dirty workers (default 1)", + "reload": "Graceful reload (HUP)", + "reopen": "Reopen log files (USR1)", + "shutdown [graceful|quick]": "Shutdown server (TERM/INT)", + "help": "Show this help message", + } + return {"commands": commands} diff --git a/gunicorn/ctl/protocol.py b/gunicorn/ctl/protocol.py new file mode 100644 index 00000000..36d2fe78 --- /dev/null +++ b/gunicorn/ctl/protocol.py @@ -0,0 +1,224 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Protocol + +JSON-based protocol with length-prefixed framing for the control interface. + +Message Format: + +----------------+------------------+ + | Length (4B BE) | JSON Payload | + +----------------+------------------+ + +Request Format: + {"id": 1, "command": "show", "args": ["workers"]} + +Response Format: + {"id": 1, "status": "ok", "data": {...}} + {"id": 1, "status": "error", "error": "message"} +""" + +import json +import struct + + +class ProtocolError(Exception): + """Protocol-level error.""" + + +class ControlProtocol: + """ + Protocol implementation for control socket communication. + + Uses 4-byte big-endian length prefix followed by JSON payload. + """ + + # Maximum message size (16 MB) + MAX_MESSAGE_SIZE = 16 * 1024 * 1024 + + @staticmethod + def encode_message(data: dict) -> bytes: + """ + Encode a message for transmission. + + Args: + data: Dictionary to encode + + Returns: + Length-prefixed JSON bytes + """ + payload = json.dumps(data).encode('utf-8') + length = struct.pack('>I', len(payload)) + return length + payload + + @staticmethod + def decode_message(data: bytes) -> dict: + """ + Decode a message from bytes. + + Args: + data: Raw bytes (length prefix + JSON payload) + + Returns: + Decoded dictionary + """ + if len(data) < 4: + raise ProtocolError("Message too short") + + length = struct.unpack('>I', data[:4])[0] + if len(data) < 4 + length: + raise ProtocolError("Incomplete message") + + payload = data[4:4 + length] + return json.loads(payload.decode('utf-8')) + + @staticmethod + def read_message(sock) -> dict: + """ + Read one message from a socket. + + Args: + sock: Socket to read from + + Returns: + Decoded message dictionary + + Raises: + ProtocolError: If message is malformed + ConnectionError: If connection is closed + """ + # Read length prefix + length_data = b'' + while len(length_data) < 4: + chunk = sock.recv(4 - len(length_data)) + if not chunk: + if not length_data: + raise ConnectionError("Connection closed") + raise ProtocolError("Incomplete length prefix") + length_data += chunk + + length = struct.unpack('>I', length_data)[0] + + if length > ControlProtocol.MAX_MESSAGE_SIZE: + raise ProtocolError(f"Message too large: {length}") + + # Read payload + payload_data = b'' + while len(payload_data) < length: + chunk = sock.recv(min(length - len(payload_data), 65536)) + if not chunk: + raise ProtocolError("Incomplete payload") + payload_data += chunk + + try: + return json.loads(payload_data.decode('utf-8')) + except json.JSONDecodeError as e: + raise ProtocolError(f"Invalid JSON: {e}") + + @staticmethod + def write_message(sock, data: dict): + """ + Write one message to a socket. + + Args: + sock: Socket to write to + data: Message dictionary to send + """ + message = ControlProtocol.encode_message(data) + sock.sendall(message) + + @staticmethod + async def read_message_async(reader) -> dict: + """ + Read one message from an async reader. + + Args: + reader: asyncio StreamReader + + Returns: + Decoded message dictionary + """ + # Read length prefix + length_data = await reader.readexactly(4) + length = struct.unpack('>I', length_data)[0] + + if length > ControlProtocol.MAX_MESSAGE_SIZE: + raise ProtocolError(f"Message too large: {length}") + + # Read payload + payload_data = await reader.readexactly(length) + + try: + return json.loads(payload_data.decode('utf-8')) + except json.JSONDecodeError as e: + raise ProtocolError(f"Invalid JSON: {e}") + + @staticmethod + async def write_message_async(writer, data: dict): + """ + Write one message to an async writer. + + Args: + writer: asyncio StreamWriter + data: Message dictionary to send + """ + message = ControlProtocol.encode_message(data) + writer.write(message) + await writer.drain() + + +def make_request(request_id: int, command: str, args: list = None) -> dict: + """ + Create a request message. + + Args: + request_id: Unique request identifier + command: Command name (e.g., "show workers") + args: Optional list of arguments + + Returns: + Request dictionary + """ + return { + "id": request_id, + "command": command, + "args": args or [], + } + + +def make_response(request_id: int, data: dict = None) -> dict: + """ + Create a success response message. + + Args: + request_id: Request identifier being responded to + data: Response data + + Returns: + Response dictionary + """ + return { + "id": request_id, + "status": "ok", + "data": data or {}, + } + + +def make_error_response(request_id: int, error: str) -> dict: + """ + Create an error response message. + + Args: + request_id: Request identifier being responded to + error: Error message + + Returns: + Error response dictionary + """ + return { + "id": request_id, + "status": "error", + "error": error, + } diff --git a/gunicorn/ctl/server.py b/gunicorn/ctl/server.py new file mode 100644 index 00000000..b585b438 --- /dev/null +++ b/gunicorn/ctl/server.py @@ -0,0 +1,301 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +""" +Control Socket Server + +Runs in the arbiter process and accepts commands via Unix socket. +Uses asyncio in a background thread to handle client connections. +""" + +import asyncio +import os +import shlex +import threading + +from gunicorn.ctl.handlers import CommandHandlers +from gunicorn.ctl.protocol import ( + ControlProtocol, + make_response, + make_error_response, +) + + +class ControlSocketServer: + """ + Control socket server running in arbiter process. + + The server runs an asyncio event loop in a background thread, + accepting connections and dispatching commands to handlers. + """ + + def __init__(self, arbiter, socket_path, socket_mode=0o600): + """ + Initialize control socket server. + + Args: + arbiter: The Gunicorn arbiter instance + socket_path: Path for the Unix socket + socket_mode: Permission mode for socket (default 0o600) + """ + self.arbiter = arbiter + self.socket_path = socket_path + self.socket_mode = socket_mode + + self.handlers = CommandHandlers(arbiter) + self._server = None + self._loop = None + self._thread = None + self._running = False + + def start(self): + """Start server in background thread with asyncio event loop.""" + if self._running: + return + + self._running = True + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + + def stop(self): + """Stop server and cleanup socket.""" + if not self._running: + return + + self._running = False + + if self._loop and self._server: + # Schedule server close in the loop + self._loop.call_soon_threadsafe(self._shutdown) + + if self._thread: + self._thread.join(timeout=2.0) + self._thread = None + + # Clean up socket file + if os.path.exists(self.socket_path): + try: + os.unlink(self.socket_path) + except OSError: + pass + + def _shutdown(self): + """Shutdown server (called from event loop thread).""" + if self._server: + self._server.close() + + def _run_loop(self): + """Run the asyncio event loop in background thread.""" + try: + asyncio.run(self._serve()) + except Exception as e: + if self.arbiter.log: + self.arbiter.log.error("Control server error: %s", e) + + async def _serve(self): + """Main async server loop.""" + self._loop = asyncio.get_running_loop() + + # Remove socket if it exists + if os.path.exists(self.socket_path): + os.unlink(self.socket_path) + + # Create Unix socket server + self._server = await asyncio.start_unix_server( + self._handle_client, + path=self.socket_path + ) + + # Set socket permissions + os.chmod(self.socket_path, self.socket_mode) + + if self.arbiter.log: + self.arbiter.log.info("Control socket listening at %s", + self.socket_path) + + try: + async with self._server: + await self._server.serve_forever() + except asyncio.CancelledError: + pass + finally: + if os.path.exists(self.socket_path): + try: + os.unlink(self.socket_path) + except OSError: + pass + + async def _handle_client(self, reader, writer): + """ + Handle client connection. + + Args: + reader: asyncio StreamReader + writer: asyncio StreamWriter + """ + try: + while self._running: + try: + message = await asyncio.wait_for( + ControlProtocol.read_message_async(reader), + timeout=300.0 # 5 minute idle timeout + ) + except asyncio.TimeoutError: + # Client idle too long, close connection + break + except asyncio.IncompleteReadError: + # Client disconnected + break + except Exception: + # Protocol error + break + + # Process command + response = await self._dispatch(message) + + # Send response + await ControlProtocol.write_message_async(writer, response) + + except Exception as e: + if self.arbiter.log: + self.arbiter.log.debug("Control client error: %s", e) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + + async def _dispatch(self, message: dict) -> dict: + """ + Dispatch command to appropriate handler. + + Args: + message: Request message dict + + Returns: + Response dictionary + """ + request_id = message.get("id", 0) + command = message.get("command", "").strip() + args = message.get("args", []) + + if not command: + return make_error_response(request_id, "Empty command") + + try: + # Parse command (e.g., "show workers" or "worker add 2") + parts = shlex.split(command) + if args: + parts.extend(str(a) for a in args) + + if not parts: + return make_error_response(request_id, "Empty command") + + # Route to handler + result = self._execute_command(parts) + return make_response(request_id, result) + + except ValueError as e: + return make_error_response(request_id, f"Invalid argument: {e}") + except Exception as e: + if self.arbiter.log: + self.arbiter.log.exception("Command error") + return make_error_response(request_id, f"Command failed: {e}") + + def _execute_command(self, parts: list) -> dict: # pylint: disable=too-many-return-statements + """ + Execute a parsed command. + + Args: + parts: Command parts (e.g., ["show", "workers"]) + + Returns: + Handler result dictionary + """ + if not parts: + raise ValueError("Empty command") + + cmd = parts[0].lower() + rest = parts[1:] + + # Map commands to handlers + if cmd == "show": + return self._handle_show(rest) + elif cmd == "worker": + return self._handle_worker(rest) + elif cmd == "dirty": + return self._handle_dirty(rest) + elif cmd == "reload": + return self.handlers.reload() + elif cmd == "reopen": + return self.handlers.reopen() + elif cmd == "shutdown": + mode = rest[0] if rest else "graceful" + return self.handlers.shutdown(mode) + elif cmd == "help": + return self.handlers.help() + else: + raise ValueError(f"Unknown command: {cmd}") + + def _handle_show(self, args: list) -> dict: + """Handle 'show' commands.""" + if not args: + raise ValueError("Missing show target (all|workers|dirty|config|stats|listeners)") + + target = args[0].lower() + + if target == "all": + return self.handlers.show_all() + elif target == "workers": + return self.handlers.show_workers() + elif target == "dirty": + return self.handlers.show_dirty() + elif target == "config": + return self.handlers.show_config() + elif target == "stats": + return self.handlers.show_stats() + elif target == "listeners": + return self.handlers.show_listeners() + else: + raise ValueError(f"Unknown show target: {target}") + + def _handle_worker(self, args: list) -> dict: + """Handle 'worker' commands.""" + if not args: + raise ValueError("Missing worker action (add|remove|kill)") + + action = args[0].lower() + action_args = args[1:] + + if action == "add": + count = int(action_args[0]) if action_args else 1 + return self.handlers.worker_add(count) + elif action == "remove": + count = int(action_args[0]) if action_args else 1 + return self.handlers.worker_remove(count) + elif action == "kill": + if not action_args: + raise ValueError("Missing PID for worker kill") + pid = int(action_args[0]) + return self.handlers.worker_kill(pid) + else: + raise ValueError(f"Unknown worker action: {action}") + + def _handle_dirty(self, args: list) -> dict: + """Handle 'dirty' commands.""" + if not args: + raise ValueError("Missing dirty action (add|remove)") + + action = args[0].lower() + action_args = args[1:] + + if action == "add": + count = int(action_args[0]) if action_args else 1 + return self.handlers.dirty_add(count) + elif action == "remove": + count = int(action_args[0]) if action_args else 1 + return self.handlers.dirty_remove(count) + else: + raise ValueError(f"Unknown dirty action: {action}") diff --git a/gunicorn/dirty/arbiter.py b/gunicorn/dirty/arbiter.py index 44b23329..63962ace 100644 --- a/gunicorn/dirty/arbiter.py +++ b/gunicorn/dirty/arbiter.py @@ -14,7 +14,6 @@ import errno import fnmatch import os import signal -import sys import tempfile import time @@ -41,6 +40,8 @@ from .protocol import ( STASH_OP_DELETE_TABLE, STASH_OP_TABLES, STASH_OP_EXISTS, + MANAGE_OP_ADD, + MANAGE_OP_REMOVE, ) from .worker import DirtyWorker @@ -367,7 +368,8 @@ class DirtyArbiter: try: async with self._server: await self._server.serve_forever() - except asyncio.CancelledError: + except (asyncio.CancelledError, RuntimeError): + # RuntimeError raised when server.close() is called during serve_forever() pass finally: monitor_task.cancel() @@ -422,6 +424,12 @@ class DirtyArbiter: # Handle stash operations if msg_type == DirtyProtocol.MSG_TYPE_STASH: await self.handle_stash_request(message, writer) + # Handle status queries + elif msg_type == DirtyProtocol.MSG_TYPE_STATUS: + await self.handle_status_request(message, writer) + # Handle worker management (add/remove workers) + elif msg_type == DirtyProtocol.MSG_TYPE_MANAGE: + await self.handle_manage_request(message, writer) else: # Route request to a dirty worker - pass writer for streaming await self.route_request(message, writer) @@ -645,6 +653,141 @@ class DirtyArbiter: # Stash (shared state) operations - handled directly in arbiter # ------------------------------------------------------------------------- + async def handle_status_request(self, message, client_writer): + """ + Handle a status query request. + + Returns information about the dirty arbiter and its workers. + + Args: + message: Status request message + client_writer: StreamWriter to send response to client + """ + request_id = message.get("id", "unknown") + now = time.monotonic() + + workers_info = [] + for pid, worker in self.workers.items(): + try: + last_update = worker.tmp.last_update() + last_heartbeat = round(now - last_update, 2) + except (OSError, ValueError, AttributeError): + last_heartbeat = None + + workers_info.append({ + "pid": pid, + "age": worker.age, + "apps": getattr(worker, 'app_paths', []), + "booted": getattr(worker, 'booted', False), + "last_heartbeat": last_heartbeat, + }) + + workers_info.sort(key=lambda w: w["age"]) + + result = { + "arbiter_pid": self.pid, + "workers": workers_info, + "worker_count": len(workers_info), + "apps": list(self.app_specs.keys()) if self.app_specs else [], + } + + response = make_response(request_id, result) + await DirtyProtocol.write_message_async(client_writer, response) + + async def handle_manage_request(self, message, client_writer): + """ + Handle a worker management request. + + Supports adding or removing dirty workers via protocol messages. + + Args: + message: Manage request message + client_writer: StreamWriter to send response to client + """ + request_id = message.get("id", "unknown") + op = message.get("op") + count = max(1, int(message.get("count", 1))) + + try: + if op == MANAGE_OP_ADD: + # Add workers - only loads apps that need more workers + spawned = 0 + for _ in range(count): + result = self.spawn_worker() + if result is not None: + self.num_workers += 1 + spawned += 1 + await asyncio.sleep(0.1) + + # Provide feedback about why no workers were spawned + if spawned == 0: + result = { + "success": True, + "operation": "add", + "requested": count, + "spawned": 0, + "reason": "All apps have reached their worker limits", + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + else: + result = { + "success": True, + "operation": "add", + "requested": count, + "spawned": spawned, + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + + elif op == MANAGE_OP_REMOVE: + # Remove workers (similar to TTOU signal but via message) + min_workers = self._get_minimum_workers() + removed = 0 + + for _ in range(count): + if self.num_workers <= min_workers: + break + if len(self.workers) <= 1: + break + + self.num_workers -= 1 + + # Kill oldest worker + oldest_pid = min(self.workers.keys(), + key=lambda p: self.workers[p].age) + self.kill_worker(oldest_pid, signal.SIGTERM) + removed += 1 + await asyncio.sleep(0.1) + + result = { + "success": True, + "operation": "remove", + "requested": count, + "removed": removed, + "total_workers": len(self.workers), + "target_workers": self.num_workers, + } + + else: + error = DirtyError(f"Unknown manage operation: {op}") + response = make_error_response(request_id, error) + await DirtyProtocol.write_message_async(client_writer, response) + return + + self.log.info("Worker management: %s %d workers (spawned/removed: %d)", + "add" if op == MANAGE_OP_ADD else "remove", + count, + result.get("spawned", result.get("removed", 0))) + + response = make_response(request_id, result) + await DirtyProtocol.write_message_async(client_writer, response) + + except Exception as e: + self.log.error("Manage operation error: %s", e) + response = make_error_response(request_id, DirtyError(str(e))) + await DirtyProtocol.write_message_async(client_writer, response) + async def handle_stash_request(self, message, client_writer): """ Handle a stash operation directly in the arbiter. @@ -785,13 +928,17 @@ class DirtyArbiter: self.kill_worker(oldest_pid, signal.SIGTERM) await asyncio.sleep(0.1) - def spawn_worker(self): + def spawn_worker(self, force_all_apps=False): """ Spawn a new dirty worker. Worker app assignment follows these priorities: 1. If there are pending respawns (from dead workers), use those apps 2. Otherwise, determine apps for a new worker based on allocation + 3. If force_all_apps=True, spawn with all apps regardless of limits + + Args: + force_all_apps: If True, spawn worker with all apps ignoring limits Returns: Worker PID in parent process, or None if no apps need workers @@ -799,12 +946,15 @@ class DirtyArbiter: # Priority 1: Respawn dead worker with same apps if self._pending_respawns: app_paths = self._pending_respawns.pop(0) + elif force_all_apps: + # Force spawn with all apps (used by TTIN signal) + app_paths = list(self.app_specs.keys()) else: # Priority 2: New worker for initial pool app_paths = self._get_apps_for_new_worker() if not app_paths: - self.log.warning("No apps need more workers, skipping spawn") + self.log.debug("No apps need more workers, skipping spawn") return None self.worker_age += 1 @@ -836,19 +986,19 @@ class DirtyArbiter: pid, app_paths) return pid - # Child process + # Child process - use os._exit() to avoid asyncio cleanup issues worker.pid = os.getpid() try: util._setproctitle(f"dirty-worker [{self.cfg.proc_name}]") worker.init_process() - sys.exit(0) - except SystemExit: - raise + os._exit(0) + except SystemExit as e: + os._exit(e.code if e.code is not None else 0) except Exception: self.log.exception("Exception in dirty worker process") if not worker.booted: - sys.exit(self.WORKER_BOOT_ERROR) - sys.exit(-1) + os._exit(self.WORKER_BOOT_ERROR) + os._exit(1) def kill_worker(self, pid, sig): """Kill a worker by PID.""" diff --git a/gunicorn/dirty/protocol.py b/gunicorn/dirty/protocol.py index 8a5f7b61..5f3f9d04 100644 --- a/gunicorn/dirty/protocol.py +++ b/gunicorn/dirty/protocol.py @@ -43,6 +43,8 @@ MSG_TYPE_ERROR = 0x03 MSG_TYPE_CHUNK = 0x04 MSG_TYPE_END = 0x05 MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers) +MSG_TYPE_STATUS = 0x11 # Status query for arbiter/workers +MSG_TYPE_MANAGE = 0x12 # Worker management (add/remove workers) # Message type names (for backwards compatibility with old API) MSG_TYPE_REQUEST_STR = "request" @@ -51,6 +53,8 @@ MSG_TYPE_ERROR_STR = "error" MSG_TYPE_CHUNK_STR = "chunk" MSG_TYPE_END_STR = "end" MSG_TYPE_STASH_STR = "stash" +MSG_TYPE_STATUS_STR = "status" +MSG_TYPE_MANAGE_STR = "manage" # Map int types to string names MSG_TYPE_TO_STR = { @@ -60,6 +64,8 @@ MSG_TYPE_TO_STR = { MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR, MSG_TYPE_END: MSG_TYPE_END_STR, MSG_TYPE_STASH: MSG_TYPE_STASH_STR, + MSG_TYPE_STATUS: MSG_TYPE_STATUS_STR, + MSG_TYPE_MANAGE: MSG_TYPE_MANAGE_STR, } # Map string names to int types @@ -77,6 +83,10 @@ STASH_OP_DELETE_TABLE = 8 STASH_OP_TABLES = 9 STASH_OP_EXISTS = 10 +# Manage operation codes +MANAGE_OP_ADD = 1 # Add/spawn workers +MANAGE_OP_REMOVE = 2 # Remove/kill workers + # Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16 HEADER_FORMAT = ">2sBBIQ" HEADER_SIZE = struct.calcsize(HEADER_FORMAT) @@ -98,6 +108,8 @@ class BinaryProtocol: MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR MSG_TYPE_END = MSG_TYPE_END_STR MSG_TYPE_STASH = MSG_TYPE_STASH_STR + MSG_TYPE_STATUS = MSG_TYPE_STATUS_STR + MSG_TYPE_MANAGE = MSG_TYPE_MANAGE_STR @staticmethod def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes: @@ -273,6 +285,43 @@ class BinaryProtocol: header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0) return header + @staticmethod + def encode_status(request_id: int) -> bytes: + """ + Encode a status query message. + + Args: + request_id: Request identifier + + Returns: + bytes: Complete message (header + empty payload) + """ + # Status query has empty payload + header = BinaryProtocol.encode_header(MSG_TYPE_STATUS, request_id, 0) + return header + + @staticmethod + def encode_manage(request_id: int, op: int, count: int = 1) -> bytes: + """ + Encode a worker management message. + + Args: + request_id: Request identifier + op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE) + count: Number of workers to add/remove + + Returns: + bytes: Complete message (header + payload) + """ + payload_dict = { + "op": op, + "count": count, + } + payload = TLVEncoder.encode(payload_dict) + header = BinaryProtocol.encode_header(MSG_TYPE_MANAGE, request_id, + len(payload)) + return header + payload + @staticmethod def encode_stash(request_id: int, op: int, table: str, key=None, value=None, pattern=None) -> bytes: @@ -524,7 +573,7 @@ class BinaryProtocol: sock.sendall(data) @staticmethod - def _encode_from_dict(message: dict) -> bytes: + def _encode_from_dict(message: dict) -> bytes: # pylint: disable=too-many-return-statements """ Encode a message dict to binary format. @@ -582,6 +631,14 @@ class BinaryProtocol: message.get("value"), message.get("pattern") ) + elif msg_type == MSG_TYPE_STATUS: + return BinaryProtocol.encode_status(request_id) + elif msg_type == MSG_TYPE_MANAGE: + return BinaryProtocol.encode_manage( + request_id, + message.get("op"), + message.get("count", 1) + ) else: raise DirtyProtocolError(f"Unhandled message type: {msg_type}") @@ -731,3 +788,23 @@ def make_stash_message(request_id, op: int, table: str, if pattern is not None: msg["pattern"] = pattern return msg + + +def make_manage_message(request_id, op: int, count: int = 1) -> dict: + """ + Build a worker management message dict. + + Args: + request_id: Unique request identifier (int or str) + op: Management operation (MANAGE_OP_ADD or MANAGE_OP_REMOVE) + count: Number of workers to add/remove + + Returns: + dict: Manage message dict + """ + return { + "type": DirtyProtocol.MSG_TYPE_MANAGE, + "id": request_id, + "op": op, + "count": count, + } diff --git a/mkdocs.yml b/mkdocs.yml index 03193260..f1c935f4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -18,6 +18,7 @@ nav: - HTTP/2: guides/http2.md - ASGI Worker: asgi.md - Dirty Arbiters: dirty.md + - Control Interface: guides/gunicornc.md - uWSGI Protocol: uwsgi.md - Signals: signals.md - Instrumentation: instrumentation.md diff --git a/pyproject.toml b/pyproject.toml index 852cdfe6..c23d6622 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,7 @@ testing = [ [project.scripts] # duplicates "python -m gunicorn" handling in __main__.py gunicorn = "gunicorn.app.wsgiapp:run" +gunicornc = "gunicorn.ctl.cli:main" # note the quotes around "paste.server_runner" to escape the dot [project.entry-points."paste.server_runner"] diff --git a/tests/ctl/__init__.py b/tests/ctl/__init__.py new file mode 100644 index 00000000..530e35ca --- /dev/null +++ b/tests/ctl/__init__.py @@ -0,0 +1,3 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. diff --git a/tests/ctl/test_client.py b/tests/ctl/test_client.py new file mode 100644 index 00000000..7f7b770a --- /dev/null +++ b/tests/ctl/test_client.py @@ -0,0 +1,275 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket client.""" + +import os +import socket +import tempfile +import threading + +import pytest + +from gunicorn.ctl.client import ( + ControlClient, + ControlClientError, + parse_command, +) +from gunicorn.ctl.protocol import ControlProtocol, make_response + + +class TestControlClientInit: + """Tests for ControlClient initialization.""" + + def test_init_attributes(self): + """Test that client is initialized with correct attributes.""" + client = ControlClient("/tmp/test.sock", timeout=60.0) + + assert client.socket_path == "/tmp/test.sock" + assert client.timeout == 60.0 + assert client._sock is None + assert client._request_id == 0 + + +class TestControlClientConnect: + """Tests for ControlClient connection.""" + + def test_connect_nonexistent_socket(self): + """Test connecting to non-existent socket.""" + client = ControlClient("/nonexistent/socket.sock") + + with pytest.raises(ControlClientError) as exc_info: + client.connect() + + assert "Failed to connect" in str(exc_info.value) + + def test_connect_success(self): + """Test successful connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + # Create a listening socket + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + client = ControlClient(socket_path) + client.connect() + + assert client._sock is not None + client.close() + finally: + server_sock.close() + + def test_connect_already_connected(self): + """Test that connect is idempotent.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + client = ControlClient(socket_path) + client.connect() + first_sock = client._sock + client.connect() # Should not create new connection + + assert client._sock is first_sock + client.close() + finally: + server_sock.close() + + +class TestControlClientClose: + """Tests for ControlClient close.""" + + def test_close_idempotent(self): + """Test that close can be called multiple times.""" + client = ControlClient("/tmp/test.sock") + client.close() + client.close() # Should not raise + + def test_close_clears_socket(self): + """Test that close clears the socket.""" + client = ControlClient("/tmp/test.sock") + client._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client.close() + + assert client._sock is None + + +class TestControlClientContextManager: + """Tests for context manager functionality.""" + + def test_context_manager_connection_error(self): + """Test context manager with connection error.""" + client = ControlClient("/nonexistent/socket.sock") + + with pytest.raises(ControlClientError): + with client: + pass + + def test_context_manager_success(self): + """Test successful context manager usage.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + try: + with ControlClient(socket_path) as client: + assert client._sock is not None + + # After context manager exits, socket should be closed + assert client._sock is None + finally: + server_sock.close() + + +class TestControlClientSendCommand: + """Tests for send_command functionality.""" + + def test_send_command_success(self): + """Test successful command send.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + response_data = {"workers": [], "count": 0} + response_sent = threading.Event() + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = make_response(msg["id"], response_data) + ControlProtocol.write_message(conn, resp) + response_sent.set() + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + result = client.send_command("show workers") + + assert result == response_data + client.close() + finally: + response_sent.wait(timeout=2.0) + server_thread.join(timeout=2.0) + server_sock.close() + + def test_send_command_error_response(self): + """Test handling error response.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = { + "id": msg["id"], + "status": "error", + "error": "Unknown command", + } + ControlProtocol.write_message(conn, resp) + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + + with pytest.raises(ControlClientError) as exc_info: + client.send_command("invalid command") + + assert "Unknown command" in str(exc_info.value) + client.close() + finally: + server_thread.join(timeout=2.0) + server_sock.close() + + def test_send_command_auto_connect(self): + """Test that send_command auto-connects if not connected.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server_sock.bind(socket_path) + server_sock.listen(1) + + def server_handler(): + conn, _ = server_sock.accept() + try: + msg = ControlProtocol.read_message(conn) + resp = make_response(msg["id"], {}) + ControlProtocol.write_message(conn, resp) + finally: + conn.close() + + server_thread = threading.Thread(target=server_handler) + server_thread.start() + + try: + client = ControlClient(socket_path, timeout=5.0) + # Don't call connect() explicitly + result = client.send_command("help") + + assert isinstance(result, dict) + client.close() + finally: + server_thread.join(timeout=2.0) + server_sock.close() + + +class TestParseCommand: + """Tests for command parsing.""" + + def test_parse_simple_command(self): + """Test parsing simple command.""" + cmd, args = parse_command("show workers") + assert cmd == "show workers" + assert args == [] + + def test_parse_command_with_args(self): + """Test parsing command with arguments.""" + cmd, args = parse_command("worker add 2") + assert cmd == "worker add" + assert args == ["2"] + + def test_parse_command_with_multiple_args(self): + """Test parsing command with multiple arguments.""" + cmd, args = parse_command("worker kill 12345") + assert cmd == "worker kill" + assert args == ["12345"] + + def test_parse_empty_command(self): + """Test parsing empty command.""" + cmd, args = parse_command("") + assert cmd == "" + assert args == [] + + def test_parse_command_quoted(self): + """Test parsing command with quoted arguments.""" + cmd, args = parse_command('worker kill "12345"') + assert cmd == "worker kill" + assert args == ["12345"] diff --git a/tests/ctl/test_handlers.py b/tests/ctl/test_handlers.py new file mode 100644 index 00000000..8b2771da --- /dev/null +++ b/tests/ctl/test_handlers.py @@ -0,0 +1,468 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket command handlers.""" + +import signal +import time +from unittest.mock import MagicMock, patch + +from gunicorn.ctl.handlers import CommandHandlers + + +class MockWorker: + """Mock worker for testing.""" + + def __init__(self, pid, age, booted=True, aborted=False): + self.pid = pid + self.age = age + self.booted = booted + self.aborted = aborted + self.tmp = MagicMock() + self.tmp.last_update.return_value = time.monotonic() + + +class MockListener: + """Mock listener for testing.""" + + def __init__(self, address, fd=3): + self._address = address + self._fd = fd + self.sock = MagicMock() + self.sock.family = 2 # AF_INET + + def __str__(self): + return self._address + + def fileno(self): + return self._fd + + +class MockConfig: + """Mock config for testing.""" + + def __init__(self): + self.bind = ['127.0.0.1:8000'] + self.workers = 4 + self.worker_class = 'sync' + self.threads = 1 + self.timeout = 30 + self.graceful_timeout = 30 + self.keepalive = 2 + self.max_requests = 0 + self.max_requests_jitter = 0 + self.worker_connections = 1000 + self.preload_app = False + self.daemon = False + self.pidfile = None + self.proc_name = 'test_app' + self.reload = False + self.dirty_workers = 0 + self.dirty_apps = [] + self.dirty_timeout = 30 + self.control_socket = 'gunicorn.ctl' + self.control_socket_disable = False + + +class MockArbiter: + """Mock arbiter for testing.""" + + def __init__(self): + self.cfg = MockConfig() + self.pid = 12345 + self.WORKERS = {} + self.LISTENERS = [] + self.dirty_arbiter_pid = 0 + self.dirty_arbiter = None + self.num_workers = 4 + self._stats = { + 'start_time': time.time() - 3600, # 1 hour ago + 'workers_spawned': 10, + 'workers_killed': 5, + 'reloads': 2, + } + + def wakeup(self): + pass + + +class TestShowWorkers: + """Tests for show workers command.""" + + def test_show_workers_empty(self): + """Test showing workers when none exist.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_workers() + + assert result["workers"] == [] + assert result["count"] == 0 + + def test_show_workers_with_workers(self): + """Test showing workers.""" + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + 1003: MockWorker(1003, 3), + } + handlers = CommandHandlers(arbiter) + + result = handlers.show_workers() + + assert result["count"] == 3 + assert len(result["workers"]) == 3 + + # Verify sorted by age + ages = [w["age"] for w in result["workers"]] + assert ages == sorted(ages) + + # Verify worker data + worker = result["workers"][0] + assert "pid" in worker + assert "age" in worker + assert "booted" in worker + assert "last_heartbeat" in worker + + +class TestShowStats: + """Tests for show stats command.""" + + def test_show_stats(self): + """Test showing stats.""" + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + } + handlers = CommandHandlers(arbiter) + + result = handlers.show_stats() + + assert result["pid"] == 12345 + assert result["workers_current"] == 2 + assert result["workers_target"] == 4 + assert result["workers_spawned"] == 10 + assert result["workers_killed"] == 5 + assert result["reloads"] == 2 + assert result["uptime"] is not None + assert result["uptime"] > 0 + + +class TestShowConfig: + """Tests for show config command.""" + + def test_show_config(self): + """Test showing config.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_config() + + assert result["workers"] == 4 + assert result["timeout"] == 30 + assert result["bind"] == ['127.0.0.1:8000'] + + +class TestShowListeners: + """Tests for show listeners command.""" + + def test_show_listeners_empty(self): + """Test showing listeners when none exist.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_listeners() + + assert result["listeners"] == [] + assert result["count"] == 0 + + def test_show_listeners(self): + """Test showing listeners.""" + arbiter = MockArbiter() + arbiter.LISTENERS = [ + MockListener("127.0.0.1:8000", fd=3), + MockListener("127.0.0.1:8001", fd=4), + ] + handlers = CommandHandlers(arbiter) + + result = handlers.show_listeners() + + assert result["count"] == 2 + assert len(result["listeners"]) == 2 + assert result["listeners"][0]["address"] == "127.0.0.1:8000" + + +class TestWorkerAdd: + """Tests for worker add command.""" + + def test_worker_add_default(self): + """Test adding one worker (default).""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_add() + + assert result["added"] == 1 + assert result["previous"] == 4 + assert result["total"] == 5 + assert arbiter.num_workers == 5 + arbiter.wakeup.assert_called_once() + + def test_worker_add_multiple(self): + """Test adding multiple workers.""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_add(3) + + assert result["added"] == 3 + assert result["total"] == 7 + + +class TestWorkerRemove: + """Tests for worker remove command.""" + + def test_worker_remove_default(self): + """Test removing one worker (default).""" + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_remove() + + assert result["removed"] == 1 + assert result["previous"] == 4 + assert result["total"] == 3 + assert arbiter.num_workers == 3 + arbiter.wakeup.assert_called_once() + + def test_worker_remove_cannot_go_below_one(self): + """Test that worker count cannot go below 1.""" + arbiter = MockArbiter() + arbiter.num_workers = 2 + arbiter.wakeup = MagicMock() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_remove(5) + + assert result["removed"] == 1 + assert result["total"] == 1 + assert arbiter.num_workers == 1 + + +class TestWorkerKill: + """Tests for worker kill command.""" + + def test_worker_kill_success(self): + """Test killing a worker.""" + arbiter = MockArbiter() + arbiter.WORKERS = {1001: MockWorker(1001, 1)} + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.worker_kill(1001) + + assert result["success"] is True + assert result["killed"] == 1001 + mock_kill.assert_called_once_with(1001, signal.SIGTERM) + + def test_worker_kill_not_found(self): + """Test killing a non-existent worker.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.worker_kill(9999) + + assert result["success"] is False + assert "not found" in result["error"] + + +class TestShowDirty: + """Tests for show dirty command.""" + + def test_show_dirty_disabled(self): + """Test showing dirty when disabled.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.show_dirty() + + assert result["enabled"] is False + assert result["pid"] is None + + +class TestDirtyAdd: + """Tests for dirty add command.""" + + def test_dirty_add_not_running(self): + """Test dirty add when dirty arbiter not running.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.dirty_add() + + assert result["success"] is False + assert "not running" in result["error"] + + def test_dirty_add_no_socket(self): + """Test dirty add when socket path not available.""" + arbiter = MockArbiter() + arbiter.dirty_arbiter_pid = 2000 + handlers = CommandHandlers(arbiter) + + # No dirty_arbiter attribute and no env var + with patch.dict('os.environ', {}, clear=True): + result = handlers.dirty_add() + + assert result["success"] is False + assert "socket" in result["error"].lower() + + +class TestDirtyRemove: + """Tests for dirty remove command.""" + + def test_dirty_remove_not_running(self): + """Test dirty remove when dirty arbiter not running.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.dirty_remove() + + assert result["success"] is False + assert "not running" in result["error"] + + def test_dirty_remove_no_socket(self): + """Test dirty remove when socket path not available.""" + arbiter = MockArbiter() + arbiter.dirty_arbiter_pid = 2000 + handlers = CommandHandlers(arbiter) + + # No dirty_arbiter attribute and no env var + with patch.dict('os.environ', {}, clear=True): + result = handlers.dirty_remove() + + assert result["success"] is False + assert "socket" in result["error"].lower() + + +class TestReload: + """Tests for reload command.""" + + def test_reload(self): + """Test reload command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.reload() + + assert result["status"] == "reloading" + mock_kill.assert_called_once_with(12345, signal.SIGHUP) + + +class TestReopen: + """Tests for reopen command.""" + + def test_reopen(self): + """Test reopen command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.reopen() + + assert result["status"] == "reopening" + mock_kill.assert_called_once_with(12345, signal.SIGUSR1) + + +class TestShutdown: + """Tests for shutdown command.""" + + def test_shutdown_graceful(self): + """Test graceful shutdown.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.shutdown() + + assert result["status"] == "shutting_down" + assert result["mode"] == "graceful" + mock_kill.assert_called_once_with(12345, signal.SIGTERM) + + def test_shutdown_quick(self): + """Test quick shutdown.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + with patch('os.kill') as mock_kill: + result = handlers.shutdown("quick") + + assert result["status"] == "shutting_down" + assert result["mode"] == "quick" + mock_kill.assert_called_once_with(12345, signal.SIGINT) + + +class TestShowAll: + """Tests for show all command.""" + + def test_show_all_basic(self): + """Test show all command.""" + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + } + handlers = CommandHandlers(arbiter) + + result = handlers.show_all() + + assert "arbiter" in result + assert result["arbiter"]["pid"] == 12345 + assert result["arbiter"]["type"] == "arbiter" + + assert "web_workers" in result + assert result["web_worker_count"] == 2 + assert len(result["web_workers"]) == 2 + + assert "dirty_arbiter" in result + assert result["dirty_arbiter"] is None + + # No dirty workers when no dirty arbiter + assert result["dirty_worker_count"] == 0 + + def test_show_all_with_dirty(self): + """Test show all with dirty arbiter running.""" + arbiter = MockArbiter() + arbiter.dirty_arbiter_pid = 2000 + handlers = CommandHandlers(arbiter) + + result = handlers.show_all() + + assert result["dirty_arbiter"] is not None + assert result["dirty_arbiter"]["pid"] == 2000 + assert result["dirty_arbiter"]["type"] == "dirty_arbiter" + + +class TestHelp: + """Tests for help command.""" + + def test_help(self): + """Test help command.""" + arbiter = MockArbiter() + handlers = CommandHandlers(arbiter) + + result = handlers.help() + + assert "commands" in result + commands = result["commands"] + assert "show all" in commands + assert "show workers" in commands + assert "worker add [N]" in commands + assert "reload" in commands + assert "shutdown [graceful|quick]" in commands diff --git a/tests/ctl/test_protocol.py b/tests/ctl/test_protocol.py new file mode 100644 index 00000000..899ea580 --- /dev/null +++ b/tests/ctl/test_protocol.py @@ -0,0 +1,249 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket protocol.""" + +import json +import struct +import pytest + +from gunicorn.ctl.protocol import ( + ControlProtocol, + ProtocolError, + make_request, + make_response, + make_error_response, +) + + +class TestControlProtocolEncoding: + """Tests for message encoding/decoding.""" + + def test_encode_message_simple(self): + """Test encoding a simple message.""" + data = {"command": "test"} + result = ControlProtocol.encode_message(data) + + # First 4 bytes are length + length = struct.unpack('>I', result[:4])[0] + payload = result[4:] + + assert length == len(payload) + assert json.loads(payload.decode('utf-8')) == data + + def test_encode_message_unicode(self): + """Test encoding message with unicode characters.""" + data = {"message": "Hello \u4e16\u754c"} + result = ControlProtocol.encode_message(data) + + length = struct.unpack('>I', result[:4])[0] + payload = result[4:] + + assert length == len(payload) + assert json.loads(payload.decode('utf-8')) == data + + def test_decode_message_simple(self): + """Test decoding a simple message.""" + data = {"command": "test", "args": [1, 2, 3]} + payload = json.dumps(data).encode('utf-8') + length = struct.pack('>I', len(payload)) + raw = length + payload + + result = ControlProtocol.decode_message(raw) + assert result == data + + def test_decode_message_too_short(self): + """Test decoding message that's too short.""" + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.decode_message(b'\x00\x00') + assert "too short" in str(exc_info.value) + + def test_decode_message_incomplete(self): + """Test decoding incomplete message.""" + # Length says 100 bytes but only 4 bytes provided + raw = struct.pack('>I', 100) + b'test' + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.decode_message(raw) + assert "Incomplete" in str(exc_info.value) + + def test_roundtrip(self): + """Test encode/decode roundtrip.""" + original = { + "id": 42, + "command": "show workers", + "args": ["arg1", 123, True, None], + "nested": {"a": 1, "b": [1, 2, 3]}, + } + + encoded = ControlProtocol.encode_message(original) + decoded = ControlProtocol.decode_message(encoded) + + assert decoded == original + + +class TestMakeRequest: + """Tests for request creation.""" + + def test_make_request_simple(self): + """Test creating a simple request.""" + result = make_request(1, "show workers") + + assert result["id"] == 1 + assert result["command"] == "show workers" + assert result["args"] == [] + + def test_make_request_with_args(self): + """Test creating a request with arguments.""" + result = make_request(42, "worker add", [2]) + + assert result["id"] == 42 + assert result["command"] == "worker add" + assert result["args"] == [2] + + +class TestMakeResponse: + """Tests for response creation.""" + + def test_make_response_simple(self): + """Test creating a simple response.""" + result = make_response(1, {"count": 5}) + + assert result["id"] == 1 + assert result["status"] == "ok" + assert result["data"] == {"count": 5} + + def test_make_response_empty_data(self): + """Test creating response with no data.""" + result = make_response(1) + + assert result["id"] == 1 + assert result["status"] == "ok" + assert result["data"] == {} + + +class TestMakeErrorResponse: + """Tests for error response creation.""" + + def test_make_error_response(self): + """Test creating an error response.""" + result = make_error_response(1, "Unknown command") + + assert result["id"] == 1 + assert result["status"] == "error" + assert result["error"] == "Unknown command" + + +class TestControlProtocolSocket: + """Tests for socket reading/writing.""" + + def test_read_write_message(self): + """Test read/write through socket pair.""" + import socket + import threading + + data = {"id": 1, "command": "test"} + received = [] + + # Create socket pair + server, client = socket.socketpair() + + def reader(): + received.append(ControlProtocol.read_message(server)) + + t = threading.Thread(target=reader) + t.start() + + ControlProtocol.write_message(client, data) + t.join(timeout=2.0) + + client.close() + server.close() + + assert len(received) == 1 + assert received[0] == data + + def test_read_connection_closed(self): + """Test reading from closed connection.""" + import socket + + server, client = socket.socketpair() + client.close() + + with pytest.raises(ConnectionError): + ControlProtocol.read_message(server) + + server.close() + + def test_read_message_too_large(self): + """Test reading message exceeding max size.""" + import socket + + server, client = socket.socketpair() + + # Send a length that exceeds MAX_MESSAGE_SIZE + huge_length = ControlProtocol.MAX_MESSAGE_SIZE + 1 + client.send(struct.pack('>I', huge_length)) + + with pytest.raises(ProtocolError) as exc_info: + ControlProtocol.read_message(server) + assert "too large" in str(exc_info.value) + + client.close() + server.close() + + +class TestControlProtocolAsync: + """Tests for async protocol methods.""" + + @pytest.mark.asyncio + async def test_async_read_write(self): + """Test async read/write using a unix server.""" + import asyncio + import tempfile + import os + + data = {"id": 1, "command": "async test"} + received = [] + + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + async def handler(reader, writer): + msg = await ControlProtocol.read_message_async(reader) + received.append(msg) + await ControlProtocol.write_message_async(writer, data) + writer.close() + await writer.wait_closed() + + server = await asyncio.start_unix_server(handler, path=socket_path) + + async with server: + reader, writer = await asyncio.open_unix_connection(socket_path) + await ControlProtocol.write_message_async(writer, data) + response = await ControlProtocol.read_message_async(reader) + writer.close() + await writer.wait_closed() + + assert len(received) == 1 + assert received[0] == data + assert response == data + + +class TestProtocolMaxSize: + """Tests for protocol size limits.""" + + def test_max_message_size_constant(self): + """Test that MAX_MESSAGE_SIZE is set to a reasonable value.""" + # Should be 16 MB + assert ControlProtocol.MAX_MESSAGE_SIZE == 16 * 1024 * 1024 + + def test_encode_large_message(self): + """Test encoding a large (but valid) message.""" + # Create a message with ~1MB of data + data = {"data": "x" * (1024 * 1024)} + encoded = ControlProtocol.encode_message(data) + + # Should succeed and be decodable + decoded = ControlProtocol.decode_message(encoded) + assert decoded == data diff --git a/tests/ctl/test_server.py b/tests/ctl/test_server.py new file mode 100644 index 00000000..dc70e54e --- /dev/null +++ b/tests/ctl/test_server.py @@ -0,0 +1,363 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for control socket server.""" + +import os +import tempfile +import time +from unittest.mock import MagicMock + +import pytest + +from gunicorn.ctl.server import ControlSocketServer +from gunicorn.ctl.client import ControlClient + + +class MockWorker: + """Mock worker for testing.""" + + def __init__(self, pid, age, booted=True, aborted=False): + self.pid = pid + self.age = age + self.booted = booted + self.aborted = aborted + self.tmp = MagicMock() + self.tmp.last_update.return_value = time.monotonic() + + +class MockConfig: + """Mock config for testing.""" + + def __init__(self): + self.bind = ['127.0.0.1:8000'] + self.workers = 4 + self.worker_class = 'sync' + self.threads = 1 + self.timeout = 30 + self.graceful_timeout = 30 + self.keepalive = 2 + self.max_requests = 0 + self.max_requests_jitter = 0 + self.worker_connections = 1000 + self.preload_app = False + self.daemon = False + self.pidfile = None + self.proc_name = 'test_app' + self.reload = False + self.dirty_workers = 0 + self.dirty_apps = [] + self.dirty_timeout = 30 + self.control_socket = 'gunicorn.ctl' + self.control_socket_disable = False + + +class MockLog: + """Mock logger for testing.""" + + def debug(self, msg, *args): + pass + + def info(self, msg, *args): + pass + + def warning(self, msg, *args): + pass + + def error(self, msg, *args): + pass + + def exception(self, msg, *args): + pass + + +class MockArbiter: + """Mock arbiter for testing.""" + + def __init__(self): + self.cfg = MockConfig() + self.log = MockLog() + self.pid = 12345 + self.WORKERS = {} + self.LISTENERS = [] + self.dirty_arbiter_pid = 0 + self.dirty_arbiter = None + self.num_workers = 4 + self._stats = { + 'start_time': time.time() - 3600, + 'workers_spawned': 10, + 'workers_killed': 5, + 'reloads': 2, + } + + def wakeup(self): + pass + + +class TestControlSocketServerInit: + """Tests for server initialization.""" + + def test_init(self): + """Test server initialization.""" + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, "/tmp/test.sock", 0o600) + + assert server.arbiter is arbiter + assert server.socket_path == "/tmp/test.sock" + assert server.socket_mode == 0o600 + assert server._running is False + + +class TestControlSocketServerLifecycle: + """Tests for server start/stop.""" + + def test_start_stop(self): + """Test starting and stopping the server.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + # Wait for server to start + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + assert os.path.exists(socket_path) + + server.stop() + + # Wait for cleanup + time.sleep(0.2) + + # Socket should be cleaned up + assert not os.path.exists(socket_path) + + def test_start_already_running(self): + """Test that start is idempotent.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + first_thread = server._thread + server.start() + + assert server._thread is first_thread + + server.stop() + + def test_stop_not_running(self): + """Test stopping a non-running server.""" + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, "/tmp/test.sock") + + # Should not raise + server.stop() + + +class TestControlSocketServerIntegration: + """Integration tests for server with client.""" + + def test_show_workers(self): + """Test show workers command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.WORKERS = { + 1001: MockWorker(1001, 1), + 1002: MockWorker(1002, 2), + } + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + # Wait for server to start + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("show workers") + + assert result["count"] == 2 + assert len(result["workers"]) == 2 + finally: + server.stop() + + def test_show_stats(self): + """Test show stats command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("show stats") + + assert result["pid"] == 12345 + assert result["workers_spawned"] == 10 + finally: + server.stop() + + def test_help_command(self): + """Test help command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("help") + + assert "commands" in result + assert "show workers" in result["commands"] + finally: + server.stop() + + def test_worker_add(self): + """Test worker add command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.wakeup = MagicMock() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result = client.send_command("worker add 2") + + assert result["added"] == 2 + assert result["total"] == 6 + assert arbiter.num_workers == 6 + arbiter.wakeup.assert_called() + finally: + server.stop() + + def test_invalid_command(self): + """Test handling invalid command.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + with pytest.raises(Exception) as exc_info: + client.send_command("invalid_command") + + assert "Unknown command" in str(exc_info.value) + finally: + server.stop() + + def test_multiple_commands(self): + """Test sending multiple commands on same connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + arbiter.WORKERS = {1001: MockWorker(1001, 1)} + server = ControlSocketServer(arbiter, socket_path) + + server.start() + + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + time.sleep(0.2) # Extra wait for server to be fully ready + + try: + with ControlClient(socket_path, timeout=5.0) as client: + result1 = client.send_command("show workers") + result2 = client.send_command("show stats") + result3 = client.send_command("help") + + assert result1["count"] == 1 + assert result2["pid"] == 12345 + assert "commands" in result3 + finally: + server.stop() + + +class TestControlSocketServerPermissions: + """Tests for socket permissions.""" + + @pytest.mark.skipif( + os.uname().sysname == "FreeBSD", + reason="FreeBSD socket permissions behavior differs" + ) + def test_socket_permissions(self): + """Test that socket is created with correct permissions.""" + with tempfile.TemporaryDirectory() as tmpdir: + socket_path = os.path.join(tmpdir, "test.sock") + + arbiter = MockArbiter() + server = ControlSocketServer(arbiter, socket_path, 0o660) + + server.start() + + # Wait for socket to exist + for _ in range(50): + if os.path.exists(socket_path): + break + time.sleep(0.1) + + # Extra wait for chmod to complete + time.sleep(0.2) + + try: + mode = os.stat(socket_path).st_mode & 0o777 + assert mode == 0o660 + finally: + server.stop() diff --git a/tests/test_signal_integration.py b/tests/test_signal_integration.py index 6622c4fb..f7975f78 100644 --- a/tests/test_signal_integration.py +++ b/tests/test_signal_integration.py @@ -93,15 +93,19 @@ def gunicorn_server(app_module): '--access-logfile', '-', '--error-logfile', '-', '--log-level', 'info', + '--timeout', '30', + '--graceful-timeout', '30', app_name ] + # Use setsid to create new process group for proper signal handling proc = subprocess.Popen( cmd, cwd=app_dir, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - env={**os.environ, 'PYTHONPATH': app_dir} + env={**os.environ, 'PYTHONPATH': app_dir}, + preexec_fn=os.setsid ) # Wait for server to start @@ -113,13 +117,19 @@ def gunicorn_server(app_module): yield proc, port - # Cleanup + # Cleanup - use process group kill for better cleanup if proc.poll() is None: - proc.terminate() + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, OSError): + pass try: proc.wait(timeout=5) except subprocess.TimeoutExpired: - proc.kill() + try: + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + except (ProcessLookupError, OSError): + pass proc.wait() @@ -141,8 +151,11 @@ class TestSignalHandlingIntegration: response = make_request('127.0.0.1', port) assert b'Hello, World!' in response - # Send SIGTERM - proc.send_signal(signal.SIGTERM) + # Send SIGTERM to the process group for reliable signal delivery + try: + os.killpg(os.getpgid(proc.pid), signal.SIGTERM) + except (ProcessLookupError, OSError): + proc.send_signal(signal.SIGTERM) # Wait for process to exit try: @@ -160,8 +173,11 @@ class TestSignalHandlingIntegration: response = make_request('127.0.0.1', port) assert b'Hello, World!' in response - # Send SIGINT - proc.send_signal(signal.SIGINT) + # Send SIGINT to the process group for reliable signal delivery + try: + os.killpg(os.getpgid(proc.pid), signal.SIGINT) + except (ProcessLookupError, OSError): + proc.send_signal(signal.SIGINT) # Wait for process to exit try: @@ -179,7 +195,7 @@ class TestSignalHandlingIntegration: response = make_request('127.0.0.1', port) assert b'Hello, World!' in response - # Send SIGHUP + # Send SIGHUP to the master process (not process group - only master handles reload) proc.send_signal(signal.SIGHUP) # Wait a moment for reload