import aiohttp from functools import wraps from fastapi import HTTPException import os from typing import Callable, Any, Dict, Optional, Tuple from fastapi.responses import StreamingResponse import json from settings import settings import asyncio import logging from dataclasses import dataclass, field logger = logging.getLogger(__name__) @dataclass class _BillingTask: api_key: str api_secret: str api_name: str usage_count: int max_retries: int = 3 created_at: float = field(default_factory=asyncio.get_event_loop().time) task_id: str = field(default_factory=lambda: hex(id(object()))[2:]) class BillingTaskManager: """异步扣费任务管理器(带指数退避重试)""" def __init__(self, platform_url: str, platform_key: str, platform_secret: str): self._platform_url = platform_url.rstrip("/") self._platform_key = platform_key self._platform_secret = platform_secret self._queue: asyncio.Queue = asyncio.Queue() self._worker_task = None self._running = False @property def pending_count(self) -> int: return self._queue.qsize() async def start(self): if self._running: return self._running = True self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker") logger.info("BillingTaskManager worker 已启动") async def shutdown(self, timeout: float = 10.0): if not self._running or self._worker_task is None: return self._running = False try: await asyncio.wait_for(self._queue.join(), timeout=timeout) self._worker_task.cancel() await asyncio.gather(self._worker_task, return_exceptions=True) logger.info("BillingTaskManager 已优雅关闭") except asyncio.TimeoutError: logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize()) self._worker_task.cancel() await asyncio.gather(self._worker_task, return_exceptions=True) async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int): self._queue.put_nowait(_BillingTask( api_key=api_key, api_secret=api_secret, api_name=api_name, usage_count=usage_count, )) async def _worker_loop(self): while self._running or not self._queue.empty(): try: task = await asyncio.wait_for(self._queue.get(), timeout=1.0) except asyncio.TimeoutError: continue try: await self._execute_with_retry(task) except Exception as e: logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e) finally: self._queue.task_done() async def _execute_with_retry(self, task: _BillingTask): last_error = None for attempt in range(1, task.max_retries + 1): try: result = await self._do_deduct(task) if result and result.get("success"): logger.info( "扣费成功: [%s] api=%s count=%d attempt=%d/%d", task.task_id, task.api_name, task.usage_count, attempt, task.max_retries, ) return last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应" except Exception as e: last_error = str(e) if attempt < task.max_retries: await asyncio.sleep(2 ** (attempt - 1)) raise RuntimeError(f"扣费最终失败: {last_error}") async def _do_deduct(self, task: _BillingTask) -> dict: async with aiohttp.ClientSession() as session: async with session.post( f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee", headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"}, json={ "api_key": task.api_key, "api_secret": task.api_secret, "api_name": task.api_name, "usage_count": task.usage_count, }, ) as resp: if resp.status != 200: raise RuntimeError(f"扣费服务返回 HTTP {resp.status}") result = await resp.json() if "message" in result and isinstance(result["message"], dict): result = result["message"] return result async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: """验证API密钥和团队余额""" try: async with aiohttp.ClientSession() as session: async with session.post( f"{settings.jingrow_api_url}/api/action/jcloud.api.account.verify_api_credentials_and_balance", headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} ) as response: if response.status != 200: raise HTTPException(status_code=500, detail="验证服务暂时不可用") result = await response.json() if "message" in result and isinstance(result["message"], dict): result = result["message"] if not result.get("success"): raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") billing_manager = BillingTaskManager( platform_url=settings.jingrow_api_url, platform_key=settings.jingrow_api_key, platform_secret=settings.jingrow_api_secret, ) def get_token_from_request(request) -> str: """从请求中获取访问令牌""" if not request: raise HTTPException(status_code=400, detail="无法获取请求信息") auth_header = request.headers.get("Authorization", "") if not auth_header or not auth_header.startswith("token "): raise HTTPException(status_code=401, detail="无效的Authorization头格式") token = auth_header[6:] if ":" not in token: raise HTTPException(status_code=401, detail="无效的令牌格式") return token def jingrow_api_verify_and_billing(api_name: str): """Jingrow API 验证装饰器(带余额检查和扣费)""" def decorator(func: Callable) -> Callable: @wraps(func) async def wrapper(*args, **kwargs): try: request = kwargs.get('request') if not request: raise HTTPException(status_code=400, detail="无法获取请求信息") token = get_token_from_request(request) api_key, api_secret = token.split(":", 1) verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) if not verify_result.get("success"): raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) result = await func(*args, **kwargs) usage_count = 1 try: body_data = await request.json() if isinstance(body_data, dict): for key in ["items", "urls", "images", "files"]: if key in body_data and isinstance(body_data[key], list): usage_count = len(body_data[key]) break except Exception: pass if isinstance(result, StreamingResponse): original_generator = result.body_iterator success_count = 0 async def wrapped_generator(): nonlocal success_count async for chunk in original_generator: try: data = json.loads(chunk) if isinstance(data, dict) and data.get("status") == "success": success_count += 1 except: pass yield chunk if success_count > 0: await billing_manager.enqueue(api_key, api_secret, api_name, success_count) return StreamingResponse( wrapped_generator(), media_type=result.media_type, headers=result.headers ) if isinstance(result, dict) and result.get("success") is True: actual_usage_count = result.get("successful_count", usage_count) await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count) return result await billing_manager.enqueue(api_key, api_secret, api_name, usage_count) return result except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") return wrapper return decorator