japi/apps/jmidjourney/service.py

451 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import io
import requests
import uuid
import urllib3
import logging
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
import asyncio
import httpx
from typing import Dict, Any, List, AsyncGenerator, Optional
from settings import settings
# 设置日志记录器
logger = logging.getLogger("midjourney_service")
# 禁用不安全请求警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class MidjourneyService:
def __init__(self):
"""初始化 MidjourneyService"""
# 设置代理
self.proxies = {}
if settings.http_proxy and settings.http_proxy.strip():
self.proxies['http'] = settings.http_proxy
if settings.https_proxy and settings.https_proxy.strip():
self.proxies['https'] = settings.https_proxy
# 确保保存目录存在
os.makedirs(settings.save_dir, exist_ok=True)
async def split_image(self, image_url: str) -> Optional[List[str]]:
"""
将一张大图切割成四张子图,保存到本地并返回 URLs
Args:
image_url: 图片URL
Returns:
分割后的图片URL列表失败返回 None
"""
try:
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
if response.status_code != 200:
logger.error(f"下载图像失败,状态码: {response.status_code}")
return None
image_data = response.content
img = Image.open(io.BytesIO(image_data))
width, height = img.size
original_format = img.format
if not original_format:
parsed_url = urlparse(image_url)
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
if not original_format:
original_format = 'PNG'
if width < 500 or height < 500:
logger.warning(f"图像尺寸较小: {width}x{height},跳过分割")
return [image_url]
half_width = width // 2
half_height = height // 2
quadrants = [
img.crop((0, 0, half_width, half_height)),
img.crop((half_width, 0, width, half_height)),
img.crop((0, half_height, half_width, height)),
img.crop((half_width, half_height, width, height))
]
image_id = uuid.uuid4().hex[:10]
save_dir = os.path.abspath(settings.save_dir)
os.makedirs(save_dir, exist_ok=True)
image_urls = []
for i, quadrant in enumerate(quadrants, 1):
try:
filename = f"split_{image_id}_{i}.{original_format.lower()}"
file_path = os.path.join(save_dir, filename)
save_params = {"format": original_format}
if original_format in ['PNG', 'JPEG', 'JPG']:
save_params["optimize"] = True
if original_format in ['JPEG', 'JPG']:
save_params["quality"] = 95
quadrant.save(file_path, **save_params)
if not os.path.exists(file_path):
raise Exception(f"文件保存失败: {file_path}")
file_size = os.path.getsize(file_path)
if file_size == 0:
raise Exception(f"保存的文件大小为0: {file_path}")
image_url = f"{settings.download_url}/{filename}"
image_urls.append(image_url)
logger.info(f"成功保存分割图片 {i}/4: {filename}")
except Exception as e:
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
# 清理已保存的文件
for url in image_urls:
try:
file_path = os.path.join(save_dir, os.path.basename(url))
if os.path.exists(file_path):
os.remove(file_path)
except Exception as del_e:
logger.error(f"删除失败的图片文件时出错: {str(del_e)}")
return None
if len(image_urls) != 4:
logger.error(f"分割图片数量不正确: 期望4张实际{len(image_urls)}")
return None
logger.info("成功完成图片分割生成4张子图")
return image_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
return None
# ==================== VectorEngine API 方法 ====================
async def ve_submit_imagine(
self,
prompt: str,
base64_array: List[str] = None,
notify_hook: str = "",
state: str = "",
bot_type: str = "MID_JOURNEY"
) -> Dict[str, Any]:
"""
提交 imagine 任务到 VectorEngine API
API 文档: https://vectorengine.apifox.cn/api-349239131
Args:
prompt: 提示词
base64_array: base64 编码的图片数组(垫图)
notify_hook: 回调地址
state: 自定义状态
bot_type: 机器人类型,默认 MID_JOURNEY
Returns:
{"result": "任务ID"} 或 {"success": False, "message": "错误信息"}
"""
if base64_array is None:
base64_array = []
payload = {
"base64Array": base64_array,
"notifyHook": notify_hook,
"prompt": prompt,
"state": state,
"botType": bot_type
}
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{settings.vectorengine_api_url}/mj/submit/imagine",
json=payload,
headers=headers
)
if response.status_code != 200:
error_text = response.text
logger.error(f"VectorEngine 提交任务失败: {response.status_code} - {error_text}")
return {"success": False, "message": f"API请求失败: {response.status_code} - {error_text}"}
result = response.json()
logger.info(f"VectorEngine 提交任务成功: {result}")
return result
except Exception as e:
logger.error(f"VectorEngine 提交任务异常: {str(e)}")
return {"success": False, "message": str(e)}
async def ve_get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
根据任务ID查询任务状态
API 文档: https://vectorengine.apifox.cn/api-349239132
Args:
task_id: 任务ID
Returns:
任务状态对象,包含:
- id: 任务ID
- action: 动作类型 (如 IMAGINE)
- status: 任务状态 (SUCCESS, IN_PROGRESS, FAILURE 等)
- progress: 进度字符串 (如 "100%")
- imageUrl: 生成的图片URL
- buttons: 操作按钮列表
- failReason: 失败原因
等字段,失败返回 None
"""
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(
f"{settings.vectorengine_api_url}/mj/task/{task_id}/fetch",
headers=headers
)
if response.status_code != 200:
logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}")
return None
result = response.json()
return result
except Exception as e:
logger.error(f"VectorEngine 获取任务状态异常: {str(e)}")
return None
async def ve_get_task_list(self, task_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
"""
根据ID列表查询任务
API 文档: https://vectorengine.apifox.cn/api-349239133
Args:
task_ids: 任务ID列表
Returns:
任务对象列表,失败返回 None
"""
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{settings.vectorengine_api_url}/mj/task/list-by-condition",
json={"ids": task_ids},
headers=headers
)
if response.status_code != 200:
logger.error(f"VectorEngine 批量查询任务失败: {response.status_code} - {response.text}")
return None
result = response.json()
return result
except Exception as e:
logger.error(f"VectorEngine 批量查询任务异常: {str(e)}")
return None
async def ve_get_image_seed(self, task_id: str) -> Optional[Dict[str, Any]]:
"""
获取任务图片的seed
API 文档: https://vectorengine.apifox.cn/api-349239134
Args:
task_id: 任务ID
Returns:
包含 seed 信息的任务对象,失败返回 None
"""
headers = {
"Authorization": f"Bearer {settings.vectorengine_token}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.get(
f"{settings.vectorengine_api_url}/mj/task/{task_id}/image-seed",
headers=headers
)
if response.status_code != 200:
logger.error(f"VectorEngine 获取seed失败: {response.status_code} - {response.text}")
return None
result = response.json()
return result
except Exception as e:
logger.error(f"VectorEngine 获取seed异常: {str(e)}")
return None
async def ve_generate_image(
self,
prompt: str,
config: Dict = None
) -> AsyncGenerator[Dict, None]:
"""
使用 VectorEngine API 生成图像并以流式方式返回结果
Args:
prompt: 提示词
config: 配置参数,可包含:
- base64_array: base64 编码的图片数组(垫图)
- notify_hook: 回调地址
- state: 自定义状态
- bot_type: 机器人类型
- split_image: 是否分割图片 (默认 True)
Yields:
生成状态和结果:
- {"status": "submitted", "task_id": "..."} # 任务已提交
- {"status": "progress", "progress": 50, "task_id": "..."} # 进度更新
- {"status": "success", "image_urls": [...], ...} # 成功完成
- {"status": "error", "message": "..."} # 错误
"""
if not config:
config = {}
if not settings.vectorengine_token:
yield {
"status": "error",
"message": "未配置 VectorEngine Token"
}
return
# 1. 提交任务
submit_result = await self.ve_submit_imagine(
prompt=prompt,
base64_array=config.get("base64_array", []),
notify_hook=config.get("notify_hook", ""),
state=config.get("state", ""),
bot_type=config.get("bot_type", "MID_JOURNEY")
)
# 检查提交结果 - VectorEngine 返回 {"result": "任务ID"} 表示成功
task_id = submit_result.get("result")
if not task_id:
yield {
"status": "error",
"message": submit_result.get("message", "提交任务失败未获取到任务ID")
}
return
logger.info(f"VectorEngine 任务已提交任务ID: {task_id}")
yield {
"status": "submitted",
"task_id": task_id
}
# 2. 轮询任务状态
polling_count = 0
last_progress = -1
while polling_count < settings.vectorengine_max_polling_attempts:
await asyncio.sleep(settings.vectorengine_polling_interval)
polling_count += 1
task_status = await self.ve_get_task_status(task_id)
if not task_status:
logger.warning(f"{polling_count} 次轮询未获取到任务状态")
continue
status = task_status.get("status", "")
progress_str = task_status.get("progress", "0%")
# 解析进度值(可能是数字或带%的字符串如 '0%', '100%'
try:
if isinstance(progress_str, str):
progress_val = int(progress_str.strip().rstrip('%'))
else:
progress_val = int(progress_str)
except (ValueError, TypeError):
progress_val = 0
# 发送进度更新(仅当进度有变化时)
if progress_val > last_progress:
last_progress = progress_val
yield {
"status": "progress",
"progress": progress_val,
"task_id": task_id
}
# 检查任务完成状态
if status == "SUCCESS":
image_url = task_status.get("imageUrl")
image_urls = [image_url] if image_url else []
response = {
"status": "success",
"task_id": task_id,
"image_urls": image_urls,
"buttons": task_status.get("buttons", []),
"properties": task_status.get("properties", {})
}
# 如果配置了分割图片且有图片URL则处理分割
if config.get("split_image", True) and image_url:
try:
split_urls = await self.split_image(image_url)
if split_urls and len(split_urls) == 4:
response["image_urls"] = split_urls
logger.info(f"图片分割成功,生成 4 张子图")
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
yield response
return
elif status in ("FAILURE", "FAILED"):
fail_reason = task_status.get("failReason", "任务失败")
yield {
"status": "error",
"message": fail_reason,
"task_id": task_id
}
return
elif status == "CANCELLED":
yield {
"status": "error",
"message": "任务已取消",
"task_id": task_id
}
return
# 其他状态(如 IN_PROGRESS, SUBMITTED继续轮询
logger.debug(f"任务 {task_id} 状态: {status}, 进度: {progress_str}")
# 超时
yield {
"status": "error",
"message": f"任务超时,已轮询 {settings.vectorengine_max_polling_attempts}",
"task_id": task_id
}