jcloude/press/runner.py
2025-12-23 22:43:56 +08:00

456 lines
13 KiB
Python

import json
import typing
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from typing import Literal
import jingrow
import wrapt
from ansible import constants, context
from ansible.executor.playbook_executor import PlaybookExecutor
from ansible.executor.task_executor import TaskExecutor
from ansible.inventory.manager import InventoryManager
from ansible.module_utils.common.collections import ImmutableDict
from ansible.parsing.dataloader import DataLoader
from ansible.playbook import Playbook
from ansible.plugins.action.async_status import ActionModule
from ansible.plugins.callback import CallbackBase
from ansible.utils.display import Display
from ansible.vars.manager import VariableManager
from jingrow.model.document import Document
from jingrow.utils import cstr
from jingrow.utils import now_datetime as now
from jcloude.jcloude.pagetype.ansible_play.ansible_play import AnsiblePlay
if typing.TYPE_CHECKING:
from jcloude.jcloude.pagetype.agent_job.agent_job import AgentJob
from jcloude.jcloude.pagetype.virtual_machine.virtual_machine import VirtualMachine
def reconnect_on_failure():
@wrapt.decorator
def wrapper(wrapped, instance, args, kwargs):
try:
return wrapped(*args, **kwargs)
except Exception as e:
if jingrow.db.is_interface_error(e):
jingrow.db.connect()
return wrapped(*args, **kwargs)
raise
return wrapper
class AnsibleCallback(CallbackBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@reconnect_on_failure()
def process_task_success(self, result):
result, action = jingrow._dict(result._result), result._task.action
if action == "user":
server_type, server = jingrow.db.get_value("Ansible Play", self.play, ["server_type", "server"])
server = jingrow.get_pg(server_type, server)
if result.name == "root":
server.root_public_key = result.ssh_public_key
elif result.name == "jingrow":
server.jingrow_public_key = result.ssh_public_key
server.save()
def v2_runner_on_ok(self, result, *args, **kwargs):
self.update_task("Success", result)
self.process_task_success(result)
def v2_runner_on_failed(self, result, *args, **kwargs):
self.update_task("Failure", result)
def v2_runner_on_skipped(self, result):
self.update_task("Skipped", result)
def v2_runner_on_unreachable(self, result):
self.update_task("Unreachable", result)
def v2_playbook_on_task_start(self, task, is_conditional):
self.update_task("Running", None, task)
def v2_playbook_on_start(self, playbook):
self.update_play("Running")
def v2_playbook_on_stats(self, stats):
self.update_play(None, stats)
@reconnect_on_failure()
def update_play(self, status=None, stats=None):
play = jingrow.get_pg("Ansible Play", self.play)
if stats:
# Assume we're running on one host
host = next(iter(stats.processed.keys()))
play.update(stats.summarize(host))
if play.failures or play.unreachable:
play.status = "Failure"
else:
play.status = "Success"
play.end = now()
play.duration = play.end - play.start
else:
play.status = status
play.start = now()
play.save()
jingrow.db.commit()
@reconnect_on_failure()
def update_task(self, status, result=None, task=None):
if result:
if not result._task._role:
return
task_name, result = self.parse_result(result)
else:
if not task._role:
return
task_name = self.tasks[task._role.get_name()][task.name]
task = jingrow.get_pg("Ansible Task", task_name)
task.status = status
if result:
task.output = result.stdout
task.error = result.stderr
task.exception = result.msg
# Reduce clutter be removing keys already shown elsewhere
for key in ("stdout", "stdout_lines", "stderr", "stderr_lines", "msg"):
result.pop(key, None)
task.result = json.dumps(result, indent=4)
task.end = now()
task.duration = task.end - task.start
else:
task.start = now()
task.save()
self.publish_play_progress(task.name)
jingrow.db.commit()
def publish_play_progress(self, task):
jingrow.publish_realtime(
"ansible_play_progress",
{"progress": self.task_list.index(task), "total": len(self.task_list), "play": self.play},
pagetype="Ansible Play",
docname=self.play,
user=jingrow.session.user,
)
def parse_result(self, result):
task = result._task.name
role = result._task._role.get_name()
return self.tasks[role][task], jingrow._dict(result._result)
@reconnect_on_failure()
def on_async_start(self, role, task, job_id):
task_name = self.tasks[role][task]
task = jingrow.get_pg("Ansible Task", task_name)
task.job_id = job_id
task.save()
jingrow.db.commit()
@reconnect_on_failure()
def on_async_poll(self, result):
job_id = result["ansible_job_id"]
task_name = jingrow.get_value("Ansible Task", {"play": self.play, "job_id": job_id}, "name")
task = jingrow.get_pg("Ansible Task", task_name)
task.result = json.dumps(result, indent=4)
task.duration = now() - task.start
task.save()
jingrow.db.commit()
class Ansible:
def __init__(self, server, playbook, user="root", variables=None, port=22):
self.patch()
self.server = server
self.playbook = playbook
self.playbook_path = jingrow.get_app_path("jcloude", "playbooks", self.playbook)
self.host = f"{server.ip}:{port}"
self.variables = variables or {}
constants.HOST_KEY_CHECKING = False
context.CLIARGS = ImmutableDict(
become_method="sudo",
check=False,
connection="ssh",
# This is the only way to pass variables that preserves newlines
extra_vars=[f"{cstr(key)}='{cstr(value)}'" for key, value in self.variables.items()],
remote_user=user,
start_at_task=None,
syntax=False,
verbosity=1,
ssh_common_args=self._get_ssh_proxy_commad(server),
)
self.loader = DataLoader()
self.passwords = dict({})
self.sources = f"{self.host},"
self.inventory = InventoryManager(loader=self.loader, sources=self.sources)
self.variable_manager = VariableManager(loader=self.loader, inventory=self.inventory)
self.callback = AnsibleCallback()
self.display = Display()
self.display.verbosity = 1
self.create_ansible_play()
def _get_ssh_proxy_commad(self, server):
# Note: ProxyCommand must be enclosed in double quotes
# because it contains spaces
# and the entire argument must be enclosed in single quotes
# because it is passed via the CLI
# See https://docs.ansible.com/ansible/latest/user_guide/connection_details.html#ssh-args
# and https://unix.stackexchange.com/a/303717
# for details
proxy_command = None
if hasattr(self.server, "bastion_host") and self.server.bastion_host:
proxy_command = f'-o ProxyCommand="ssh -W %h:%p \
{server.bastion_host.ssh_user}@{server.bastion_host.ip} \
-p {server.bastion_host.ssh_port}"'
return proxy_command
def patch(self):
def modified_action_module_run(*args, **kwargs):
result = self.action_module_run(*args, **kwargs)
self.callback.on_async_poll(result)
return result
def modified_poll_async_result(executor, result, templar, task_vars=None):
job_id = result["ansible_job_id"]
task = executor._task
self.callback.on_async_start(task._role.get_name(), task.name, job_id)
return self._poll_async_result(executor, result, templar, task_vars=task_vars)
if ActionModule.run.__module__ != "jcloude.runner":
self.action_module_run = ActionModule.run
ActionModule.run = modified_action_module_run
if TaskExecutor.run.__module__ != "jcloude.runner":
self._poll_async_result = TaskExecutor._poll_async_result
TaskExecutor._poll_async_result = modified_poll_async_result
def unpatch(self):
TaskExecutor._poll_async_result = self._poll_async_result
ActionModule.run = self.action_module_run
def run(self) -> AnsiblePlay:
self.executor = PlaybookExecutor(
playbooks=[self.playbook_path],
inventory=self.inventory,
variable_manager=self.variable_manager,
loader=self.loader,
passwords=self.passwords,
)
# Use AnsibleCallback so we can receive updates for tasks execution
self.executor._tqm._stdout_callback = self.callback
self.callback.play = self.play
self.callback.tasks = self.tasks
self.callback.task_list = self.task_list
self.executor.run()
self.unpatch()
return jingrow.get_pg("Ansible Play", self.play)
def create_ansible_play(self):
# Parse the playbook and create Ansible Tasks so we can show how many tasks are pending
playbook = Playbook.load(
self.playbook_path, variable_manager=self.variable_manager, loader=self.loader
)
# Assume we only have one play per playbook
play = playbook.get_plays()[0]
play_pg = jingrow.get_pg(
{
"pagetype": "Ansible Play",
"server_type": self.server.pagetype,
"server": self.server.name,
"variables": json.dumps(self.variables, indent=4),
"playbook": self.playbook,
"play": play.get_name(),
}
).insert()
self.play = play_pg.name
self.tasks = {}
self.task_list = []
for role in play.get_roles():
for block in role.get_task_blocks():
for task in block.block:
task_pg = jingrow.get_pg(
{
"pagetype": "Ansible Task",
"play": self.play,
"role": role.get_name(),
"task": task.name,
}
).insert()
self.tasks.setdefault(role.get_name(), {})[task.name] = task_pg.name
self.task_list.append(task_pg.name)
class Status(str, Enum):
Pending = "Pending"
Running = "Running"
Success = "Success"
Skipped = "Skipped"
Failure = "Failure"
def __str__(self):
return self.value
class GenericStep(Document):
attempt: int
job_type: Literal["Ansible Play", "Agent Job"]
job: str | None
status: Status
method_name: str
@dataclass
class StepHandler:
save: Callable
reload: Callable
pagetype: str
name: str
def handle_vm_status_job(
self,
step: GenericStep,
virtual_machine: str,
expected_status: str,
) -> None:
step.attempt = 1 if not step.attempt else step.attempt + 1
# Try to sync status in every attempt
try:
virtual_machine_pg: "VirtualMachine" = jingrow.get_pg("Virtual Machine", virtual_machine)
virtual_machine_pg.sync()
except Exception:
pass
machine_status = jingrow.db.get_value("Virtual Machine", virtual_machine, "status")
step.status = Status.Running if machine_status != expected_status else Status.Success
step.save()
def handle_agent_job(self, step: GenericStep, job: str, poll: bool = False) -> None:
if poll:
job_pg: AgentJob = jingrow.get_pg("Agent Job", job)
job_pg.get_status()
job_status = jingrow.db.get_value("Agent Job", job, "status")
status_map = {
"Delivery Failure": Status.Failure,
"Undelivered": Status.Pending,
}
job_status = status_map.get(job_status, job_status)
step.attempt = 1 if not step.attempt else step.attempt + 1
step.status = job_status
step.save()
if step.status == Status.Failure:
raise
def handle_ansible_play(self, step: GenericStep, ansible: Ansible) -> None:
step.job_type = "Ansible Play"
step.job = ansible.play
step.save()
ansible_play = ansible.run()
step.status = ansible_play.status
step.save()
if step.status == Status.Failure:
raise
def _fail_ansible_step(
self,
step: GenericStep,
ansible: Ansible,
e: Exception | None = None,
) -> None:
step.job = getattr(ansible, "play", None)
step.status = Status.Failure
step.output = str(e)
step.save()
def _fail_job_step(self, step: GenericStep, e: Exception | None = None) -> None:
step.status = Status.Failure
step.output = str(e)
step.save()
def fail(self):
self.status = Status.Failure
self.save()
jingrow.db.commit()
def succeed(self):
self.status = Status.Success
self.save()
jingrow.db.commit()
def handle_step_failure(self):
# can be implemented by the controller
pass
def get_steps(self, methods: list) -> list[dict]:
"""Generate a list of steps to be executed for NFS volume attachment."""
return [
{
"step_name": method.__doc__,
"method_name": method.__name__,
"status": "Pending",
}
for method in methods
]
def _get_method(self, method_name: str):
"""Retrieve a method object by name."""
return getattr(self, method_name)
def next_step(self, steps: list[GenericStep]) -> GenericStep | None:
for step in steps:
if step.status not in (Status.Success, Status.Failure, Status.Skipped):
return step
return None
def _execute_steps(self, steps: list[GenericStep]):
"""It is now required to be with a `enqueue_pg` else the first step executes in the web worker"""
self.status = Status.Running
self.save()
jingrow.db.commit()
step = self.next_step(steps)
if not step:
self.succeed()
return
# Run a single step in this job
step = step.reload()
method = self._get_method(step.method_name)
try:
method(step)
jingrow.db.commit()
except Exception:
self.reload()
self.fail()
self.handle_step_failure()
jingrow.db.commit()
return
# After step completes, queue the next step
jingrow.enqueue_pg(
self.pagetype,
self.name,
"_execute_steps",
steps=steps,
timeout=18000,
at_front=True,
queue="long",
enqueue_after_commit=True,
)