更新midjourney适合生产环境

This commit is contained in:
jingrow 2025-05-20 17:01:34 +08:00
parent 0e5b27e422
commit b523b6975b
2 changed files with 25 additions and 146 deletions

View File

@ -5,19 +5,15 @@ import io
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import time
import random
import re
import uuid
import urllib.request
import urllib3
import traceback
import logging
import base64
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
import mimetypes
import asyncio
from typing import Dict, Any, List, AsyncGenerator, Optional
from settings import settings
@ -63,7 +59,6 @@ class MidjourneyService:
"""初始化Discord客户端会话"""
client = requests.Session()
# 添加重试机制
retry = Retry(total=3, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry)
client.mount('http://', adapter)
@ -73,17 +68,14 @@ class MidjourneyService:
'Authorization': oauth_token
})
# 设置代理
if self.proxies:
client.proxies.update(self.proxies)
try:
# 获取频道信息
response = client.get(f'{self.API_URL}/channels/{channel_id}')
data = response.json()
guild_id = data['guild_id']
# 获取用户信息
response = client.get(f'{self.API_URL}/users/@me')
data = response.json()
user_id = data['id']
@ -91,7 +83,6 @@ class MidjourneyService:
return client, guild_id, user_id
except Exception as e:
logger.error(f"初始化Discord客户端失败: {str(e)}")
traceback.print_exc()
raise Exception(f"初始化Discord客户端失败: {str(e)}")
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
@ -134,28 +125,20 @@ class MidjourneyService:
}
try:
# 发送请求
r = client.post(f'{self.API_URL}/interactions', json=params)
# 初始等待时间从5秒延长到30秒给Discord足够的时间开始处理请求
print(f"[生成] 已发送请求等待20秒后开始轮询...")
client.post(f'{self.API_URL}/interactions', json=params)
await asyncio.sleep(20)
# 轮询获取结果
imagine_message = None
count = 0
last_progress = 0
# 轮询直到获取完整的结果或达到最大次数
while count < settings.max_polling_attempts:
# 轮询两种消息:
# 1. 进度消息 - 用于更新进度
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
if progress_msg:
content = progress_msg.get('content', '')
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 只有进度有变化时才发送更新
if progress_value > last_progress:
last_progress = progress_value
logger.info(f"生成进度: {progress_value}%")
@ -166,30 +149,23 @@ class MidjourneyService:
"content": content
}
# 2. 完成的消息 - 包含图像结果
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
if imagine_message:
# 找到最终结果
break
# 没有找到结果,继续等待
logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待")
await asyncio.sleep(settings.polling_interval)
count += 1
# 检查是否超过最大轮询次数
if count >= settings.max_polling_attempts:
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
return
# 检查是否有有效的最终结果
if not imagine_message:
logger.error("轮询结束但没有获取到有效结果")
yield {"status": "error", "message": "没有获取到有效结果"}
return
# 返回最终结果
logger.info(f"成功获取图像结果消息ID: {imagine_message.get('id')}")
yield {
"status": "success",
@ -200,68 +176,38 @@ class MidjourneyService:
except Exception as e:
logger.error(f"发送Imagine请求失败: {str(e)}")
traceback.print_exc()
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
"""获取生成图像的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
def criteria(item):
content = item.get('content', '')
# 检查进度信息并更新状态
if seed is not None and f"--seed {seed}" in content:
# 匹配百分比,支持多种格式:(93%) 或 93%
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 记录进度信息
logger.info(f"任务进度: {progress_value}%")
# 如果消息包含百分比说明任务还在进行中返回False继续轮询
if "%" in content:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成")
return False
# 排除进行中消息,只匹配完成的消息
if "%" in content:
return False
# seed 匹配
if seed is not None:
seed_pattern = f"--seed {seed}"
# 检查是否包含指定的seed
if seed_pattern not in content:
return False
else:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成")
# 检查是否包含图像附件
if 'attachments' in item and len(item.get('attachments', [])) > 0:
return True
return False # 默认不匹配
return False
# 查找匹配完成状态的消息
raw_message = self.first_where(data, criteria)
if raw_message is None:
print("[轮询] 没有找到完成的消息,继续等待")
return None
# 打印匹配到的完整消息内容
print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}")
try:
print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"[轮询] 无法打印完整消息内容: {str(e)}")
# 尝试打印关键字段
print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}")
print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}")
for i, attachment in enumerate(raw_message.get('attachments', [])):
print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}")
return raw_message
return self.first_where(data, criteria)
except Exception as e:
logger.error(f"获取通道消息失败: {str(e)}")
@ -270,13 +216,11 @@ class MidjourneyService:
async def get_progress_message(self, client, channel_id, prompt, seed=None):
"""获取包含进度信息的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
for item in data:
content = item.get('content', '')
# 检查是否包含seed和进度信息
if seed is not None and f"--seed {seed}" in content and "%" in content:
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
@ -298,69 +242,48 @@ class MidjourneyService:
return image_urls
async def split_image(self, image_url):
"""将一张大图切割成四张子图保存到本地并返回URLs
Args:
image_url: 原始图像的URL
Returns:
包含四个子图像URL的列表如果处理失败则返回None
"""
"""将一张大图切割成四张子图保存到本地并返回URLs"""
try:
# 下载图像
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
if response.status_code != 200:
logger.error(f"下载图像失败,状态码: {response.status_code}")
return None
image_data = response.content
# 从二进制数据创建图像对象
img = Image.open(io.BytesIO(image_data))
width, height = img.size
# 获取原始图片格式
original_format = img.format
if not original_format:
# 如果无法获取格式从URL中提取
parsed_url = urlparse(image_url)
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
if not original_format:
original_format = 'PNG' # 默认使用PNG
original_format = 'PNG'
# 确认图像尺寸约为2048x2048
if width < 1500 or height < 1500:
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
return None
# 计算每个象限的尺寸
half_width = width // 2
half_height = height // 2
# 分割图像为四个象限
quadrants = [
img.crop((0, 0, half_width, half_height)), # 左上
img.crop((half_width, 0, width, half_height)), # 右上
img.crop((0, half_height, half_width, height)), # 左下
img.crop((half_width, half_height, width, height)) # 右下
img.crop((0, 0, half_width, half_height)),
img.crop((half_width, 0, width, half_height)),
img.crop((0, half_height, half_width, height)),
img.crop((half_width, half_height, width, height))
]
# 生成唯一的图片名称前缀
image_id = uuid.uuid4().hex[:10]
# 确保保存目录存在(使用绝对路径)
save_dir = os.path.abspath(settings.save_dir)
os.makedirs(save_dir, exist_ok=True)
# 保存图片到本地并生成URLs
image_urls = []
for i, quadrant in enumerate(quadrants, 1):
try:
# 使用原始格式保存
filename = f"split_{image_id}_{i}.{original_format.lower()}"
file_path = os.path.join(save_dir, filename)
# 保存图片,保持原始格式
save_params = {"format": original_format}
if original_format in ['PNG', 'JPEG', 'JPG']:
save_params["optimize"] = True
@ -369,23 +292,19 @@ class MidjourneyService:
quadrant.save(file_path, **save_params)
# 验证文件是否成功保存
if not os.path.exists(file_path):
raise Exception(f"文件保存失败: {file_path}")
# 验证文件大小
file_size = os.path.getsize(file_path)
if file_size == 0:
raise Exception(f"保存的文件大小为0: {file_path}")
# 构建图片URL
image_url = f"{settings.download_url}/{filename}"
image_urls.append(image_url)
logger.info(f"成功保存分割图片 {i}/4: {filename} (大小: {file_size} 字节, 格式: {original_format})")
logger.info(f"成功保存分割图片 {i}/4: {filename}")
except Exception as e:
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
# 如果保存失败,删除已保存的图片
for url in image_urls:
try:
file_path = os.path.join(save_dir, os.path.basename(url))
@ -399,12 +318,11 @@ class MidjourneyService:
logger.error(f"分割图片数量不正确: 期望4张实际{len(image_urls)}")
return None
logger.info(f"成功完成图片分割生成4张子图")
logger.info("成功完成图片分割生成4张子图")
return image_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
traceback.print_exc()
return None
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
@ -418,95 +336,76 @@ class MidjourneyService:
for i, image_url in enumerate(image_urls, 1):
try:
if not is_valid_image_url(image_url):
response = {
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": "无效的图片URL"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
continue
split_urls = await self.split_image(image_url)
if split_urls and len(split_urls) == 4:
success_count += 1
response = {
yield {
"status": "success",
"index": i,
"total": total,
"success_count": success_count,
"images": split_urls
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
response = {
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": "分割图片失败"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
except Exception as e:
response = {
yield {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"message": str(e)
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
async def generate_image(self, prompt, config=None):
"""生成图像并以流式方式返回结果"""
if not config:
config = {}
# 获取必要的认证信息
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
channel_id = config.get('channel_id', settings.midjourney_channel_id)
if not oauth_token or not channel_id:
response = {
yield {
"status": "error",
"message": "缺少Discord配置",
"success_count": 0
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 初始化客户端
try:
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
except Exception as e:
response = {
yield {
"status": "error",
"message": str(e),
"success_count": 0
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 解析和准备提示词
prompt = prompt.strip()
# 获取或设置默认选项
options = settings.midjourney_default_options.copy()
if 'options' in config:
options.update(config.get('options', {}))
# 强制设置seed如果没有
if "seed" not in options or not options.get('seed'):
options['seed'] = random.randint(0, 4294967295)
# 处理选项,构建参数字符串
seed = options.pop('seed', None)
parameter = ""
no_value_key = ['relax', 'fast', 'turbo', 'tile']
@ -516,20 +415,15 @@ class MidjourneyService:
else:
parameter += f" --{key} {value}"
# 添加seed
if seed:
parameter += f" --seed {seed}"
# 打印使用的seed值
print(f"[生成] 使用的seed值: {seed}")
logger.info(f"使用的seed值: {seed}")
# 处理参考图像
if 'image_urls' in config and config['image_urls']:
# 确保是列表格式
image_urls = config['image_urls']
if not isinstance(image_urls, list):
image_urls = [image_urls]
# 转换图片URL
new_image_urls = []
for image_url in image_urls:
if is_valid_image_url(image_url):
@ -540,16 +434,13 @@ class MidjourneyService:
except Exception as e:
logger.warning(f"转换图片URL失败: {str(e)}")
# 添加到prompt前面
if new_image_urls:
prompt = " ".join(new_image_urls) + " " + prompt
# 设置图像权重
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间
iw = max(0.1, min(config['image_weight'], 3))
parameter += f" --iw {iw}"
# 添加字符引用
if 'characters' in config and config['characters']:
char_urls = []
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
@ -566,7 +457,6 @@ class MidjourneyService:
if char_urls:
prompt = prompt + " --cref " + " ".join(char_urls)
# 添加风格引用
if 'styles' in config and config['styles']:
style_urls = []
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
@ -583,31 +473,24 @@ class MidjourneyService:
if style_urls:
prompt = prompt + " --sref " + " ".join(style_urls)
# 添加参数
prompt = prompt + parameter
# 流式返回结果
success_count = 0
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
if result.get("status") == "progress":
# 进度信息只包含必要字段
response = {
yield {
"status": "progress",
"progress": result.get("progress", 0),
"seed": seed # 添加seed字段
"seed": seed
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
elif result.get("status") == "success":
success_count += 1
response = {
"status": "success",
"success_count": success_count,
"image_urls": result.get("images", []) # 添加image_urls字段
"image_urls": result.get("images", [])
}
# 如果需要分割图片
if config.get("split_image", True) and result.get("images"):
try:
orig_image_url = result["images"][0]
@ -619,14 +502,10 @@ class MidjourneyService:
response["status"] = "error"
response["message"] = f"分割图像失败: {str(e)}"
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
# 错误信息保持统一格式
response = {
yield {
"status": "error",
"message": result.get("message", "未知错误"),
"success_count": success_count
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
}

View File

@ -6,7 +6,7 @@ class Settings(BaseSettings):
# Japi Server 配置
host: str = "0.0.0.0"
port: int = 8113
debug: bool = True
debug: bool = False
# API路由配置
router_prefix: str = "/midjourney"