优化midjourney微服务

This commit is contained in:
jingrow 2025-05-20 16:04:57 +08:00
parent 87261dad9f
commit 0e5b27e422
4 changed files with 33 additions and 101 deletions

View File

@ -13,16 +13,7 @@ service = MidjourneyService()
@router.post(settings.generate_route) @router.post(settings.generate_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name) @jingrow_api_verify_and_billing(api_name=settings.api_name)
async def generate_image(data: dict, request: Request): async def generate_image(data: dict, request: Request):
"""
根据文本提示生成图像
Args:
data: 包含文本提示和配置参数的字典
request: FastAPI 请求对象
Returns:
生成的图像内容
"""
if "prompt" not in data: if "prompt" not in data:
raise HTTPException(status_code=400, detail="缺少prompt参数") raise HTTPException(status_code=400, detail="缺少prompt参数")
@ -38,32 +29,3 @@ async def generate_image(data: dict, request: Request):
media_type="application/x-ndjson", media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"} headers={"X-Content-Type-Options": "nosniff"}
) )
@router.post(settings.batch_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name)
async def batch_process_images(data: dict, request: Request):
"""
批量处理多个图像URL将每张图片分割成4张并保存
Args:
data: 包含图片URLs列表的字典
request: FastAPI 请求对象
Returns:
处理结果的流式响应
"""
if "image_urls" not in data or not isinstance(data["image_urls"], list):
raise HTTPException(status_code=400, detail="缺少有效的image_urls参数")
image_urls: List[str] = data["image_urls"]
config = data.get("config", {})
async def process() -> AsyncGenerator[str, None]:
async for result in service.process_batch(image_urls, config):
yield json.dumps(result, ensure_ascii=False) + "\n"
return StreamingResponse(
process(),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)

View File

