增加jmidjourney微服务

This commit is contained in:
jingrow 2026-04-15 17:13:57 +08:00
parent 25a43042cf
commit 213078f57b
7 changed files with 1142 additions and 1 deletions

View File

@ -16,7 +16,7 @@ class Settings(BaseSettings):
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
# Jingrow Jcloud API 配置
jingrow_api_url: str = "https://cloud.jingrow.com"
jingrow_api_url: str = "https://console.jingrow.com"
jingrow_api_key: Optional[str] = None
jingrow_api_secret: Optional[str] = None

View File

67
apps/jmidjourney/api.py Normal file
View File

@ -0,0 +1,67 @@
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from service import MidjourneyService
from utils import jingrow_api_verify_and_billing
from settings import settings
import json
import asyncio
from typing import AsyncGenerator, List
router = APIRouter(prefix=settings.router_prefix)
service = MidjourneyService()
@router.post(settings.generate_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name)
async def generate_image(data: dict, request: Request):
if "prompt" not in data:
raise HTTPException(status_code=400, detail="缺少prompt参数")
prompt = data["prompt"]
config = data.get("config", {})
async def generate() -> AsyncGenerator[str, None]:
async for result in service.generate_image(prompt, config):
yield json.dumps(result, ensure_ascii=False) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)
@router.post(settings.ve_generate_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name)
async def ve_generate_image(data: dict, request: Request):
"""
使用 VectorEngine API 生成图片
Args:
data: 请求数据包含:
- prompt: 提示词必需
- config: 配置参数可选
- base64_array: base64 编码的图片数组
- notify_hook: 回调地址
- state: 自定义状态
- bot_type: 机器人类型默认 MID_JOURNEY
- split_image: 是否分割图片默认 True
Returns:
流式响应包含生成状态和结果
"""
if "prompt" not in data:
raise HTTPException(status_code=400, detail="缺少prompt参数")
prompt = data["prompt"]
config = data.get("config", {})
async def generate() -> AsyncGenerator[str, None]:
async for result in service.ve_generate_image(prompt, config):
yield json.dumps(result, ensure_ascii=False) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)

21
apps/jmidjourney/app.py Normal file
View File

@ -0,0 +1,21 @@
from fastapi import FastAPI
from settings import settings
from api import router
app = FastAPI(
title="Midjourney",
description="Midjourney绘画服务API",
version="1.0.0"
)
# 注册路由
app.include_router(router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)

748
apps/jmidjourney/service.py Normal file
View File

