mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 18:21:30 +08:00
feat(ctl): add gunicornc control interface
Add a control socket server and CLI client for runtime management of Gunicorn instances, similar to birdc for BIRD routing daemon. Features: - Control socket server running in arbiter process (asyncio/threaded) - gunicornc CLI with interactive and single-command modes - JSON protocol with length-prefixed framing - Commands: show workers/stats/config/listeners/dirty, worker add/remove/kill, dirty add/remove, reload, reopen, shutdown - Stats tracking (uptime, workers spawned/killed, reloads) - Configurable socket path and permissions New config options: - control_socket: Unix socket path (default: gunicorn.ctl) - control_socket_mode: Socket permissions (default: 0o600) - --no-control-socket: Disable control socket
This commit is contained in:
parent
3cba17b84a
commit
a57507c4e5
@ -74,6 +74,17 @@ class Arbiter:
|
||||
self.dirty_arbiter = None
|
||||
self.dirty_pidfile = None # Well-known location for orphan detection
|
||||
|
||||
# Control socket server
|
||||
self._control_server = None
|
||||
|
||||
# Stats tracking
|
||||
self._stats = {
|
||||
'start_time': None,
|
||||
'workers_spawned': 0,
|
||||
'workers_killed': 0,
|
||||
'reloads': 0,
|
||||
}
|
||||
|
||||
cwd = util.getcwd()
|
||||
|
||||
args = sys.argv[:]
|
||||
@ -133,6 +144,9 @@ class Arbiter:
|
||||
"""
|
||||
self.log.info("Starting gunicorn %s", __version__)
|
||||
|
||||
# Initialize stats tracking
|
||||
self._stats['start_time'] = time.time()
|
||||
|
||||
if 'GUNICORN_PID' in os.environ:
|
||||
self.master_pid = int(os.environ.get('GUNICORN_PID'))
|
||||
self.proc_name = self.proc_name + ".2"
|
||||
@ -179,6 +193,9 @@ class Arbiter:
|
||||
if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps:
|
||||
self.spawn_dirty_arbiter()
|
||||
|
||||
# Start control socket server
|
||||
self._start_control_server()
|
||||
|
||||
self.cfg.when_ready(self)
|
||||
|
||||
def init_signals(self):
|
||||
@ -351,6 +368,9 @@ class Arbiter:
|
||||
|
||||
def halt(self, reason=None, exit_status=0):
|
||||
""" halt arbiter """
|
||||
# Stop control socket server first
|
||||
self._stop_control_server()
|
||||
|
||||
self.stop()
|
||||
|
||||
log_func = self.log.info if exit_status == 0 else self.log.error
|
||||
@ -477,6 +497,9 @@ class Arbiter:
|
||||
os.execvpe(self.START_CTX[0], self.START_CTX['args'], environ)
|
||||
|
||||
def reload(self):
|
||||
# Track reload stats
|
||||
self._stats['reloads'] += 1
|
||||
|
||||
old_address = self.cfg.address
|
||||
|
||||
# reset old environment
|
||||
@ -667,6 +690,7 @@ class Arbiter:
|
||||
if pid != 0:
|
||||
worker.pid = pid
|
||||
self.WORKERS[pid] = worker
|
||||
self._stats['workers_spawned'] += 1
|
||||
return pid
|
||||
|
||||
# Do not inherit the temporary files of other workers
|
||||
@ -737,6 +761,9 @@ class Arbiter:
|
||||
"""
|
||||
try:
|
||||
os.kill(pid, sig)
|
||||
# Track kills only on SIGTERM/SIGKILL (actual termination signals)
|
||||
if sig in (signal.SIGTERM, signal.SIGKILL):
|
||||
self._stats['workers_killed'] += 1
|
||||
except OSError as e:
|
||||
if e.errno == errno.ESRCH:
|
||||
try:
|
||||
@ -906,3 +933,51 @@ class Arbiter:
|
||||
if self.cfg.dirty_workers > 0 and self.cfg.dirty_apps:
|
||||
self.log.info("Spawning dirty arbiter...")
|
||||
self.spawn_dirty_arbiter()
|
||||
|
||||
# =========================================================================
|
||||
# Control Socket Management
|
||||
# =========================================================================
|
||||
|
||||
def _get_control_socket_path(self):
|
||||
"""Get the control socket path, making relative paths absolute."""
|
||||
socket_path = self.cfg.control_socket
|
||||
if not os.path.isabs(socket_path):
|
||||
socket_path = os.path.join(util.getcwd(), socket_path)
|
||||
return socket_path
|
||||
|
||||
def _start_control_server(self):
|
||||
"""\
|
||||
Start the control socket server.
|
||||
|
||||
The server runs in a background thread and accepts commands
|
||||
via Unix socket.
|
||||
"""
|
||||
if self.cfg.control_socket_disable:
|
||||
self.log.debug("Control socket disabled")
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular imports and gevent compatibility
|
||||
from gunicorn.ctl.server import ControlSocketServer
|
||||
|
||||
socket_path = self._get_control_socket_path()
|
||||
socket_mode = self.cfg.control_socket_mode
|
||||
|
||||
try:
|
||||
self._control_server = ControlSocketServer(
|
||||
self, socket_path, socket_mode
|
||||
)
|
||||
self._control_server.start()
|
||||
except Exception as e:
|
||||
self.log.warning("Failed to start control socket: %s", e)
|
||||
self._control_server = None
|
||||
|
||||
def _stop_control_server(self):
|
||||
"""\
|
||||
Stop the control socket server.
|
||||
"""
|
||||
if self._control_server:
|
||||
try:
|
||||
self._control_server.stop()
|
||||
except Exception as e:
|
||||
self.log.debug("Error stopping control server: %s", e)
|
||||
self._control_server = None
|
||||
|
||||
@ -3113,3 +3113,63 @@ class DirtyWorkerExit(Setting):
|
||||
|
||||
.. versionadded:: 25.0.0
|
||||
"""
|
||||
|
||||
|
||||
# Control Socket Settings
|
||||
|
||||
class ControlSocket(Setting):
|
||||
name = "control_socket"
|
||||
section = "Control"
|
||||
cli = ["--control-socket"]
|
||||
meta = "PATH"
|
||||
validator = validate_string
|
||||
default = "gunicorn.ctl"
|
||||
desc = """\
|
||||
Unix socket path for control interface.
|
||||
|
||||
The control socket allows runtime management of Gunicorn via the
|
||||
``gunicornc`` command-line tool. Commands include viewing worker
|
||||
status, adjusting worker count, and graceful reload/shutdown.
|
||||
|
||||
By default, creates ``gunicorn.ctl`` in the working directory.
|
||||
Set an absolute path for a fixed location (e.g., ``/var/run/gunicorn.ctl``).
|
||||
|
||||
Use ``--no-control-socket`` to disable.
|
||||
|
||||
.. versionadded:: 25.1.0
|
||||
"""
|
||||
|
||||
|
||||
class ControlSocketMode(Setting):
|
||||
name = "control_socket_mode"
|
||||
section = "Control"
|
||||
cli = ["--control-socket-mode"]
|
||||
meta = "INT"
|
||||
validator = validate_pos_int
|
||||
type = auto_int
|
||||
default = 0o600
|
||||
desc = """\
|
||||
Permission mode for control socket.
|
||||
|
||||
Restricts who can connect to the control socket. Default ``0600``
|
||||
allows only the socket owner. Set to ``0660`` to allow group access.
|
||||
|
||||
.. versionadded:: 25.1.0
|
||||
"""
|
||||
|
||||
|
||||
class ControlSocketDisable(Setting):
|
||||
name = "control_socket_disable"
|
||||
section = "Control"
|
||||
cli = ["--no-control-socket"]
|
||||
validator = validate_bool
|
||||
action = "store_true"
|
||||
default = False
|
||||
desc = """\
|
||||
Disable control socket.
|
||||
|
||||
When set, no control socket is created and ``gunicornc`` cannot
|
||||
connect to this Gunicorn instance.
|
||||
|
||||
.. versionadded:: 25.1.0
|
||||
"""
|
||||
|
||||
16
gunicorn/ctl/__init__.py
Normal file
16
gunicorn/ctl/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Gunicorn Control Interface
|
||||
|
||||
Provides a control socket server for runtime management and
|
||||
a CLI client (gunicornc) for interacting with running Gunicorn instances.
|
||||
"""
|
||||
|
||||
from gunicorn.ctl.server import ControlSocketServer
|
||||
from gunicorn.ctl.client import ControlClient
|
||||
from gunicorn.ctl.protocol import ControlProtocol
|
||||
|
||||
__all__ = ['ControlSocketServer', 'ControlClient', 'ControlProtocol']
|
||||
385
gunicorn/ctl/cli.py
Normal file
385
gunicorn/ctl/cli.py
Normal file
@ -0,0 +1,385 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
gunicornc - Gunicorn control interface CLI
|
||||
|
||||
Interactive and single-command modes for controlling Gunicorn instances.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
from gunicorn.ctl.client import ControlClient, ControlClientError, parse_command
|
||||
|
||||
|
||||
def format_workers(data: dict) -> str:
|
||||
"""Format workers output for display."""
|
||||
workers = data.get("workers", [])
|
||||
if not workers:
|
||||
return "No workers running"
|
||||
|
||||
lines = []
|
||||
lines.append(f"{'PID':<10} {'AGE':<6} {'BOOTED':<8} {'LAST_BEAT'}")
|
||||
lines.append("-" * 40)
|
||||
|
||||
for w in workers:
|
||||
pid = w.get("pid", "?")
|
||||
age = w.get("age", "?")
|
||||
booted = "yes" if w.get("booted") else "no"
|
||||
hb = w.get("last_heartbeat")
|
||||
hb_str = f"{hb}s ago" if hb is not None else "n/a"
|
||||
|
||||
lines.append(f"{pid:<10} {age:<6} {booted:<8} {hb_str}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"Total: {data.get('count', len(workers))} workers")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_dirty(data: dict) -> str:
|
||||
"""Format dirty workers output for display."""
|
||||
if not data.get("enabled"):
|
||||
return "Dirty arbiter not running"
|
||||
|
||||
lines = []
|
||||
lines.append(f"Dirty arbiter PID: {data.get('pid')}")
|
||||
lines.append("")
|
||||
|
||||
workers = data.get("workers", [])
|
||||
if workers:
|
||||
lines.append("DIRTY WORKERS:")
|
||||
lines.append(f"{'PID':<10} {'AGE':<6} {'APPS':<30} {'LAST_BEAT'}")
|
||||
lines.append("-" * 60)
|
||||
|
||||
for w in workers:
|
||||
pid = w.get("pid", "?")
|
||||
age = w.get("age", "?")
|
||||
apps = ", ".join(w.get("apps", []))[:30]
|
||||
hb = w.get("last_heartbeat")
|
||||
hb_str = f"{hb}s ago" if hb is not None else "n/a"
|
||||
|
||||
lines.append(f"{pid:<10} {age:<6} {apps:<30} {hb_str}")
|
||||
lines.append("")
|
||||
|
||||
apps = data.get("apps", [])
|
||||
if apps:
|
||||
lines.append("DIRTY APPS:")
|
||||
lines.append(f"{'APP':<30} {'WORKERS':<10} {'LIMIT'}")
|
||||
lines.append("-" * 50)
|
||||
|
||||
for app in apps:
|
||||
path = app.get("import_path", "?")[:30]
|
||||
current = app.get("current_workers", 0)
|
||||
limit = app.get("worker_count")
|
||||
limit_str = str(limit) if limit is not None else "none"
|
||||
|
||||
lines.append(f"{path:<30} {current:<10} {limit_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_stats(data: dict) -> str:
|
||||
"""Format stats output for display."""
|
||||
lines = []
|
||||
|
||||
uptime = data.get("uptime")
|
||||
if uptime:
|
||||
hours = int(uptime // 3600)
|
||||
minutes = int((uptime % 3600) // 60)
|
||||
seconds = int(uptime % 60)
|
||||
if hours:
|
||||
uptime_str = f"{hours}h {minutes}m {seconds}s"
|
||||
elif minutes:
|
||||
uptime_str = f"{minutes}m {seconds}s"
|
||||
else:
|
||||
uptime_str = f"{seconds}s"
|
||||
else:
|
||||
uptime_str = "unknown"
|
||||
|
||||
lines.append(f"Uptime: {uptime_str}")
|
||||
lines.append(f"PID: {data.get('pid', 'unknown')}")
|
||||
lines.append(f"Workers current: {data.get('workers_current', 0)}")
|
||||
lines.append(f"Workers target: {data.get('workers_target', 0)}")
|
||||
lines.append(f"Workers spawned: {data.get('workers_spawned', 0)}")
|
||||
lines.append(f"Workers killed: {data.get('workers_killed', 0)}")
|
||||
lines.append(f"Reloads: {data.get('reloads', 0)}")
|
||||
|
||||
dirty_pid = data.get("dirty_arbiter_pid")
|
||||
if dirty_pid:
|
||||
lines.append(f"Dirty arbiter: {dirty_pid}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_listeners(data: dict) -> str:
|
||||
"""Format listeners output for display."""
|
||||
listeners = data.get("listeners", [])
|
||||
if not listeners:
|
||||
return "No listeners bound"
|
||||
|
||||
lines = []
|
||||
lines.append(f"{'ADDRESS':<40} {'TYPE':<8} {'FD'}")
|
||||
lines.append("-" * 55)
|
||||
|
||||
for lnr in listeners:
|
||||
addr = lnr.get("address", "?")
|
||||
ltype = lnr.get("type", "?")
|
||||
fd = lnr.get("fd", "?")
|
||||
lines.append(f"{addr:<40} {ltype:<8} {fd}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"Total: {data.get('count', len(listeners))} listeners")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_config(data: dict) -> str:
|
||||
"""Format config output for display."""
|
||||
lines = []
|
||||
|
||||
# Sort keys for consistent output
|
||||
for key in sorted(data.keys()):
|
||||
value = data[key]
|
||||
if isinstance(value, list):
|
||||
value = ", ".join(str(v) for v in value)
|
||||
lines.append(f"{key}: {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_help(data: dict) -> str:
|
||||
"""Format help output for display."""
|
||||
commands = data.get("commands", {})
|
||||
lines = []
|
||||
lines.append("Available commands:")
|
||||
lines.append("")
|
||||
|
||||
# Find max command length for alignment
|
||||
max_len = max(len(cmd) for cmd in commands.keys()) if commands else 0
|
||||
|
||||
for cmd, desc in sorted(commands.items()):
|
||||
lines.append(f" {cmd:<{max_len + 2}} {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_response(command: str, data: dict) -> str:
|
||||
"""
|
||||
Format response data based on command.
|
||||
|
||||
Args:
|
||||
command: Original command string
|
||||
data: Response data dictionary
|
||||
|
||||
Returns:
|
||||
Formatted string for display
|
||||
"""
|
||||
cmd_lower = command.lower().strip()
|
||||
|
||||
# Route to specific formatters
|
||||
if cmd_lower == "show workers":
|
||||
return format_workers(data)
|
||||
elif cmd_lower == "show dirty":
|
||||
return format_dirty(data)
|
||||
elif cmd_lower == "show stats":
|
||||
return format_stats(data)
|
||||
elif cmd_lower == "show listeners":
|
||||
return format_listeners(data)
|
||||
elif cmd_lower == "show config":
|
||||
return format_config(data)
|
||||
elif cmd_lower == "help":
|
||||
return format_help(data)
|
||||
else:
|
||||
# Generic JSON output for other commands
|
||||
if data:
|
||||
return json.dumps(data, indent=2)
|
||||
return "OK"
|
||||
|
||||
|
||||
def run_command(socket_path: str, command: str, json_output: bool = False) -> int:
|
||||
"""
|
||||
Execute single command and exit.
|
||||
|
||||
Args:
|
||||
socket_path: Path to control socket
|
||||
command: Command to execute
|
||||
json_output: If True, output raw JSON
|
||||
|
||||
Returns:
|
||||
Exit code (0 for success, 1 for error)
|
||||
"""
|
||||
try:
|
||||
with ControlClient(socket_path) as client:
|
||||
cmd, args = parse_command(command)
|
||||
full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd
|
||||
result = client.send_command(full_command)
|
||||
|
||||
if json_output:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
output = format_response(cmd, result)
|
||||
print(output)
|
||||
|
||||
return 0
|
||||
|
||||
except ControlClientError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
def run_interactive(socket_path: str, json_output: bool = False) -> int:
|
||||
"""
|
||||
Run interactive CLI with readline support.
|
||||
|
||||
Args:
|
||||
socket_path: Path to control socket
|
||||
json_output: If True, output raw JSON
|
||||
|
||||
Returns:
|
||||
Exit code
|
||||
"""
|
||||
try:
|
||||
import readline # noqa: F401 - imported for side effects
|
||||
has_readline = True
|
||||
except ImportError:
|
||||
has_readline = False
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path)
|
||||
client.connect()
|
||||
except ControlClientError as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
print(f"Connected to {socket_path}")
|
||||
print("Type 'help' for available commands, 'quit' to exit.")
|
||||
print()
|
||||
|
||||
# Set up readline history
|
||||
history_file = os.path.expanduser("~/.gunicornc_history")
|
||||
if has_readline:
|
||||
try:
|
||||
readline.read_history_file(history_file)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
exit_code = 0
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
line = input("gunicorn> ").strip()
|
||||
except EOFError:
|
||||
print()
|
||||
break
|
||||
|
||||
if not line:
|
||||
continue
|
||||
|
||||
if line.lower() in ('quit', 'exit', 'q'):
|
||||
break
|
||||
|
||||
try:
|
||||
cmd, args = parse_command(line)
|
||||
full_command = f"{cmd} {' '.join(args)}".strip() if args else cmd
|
||||
result = client.send_command(full_command)
|
||||
|
||||
if json_output:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
output = format_response(cmd, result)
|
||||
print(output)
|
||||
|
||||
except ControlClientError as e:
|
||||
print(f"Error: {e}")
|
||||
# Try to reconnect
|
||||
try:
|
||||
client.close()
|
||||
client.connect()
|
||||
except ControlClientError:
|
||||
print("Connection lost. Exiting.")
|
||||
exit_code = 1
|
||||
break
|
||||
|
||||
print()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print()
|
||||
exit_code = 130
|
||||
finally:
|
||||
client.close()
|
||||
if has_readline:
|
||||
try:
|
||||
readline.write_history_file(history_file)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return exit_code
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point for gunicornc CLI."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Gunicorn control interface',
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
gunicornc # Interactive mode (default socket)
|
||||
gunicornc -s /tmp/myapp.ctl # Interactive mode with custom socket
|
||||
gunicornc -c "show workers" # Single command mode
|
||||
gunicornc -c "worker add 2" # Add 2 workers
|
||||
gunicornc -c "show stats" -j # Output stats as JSON
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-s', '--socket',
|
||||
default='gunicorn.ctl',
|
||||
help='Control socket path (default: gunicorn.ctl in current directory)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-c', '--command',
|
||||
help='Execute single command and exit'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-j', '--json',
|
||||
action='store_true',
|
||||
help='Output raw JSON (for scripting)'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'-v', '--version',
|
||||
action='store_true',
|
||||
help='Show version and exit'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
from gunicorn import __version__
|
||||
print(f"gunicornc (gunicorn {__version__})")
|
||||
return 0
|
||||
|
||||
socket_path = args.socket
|
||||
|
||||
# Make relative paths absolute from cwd
|
||||
if not os.path.isabs(socket_path):
|
||||
socket_path = os.path.join(os.getcwd(), socket_path)
|
||||
|
||||
if args.command:
|
||||
return run_command(socket_path, args.command, args.json)
|
||||
else:
|
||||
return run_interactive(socket_path, args.json)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
140
gunicorn/ctl/client.py
Normal file
140
gunicorn/ctl/client.py
Normal file
@ -0,0 +1,140 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Control Socket Client
|
||||
|
||||
Client library for connecting to gunicorn control socket.
|
||||
"""
|
||||
|
||||
import shlex
|
||||
import socket
|
||||
|
||||
from gunicorn.ctl.protocol import (
|
||||
ControlProtocol,
|
||||
make_request,
|
||||
)
|
||||
|
||||
|
||||
class ControlClientError(Exception):
|
||||
"""Control client error."""
|
||||
pass
|
||||
|
||||
|
||||
class ControlClient:
|
||||
"""
|
||||
Client for connecting to gunicorn control socket.
|
||||
|
||||
Can be used as a context manager:
|
||||
|
||||
with ControlClient('/path/to/gunicorn.ctl') as client:
|
||||
result = client.send_command('show workers')
|
||||
"""
|
||||
|
||||
def __init__(self, socket_path: str, timeout: float = 30.0):
|
||||
"""
|
||||
Initialize control client.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the Unix socket
|
||||
timeout: Socket timeout in seconds (default 30)
|
||||
"""
|
||||
self.socket_path = socket_path
|
||||
self.timeout = timeout
|
||||
self._sock = None
|
||||
self._request_id = 0
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to control socket.
|
||||
|
||||
Raises:
|
||||
ControlClientError: If connection fails
|
||||
"""
|
||||
if self._sock:
|
||||
return
|
||||
|
||||
try:
|
||||
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self._sock.settimeout(self.timeout)
|
||||
self._sock.connect(self.socket_path)
|
||||
except socket.error as e:
|
||||
self._sock = None
|
||||
raise ControlClientError(f"Failed to connect to {self.socket_path}: {e}")
|
||||
|
||||
def close(self):
|
||||
"""Close connection."""
|
||||
if self._sock:
|
||||
try:
|
||||
self._sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._sock = None
|
||||
|
||||
def send_command(self, command: str, args: list = None) -> dict:
|
||||
"""
|
||||
Send command and wait for response.
|
||||
|
||||
Args:
|
||||
command: Command string (e.g., "show workers")
|
||||
args: Optional additional arguments
|
||||
|
||||
Returns:
|
||||
Response data dictionary
|
||||
|
||||
Raises:
|
||||
ControlClientError: If communication fails
|
||||
"""
|
||||
if not self._sock:
|
||||
self.connect()
|
||||
|
||||
self._request_id += 1
|
||||
request = make_request(self._request_id, command, args)
|
||||
|
||||
try:
|
||||
ControlProtocol.write_message(self._sock, request)
|
||||
response = ControlProtocol.read_message(self._sock)
|
||||
except Exception as e:
|
||||
self.close()
|
||||
raise ControlClientError(f"Communication error: {e}")
|
||||
|
||||
if response.get("status") == "error":
|
||||
raise ControlClientError(response.get("error", "Unknown error"))
|
||||
|
||||
return response.get("data", {})
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.close()
|
||||
|
||||
|
||||
def parse_command(line: str) -> tuple:
|
||||
"""
|
||||
Parse a command line into command and args.
|
||||
|
||||
Args:
|
||||
line: Command line string
|
||||
|
||||
Returns:
|
||||
Tuple of (command_string, args_list)
|
||||
"""
|
||||
parts = shlex.split(line)
|
||||
if not parts:
|
||||
return "", []
|
||||
|
||||
# Find where numeric/value args start
|
||||
command_parts = []
|
||||
args = []
|
||||
|
||||
for part in parts:
|
||||
# If we haven't hit args yet and this looks like a command word
|
||||
if not args and not part.isdigit() and not part.startswith('-'):
|
||||
command_parts.append(part)
|
||||
else:
|
||||
args.append(part)
|
||||
|
||||
return " ".join(command_parts), args
|
||||
431
gunicorn/ctl/handlers.py
Normal file
431
gunicorn/ctl/handlers.py
Normal file
@ -0,0 +1,431 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Control Interface Command Handlers
|
||||
|
||||
Provides handlers for all control commands with access to arbiter state.
|
||||
"""
|
||||
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
|
||||
|
||||
class CommandHandlers:
|
||||
"""
|
||||
Command handlers with access to arbiter state.
|
||||
|
||||
All handler methods return dictionaries that will be sent
|
||||
as the response data.
|
||||
"""
|
||||
|
||||
def __init__(self, arbiter):
|
||||
"""
|
||||
Initialize handlers with arbiter reference.
|
||||
|
||||
Args:
|
||||
arbiter: The Gunicorn arbiter instance
|
||||
"""
|
||||
self.arbiter = arbiter
|
||||
|
||||
def show_workers(self) -> dict:
|
||||
"""
|
||||
Return list of HTTP workers.
|
||||
|
||||
Returns:
|
||||
Dictionary with workers list containing:
|
||||
- pid: Worker process ID
|
||||
- age: Worker age (spawn order)
|
||||
- requests: Number of requests handled (if available)
|
||||
- booted: Whether worker has finished booting
|
||||
- last_heartbeat: Seconds since last heartbeat
|
||||
"""
|
||||
workers = []
|
||||
now = time.monotonic()
|
||||
|
||||
for pid, worker in self.arbiter.WORKERS.items():
|
||||
try:
|
||||
last_update = worker.tmp.last_update()
|
||||
last_heartbeat = round(now - last_update, 2)
|
||||
except (OSError, ValueError):
|
||||
last_heartbeat = None
|
||||
|
||||
workers.append({
|
||||
"pid": pid,
|
||||
"age": worker.age,
|
||||
"booted": worker.booted,
|
||||
"aborted": worker.aborted,
|
||||
"last_heartbeat": last_heartbeat,
|
||||
})
|
||||
|
||||
# Sort by age (oldest first)
|
||||
workers.sort(key=lambda w: w["age"])
|
||||
|
||||
return {"workers": workers, "count": len(workers)}
|
||||
|
||||
def show_dirty(self) -> dict:
|
||||
"""
|
||||
Return dirty workers and apps information.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- enabled: Whether dirty arbiter is running
|
||||
- pid: Dirty arbiter PID
|
||||
- workers: List of dirty worker info
|
||||
- apps: List of dirty app specs
|
||||
"""
|
||||
if not self.arbiter.dirty_arbiter_pid:
|
||||
return {
|
||||
"enabled": False,
|
||||
"pid": None,
|
||||
"workers": [],
|
||||
"apps": [],
|
||||
}
|
||||
|
||||
# Get dirty arbiter reference if available
|
||||
dirty_arbiter = getattr(self.arbiter, 'dirty_arbiter', None)
|
||||
|
||||
workers = []
|
||||
apps = []
|
||||
|
||||
if dirty_arbiter and hasattr(dirty_arbiter, 'workers'):
|
||||
now = time.monotonic()
|
||||
for pid, worker in dirty_arbiter.workers.items():
|
||||
try:
|
||||
last_update = worker.tmp.last_update()
|
||||
last_heartbeat = round(now - last_update, 2)
|
||||
except (OSError, ValueError, AttributeError):
|
||||
last_heartbeat = None
|
||||
|
||||
workers.append({
|
||||
"pid": pid,
|
||||
"age": worker.age,
|
||||
"apps": getattr(worker, 'app_paths', []),
|
||||
"booted": getattr(worker, 'booted', False),
|
||||
"last_heartbeat": last_heartbeat,
|
||||
})
|
||||
|
||||
# Get app specs
|
||||
if hasattr(dirty_arbiter, 'app_specs'):
|
||||
for path, spec in dirty_arbiter.app_specs.items():
|
||||
worker_pids = list(dirty_arbiter.app_worker_map.get(path, []))
|
||||
apps.append({
|
||||
"import_path": path,
|
||||
"worker_count": spec.get('worker_count'),
|
||||
"current_workers": len(worker_pids),
|
||||
"worker_pids": worker_pids,
|
||||
})
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"pid": self.arbiter.dirty_arbiter_pid,
|
||||
"workers": workers,
|
||||
"apps": apps,
|
||||
}
|
||||
|
||||
def show_config(self) -> dict:
|
||||
"""
|
||||
Return current effective configuration.
|
||||
|
||||
Returns:
|
||||
Dictionary of configuration values
|
||||
"""
|
||||
cfg = self.arbiter.cfg
|
||||
config = {}
|
||||
|
||||
# Get commonly needed config values
|
||||
config_keys = [
|
||||
'bind', 'workers', 'worker_class', 'threads', 'timeout',
|
||||
'graceful_timeout', 'keepalive', 'max_requests',
|
||||
'max_requests_jitter', 'worker_connections', 'preload_app',
|
||||
'daemon', 'pidfile', 'proc_name', 'reload',
|
||||
'dirty_workers', 'dirty_apps', 'dirty_timeout',
|
||||
'control_socket', 'control_socket_disable',
|
||||
]
|
||||
|
||||
for key in config_keys:
|
||||
try:
|
||||
value = getattr(cfg, key)
|
||||
# Convert non-serializable types
|
||||
if callable(value):
|
||||
value = str(value)
|
||||
elif hasattr(value, '__class__') and not isinstance(
|
||||
value, (str, int, float, bool, list, dict, type(None))):
|
||||
value = str(value)
|
||||
config[key] = value
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return config
|
||||
|
||||
def show_stats(self) -> dict:
|
||||
"""
|
||||
Return server statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
- uptime: Seconds since arbiter started
|
||||
- pid: Arbiter PID
|
||||
- workers_current: Current number of workers
|
||||
- workers_spawned: Total workers spawned
|
||||
- workers_killed: Total workers killed (if tracked)
|
||||
- reloads: Number of reloads (if tracked)
|
||||
"""
|
||||
stats = getattr(self.arbiter, '_stats', {})
|
||||
start_time = stats.get('start_time')
|
||||
|
||||
uptime = None
|
||||
if start_time:
|
||||
uptime = round(time.time() - start_time, 2)
|
||||
|
||||
return {
|
||||
"uptime": uptime,
|
||||
"pid": self.arbiter.pid,
|
||||
"workers_current": len(self.arbiter.WORKERS),
|
||||
"workers_target": self.arbiter.num_workers,
|
||||
"workers_spawned": stats.get('workers_spawned', 0),
|
||||
"workers_killed": stats.get('workers_killed', 0),
|
||||
"reloads": stats.get('reloads', 0),
|
||||
"dirty_arbiter_pid": self.arbiter.dirty_arbiter_pid or None,
|
||||
}
|
||||
|
||||
def show_listeners(self) -> dict:
|
||||
"""
|
||||
Return bound socket information.
|
||||
|
||||
Returns:
|
||||
Dictionary with listeners list
|
||||
"""
|
||||
listeners = []
|
||||
|
||||
for lnr in self.arbiter.LISTENERS:
|
||||
addr = str(lnr)
|
||||
listener_info = {
|
||||
"address": addr,
|
||||
"fd": lnr.fileno(),
|
||||
}
|
||||
|
||||
# Try to get socket family
|
||||
try:
|
||||
import socket
|
||||
sock = lnr.sock
|
||||
if sock.family == socket.AF_UNIX:
|
||||
listener_info["type"] = "unix"
|
||||
elif sock.family == socket.AF_INET:
|
||||
listener_info["type"] = "tcp"
|
||||
elif sock.family == socket.AF_INET6:
|
||||
listener_info["type"] = "tcp6"
|
||||
except Exception:
|
||||
listener_info["type"] = "unknown"
|
||||
|
||||
listeners.append(listener_info)
|
||||
|
||||
return {"listeners": listeners, "count": len(listeners)}
|
||||
|
||||
def worker_add(self, count: int = 1) -> dict:
|
||||
"""
|
||||
Increase worker count.
|
||||
|
||||
Args:
|
||||
count: Number of workers to add (default 1)
|
||||
|
||||
Returns:
|
||||
Dictionary with added count and new total
|
||||
"""
|
||||
count = max(1, int(count))
|
||||
old_count = self.arbiter.num_workers
|
||||
self.arbiter.num_workers += count
|
||||
|
||||
# Wake up the arbiter to spawn workers
|
||||
self.arbiter.wakeup()
|
||||
|
||||
return {
|
||||
"added": count,
|
||||
"previous": old_count,
|
||||
"total": self.arbiter.num_workers,
|
||||
}
|
||||
|
||||
def worker_remove(self, count: int = 1) -> dict:
|
||||
"""
|
||||
Decrease worker count.
|
||||
|
||||
Args:
|
||||
count: Number of workers to remove (default 1)
|
||||
|
||||
Returns:
|
||||
Dictionary with removed count and new total
|
||||
"""
|
||||
count = max(1, int(count))
|
||||
old_count = self.arbiter.num_workers
|
||||
|
||||
# Don't go below 1 worker
|
||||
new_count = max(1, old_count - count)
|
||||
actual_removed = old_count - new_count
|
||||
|
||||
self.arbiter.num_workers = new_count
|
||||
|
||||
# Wake up the arbiter to kill excess workers
|
||||
self.arbiter.wakeup()
|
||||
|
||||
return {
|
||||
"removed": actual_removed,
|
||||
"previous": old_count,
|
||||
"total": new_count,
|
||||
}
|
||||
|
||||
def worker_kill(self, pid: int) -> dict:
|
||||
"""
|
||||
Gracefully terminate a specific worker.
|
||||
|
||||
Args:
|
||||
pid: Worker process ID
|
||||
|
||||
Returns:
|
||||
Dictionary with killed PID or error
|
||||
"""
|
||||
pid = int(pid)
|
||||
|
||||
if pid not in self.arbiter.WORKERS:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Worker {pid} not found",
|
||||
}
|
||||
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
return {
|
||||
"success": True,
|
||||
"killed": pid,
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def dirty_add(self, count: int = 1) -> dict:
|
||||
"""
|
||||
Spawn additional dirty workers.
|
||||
|
||||
Args:
|
||||
count: Number of dirty workers to add (default 1)
|
||||
|
||||
Returns:
|
||||
Dictionary with added count or error
|
||||
"""
|
||||
if not self.arbiter.dirty_arbiter_pid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Dirty arbiter not running",
|
||||
}
|
||||
|
||||
# Send TTIN signals to dirty arbiter
|
||||
count = max(1, int(count))
|
||||
try:
|
||||
for _ in range(count):
|
||||
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTIN)
|
||||
return {
|
||||
"success": True,
|
||||
"added": count,
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def dirty_remove(self, count: int = 1) -> dict:
|
||||
"""
|
||||
Remove dirty workers.
|
||||
|
||||
Args:
|
||||
count: Number of dirty workers to remove (default 1)
|
||||
|
||||
Returns:
|
||||
Dictionary with removed count or error
|
||||
"""
|
||||
if not self.arbiter.dirty_arbiter_pid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Dirty arbiter not running",
|
||||
}
|
||||
|
||||
# Send TTOU signals to dirty arbiter
|
||||
count = max(1, int(count))
|
||||
try:
|
||||
for _ in range(count):
|
||||
os.kill(self.arbiter.dirty_arbiter_pid, signal.SIGTTOU)
|
||||
return {
|
||||
"success": True,
|
||||
"removed": count,
|
||||
}
|
||||
except OSError as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
def reload(self) -> dict:
|
||||
"""
|
||||
Trigger graceful reload (equivalent to SIGHUP).
|
||||
|
||||
Returns:
|
||||
Dictionary with status
|
||||
"""
|
||||
# Send HUP to self to trigger reload
|
||||
os.kill(self.arbiter.pid, signal.SIGHUP)
|
||||
return {"status": "reloading"}
|
||||
|
||||
def reopen(self) -> dict:
|
||||
"""
|
||||
Reopen log files (equivalent to SIGUSR1).
|
||||
|
||||
Returns:
|
||||
Dictionary with status
|
||||
"""
|
||||
os.kill(self.arbiter.pid, signal.SIGUSR1)
|
||||
return {"status": "reopening"}
|
||||
|
||||
def shutdown(self, mode: str = "graceful") -> dict:
|
||||
"""
|
||||
Initiate shutdown.
|
||||
|
||||
Args:
|
||||
mode: "graceful" (SIGTERM) or "quick" (SIGINT)
|
||||
|
||||
Returns:
|
||||
Dictionary with status
|
||||
"""
|
||||
if mode == "quick":
|
||||
os.kill(self.arbiter.pid, signal.SIGINT)
|
||||
else:
|
||||
os.kill(self.arbiter.pid, signal.SIGTERM)
|
||||
|
||||
return {"status": "shutting_down", "mode": mode}
|
||||
|
||||
def help(self) -> dict:
|
||||
"""
|
||||
Return list of available commands.
|
||||
|
||||
Returns:
|
||||
Dictionary with commands and descriptions
|
||||
"""
|
||||
commands = {
|
||||
"show workers": "List HTTP workers with their status",
|
||||
"show dirty": "List dirty workers and apps",
|
||||
"show config": "Show current effective configuration",
|
||||
"show stats": "Show server statistics",
|
||||
"show listeners": "Show bound sockets",
|
||||
"worker add [N]": "Spawn N workers (default 1)",
|
||||
"worker remove [N]": "Remove N workers (default 1)",
|
||||
"worker kill <PID>": "Gracefully terminate specific worker",
|
||||
"dirty add [N]": "Spawn N dirty workers (default 1)",
|
||||
"dirty remove [N]": "Remove N dirty workers (default 1)",
|
||||
"reload": "Graceful reload (HUP)",
|
||||
"reopen": "Reopen log files (USR1)",
|
||||
"shutdown [graceful|quick]": "Shutdown server (TERM/INT)",
|
||||
"help": "Show this help message",
|
||||
}
|
||||
return {"commands": commands}
|
||||
225
gunicorn/ctl/protocol.py
Normal file
225
gunicorn/ctl/protocol.py
Normal file
@ -0,0 +1,225 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Control Socket Protocol
|
||||
|
||||
JSON-based protocol with length-prefixed framing for the control interface.
|
||||
|
||||
Message Format:
|
||||
+----------------+------------------+
|
||||
| Length (4B BE) | JSON Payload |
|
||||
+----------------+------------------+
|
||||
|
||||
Request Format:
|
||||
{"id": 1, "command": "show", "args": ["workers"]}
|
||||
|
||||
Response Format:
|
||||
{"id": 1, "status": "ok", "data": {...}}
|
||||
{"id": 1, "status": "error", "error": "message"}
|
||||
"""
|
||||
|
||||
import json
|
||||
import struct
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
"""Protocol-level error."""
|
||||
pass
|
||||
|
||||
|
||||
class ControlProtocol:
|
||||
"""
|
||||
Protocol implementation for control socket communication.
|
||||
|
||||
Uses 4-byte big-endian length prefix followed by JSON payload.
|
||||
"""
|
||||
|
||||
# Maximum message size (16 MB)
|
||||
MAX_MESSAGE_SIZE = 16 * 1024 * 1024
|
||||
|
||||
@staticmethod
|
||||
def encode_message(data: dict) -> bytes:
|
||||
"""
|
||||
Encode a message for transmission.
|
||||
|
||||
Args:
|
||||
data: Dictionary to encode
|
||||
|
||||
Returns:
|
||||
Length-prefixed JSON bytes
|
||||
"""
|
||||
payload = json.dumps(data).encode('utf-8')
|
||||
length = struct.pack('>I', len(payload))
|
||||
return length + payload
|
||||
|
||||
@staticmethod
|
||||
def decode_message(data: bytes) -> dict:
|
||||
"""
|
||||
Decode a message from bytes.
|
||||
|
||||
Args:
|
||||
data: Raw bytes (length prefix + JSON payload)
|
||||
|
||||
Returns:
|
||||
Decoded dictionary
|
||||
"""
|
||||
if len(data) < 4:
|
||||
raise ProtocolError("Message too short")
|
||||
|
||||
length = struct.unpack('>I', data[:4])[0]
|
||||
if len(data) < 4 + length:
|
||||
raise ProtocolError("Incomplete message")
|
||||
|
||||
payload = data[4:4 + length]
|
||||
return json.loads(payload.decode('utf-8'))
|
||||
|
||||
@staticmethod
|
||||
def read_message(sock) -> dict:
|
||||
"""
|
||||
Read one message from a socket.
|
||||
|
||||
Args:
|
||||
sock: Socket to read from
|
||||
|
||||
Returns:
|
||||
Decoded message dictionary
|
||||
|
||||
Raises:
|
||||
ProtocolError: If message is malformed
|
||||
ConnectionError: If connection is closed
|
||||
"""
|
||||
# Read length prefix
|
||||
length_data = b''
|
||||
while len(length_data) < 4:
|
||||
chunk = sock.recv(4 - len(length_data))
|
||||
if not chunk:
|
||||
if not length_data:
|
||||
raise ConnectionError("Connection closed")
|
||||
raise ProtocolError("Incomplete length prefix")
|
||||
length_data += chunk
|
||||
|
||||
length = struct.unpack('>I', length_data)[0]
|
||||
|
||||
if length > ControlProtocol.MAX_MESSAGE_SIZE:
|
||||
raise ProtocolError(f"Message too large: {length}")
|
||||
|
||||
# Read payload
|
||||
payload_data = b''
|
||||
while len(payload_data) < length:
|
||||
chunk = sock.recv(min(length - len(payload_data), 65536))
|
||||
if not chunk:
|
||||
raise ProtocolError("Incomplete payload")
|
||||
payload_data += chunk
|
||||
|
||||
try:
|
||||
return json.loads(payload_data.decode('utf-8'))
|
||||
except json.JSONDecodeError as e:
|
||||
raise ProtocolError(f"Invalid JSON: {e}")
|
||||
|
||||
@staticmethod
|
||||
def write_message(sock, data: dict):
|
||||
"""
|
||||
Write one message to a socket.
|
||||
|
||||
Args:
|
||||
sock: Socket to write to
|
||||
data: Message dictionary to send
|
||||
"""
|
||||
message = ControlProtocol.encode_message(data)
|
||||
sock.sendall(message)
|
||||
|
||||
@staticmethod
|
||||
async def read_message_async(reader) -> dict:
|
||||
"""
|
||||
Read one message from an async reader.
|
||||
|
||||
Args:
|
||||
reader: asyncio StreamReader
|
||||
|
||||
Returns:
|
||||
Decoded message dictionary
|
||||
"""
|
||||
# Read length prefix
|
||||
length_data = await reader.readexactly(4)
|
||||
length = struct.unpack('>I', length_data)[0]
|
||||
|
||||
if length > ControlProtocol.MAX_MESSAGE_SIZE:
|
||||
raise ProtocolError(f"Message too large: {length}")
|
||||
|
||||
# Read payload
|
||||
payload_data = await reader.readexactly(length)
|
||||
|
||||
try:
|
||||
return json.loads(payload_data.decode('utf-8'))
|
||||
except json.JSONDecodeError as e:
|
||||
raise ProtocolError(f"Invalid JSON: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def write_message_async(writer, data: dict):
|
||||
"""
|
||||
Write one message to an async writer.
|
||||
|
||||
Args:
|
||||
writer: asyncio StreamWriter
|
||||
data: Message dictionary to send
|
||||
"""
|
||||
message = ControlProtocol.encode_message(data)
|
||||
writer.write(message)
|
||||
await writer.drain()
|
||||
|
||||
|
||||
def make_request(request_id: int, command: str, args: list = None) -> dict:
|
||||
"""
|
||||
Create a request message.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier
|
||||
command: Command name (e.g., "show workers")
|
||||
args: Optional list of arguments
|
||||
|
||||
Returns:
|
||||
Request dictionary
|
||||
"""
|
||||
return {
|
||||
"id": request_id,
|
||||
"command": command,
|
||||
"args": args or [],
|
||||
}
|
||||
|
||||
|
||||
def make_response(request_id: int, data: dict = None) -> dict:
|
||||
"""
|
||||
Create a success response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier being responded to
|
||||
data: Response data
|
||||
|
||||
Returns:
|
||||
Response dictionary
|
||||
"""
|
||||
return {
|
||||
"id": request_id,
|
||||
"status": "ok",
|
||||
"data": data or {},
|
||||
}
|
||||
|
||||
|
||||
def make_error_response(request_id: int, error: str) -> dict:
|
||||
"""
|
||||
Create an error response message.
|
||||
|
||||
Args:
|
||||
request_id: Request identifier being responded to
|
||||
error: Error message
|
||||
|
||||
Returns:
|
||||
Error response dictionary
|
||||
"""
|
||||
return {
|
||||
"id": request_id,
|
||||
"status": "error",
|
||||
"error": error,
|
||||
}
|
||||
299
gunicorn/ctl/server.py
Normal file
299
gunicorn/ctl/server.py
Normal file
@ -0,0 +1,299 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Control Socket Server
|
||||
|
||||
Runs in the arbiter process and accepts commands via Unix socket.
|
||||
Uses asyncio in a background thread to handle client connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shlex
|
||||
import threading
|
||||
|
||||
from gunicorn.ctl.handlers import CommandHandlers
|
||||
from gunicorn.ctl.protocol import (
|
||||
ControlProtocol,
|
||||
make_response,
|
||||
make_error_response,
|
||||
)
|
||||
|
||||
|
||||
class ControlSocketServer:
|
||||
"""
|
||||
Control socket server running in arbiter process.
|
||||
|
||||
The server runs an asyncio event loop in a background thread,
|
||||
accepting connections and dispatching commands to handlers.
|
||||
"""
|
||||
|
||||
def __init__(self, arbiter, socket_path, socket_mode=0o600):
|
||||
"""
|
||||
Initialize control socket server.
|
||||
|
||||
Args:
|
||||
arbiter: The Gunicorn arbiter instance
|
||||
socket_path: Path for the Unix socket
|
||||
socket_mode: Permission mode for socket (default 0o600)
|
||||
"""
|
||||
self.arbiter = arbiter
|
||||
self.socket_path = socket_path
|
||||
self.socket_mode = socket_mode
|
||||
|
||||
self.handlers = CommandHandlers(arbiter)
|
||||
self._server = None
|
||||
self._loop = None
|
||||
self._thread = None
|
||||
self._running = False
|
||||
|
||||
def start(self):
|
||||
"""Start server in background thread with asyncio event loop."""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop server and cleanup socket."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
|
||||
if self._loop and self._server:
|
||||
# Schedule server close in the loop
|
||||
self._loop.call_soon_threadsafe(self._shutdown)
|
||||
|
||||
if self._thread:
|
||||
self._thread.join(timeout=2.0)
|
||||
self._thread = None
|
||||
|
||||
# Clean up socket file
|
||||
if os.path.exists(self.socket_path):
|
||||
try:
|
||||
os.unlink(self.socket_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _shutdown(self):
|
||||
"""Shutdown server (called from event loop thread)."""
|
||||
if self._server:
|
||||
self._server.close()
|
||||
|
||||
def _run_loop(self):
|
||||
"""Run the asyncio event loop in background thread."""
|
||||
try:
|
||||
asyncio.run(self._serve())
|
||||
except Exception as e:
|
||||
if self.arbiter.log:
|
||||
self.arbiter.log.error("Control server error: %s", e)
|
||||
|
||||
async def _serve(self):
|
||||
"""Main async server loop."""
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
# Remove socket if it exists
|
||||
if os.path.exists(self.socket_path):
|
||||
os.unlink(self.socket_path)
|
||||
|
||||
# Create Unix socket server
|
||||
self._server = await asyncio.start_unix_server(
|
||||
self._handle_client,
|
||||
path=self.socket_path
|
||||
)
|
||||
|
||||
# Set socket permissions
|
||||
os.chmod(self.socket_path, self.socket_mode)
|
||||
|
||||
if self.arbiter.log:
|
||||
self.arbiter.log.info("Control socket listening at %s",
|
||||
self.socket_path)
|
||||
|
||||
try:
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
if os.path.exists(self.socket_path):
|
||||
try:
|
||||
os.unlink(self.socket_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _handle_client(self, reader, writer):
|
||||
"""
|
||||
Handle client connection.
|
||||
|
||||
Args:
|
||||
reader: asyncio StreamReader
|
||||
writer: asyncio StreamWriter
|
||||
"""
|
||||
try:
|
||||
while self._running:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
ControlProtocol.read_message_async(reader),
|
||||
timeout=300.0 # 5 minute idle timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# Client idle too long, close connection
|
||||
break
|
||||
except asyncio.IncompleteReadError:
|
||||
# Client disconnected
|
||||
break
|
||||
except Exception:
|
||||
# Protocol error
|
||||
break
|
||||
|
||||
# Process command
|
||||
response = await self._dispatch(message)
|
||||
|
||||
# Send response
|
||||
await ControlProtocol.write_message_async(writer, response)
|
||||
|
||||
except Exception as e:
|
||||
if self.arbiter.log:
|
||||
self.arbiter.log.debug("Control client error: %s", e)
|
||||
finally:
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _dispatch(self, message: dict) -> dict:
|
||||
"""
|
||||
Dispatch command to appropriate handler.
|
||||
|
||||
Args:
|
||||
message: Request message dict
|
||||
|
||||
Returns:
|
||||
Response dictionary
|
||||
"""
|
||||
request_id = message.get("id", 0)
|
||||
command = message.get("command", "").strip()
|
||||
args = message.get("args", [])
|
||||
|
||||
if not command:
|
||||
return make_error_response(request_id, "Empty command")
|
||||
|
||||
try:
|
||||
# Parse command (e.g., "show workers" or "worker add 2")
|
||||
parts = shlex.split(command)
|
||||
if args:
|
||||
parts.extend(str(a) for a in args)
|
||||
|
||||
if not parts:
|
||||
return make_error_response(request_id, "Empty command")
|
||||
|
||||
# Route to handler
|
||||
result = self._execute_command(parts)
|
||||
return make_response(request_id, result)
|
||||
|
||||
except ValueError as e:
|
||||
return make_error_response(request_id, f"Invalid argument: {e}")
|
||||
except Exception as e:
|
||||
if self.arbiter.log:
|
||||
self.arbiter.log.exception("Command error")
|
||||
return make_error_response(request_id, f"Command failed: {e}")
|
||||
|
||||
def _execute_command(self, parts: list) -> dict:
|
||||
"""
|
||||
Execute a parsed command.
|
||||
|
||||
Args:
|
||||
parts: Command parts (e.g., ["show", "workers"])
|
||||
|
||||
Returns:
|
||||
Handler result dictionary
|
||||
"""
|
||||
if not parts:
|
||||
raise ValueError("Empty command")
|
||||
|
||||
cmd = parts[0].lower()
|
||||
rest = parts[1:]
|
||||
|
||||
# Map commands to handlers
|
||||
if cmd == "show":
|
||||
return self._handle_show(rest)
|
||||
elif cmd == "worker":
|
||||
return self._handle_worker(rest)
|
||||
elif cmd == "dirty":
|
||||
return self._handle_dirty(rest)
|
||||
elif cmd == "reload":
|
||||
return self.handlers.reload()
|
||||
elif cmd == "reopen":
|
||||
return self.handlers.reopen()
|
||||
elif cmd == "shutdown":
|
||||
mode = rest[0] if rest else "graceful"
|
||||
return self.handlers.shutdown(mode)
|
||||
elif cmd == "help":
|
||||
return self.handlers.help()
|
||||
else:
|
||||
raise ValueError(f"Unknown command: {cmd}")
|
||||
|
||||
def _handle_show(self, args: list) -> dict:
|
||||
"""Handle 'show' commands."""
|
||||
if not args:
|
||||
raise ValueError("Missing show target (workers|dirty|config|stats|listeners)")
|
||||
|
||||
target = args[0].lower()
|
||||
|
||||
if target == "workers":
|
||||
return self.handlers.show_workers()
|
||||
elif target == "dirty":
|
||||
return self.handlers.show_dirty()
|
||||
elif target == "config":
|
||||
return self.handlers.show_config()
|
||||
elif target == "stats":
|
||||
return self.handlers.show_stats()
|
||||
elif target == "listeners":
|
||||
return self.handlers.show_listeners()
|
||||
else:
|
||||
raise ValueError(f"Unknown show target: {target}")
|
||||
|
||||
def _handle_worker(self, args: list) -> dict:
|
||||
"""Handle 'worker' commands."""
|
||||
if not args:
|
||||
raise ValueError("Missing worker action (add|remove|kill)")
|
||||
|
||||
action = args[0].lower()
|
||||
action_args = args[1:]
|
||||
|
||||
if action == "add":
|
||||
count = int(action_args[0]) if action_args else 1
|
||||
return self.handlers.worker_add(count)
|
||||
elif action == "remove":
|
||||
count = int(action_args[0]) if action_args else 1
|
||||
return self.handlers.worker_remove(count)
|
||||
elif action == "kill":
|
||||
if not action_args:
|
||||
raise ValueError("Missing PID for worker kill")
|
||||
pid = int(action_args[0])
|
||||
return self.handlers.worker_kill(pid)
|
||||
else:
|
||||
raise ValueError(f"Unknown worker action: {action}")
|
||||
|
||||
def _handle_dirty(self, args: list) -> dict:
|
||||
"""Handle 'dirty' commands."""
|
||||
if not args:
|
||||
raise ValueError("Missing dirty action (add|remove)")
|
||||
|
||||
action = args[0].lower()
|
||||
action_args = args[1:]
|
||||
|
||||
if action == "add":
|
||||
count = int(action_args[0]) if action_args else 1
|
||||
return self.handlers.dirty_add(count)
|
||||
elif action == "remove":
|
||||
count = int(action_args[0]) if action_args else 1
|
||||
return self.handlers.dirty_remove(count)
|
||||
else:
|
||||
raise ValueError(f"Unknown dirty action: {action}")
|
||||
@ -367,7 +367,8 @@ class DirtyArbiter:
|
||||
try:
|
||||
async with self._server:
|
||||
await self._server.serve_forever()
|
||||
except asyncio.CancelledError:
|
||||
except (asyncio.CancelledError, RuntimeError):
|
||||
# RuntimeError raised when server.close() is called during serve_forever()
|
||||
pass
|
||||
finally:
|
||||
monitor_task.cancel()
|
||||
@ -836,19 +837,19 @@ class DirtyArbiter:
|
||||
pid, app_paths)
|
||||
return pid
|
||||
|
||||
# Child process
|
||||
# Child process - use os._exit() to avoid asyncio cleanup issues
|
||||
worker.pid = os.getpid()
|
||||
try:
|
||||
util._setproctitle(f"dirty-worker [{self.cfg.proc_name}]")
|
||||
worker.init_process()
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
raise
|
||||
os._exit(0)
|
||||
except SystemExit as e:
|
||||
os._exit(e.code if e.code is not None else 0)
|
||||
except Exception:
|
||||
self.log.exception("Exception in dirty worker process")
|
||||
if not worker.booted:
|
||||
sys.exit(self.WORKER_BOOT_ERROR)
|
||||
sys.exit(-1)
|
||||
os._exit(self.WORKER_BOOT_ERROR)
|
||||
os._exit(1)
|
||||
|
||||
def kill_worker(self, pid, sig):
|
||||
"""Kill a worker by PID."""
|
||||
|
||||
@ -68,6 +68,7 @@ testing = [
|
||||
[project.scripts]
|
||||
# duplicates "python -m gunicorn" handling in __main__.py
|
||||
gunicorn = "gunicorn.app.wsgiapp:run"
|
||||
gunicornc = "gunicorn.ctl.cli:main"
|
||||
|
||||
# note the quotes around "paste.server_runner" to escape the dot
|
||||
[project.entry-points."paste.server_runner"]
|
||||
|
||||
3
tests/ctl/__init__.py
Normal file
3
tests/ctl/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
275
tests/ctl/test_client.py
Normal file
275
tests/ctl/test_client.py
Normal file
@ -0,0 +1,275 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for control socket client."""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import tempfile
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.ctl.client import (
|
||||
ControlClient,
|
||||
ControlClientError,
|
||||
parse_command,
|
||||
)
|
||||
from gunicorn.ctl.protocol import ControlProtocol, make_response
|
||||
|
||||
|
||||
class TestControlClientInit:
|
||||
"""Tests for ControlClient initialization."""
|
||||
|
||||
def test_init_attributes(self):
|
||||
"""Test that client is initialized with correct attributes."""
|
||||
client = ControlClient("/tmp/test.sock", timeout=60.0)
|
||||
|
||||
assert client.socket_path == "/tmp/test.sock"
|
||||
assert client.timeout == 60.0
|
||||
assert client._sock is None
|
||||
assert client._request_id == 0
|
||||
|
||||
|
||||
class TestControlClientConnect:
|
||||
"""Tests for ControlClient connection."""
|
||||
|
||||
def test_connect_nonexistent_socket(self):
|
||||
"""Test connecting to non-existent socket."""
|
||||
client = ControlClient("/nonexistent/socket.sock")
|
||||
|
||||
with pytest.raises(ControlClientError) as exc_info:
|
||||
client.connect()
|
||||
|
||||
assert "Failed to connect" in str(exc_info.value)
|
||||
|
||||
def test_connect_success(self):
|
||||
"""Test successful connection."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
# Create a listening socket
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path)
|
||||
client.connect()
|
||||
|
||||
assert client._sock is not None
|
||||
client.close()
|
||||
finally:
|
||||
server_sock.close()
|
||||
|
||||
def test_connect_already_connected(self):
|
||||
"""Test that connect is idempotent."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path)
|
||||
client.connect()
|
||||
first_sock = client._sock
|
||||
client.connect() # Should not create new connection
|
||||
|
||||
assert client._sock is first_sock
|
||||
client.close()
|
||||
finally:
|
||||
server_sock.close()
|
||||
|
||||
|
||||
class TestControlClientClose:
|
||||
"""Tests for ControlClient close."""
|
||||
|
||||
def test_close_idempotent(self):
|
||||
"""Test that close can be called multiple times."""
|
||||
client = ControlClient("/tmp/test.sock")
|
||||
client.close()
|
||||
client.close() # Should not raise
|
||||
|
||||
def test_close_clears_socket(self):
|
||||
"""Test that close clears the socket."""
|
||||
client = ControlClient("/tmp/test.sock")
|
||||
client._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
client.close()
|
||||
|
||||
assert client._sock is None
|
||||
|
||||
|
||||
class TestControlClientContextManager:
|
||||
"""Tests for context manager functionality."""
|
||||
|
||||
def test_context_manager_connection_error(self):
|
||||
"""Test context manager with connection error."""
|
||||
client = ControlClient("/nonexistent/socket.sock")
|
||||
|
||||
with pytest.raises(ControlClientError):
|
||||
with client:
|
||||
pass
|
||||
|
||||
def test_context_manager_success(self):
|
||||
"""Test successful context manager usage."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path) as client:
|
||||
assert client._sock is not None
|
||||
|
||||
# After context manager exits, socket should be closed
|
||||
assert client._sock is None
|
||||
finally:
|
||||
server_sock.close()
|
||||
|
||||
|
||||
class TestControlClientSendCommand:
|
||||
"""Tests for send_command functionality."""
|
||||
|
||||
def test_send_command_success(self):
|
||||
"""Test successful command send."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
response_data = {"workers": [], "count": 0}
|
||||
response_sent = threading.Event()
|
||||
|
||||
def server_handler():
|
||||
conn, _ = server_sock.accept()
|
||||
try:
|
||||
msg = ControlProtocol.read_message(conn)
|
||||
resp = make_response(msg["id"], response_data)
|
||||
ControlProtocol.write_message(conn, resp)
|
||||
response_sent.set()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
server_thread = threading.Thread(target=server_handler)
|
||||
server_thread.start()
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path, timeout=5.0)
|
||||
result = client.send_command("show workers")
|
||||
|
||||
assert result == response_data
|
||||
client.close()
|
||||
finally:
|
||||
response_sent.wait(timeout=2.0)
|
||||
server_thread.join(timeout=2.0)
|
||||
server_sock.close()
|
||||
|
||||
def test_send_command_error_response(self):
|
||||
"""Test handling error response."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
def server_handler():
|
||||
conn, _ = server_sock.accept()
|
||||
try:
|
||||
msg = ControlProtocol.read_message(conn)
|
||||
resp = {
|
||||
"id": msg["id"],
|
||||
"status": "error",
|
||||
"error": "Unknown command",
|
||||
}
|
||||
ControlProtocol.write_message(conn, resp)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
server_thread = threading.Thread(target=server_handler)
|
||||
server_thread.start()
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path, timeout=5.0)
|
||||
|
||||
with pytest.raises(ControlClientError) as exc_info:
|
||||
client.send_command("invalid command")
|
||||
|
||||
assert "Unknown command" in str(exc_info.value)
|
||||
client.close()
|
||||
finally:
|
||||
server_thread.join(timeout=2.0)
|
||||
server_sock.close()
|
||||
|
||||
def test_send_command_auto_connect(self):
|
||||
"""Test that send_command auto-connects if not connected."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
server_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server_sock.bind(socket_path)
|
||||
server_sock.listen(1)
|
||||
|
||||
def server_handler():
|
||||
conn, _ = server_sock.accept()
|
||||
try:
|
||||
msg = ControlProtocol.read_message(conn)
|
||||
resp = make_response(msg["id"], {})
|
||||
ControlProtocol.write_message(conn, resp)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
server_thread = threading.Thread(target=server_handler)
|
||||
server_thread.start()
|
||||
|
||||
try:
|
||||
client = ControlClient(socket_path, timeout=5.0)
|
||||
# Don't call connect() explicitly
|
||||
result = client.send_command("help")
|
||||
|
||||
assert isinstance(result, dict)
|
||||
client.close()
|
||||
finally:
|
||||
server_thread.join(timeout=2.0)
|
||||
server_sock.close()
|
||||
|
||||
|
||||
class TestParseCommand:
|
||||
"""Tests for command parsing."""
|
||||
|
||||
def test_parse_simple_command(self):
|
||||
"""Test parsing simple command."""
|
||||
cmd, args = parse_command("show workers")
|
||||
assert cmd == "show workers"
|
||||
assert args == []
|
||||
|
||||
def test_parse_command_with_args(self):
|
||||
"""Test parsing command with arguments."""
|
||||
cmd, args = parse_command("worker add 2")
|
||||
assert cmd == "worker add"
|
||||
assert args == ["2"]
|
||||
|
||||
def test_parse_command_with_multiple_args(self):
|
||||
"""Test parsing command with multiple arguments."""
|
||||
cmd, args = parse_command("worker kill 12345")
|
||||
assert cmd == "worker kill"
|
||||
assert args == ["12345"]
|
||||
|
||||
def test_parse_empty_command(self):
|
||||
"""Test parsing empty command."""
|
||||
cmd, args = parse_command("")
|
||||
assert cmd == ""
|
||||
assert args == []
|
||||
|
||||
def test_parse_command_quoted(self):
|
||||
"""Test parsing command with quoted arguments."""
|
||||
cmd, args = parse_command('worker kill "12345"')
|
||||
assert cmd == "worker kill"
|
||||
assert args == ["12345"]
|
||||
374
tests/ctl/test_handlers.py
Normal file
374
tests/ctl/test_handlers.py
Normal file
@ -0,0 +1,374 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for control socket command handlers."""
|
||||
|
||||
import signal
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.ctl.handlers import CommandHandlers
|
||||
|
||||
|
||||
class MockWorker:
|
||||
"""Mock worker for testing."""
|
||||
|
||||
def __init__(self, pid, age, booted=True, aborted=False):
|
||||
self.pid = pid
|
||||
self.age = age
|
||||
self.booted = booted
|
||||
self.aborted = aborted
|
||||
self.tmp = MagicMock()
|
||||
self.tmp.last_update.return_value = time.monotonic()
|
||||
|
||||
|
||||
class MockListener:
|
||||
"""Mock listener for testing."""
|
||||
|
||||
def __init__(self, address, fd=3):
|
||||
self._address = address
|
||||
self._fd = fd
|
||||
self.sock = MagicMock()
|
||||
self.sock.family = 2 # AF_INET
|
||||
|
||||
def __str__(self):
|
||||
return self._address
|
||||
|
||||
def fileno(self):
|
||||
return self._fd
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.bind = ['127.0.0.1:8000']
|
||||
self.workers = 4
|
||||
self.worker_class = 'sync'
|
||||
self.threads = 1
|
||||
self.timeout = 30
|
||||
self.graceful_timeout = 30
|
||||
self.keepalive = 2
|
||||
self.max_requests = 0
|
||||
self.max_requests_jitter = 0
|
||||
self.worker_connections = 1000
|
||||
self.preload_app = False
|
||||
self.daemon = False
|
||||
self.pidfile = None
|
||||
self.proc_name = 'test_app'
|
||||
self.reload = False
|
||||
self.dirty_workers = 0
|
||||
self.dirty_apps = []
|
||||
self.dirty_timeout = 30
|
||||
self.control_socket = 'gunicorn.ctl'
|
||||
self.control_socket_disable = False
|
||||
|
||||
|
||||
class MockArbiter:
|
||||
"""Mock arbiter for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.cfg = MockConfig()
|
||||
self.pid = 12345
|
||||
self.WORKERS = {}
|
||||
self.LISTENERS = []
|
||||
self.dirty_arbiter_pid = 0
|
||||
self.dirty_arbiter = None
|
||||
self.num_workers = 4
|
||||
self._stats = {
|
||||
'start_time': time.time() - 3600, # 1 hour ago
|
||||
'workers_spawned': 10,
|
||||
'workers_killed': 5,
|
||||
'reloads': 2,
|
||||
}
|
||||
|
||||
def wakeup(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestShowWorkers:
|
||||
"""Tests for show workers command."""
|
||||
|
||||
def test_show_workers_empty(self):
|
||||
"""Test showing workers when none exist."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_workers()
|
||||
|
||||
assert result["workers"] == []
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_show_workers_with_workers(self):
|
||||
"""Test showing workers."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.WORKERS = {
|
||||
1001: MockWorker(1001, 1),
|
||||
1002: MockWorker(1002, 2),
|
||||
1003: MockWorker(1003, 3),
|
||||
}
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_workers()
|
||||
|
||||
assert result["count"] == 3
|
||||
assert len(result["workers"]) == 3
|
||||
|
||||
# Verify sorted by age
|
||||
ages = [w["age"] for w in result["workers"]]
|
||||
assert ages == sorted(ages)
|
||||
|
||||
# Verify worker data
|
||||
worker = result["workers"][0]
|
||||
assert "pid" in worker
|
||||
assert "age" in worker
|
||||
assert "booted" in worker
|
||||
assert "last_heartbeat" in worker
|
||||
|
||||
|
||||
class TestShowStats:
|
||||
"""Tests for show stats command."""
|
||||
|
||||
def test_show_stats(self):
|
||||
"""Test showing stats."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.WORKERS = {
|
||||
1001: MockWorker(1001, 1),
|
||||
1002: MockWorker(1002, 2),
|
||||
}
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_stats()
|
||||
|
||||
assert result["pid"] == 12345
|
||||
assert result["workers_current"] == 2
|
||||
assert result["workers_target"] == 4
|
||||
assert result["workers_spawned"] == 10
|
||||
assert result["workers_killed"] == 5
|
||||
assert result["reloads"] == 2
|
||||
assert result["uptime"] is not None
|
||||
assert result["uptime"] > 0
|
||||
|
||||
|
||||
class TestShowConfig:
|
||||
"""Tests for show config command."""
|
||||
|
||||
def test_show_config(self):
|
||||
"""Test showing config."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_config()
|
||||
|
||||
assert result["workers"] == 4
|
||||
assert result["timeout"] == 30
|
||||
assert result["bind"] == ['127.0.0.1:8000']
|
||||
|
||||
|
||||
class TestShowListeners:
|
||||
"""Tests for show listeners command."""
|
||||
|
||||
def test_show_listeners_empty(self):
|
||||
"""Test showing listeners when none exist."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_listeners()
|
||||
|
||||
assert result["listeners"] == []
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_show_listeners(self):
|
||||
"""Test showing listeners."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.LISTENERS = [
|
||||
MockListener("127.0.0.1:8000", fd=3),
|
||||
MockListener("127.0.0.1:8001", fd=4),
|
||||
]
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_listeners()
|
||||
|
||||
assert result["count"] == 2
|
||||
assert len(result["listeners"]) == 2
|
||||
assert result["listeners"][0]["address"] == "127.0.0.1:8000"
|
||||
|
||||
|
||||
class TestWorkerAdd:
|
||||
"""Tests for worker add command."""
|
||||
|
||||
def test_worker_add_default(self):
|
||||
"""Test adding one worker (default)."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.wakeup = MagicMock()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.worker_add()
|
||||
|
||||
assert result["added"] == 1
|
||||
assert result["previous"] == 4
|
||||
assert result["total"] == 5
|
||||
assert arbiter.num_workers == 5
|
||||
arbiter.wakeup.assert_called_once()
|
||||
|
||||
def test_worker_add_multiple(self):
|
||||
"""Test adding multiple workers."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.wakeup = MagicMock()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.worker_add(3)
|
||||
|
||||
assert result["added"] == 3
|
||||
assert result["total"] == 7
|
||||
|
||||
|
||||
class TestWorkerRemove:
|
||||
"""Tests for worker remove command."""
|
||||
|
||||
def test_worker_remove_default(self):
|
||||
"""Test removing one worker (default)."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.wakeup = MagicMock()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.worker_remove()
|
||||
|
||||
assert result["removed"] == 1
|
||||
assert result["previous"] == 4
|
||||
assert result["total"] == 3
|
||||
assert arbiter.num_workers == 3
|
||||
arbiter.wakeup.assert_called_once()
|
||||
|
||||
def test_worker_remove_cannot_go_below_one(self):
|
||||
"""Test that worker count cannot go below 1."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.num_workers = 2
|
||||
arbiter.wakeup = MagicMock()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.worker_remove(5)
|
||||
|
||||
assert result["removed"] == 1
|
||||
assert result["total"] == 1
|
||||
assert arbiter.num_workers == 1
|
||||
|
||||
|
||||
class TestWorkerKill:
|
||||
"""Tests for worker kill command."""
|
||||
|
||||
def test_worker_kill_success(self):
|
||||
"""Test killing a worker."""
|
||||
arbiter = MockArbiter()
|
||||
arbiter.WORKERS = {1001: MockWorker(1001, 1)}
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
with patch('os.kill') as mock_kill:
|
||||
result = handlers.worker_kill(1001)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["killed"] == 1001
|
||||
mock_kill.assert_called_once_with(1001, signal.SIGTERM)
|
||||
|
||||
def test_worker_kill_not_found(self):
|
||||
"""Test killing a non-existent worker."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.worker_kill(9999)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "not found" in result["error"]
|
||||
|
||||
|
||||
class TestShowDirty:
|
||||
"""Tests for show dirty command."""
|
||||
|
||||
def test_show_dirty_disabled(self):
|
||||
"""Test showing dirty when disabled."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.show_dirty()
|
||||
|
||||
assert result["enabled"] is False
|
||||
assert result["pid"] is None
|
||||
|
||||
|
||||
class TestReload:
|
||||
"""Tests for reload command."""
|
||||
|
||||
def test_reload(self):
|
||||
"""Test reload command."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
with patch('os.kill') as mock_kill:
|
||||
result = handlers.reload()
|
||||
|
||||
assert result["status"] == "reloading"
|
||||
mock_kill.assert_called_once_with(12345, signal.SIGHUP)
|
||||
|
||||
|
||||
class TestReopen:
|
||||
"""Tests for reopen command."""
|
||||
|
||||
def test_reopen(self):
|
||||
"""Test reopen command."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
with patch('os.kill') as mock_kill:
|
||||
result = handlers.reopen()
|
||||
|
||||
assert result["status"] == "reopening"
|
||||
mock_kill.assert_called_once_with(12345, signal.SIGUSR1)
|
||||
|
||||
|
||||
class TestShutdown:
|
||||
"""Tests for shutdown command."""
|
||||
|
||||
def test_shutdown_graceful(self):
|
||||
"""Test graceful shutdown."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
with patch('os.kill') as mock_kill:
|
||||
result = handlers.shutdown()
|
||||
|
||||
assert result["status"] == "shutting_down"
|
||||
assert result["mode"] == "graceful"
|
||||
mock_kill.assert_called_once_with(12345, signal.SIGTERM)
|
||||
|
||||
def test_shutdown_quick(self):
|
||||
"""Test quick shutdown."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
with patch('os.kill') as mock_kill:
|
||||
result = handlers.shutdown("quick")
|
||||
|
||||
assert result["status"] == "shutting_down"
|
||||
assert result["mode"] == "quick"
|
||||
mock_kill.assert_called_once_with(12345, signal.SIGINT)
|
||||
|
||||
|
||||
class TestHelp:
|
||||
"""Tests for help command."""
|
||||
|
||||
def test_help(self):
|
||||
"""Test help command."""
|
||||
arbiter = MockArbiter()
|
||||
handlers = CommandHandlers(arbiter)
|
||||
|
||||
result = handlers.help()
|
||||
|
||||
assert "commands" in result
|
||||
commands = result["commands"]
|
||||
assert "show workers" in commands
|
||||
assert "worker add [N]" in commands
|
||||
assert "reload" in commands
|
||||
assert "shutdown [graceful|quick]" in commands
|
||||
249
tests/ctl/test_protocol.py
Normal file
249
tests/ctl/test_protocol.py
Normal file
@ -0,0 +1,249 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for control socket protocol."""
|
||||
|
||||
import json
|
||||
import struct
|
||||
import pytest
|
||||
|
||||
from gunicorn.ctl.protocol import (
|
||||
ControlProtocol,
|
||||
ProtocolError,
|
||||
make_request,
|
||||
make_response,
|
||||
make_error_response,
|
||||
)
|
||||
|
||||
|
||||
class TestControlProtocolEncoding:
|
||||
"""Tests for message encoding/decoding."""
|
||||
|
||||
def test_encode_message_simple(self):
|
||||
"""Test encoding a simple message."""
|
||||
data = {"command": "test"}
|
||||
result = ControlProtocol.encode_message(data)
|
||||
|
||||
# First 4 bytes are length
|
||||
length = struct.unpack('>I', result[:4])[0]
|
||||
payload = result[4:]
|
||||
|
||||
assert length == len(payload)
|
||||
assert json.loads(payload.decode('utf-8')) == data
|
||||
|
||||
def test_encode_message_unicode(self):
|
||||
"""Test encoding message with unicode characters."""
|
||||
data = {"message": "Hello \u4e16\u754c"}
|
||||
result = ControlProtocol.encode_message(data)
|
||||
|
||||
length = struct.unpack('>I', result[:4])[0]
|
||||
payload = result[4:]
|
||||
|
||||
assert length == len(payload)
|
||||
assert json.loads(payload.decode('utf-8')) == data
|
||||
|
||||
def test_decode_message_simple(self):
|
||||
"""Test decoding a simple message."""
|
||||
data = {"command": "test", "args": [1, 2, 3]}
|
||||
payload = json.dumps(data).encode('utf-8')
|
||||
length = struct.pack('>I', len(payload))
|
||||
raw = length + payload
|
||||
|
||||
result = ControlProtocol.decode_message(raw)
|
||||
assert result == data
|
||||
|
||||
def test_decode_message_too_short(self):
|
||||
"""Test decoding message that's too short."""
|
||||
with pytest.raises(ProtocolError) as exc_info:
|
||||
ControlProtocol.decode_message(b'\x00\x00')
|
||||
assert "too short" in str(exc_info.value)
|
||||
|
||||
def test_decode_message_incomplete(self):
|
||||
"""Test decoding incomplete message."""
|
||||
# Length says 100 bytes but only 4 bytes provided
|
||||
raw = struct.pack('>I', 100) + b'test'
|
||||
with pytest.raises(ProtocolError) as exc_info:
|
||||
ControlProtocol.decode_message(raw)
|
||||
assert "Incomplete" in str(exc_info.value)
|
||||
|
||||
def test_roundtrip(self):
|
||||
"""Test encode/decode roundtrip."""
|
||||
original = {
|
||||
"id": 42,
|
||||
"command": "show workers",
|
||||
"args": ["arg1", 123, True, None],
|
||||
"nested": {"a": 1, "b": [1, 2, 3]},
|
||||
}
|
||||
|
||||
encoded = ControlProtocol.encode_message(original)
|
||||
decoded = ControlProtocol.decode_message(encoded)
|
||||
|
||||
assert decoded == original
|
||||
|
||||
|
||||
class TestMakeRequest:
|
||||
"""Tests for request creation."""
|
||||
|
||||
def test_make_request_simple(self):
|
||||
"""Test creating a simple request."""
|
||||
result = make_request(1, "show workers")
|
||||
|
||||
assert result["id"] == 1
|
||||
assert result["command"] == "show workers"
|
||||
assert result["args"] == []
|
||||
|
||||
def test_make_request_with_args(self):
|
||||
"""Test creating a request with arguments."""
|
||||
result = make_request(42, "worker add", [2])
|
||||
|
||||
assert result["id"] == 42
|
||||
assert result["command"] == "worker add"
|
||||
assert result["args"] == [2]
|
||||
|
||||
|
||||
class TestMakeResponse:
|
||||
"""Tests for response creation."""
|
||||
|
||||
def test_make_response_simple(self):
|
||||
"""Test creating a simple response."""
|
||||
result = make_response(1, {"count": 5})
|
||||
|
||||
assert result["id"] == 1
|
||||
assert result["status"] == "ok"
|
||||
assert result["data"] == {"count": 5}
|
||||
|
||||
def test_make_response_empty_data(self):
|
||||
"""Test creating response with no data."""
|
||||
result = make_response(1)
|
||||
|
||||
assert result["id"] == 1
|
||||
assert result["status"] == "ok"
|
||||
assert result["data"] == {}
|
||||
|
||||
|
||||
class TestMakeErrorResponse:
|
||||
"""Tests for error response creation."""
|
||||
|
||||
def test_make_error_response(self):
|
||||
"""Test creating an error response."""
|
||||
result = make_error_response(1, "Unknown command")
|
||||
|
||||
assert result["id"] == 1
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] == "Unknown command"
|
||||
|
||||
|
||||
class TestControlProtocolSocket:
|
||||
"""Tests for socket reading/writing."""
|
||||
|
||||
def test_read_write_message(self):
|
||||
"""Test read/write through socket pair."""
|
||||
import socket
|
||||
import threading
|
||||
|
||||
data = {"id": 1, "command": "test"}
|
||||
received = []
|
||||
|
||||
# Create socket pair
|
||||
server, client = socket.socketpair()
|
||||
|
||||
def reader():
|
||||
received.append(ControlProtocol.read_message(server))
|
||||
|
||||
t = threading.Thread(target=reader)
|
||||
t.start()
|
||||
|
||||
ControlProtocol.write_message(client, data)
|
||||
t.join(timeout=2.0)
|
||||
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0] == data
|
||||
|
||||
def test_read_connection_closed(self):
|
||||
"""Test reading from closed connection."""
|
||||
import socket
|
||||
|
||||
server, client = socket.socketpair()
|
||||
client.close()
|
||||
|
||||
with pytest.raises(ConnectionError):
|
||||
ControlProtocol.read_message(server)
|
||||
|
||||
server.close()
|
||||
|
||||
def test_read_message_too_large(self):
|
||||
"""Test reading message exceeding max size."""
|
||||
import socket
|
||||
|
||||
server, client = socket.socketpair()
|
||||
|
||||
# Send a length that exceeds MAX_MESSAGE_SIZE
|
||||
huge_length = ControlProtocol.MAX_MESSAGE_SIZE + 1
|
||||
client.send(struct.pack('>I', huge_length))
|
||||
|
||||
with pytest.raises(ProtocolError) as exc_info:
|
||||
ControlProtocol.read_message(server)
|
||||
assert "too large" in str(exc_info.value)
|
||||
|
||||
client.close()
|
||||
server.close()
|
||||
|
||||
|
||||
class TestControlProtocolAsync:
|
||||
"""Tests for async protocol methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_read_write(self):
|
||||
"""Test async read/write using a unix server."""
|
||||
import asyncio
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
data = {"id": 1, "command": "async test"}
|
||||
received = []
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
async def handler(reader, writer):
|
||||
msg = await ControlProtocol.read_message_async(reader)
|
||||
received.append(msg)
|
||||
await ControlProtocol.write_message_async(writer, data)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
server = await asyncio.start_unix_server(handler, path=socket_path)
|
||||
|
||||
async with server:
|
||||
reader, writer = await asyncio.open_unix_connection(socket_path)
|
||||
await ControlProtocol.write_message_async(writer, data)
|
||||
response = await ControlProtocol.read_message_async(reader)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0] == data
|
||||
assert response == data
|
||||
|
||||
|
||||
class TestProtocolMaxSize:
|
||||
"""Tests for protocol size limits."""
|
||||
|
||||
def test_max_message_size_constant(self):
|
||||
"""Test that MAX_MESSAGE_SIZE is set to a reasonable value."""
|
||||
# Should be 16 MB
|
||||
assert ControlProtocol.MAX_MESSAGE_SIZE == 16 * 1024 * 1024
|
||||
|
||||
def test_encode_large_message(self):
|
||||
"""Test encoding a large (but valid) message."""
|
||||
# Create a message with ~1MB of data
|
||||
data = {"data": "x" * (1024 * 1024)}
|
||||
encoded = ControlProtocol.encode_message(data)
|
||||
|
||||
# Should succeed and be decodable
|
||||
decoded = ControlProtocol.decode_message(encoded)
|
||||
assert decoded == data
|
||||
348
tests/ctl/test_server.py
Normal file
348
tests/ctl/test_server.py
Normal file
@ -0,0 +1,348 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for control socket server."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.ctl.server import ControlSocketServer
|
||||
from gunicorn.ctl.client import ControlClient
|
||||
|
||||
|
||||
class MockWorker:
|
||||
"""Mock worker for testing."""
|
||||
|
||||
def __init__(self, pid, age, booted=True, aborted=False):
|
||||
self.pid = pid
|
||||
self.age = age
|
||||
self.booted = booted
|
||||
self.aborted = aborted
|
||||
self.tmp = MagicMock()
|
||||
self.tmp.last_update.return_value = time.monotonic()
|
||||
|
||||
|
||||
class MockConfig:
|
||||
"""Mock config for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.bind = ['127.0.0.1:8000']
|
||||
self.workers = 4
|
||||
self.worker_class = 'sync'
|
||||
self.threads = 1
|
||||
self.timeout = 30
|
||||
self.graceful_timeout = 30
|
||||
self.keepalive = 2
|
||||
self.max_requests = 0
|
||||
self.max_requests_jitter = 0
|
||||
self.worker_connections = 1000
|
||||
self.preload_app = False
|
||||
self.daemon = False
|
||||
self.pidfile = None
|
||||
self.proc_name = 'test_app'
|
||||
self.reload = False
|
||||
self.dirty_workers = 0
|
||||
self.dirty_apps = []
|
||||
self.dirty_timeout = 30
|
||||
self.control_socket = 'gunicorn.ctl'
|
||||
self.control_socket_disable = False
|
||||
|
||||
|
||||
class MockLog:
|
||||
"""Mock logger for testing."""
|
||||
|
||||
def debug(self, msg, *args):
|
||||
pass
|
||||
|
||||
def info(self, msg, *args):
|
||||
pass
|
||||
|
||||
def warning(self, msg, *args):
|
||||
pass
|
||||
|
||||
def error(self, msg, *args):
|
||||
pass
|
||||
|
||||
def exception(self, msg, *args):
|
||||
pass
|
||||
|
||||
|
||||
class MockArbiter:
|
||||
"""Mock arbiter for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.cfg = MockConfig()
|
||||
self.log = MockLog()
|
||||
self.pid = 12345
|
||||
self.WORKERS = {}
|
||||
self.LISTENERS = []
|
||||
self.dirty_arbiter_pid = 0
|
||||
self.dirty_arbiter = None
|
||||
self.num_workers = 4
|
||||
self._stats = {
|
||||
'start_time': time.time() - 3600,
|
||||
'workers_spawned': 10,
|
||||
'workers_killed': 5,
|
||||
'reloads': 2,
|
||||
}
|
||||
|
||||
def wakeup(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestControlSocketServerInit:
|
||||
"""Tests for server initialization."""
|
||||
|
||||
def test_init(self):
|
||||
"""Test server initialization."""
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, "/tmp/test.sock", 0o600)
|
||||
|
||||
assert server.arbiter is arbiter
|
||||
assert server.socket_path == "/tmp/test.sock"
|
||||
assert server.socket_mode == 0o600
|
||||
assert server._running is False
|
||||
|
||||
|
||||
class TestControlSocketServerLifecycle:
|
||||
"""Tests for server start/stop."""
|
||||
|
||||
def test_start_stop(self):
|
||||
"""Test starting and stopping the server."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
# Wait for server to start
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
assert os.path.exists(socket_path)
|
||||
|
||||
server.stop()
|
||||
|
||||
# Wait for cleanup
|
||||
time.sleep(0.2)
|
||||
|
||||
# Socket should be cleaned up
|
||||
assert not os.path.exists(socket_path)
|
||||
|
||||
def test_start_already_running(self):
|
||||
"""Test that start is idempotent."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
first_thread = server._thread
|
||||
server.start()
|
||||
|
||||
assert server._thread is first_thread
|
||||
|
||||
server.stop()
|
||||
|
||||
def test_stop_not_running(self):
|
||||
"""Test stopping a non-running server."""
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, "/tmp/test.sock")
|
||||
|
||||
# Should not raise
|
||||
server.stop()
|
||||
|
||||
|
||||
class TestControlSocketServerIntegration:
|
||||
"""Integration tests for server with client."""
|
||||
|
||||
def test_show_workers(self):
|
||||
"""Test show workers command."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
arbiter.WORKERS = {
|
||||
1001: MockWorker(1001, 1),
|
||||
1002: MockWorker(1002, 2),
|
||||
}
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
# Wait for server to start
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
result = client.send_command("show workers")
|
||||
|
||||
assert result["count"] == 2
|
||||
assert len(result["workers"]) == 2
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_show_stats(self):
|
||||
"""Test show stats command."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
result = client.send_command("show stats")
|
||||
|
||||
assert result["pid"] == 12345
|
||||
assert result["workers_spawned"] == 10
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_help_command(self):
|
||||
"""Test help command."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
result = client.send_command("help")
|
||||
|
||||
assert "commands" in result
|
||||
assert "show workers" in result["commands"]
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_worker_add(self):
|
||||
"""Test worker add command."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
arbiter.wakeup = MagicMock()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
result = client.send_command("worker add 2")
|
||||
|
||||
assert result["added"] == 2
|
||||
assert result["total"] == 6
|
||||
assert arbiter.num_workers == 6
|
||||
arbiter.wakeup.assert_called()
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_invalid_command(self):
|
||||
"""Test handling invalid command."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
client.send_command("invalid_command")
|
||||
|
||||
assert "Unknown command" in str(exc_info.value)
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
def test_multiple_commands(self):
|
||||
"""Test sending multiple commands on same connection."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
arbiter.WORKERS = {1001: MockWorker(1001, 1)}
|
||||
server = ControlSocketServer(arbiter, socket_path)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
with ControlClient(socket_path, timeout=5.0) as client:
|
||||
result1 = client.send_command("show workers")
|
||||
result2 = client.send_command("show stats")
|
||||
result3 = client.send_command("help")
|
||||
|
||||
assert result1["count"] == 1
|
||||
assert result2["pid"] == 12345
|
||||
assert "commands" in result3
|
||||
finally:
|
||||
server.stop()
|
||||
|
||||
|
||||
class TestControlSocketServerPermissions:
|
||||
"""Tests for socket permissions."""
|
||||
|
||||
def test_socket_permissions(self):
|
||||
"""Test that socket is created with correct permissions."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
socket_path = os.path.join(tmpdir, "test.sock")
|
||||
|
||||
arbiter = MockArbiter()
|
||||
server = ControlSocketServer(arbiter, socket_path, 0o660)
|
||||
|
||||
server.start()
|
||||
|
||||
for _ in range(20):
|
||||
if os.path.exists(socket_path):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
mode = os.stat(socket_path).st_mode & 0o777
|
||||
assert mode == 0o660
|
||||
finally:
|
||||
server.stop()
|
||||
Loading…
x
Reference in New Issue
Block a user