323 lines
13 KiB
Python
323 lines
13 KiB
Python
import aiohttp
|
||
from functools import wraps
|
||
from fastapi import HTTPException
|
||
import os
|
||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||
from fastapi.responses import StreamingResponse
|
||
import json
|
||
import requests
|
||
import io
|
||
import re
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
from settings import settings
|
||
import asyncio
|
||
import logging
|
||
from dataclasses import dataclass, field
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class _BillingTask:
|
||
api_key: str
|
||
api_secret: str
|
||
api_name: str
|
||
usage_count: int
|
||
max_retries: int = 3
|
||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||
|
||
|
||
class BillingTaskManager:
|
||
"""异步扣费任务管理器(带指数退避重试)"""
|
||
|
||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||
self._platform_url = platform_url.rstrip("/")
|
||
self._platform_key = platform_key
|
||
self._platform_secret = platform_secret
|
||
self._queue: asyncio.Queue = asyncio.Queue()
|
||
self._worker_task = None
|
||
self._running = False
|
||
|
||
@property
|
||
def pending_count(self) -> int:
|
||
return self._queue.qsize()
|
||
|
||
async def start(self):
|
||
if self._running:
|
||
return
|
||
self._running = True
|
||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||
logger.info("BillingTaskManager worker 已启动")
|
||
|
||
async def shutdown(self, timeout: float = 10.0):
|
||
if not self._running or self._worker_task is None:
|
||
return
|
||
self._running = False
|
||
try:
|
||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||
self._worker_task.cancel()
|
||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||
logger.info("BillingTaskManager 已优雅关闭")
|
||
except asyncio.TimeoutError:
|
||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||
self._worker_task.cancel()
|
||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||
|
||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||
self._queue.put_nowait(_BillingTask(
|
||
api_key=api_key,
|
||
api_secret=api_secret,
|
||
api_name=api_name,
|
||
usage_count=usage_count,
|
||
))
|
||
|
||
async def _worker_loop(self):
|
||
while self._running or not self._queue.empty():
|
||
try:
|
||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
try:
|
||
await self._execute_with_retry(task)
|
||
except Exception as e:
|
||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||
finally:
|
||
self._queue.task_done()
|
||
|
||
async def _execute_with_retry(self, task: _BillingTask):
|
||
last_error = None
|
||
for attempt in range(1, task.max_retries + 1):
|
||
try:
|
||
result = await self._do_deduct(task)
|
||
if result and result.get("success"):
|
||
logger.info(
|
||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||
)
|
||
return
|
||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
if attempt < task.max_retries:
|
||
await asyncio.sleep(2 ** (attempt - 1))
|
||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||
|
||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||
json={
|
||
"api_key": task.api_key,
|
||
"api_secret": task.api_secret,
|
||
"api_name": task.api_name,
|
||
"usage_count": task.usage_count,
|
||
},
|
||
) as resp:
|
||
if resp.status != 200:
|
||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||
result = await resp.json()
|
||
if "message" in result and isinstance(result["message"], dict):
|
||
result = result["message"]
|
||
return result
|
||
|
||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||
"""验证API密钥和团队余额"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.post(
|
||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.verify_api_credentials_and_balance",
|
||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||
) as response:
|
||
if response.status != 200:
|
||
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||
|
||
result = await response.json()
|
||
if "message" in result and isinstance(result["message"], dict):
|
||
result = result["message"]
|
||
|
||
if not result.get("success"):
|
||
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||
|
||
billing_manager = BillingTaskManager(
|
||
platform_url=settings.jingrow_api_url,
|
||
platform_key=settings.jingrow_api_key,
|
||
platform_secret=settings.jingrow_api_secret,
|
||
)
|
||
|
||
def get_token_from_request(request) -> str:
|
||
"""从请求中获取访问令牌"""
|
||
if not request:
|
||
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||
|
||
auth_header = request.headers.get("Authorization", "")
|
||
if not auth_header or not auth_header.startswith("token "):
|
||
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||
|
||
token = auth_header[6:]
|
||
if ":" not in token:
|
||
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||
|
||
return token
|
||
|
||
def jingrow_api_verify_and_billing(api_name: str):
|
||
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||
def decorator(func: Callable) -> Callable:
|
||
@wraps(func)
|
||
async def wrapper(*args, **kwargs):
|
||
try:
|
||
request = kwargs.get('request')
|
||
if not request:
|
||
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||
|
||
token = get_token_from_request(request)
|
||
api_key, api_secret = token.split(":", 1)
|
||
|
||
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||
if not verify_result.get("success"):
|
||
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||
|
||
result = await func(*args, **kwargs)
|
||
|
||
usage_count = 1
|
||
try:
|
||
body_data = await request.json()
|
||
if isinstance(body_data, dict):
|
||
for key in ["items", "urls", "images", "files"]:
|
||
if key in body_data and isinstance(body_data[key], list):
|
||
usage_count = len(body_data[key])
|
||
break
|
||
except Exception:
|
||
pass
|
||
|
||
if isinstance(result, StreamingResponse):
|
||
original_generator = result.body_iterator
|
||
success_count = 0
|
||
|
||
async def wrapped_generator():
|
||
nonlocal success_count
|
||
async for chunk in original_generator:
|
||
try:
|
||
data = json.loads(chunk)
|
||
if isinstance(data, dict) and data.get("status") == "success":
|
||
success_count += 1
|
||
except:
|
||
pass
|
||
yield chunk
|
||
|
||
if success_count > 0:
|
||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||
|
||
return StreamingResponse(
|
||
wrapped_generator(),
|
||
media_type=result.media_type,
|
||
headers=result.headers
|
||
)
|
||
|
||
if isinstance(result, dict) and result.get("success") is True:
|
||
actual_usage_count = result.get("success_count", usage_count)
|
||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||
return result
|
||
|
||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||
return result
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||
return wrapper
|
||
return decorator
|
||
|
||
def is_valid_image_url(url: str) -> bool:
|
||
|
||
if not url or not isinstance(url, str):
|
||
return False
|
||
|
||
try:
|
||
parsed = urlparse(url)
|
||
if not parsed.scheme or not parsed.netloc:
|
||
return False
|
||
|
||
# 检查文件扩展名
|
||
path = parsed.path.lower()
|
||
valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif']
|
||
return any(path.endswith(ext) for ext in valid_extensions)
|
||
except:
|
||
return False
|
||
|
||
|
||
def get_new_image_url(image_url: str) -> str:
|
||
|
||
try:
|
||
# 使用settings中的upload_url
|
||
upload_url = settings.upload_url
|
||
if not upload_url:
|
||
raise HTTPException(status_code=500, detail="未配置上传URL")
|
||
|
||
# 下载图片
|
||
response = requests.get(image_url, verify=False, timeout=30)
|
||
if response.status_code != 200:
|
||
raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}")
|
||
image_data = response.content
|
||
|
||
# 解析文件名和扩展名
|
||
parsed_url = urlparse(image_url)
|
||
file_name = Path(parsed_url.path).name
|
||
file_name = sanitize_filename(file_name)
|
||
file_ext = Path(file_name).suffix.lower()
|
||
|
||
# 如果图片是webp格式,转换为png格式
|
||
if file_ext == '.webp':
|
||
image = Image.open(io.BytesIO(image_data))
|
||
png_buffer = io.BytesIO()
|
||
image.save(png_buffer, format='PNG')
|
||
image_data = png_buffer.getvalue()
|
||
file_name = file_name.replace('.webp', '.png')
|
||
|
||
# 准备文件上传
|
||
files = {"file": (file_name, image_data)}
|
||
|
||
# 上传图片
|
||
upload_response = requests.post(upload_url, files=files, verify=False, timeout=30)
|
||
if upload_response.status_code != 200:
|
||
error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}"
|
||
raise HTTPException(status_code=500, detail=error_message)
|
||
|
||
result = upload_response.json()
|
||
new_url = result.get("url")
|
||
if not new_url:
|
||
raise HTTPException(status_code=500, detail="上传成功但未返回URL")
|
||
return new_url
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
|
||
# 移除路径分隔符和空字符
|
||
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
|
||
# 移除首尾空白字符
|
||
filename = filename.strip()
|
||
# 如果文件名为空,使用默认名称
|
||
if not filename:
|
||
filename = "untitled"
|
||
# 限制文件名长度
|
||
if len(filename) > 255:
|
||
name, ext = os.path.splitext(filename)
|
||
filename = name[:255-len(ext)] + ext
|
||
return filename
|