@ -0,0 +1,748 @@
import json
import sys
import os
import io
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import random
import re
import uuid
import urllib.request
import urllib3
import logging
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
import asyncio
import httpx
from typing import Dict, Any, List, AsyncGenerator, Optional
from settings import settings
from utils import get_new_image_url, is_valid_image_url
# 设置日志记录器
logger = logging.getLogger("midjourney_service")
# 禁用不安全请求警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class MidjourneyService:
def __init__(self):
"""初始化MidjourneyService"""
self.API_URL = settings.midjourney_api_url
self.APPLICATION_ID = settings.midjourney_application_id
self.DATA_ID = settings.midjourney_data_id
self.DATA_VERSION = settings.midjourney_data_version
self.SESSION_ID = settings.midjourney_session_id
# 设置代理
self.proxies = {}
if settings.http_proxy and settings.http_proxy.strip():
self.proxies['http'] = settings.http_proxy
if settings.https_proxy and settings.https_proxy.strip():
self.proxies['https'] = settings.https_proxy
# 确保保存目录存在
os.makedirs(settings.save_dir, exist_ok=True)
@staticmethod
def first_where(array, key, value=None):
"""在数组中找到第一个匹配条件的项"""
for item in array:
if callable(key) and key(item):
return item
if isinstance(key, str) and item[key].startswith(value):
return item
return None
async def initialize_client(self, oauth_token, channel_id):
"""初始化Discord客户端会话"""
client = requests.Session()
retry = Retry(total=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
client.mount('http://', adapter)
client.mount('https://', adapter)
client.headers.update({
'Authorization': oauth_token
})
if self.proxies:
client.proxies.update(self.proxies)
try:
response = client.get(f'{self.API_URL}/channels/{channel_id}')
data = response.json()
guild_id = data['guild_id']
response = client.get(f'{self.API_URL}/users/@me')
data = response.json()
user_id = data['id']
return client, guild_id, user_id
except Exception as e:
logger.error(f"初始化Discord客户端失败: {str(e)}")
raise Exception(f"初始化Discord客户端失败: {str(e)}")
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
"""发送imagine请求到Discord"""
params = {
'type': 2,
'application_id': self.APPLICATION_ID,
'guild_id': guild_id,
'channel_id': channel_id,
'session_id': self.SESSION_ID,
'data': {
'version': self.DATA_VERSION,
'id': self.DATA_ID,
'name': 'imagine',
'type': 1,
'options': [{
'type': 3,
'name': 'prompt',
'value': prompt
}],
'application_command': {
'id': self.DATA_ID,
'application_id': self.APPLICATION_ID,
'version': self.DATA_VERSION,
'default_member_permissions': None,
'type': 1,
'nsfw': False,
'name': 'imagine',
'description': 'Create images with Midjourney',
'dm_permission': True,
'options': [{
'type': 3,
'name': 'prompt',
'description': 'The prompt to imagine',
'required': True
}]
},
'attachments': []
}
}
try:
client.post(f'{self.API_URL}/interactions', json=params)
await asyncio.sleep(20)
imagine_message = None
count = 0
last_progress = 0
while count < settings.max_polling_attempts:
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
if progress_msg:
content = progress_msg.get('content', '')
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
if progress_value > last_progress:
last_progress = progress_value
logger.info(f"生成进度: {progress_value}%")
yield {
"status": "progress",
"progress": progress_value,
"message_id": progress_msg.get('id'),
"content": content
}
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
if imagine_message:
break
await asyncio.sleep(settings.polling_interval)
count += 1
if count >= settings.max_polling_attempts:
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
return
if not imagine_message:
logger.error("轮询结束但没有获取到有效结果")
yield {"status": "error", "message": "没有获取到有效结果"}
return
logger.info(f"成功获取图像结果消息ID: {imagine_message.get('id')}")
yield {
"status": "success",
"message_id": imagine_message.get('id'),
"content": imagine_message.get('content', ''),
"images": self.extract_image_urls(imagine_message)
}
except Exception as e:
logger.error(f"发送Imagine请求失败: {str(e)}")
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
"""获取生成图像的消息"""
try:
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
def criteria(item):
content = item.get('content', '')
if seed is not None and f"--seed {seed}" in content:
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
logger.info(f"任务进度: {progress_value}%")
if "%" in content:
return False
if "%" in content:
return False
if seed is not None:
seed_pattern = f"--seed {seed}"
if seed_pattern not in content:
return False
else:
if 'attachments' in item and len(item.get('attachments', [])) > 0:
return True
return False
return self.first_where(data, criteria)
except Exception as e:
logger.error(f"获取通道消息失败: {str(e)}")
return None
async def get_progress_message(self, client, channel_id, prompt, seed=None):
"""获取包含进度信息的消息"""
try:
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
for item in data:
content = item.get('content', '')
if seed is not None and f"--seed {seed}" in content and "%" in content:
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
return item
return None
except Exception as e:
logger.error(f"获取进度消息失败: {str(e)}")
return None
def extract_image_urls(self, message):
"""从消息中提取图像URL"""
image_urls = []
if 'attachments' in message and message['attachments']:
for attachment in message['attachments']:
if 'url' in attachment:
image_urls.append(attachment['url'])
return image_urls
async def split_image(self, image_url):
"""将一张大图切割成四张子图保存到本地并返回URLs"""
try:
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
if response.status_code != 200:
logger.error(f"下载图像失败,状态码: {response.status_code}")
return None
image_data = response.content
img = Image.open(io.BytesIO(image_data))
width, height = img.size
original_format = img.format
if not original_format:
parsed_url = urlparse(image_url)
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
if not original_format:
original_format = 'PNG'
if width < 500 or height < 500:
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
return None
half_width = width // 2
half_height = height // 2
quadrants = [
img.crop((0, 0, half_width, half_height)),
img.crop((half_width, 0, width, half_height)),
img.crop((0, half_height, half_width, height)),
img.crop((half_width, half_height, width, height))
]
image_id = uuid.uuid4().hex[:10]
save_dir = os.path.abspath(settings.save_dir)
os.makedirs(save_dir, exist_ok=True)
image_urls = []
for i, quadrant in enumerate(quadrants, 1):
try:
filename = f"split_{image_id}_{i}.{original_format.lower()}"
file_path = os.path.join(save_dir, filename)
save_params = {"format": original_format}
if original_format in ['PNG', 'JPEG', 'JPG']:
save_params["optimize"] = True
if original_format in ['JPEG', 'JPG']:
save_params["quality"] = 95
quadrant.save(file_path, **save_params)
if not os.path.exists(file_path):
raise Exception(f"文件保存失败: {file_path}")
file_size = os.path.getsize(file_path)
if file_size == 0:
raise Exception(f"保存的文件大小为0: {file_path}")
image_url = f"{settings.download_url}/{filename}"
image_urls.append(image_url)
logger.info(f"成功保存分割图片 {i}/4: {filename}")
except Exception as e:
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
for url in image_urls:
try:
file_path = os.path.join(save_dir, os.path.basename(url))
if os.path.exists(file_path):
os.remove(file_path)
except Exception as del_e:
logger.error(f"删除失败的图片文件时出错: {str(del_e)}")
return None
if len(image_urls) != 4:
logger.error(f"分割图片数量不正确: 期望4张实际{len(image_urls)}")
return None
logger.info("成功完成图片分割生成4张子图")
return image_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
return None
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
"""批量处理多个图像URL"""
if not config:
config = {}
total = len(image_urls)
success_count = 0
for i, image_url in enumerate(image_urls, 1):
try:
if not is_valid_image_url(image_url):
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": "无效的图片URL"
}
continue
split_urls = await self.split_image(image_url)
if split_urls and len(split_urls) == 4:
success_count += 1
yield {
"status": "success",
"index": i,
"total": total,
"success_count": success_count,
"images": split_urls
}
else:
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": "分割图片失败"
}
except Exception as e:
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": str(e)
}
async def generate_image(self, prompt, config=None):
"""生成图像并以流式方式返回结果"""
if not config:
config = {}
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
channel_id = config.get('channel_id', settings.midjourney_channel_id)
if not oauth_token or not channel_id:
yield {
"status": "error",
"message": "缺少Discord配置",
"success_count": 0
}
return
try:
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
except Exception as e:
yield {
"status": "error",
"message": str(e),
"success_count": 0
}
return
prompt = prompt.strip()
options = settings.midjourney_default_options.copy()
if 'options' in config:
options.update(config.get('options', {}))
if "seed" not in options or not options.get('seed'):
options['seed'] = random.randint(0, 4294967295)
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
iw = max(0.1, min(config['image_weight'], 3))
options['iw'] = iw
seed = options.pop('seed', None)
parameter = ""
no_value_key = ['relax', 'fast', 'turbo', 'tile']
for key, value in options.items():
if key in no_value_key:
parameter += f" --{key}"
else:
parameter += f" --{key} {value}"
if seed:
parameter += f" --seed {seed}"
logger.info(f"使用的seed值: {seed}")
if 'image_urls' in config and config['image_urls']:
image_urls = config['image_urls']
if not isinstance(image_urls, list):
image_urls = [image_urls]
new_image_urls = []
for image_url in image_urls:
if is_valid_image_url(image_url):
try:
new_url = get_new_image_url(image_url)
if new_url:
new_image_urls.append(new_url)
except Exception as e:
logger.warning(f"转换图片URL失败: {str(e)}")
if new_image_urls:
prompt = " ".join(new_image_urls) + " " + prompt
if 'characters' in config and config['characters']:
char_urls = []
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
for char_url in characters:
if is_valid_image_url(char_url):
try:
new_url = get_new_image_url(char_url)
if new_url:
char_urls.append(new_url)
except Exception as e:
logger.warning(f"转换角色图片URL失败: {str(e)}")
if char_urls:
prompt = prompt + " --cref " + " ".join(char_urls)
if 'styles' in config and config['styles']:
style_urls = []
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
for style_url in styles:
if is_valid_image_url(style_url):
try:
new_url = get_new_image_url(style_url)
if new_url:
style_urls.append(new_url)
except Exception as e:
logger.warning(f"转换风格图片URL失败: {str(e)}")
if style_urls:
prompt = prompt + " --sref " + " ".join(style_urls)
prompt = prompt + parameter
success_count = 0
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
if result.get("status") == "progress":
yield {
"status": "progress",
"progress": result.get("progress", 0),
"seed": seed
}
elif result.get("status") == "success":
success_count += 1
response = {
"status": "success",
"success_count": success_count,
"image_urls": result.get("images", [])
}
if config.get("split_image", True) and result.get("images"):
try:
orig_image_url = result["images"][0]
split_urls = await self.split_image(orig_image_url)
if split_urls:
response["image_urls"] = split_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
response["status"] = "error"
response["message"] = f"分割图像失败: {str(e)}"
yield response
else:
yield {
"status": "error",
"message": result.get("message", "未知错误"),
"success_count": success_count
}
async def ve_submit_imagine(self, prompt: str, base64_array: List[str] = None,
notify_hook: str = "", state: str = "",
bot_type: str = "MID_JOURNEY") -> Dict[str, Any]:
"""
提交 imagine 任务到 VectorEngine API
Args:
prompt: 提示词
base64_array: base64 编码的图片数组
notify_hook: 回调地址
state: 自定义状态
bot_type: 机器人类型默认 MID_JOURNEY
Returns:
包含任务ID的响应
"""
if base64_array is None:
base64_array = []
payload = {
"base64Array": base64_array,
"notifyHook": notify_hook,
"prompt": prompt,
"state": state,
"botType": bot_type
}
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{settings.vectorengine_api_url}/mj/submit/imagine",
json=payload,
headers=headers
)
if response.status_code != 200:
logger.error(f"VectorEngine API 请求失败: {response.status_code} - {response.text}")
return {"success": False, "message": f"API请求失败: {response.status_code}"}
result = response.json()
logger.info(f"VectorEngine 提交任务成功: {result}")
return result
except Exception as e:
logger.error(f"VectorEngine 提交任务异常: {str(e)}")
return {"success": False, "message": str(e)}
async def ve_get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
获取 VectorEngine 任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(
f"{settings.vectorengine_api_url}/mj/task/{task_id}/fetch",
headers=headers
)
if response.status_code != 200:
logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}")
return {"success": False, "message": f"获取任务状态失败: {response.status_code}"}
result = response.json()
return result
except Exception as e:
logger.error(f"VectorEngine 获取任务状态异常: {str(e)}")
return {"success": False, "message": str(e)}
async def ve_generate_image(self, prompt: str, config: Dict = None) -> AsyncGenerator[Dict, None]:
"""
使用 VectorEngine API 生成图像并以流式方式返回结果
Args:
prompt: 提示词
config: 配置参数可包含:
- base64_array: base64 编码的图片数组
- notify_hook: 回调地址
- state: 自定义状态
- bot_type: 机器人类型
- split_image: 是否分割图片 (默认 True)
Yields:
生成状态和结果
"""
if not config:
config = {}
if not settings.vectorengine_token:
yield {
"status": "error",
"message": "未配置 VectorEngine Token",
"success_count": 0
}
return
# 提交任务
submit_result = await self.ve_submit_imagine(
prompt=prompt,
base64_array=config.get("base64_array", []),
notify_hook=config.get("notify_hook", ""),
state=config.get("state", ""),
bot_type=config.get("bot_type", "MID_JOURNEY")
)
if not submit_result.get("success", True) and "result" not in submit_result:
yield {
"status": "error",
"message": submit_result.get("message", "提交任务失败"),
"success_count": 0
}
return
# 获取任务ID
task_id = submit_result.get("result")
if not task_id:
yield {
"status": "error",
"message": "未获取到任务ID",
"success_count": 0
}
return
logger.info(f"VectorEngine 任务已提交任务ID: {task_id}")
yield {
"status": "submitted",
"task_id": task_id
}
# 轮询任务状态
polling_count = 0
last_progress = 0
while polling_count < settings.vectorengine_max_polling_attempts:
await asyncio.sleep(settings.vectorengine_polling_interval)
polling_count += 1
task_status = await self.ve_get_task_status(task_id)
if not task_status:
continue
status = task_status.get("status")
progress = task_status.get("progress", 0)
# 解析进度值(可能是数字或带%的字符串如 '0%'
try:
if isinstance(progress, str):
progress_val = int(progress.strip().rstrip('%'))
else:
progress_val = int(progress)
except (ValueError, TypeError):
progress_val = 0
# 发送进度更新
if progress_val > last_progress:
last_progress = progress_val
yield {
"status": "progress",
"progress": last_progress,
"task_id": task_id
}
# 检查任务完成状态
if status == "SUCCESS":
image_url = task_status.get("imageUrl")
image_urls = [image_url] if image_url else []
# 检查是否有多个图片URL对于网格图
if "buttons" in task_status:
# 可能包含 U1, U2, U3, U4 按钮对应的图片
pass
response = {
"status": "success",
"success_count": 1,
"task_id": task_id,
"image_urls": image_urls
}
# 如果配置了分割图片,则处理分割
if config.get("split_image", True) and image_url:
try:
split_urls = await self.split_image(image_url)
if split_urls:
response["image_urls"] = split_urls
logger.info(f"图片分割成功,生成 {len(split_urls)} 张子图")
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
yield response
return
elif status == "FAILURE" or status == "FAILED":
yield {
"status": "error",
"message": task_status.get("failReason", "任务失败"),
"task_id": task_id,
"success_count": 0
}
return
elif status == "CANCELLED":
yield {
"status": "error",
"message": "任务已取消",
"task_id": task_id,
"success_count": 0
}
return
# 超时
yield {
"status": "error",
"message": "任务超时",
"task_id": task_id,
"success_count": 0
}

