优化midjourney微服务
This commit is contained in:
parent
87261dad9f
commit
0e5b27e422
@ -13,16 +13,7 @@ service = MidjourneyService()
|
|||||||
@router.post(settings.generate_route)
|
@router.post(settings.generate_route)
|
||||||
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
async def generate_image(data: dict, request: Request):
|
async def generate_image(data: dict, request: Request):
|
||||||
"""
|
|
||||||
根据文本提示生成图像
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含文本提示和配置参数的字典
|
|
||||||
request: FastAPI 请求对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
生成的图像内容
|
|
||||||
"""
|
|
||||||
if "prompt" not in data:
|
if "prompt" not in data:
|
||||||
raise HTTPException(status_code=400, detail="缺少prompt参数")
|
raise HTTPException(status_code=400, detail="缺少prompt参数")
|
||||||
|
|
||||||
@ -38,32 +29,3 @@ async def generate_image(data: dict, request: Request):
|
|||||||
media_type="application/x-ndjson",
|
media_type="application/x-ndjson",
|
||||||
headers={"X-Content-Type-Options": "nosniff"}
|
headers={"X-Content-Type-Options": "nosniff"}
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post(settings.batch_route)
|
|
||||||
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
|
||||||
async def batch_process_images(data: dict, request: Request):
|
|
||||||
"""
|
|
||||||
批量处理多个图像URL,将每张图片分割成4张并保存
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 包含图片URLs列表的字典
|
|
||||||
request: FastAPI 请求对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
处理结果的流式响应
|
|
||||||
"""
|
|
||||||
if "image_urls" not in data or not isinstance(data["image_urls"], list):
|
|
||||||
raise HTTPException(status_code=400, detail="缺少有效的image_urls参数")
|
|
||||||
|
|
||||||
image_urls: List[str] = data["image_urls"]
|
|
||||||
config = data.get("config", {})
|
|
||||||
|
|
||||||
async def process() -> AsyncGenerator[str, None]:
|
|
||||||
async for result in service.process_batch(image_urls, config):
|
|
||||||
yield json.dumps(result, ensure_ascii=False) + "\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
process(),
|
|
||||||
media_type="application/x-ndjson",
|
|
||||||
headers={"X-Content-Type-Options": "nosniff"}
|
|
||||||
)
|
|
||||||
@ -3,6 +3,8 @@ import sys
|
|||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import requests
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
from urllib3.util.retry import Retry
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
@ -60,6 +62,13 @@ class MidjourneyService:
|
|||||||
async def initialize_client(self, oauth_token, channel_id):
|
async def initialize_client(self, oauth_token, channel_id):
|
||||||
"""初始化Discord客户端会话"""
|
"""初始化Discord客户端会话"""
|
||||||
client = requests.Session()
|
client = requests.Session()
|
||||||
|
|
||||||
|
# 添加重试机制
|
||||||
|
retry = Retry(total=3, backoff_factor=0.5)
|
||||||
|
adapter = HTTPAdapter(max_retries=retry)
|
||||||
|
client.mount('http://', adapter)
|
||||||
|
client.mount('https://', adapter)
|
||||||
|
|
||||||
client.headers.update({
|
client.headers.update({
|
||||||
'Authorization': oauth_token
|
'Authorization': oauth_token
|
||||||
})
|
})
|
||||||
@ -128,8 +137,8 @@ class MidjourneyService:
|
|||||||
# 发送请求
|
# 发送请求
|
||||||
r = client.post(f'{self.API_URL}/interactions', json=params)
|
r = client.post(f'{self.API_URL}/interactions', json=params)
|
||||||
# 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求
|
# 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求
|
||||||
print(f"[生成] 已发送请求,等待30秒后开始轮询...")
|
print(f"[生成] 已发送请求,等待20秒后开始轮询...")
|
||||||
await asyncio.sleep(30)
|
await asyncio.sleep(20)
|
||||||
|
|
||||||
# 轮询获取结果
|
# 轮询获取结果
|
||||||
imagine_message = None
|
imagine_message = None
|
||||||
@ -398,25 +407,22 @@ class MidjourneyService:
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
|
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
|
||||||
"""批量处理多个图像URL"""
|
"""批量处理多个图像URL"""
|
||||||
if not config:
|
if not config:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
total = len(image_urls)
|
total = len(image_urls)
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
|
||||||
|
|
||||||
for i, image_url in enumerate(image_urls, 1):
|
for i, image_url in enumerate(image_urls, 1):
|
||||||
try:
|
try:
|
||||||
if not is_valid_image_url(image_url):
|
if not is_valid_image_url(image_url):
|
||||||
error_count += 1
|
|
||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
|
||||||
"message": "无效的图片URL"
|
"message": "无效的图片URL"
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
@ -431,31 +437,26 @@ class MidjourneyService:
|
|||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
|
||||||
"images": split_urls
|
"images": split_urls
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
error_count += 1
|
|
||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
|
||||||
"message": "分割图片失败"
|
"message": "分割图片失败"
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_count += 1
|
|
||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
|
||||||
"message": str(e)
|
"message": str(e)
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
@ -474,8 +475,7 @@ class MidjourneyService:
|
|||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "缺少Discord配置",
|
"message": "缺少Discord配置",
|
||||||
"success_count": 0,
|
"success_count": 0
|
||||||
"error_count": 1
|
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
@ -488,8 +488,7 @@ class MidjourneyService:
|
|||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e),
|
"message": str(e),
|
||||||
"success_count": 0,
|
"success_count": 0
|
||||||
"error_count": 1
|
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
@ -524,26 +523,26 @@ class MidjourneyService:
|
|||||||
print(f"[生成] 使用的seed值: {seed}")
|
print(f"[生成] 使用的seed值: {seed}")
|
||||||
|
|
||||||
# 处理参考图像
|
# 处理参考图像
|
||||||
if 'reference_images' in config and config['reference_images']:
|
if 'image_urls' in config and config['image_urls']:
|
||||||
# 确保是列表格式
|
# 确保是列表格式
|
||||||
reference_images = config['reference_images']
|
image_urls = config['image_urls']
|
||||||
if not isinstance(reference_images, list):
|
if not isinstance(image_urls, list):
|
||||||
reference_images = [reference_images]
|
image_urls = [image_urls]
|
||||||
|
|
||||||
# 转换图片URL
|
# 转换图片URL
|
||||||
image_urls = []
|
new_image_urls = []
|
||||||
for image_url in reference_images:
|
for image_url in image_urls:
|
||||||
if is_valid_image_url(image_url):
|
if is_valid_image_url(image_url):
|
||||||
try:
|
try:
|
||||||
new_url = get_new_image_url(image_url)
|
new_url = get_new_image_url(image_url)
|
||||||
if new_url:
|
if new_url:
|
||||||
image_urls.append(new_url)
|
new_image_urls.append(new_url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"转换图片URL失败: {str(e)}")
|
logger.warning(f"转换图片URL失败: {str(e)}")
|
||||||
|
|
||||||
# 添加到prompt前面
|
# 添加到prompt前面
|
||||||
if image_urls:
|
if new_image_urls:
|
||||||
prompt = " ".join(image_urls) + " " + prompt
|
prompt = " ".join(new_image_urls) + " " + prompt
|
||||||
|
|
||||||
# 设置图像权重
|
# 设置图像权重
|
||||||
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
||||||
@ -589,7 +588,6 @@ class MidjourneyService:
|
|||||||
|
|
||||||
# 流式返回结果
|
# 流式返回结果
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
|
||||||
|
|
||||||
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
||||||
if result.get("status") == "progress":
|
if result.get("status") == "progress":
|
||||||
@ -606,8 +604,7 @@ class MidjourneyService:
|
|||||||
response = {
|
response = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
"image_urls": result.get("images", []) # 添加image_urls字段
|
||||||
"images": result.get("images", []) # 添加images字段
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果需要分割图片
|
# 如果需要分割图片
|
||||||
@ -616,22 +613,20 @@ class MidjourneyService:
|
|||||||
orig_image_url = result["images"][0]
|
orig_image_url = result["images"][0]
|
||||||
split_urls = await self.split_image(orig_image_url)
|
split_urls = await self.split_image(orig_image_url)
|
||||||
if split_urls:
|
if split_urls:
|
||||||
response["images"] = split_urls
|
response["image_urls"] = split_urls
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_count += 1
|
|
||||||
logger.error(f"分割图像失败: {str(e)}")
|
logger.error(f"分割图像失败: {str(e)}")
|
||||||
response["error_count"] = error_count
|
response["status"] = "error"
|
||||||
|
response["message"] = f"分割图像失败: {str(e)}"
|
||||||
|
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
# 错误信息保持统一格式
|
# 错误信息保持统一格式
|
||||||
error_count += 1
|
|
||||||
response = {
|
response = {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": result.get("message", "未知错误"),
|
"message": result.get("message", "未知错误"),
|
||||||
"success_count": success_count,
|
"success_count": success_count
|
||||||
"error_count": error_count
|
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||||
yield response
|
yield response
|
||||||
@ -11,7 +11,6 @@ class Settings(BaseSettings):
|
|||||||
# API路由配置
|
# API路由配置
|
||||||
router_prefix: str = "/midjourney"
|
router_prefix: str = "/midjourney"
|
||||||
generate_route: str = "/generate" # 生成图片的路由
|
generate_route: str = "/generate" # 生成图片的路由
|
||||||
batch_route: str = "/batch" # 批量处理图片的路由
|
|
||||||
api_name: str = "midjourney" # 默认API名称
|
api_name: str = "midjourney" # 默认API名称
|
||||||
|
|
||||||
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
|
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
|
||||||
|
|||||||
@ -137,7 +137,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(result, dict) and result.get("success") is True:
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
actual_usage_count = result.get("successful_count", usage_count)
|
actual_usage_count = result.get("success_count", usage_count)
|
||||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -152,14 +152,7 @@ def jingrow_api_verify_and_billing(api_name: str):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def is_valid_image_url(url: str) -> bool:
|
def is_valid_image_url(url: str) -> bool:
|
||||||
"""验证图片URL是否有效
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: 要验证的URL
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: URL是否有效
|
|
||||||
"""
|
|
||||||
if not url or not isinstance(url, str):
|
if not url or not isinstance(url, str):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -177,17 +170,7 @@ def is_valid_image_url(url: str) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def get_new_image_url(image_url: str) -> str:
|
def get_new_image_url(image_url: str) -> str:
|
||||||
"""将图片URL转换为新的存储URL
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_url: 原始图片URL
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 新的图片URL
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: 当图片处理失败时抛出
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 使用settings中的upload_url
|
# 使用settings中的upload_url
|
||||||
upload_url = settings.upload_url
|
upload_url = settings.upload_url
|
||||||
@ -235,14 +218,7 @@ def get_new_image_url(image_url: str) -> str:
|
|||||||
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
|
||||||
|
|
||||||
def sanitize_filename(filename: str) -> str:
|
def sanitize_filename(filename: str) -> str:
|
||||||
"""清理文件名,移除非法字符
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename: 原始文件名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 清理后的文件名
|
|
||||||
"""
|
|
||||||
# 移除路径分隔符和空字符
|
# 移除路径分隔符和空字符
|
||||||
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
|
filename = re.sub(r'[\\/:*?"<>|\x00-\x1f]', '', filename)
|
||||||
# 移除首尾空白字符
|
# 移除首尾空白字符
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user