236 lines
9.7 KiB
Python
236 lines
9.7 KiB
Python
import aiohttp
|
|
from functools import wraps
|
|
from fastapi import HTTPException
|
|
import os
|
|
from typing import Callable, Any, Dict, Optional, Tuple
|
|
from fastapi.responses import StreamingResponse
|
|
import json
|
|
from settings import settings
|
|
import asyncio
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class _BillingTask:
|
|
api_key: str
|
|
api_secret: str
|
|
api_name: str
|
|
usage_count: int
|
|
max_retries: int = 3
|
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
|
|
|
|
|
class BillingTaskManager:
|
|
"""异步扣费任务管理器(带指数退避重试)"""
|
|
|
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
|
self._platform_url = platform_url.rstrip("/")
|
|
self._platform_key = platform_key
|
|
self._platform_secret = platform_secret
|
|
self._queue: asyncio.Queue = asyncio.Queue()
|
|
self._worker_task = None
|
|
self._running = False
|
|
|
|
@property
|
|
def pending_count(self) -> int:
|
|
return self._queue.qsize()
|
|
|
|
async def start(self):
|
|
if self._running:
|
|
return
|
|
self._running = True
|
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
|
logger.info("BillingTaskManager worker 已启动")
|
|
|
|
async def shutdown(self, timeout: float = 10.0):
|
|
if not self._running or self._worker_task is None:
|
|
return
|
|
self._running = False
|
|
try:
|
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
|
self._worker_task.cancel()
|
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
|
logger.info("BillingTaskManager 已优雅关闭")
|
|
except asyncio.TimeoutError:
|
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
|
self._worker_task.cancel()
|
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
|
|
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
|
self._queue.put_nowait(_BillingTask(
|
|
api_key=api_key,
|
|
api_secret=api_secret,
|
|
api_name=api_name,
|
|
usage_count=usage_count,
|
|
))
|
|
|
|
async def _worker_loop(self):
|
|
while self._running or not self._queue.empty():
|
|
try:
|
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
try:
|
|
await self._execute_with_retry(task)
|
|
except Exception as e:
|
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
|
finally:
|
|
self._queue.task_done()
|
|
|
|
async def _execute_with_retry(self, task: _BillingTask):
|
|
last_error = None
|
|
for attempt in range(1, task.max_retries + 1):
|
|
try:
|
|
result = await self._do_deduct(task)
|
|
if result and result.get("success"):
|
|
logger.info(
|
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
|
)
|
|
return
|
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
if attempt < task.max_retries:
|
|
await asyncio.sleep(2 ** (attempt - 1))
|
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
|
|
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
|
json={
|
|
"api_key": task.api_key,
|
|
"api_secret": task.api_secret,
|
|
"api_name": task.api_name,
|
|
"usage_count": task.usage_count,
|
|
},
|
|
) as resp:
|
|
if resp.status != 200:
|
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
|
result = await resp.json()
|
|
if "message" in result and isinstance(result["message"], dict):
|
|
result = result["message"]
|
|
return result
|
|
|
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
|
"""验证API密钥和团队余额"""
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.verify_api_credentials_and_balance",
|
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
|
) as response:
|
|
if response.status != 200:
|
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
|
|
|
result = await response.json()
|
|
if "message" in result and isinstance(result["message"], dict):
|
|
result = result["message"]
|
|
|
|
if not result.get("success"):
|
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
|
|
|
return result
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
|
|
|
billing_manager = BillingTaskManager(
|
|
platform_url=settings.jingrow_api_url,
|
|
platform_key=settings.jingrow_api_key,
|
|
platform_secret=settings.jingrow_api_secret,
|
|
)
|
|
|
|
def get_token_from_request(request) -> str:
|
|
"""从请求中获取访问令牌"""
|
|
if not request:
|
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
|
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if not auth_header or not auth_header.startswith("token "):
|
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
|
|
|
token = auth_header[6:]
|
|
if ":" not in token:
|
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
|
|
|
return token
|
|
|
|
def jingrow_api_verify_and_billing(api_name: str):
|
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
|
def decorator(func: Callable) -> Callable:
|
|
@wraps(func)
|
|
async def wrapper(*args, **kwargs):
|
|
try:
|
|
request = kwargs.get('request')
|
|
if not request:
|
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
|
|
|
token = get_token_from_request(request)
|
|
api_key, api_secret = token.split(":", 1)
|
|
|
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
|
if not verify_result.get("success"):
|
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
|
|
|
result = await func(*args, **kwargs)
|
|
|
|
usage_count = 1
|
|
try:
|
|
body_data = await request.json()
|
|
if isinstance(body_data, dict):
|
|
for key in ["items", "urls", "images", "files"]:
|
|
if key in body_data and isinstance(body_data[key], list):
|
|
usage_count = len(body_data[key])
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
if isinstance(result, StreamingResponse):
|
|
original_generator = result.body_iterator
|
|
success_count = 0
|
|
|
|
async def wrapped_generator():
|
|
nonlocal success_count
|
|
async for chunk in original_generator:
|
|
try:
|
|
data = json.loads(chunk)
|
|
if isinstance(data, dict) and data.get("status") == "success":
|
|
success_count += 1
|
|
except:
|
|
pass
|
|
yield chunk
|
|
|
|
if success_count > 0:
|
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
|
|
|
return StreamingResponse(
|
|
wrapped_generator(),
|
|
media_type=result.media_type,
|
|
headers=result.headers
|
|
)
|
|
|
|
if isinstance(result, dict) and result.get("success") is True:
|
|
actual_usage_count = result.get("successful_count", usage_count)
|
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
|
return result
|
|
|
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
|
return result
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
|
return wrapper
|
|
return decorator
|