重构api的扣费逻辑:独立 Billing Task Manager(带重试 + 持久化队列)
This commit is contained in:
parent
325ea6ea20
commit
44cff241e3
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Add Background",
|
title="Add Background",
|
||||||
description="图片添加背景颜色",
|
description="图片添加背景颜色",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="JArt",
|
title="JArt",
|
||||||
description="JArt绘画服务API",
|
description="JArt绘画服务API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="JArt V1",
|
title="JArt V1",
|
||||||
description="JArt绘画服务API",
|
description="JArt绘画服务API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="JChat Service",
|
title="JChat Service",
|
||||||
description="AI聊天服务API",
|
description="AI聊天服务API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jdescribe",
|
title="Jdescribe",
|
||||||
description="Jdescribe描述图片API",
|
description="Jdescribe描述图片API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="JEmbedding",
|
title="JEmbedding",
|
||||||
description="文本向量化服务",
|
description="文本向量化服务",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|||||||
@ -2,9 +2,120 @@ import aiohttp
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from typing import Callable, Any, Dict
|
from typing import Callable, Any, Dict
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
@ -29,29 +140,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
try:
|
platform_url=settings.jingrow_api_url,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with session.post(
|
platform_secret=settings.jingrow_api_secret,
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
)
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
return result
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
@ -105,7 +198,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
pass
|
pass
|
||||||
yield chunk
|
yield chunk
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -115,10 +208,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
return wrapper
|
return wrapper
|
||||||
return decorator
|
return decorator
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Midjourney",
|
title="Midjourney",
|
||||||
description="Midjourney绘画服务API",
|
description="Midjourney绘画服务API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
@ -12,6 +11,118 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -39,33 +150,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -128,7 +217,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -138,10 +227,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("success_count", usage_count)
|
actual_usage_count = result.get("success_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jtranslate",
|
title="Jtranslate",
|
||||||
description="Jtranslate翻译API",
|
description="Jtranslate翻译API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jupscale",
|
title="Jupscale",
|
||||||
description="Jupscale放大图片API",
|
description="Jupscale放大图片API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Jvector",
|
title="Jvector",
|
||||||
description="Jvector转矢量图API",
|
description="Jvector转矢量图API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Midjourney",
|
title="Midjourney",
|
||||||
description="Midjourney绘画服务API",
|
description="Midjourney绘画服务API",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,7 +3,6 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
@ -12,6 +11,118 @@ import re
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -39,33 +150,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -128,7 +217,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -138,10 +227,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("success_count", usage_count)
|
actual_usage_count = result.get("success_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Pattern to Tshirt",
|
title="Pattern to Tshirt",
|
||||||
description="将图片中的花型添加到T恤上",
|
description="将图片中的花型添加到T恤上",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -2,16 +2,19 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router, service
|
from api import router, service
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
"""应用生命周期管理"""
|
"""应用生命周期管理"""
|
||||||
# 启动时初始化
|
# 启动时初始化
|
||||||
|
await billing_manager.start()
|
||||||
if settings.enable_queue_batch:
|
if settings.enable_queue_batch:
|
||||||
await service._start_queue_processor()
|
await service._start_queue_processor()
|
||||||
yield
|
yield
|
||||||
# 关闭时清理资源
|
# 关闭时清理资源
|
||||||
|
await billing_manager.shutdown()
|
||||||
await service.cleanup()
|
await service.cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
@ -1,11 +1,23 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from api import router
|
from api import router
|
||||||
|
from utils import billing_manager
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
await billing_manager.start()
|
||||||
|
yield
|
||||||
|
await billing_manager.shutdown()
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Tryon",
|
title="Tryon",
|
||||||
description="虚拟试穿",
|
description="虚拟试穿",
|
||||||
version="1.0.0"
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|||||||
@ -3,9 +3,120 @@ from functools import wraps
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Any, Dict, Optional, Tuple
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
from settings import settings
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _BillingTask:
|
||||||
|
api_key: str
|
||||||
|
api_secret: str
|
||||||
|
api_name: str
|
||||||
|
usage_count: int
|
||||||
|
max_retries: int = 3
|
||||||
|
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||||
|
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||||
|
|
||||||
|
|
||||||
|
class BillingTaskManager:
|
||||||
|
"""异步扣费任务管理器(带指数退避重试)"""
|
||||||
|
|
||||||
|
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||||
|
self._platform_url = platform_url.rstrip("/")
|
||||||
|
self._platform_key = platform_key
|
||||||
|
self._platform_secret = platform_secret
|
||||||
|
self._queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
self._worker_task = None
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
return self._queue.qsize()
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
self._running = True
|
||||||
|
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||||
|
logger.info("BillingTaskManager worker 已启动")
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: float = 10.0):
|
||||||
|
if not self._running or self._worker_task is None:
|
||||||
|
return
|
||||||
|
self._running = False
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
logger.info("BillingTaskManager 已优雅关闭")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||||
|
self._worker_task.cancel()
|
||||||
|
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||||
|
|
||||||
|
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||||
|
self._queue.put_nowait(_BillingTask(
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
api_name=api_name,
|
||||||
|
usage_count=usage_count,
|
||||||
|
))
|
||||||
|
|
||||||
|
async def _worker_loop(self):
|
||||||
|
while self._running or not self._queue.empty():
|
||||||
|
try:
|
||||||
|
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
await self._execute_with_retry(task)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||||
|
finally:
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
async def _execute_with_retry(self, task: _BillingTask):
|
||||||
|
last_error = None
|
||||||
|
for attempt in range(1, task.max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = await self._do_deduct(task)
|
||||||
|
if result and result.get("success"):
|
||||||
|
logger.info(
|
||||||
|
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||||
|
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||||
|
except Exception as e:
|
||||||
|
last_error = str(e)
|
||||||
|
if attempt < task.max_retries:
|
||||||
|
await asyncio.sleep(2 ** (attempt - 1))
|
||||||
|
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||||
|
|
||||||
|
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": task.api_key,
|
||||||
|
"api_secret": task.api_secret,
|
||||||
|
"api_name": task.api_name,
|
||||||
|
"usage_count": task.usage_count,
|
||||||
|
},
|
||||||
|
) as resp:
|
||||||
|
if resp.status != 200:
|
||||||
|
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||||
|
result = await resp.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
return result
|
||||||
|
|
||||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
"""验证API密钥和团队余额"""
|
"""验证API密钥和团队余额"""
|
||||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
billing_manager = BillingTaskManager(
|
||||||
"""从Jingrow平台扣除API使用费"""
|
platform_url=settings.jingrow_api_url,
|
||||||
try:
|
platform_key=settings.jingrow_api_key,
|
||||||
async with aiohttp.ClientSession() as session:
|
platform_secret=settings.jingrow_api_secret,
|
||||||
async with session.post(
|
)
|
||||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
|
||||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
|
||||||
json={
|
|
||||||
"api_key": api_key,
|
|
||||||
"api_secret": api_secret,
|
|
||||||
"api_name": api_name,
|
|
||||||
"usage_count": usage_count
|
|
||||||
}
|
|
||||||
) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "message" in result and isinstance(result["message"], dict):
|
|
||||||
result = result["message"]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
|
||||||
|
|
||||||
def get_token_from_request(request) -> str:
|
def get_token_from_request(request) -> str:
|
||||||
"""从请求中获取访问令牌"""
|
"""从请求中获取访问令牌"""
|
||||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
if success_count > 0:
|
if success_count > 0:
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
wrapped_generator(),
|
wrapped_generator(),
|
||||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user