japi/apps/jmidjourney/utils.py

323 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import aiohttp
from functools import wraps
from fastapi import HTTPException
import os
from typing import Callable, Any, Dict, Optional, Tuple, List
from fastapi.responses import StreamingResponse
import json
import requests
import io
import re
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
from settings import settings
import asyncio
import logging
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
@dataclass
class _BillingTask:
api_key: str
api_secret: str
api_name: str
usage_count: int
max_retries: int = 3
created_at: float = field(default_factory=asyncio.get_event_loop().time)
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
class BillingTaskManager:
"""异步扣费任务管理器(带指数退避重试)"""
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
self._platform_url = platform_url.rstrip("/")
self._platform_key = platform_key
self._platform_secret = platform_secret
self._queue: asyncio.Queue = asyncio.Queue()
self._worker_task = None
self._running = False
@property
def pending_count(self) -> int:
return self._queue.qsize()
async def start(self):
if self._running:
return
self._running = True
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
logger.info("BillingTaskManager worker 已启动")
async def shutdown(self, timeout: float = 10.0):
if not self._running or self._worker_task is None:
return
self._running = False
try:
await asyncio.wait_for(self._queue.join(), timeout=timeout)
self._worker_task.cancel()
await asyncio.gather(self._worker_task, return_exceptions=True)
logger.info("BillingTaskManager 已优雅关闭")
except asyncio.TimeoutError:
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
self._worker_task.cancel()
await asyncio.gather(self._worker_task, return_exceptions=True)
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
self._queue.put_nowait(_BillingTask(
api_key=api_key,
api_secret=api_secret,
api_name=api_name,
usage_count=usage_count,
))
async def _worker_loop(self):
while self._running or not self._queue.empty():
try:
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
except asyncio.TimeoutError:
continue
try:
await self._execute_with_retry(task)
except Exception as e:
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
finally:
self._queue.task_done()
async def _execute_with_retry(self, task: _BillingTask):
last_error = None
for attempt in range(1, task.max_retries + 1):
try:
result = await self._do_deduct(task)
if result and result.get("success"):
logger.info(
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
)
return
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
except Exception as e:
last_error = str(e)
if attempt < task.max_retries:
await asyncio.sleep(2 ** (attempt - 1))
raise RuntimeError(f"扣费最终失败: {last_error}")
async def _do_deduct(self, task: _BillingTask) -> dict:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
json={
"api_key": task.api_key,
"api_secret": task.api_secret,
"api_name": task.api_name,
"usage_count": task.usage_count,
},
) as resp:
if resp.status != 200:
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
result = await resp.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
return result
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
"""验证API密钥和团队余额"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.verify_api_credentials_and_balance",
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
result = await response.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
if not result.get("success"):
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
billing_manager = BillingTaskManager(
platform_url=settings.jingrow_api_url,
platform_key=settings.jingrow_api_key,
platform_secret=settings.jingrow_api_secret,
)
def get_token_from_request(request) -> str:
"""从请求中获取访问令牌"""
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
auth_header = request.headers.get("Authorization", "")
if not auth_header or not auth_header.startswith("token "):
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
token = auth_header[6:]
if ":" not in token:
raise HTTPException(status_code=401, detail="无效的令牌格式")
return token
def jingrow_api_verify_and_billing(api_name: str):
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
try:
request = kwargs.get('request')
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
token = get_token_from_request(request)
api_key, api_secret = token.split(":", 1)
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
if not verify_result.get("success"):
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
result = await func(*args, **kwargs)
usage_count = 1
try:
body_data = await request.json()
if isinstance(body_data, dict):
for key in ["items", "urls", "images", "files"]:
if key in body_data and isinstance(body_data[key], list):
usage_count = len(body_data[key])
break
except Exception:
pass
if isinstance(result, StreamingResponse):
original_generator = result.body_iterator
success_count = 0
async def wrapped_generator():
nonlocal success_count
async for chunk in original_generator:
try:
data = json.loads(chunk)
if isinstance(data, dict) and data.get("status") == "success":
success_count += 1
except:
pass
yield chunk
if success_count > 0:
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
return StreamingResponse(
wrapped_generator(),
media_type=result.media_type,
headers=result.headers
)
if isinstance(result, dict) and result.get("success") is True:
actual_usage_count = result.get("success_count", usage_count)
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
return result
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
return wrapper
return decorator
def is_valid_image_url(url: str) -> bool:
if not url or not isinstance(url, str):
return False
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
# 检查文件扩展名
path = parsed.path.lower()
valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif']
return any(path.endswith(ext) for ext in valid_extensions)
except:
return False
def get_new_image_url(image_url: str) -> str:
try:
# 使用settings中的upload_url
upload_url = settings.upload_url
if not upload_url:
raise HTTPException(status_code=500, detail="未配置上传URL")
# 下载图片
response = requests.get(image_url, verify=False, timeout=30)
if response.status_code != 200:
raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}")
image_data = response.content
# 解析文件名和扩展名
parsed_url = urlparse(image_url)
file_name = Path(parsed_url.path).name
file_name = sanitize_filename(file_name)
file_ext = Path(file_name).suffix.lower()
# 如果图片是webp格式转换为png格式
if file_ext == '.webp':
image = Image.open(io.BytesIO(image_data))
png_buffer = io.BytesIO()
image.save(png_buffer, format='PNG')
image_data = png_buffer.getvalue()
file_name = file_name.replace('.webp', '.png')
# 准备文件上传
files = {"file": (file_name, image_data)}
# 上传图片
upload_response = requests.post(upload_url, files=files, verify=False, timeout=30)
if upload_response.status_code != 200:
error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}"
raise HTTPException(status_code=500, detail=error_message)
result = upload_response.json()
new_url = result.get("url")
if not new_url:
raise HTTPException(status_code=500, detail="上传成功但未返回URL")
return new_url
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
def sanitize_filename(filename: str) -> str:
# 移除路径分隔符和空字符
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
# 移除首尾空白字符
filename = filename.strip()
# 如果文件名为空,使用默认名称
if not filename:
filename = "untitled"
# 限制文件名长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename