japi/apps/midjourney/service.py
2025-05-20 04:00:49 +08:00

637 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import sys
import os
import io
import requests
import time
import random
import re
import uuid
import urllib.request
import urllib3
import traceback
import logging
import base64
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
import mimetypes
import asyncio
from typing import Dict, Any, List, AsyncGenerator, Optional
from settings import settings
from utils import get_new_image_url, is_valid_image_url
# 设置日志记录器
logger = logging.getLogger("midjourney_service")
# 禁用不安全请求警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class MidjourneyService:
def __init__(self):
"""初始化MidjourneyService"""
self.API_URL = settings.midjourney_api_url
self.APPLICATION_ID = settings.midjourney_application_id
self.DATA_ID = settings.midjourney_data_id
self.DATA_VERSION = settings.midjourney_data_version
self.SESSION_ID = settings.midjourney_session_id
# 设置代理
self.proxies = {}
if settings.http_proxy and settings.http_proxy.strip():
self.proxies['http'] = settings.http_proxy
if settings.https_proxy and settings.https_proxy.strip():
self.proxies['https'] = settings.https_proxy
# 确保保存目录存在
os.makedirs(settings.save_dir, exist_ok=True)
@staticmethod
def first_where(array, key, value=None):
"""在数组中找到第一个匹配条件的项"""
for item in array:
if callable(key) and key(item):
return item
if isinstance(key, str) and item[key].startswith(value):
return item
return None
async def initialize_client(self, oauth_token, channel_id):
"""初始化Discord客户端会话"""
client = requests.Session()
client.headers.update({
'Authorization': oauth_token
})
# 设置代理
if self.proxies:
client.proxies.update(self.proxies)
try:
# 获取频道信息
response = client.get(f'{self.API_URL}/channels/{channel_id}')
data = response.json()
guild_id = data['guild_id']
# 获取用户信息
response = client.get(f'{self.API_URL}/users/@me')
data = response.json()
user_id = data['id']
return client, guild_id, user_id
except Exception as e:
logger.error(f"初始化Discord客户端失败: {str(e)}")
traceback.print_exc()
raise Exception(f"初始化Discord客户端失败: {str(e)}")
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
"""发送imagine请求到Discord"""
params = {
'type': 2,
'application_id': self.APPLICATION_ID,
'guild_id': guild_id,
'channel_id': channel_id,
'session_id': self.SESSION_ID,
'data': {
'version': self.DATA_VERSION,
'id': self.DATA_ID,
'name': 'imagine',
'type': 1,
'options': [{
'type': 3,
'name': 'prompt',
'value': prompt
}],
'application_command': {
'id': self.DATA_ID,
'application_id': self.APPLICATION_ID,
'version': self.DATA_VERSION,
'default_member_permissions': None,
'type': 1,
'nsfw': False,
'name': 'imagine',
'description': 'Create images with Midjourney',
'dm_permission': True,
'options': [{
'type': 3,
'name': 'prompt',
'description': 'The prompt to imagine',
'required': True
}]
},
'attachments': []
}
}
try:
# 发送请求
r = client.post(f'{self.API_URL}/interactions', json=params)
# 初始等待时间从5秒延长到30秒给Discord足够的时间开始处理请求
print(f"[生成] 已发送请求等待30秒后开始轮询...")
await asyncio.sleep(30)
# 轮询获取结果
imagine_message = None
count = 0
last_progress = 0
# 轮询直到获取完整的结果或达到最大次数
while count < settings.max_polling_attempts:
# 轮询两种消息:
# 1. 进度消息 - 用于更新进度
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
if progress_msg:
content = progress_msg.get('content', '')
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 只有进度有变化时才发送更新
if progress_value > last_progress:
last_progress = progress_value
logger.info(f"生成进度: {progress_value}%")
yield {
"status": "progress",
"progress": progress_value,
"message_id": progress_msg.get('id'),
"content": content
}
# 2. 完成的消息 - 包含图像结果
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
if imagine_message:
# 找到最终结果
break
# 没有找到结果,继续等待
logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待")
await asyncio.sleep(settings.polling_interval)
count += 1
# 检查是否超过最大轮询次数
if count >= settings.max_polling_attempts:
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
return
# 检查是否有有效的最终结果
if not imagine_message:
logger.error("轮询结束但没有获取到有效结果")
yield {"status": "error", "message": "没有获取到有效结果"}
return
# 返回最终结果
logger.info(f"成功获取图像结果消息ID: {imagine_message.get('id')}")
yield {
"status": "success",
"message_id": imagine_message.get('id'),
"content": imagine_message.get('content', ''),
"images": self.extract_image_urls(imagine_message)
}
except Exception as e:
logger.error(f"发送Imagine请求失败: {str(e)}")
traceback.print_exc()
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
"""获取生成图像的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
def criteria(item):
content = item.get('content', '')
# 检查进度信息并更新状态
if seed is not None and f"--seed {seed}" in content:
# 匹配百分比,支持多种格式:(93%) 或 93%
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 记录进度信息
logger.info(f"任务进度: {progress_value}%")
# 如果消息包含百分比说明任务还在进行中返回False继续轮询
if "%" in content:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成")
return False
# 排除进行中消息,只匹配完成的消息
if "%" in content:
return False
# seed 匹配
if seed is not None:
seed_pattern = f"--seed {seed}"
# 检查是否包含指定的seed
if seed_pattern not in content:
return False
else:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成")
# 检查是否包含图像附件
if 'attachments' in item and len(item.get('attachments', [])) > 0:
return True
return False # 默认不匹配
# 查找匹配完成状态的消息
raw_message = self.first_where(data, criteria)
if raw_message is None:
print("[轮询] 没有找到完成的消息,继续等待")
return None
# 打印匹配到的完整消息内容
print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}")
try:
print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"[轮询] 无法打印完整消息内容: {str(e)}")
# 尝试打印关键字段
print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}")
print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}")
for i, attachment in enumerate(raw_message.get('attachments', [])):
print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}")
return raw_message
except Exception as e:
logger.error(f"获取通道消息失败: {str(e)}")
return None
async def get_progress_message(self, client, channel_id, prompt, seed=None):
"""获取包含进度信息的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
for item in data:
content = item.get('content', '')
# 检查是否包含seed和进度信息
if seed is not None and f"--seed {seed}" in content and "%" in content:
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
return item
return None
except Exception as e:
logger.error(f"获取进度消息失败: {str(e)}")
return None
def extract_image_urls(self, message):
"""从消息中提取图像URL"""
image_urls = []
if 'attachments' in message and message['attachments']:
for attachment in message['attachments']:
if 'url' in attachment:
image_urls.append(attachment['url'])
return image_urls
async def split_image(self, image_url):
"""将一张大图切割成四张子图保存到本地并返回URLs
Args:
image_url: 原始图像的URL
Returns:
包含四个子图像URL的列表如果处理失败则返回None
"""
try:
# 下载图像
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
if response.status_code != 200:
logger.error(f"下载图像失败,状态码: {response.status_code}")
return None
image_data = response.content
# 从二进制数据创建图像对象
img = Image.open(io.BytesIO(image_data))
width, height = img.size
# 获取原始图片格式
original_format = img.format
if not original_format:
# 如果无法获取格式从URL中提取
parsed_url = urlparse(image_url)
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
if not original_format:
original_format = 'PNG' # 默认使用PNG
# 确认图像尺寸约为2048x2048
if width < 1500 or height < 1500:
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
return None
# 计算每个象限的尺寸
half_width = width // 2
half_height = height // 2
# 分割图像为四个象限
quadrants = [
img.crop((0, 0, half_width, half_height)), # 左上
img.crop((half_width, 0, width, half_height)), # 右上
img.crop((0, half_height, half_width, height)), # 左下
img.crop((half_width, half_height, width, height)) # 右下
]
# 生成唯一的图片名称前缀
image_id = uuid.uuid4().hex[:10]
# 确保保存目录存在(使用绝对路径)
save_dir = os.path.abspath(settings.save_dir)
os.makedirs(save_dir, exist_ok=True)
# 保存图片到本地并生成URLs
image_urls = []
for i, quadrant in enumerate(quadrants, 1):
try:
# 使用原始格式保存
filename = f"split_{image_id}_{i}.{original_format.lower()}"
file_path = os.path.join(save_dir, filename)
# 保存图片,保持原始格式
save_params = {"format": original_format}
if original_format in ['PNG', 'JPEG', 'JPG']:
save_params["optimize"] = True
if original_format in ['JPEG', 'JPG']:
save_params["quality"] = 95
quadrant.save(file_path, **save_params)
# 验证文件是否成功保存
if not os.path.exists(file_path):
raise Exception(f"文件保存失败: {file_path}")
# 验证文件大小
file_size = os.path.getsize(file_path)
if file_size == 0:
raise Exception(f"保存的文件大小为0: {file_path}")
# 构建图片URL
image_url = f"{settings.download_url}/{filename}"
image_urls.append(image_url)
logger.info(f"成功保存分割图片 {i}/4: {filename} (大小: {file_size} 字节, 格式: {original_format})")
except Exception as e:
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
# 如果保存失败,删除已保存的图片
for url in image_urls:
try:
file_path = os.path.join(save_dir, os.path.basename(url))
if os.path.exists(file_path):
os.remove(file_path)
except Exception as del_e:
logger.error(f"删除失败的图片文件时出错: {str(del_e)}")
return None
if len(image_urls) != 4:
logger.error(f"分割图片数量不正确: 期望4张实际{len(image_urls)}")
return None
logger.info(f"成功完成图片分割生成4张子图")
return image_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
traceback.print_exc()
return None
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
"""批量处理多个图像URL"""
if not config:
config = {}
total = len(image_urls)
success_count = 0
error_count = 0
for i, image_url in enumerate(image_urls, 1):
try:
if not is_valid_image_url(image_url):
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": "无效的图片URL"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
continue
split_urls = await self.split_image(image_url)
if split_urls and len(split_urls) == 4:
success_count += 1
response = {
"status": "success",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"images": split_urls
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": "分割图片失败"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
except Exception as e:
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": str(e)
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
async def generate_image(self, prompt, config=None):
"""生成图像并以流式方式返回结果"""
if not config:
config = {}
# 获取必要的认证信息
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
channel_id = config.get('channel_id', settings.midjourney_channel_id)
if not oauth_token or not channel_id:
response = {
"status": "error",
"message": "缺少Discord配置",
"success_count": 0,
"error_count": 1
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 初始化客户端
try:
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
except Exception as e:
response = {
"status": "error",
"message": str(e),
"success_count": 0,
"error_count": 1
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 解析和准备提示词
prompt = prompt.strip()
# 获取或设置默认选项
options = settings.midjourney_default_options.copy()
if 'options' in config:
options.update(config.get('options', {}))
# 强制设置seed如果没有
if "seed" not in options or not options.get('seed'):
options['seed'] = random.randint(0, 4294967295)
# 处理选项,构建参数字符串
seed = options.pop('seed', None)
parameter = ""
no_value_key = ['relax', 'fast', 'turbo', 'tile']
for key, value in options.items():
if key in no_value_key:
parameter += f" --{key}"
else:
parameter += f" --{key} {value}"
# 添加seed
if seed:
parameter += f" --seed {seed}"
# 打印使用的seed值
print(f"[生成] 使用的seed值: {seed}")
# 处理参考图像
if 'reference_images' in config and config['reference_images']:
# 确保是列表格式
reference_images = config['reference_images']
if not isinstance(reference_images, list):
reference_images = [reference_images]
# 转换图片URL
image_urls = []
for image_url in reference_images:
if is_valid_image_url(image_url):
try:
new_url = get_new_image_url(image_url)
if new_url:
image_urls.append(new_url)
except Exception as e:
logger.warning(f"转换图片URL失败: {str(e)}")
# 添加到prompt前面
if image_urls:
prompt = " ".join(image_urls) + " " + prompt
# 设置图像权重
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间
parameter += f" --iw {iw}"
# 添加字符引用
if 'characters' in config and config['characters']:
char_urls = []
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
for char_url in characters:
if is_valid_image_url(char_url):
try:
new_url = get_new_image_url(char_url)
if new_url:
char_urls.append(new_url)
except Exception as e:
logger.warning(f"转换角色图片URL失败: {str(e)}")
if char_urls:
prompt = prompt + " --cref " + " ".join(char_urls)
# 添加风格引用
if 'styles' in config and config['styles']:
style_urls = []
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
for style_url in styles:
if is_valid_image_url(style_url):
try:
new_url = get_new_image_url(style_url)
if new_url:
style_urls.append(new_url)
except Exception as e:
logger.warning(f"转换风格图片URL失败: {str(e)}")
if style_urls:
prompt = prompt + " --sref " + " ".join(style_urls)
# 添加参数
prompt = prompt + parameter
# 流式返回结果
success_count = 0
error_count = 0
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
if result.get("status") == "progress":
# 进度信息保持简单
response = {
"status": "progress",
"progress": result.get("progress", 0),
"success_count": success_count,
"error_count": error_count
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
elif result.get("status") == "success":
success_count += 1
response = {
"status": "success",
"success_count": success_count,
"error_count": error_count
}
# 如果需要分割图片
if config.get("split_image", False) and result.get("images"):
try:
orig_image_url = result["images"][0]
split_urls = await self.split_image(orig_image_url)
if split_urls:
response["images"] = split_urls
except Exception as e:
error_count += 1
logger.error(f"分割图像失败: {str(e)}")
response["error_count"] = error_count
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
# 错误信息保持简单
error_count += 1
response = {
"status": "error",
"message": result.get("message", "未知错误"),
"success_count": success_count,
"error_count": error_count
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response