View File

@ -0,0 +1,72 @@
from pydantic_settings import BaseSettings
from typing import Optional, Dict
from functools import lru_cache
class Settings(BaseSettings):
# Japi Server 配置
host: str = "0.0.0.0"
port: int = 8113
debug: bool = False
# API路由配置
router_prefix: str = "/jmidjourney"
generate_route: str = "/generate" # 生成图片的路由
ve_generate_route: str = "/ve/generate" # VectorEngine 生成图片的路由
api_name: str = "jmidjourney" # 默认API名称
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
# VectorEngine API 配置
vectorengine_api_url: str = "https://api.vectorengine.ai"
vectorengine_token: Optional[str] = None
vectorengine_max_polling_attempts: int = 120 # 最大轮询次数
vectorengine_polling_interval: int = 3 # 轮询间隔(秒)
# 图片保存配置
save_dir: str = "../jfile/files"
# Japi 静态资源下载URL
download_url: str = "https://api.jingrow.com/files"
# Jingrow Jcloud API 配置
jingrow_api_url: str = "https://console.jingrow.com"
jingrow_api_key: Optional[str] = None
jingrow_api_secret: Optional[str] = None
# Discord Midjourney配置
midjourney_api_url: str = "https://discord.com/api/v9"
midjourney_application_id: str = "936929561302675456"
midjourney_data_id: str = "938956540159881230"
midjourney_data_version: str = "1237876415471554623"
midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029"
midjourney_channel_id: str = "1259838588510670941"
midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0"
midjourney_suffix: str = "mj" # 图片文件名的后缀
# 代理配置
http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理
https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理
# Midjourney默认选项
midjourney_default_options: Dict = {
"ar": "1:1",
"v": "6.1",
"quality": "1"
}
# 图像设置
add_title_to_image_name: bool = False # 是否将标题添加到图片名称中
# 超时设置(秒)
request_timeout: int = 30
max_polling_attempts: int = 60
polling_interval: int = 3
class Config:
env_file = ".env"
@lru_cache()
def get_settings() -> Settings:
return Settings()
# 创建全局配置实例
settings = get_settings()

