351 lines
14 KiB
Python
351 lines
14 KiB
Python
import os
|
||
import json
|
||
import logging
|
||
from typing import Any, Dict, List
|
||
import uuid
|
||
|
||
import dramatiq
|
||
from dramatiq.brokers.redis import RedisBroker
|
||
|
||
from jingrow.ai.pagetype.local_ai_agent.executor import NodeExecutor
|
||
from jingrow.core.hooks.init_hooks import init_hooks
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _get_redis_url() -> str:
|
||
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
|
||
|
||
|
||
_broker_initialized = False
|
||
|
||
|
||
def init_queue() -> None:
|
||
global _broker_initialized
|
||
if _broker_initialized:
|
||
return
|
||
|
||
try:
|
||
broker = RedisBroker(url=_get_redis_url())
|
||
dramatiq.set_broker(broker)
|
||
_broker_initialized = True
|
||
logger.info("Dramatiq broker initialized: %s", _get_redis_url())
|
||
except Exception as e:
|
||
logger.error(f"Failed to initialize Dramatiq broker: {e}")
|
||
_broker_initialized = True # 避免重复尝试
|
||
|
||
|
||
from jingrow.utils.jingrow_api import update_local_job
|
||
from jingrow.services.local_job_manager import local_job_manager
|
||
from jingrow.utils.jingrow_api import get_logged_user
|
||
|
||
|
||
async def _push_status_to_jingrow(job_id: str, data: Dict[str, Any]) -> None:
|
||
"""从 Jingrow Local Redis 读取数据并推送到 Jingrow Local Job"""
|
||
if not job_id:
|
||
return
|
||
try:
|
||
# 从本地 Redis 获取完整数据
|
||
local_job = local_job_manager.get_job(job_id)
|
||
if not local_job:
|
||
logger.warning(f"Local job not found: {job_id}")
|
||
return
|
||
|
||
# 更新本地数据
|
||
local_job.update(data or {})
|
||
local_job_manager.update_job(job_id, data or {})
|
||
|
||
# 推送到 Jingrow
|
||
payload = {
|
||
'job_id': job_id,
|
||
'name': local_job.get('name', job_id),
|
||
'queue': local_job.get('queue', 'default'),
|
||
'job_name': local_job.get('job_name', job_id),
|
||
'status': local_job.get('status', 'queued'),
|
||
'started_at': local_job.get('started_at', ''),
|
||
'ended_at': local_job.get('ended_at', ''),
|
||
'time_taken': local_job.get('time_taken', ''),
|
||
'exc_info': local_job.get('exc_info', ''),
|
||
'arguments': local_job.get('arguments', '{}'),
|
||
'timeout': local_job.get('timeout', ''),
|
||
'creation': local_job.get('creation', ''),
|
||
'modified': local_job.get('modified', ''),
|
||
'owner': local_job.get('owner', 'system'),
|
||
'modified_by': local_job.get('modified_by', 'system'),
|
||
'session_cookie': local_job.get('session_cookie')
|
||
}
|
||
update_local_job(payload)
|
||
except Exception as e:
|
||
logger.error(f"Failed to push job {job_id} to Jingrow: {e}")
|
||
|
||
|
||
async def _execute_node_job(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
node_type = payload.get("node_type")
|
||
flow_id = payload.get("flow_id", "unknown")
|
||
context = payload.get("context", {})
|
||
inputs = payload.get("inputs", {})
|
||
config = payload.get("config", {})
|
||
|
||
executor = NodeExecutor()
|
||
# 透传 session_cookie,保持与 Jingrow 执行一致
|
||
session_cookie = (context or {}).get("session_cookie")
|
||
job_id = payload.get("job_id")
|
||
started_iso = None
|
||
try:
|
||
import time
|
||
started = time.time()
|
||
started_iso = __import__('datetime').datetime.now().isoformat()
|
||
modifier = get_logged_user(session_cookie) or 'system'
|
||
await _push_status_to_jingrow(job_id, {
|
||
'status': 'started',
|
||
'started_at': started_iso,
|
||
'modified_by': modifier
|
||
})
|
||
result = await executor.execute_node(node_type, flow_id, context, inputs, config, session_cookie)
|
||
ended = time.time()
|
||
ended_iso = __import__('datetime').datetime.now().isoformat()
|
||
await _push_status_to_jingrow(job_id, {
|
||
'status': 'finished' if result.get('success') else 'failed',
|
||
'exc_info': None if result.get('success') else result.get('error'),
|
||
'started_at': started_iso,
|
||
'ended_at': ended_iso,
|
||
'time_taken': round(max(0.0, ended - started), 3),
|
||
'modified_by': modifier,
|
||
})
|
||
return result
|
||
except Exception as e:
|
||
ended_iso = __import__('datetime').datetime.now().isoformat()
|
||
modifier = get_logged_user(session_cookie) or 'system'
|
||
await _push_status_to_jingrow(job_id, {
|
||
'status': 'failed',
|
||
'exc_info': str(e),
|
||
'started_at': started_iso,
|
||
'ended_at': ended_iso,
|
||
'modified_by': modifier
|
||
})
|
||
raise
|
||
|
||
|
||
async def _execute_flow_job(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
# 轻量后端 Flow 执行(拓扑排序 + 顺序执行节点)
|
||
nodes: List[Dict[str, Any]] = payload.get("nodes", [])
|
||
edges: List[Dict[str, Any]] = payload.get("edges", [])
|
||
initial_context: Dict[str, Any] = payload.get("context", {})
|
||
flow_id: str = payload.get("flow_id", "unknown")
|
||
|
||
if not nodes:
|
||
return {"success": False, "error": "No nodes provided"}
|
||
|
||
# 构建入度与边索引(用于可达性与分支路由)
|
||
indegree = {n["id"]: 0 for n in nodes if "id" in n}
|
||
incoming_edges_by_target: Dict[str, List[Dict[str, Any]]] = {}
|
||
outgoing_edges_by_source: Dict[str, List[Dict[str, Any]]] = {}
|
||
for e in edges:
|
||
src = e.get("source")
|
||
tgt = e.get("target")
|
||
if tgt is None:
|
||
continue
|
||
indegree[tgt] = indegree.get(tgt, 0) + 1
|
||
incoming_edges_by_target.setdefault(tgt, []).append(e)
|
||
if src is not None:
|
||
outgoing_edges_by_source.setdefault(src, []).append(e)
|
||
|
||
# Kahn 拓扑
|
||
queue = [nid for nid, d in indegree.items() if d == 0]
|
||
order: List[str] = []
|
||
while queue:
|
||
cur = queue.pop(0)
|
||
order.append(cur)
|
||
for e in edges:
|
||
if e.get("source") == cur:
|
||
t = e.get("target")
|
||
indegree[t] = indegree.get(t, 0) - 1
|
||
if indegree.get(t, 0) == 0:
|
||
queue.append(t)
|
||
|
||
if len(order) == 0:
|
||
return {"success": False, "error": "Topological sort failed"}
|
||
|
||
id_to_node = {n["id"]: n for n in nodes if "id" in n}
|
||
context = {
|
||
**initial_context,
|
||
"node_results": {},
|
||
"flow_data": {"nodes": nodes, "edges": edges},
|
||
}
|
||
|
||
# 确保 session_cookie 被传递到每个节点
|
||
session_cookie = initial_context.get("session_cookie")
|
||
|
||
# 记录被激活的边与已执行节点,采用多轮就绪执行,确保条件分支下游在后续轮次被推进
|
||
active_edge_ids: set = set()
|
||
executed_nodes: set = set()
|
||
|
||
max_rounds = max(1, len(order))
|
||
for _ in range(max_rounds):
|
||
progress_made = False
|
||
for node_id in order:
|
||
if node_id in executed_nodes:
|
||
continue
|
||
node = id_to_node.get(node_id)
|
||
if not node:
|
||
continue
|
||
node_type = node.get("type")
|
||
# 从上游结果映射 inputs
|
||
data_inputs = (node.get("data") or {}).get("inputs") or {}
|
||
inputs = {}
|
||
for key, ref in data_inputs.items():
|
||
from_id = (ref or {}).get("from")
|
||
field = (ref or {}).get("field")
|
||
inputs[key] = (context["node_results"].get(from_id) or {}).get(field)
|
||
# 入口节点合并初始触发入参(入口节点直接使用 context["inputs"])
|
||
incoming_list = incoming_edges_by_target.get(node_id, [])
|
||
if not incoming_list:
|
||
initial_inputs = (context or {}).get("inputs") or {}
|
||
if isinstance(initial_inputs, dict) and initial_inputs:
|
||
# 后者优先,保留显式连线映射的结果
|
||
merged = {**initial_inputs, **inputs}
|
||
inputs = merged
|
||
# 可达性判断:无入边(起点)或至少有一条入边被激活
|
||
has_activated_incoming = any((e.get("id") in active_edge_ids) for e in incoming_list)
|
||
if incoming_list and not has_activated_incoming:
|
||
# 暂不可达,留待后轮
|
||
continue
|
||
|
||
config = (node.get("data") or {}).get("config") or {}
|
||
exec_payload = {
|
||
"node_type": node_type,
|
||
"flow_id": flow_id,
|
||
"context": {**context, "current_node_id": node_id, "session_cookie": session_cookie},
|
||
"inputs": inputs,
|
||
"config": config,
|
||
}
|
||
result = await _execute_node_job(exec_payload)
|
||
context["node_results"][node_id] = result
|
||
executed_nodes.add(node_id)
|
||
progress_made = True
|
||
if not result.get("success"):
|
||
return {"success": False, "error": result.get("error"), "context": context}
|
||
|
||
allowed_handles = {"output", None}
|
||
outgoing_list = outgoing_edges_by_source.get(node_id, [])
|
||
outgoing_handles = {e.get("sourceHandle") for e in outgoing_list}
|
||
if ("true_output" in outgoing_handles) or ("false_output" in outgoing_handles):
|
||
condition_met = result.get("condition_met")
|
||
flow_path = (result.get("flow_path") or "").strip().lower()
|
||
is_true = (condition_met is True) or (flow_path in {"true_path", "true"})
|
||
if is_true and ("true_output" in outgoing_handles):
|
||
allowed_handles.update({"true_output", "true"})
|
||
if (not is_true) and ("false_output" in outgoing_handles):
|
||
allowed_handles.update({"false_output", "false"})
|
||
|
||
for e in outgoing_list:
|
||
handle = e.get("sourceHandle")
|
||
if handle in allowed_handles or (handle is None and None in allowed_handles):
|
||
if e.get("id"):
|
||
active_edge_ids.add(e["id"])
|
||
|
||
if not progress_made:
|
||
break
|
||
|
||
return {"success": True, "context": context}
|
||
|
||
|
||
async def _execute_agent_job(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""根据 agent_id 从 Jingrow 查询 agent_flow 并本地执行"""
|
||
agent_id = payload.get("agent_id") or payload.get("agent_name")
|
||
session_cookie = payload.get("session_cookie")
|
||
|
||
if not agent_id:
|
||
return {"success": False, "error": "agent_id required"}
|
||
|
||
try:
|
||
from ..utils.jingrow_api import get_agent_detail
|
||
agent = get_agent_detail(agent_id, session_cookie)
|
||
|
||
if agent and agent.get("agent_flow"):
|
||
flow_data = agent.get("agent_flow")
|
||
if isinstance(flow_data, str):
|
||
try:
|
||
flow_data = json.loads(flow_data)
|
||
except Exception as e:
|
||
return {"success": False, "error": "invalid agent_flow json"}
|
||
|
||
nodes = flow_data.get("nodes") or []
|
||
edges = flow_data.get("edges") or []
|
||
|
||
# 收集初始触发入参,入口节点可直接取用 context["inputs"]
|
||
initial_inputs: Dict[str, Any] = {}
|
||
incoming_inputs = payload.get("inputs") or {}
|
||
if isinstance(incoming_inputs, dict):
|
||
initial_inputs.update(incoming_inputs)
|
||
ctx_inputs = (payload.get("context") or {}).get("inputs")
|
||
if isinstance(ctx_inputs, dict):
|
||
initial_inputs.update(ctx_inputs)
|
||
# 兼容直传 pagetype/name 的场景
|
||
if payload.get("pagetype") and payload.get("name"):
|
||
initial_inputs.setdefault("pagetype", payload.get("pagetype"))
|
||
initial_inputs.setdefault("name", payload.get("name"))
|
||
|
||
context = {
|
||
"agent_name": agent.get("agent_name") or agent_id,
|
||
"agent_data": agent,
|
||
"session_cookie": session_cookie,
|
||
"inputs": initial_inputs,
|
||
}
|
||
result = await _execute_flow_job({"nodes": nodes, "edges": edges, "context": context, "flow_id": agent_id})
|
||
return result
|
||
else:
|
||
return {"success": False, "error": "agent_flow not found"}
|
||
except Exception as e:
|
||
return {"success": False, "error": f"Failed to get agent detail: {str(e)}"}
|
||
|
||
|
||
@dramatiq.actor(max_retries=3, time_limit=60_000)
|
||
def execute_local_scheduled_job(job_json: str) -> None:
|
||
"""Worker 执行入口:job_json 为 JSON 字符串,包含 target_type 及其参数。
|
||
target_type: agent
|
||
"""
|
||
try:
|
||
init_hooks(clear_cache=False)
|
||
except Exception as e:
|
||
logger.warning(f"初始化钩子系统失败: {e}")
|
||
|
||
try:
|
||
payload = json.loads(job_json)
|
||
except Exception as e:
|
||
return
|
||
|
||
import asyncio
|
||
|
||
if not payload.get("job_id"):
|
||
payload["job_id"] = str(uuid.uuid4())
|
||
|
||
target_type = payload.get("target_type")
|
||
|
||
if target_type == "agent":
|
||
import time
|
||
started = time.time()
|
||
started_iso = __import__('datetime').datetime.now().isoformat()
|
||
sc = payload.get("session_cookie") or (payload.get("context") or {}).get("session_cookie")
|
||
modifier = get_logged_user(sc) or 'system'
|
||
asyncio.run(_push_status_to_jingrow(payload["job_id"], {
|
||
"status": "started",
|
||
"started_at": started_iso,
|
||
"session_cookie": sc,
|
||
"modified_by": modifier
|
||
}))
|
||
result = asyncio.run(_execute_agent_job(payload))
|
||
ended = time.time()
|
||
ended_iso = __import__('datetime').datetime.now().isoformat()
|
||
asyncio.run(_push_status_to_jingrow(payload["job_id"], {
|
||
"status": "finished" if result.get("success") else "failed",
|
||
"ended_at": ended_iso,
|
||
"time_taken": round(max(0.0, ended - started), 3),
|
||
"session_cookie": sc,
|
||
"modified_by": modifier,
|
||
}))
|
||
else:
|
||
logger.warning(f"Unsupported target_type: {target_type}")
|
||
|
||
|