修复jmidjourney无法更新进度及下载图片的问题
This commit is contained in:
parent
213078f57b
commit
5b17573c46
@ -1,14 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import requests
|
import requests
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
from urllib3.util.retry import Retry
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import uuid
|
import uuid
|
||||||
import urllib.request
|
|
||||||
import urllib3
|
import urllib3
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -18,7 +12,6 @@ import asyncio
|
|||||||
import httpx
|
import httpx
|
||||||
from typing import Dict, Any, List, AsyncGenerator, Optional
|
from typing import Dict, Any, List, AsyncGenerator, Optional
|
||||||
from settings import settings
|
from settings import settings
|
||||||
from utils import get_new_image_url, is_valid_image_url
|
|
||||||
|
|
||||||
|
|
||||||
# 设置日志记录器
|
# 设置日志记录器
|
||||||
@ -27,15 +20,10 @@ logger = logging.getLogger("midjourney_service")
|
|||||||
# 禁用不安全请求警告
|
# 禁用不安全请求警告
|
||||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
|
||||||
class MidjourneyService:
|
class MidjourneyService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化MidjourneyService"""
|
"""初始化 MidjourneyService"""
|
||||||
self.API_URL = settings.midjourney_api_url
|
|
||||||
self.APPLICATION_ID = settings.midjourney_application_id
|
|
||||||
self.DATA_ID = settings.midjourney_data_id
|
|
||||||
self.DATA_VERSION = settings.midjourney_data_version
|
|
||||||
self.SESSION_ID = settings.midjourney_session_id
|
|
||||||
|
|
||||||
# 设置代理
|
# 设置代理
|
||||||
self.proxies = {}
|
self.proxies = {}
|
||||||
if settings.http_proxy and settings.http_proxy.strip():
|
if settings.http_proxy and settings.http_proxy.strip():
|
||||||
@ -46,204 +34,16 @@ class MidjourneyService:
|
|||||||
# 确保保存目录存在
|
# 确保保存目录存在
|
||||||
os.makedirs(settings.save_dir, exist_ok=True)
|
os.makedirs(settings.save_dir, exist_ok=True)
|
||||||
|
|
||||||
@staticmethod
|
async def split_image(self, image_url: str) -> Optional[List[str]]:
|
||||||
def first_where(array, key, value=None):
|
"""
|
||||||
"""在数组中找到第一个匹配条件的项"""
|
将一张大图切割成四张子图,保存到本地并返回 URLs
|
||||||
for item in array:
|
|
||||||
if callable(key) and key(item):
|
|
||||||
return item
|
|
||||||
if isinstance(key, str) and item[key].startswith(value):
|
|
||||||
return item
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def initialize_client(self, oauth_token, channel_id):
|
|
||||||
"""初始化Discord客户端会话"""
|
|
||||||
client = requests.Session()
|
|
||||||
|
|
||||||
retry = Retry(total=3, backoff_factor=0.5)
|
Args:
|
||||||
adapter = HTTPAdapter(max_retries=retry)
|
image_url: 图片URL
|
||||||
client.mount('http://', adapter)
|
|
||||||
client.mount('https://', adapter)
|
|
||||||
|
|
||||||
client.headers.update({
|
|
||||||
'Authorization': oauth_token
|
|
||||||
})
|
|
||||||
|
|
||||||
if self.proxies:
|
|
||||||
client.proxies.update(self.proxies)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}')
|
|
||||||
data = response.json()
|
|
||||||
guild_id = data['guild_id']
|
|
||||||
|
|
||||||
response = client.get(f'{self.API_URL}/users/@me')
|
Returns:
|
||||||
data = response.json()
|
分割后的图片URL列表,失败返回 None
|
||||||
user_id = data['id']
|
"""
|
||||||
|
|
||||||
return client, guild_id, user_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"初始化Discord客户端失败: {str(e)}")
|
|
||||||
raise Exception(f"初始化Discord客户端失败: {str(e)}")
|
|
||||||
|
|
||||||
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
|
|
||||||
"""发送imagine请求到Discord"""
|
|
||||||
params = {
|
|
||||||
'type': 2,
|
|
||||||
'application_id': self.APPLICATION_ID,
|
|
||||||
'guild_id': guild_id,
|
|
||||||
'channel_id': channel_id,
|
|
||||||
'session_id': self.SESSION_ID,
|
|
||||||
'data': {
|
|
||||||
'version': self.DATA_VERSION,
|
|
||||||
'id': self.DATA_ID,
|
|
||||||
'name': 'imagine',
|
|
||||||
'type': 1,
|
|
||||||
'options': [{
|
|
||||||
'type': 3,
|
|
||||||
'name': 'prompt',
|
|
||||||
'value': prompt
|
|
||||||
}],
|
|
||||||
'application_command': {
|
|
||||||
'id': self.DATA_ID,
|
|
||||||
'application_id': self.APPLICATION_ID,
|
|
||||||
'version': self.DATA_VERSION,
|
|
||||||
'default_member_permissions': None,
|
|
||||||
'type': 1,
|
|
||||||
'nsfw': False,
|
|
||||||
'name': 'imagine',
|
|
||||||
'description': 'Create images with Midjourney',
|
|
||||||
'dm_permission': True,
|
|
||||||
'options': [{
|
|
||||||
'type': 3,
|
|
||||||
'name': 'prompt',
|
|
||||||
'description': 'The prompt to imagine',
|
|
||||||
'required': True
|
|
||||||
}]
|
|
||||||
},
|
|
||||||
'attachments': []
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
client.post(f'{self.API_URL}/interactions', json=params)
|
|
||||||
await asyncio.sleep(20)
|
|
||||||
|
|
||||||
imagine_message = None
|
|
||||||
count = 0
|
|
||||||
last_progress = 0
|
|
||||||
|
|
||||||
while count < settings.max_polling_attempts:
|
|
||||||
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
|
|
||||||
if progress_msg:
|
|
||||||
content = progress_msg.get('content', '')
|
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
|
||||||
if progress_match:
|
|
||||||
progress_value = int(progress_match.group(1))
|
|
||||||
if progress_value > last_progress:
|
|
||||||
last_progress = progress_value
|
|
||||||
logger.info(f"生成进度: {progress_value}%")
|
|
||||||
yield {
|
|
||||||
"status": "progress",
|
|
||||||
"progress": progress_value,
|
|
||||||
"message_id": progress_msg.get('id'),
|
|
||||||
"content": content
|
|
||||||
}
|
|
||||||
|
|
||||||
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
|
|
||||||
if imagine_message:
|
|
||||||
break
|
|
||||||
|
|
||||||
await asyncio.sleep(settings.polling_interval)
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
if count >= settings.max_polling_attempts:
|
|
||||||
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
|
|
||||||
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
|
|
||||||
return
|
|
||||||
|
|
||||||
if not imagine_message:
|
|
||||||
logger.error("轮询结束但没有获取到有效结果")
|
|
||||||
yield {"status": "error", "message": "没有获取到有效结果"}
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}")
|
|
||||||
yield {
|
|
||||||
"status": "success",
|
|
||||||
"message_id": imagine_message.get('id'),
|
|
||||||
"content": imagine_message.get('content', ''),
|
|
||||||
"images": self.extract_image_urls(imagine_message)
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送Imagine请求失败: {str(e)}")
|
|
||||||
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
|
|
||||||
|
|
||||||
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
|
|
||||||
"""获取生成图像的消息"""
|
|
||||||
try:
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
def criteria(item):
|
|
||||||
content = item.get('content', '')
|
|
||||||
|
|
||||||
if seed is not None and f"--seed {seed}" in content:
|
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
|
||||||
if progress_match:
|
|
||||||
progress_value = int(progress_match.group(1))
|
|
||||||
logger.info(f"任务进度: {progress_value}%")
|
|
||||||
if "%" in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if "%" in content:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if seed is not None:
|
|
||||||
seed_pattern = f"--seed {seed}"
|
|
||||||
if seed_pattern not in content:
|
|
||||||
return False
|
|
||||||
else:
|
|
||||||
if 'attachments' in item and len(item.get('attachments', [])) > 0:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.first_where(data, criteria)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取通道消息失败: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_progress_message(self, client, channel_id, prompt, seed=None):
|
|
||||||
"""获取包含进度信息的消息"""
|
|
||||||
try:
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
for item in data:
|
|
||||||
content = item.get('content', '')
|
|
||||||
if seed is not None and f"--seed {seed}" in content and "%" in content:
|
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
|
||||||
if progress_match:
|
|
||||||
return item
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取进度消息失败: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extract_image_urls(self, message):
|
|
||||||
"""从消息中提取图像URL"""
|
|
||||||
image_urls = []
|
|
||||||
if 'attachments' in message and message['attachments']:
|
|
||||||
for attachment in message['attachments']:
|
|
||||||
if 'url' in attachment:
|
|
||||||
image_urls.append(attachment['url'])
|
|
||||||
return image_urls
|
|
||||||
|
|
||||||
async def split_image(self, image_url):
|
|
||||||
"""将一张大图切割成四张子图,保存到本地并返回URLs"""
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@ -262,8 +62,8 @@ class MidjourneyService:
|
|||||||
original_format = 'PNG'
|
original_format = 'PNG'
|
||||||
|
|
||||||
if width < 500 or height < 500:
|
if width < 500 or height < 500:
|
||||||
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
|
logger.warning(f"图像尺寸较小: {width}x{height},跳过分割")
|
||||||
return None
|
return [image_url]
|
||||||
|
|
||||||
half_width = width // 2
|
half_width = width // 2
|
||||||
half_height = height // 2
|
half_height = height // 2
|
||||||
@ -306,6 +106,7 @@ class MidjourneyService:
|
|||||||
logger.info(f"成功保存分割图片 {i}/4: {filename}")
|
logger.info(f"成功保存分割图片 {i}/4: {filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
||||||
|
# 清理已保存的文件
|
||||||
for url in image_urls:
|
for url in image_urls:
|
||||||
try:
|
try:
|
||||||
file_path = os.path.join(save_dir, os.path.basename(url))
|
file_path = os.path.join(save_dir, os.path.basename(url))
|
||||||
@ -326,207 +127,30 @@ class MidjourneyService:
|
|||||||
logger.error(f"分割图像失败: {str(e)}")
|
logger.error(f"分割图像失败: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
|
# ==================== VectorEngine API 方法 ====================
|
||||||
"""批量处理多个图像URL"""
|
|
||||||
if not config:
|
|
||||||
config = {}
|
|
||||||
|
|
||||||
total = len(image_urls)
|
|
||||||
success_count = 0
|
|
||||||
|
|
||||||
for i, image_url in enumerate(image_urls, 1):
|
|
||||||
try:
|
|
||||||
if not is_valid_image_url(image_url):
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"index": i,
|
|
||||||
"total": total,
|
|
||||||
"success_count": success_count,
|
|
||||||
"message": "无效的图片URL"
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
|
|
||||||
split_urls = await self.split_image(image_url)
|
|
||||||
if split_urls and len(split_urls) == 4:
|
|
||||||
success_count += 1
|
|
||||||
yield {
|
|
||||||
"status": "success",
|
|
||||||
"index": i,
|
|
||||||
"total": total,
|
|
||||||
"success_count": success_count,
|
|
||||||
"images": split_urls
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"index": i,
|
|
||||||
"total": total,
|
|
||||||
"success_count": success_count,
|
|
||||||
"message": "分割图片失败"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"index": i,
|
|
||||||
"total": total,
|
|
||||||
"success_count": success_count,
|
|
||||||
"message": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
async def generate_image(self, prompt, config=None):
|
async def ve_submit_imagine(
|
||||||
"""生成图像并以流式方式返回结果"""
|
self,
|
||||||
if not config:
|
prompt: str,
|
||||||
config = {}
|
base64_array: List[str] = None,
|
||||||
|
notify_hook: str = "",
|
||||||
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
|
state: str = "",
|
||||||
channel_id = config.get('channel_id', settings.midjourney_channel_id)
|
bot_type: str = "MID_JOURNEY"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
if not oauth_token or not channel_id:
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"message": "缺少Discord配置",
|
|
||||||
"success_count": 0
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
|
|
||||||
except Exception as e:
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"message": str(e),
|
|
||||||
"success_count": 0
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
prompt = prompt.strip()
|
|
||||||
options = settings.midjourney_default_options.copy()
|
|
||||||
if 'options' in config:
|
|
||||||
options.update(config.get('options', {}))
|
|
||||||
|
|
||||||
if "seed" not in options or not options.get('seed'):
|
|
||||||
options['seed'] = random.randint(0, 4294967295)
|
|
||||||
|
|
||||||
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
|
||||||
iw = max(0.1, min(config['image_weight'], 3))
|
|
||||||
options['iw'] = iw
|
|
||||||
|
|
||||||
seed = options.pop('seed', None)
|
|
||||||
|
|
||||||
parameter = ""
|
|
||||||
no_value_key = ['relax', 'fast', 'turbo', 'tile']
|
|
||||||
for key, value in options.items():
|
|
||||||
if key in no_value_key:
|
|
||||||
parameter += f" --{key}"
|
|
||||||
else:
|
|
||||||
parameter += f" --{key} {value}"
|
|
||||||
|
|
||||||
if seed:
|
|
||||||
parameter += f" --seed {seed}"
|
|
||||||
logger.info(f"使用的seed值: {seed}")
|
|
||||||
|
|
||||||
if 'image_urls' in config and config['image_urls']:
|
|
||||||
image_urls = config['image_urls']
|
|
||||||
if not isinstance(image_urls, list):
|
|
||||||
image_urls = [image_urls]
|
|
||||||
|
|
||||||
new_image_urls = []
|
|
||||||
for image_url in image_urls:
|
|
||||||
if is_valid_image_url(image_url):
|
|
||||||
try:
|
|
||||||
new_url = get_new_image_url(image_url)
|
|
||||||
if new_url:
|
|
||||||
new_image_urls.append(new_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"转换图片URL失败: {str(e)}")
|
|
||||||
|
|
||||||
if new_image_urls:
|
|
||||||
prompt = " ".join(new_image_urls) + " " + prompt
|
|
||||||
|
|
||||||
if 'characters' in config and config['characters']:
|
|
||||||
char_urls = []
|
|
||||||
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
|
|
||||||
|
|
||||||
for char_url in characters:
|
|
||||||
if is_valid_image_url(char_url):
|
|
||||||
try:
|
|
||||||
new_url = get_new_image_url(char_url)
|
|
||||||
if new_url:
|
|
||||||
char_urls.append(new_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"转换角色图片URL失败: {str(e)}")
|
|
||||||
|
|
||||||
if char_urls:
|
|
||||||
prompt = prompt + " --cref " + " ".join(char_urls)
|
|
||||||
|
|
||||||
if 'styles' in config and config['styles']:
|
|
||||||
style_urls = []
|
|
||||||
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
|
|
||||||
|
|
||||||
for style_url in styles:
|
|
||||||
if is_valid_image_url(style_url):
|
|
||||||
try:
|
|
||||||
new_url = get_new_image_url(style_url)
|
|
||||||
if new_url:
|
|
||||||
style_urls.append(new_url)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"转换风格图片URL失败: {str(e)}")
|
|
||||||
|
|
||||||
if style_urls:
|
|
||||||
prompt = prompt + " --sref " + " ".join(style_urls)
|
|
||||||
|
|
||||||
prompt = prompt + parameter
|
|
||||||
success_count = 0
|
|
||||||
|
|
||||||
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
|
||||||
if result.get("status") == "progress":
|
|
||||||
yield {
|
|
||||||
"status": "progress",
|
|
||||||
"progress": result.get("progress", 0),
|
|
||||||
"seed": seed
|
|
||||||
}
|
|
||||||
elif result.get("status") == "success":
|
|
||||||
success_count += 1
|
|
||||||
response = {
|
|
||||||
"status": "success",
|
|
||||||
"success_count": success_count,
|
|
||||||
"image_urls": result.get("images", [])
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.get("split_image", True) and result.get("images"):
|
|
||||||
try:
|
|
||||||
orig_image_url = result["images"][0]
|
|
||||||
split_urls = await self.split_image(orig_image_url)
|
|
||||||
if split_urls:
|
|
||||||
response["image_urls"] = split_urls
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"分割图像失败: {str(e)}")
|
|
||||||
response["status"] = "error"
|
|
||||||
response["message"] = f"分割图像失败: {str(e)}"
|
|
||||||
|
|
||||||
yield response
|
|
||||||
else:
|
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"message": result.get("message", "未知错误"),
|
|
||||||
"success_count": success_count
|
|
||||||
}
|
|
||||||
|
|
||||||
async def ve_submit_imagine(self, prompt: str, base64_array: List[str] = None,
|
|
||||||
notify_hook: str = "", state: str = "",
|
|
||||||
bot_type: str = "MID_JOURNEY") -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
提交 imagine 任务到 VectorEngine API
|
提交 imagine 任务到 VectorEngine API
|
||||||
|
|
||||||
|
API 文档: https://vectorengine.apifox.cn/api-349239131
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: 提示词
|
prompt: 提示词
|
||||||
base64_array: base64 编码的图片数组
|
base64_array: base64 编码的图片数组(垫图)
|
||||||
notify_hook: 回调地址
|
notify_hook: 回调地址
|
||||||
state: 自定义状态
|
state: 自定义状态
|
||||||
bot_type: 机器人类型,默认 MID_JOURNEY
|
bot_type: 机器人类型,默认 MID_JOURNEY
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含任务ID的响应
|
{"result": "任务ID"} 或 {"success": False, "message": "错误信息"}
|
||||||
"""
|
"""
|
||||||
if base64_array is None:
|
if base64_array is None:
|
||||||
base64_array = []
|
base64_array = []
|
||||||
@ -553,8 +177,9 @@ class MidjourneyService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.error(f"VectorEngine API 请求失败: {response.status_code} - {response.text}")
|
error_text = response.text
|
||||||
return {"success": False, "message": f"API请求失败: {response.status_code}"}
|
logger.error(f"VectorEngine 提交任务失败: {response.status_code} - {error_text}")
|
||||||
|
return {"success": False, "message": f"API请求失败: {response.status_code} - {error_text}"}
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
logger.info(f"VectorEngine 提交任务成功: {result}")
|
logger.info(f"VectorEngine 提交任务成功: {result}")
|
||||||
@ -564,15 +189,25 @@ class MidjourneyService:
|
|||||||
logger.error(f"VectorEngine 提交任务异常: {str(e)}")
|
logger.error(f"VectorEngine 提交任务异常: {str(e)}")
|
||||||
return {"success": False, "message": str(e)}
|
return {"success": False, "message": str(e)}
|
||||||
|
|
||||||
async def ve_get_task_status(self, task_id: str) -> Dict[str, Any]:
|
async def ve_get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取 VectorEngine 任务状态
|
根据任务ID查询任务状态
|
||||||
|
|
||||||
|
API 文档: https://vectorengine.apifox.cn/api-349239132
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id: 任务ID
|
task_id: 任务ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
任务状态信息
|
任务状态对象,包含:
|
||||||
|
- id: 任务ID
|
||||||
|
- action: 动作类型 (如 IMAGINE)
|
||||||
|
- status: 任务状态 (SUCCESS, IN_PROGRESS, FAILURE 等)
|
||||||
|
- progress: 进度字符串 (如 "100%")
|
||||||
|
- imageUrl: 生成的图片URL
|
||||||
|
- buttons: 操作按钮列表
|
||||||
|
- failReason: 失败原因
|
||||||
|
等字段,失败返回 None
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {settings.vectorengine_token}",
|
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||||||
@ -588,30 +223,109 @@ class MidjourneyService:
|
|||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}")
|
logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}")
|
||||||
return {"success": False, "message": f"获取任务状态失败: {response.status_code}"}
|
return None
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"VectorEngine 获取任务状态异常: {str(e)}")
|
logger.error(f"VectorEngine 获取任务状态异常: {str(e)}")
|
||||||
return {"success": False, "message": str(e)}
|
return None
|
||||||
|
|
||||||
async def ve_generate_image(self, prompt: str, config: Dict = None) -> AsyncGenerator[Dict, None]:
|
async def ve_get_task_list(self, task_ids: List[str]) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
根据ID列表查询任务
|
||||||
|
|
||||||
|
API 文档: https://vectorengine.apifox.cn/api-349239133
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_ids: 任务ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
任务对象列表,失败返回 None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
f"{settings.vectorengine_api_url}/mj/task/list-by-condition",
|
||||||
|
json={"ids": task_ids},
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"VectorEngine 批量查询任务失败: {response.status_code} - {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VectorEngine 批量查询任务异常: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def ve_get_image_seed(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取任务图片的seed
|
||||||
|
|
||||||
|
API 文档: https://vectorengine.apifox.cn/api-349239134
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 seed 信息的任务对象,失败返回 None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.vectorengine_token}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{settings.vectorengine_api_url}/mj/task/{task_id}/image-seed",
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"VectorEngine 获取seed失败: {response.status_code} - {response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VectorEngine 获取seed异常: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def ve_generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
config: Dict = None
|
||||||
|
) -> AsyncGenerator[Dict, None]:
|
||||||
"""
|
"""
|
||||||
使用 VectorEngine API 生成图像并以流式方式返回结果
|
使用 VectorEngine API 生成图像并以流式方式返回结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: 提示词
|
prompt: 提示词
|
||||||
config: 配置参数,可包含:
|
config: 配置参数,可包含:
|
||||||
- base64_array: base64 编码的图片数组
|
- base64_array: base64 编码的图片数组(垫图)
|
||||||
- notify_hook: 回调地址
|
- notify_hook: 回调地址
|
||||||
- state: 自定义状态
|
- state: 自定义状态
|
||||||
- bot_type: 机器人类型
|
- bot_type: 机器人类型
|
||||||
- split_image: 是否分割图片 (默认 True)
|
- split_image: 是否分割图片 (默认 True)
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
生成状态和结果
|
生成状态和结果:
|
||||||
|
- {"status": "submitted", "task_id": "..."} # 任务已提交
|
||||||
|
- {"status": "progress", "progress": 50, "task_id": "..."} # 进度更新
|
||||||
|
- {"status": "success", "image_urls": [...], ...} # 成功完成
|
||||||
|
- {"status": "error", "message": "..."} # 错误
|
||||||
"""
|
"""
|
||||||
if not config:
|
if not config:
|
||||||
config = {}
|
config = {}
|
||||||
@ -619,12 +333,11 @@ class MidjourneyService:
|
|||||||
if not settings.vectorengine_token:
|
if not settings.vectorengine_token:
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "未配置 VectorEngine Token",
|
"message": "未配置 VectorEngine Token"
|
||||||
"success_count": 0
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 提交任务
|
# 1. 提交任务
|
||||||
submit_result = await self.ve_submit_imagine(
|
submit_result = await self.ve_submit_imagine(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
base64_array=config.get("base64_array", []),
|
base64_array=config.get("base64_array", []),
|
||||||
@ -633,21 +346,12 @@ class MidjourneyService:
|
|||||||
bot_type=config.get("bot_type", "MID_JOURNEY")
|
bot_type=config.get("bot_type", "MID_JOURNEY")
|
||||||
)
|
)
|
||||||
|
|
||||||
if not submit_result.get("success", True) and "result" not in submit_result:
|
# 检查提交结果 - VectorEngine 返回 {"result": "任务ID"} 表示成功
|
||||||
yield {
|
|
||||||
"status": "error",
|
|
||||||
"message": submit_result.get("message", "提交任务失败"),
|
|
||||||
"success_count": 0
|
|
||||||
}
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取任务ID
|
|
||||||
task_id = submit_result.get("result")
|
task_id = submit_result.get("result")
|
||||||
if not task_id:
|
if not task_id:
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "未获取到任务ID",
|
"message": submit_result.get("message", "提交任务失败,未获取到任务ID")
|
||||||
"success_count": 0
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -656,10 +360,10 @@ class MidjourneyService:
|
|||||||
"status": "submitted",
|
"status": "submitted",
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
}
|
}
|
||||||
|
|
||||||
# 轮询任务状态
|
# 2. 轮询任务状态
|
||||||
polling_count = 0
|
polling_count = 0
|
||||||
last_progress = 0
|
last_progress = -1
|
||||||
|
|
||||||
while polling_count < settings.vectorengine_max_polling_attempts:
|
while polling_count < settings.vectorengine_max_polling_attempts:
|
||||||
await asyncio.sleep(settings.vectorengine_polling_interval)
|
await asyncio.sleep(settings.vectorengine_polling_interval)
|
||||||
@ -668,26 +372,27 @@ class MidjourneyService:
|
|||||||
task_status = await self.ve_get_task_status(task_id)
|
task_status = await self.ve_get_task_status(task_id)
|
||||||
|
|
||||||
if not task_status:
|
if not task_status:
|
||||||
|
logger.warning(f"第 {polling_count} 次轮询未获取到任务状态")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
status = task_status.get("status")
|
status = task_status.get("status", "")
|
||||||
progress = task_status.get("progress", 0)
|
progress_str = task_status.get("progress", "0%")
|
||||||
|
|
||||||
# 解析进度值(可能是数字或带%的字符串如 '0%')
|
# 解析进度值(可能是数字或带%的字符串如 '0%', '100%')
|
||||||
try:
|
try:
|
||||||
if isinstance(progress, str):
|
if isinstance(progress_str, str):
|
||||||
progress_val = int(progress.strip().rstrip('%'))
|
progress_val = int(progress_str.strip().rstrip('%'))
|
||||||
else:
|
else:
|
||||||
progress_val = int(progress)
|
progress_val = int(progress_str)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
progress_val = 0
|
progress_val = 0
|
||||||
|
|
||||||
# 发送进度更新
|
# 发送进度更新(仅当进度有变化时)
|
||||||
if progress_val > last_progress:
|
if progress_val > last_progress:
|
||||||
last_progress = progress_val
|
last_progress = progress_val
|
||||||
yield {
|
yield {
|
||||||
"status": "progress",
|
"status": "progress",
|
||||||
"progress": last_progress,
|
"progress": progress_val,
|
||||||
"task_id": task_id
|
"task_id": task_id
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -696,37 +401,33 @@ class MidjourneyService:
|
|||||||
image_url = task_status.get("imageUrl")
|
image_url = task_status.get("imageUrl")
|
||||||
image_urls = [image_url] if image_url else []
|
image_urls = [image_url] if image_url else []
|
||||||
|
|
||||||
# 检查是否有多个图片URL(对于网格图)
|
|
||||||
if "buttons" in task_status:
|
|
||||||
# 可能包含 U1, U2, U3, U4 按钮对应的图片
|
|
||||||
pass
|
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"success_count": 1,
|
|
||||||
"task_id": task_id,
|
"task_id": task_id,
|
||||||
"image_urls": image_urls
|
"image_urls": image_urls,
|
||||||
|
"buttons": task_status.get("buttons", []),
|
||||||
|
"properties": task_status.get("properties", {})
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果配置了分割图片,则处理分割
|
# 如果配置了分割图片且有图片URL,则处理分割
|
||||||
if config.get("split_image", True) and image_url:
|
if config.get("split_image", True) and image_url:
|
||||||
try:
|
try:
|
||||||
split_urls = await self.split_image(image_url)
|
split_urls = await self.split_image(image_url)
|
||||||
if split_urls:
|
if split_urls and len(split_urls) == 4:
|
||||||
response["image_urls"] = split_urls
|
response["image_urls"] = split_urls
|
||||||
logger.info(f"图片分割成功,生成 {len(split_urls)} 张子图")
|
logger.info(f"图片分割成功,生成 4 张子图")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"分割图像失败: {str(e)}")
|
logger.error(f"分割图像失败: {str(e)}")
|
||||||
|
|
||||||
yield response
|
yield response
|
||||||
return
|
return
|
||||||
|
|
||||||
elif status == "FAILURE" or status == "FAILED":
|
elif status in ("FAILURE", "FAILED"):
|
||||||
|
fail_reason = task_status.get("failReason", "任务失败")
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": task_status.get("failReason", "任务失败"),
|
"message": fail_reason,
|
||||||
"task_id": task_id,
|
"task_id": task_id
|
||||||
"success_count": 0
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -734,15 +435,16 @@ class MidjourneyService:
|
|||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "任务已取消",
|
"message": "任务已取消",
|
||||||
"task_id": task_id,
|
"task_id": task_id
|
||||||
"success_count": 0
|
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 其他状态(如 IN_PROGRESS, SUBMITTED)继续轮询
|
||||||
|
logger.debug(f"任务 {task_id} 状态: {status}, 进度: {progress_str}")
|
||||||
|
|
||||||
# 超时
|
# 超时
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "任务超时",
|
"message": f"任务超时,已轮询 {settings.vectorengine_max_polling_attempts} 次",
|
||||||
"task_id": task_id,
|
"task_id": task_id
|
||||||
"success_count": 0
|
}
|
||||||
}
|
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
from typing import Optional, Dict
|
from typing import Optional
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
# Japi Server 配置
|
# Japi Server 配置
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
@ -10,7 +11,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# API路由配置
|
# API路由配置
|
||||||
router_prefix: str = "/jmidjourney"
|
router_prefix: str = "/jmidjourney"
|
||||||
generate_route: str = "/generate" # 生成图片的路由
|
generate_route: str = "/generate" # 保留用于兼容
|
||||||
ve_generate_route: str = "/ve/generate" # VectorEngine 生成图片的路由
|
ve_generate_route: str = "/ve/generate" # VectorEngine 生成图片的路由
|
||||||
api_name: str = "jmidjourney" # 默认API名称
|
api_name: str = "jmidjourney" # 默认API名称
|
||||||
|
|
||||||
@ -32,41 +33,21 @@ class Settings(BaseSettings):
|
|||||||
jingrow_api_key: Optional[str] = None
|
jingrow_api_key: Optional[str] = None
|
||||||
jingrow_api_secret: Optional[str] = None
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
# Discord Midjourney配置
|
# 代理配置(用于下载外部图片)
|
||||||
midjourney_api_url: str = "https://discord.com/api/v9"
|
http_proxy: Optional[str] = None
|
||||||
midjourney_application_id: str = "936929561302675456"
|
https_proxy: Optional[str] = None
|
||||||
midjourney_data_id: str = "938956540159881230"
|
|
||||||
midjourney_data_version: str = "1237876415471554623"
|
|
||||||
midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029"
|
|
||||||
midjourney_channel_id: str = "1259838588510670941"
|
|
||||||
midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0"
|
|
||||||
midjourney_suffix: str = "mj" # 图片文件名的后缀
|
|
||||||
|
|
||||||
# 代理配置
|
|
||||||
http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理
|
|
||||||
https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理
|
|
||||||
|
|
||||||
# Midjourney默认选项
|
|
||||||
midjourney_default_options: Dict = {
|
|
||||||
"ar": "1:1",
|
|
||||||
"v": "6.1",
|
|
||||||
"quality": "1"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 图像设置
|
|
||||||
add_title_to_image_name: bool = False # 是否将标题添加到图片名称中
|
|
||||||
|
|
||||||
# 超时设置(秒)
|
# 超时设置(秒)
|
||||||
request_timeout: int = 30
|
request_timeout: int = 30
|
||||||
max_polling_attempts: int = 60
|
|
||||||
polling_interval: int = 3
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|
||||||
|
|
||||||
@lru_cache()
|
@lru_cache()
|
||||||
def get_settings() -> Settings:
|
def get_settings() -> Settings:
|
||||||
return Settings()
|
return Settings()
|
||||||
|
|
||||||
|
|
||||||
# 创建全局配置实例
|
# 创建全局配置实例
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user