323 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025, JINGROW and contributors
# For license information, please see license.txt
from typing import Dict, List, Any, Optional, Union
import functools
import inspect
import logging
import os
from contextvars import ContextVar
from jingrow.model.page import Page
from jingrow.config import Config
from jingrow.utils.jingrow_api import upload_file_to_jingrow
_local = {}
# 【通用执行上下文标记】用于标记当前调用链的来源,防止循环触发等场景
_update_source_context: ContextVar[Optional[str]] = ContextVar('update_source', default=None)
# 统一 Jingrow 日志记录器(仅为本模块及调用方提供最小可用输出,不修改全局 root logger
_root_logger = logging.getLogger("jingrow")
# 简单高效:用 addLevelName 给级别上色;使用标准 Formatter 调整顺序为“时间在最后”
_ANSI = {
'DEBUG': '\x1b[36m', # cyan
'INFO': '\x1b[32m', # green
'WARNING': '\x1b[33m', # yellow
'ERROR': '\x1b[31m', # red
'CRITICAL': '\x1b[41m' # red background
}
for _lvl in list(_ANSI.keys()):
_code = _ANSI[_lvl]
logging.addLevelName(getattr(logging, _lvl), f"{_code}{_lvl}\x1b[0m")
# ====== High-level helpers to enforce Page lifecycle (hooks) ======
def get_pg(pagetype: str, name: str):
"""获取单条记录并转为可属性访问的对象,失败返回 None。"""
pg = Page(pagetype)
res = pg.get(name)
if not isinstance(res, dict) or not res.get('success'):
return None
data = res.get('data') or {}
return data
def create_pg(pagetype: str, data: Dict[str, Any]):
"""创建记录,返回创建后的数据对象或 None。"""
pg = Page(pagetype)
res = pg.create(data)
if not isinstance(res, dict) or not res.get('success'):
return None
created = res.get('data') or {}
return created
def update_pg(pagetype: str, name: str, data: Dict[str, Any]):
"""更新记录,成功返回更新后的数据对象或 True失败返回 False。"""
pg = Page(pagetype)
res = pg.update(name, data)
if not isinstance(res, dict) or not res.get('success'):
return False
updated = res.get('data')
if updated is None:
return True
return updated
def skip_hooks(source: str = "agent"):
"""
设置上下文标记,表示当前调用链来自指定来源(如智能体执行),
用于跳过某些钩子防止循环触发等场景。
使用 contextvars.ContextVar支持同步和异步上下文传播。
"""
try:
_update_source_context.set(source)
except Exception:
pass
def restore_hooks():
"""
清除上下文标记,恢复正常的钩子执行。
注意contextvars 会自动管理上下文生命周期,通常在 try-finally 中使用。
"""
try:
_update_source_context.set(None)
except Exception:
pass
def get_hook_source() -> Optional[str]:
"""
获取当前调用链的来源标识,用于判断是否跳过某些钩子。
在异步环境中也能正确获取上下文值。
"""
try:
return _update_source_context.get()
except Exception:
return None
def delete_pg(pagetype: str, name: str) -> bool:
pg = Page(pagetype)
res = pg.delete(name)
return bool(isinstance(res, dict) and res.get('success'))
def get_list(pagetype: str, filters: List[List[Any]] = None, fields: List[str] = None, limit: int = None):
"""获取记录列表,返回对象列表;失败返回空列表。"""
pg = Page(pagetype)
res = pg.list(filters=filters, fields=fields, limit=limit)
if not isinstance(res, dict) or not res.get('success'):
return []
items = res.get('data') or []
return items
def get_all(pagetype: str, filters: List[List[Any]] = None, fields: List[str] = None):
return get_list(pagetype, filters=filters, fields=fields, limit=None)
def get_single(pagetype: str):
"""获取 single 类型 pagetype 配置,返回 {success, config|error} 结构。"""
return Page(pagetype).get_single()
def get_module_app(pagetype: str):
"""获取指定 pagetype 的模块应用信息,返回后端适配器的原始结果结构。"""
return Page(pagetype).get_module_app()
def get_pg_id(
pagetype: str,
filters: List[List[Any]] = None,
field: Optional[str] = None,
value: Optional[str] = None,
site: Optional[str] = None,
):
"""根据过滤条件或字段值获取文档 ID返回后端适配器的原始结果结构。"""
return Page(pagetype).get_pg_id(filters=filters, field=field, value=value, site=site)
def get_meta(pagetype: str):
"""获取 pagetype 的元数据,返回后端适配器的原始结果结构。"""
return Page(pagetype).get_meta()
def get_field_mapping_from_jingrow(pagetype: str):
"""获取字段的label到fieldname的映射返回 {label: fieldname} 的映射字典。"""
return Page(pagetype).get_field_mapping_from_jingrow()
def get_field_value_from_jingrow(pagetype: str, name: str, fieldname: str):
"""从Jingrow获取字段的当前值返回字段的当前值如果为空则返回None。"""
return Page(pagetype).get_field_value_from_jingrow(name, fieldname)
def upload_file(file_data: bytes, filename: str, attached_to_pagetype: Optional[str] = None, attached_to_name: Optional[str] = None, attached_to_field: Optional[str] = None):
"""
上传文件到Jingrow服务器
Args:
file_data (bytes): 文件二进制数据
filename (str): 文件名
attached_to_pagetype (str, optional): 关联的页面类型
attached_to_name (str, optional): 关联的记录名称
attached_to_field (str, optional): 关联的字段名称
Returns:
dict: 上传结果
{
'success': bool,
'file_url': str, # 成功时返回文件URL
'file_name': str, # 成功时返回文件名
'error': str # 失败时返回错误信息
}
"""
return upload_file_to_jingrow(file_data, filename, attached_to_pagetype, attached_to_name, attached_to_field)
def map_fields_by_labels(field_map: list, ai_outputs: dict, label_to_fieldname: dict) -> dict:
"""
根据 field_map 将 ai_outputs 映射为记录字段数据。
- field_map: [{ 'from': '来源字段', 'to': '目标label或字段名' }]
- ai_outputs: 上游节点聚合的输出
- label_to_fieldname: 来自 Jingrow 的 label 到字段名映射
返回: dict 记录字段数据
"""
record_data = {}
fieldname_set = set(label_to_fieldname.values())
for mapping in field_map or []:
source_key = mapping.get('from')
to_raw = mapping.get('to')
if not source_key or to_raw is None:
continue
value = ai_outputs.get(source_key)
if value is None:
continue
# 目标字段既支持label也支持字段名
to_field = label_to_fieldname.get(str(to_raw).strip())
if not to_field and str(to_raw).strip() in fieldname_set:
to_field = str(to_raw).strip()
if not to_field:
continue
record_data[to_field] = value
return record_data
def _ensure_logging_configured() -> None:
# 统一在 root logger 上配置输出与格式,这样通过名为 "jingrow" 的 logger 打印时,
# 也会以 "jingrow - ERROR - ... - 时间" 的格式输出,并带颜色的级别名。
root_logger = logging.getLogger()
if not root_logger.handlers:
handler = logging.StreamHandler()
# 统一样式:不含 logger 名称,只输出级别、消息、时间
formatter = logging.Formatter("%(levelname)s - %(message)s - %(asctime)s")
handler.setFormatter(formatter)
root_logger.addHandler(handler)
try:
level_name = str(getattr(Config, 'log_level', 'INFO')).upper()
except Exception:
level_name = 'INFO'
try:
level = getattr(logging, level_name, logging.INFO)
except Exception:
level = logging.INFO
root_logger.setLevel(level)
# 如果已存在其他 handler但需要强制向控制台输出例如运行在 API 模式),
# 则在设置了环境变量 JINGROW_STREAM_LOGGING 的情况下,补充一个 StreamHandler。
try:
force_stream = os.environ.get("JINGROW_STREAM_LOGGING")
except Exception:
force_stream = None
if force_stream and not any(isinstance(h, logging.StreamHandler) for h in root_logger.handlers):
_sh = logging.StreamHandler()
_sh.setFormatter(logging.Formatter("%(levelname)s - %(message)s - %(asctime)s"))
root_logger.addHandler(_sh)
# 使用 root 的处理器,避免 jingrow 自带 handler 导致重复输出
if _root_logger.handlers:
_root_logger.handlers.clear()
_root_logger.propagate = True
# 与 root 保持同级别,避免级别不一致导致丢日志
_root_logger.setLevel(root_logger.level)
def log_error(title: Optional[str] = None, message: Optional[str] = None, *, exc: Optional[BaseException] = None) -> None:
"""输出错误日志到终端。
调用约定:
- log_error(content) -> 仅内容
- log_error(title, content) -> 标题 + 内容
- 可选 exc=Exception(...) 传入异常以带上堆栈
"""
_ensure_logging_configured()
# 兼容仅传内容或传标题+内容两种形式
if message is None:
content = "" if title is None else str(title)
else:
content = str(message) if title is None else f"{title} - {message}"
_root_logger.error(content, exc_info=exc)
def _dict():
"""创建一个空字典"""
return {}
# =============== Whitelist 实现 ===============
_logger = logging.getLogger(__name__)
_whitelisted_functions: Dict[str, Dict[str, Any]] = {}
def whitelist(allow_guest: bool = False, methods: List[str] = None):
if methods is None:
methods = ['GET', 'POST']
def decorator(func):
module_name = func.__module__
function_name = func.__name__
api_path = f"{module_name}.{function_name}"
sig = inspect.signature(func)
parameters: Dict[str, Dict[str, Any]] = {}
for param_name, param in sig.parameters.items():
parameters[param_name] = {
'name': param_name,
'annotation': None if param.annotation is inspect._empty else param.annotation,
'default': None if param.default is inspect._empty else param.default,
'kind': param.kind,
}
_whitelisted_functions[api_path] = {
'function': func,
'module_path': module_name,
'function_name': function_name,
'allow_guest': allow_guest,
'methods': methods,
'parameters': parameters,
'docstring': func.__doc__ or "",
}
_logger.info(
f"Whitelisted function registered: {api_path} (methods={methods}, allow_guest={allow_guest})"
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return decorator
def get_whitelisted_functions() -> Dict[str, Dict[str, Any]]:
return _whitelisted_functions.copy()
def is_whitelisted(api_path: str) -> bool:
return api_path in _whitelisted_functions
def get_whitelisted_function(api_path: str):
return _whitelisted_functions.get(api_path)