233
apps/jmidjourney/utils.py Normal file
View File

@ -0,0 +1,233 @@
import aiohttp
from functools import wraps
from fastapi import HTTPException
import os
from typing import Callable, Any, Dict, Optional, Tuple, List
from settings import settings
from fastapi.responses import StreamingResponse
import json
import requests
import io
import re
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
"""验证API密钥和团队余额"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.verify_api_credentials_and_balance",
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
result = await response.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
if not result.get("success"):
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
"""从Jingrow平台扣除API使用费"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{settings.jingrow_api_url}/api/action/jcloud.api.account.deduct_api_usage_fee",
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
json={
"api_key": api_key,
"api_secret": api_secret,
"api_name": api_name,
"usage_count": usage_count
}
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
result = await response.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
return result
except HTTPException:
raise
except Exception as e:
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
def get_token_from_request(request) -> str:
"""从请求中获取访问令牌"""
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
auth_header = request.headers.get("Authorization", "")
if not auth_header or not auth_header.startswith("token "):
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
token = auth_header[6:]
if ":" not in token:
raise HTTPException(status_code=401, detail="无效的令牌格式")
return token
def jingrow_api_verify_and_billing(api_name: str):
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
try:
request = kwargs.get('request')
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
token = get_token_from_request(request)
api_key, api_secret = token.split(":", 1)
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
if not verify_result.get("success"):
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
result = await func(*args, **kwargs)
usage_count = 1
try:
body_data = await request.json()
if isinstance(body_data, dict):
for key in ["items", "urls", "images", "files"]:
if key in body_data and isinstance(body_data[key], list):
usage_count = len(body_data[key])
break
except Exception:
pass
if isinstance(result, StreamingResponse):
original_generator = result.body_iterator
success_count = 0
async def wrapped_generator():
nonlocal success_count
async for chunk in original_generator:
try:
data = json.loads(chunk)
if isinstance(data, dict) and data.get("status") == "success":
success_count += 1
except:
pass
yield chunk
if success_count > 0:
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
return StreamingResponse(
wrapped_generator(),
media_type=result.media_type,
headers=result.headers
)
if isinstance(result, dict) and result.get("success") is True:
actual_usage_count = result.get("success_count", usage_count)
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
return result
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
return wrapper
return decorator
def is_valid_image_url(url: str) -> bool:
if not url or not isinstance(url, str):
return False
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
# 检查文件扩展名
path = parsed.path.lower()
valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif']
return any(path.endswith(ext) for ext in valid_extensions)
except:
return False
def get_new_image_url(image_url: str) -> str:
try:
# 使用settings中的upload_url
upload_url = settings.upload_url
if not upload_url:
raise HTTPException(status_code=500, detail="未配置上传URL")
# 下载图片
response = requests.get(image_url, verify=False, timeout=30)
if response.status_code != 200:
raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}")
image_data = response.content
# 解析文件名和扩展名
parsed_url = urlparse(image_url)
file_name = Path(parsed_url.path).name
file_name = sanitize_filename(file_name)
file_ext = Path(file_name).suffix.lower()
# 如果图片是webp格式转换为png格式
if file_ext == '.webp':
image = Image.open(io.BytesIO(image_data))
png_buffer = io.BytesIO()
image.save(png_buffer, format='PNG')
image_data = png_buffer.getvalue()
file_name = file_name.replace('.webp', '.png')
# 准备文件上传
files = {"file": (file_name, image_data)}
# 上传图片
upload_response = requests.post(upload_url, files=files, verify=False, timeout=30)
if upload_response.status_code != 200:
error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}"
raise HTTPException(status_code=500, detail=error_message)
result = upload_response.json()
new_url = result.get("url")
if not new_url:
raise HTTPException(status_code=500, detail="上传成功但未返回URL")
return new_url
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
def sanitize_filename(filename: str) -> str:
# 移除路径分隔符和空字符
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
# 移除首尾空白字符
filename = filename.strip()
# 如果文件名为空,使用默认名称
if not filename:
filename = "untitled"
# 限制文件名长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename