451 lines
16 KiB
Python
451 lines
16 KiB
Python
import json
|
||
import os
|
||
import io
|
||
import requests
|
||
import uuid
|
||
import urllib3
|
||
import logging
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import asyncio
|
||
import httpx
|
||
from typing import Dict, Any, List, AsyncGenerator, Optional
|
||
from settings import settings
|
||
|
||
|
||
# 设置日志记录器
|
||
logger = logging.getLogger("midjourney_service")
|
||
|
||
# 禁用不安全请求警告
|
||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||
|
||
|
||
class MidjourneyService:
|
||
def __init__(self):
|
||
"""初始化 MidjourneyService"""
|
||
# 设置代理
|
||
self.proxies = {}
|
||
if settings.http_proxy and settings.http_proxy.strip():
|
||
self.proxies['http'] = settings.http_proxy
|
||
if settings.https_proxy and settings.https_proxy.strip():
|
||
self.proxies['https'] = settings.https_proxy
|
||
|
||
# 确保保存目录存在
|
||
os.makedirs(settings.save_dir, exist_ok=True)
|
||
|
||
async def split_image(self, image_url: str) -> Optional[List[str]]:
|
||
"""
|
||
将一张大图切割成四张子图,保存到本地并返回 URLs
|
||
|
||
Args:
|
||
image_url: 图片URL
|
||
|
||
Returns:
|
||
分割后的图片URL列表,失败返回 None
|
||
"""
|
||
try:
|
||
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
||
if response.status_code != 200:
|
||
logger.error(f"下载图像失败,状态码: {response.status_code}")
|
||
return None
|
||
|
||
image_data = response.content
|
||
img = Image.open(io.BytesIO(image_data))
|
||
width, height = img.size
|
||
|
||
original_format = img.format
|
||
if not original_format:
|
||
parsed_url = urlparse(image_url)
|
||
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
|
||
if not original_format:
|
||
original_format = 'PNG'
|
||
|
||
if width < 500 or height < 500:
|
||
logger.warning(f"图像尺寸较小: {width}x{height},跳过分割")
|
||
return [image_url]
|
||
|
||
half_width = width // 2
|
||
half_height = height // 2
|
||
|
||
quadrants = [
|
||
img.crop((0, 0, half_width, half_height)),
|
||
img.crop((half_width, 0, width, half_height)),
|
||
img.crop((0, half_height, half_width, height)),
|
||
img.crop((half_width, half_height, width, height))
|
||
]
|
||
|
||
image_id = uuid.uuid4().hex[:10]
|
||
save_dir = os.path.abspath(settings.save_dir)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
image_urls = []
|
||
for i, quadrant in enumerate(quadrants, 1):
|
||
try:
|
||
filename = f"split_{image_id}_{i}.{original_format.lower()}"
|
||
file_path = os.path.join(save_dir, filename)
|
||
|
||
save_params = {"format": original_format}
|
||
if original_format in ['PNG', 'JPEG', 'JPG']:
|
||
save_params["optimize"] = True
|
||
if original_format in ['JPEG', 'JPG']:
|
||
save_params["quality"] = 95
|
||
|
||
quadrant.save(file_path, **save_params)
|
||
|
||
if not os.path.exists(file_path):
|
||
raise Exception(f"文件保存失败: {file_path}")
|
||
|
||
file_size = os.path.getsize(file_path)
|
||
if file_size == 0:
|
||
raise Exception(f"保存的文件大小为0: {file_path}")
|
||
|
||
image_url = f"{settings.download_url}/{filename}"
|
||
image_urls.append(image_url)
|
||
|
||
logger.info(f"成功保存分割图片 {i}/4: {filename}")
|
||
except Exception as e:
|
||
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
||
# 清理已保存的文件
|
||
for url in image_urls:
|
||
try:
|
||
file_path = os.path.join(save_dir, os.path.basename(url))
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
except Exception as del_e:
|
||
logger.error(f"删除失败的图片文件时出错: {str(del_e)}")
|
||
return None
|
||
|
||
if len(image_urls) != 4:
|
||
logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张")
|
||
return None
|
||
|
||
logger.info("成功完成图片分割,生成4张子图")
|
||
return image_urls
|
||
|
||
except Exception as e:
|
||
logger.error(f"分割图像失败: {str(e)}")
|
||
return None
|
||
|
||
# ==================== VectorEngine API 方法 ====================
|
||
|
||
async def ve_submit_imagine(
|
||
self,
|
||
prompt: str,
|
||
base64_array: List[str] = None,
|
||
notify_hook: str = "",
|
||
state: str = "",
|
||
bot_type: str = "MID_JOURNEY"
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
提交 imagine 任务到 VectorEngine API
|
||
|
||
API 文档: https://vectorengine.apifox.cn/api-349239131
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
base64_array: base64 编码的图片数组(垫图)
|
||
notify_hook: 回调地址
|
||
state: 自定义状态
|
||
bot_type: 机器人类型,默认 MID_JOURNEY
|
||
|
||
Returns:
|
||
{"result": "任务ID"} 或 {"success": False, "message": "错误信息"}
|
||
"""
|
||
if base64_array is None:
|
||
base64_array = []
|
||
|
||
payload = {
|
||
"base64Array": base64_array,
|
||
"notifyHook": notify_hook,
|
||
"prompt": prompt,
|
||
"state": state,
|
||
"botType": bot_type
|
||
}
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(
|
||
f"{settings.vectorengine_api_url}/mj/submit/imagine",
|
||
json=payload,
|
||
headers=headers
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
error_text = response.text
|
||
logger.error(f"VectorEngine 提交任务失败: {response.status_code} - {error_text}")
|
||
return {"success": False, "message": f"API请求失败: {response.status_code} - {error_text}"}
|
||
|
||
result = response.json()
|
||
logger.info(f"VectorEngine 提交任务成功: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"VectorEngine 提交任务异常: {str(e)}")
|
||
return {"success": False, "message": str(e)}
|
||
|
||
async def ve_get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
根据任务ID查询任务状态
|
||
|
||
API 文档: https://vectorengine.apifox.cn/api-349239132
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
任务状态对象,包含:
|
||
- id: 任务ID
|
||
- action: 动作类型 (如 IMAGINE)
|
||
- status: 任务状态 (SUCCESS, IN_PROGRESS, FAILURE 等)
|
||
- progress: 进度字符串 (如 "100%")
|
||
- imageUrl: 生成的图片URL
|
||
- buttons: 操作按钮列表
|
||
- failReason: 失败原因
|
||
等字段,失败返回 None
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.get(
|
||
f"{settings.vectorengine_api_url}/mj/task/{task_id}/fetch",
|
||
headers=headers
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}")
|
||
return None
|
||
|
||
result = response.json()
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"VectorEngine 获取任务状态异常: {str(e)}")
|
||
return None
|
||
|
||
async def ve_get_task_list(self, task_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
|
||
"""
|
||
根据ID列表查询任务
|
||
|
||
API 文档: https://vectorengine.apifox.cn/api-349239133
|
||
|
||
Args:
|
||
task_ids: 任务ID列表
|
||
|
||
Returns:
|
||
任务对象列表,失败返回 None
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(
|
||
f"{settings.vectorengine_api_url}/mj/task/list-by-condition",
|
||
json={"ids": task_ids},
|
||
headers=headers
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"VectorEngine 批量查询任务失败: {response.status_code} - {response.text}")
|
||
return None
|
||
|
||
result = response.json()
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"VectorEngine 批量查询任务异常: {str(e)}")
|
||
return None
|
||
|
||
async def ve_get_image_seed(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取任务图片的seed
|
||
|
||
API 文档: https://vectorengine.apifox.cn/api-349239134
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
包含 seed 信息的任务对象,失败返回 None
|
||
"""
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.get(
|
||
f"{settings.vectorengine_api_url}/mj/task/{task_id}/image-seed",
|
||
headers=headers
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"VectorEngine 获取seed失败: {response.status_code} - {response.text}")
|
||
return None
|
||
|
||
result = response.json()
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"VectorEngine 获取seed异常: {str(e)}")
|
||
return None
|
||
|
||
async def ve_generate_image(
|
||
self,
|
||
prompt: str,
|
||
config: Dict = None
|
||
) -> AsyncGenerator[Dict, None]:
|
||
"""
|
||
使用 VectorEngine API 生成图像并以流式方式返回结果
|
||
|
||
Args:
|
||
prompt: 提示词
|
||
config: 配置参数,可包含:
|
||
- base64_array: base64 编码的图片数组(垫图)
|
||
- notify_hook: 回调地址
|
||
- state: 自定义状态
|
||
- bot_type: 机器人类型
|
||
- split_image: 是否分割图片 (默认 True)
|
||
|
||
Yields:
|
||
生成状态和结果:
|
||
- {"status": "submitted", "task_id": "..."} # 任务已提交
|
||
- {"status": "progress", "progress": 50, "task_id": "..."} # 进度更新
|
||
- {"status": "success", "image_urls": [...], ...} # 成功完成
|
||
- {"status": "error", "message": "..."} # 错误
|
||
"""
|
||
if not config:
|
||
config = {}
|
||
|
||
if not settings.vectorengine_token:
|
||
yield {
|
||
"status": "error",
|
||
"message": "未配置 VectorEngine Token"
|
||
}
|
||
return
|
||
|
||
# 1. 提交任务
|
||
submit_result = await self.ve_submit_imagine(
|
||
prompt=prompt,
|
||
base64_array=config.get("base64_array", []),
|
||
notify_hook=config.get("notify_hook", ""),
|
||
state=config.get("state", ""),
|
||
bot_type=config.get("bot_type", "MID_JOURNEY")
|
||
)
|
||
|
||
# 检查提交结果 - VectorEngine 返回 {"result": "任务ID"} 表示成功
|
||
task_id = submit_result.get("result")
|
||
if not task_id:
|
||
yield {
|
||
"status": "error",
|
||
"message": submit_result.get("message", "提交任务失败,未获取到任务ID")
|
||
}
|
||
return
|
||
|
||
logger.info(f"VectorEngine 任务已提交,任务ID: {task_id}")
|
||
yield {
|
||
"status": "submitted",
|
||
"task_id": task_id
|
||
}
|
||
|
||
# 2. 轮询任务状态
|
||
polling_count = 0
|
||
last_progress = -1
|
||
|
||
while polling_count < settings.vectorengine_max_polling_attempts:
|
||
await asyncio.sleep(settings.vectorengine_polling_interval)
|
||
polling_count += 1
|
||
|
||
task_status = await self.ve_get_task_status(task_id)
|
||
|
||
if not task_status:
|
||
logger.warning(f"第 {polling_count} 次轮询未获取到任务状态")
|
||
continue
|
||
|
||
status = task_status.get("status", "")
|
||
progress_str = task_status.get("progress", "0%")
|
||
|
||
# 解析进度值(可能是数字或带%的字符串如 '0%', '100%')
|
||
try:
|
||
if isinstance(progress_str, str):
|
||
progress_val = int(progress_str.strip().rstrip('%'))
|
||
else:
|
||
progress_val = int(progress_str)
|
||
except (ValueError, TypeError):
|
||
progress_val = 0
|
||
|
||
# 发送进度更新(仅当进度有变化时)
|
||
if progress_val > last_progress:
|
||
last_progress = progress_val
|
||
yield {
|
||
"status": "progress",
|
||
"progress": progress_val,
|
||
"task_id": task_id
|
||
}
|
||
|
||
# 检查任务完成状态
|
||
if status == "SUCCESS":
|
||
image_url = task_status.get("imageUrl")
|
||
image_urls = [image_url] if image_url else []
|
||
|
||
response = {
|
||
"status": "success",
|
||
"task_id": task_id,
|
||
"image_urls": image_urls,
|
||
"buttons": task_status.get("buttons", []),
|
||
"properties": task_status.get("properties", {})
|
||
}
|
||
|
||
# 如果配置了分割图片且有图片URL,则处理分割
|
||
if config.get("split_image", True) and image_url:
|
||
try:
|
||
split_urls = await self.split_image(image_url)
|
||
if split_urls and len(split_urls) == 4:
|
||
response["image_urls"] = split_urls
|
||
logger.info(f"图片分割成功,生成 4 张子图")
|
||
except Exception as e:
|
||
logger.error(f"分割图像失败: {str(e)}")
|
||
|
||
yield response
|
||
return
|
||
|
||
elif status in ("FAILURE", "FAILED"):
|
||
fail_reason = task_status.get("failReason", "任务失败")
|
||
yield {
|
||
"status": "error",
|
||
"message": fail_reason,
|
||
"task_id": task_id
|
||
}
|
||
return
|
||
|
||
elif status == "CANCELLED":
|
||
yield {
|
||
"status": "error",
|
||
"message": "任务已取消",
|
||
"task_id": task_id
|
||
}
|
||
return
|
||
|
||
# 其他状态(如 IN_PROGRESS, SUBMITTED)继续轮询
|
||
logger.debug(f"任务 {task_id} 状态: {status}, 进度: {progress_str}")
|
||
|
||
# 超时
|
||
yield {
|
||
"status": "error",
|
||
"message": f"任务超时,已轮询 {settings.vectorengine_max_polling_attempts} 次",
|
||
"task_id": task_id
|
||
}
|