重构api的扣费逻辑:独立 Billing Task Manager(带重试 + 持久化队列)
This commit is contained in:
parent
325ea6ea20
commit
44cff241e3
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Add Background",
|
||||
description="图片添加背景颜色",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="JArt",
|
||||
description="JArt绘画服务API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="JArt V1",
|
||||
description="JArt绘画服务API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="JChat Service",
|
||||
description="AI聊天服务API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Jdescribe",
|
||||
description="Jdescribe描述图片API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="JEmbedding",
|
||||
description="文本向量化服务",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
@ -2,9 +2,120 @@ import aiohttp
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
from typing import Callable, Any, Dict
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
@ -29,29 +140,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
@ -105,7 +198,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
pass
|
||||
yield chunk
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -115,10 +208,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Midjourney",
|
||||
description="Midjourney绘画服务API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,7 +3,6 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
import requests
|
||||
@ -12,6 +11,118 @@ import re
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -39,33 +150,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -128,7 +217,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -138,10 +227,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("success_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Jtranslate",
|
||||
description="Jtranslate翻译API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Jupscale",
|
||||
description="Jupscale放大图片API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Jvector",
|
||||
description="Jvector转矢量图API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Midjourney",
|
||||
description="Midjourney绘画服务API",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,7 +3,6 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
import requests
|
||||
@ -12,6 +11,118 @@ import re
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -39,33 +150,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -128,7 +217,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -138,10 +227,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("success_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Pattern to Tshirt",
|
||||
description="将图片中的花型添加到T恤上",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -2,16 +2,19 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router, service
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
# 启动时初始化
|
||||
await billing_manager.start()
|
||||
if settings.enable_queue_batch:
|
||||
await service._start_queue_processor()
|
||||
yield
|
||||
# 关闭时清理资源
|
||||
await billing_manager.shutdown()
|
||||
await service.cleanup()
|
||||
|
||||
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
from utils import billing_manager
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""应用生命周期管理"""
|
||||
await billing_manager.start()
|
||||
yield
|
||||
await billing_manager.shutdown()
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Tryon",
|
||||
description="虚拟试穿",
|
||||
version="1.0.0"
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
|
||||
@ -3,9 +3,120 @@ from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
from settings import settings
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BillingTask:
|
||||
api_key: str
|
||||
api_secret: str
|
||||
api_name: str
|
||||
usage_count: int
|
||||
max_retries: int = 3
|
||||
created_at: float = field(default_factory=asyncio.get_event_loop().time)
|
||||
task_id: str = field(default_factory=lambda: hex(id(object()))[2:])
|
||||
|
||||
|
||||
class BillingTaskManager:
|
||||
"""异步扣费任务管理器(带指数退避重试)"""
|
||||
|
||||
def __init__(self, platform_url: str, platform_key: str, platform_secret: str):
|
||||
self._platform_url = platform_url.rstrip("/")
|
||||
self._platform_key = platform_key
|
||||
self._platform_secret = platform_secret
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
async def start(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
self._worker_task = asyncio.create_task(self._worker_loop(), name="billing-worker")
|
||||
logger.info("BillingTaskManager worker 已启动")
|
||||
|
||||
async def shutdown(self, timeout: float = 10.0):
|
||||
if not self._running or self._worker_task is None:
|
||||
return
|
||||
self._running = False
|
||||
try:
|
||||
await asyncio.wait_for(self._queue.join(), timeout=timeout)
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
logger.info("BillingTaskManager 已优雅关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("BillingTaskManager 关闭超时,剩余 %d 个任务", self._queue.qsize())
|
||||
self._worker_task.cancel()
|
||||
await asyncio.gather(self._worker_task, return_exceptions=True)
|
||||
|
||||
async def enqueue(self, api_key: str, api_secret: str, api_name: str, usage_count: int):
|
||||
self._queue.put_nowait(_BillingTask(
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
api_name=api_name,
|
||||
usage_count=usage_count,
|
||||
))
|
||||
|
||||
async def _worker_loop(self):
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
task = await asyncio.wait_for(self._queue.get(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
try:
|
||||
await self._execute_with_retry(task)
|
||||
except Exception as e:
|
||||
logger.error("扣费最终失败: [%s] api=%s error=%s", task.task_id, task.api_name, e)
|
||||
finally:
|
||||
self._queue.task_done()
|
||||
|
||||
async def _execute_with_retry(self, task: _BillingTask):
|
||||
last_error = None
|
||||
for attempt in range(1, task.max_retries + 1):
|
||||
try:
|
||||
result = await self._do_deduct(task)
|
||||
if result and result.get("success"):
|
||||
logger.info(
|
||||
"扣费成功: [%s] api=%s count=%d attempt=%d/%d",
|
||||
task.task_id, task.api_name, task.usage_count, attempt, task.max_retries,
|
||||
)
|
||||
return
|
||||
last_error = result.get("message", "扣费接口返回失败") if result else "扣费接口无响应"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
if attempt < task.max_retries:
|
||||
await asyncio.sleep(2 ** (attempt - 1))
|
||||
raise RuntimeError(f"扣费最终失败: {last_error}")
|
||||
|
||||
async def _do_deduct(self, task: _BillingTask) -> dict:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self._platform_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {self._platform_key}:{self._platform_secret}"},
|
||||
json={
|
||||
"api_key": task.api_key,
|
||||
"api_secret": task.api_secret,
|
||||
"api_name": task.api_name,
|
||||
"usage_count": task.usage_count,
|
||||
},
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise RuntimeError(f"扣费服务返回 HTTP {resp.status}")
|
||||
result = await resp.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
return result
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
@ -33,33 +144,11 @@ async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
billing_manager = BillingTaskManager(
|
||||
platform_url=settings.jingrow_api_url,
|
||||
platform_key=settings.jingrow_api_key,
|
||||
platform_secret=settings.jingrow_api_secret,
|
||||
)
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
@ -122,7 +211,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
@ -132,10 +221,10 @@ def jingrow_api_verify_and_billing(api_name: str):
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
await billing_manager.enqueue(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user