323 lines
11 KiB
Python
323 lines
11 KiB
Python
# Copyright (c) 2025, JINGROW and contributors
|
||
# For license information, please see license.txt
|
||
|
||
from typing import Dict, List, Any, Optional, Union
|
||
import functools
|
||
import inspect
|
||
import logging
|
||
import os
|
||
from contextvars import ContextVar
|
||
from jingrow.model.page import Page
|
||
from jingrow.config import Config
|
||
from jingrow.utils.jingrow_api import upload_file_to_jingrow
|
||
|
||
_local = {}
|
||
|
||
# 【通用执行上下文标记】用于标记当前调用链的来源,防止循环触发等场景
|
||
_update_source_context: ContextVar[Optional[str]] = ContextVar('update_source', default=None)
|
||
|
||
# 统一 Jingrow 日志记录器(仅为本模块及调用方提供最小可用输出,不修改全局 root logger)
|
||
_root_logger = logging.getLogger("jingrow")
|
||
|
||
# 简单高效:用 addLevelName 给级别上色;使用标准 Formatter 调整顺序为“时间在最后”
|
||
_ANSI = {
|
||
'DEBUG': '\x1b[36m', # cyan
|
||
'INFO': '\x1b[32m', # green
|
||
'WARNING': '\x1b[33m', # yellow
|
||
'ERROR': '\x1b[31m', # red
|
||
'CRITICAL': '\x1b[41m' # red background
|
||
}
|
||
for _lvl in list(_ANSI.keys()):
|
||
_code = _ANSI[_lvl]
|
||
logging.addLevelName(getattr(logging, _lvl), f"{_code}{_lvl}\x1b[0m")
|
||
|
||
|
||
# ====== High-level helpers to enforce Page lifecycle (hooks) ======
|
||
|
||
def get_pg(pagetype: str, name: str):
|
||
"""获取单条记录并转为可属性访问的对象,失败返回 None。"""
|
||
pg = Page(pagetype)
|
||
res = pg.get(name)
|
||
if not isinstance(res, dict) or not res.get('success'):
|
||
return None
|
||
data = res.get('data') or {}
|
||
return data
|
||
|
||
|
||
def create_pg(pagetype: str, data: Dict[str, Any]):
|
||
"""创建记录,返回创建后的数据对象或 None。"""
|
||
pg = Page(pagetype)
|
||
res = pg.create(data)
|
||
if not isinstance(res, dict) or not res.get('success'):
|
||
return None
|
||
created = res.get('data') or {}
|
||
return created
|
||
|
||
|
||
def update_pg(pagetype: str, name: str, data: Dict[str, Any]):
|
||
"""更新记录,成功返回更新后的数据对象或 True,失败返回 False。"""
|
||
pg = Page(pagetype)
|
||
res = pg.update(name, data)
|
||
if not isinstance(res, dict) or not res.get('success'):
|
||
return False
|
||
updated = res.get('data')
|
||
if updated is None:
|
||
return True
|
||
return updated
|
||
|
||
|
||
def skip_hooks(source: str = "agent"):
|
||
"""
|
||
设置上下文标记,表示当前调用链来自指定来源(如智能体执行),
|
||
用于跳过某些钩子防止循环触发等场景。
|
||
使用 contextvars.ContextVar,支持同步和异步上下文传播。
|
||
"""
|
||
try:
|
||
_update_source_context.set(source)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def restore_hooks():
|
||
"""
|
||
清除上下文标记,恢复正常的钩子执行。
|
||
注意:contextvars 会自动管理上下文生命周期,通常在 try-finally 中使用。
|
||
"""
|
||
try:
|
||
_update_source_context.set(None)
|
||
except Exception:
|
||
pass
|
||
|
||
|
||
def get_hook_source() -> Optional[str]:
|
||
"""
|
||
获取当前调用链的来源标识,用于判断是否跳过某些钩子。
|
||
在异步环境中也能正确获取上下文值。
|
||
"""
|
||
try:
|
||
return _update_source_context.get()
|
||
except Exception:
|
||
return None
|
||
|
||
|
||
def delete_pg(pagetype: str, name: str) -> bool:
|
||
pg = Page(pagetype)
|
||
res = pg.delete(name)
|
||
return bool(isinstance(res, dict) and res.get('success'))
|
||
|
||
|
||
def get_list(pagetype: str, filters: List[List[Any]] = None, fields: List[str] = None, limit: int = None):
|
||
"""获取记录列表,返回对象列表;失败返回空列表。"""
|
||
pg = Page(pagetype)
|
||
res = pg.list(filters=filters, fields=fields, limit=limit)
|
||
if not isinstance(res, dict) or not res.get('success'):
|
||
return []
|
||
items = res.get('data') or []
|
||
return items
|
||
|
||
|
||
def get_all(pagetype: str, filters: List[List[Any]] = None, fields: List[str] = None):
|
||
return get_list(pagetype, filters=filters, fields=fields, limit=None)
|
||
|
||
|
||
def get_single(pagetype: str):
|
||
"""获取 single 类型 pagetype 配置,返回 {success, config|error} 结构。"""
|
||
return Page(pagetype).get_single()
|
||
|
||
|
||
def get_module_app(pagetype: str):
|
||
"""获取指定 pagetype 的模块应用信息,返回后端适配器的原始结果结构。"""
|
||
return Page(pagetype).get_module_app()
|
||
|
||
|
||
def get_pg_id(
|
||
pagetype: str,
|
||
filters: List[List[Any]] = None,
|
||
field: Optional[str] = None,
|
||
value: Optional[str] = None,
|
||
site: Optional[str] = None,
|
||
):
|
||
"""根据过滤条件或字段值获取文档 ID,返回后端适配器的原始结果结构。"""
|
||
return Page(pagetype).get_pg_id(filters=filters, field=field, value=value, site=site)
|
||
|
||
|
||
def get_meta(pagetype: str):
|
||
"""获取 pagetype 的元数据,返回后端适配器的原始结果结构。"""
|
||
return Page(pagetype).get_meta()
|
||
|
||
|
||
def get_field_mapping_from_jingrow(pagetype: str):
|
||
"""获取字段的label到fieldname的映射,返回 {label: fieldname} 的映射字典。"""
|
||
return Page(pagetype).get_field_mapping_from_jingrow()
|
||
|
||
|
||
def get_field_value_from_jingrow(pagetype: str, name: str, fieldname: str):
|
||
"""从Jingrow获取字段的当前值,返回字段的当前值,如果为空则返回None。"""
|
||
return Page(pagetype).get_field_value_from_jingrow(name, fieldname)
|
||
|
||
|
||
def upload_file(file_data: bytes, filename: str, attached_to_pagetype: Optional[str] = None, attached_to_name: Optional[str] = None, attached_to_field: Optional[str] = None):
|
||
"""
|
||
上传文件到Jingrow服务器
|
||
|
||
Args:
|
||
file_data (bytes): 文件二进制数据
|
||
filename (str): 文件名
|
||
attached_to_pagetype (str, optional): 关联的页面类型
|
||
attached_to_name (str, optional): 关联的记录名称
|
||
attached_to_field (str, optional): 关联的字段名称
|
||
|
||
Returns:
|
||
dict: 上传结果
|
||
{
|
||
'success': bool,
|
||
'file_url': str, # 成功时返回文件URL
|
||
'file_name': str, # 成功时返回文件名
|
||
'error': str # 失败时返回错误信息
|
||
}
|
||
"""
|
||
return upload_file_to_jingrow(file_data, filename, attached_to_pagetype, attached_to_name, attached_to_field)
|
||
|
||
|
||
def map_fields_by_labels(field_map: list, ai_outputs: dict, label_to_fieldname: dict) -> dict:
|
||
"""
|
||
根据 field_map 将 ai_outputs 映射为记录字段数据。
|
||
- field_map: [{ 'from': '来源字段', 'to': '目标label或字段名' }]
|
||
- ai_outputs: 上游节点聚合的输出
|
||
- label_to_fieldname: 来自 Jingrow 的 label 到字段名映射
|
||
返回: dict 记录字段数据
|
||
"""
|
||
record_data = {}
|
||
fieldname_set = set(label_to_fieldname.values())
|
||
for mapping in field_map or []:
|
||
source_key = mapping.get('from')
|
||
to_raw = mapping.get('to')
|
||
if not source_key or to_raw is None:
|
||
continue
|
||
value = ai_outputs.get(source_key)
|
||
if value is None:
|
||
continue
|
||
# 目标字段既支持label也支持字段名
|
||
to_field = label_to_fieldname.get(str(to_raw).strip())
|
||
if not to_field and str(to_raw).strip() in fieldname_set:
|
||
to_field = str(to_raw).strip()
|
||
if not to_field:
|
||
continue
|
||
record_data[to_field] = value
|
||
return record_data
|
||
|
||
|
||
def _ensure_logging_configured() -> None:
|
||
|
||
# 统一在 root logger 上配置输出与格式,这样通过名为 "jingrow" 的 logger 打印时,
|
||
# 也会以 "jingrow - ERROR - ... - 时间" 的格式输出,并带颜色的级别名。
|
||
root_logger = logging.getLogger()
|
||
if not root_logger.handlers:
|
||
handler = logging.StreamHandler()
|
||
# 统一样式:不含 logger 名称,只输出级别、消息、时间
|
||
formatter = logging.Formatter("%(levelname)s - %(message)s - %(asctime)s")
|
||
handler.setFormatter(formatter)
|
||
root_logger.addHandler(handler)
|
||
try:
|
||
level_name = str(getattr(Config, 'log_level', 'INFO')).upper()
|
||
except Exception:
|
||
level_name = 'INFO'
|
||
try:
|
||
level = getattr(logging, level_name, logging.INFO)
|
||
except Exception:
|
||
level = logging.INFO
|
||
root_logger.setLevel(level)
|
||
|
||
# 如果已存在其他 handler,但需要强制向控制台输出(例如运行在 API 模式),
|
||
# 则在设置了环境变量 JINGROW_STREAM_LOGGING 的情况下,补充一个 StreamHandler。
|
||
try:
|
||
force_stream = os.environ.get("JINGROW_STREAM_LOGGING")
|
||
except Exception:
|
||
force_stream = None
|
||
if force_stream and not any(isinstance(h, logging.StreamHandler) for h in root_logger.handlers):
|
||
_sh = logging.StreamHandler()
|
||
_sh.setFormatter(logging.Formatter("%(levelname)s - %(message)s - %(asctime)s"))
|
||
root_logger.addHandler(_sh)
|
||
|
||
# 使用 root 的处理器,避免 jingrow 自带 handler 导致重复输出
|
||
if _root_logger.handlers:
|
||
_root_logger.handlers.clear()
|
||
_root_logger.propagate = True
|
||
# 与 root 保持同级别,避免级别不一致导致丢日志
|
||
_root_logger.setLevel(root_logger.level)
|
||
|
||
def log_error(title: Optional[str] = None, message: Optional[str] = None, *, exc: Optional[BaseException] = None) -> None:
|
||
"""输出错误日志到终端。
|
||
|
||
调用约定:
|
||
- log_error(content) -> 仅内容
|
||
- log_error(title, content) -> 标题 + 内容
|
||
- 可选 exc=Exception(...) 传入异常以带上堆栈
|
||
"""
|
||
_ensure_logging_configured()
|
||
# 兼容仅传内容或传标题+内容两种形式
|
||
if message is None:
|
||
content = "" if title is None else str(title)
|
||
else:
|
||
content = str(message) if title is None else f"{title} - {message}"
|
||
_root_logger.error(content, exc_info=exc)
|
||
|
||
def _dict():
|
||
"""创建一个空字典"""
|
||
return {}
|
||
|
||
# =============== Whitelist 实现 ===============
|
||
_logger = logging.getLogger(__name__)
|
||
_whitelisted_functions: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def whitelist(allow_guest: bool = False, methods: List[str] = None):
|
||
if methods is None:
|
||
methods = ['GET', 'POST']
|
||
|
||
def decorator(func):
|
||
module_name = func.__module__
|
||
function_name = func.__name__
|
||
api_path = f"{module_name}.{function_name}"
|
||
|
||
sig = inspect.signature(func)
|
||
parameters: Dict[str, Dict[str, Any]] = {}
|
||
for param_name, param in sig.parameters.items():
|
||
parameters[param_name] = {
|
||
'name': param_name,
|
||
'annotation': None if param.annotation is inspect._empty else param.annotation,
|
||
'default': None if param.default is inspect._empty else param.default,
|
||
'kind': param.kind,
|
||
}
|
||
|
||
_whitelisted_functions[api_path] = {
|
||
'function': func,
|
||
'module_path': module_name,
|
||
'function_name': function_name,
|
||
'allow_guest': allow_guest,
|
||
'methods': methods,
|
||
'parameters': parameters,
|
||
'docstring': func.__doc__ or "",
|
||
}
|
||
|
||
_logger.info(
|
||
f"Whitelisted function registered: {api_path} (methods={methods}, allow_guest={allow_guest})"
|
||
)
|
||
|
||
@functools.wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
def get_whitelisted_functions() -> Dict[str, Dict[str, Any]]:
|
||
return _whitelisted_functions.copy()
|
||
|
||
def is_whitelisted(api_path: str) -> bool:
|
||
return api_path in _whitelisted_functions
|
||
|
||
def get_whitelisted_function(api_path: str):
|
||
return _whitelisted_functions.get(api_path)
|
||
|