@ -3,6 +3,8 @@ import sys
import os import os
import io import io
import requests import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import time import time
import random import random
import re import re
@ -60,6 +62,13 @@ class MidjourneyService:
async def initialize_client(self, oauth_token, channel_id): async def initialize_client(self, oauth_token, channel_id):
"""初始化Discord客户端会话""" """初始化Discord客户端会话"""
client = requests.Session() client = requests.Session()
# 添加重试机制
retry = Retry(total=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
client.mount('http://', adapter)
client.mount('https://', adapter)
client.headers.update({ client.headers.update({
'Authorization': oauth_token 'Authorization': oauth_token
}) })
@ -128,8 +137,8 @@ class MidjourneyService:
# 发送请求 # 发送请求
r = client.post(f'{self.API_URL}/interactions', json=params) r = client.post(f'{self.API_URL}/interactions', json=params)
# 初始等待时间从5秒延长到30秒给Discord足够的时间开始处理请求 # 初始等待时间从5秒延长到30秒给Discord足够的时间开始处理请求
print(f"[生成] 已发送请求,等待30秒后开始轮询...") print(f"[生成] 已发送请求,等待20秒后开始轮询...")
await asyncio.sleep(30) await asyncio.sleep(20)
# 轮询获取结果 # 轮询获取结果
imagine_message = None imagine_message = None
@ -398,25 +407,22 @@ class MidjourneyService:
traceback.print_exc() traceback.print_exc()
return None return None
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None): async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
"""批量处理多个图像URL""" """批量处理多个图像URL"""
if not config: if not config:
config = {} config = {}
total = len(image_urls) total = len(image_urls)
success_count = 0 success_count = 0
error_count = 0
for i, image_url in enumerate(image_urls, 1): for i, image_url in enumerate(image_urls, 1):
try: try:
if not is_valid_image_url(image_url): if not is_valid_image_url(image_url):
error_count += 1
response = { response = {
"status": "error", "status": "error",
"index": i, "index": i,
"total": total, "total": total,
"success_count": success_count, "success_count": success_count,
"error_count": error_count,
"message": "无效的图片URL" "message": "无效的图片URL"
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
@ -431,31 +437,26 @@ class MidjourneyService:
"index": i, "index": i,
"total": total, "total": total,
"success_count": success_count, "success_count": success_count,
"error_count": error_count,
"images": split_urls "images": split_urls
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response
else: else:
error_count += 1
response = { response = {
"status": "error", "status": "error",
"index": i, "index": i,
"total": total, "total": total,
"success_count": success_count, "success_count": success_count,
"error_count": error_count,
"message": "分割图片失败" "message": "分割图片失败"
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response
except Exception as e: except Exception as e:
error_count += 1
response = { response = {
"status": "error", "status": "error",
"index": i, "index": i,
"total": total, "total": total,
"success_count": success_count, "success_count": success_count,
"error_count": error_count,
"message": str(e) "message": str(e)
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
@ -474,8 +475,7 @@ class MidjourneyService:
response = { response = {
"status": "error", "status": "error",
"message": "缺少Discord配置", "message": "缺少Discord配置",
"success_count": 0, "success_count": 0
"error_count": 1
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response
@ -488,8 +488,7 @@ class MidjourneyService:
response = { response = {
"status": "error", "status": "error",
"message": str(e), "message": str(e),
"success_count": 0, "success_count": 0
"error_count": 1
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response
@ -524,26 +523,26 @@ class MidjourneyService:
print(f"[生成] 使用的seed值: {seed}") print(f"[生成] 使用的seed值: {seed}")
# 处理参考图像 # 处理参考图像
if 'reference_images' in config and config['reference_images']: if 'image_urls' in config and config['image_urls']:
# 确保是列表格式 # 确保是列表格式
reference_images = config['reference_images'] image_urls = config['image_urls']
if not isinstance(reference_images, list): if not isinstance(image_urls, list):
reference_images = [reference_images] image_urls = [image_urls]
# 转换图片URL # 转换图片URL
image_urls = [] new_image_urls = []
for image_url in reference_images: for image_url in image_urls:
if is_valid_image_url(image_url): if is_valid_image_url(image_url):
try: try:
new_url = get_new_image_url(image_url) new_url = get_new_image_url(image_url)
if new_url: if new_url:
image_urls.append(new_url) new_image_urls.append(new_url)
except Exception as e: except Exception as e:
logger.warning(f"转换图片URL失败: {str(e)}") logger.warning(f"转换图片URL失败: {str(e)}")
# 添加到prompt前面 # 添加到prompt前面
if image_urls: if new_image_urls:
prompt = " ".join(image_urls) + " " + prompt prompt = " ".join(new_image_urls) + " " + prompt
# 设置图像权重 # 设置图像权重
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)): if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
@ -589,7 +588,6 @@ class MidjourneyService:
# 流式返回结果 # 流式返回结果
success_count = 0 success_count = 0
error_count = 0
async for result in self.imagine(client, guild_id, channel_id, prompt, seed): async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
if result.get("status") == "progress": if result.get("status") == "progress":
@ -606,8 +604,7 @@ class MidjourneyService:
response = { response = {
"status": "success", "status": "success",
"success_count": success_count, "success_count": success_count,
"error_count": error_count, "image_urls": result.get("images", []) # 添加image_urls字段
"images": result.get("images", []) # 添加images字段
} }
# 如果需要分割图片 # 如果需要分割图片
@ -616,22 +613,20 @@ class MidjourneyService:
orig_image_url = result["images"][0] orig_image_url = result["images"][0]
split_urls = await self.split_image(orig_image_url) split_urls = await self.split_image(orig_image_url)
if split_urls: if split_urls:
response["images"] = split_urls response["image_urls"] = split_urls
except Exception as e: except Exception as e:
error_count += 1
logger.error(f"分割图像失败: {str(e)}") logger.error(f"分割图像失败: {str(e)}")
response["error_count"] = error_count response["status"] = "error"
response["message"] = f"分割图像失败: {str(e)}"
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response
else: else:
# 错误信息保持统一格式 # 错误信息保持统一格式
error_count += 1
response = { response = {
"status": "error", "status": "error",
"message": result.get("message", "未知错误"), "message": result.get("message", "未知错误"),
"success_count": success_count, "success_count": success_count
"error_count": error_count
} }
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response yield response

View File

@ -11,7 +11,6 @@ class Settings(BaseSettings):
# API路由配置 # API路由配置
router_prefix: str = "/midjourney" router_prefix: str = "/midjourney"
generate_route: str = "/generate" # 生成图片的路由 generate_route: str = "/generate" # 生成图片的路由
batch_route: str = "/batch" # 批量处理图片的路由
api_name: str = "midjourney" # 默认API名称 api_name: str = "midjourney" # 默认API名称
upload_url: str = "http://images.jingrow.com:8080/api/v1/image" upload_url: str = "http://images.jingrow.com:8080/api/v1/image"

View File

@ -137,7 +137,7 @@ def jingrow_api_verify_and_billing(api_name: str):
) )
if isinstance(result, dict) and result.get("success") is True: if isinstance(result, dict) and result.get("success") is True:
actual_usage_count = result.get("successful_count", usage_count) actual_usage_count = result.get("success_count", usage_count)
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
return result return result
@ -152,14 +152,7 @@ def jingrow_api_verify_and_billing(api_name: str):
return decorator return decorator
def is_valid_image_url(url: str) -> bool: def is_valid_image_url(url: str) -> bool:
"""验证图片URL是否有效
Args:
url: 要验证的URL
Returns:
bool: URL是否有效
"""
if not url or not isinstance(url, str): if not url or not isinstance(url, str):
return False return False
@ -177,17 +170,7 @@ def is_valid_image_url(url: str) -> bool:
def get_new_image_url(image_url: str) -> str: def get_new_image_url(image_url: str) -> str:
"""将图片URL转换为新的存储URL
Args:
image_url: 原始图片URL
Returns:
str: 新的图片URL
Raises:
HTTPException: 当图片处理失败时抛出
"""
try: try:
# 使用settings中的upload_url # 使用settings中的upload_url
upload_url = settings.upload_url upload_url = settings.upload_url
@ -235,14 +218,7 @@ def get_new_image_url(image_url: str) -> str:
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}") raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
def sanitize_filename(filename: str) -> str: def sanitize_filename(filename: str) -> str:
"""清理文件名,移除非法字符
Args:
filename: 原始文件名
Returns:
str: 清理后的文件名
"""
# 移除路径分隔符和空字符 # 移除路径分隔符和空字符
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename) filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
# 移除首尾空白字符 # 移除首尾空白字符