178 lines
6.3 KiB
Python
178 lines
6.3 KiB
Python
# Copyright (c) 2025, JINGROW and contributors
|
||
# For license information, please see license.txt
|
||
|
||
"""
|
||
Jingrow 白名单路由服务
|
||
"""
|
||
|
||
from fastapi import APIRouter, Request, HTTPException
|
||
from fastapi.responses import JSONResponse
|
||
import importlib
|
||
import sys
|
||
from pathlib import Path
|
||
from typing import Any, Dict
|
||
import logging
|
||
from jingrow import get_whitelisted_function, is_whitelisted
|
||
from jingrow.utils.auth import get_jingrow_api_headers, get_session_api_headers
|
||
from jingrow.utils.jingrow_api import get_logged_user
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
# 确保各 app 根目录在 sys.path 中(仅初始化一次)
|
||
_apps_path_initialized = False
|
||
|
||
def _ensure_apps_on_sys_path():
|
||
"""确保各 app 根目录在 sys.path 中,支持跨 app 导入"""
|
||
global _apps_path_initialized
|
||
if _apps_path_initialized:
|
||
return
|
||
|
||
try:
|
||
project_root = Path(__file__).resolve().parents[4]
|
||
apps_dir = project_root / "apps"
|
||
|
||
# 读取 apps.txt,添加各 app 的根目录(apps/<app>)
|
||
apps_txt = apps_dir / "apps.txt"
|
||
if apps_txt.exists():
|
||
for app_name in apps_txt.read_text().splitlines():
|
||
app_name = app_name.strip()
|
||
if app_name:
|
||
app_root_dir = apps_dir / app_name
|
||
if app_root_dir.exists() and str(app_root_dir) not in sys.path:
|
||
sys.path.insert(0, str(app_root_dir))
|
||
except Exception:
|
||
pass
|
||
finally:
|
||
_apps_path_initialized = True
|
||
|
||
async def authenticate_request(request: Request, allow_guest: bool) -> bool:
|
||
"""
|
||
认证请求,支持两种认证方式:
|
||
1. Session Cookie 认证
|
||
2. API Key 认证
|
||
"""
|
||
if allow_guest:
|
||
return True
|
||
|
||
# 方式1: 检查 Session Cookie 认证
|
||
session_cookie = request.cookies.get('sid')
|
||
if session_cookie:
|
||
try:
|
||
user = get_logged_user(session_cookie)
|
||
if user:
|
||
logger.info(f"Session认证成功: {user}")
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"Session认证失败: {e}")
|
||
|
||
# 方式2: 检查 API Key 认证
|
||
auth_header = request.headers.get('Authorization')
|
||
if auth_header and auth_header.startswith('token '):
|
||
try:
|
||
# 验证API Key格式: token key:secret
|
||
token_part = auth_header[6:] # 移除 "token " 前缀
|
||
if ':' in token_part:
|
||
api_key, api_secret = token_part.split(':', 1)
|
||
# 验证API Key是否有效
|
||
expected_headers = get_jingrow_api_headers()
|
||
if expected_headers and expected_headers.get('Authorization') == auth_header:
|
||
logger.info("API Key认证成功")
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"API Key认证失败: {e}")
|
||
|
||
logger.warning("认证失败: 未提供有效的认证信息")
|
||
return False
|
||
|
||
async def _process_whitelist_call(request: Request, full_module_path: str):
|
||
"""通用处理:接收完整点分路径 '<package.module.function>' 并执行调用"""
|
||
try:
|
||
async def _get_request_data(req: Request) -> Dict[str, Any]:
|
||
"""GET 使用查询参数,其他方法使用 JSON body"""
|
||
if req.method == 'GET':
|
||
return dict(req.query_params)
|
||
try:
|
||
return await req.json()
|
||
except Exception:
|
||
return {}
|
||
|
||
# 确保 apps 目录在 sys.path 中(支持跨 app 导入)
|
||
_ensure_apps_on_sys_path()
|
||
|
||
# 解析路径并导入
|
||
modulename = ".".join(full_module_path.split('.')[:-1])
|
||
methodname = full_module_path.split('.')[-1]
|
||
module = import_module(modulename)
|
||
func = getattr(module, methodname)
|
||
|
||
# 检查白名单(装饰器注册时使用 func.__module__)
|
||
actual_whitelist_path = f"{module.__name__}.{methodname}"
|
||
whitelist_info = get_whitelisted_function(actual_whitelist_path)
|
||
|
||
if whitelist_info:
|
||
# 检查 HTTP 方法
|
||
if request.method not in whitelist_info['methods']:
|
||
raise HTTPException(status_code=405, detail=f"Method {request.method} not allowed")
|
||
|
||
# 检查权限
|
||
if not whitelist_info['allow_guest']:
|
||
if not await authenticate_request(request, whitelist_info['allow_guest']):
|
||
raise HTTPException(status_code=401, detail="Authentication required")
|
||
|
||
# 调用函数
|
||
request_data = await _get_request_data(request)
|
||
result = func(**request_data)
|
||
|
||
return JSONResponse(content={"success": True, "data": result})
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"Request handler error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.api_route("/{module_path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||
async def handle_request(request: Request, module_path: str):
|
||
"""
|
||
兼容旧路径:直接传入完整点分路径 '<package.module.function>'
|
||
"""
|
||
return await _process_whitelist_call(request, module_path)
|
||
|
||
|
||
@router.api_route("/{app}/{module_path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||
async def handle_request_with_app_prefix(request: Request, app: str, module_path: str):
|
||
"""
|
||
新路径:支持以 app 作为前缀,例如 app.module.function
|
||
"""
|
||
full_module_path = f"{app}.{module_path}"
|
||
return await _process_whitelist_call(request, full_module_path)
|
||
|
||
def parse_module_path(module_path: str) -> Dict[str, str]:
|
||
"""解析模块路径"""
|
||
parts = module_path.split('.')
|
||
|
||
if len(parts) < 2:
|
||
raise ValueError("Invalid module path format")
|
||
|
||
# 最后一部分是函数名
|
||
function_name = parts[-1]
|
||
|
||
# 其余部分是模块路径
|
||
module_path = '.'.join(parts[:-1])
|
||
|
||
return {
|
||
'module_path': module_path,
|
||
'function_name': function_name
|
||
}
|
||
|
||
def import_module(module_path: str):
|
||
"""动态导入模块"""
|
||
try:
|
||
return importlib.import_module(module_path)
|
||
except ImportError as e:
|
||
logger.error(f"Failed to import module {module_path}: {e}")
|
||
raise HTTPException(status_code=404, detail=f"Module {module_path} not found")
|
||
|