diff --git a/docs/design/companion-process-manager.md b/docs/design/companion-process-manager.md index 2f8cbcd6..94a0213d 100644 --- a/docs/design/companion-process-manager.md +++ b/docs/design/companion-process-manager.md @@ -681,8 +681,8 @@ No per-companion logic in Arbiter. - [x] Implement `stop_process`. - [x] Implement `restart_process`. - [x] Preserve and clear `manual_stop` correctly. -- [ ] Add Unix control socket. -- [ ] Implement JSON command protocol. +- [x] Add Unix control socket. +- [x] Implement JSON command protocol. - [ ] Implement `status`. - [ ] Implement `start`. - [ ] Implement `stop`. diff --git a/gunicorn/companion/control.py b/gunicorn/companion/control.py new file mode 100644 index 00000000..1409eb91 --- /dev/null +++ b/gunicorn/companion/control.py @@ -0,0 +1,117 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + + +import json +import os +import socket + + +class CommandError(Exception): + """A control request the manager understood but had to reject. + + Raised for malformed input (bad JSON, missing ``cmd``). It is turned into + an ``{"ok": false, "error": ...}`` response rather than crashing the + manager, so a buggy or hostile client can never take the socket down. + """ + + +def decode_command(line): + """Parse one request line into a command dict. + + The wire protocol is newline-delimited JSON: each request is a single JSON + object on its own line, e.g. ``{"cmd": "status"}``. Every request must be a + JSON object carrying a string ``cmd``; anything else is a ``CommandError``. + """ + try: + obj = json.loads(line) + except (ValueError, TypeError): + raise CommandError("invalid JSON") + if not isinstance(obj, dict): + raise CommandError("request must be a JSON object") + if not isinstance(obj.get("cmd"), str): + raise CommandError("missing 'cmd'") + return obj + + +def encode_response(obj): + """Encode a response dict as one newline-terminated JSON line of bytes.""" + return (json.dumps(obj) + "\n").encode("utf-8") + + +class ControlServer: + """The manager's Unix-socket control endpoint. + + Owns the listening socket and the request framing only. Turning a decoded + command into an action is delegated to ``dispatch`` (wired to the manager's + command handlers in a later task); this class just decodes each line, runs + it through ``dispatch``, and writes back the encoded reply. + + The socket is created with mode 0o600 and owned by the (non-root) user + gunicorn runs as. There is no group-ownership switching. + """ + + def __init__(self, dispatch, path, mode=0o600, log=None, backlog=64): + self.dispatch = dispatch + self.path = path + self.mode = mode + self.log = log + self.backlog = backlog + self.sock = None + + def create(self): + """Bind and listen on the Unix socket, replacing any stale one. + + A leftover socket file from a previous manager would make ``bind`` + fail, so it is unlinked first. Called once before the manager enters + its run loop, as clients expect the socket to exist by then. + """ + if os.path.exists(self.path): + os.unlink(self.path) + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.bind(self.path) + os.chmod(self.path, self.mode) + sock.listen(self.backlog) + self.sock = sock + return sock + + def close(self): + """Close the listening socket and remove its file.""" + if self.sock is not None: + self.sock.close() + self.sock = None + if os.path.exists(self.path): + os.unlink(self.path) + + def handle_line(self, line): + """Run one request line and return the encoded response bytes. + + Both decoding and dispatch failures are caught and rendered as an + error response, so one bad request never breaks the connection or the + manager. + """ + try: + response = self.dispatch(decode_command(line)) + except CommandError as e: + response = {"ok": False, "error": str(e)} + return encode_response(response) + + def serve_connection(self, conn): + """Serve newline-delimited requests on one accepted connection. + + Reads until the client hangs up, buffering partial reads and answering + each complete line as it arrives. A trailing fragment without a newline + is ignored. + """ + buf = b"" + with conn: + while True: + chunk = conn.recv(65536) + if not chunk: + break + buf += chunk + while b"\n" in buf: + line, buf = buf.split(b"\n", 1) + if line.strip(): + conn.sendall(self.handle_line(line)) diff --git a/tests/test_companion_control.py b/tests/test_companion_control.py new file mode 100644 index 00000000..069e9124 --- /dev/null +++ b/tests/test_companion_control.py @@ -0,0 +1,85 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +import json +from unittest import mock + +import pytest + +from gunicorn.companion.control import ( + CommandError, + ControlServer, + decode_command, + encode_response, +) + + +def test_decode_command_valid(): + assert decode_command('{"cmd": "status"}') == {"cmd": "status"} + + +def test_decode_command_bad_json(): + with pytest.raises(CommandError): + decode_command("{not json") + + +def test_decode_command_not_object(): + with pytest.raises(CommandError): + decode_command("[1, 2, 3]") + + +def test_decode_command_missing_cmd(): + with pytest.raises(CommandError): + decode_command('{"name": "rq"}') + + +def test_encode_response_newline_terminated(): + out = encode_response({"ok": True}) + assert out.endswith(b"\n") + assert json.loads(out) == {"ok": True} + + +def test_handle_line_dispatches(): + server = ControlServer(dispatch=lambda obj: {"ok": True, "echo": obj["cmd"]}, + path="/tmp/x.sock") + out = server.handle_line('{"cmd": "status"}') + assert json.loads(out) == {"ok": True, "echo": "status"} + + +def test_handle_line_bad_json_error_envelope(): + server = ControlServer(dispatch=lambda obj: {"ok": True}, path="/tmp/x.sock") + out = json.loads(server.handle_line("garbage")) + assert out["ok"] is False and "JSON" in out["error"] + + +def test_handle_line_dispatch_command_error(): + def dispatch(obj): + raise CommandError("unknown command") + server = ControlServer(dispatch=dispatch, path="/tmp/x.sock") + out = json.loads(server.handle_line('{"cmd": "bogus"}')) + assert out["ok"] is False and out["error"] == "unknown command" + + +def test_create_unlinks_stale_and_chmods(): + server = ControlServer(dispatch=lambda o: {}, path="/tmp/x.sock", mode=0o600) + sock = mock.Mock() + with mock.patch("os.path.exists", return_value=True), \ + mock.patch("os.unlink") as unlink, \ + mock.patch("socket.socket", return_value=sock), \ + mock.patch("os.chmod") as chmod: + server.create() + unlink.assert_called_once_with("/tmp/x.sock") + sock.bind.assert_called_once_with("/tmp/x.sock") + chmod.assert_called_once_with("/tmp/x.sock", 0o600) + sock.listen.assert_called_once() + + +def test_close_unlinks(): + server = ControlServer(dispatch=lambda o: {}, path="/tmp/x.sock") + server.sock = mock.Mock() + with mock.patch("os.path.exists", return_value=True), \ + mock.patch("os.unlink") as unlink: + server.close() + unlink.assert_called_once_with("/tmp/x.sock") + assert server.sock is None