351 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import logging
from typing import Any, Dict, List
import uuid
import dramatiq
from dramatiq.brokers.redis import RedisBroker
from jingrow.ai.pagetype.local_ai_agent.executor import NodeExecutor
from jingrow.core.hooks.init_hooks import init_hooks
logger = logging.getLogger(__name__)
def _get_redis_url() -> str:
return os.getenv("REDIS_URL", "redis://localhost:6379/0")
_broker_initialized = False
def init_queue() -> None:
global _broker_initialized
if _broker_initialized:
return
try:
broker = RedisBroker(url=_get_redis_url())
dramatiq.set_broker(broker)
_broker_initialized = True
logger.info("Dramatiq broker initialized: %s", _get_redis_url())
except Exception as e:
logger.error(f"Failed to initialize Dramatiq broker: {e}")
_broker_initialized = True # 避免重复尝试
from jingrow.utils.jingrow_api import update_local_job
from jingrow.services.local_job_manager import local_job_manager
from jingrow.utils.jingrow_api import get_logged_user
async def _push_status_to_jingrow(job_id: str, data: Dict[str, Any]) -> None:
"""从 Jingrow Local Redis 读取数据并推送到 Jingrow Local Job"""
if not job_id:
return
try:
# 从本地 Redis 获取完整数据
local_job = local_job_manager.get_job(job_id)
if not local_job:
logger.warning(f"Local job not found: {job_id}")
return
# 更新本地数据
local_job.update(data or {})
local_job_manager.update_job(job_id, data or {})
# 推送到 Jingrow
payload = {
'job_id': job_id,
'name': local_job.get('name', job_id),
'queue': local_job.get('queue', 'default'),
'job_name': local_job.get('job_name', job_id),
'status': local_job.get('status', 'queued'),
'started_at': local_job.get('started_at', ''),
'ended_at': local_job.get('ended_at', ''),
'time_taken': local_job.get('time_taken', ''),
'exc_info': local_job.get('exc_info', ''),
'arguments': local_job.get('arguments', '{}'),
'timeout': local_job.get('timeout', ''),
'creation': local_job.get('creation', ''),
'modified': local_job.get('modified', ''),
'owner': local_job.get('owner', 'system'),
'modified_by': local_job.get('modified_by', 'system'),
'session_cookie': local_job.get('session_cookie')
}
update_local_job(payload)
except Exception as e:
logger.error(f"Failed to push job {job_id} to Jingrow: {e}")
async def _execute_node_job(payload: Dict[str, Any]) -> Dict[str, Any]:
node_type = payload.get("node_type")
flow_id = payload.get("flow_id", "unknown")
context = payload.get("context", {})
inputs = payload.get("inputs", {})
config = payload.get("config", {})
executor = NodeExecutor()
# 透传 session_cookie保持与 Jingrow 执行一致
session_cookie = (context or {}).get("session_cookie")
job_id = payload.get("job_id")
started_iso = None
try:
import time
started = time.time()
started_iso = __import__('datetime').datetime.now().isoformat()
modifier = get_logged_user(session_cookie) or 'system'
await _push_status_to_jingrow(job_id, {
'status': 'started',
'started_at': started_iso,
'modified_by': modifier
})
result = await executor.execute_node(node_type, flow_id, context, inputs, config, session_cookie)
ended = time.time()
ended_iso = __import__('datetime').datetime.now().isoformat()
await _push_status_to_jingrow(job_id, {
'status': 'finished' if result.get('success') else 'failed',
'exc_info': None if result.get('success') else result.get('error'),
'started_at': started_iso,
'ended_at': ended_iso,
'time_taken': round(max(0.0, ended - started), 3),
'modified_by': modifier,
})
return result
except Exception as e:
ended_iso = __import__('datetime').datetime.now().isoformat()
modifier = get_logged_user(session_cookie) or 'system'
await _push_status_to_jingrow(job_id, {
'status': 'failed',
'exc_info': str(e),
'started_at': started_iso,
'ended_at': ended_iso,
'modified_by': modifier
})
raise
async def _execute_flow_job(payload: Dict[str, Any]) -> Dict[str, Any]:
# 轻量后端 Flow 执行(拓扑排序 + 顺序执行节点)
nodes: List[Dict[str, Any]] = payload.get("nodes", [])
edges: List[Dict[str, Any]] = payload.get("edges", [])
initial_context: Dict[str, Any] = payload.get("context", {})
flow_id: str = payload.get("flow_id", "unknown")
if not nodes:
return {"success": False, "error": "No nodes provided"}
# 构建入度与边索引(用于可达性与分支路由)
indegree = {n["id"]: 0 for n in nodes if "id" in n}
incoming_edges_by_target: Dict[str, List[Dict[str, Any]]] = {}
outgoing_edges_by_source: Dict[str, List[Dict[str, Any]]] = {}
for e in edges:
src = e.get("source")
tgt = e.get("target")
if tgt is None:
continue
indegree[tgt] = indegree.get(tgt, 0) + 1
incoming_edges_by_target.setdefault(tgt, []).append(e)
if src is not None:
outgoing_edges_by_source.setdefault(src, []).append(e)
# Kahn 拓扑
queue = [nid for nid, d in indegree.items() if d == 0]
order: List[str] = []
while queue:
cur = queue.pop(0)
order.append(cur)
for e in edges:
if e.get("source") == cur:
t = e.get("target")
indegree[t] = indegree.get(t, 0) - 1
if indegree.get(t, 0) == 0:
queue.append(t)
if len(order) == 0:
return {"success": False, "error": "Topological sort failed"}
id_to_node = {n["id"]: n for n in nodes if "id" in n}
context = {
**initial_context,
"node_results": {},
"flow_data": {"nodes": nodes, "edges": edges},
}
# 确保 session_cookie 被传递到每个节点
session_cookie = initial_context.get("session_cookie")
# 记录被激活的边与已执行节点,采用多轮就绪执行,确保条件分支下游在后续轮次被推进
active_edge_ids: set = set()
executed_nodes: set = set()
max_rounds = max(1, len(order))
for _ in range(max_rounds):
progress_made = False
for node_id in order:
if node_id in executed_nodes:
continue
node = id_to_node.get(node_id)
if not node:
continue
node_type = node.get("type")
# 从上游结果映射 inputs
data_inputs = (node.get("data") or {}).get("inputs") or {}
inputs = {}
for key, ref in data_inputs.items():
from_id = (ref or {}).get("from")
field = (ref or {}).get("field")
inputs[key] = (context["node_results"].get(from_id) or {}).get(field)
# 入口节点合并初始触发入参(入口节点直接使用 context["inputs"]
incoming_list = incoming_edges_by_target.get(node_id, [])
if not incoming_list:
initial_inputs = (context or {}).get("inputs") or {}
if isinstance(initial_inputs, dict) and initial_inputs:
# 后者优先,保留显式连线映射的结果
merged = {**initial_inputs, **inputs}
inputs = merged
# 可达性判断:无入边(起点)或至少有一条入边被激活
has_activated_incoming = any((e.get("id") in active_edge_ids) for e in incoming_list)
if incoming_list and not has_activated_incoming:
# 暂不可达,留待后轮
continue
config = (node.get("data") or {}).get("config") or {}
exec_payload = {
"node_type": node_type,
"flow_id": flow_id,
"context": {**context, "current_node_id": node_id, "session_cookie": session_cookie},
"inputs": inputs,
"config": config,
}
result = await _execute_node_job(exec_payload)
context["node_results"][node_id] = result
executed_nodes.add(node_id)
progress_made = True
if not result.get("success"):
return {"success": False, "error": result.get("error"), "context": context}
allowed_handles = {"output", None}
outgoing_list = outgoing_edges_by_source.get(node_id, [])
outgoing_handles = {e.get("sourceHandle") for e in outgoing_list}
if ("true_output" in outgoing_handles) or ("false_output" in outgoing_handles):
condition_met = result.get("condition_met")
flow_path = (result.get("flow_path") or "").strip().lower()
is_true = (condition_met is True) or (flow_path in {"true_path", "true"})
if is_true and ("true_output" in outgoing_handles):
allowed_handles.update({"true_output", "true"})
if (not is_true) and ("false_output" in outgoing_handles):
allowed_handles.update({"false_output", "false"})
for e in outgoing_list:
handle = e.get("sourceHandle")
if handle in allowed_handles or (handle is None and None in allowed_handles):
if e.get("id"):
active_edge_ids.add(e["id"])
if not progress_made:
break
return {"success": True, "context": context}
async def _execute_agent_job(payload: Dict[str, Any]) -> Dict[str, Any]:
"""根据 agent_id 从 Jingrow 查询 agent_flow 并本地执行"""
agent_id = payload.get("agent_id") or payload.get("agent_name")
session_cookie = payload.get("session_cookie")
if not agent_id:
return {"success": False, "error": "agent_id required"}
try:
from ..utils.jingrow_api import get_agent_detail
agent = get_agent_detail(agent_id, session_cookie)
if agent and agent.get("agent_flow"):
flow_data = agent.get("agent_flow")
if isinstance(flow_data, str):
try:
flow_data = json.loads(flow_data)
except Exception as e:
return {"success": False, "error": "invalid agent_flow json"}
nodes = flow_data.get("nodes") or []
edges = flow_data.get("edges") or []
# 收集初始触发入参,入口节点可直接取用 context["inputs"]
initial_inputs: Dict[str, Any] = {}
incoming_inputs = payload.get("inputs") or {}
if isinstance(incoming_inputs, dict):
initial_inputs.update(incoming_inputs)
ctx_inputs = (payload.get("context") or {}).get("inputs")
if isinstance(ctx_inputs, dict):
initial_inputs.update(ctx_inputs)
# 兼容直传 pagetype/name 的场景
if payload.get("pagetype") and payload.get("name"):
initial_inputs.setdefault("pagetype", payload.get("pagetype"))
initial_inputs.setdefault("name", payload.get("name"))
context = {
"agent_name": agent.get("agent_name") or agent_id,
"agent_data": agent,
"session_cookie": session_cookie,
"inputs": initial_inputs,
}
result = await _execute_flow_job({"nodes": nodes, "edges": edges, "context": context, "flow_id": agent_id})
return result
else:
return {"success": False, "error": "agent_flow not found"}
except Exception as e:
return {"success": False, "error": f"Failed to get agent detail: {str(e)}"}
@dramatiq.actor(max_retries=3, time_limit=60_000)
def execute_local_scheduled_job(job_json: str) -> None:
"""Worker 执行入口job_json 为 JSON 字符串,包含 target_type 及其参数。
target_type: agent
"""
try:
init_hooks(clear_cache=False)
except Exception as e:
logger.warning(f"初始化钩子系统失败: {e}")
try:
payload = json.loads(job_json)
except Exception as e:
return
import asyncio
if not payload.get("job_id"):
payload["job_id"] = str(uuid.uuid4())
target_type = payload.get("target_type")
if target_type == "agent":
import time
started = time.time()
started_iso = __import__('datetime').datetime.now().isoformat()
sc = payload.get("session_cookie") or (payload.get("context") or {}).get("session_cookie")
modifier = get_logged_user(sc) or 'system'
asyncio.run(_push_status_to_jingrow(payload["job_id"], {
"status": "started",
"started_at": started_iso,
"session_cookie": sc,
"modified_by": modifier
}))
result = asyncio.run(_execute_agent_job(payload))
ended = time.time()
ended_iso = __import__('datetime').datetime.now().isoformat()
asyncio.run(_push_status_to_jingrow(payload["job_id"], {
"status": "finished" if result.get("success") else "failed",
"ended_at": ended_iso,
"time_taken": round(max(0.0, ended - started), 3),
"session_cookie": sc,
"modified_by": modifier,
}))
else:
logger.warning(f"Unsupported target_type: {target_type}")