更新midjourney适合生产环境
This commit is contained in:
parent
0e5b27e422
commit
b523b6975b
@ -5,19 +5,15 @@ import io
|
|||||||
import requests
|
import requests
|
||||||
from requests.adapters import HTTPAdapter
|
from requests.adapters import HTTPAdapter
|
||||||
from urllib3.util.retry import Retry
|
from urllib3.util.retry import Retry
|
||||||
import time
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib3
|
import urllib3
|
||||||
import traceback
|
|
||||||
import logging
|
import logging
|
||||||
import base64
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import mimetypes
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Dict, Any, List, AsyncGenerator, Optional
|
from typing import Dict, Any, List, AsyncGenerator, Optional
|
||||||
from settings import settings
|
from settings import settings
|
||||||
@ -63,7 +59,6 @@ class MidjourneyService:
|
|||||||
"""初始化Discord客户端会话"""
|
"""初始化Discord客户端会话"""
|
||||||
client = requests.Session()
|
client = requests.Session()
|
||||||
|
|
||||||
# 添加重试机制
|
|
||||||
retry = Retry(total=3, backoff_factor=0.5)
|
retry = Retry(total=3, backoff_factor=0.5)
|
||||||
adapter = HTTPAdapter(max_retries=retry)
|
adapter = HTTPAdapter(max_retries=retry)
|
||||||
client.mount('http://', adapter)
|
client.mount('http://', adapter)
|
||||||
@ -73,17 +68,14 @@ class MidjourneyService:
|
|||||||
'Authorization': oauth_token
|
'Authorization': oauth_token
|
||||||
})
|
})
|
||||||
|
|
||||||
# 设置代理
|
|
||||||
if self.proxies:
|
if self.proxies:
|
||||||
client.proxies.update(self.proxies)
|
client.proxies.update(self.proxies)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取频道信息
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}')
|
response = client.get(f'{self.API_URL}/channels/{channel_id}')
|
||||||
data = response.json()
|
data = response.json()
|
||||||
guild_id = data['guild_id']
|
guild_id = data['guild_id']
|
||||||
|
|
||||||
# 获取用户信息
|
|
||||||
response = client.get(f'{self.API_URL}/users/@me')
|
response = client.get(f'{self.API_URL}/users/@me')
|
||||||
data = response.json()
|
data = response.json()
|
||||||
user_id = data['id']
|
user_id = data['id']
|
||||||
@ -91,7 +83,6 @@ class MidjourneyService:
|
|||||||
return client, guild_id, user_id
|
return client, guild_id, user_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化Discord客户端失败: {str(e)}")
|
logger.error(f"初始化Discord客户端失败: {str(e)}")
|
||||||
traceback.print_exc()
|
|
||||||
raise Exception(f"初始化Discord客户端失败: {str(e)}")
|
raise Exception(f"初始化Discord客户端失败: {str(e)}")
|
||||||
|
|
||||||
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
|
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
|
||||||
@ -134,28 +125,20 @@ class MidjourneyService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 发送请求
|
client.post(f'{self.API_URL}/interactions', json=params)
|
||||||
r = client.post(f'{self.API_URL}/interactions', json=params)
|
|
||||||
# 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求
|
|
||||||
print(f"[生成] 已发送请求,等待20秒后开始轮询...")
|
|
||||||
await asyncio.sleep(20)
|
await asyncio.sleep(20)
|
||||||
|
|
||||||
# 轮询获取结果
|
|
||||||
imagine_message = None
|
imagine_message = None
|
||||||
count = 0
|
count = 0
|
||||||
last_progress = 0
|
last_progress = 0
|
||||||
|
|
||||||
# 轮询直到获取完整的结果或达到最大次数
|
|
||||||
while count < settings.max_polling_attempts:
|
while count < settings.max_polling_attempts:
|
||||||
# 轮询两种消息:
|
|
||||||
# 1. 进度消息 - 用于更新进度
|
|
||||||
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
|
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
|
||||||
if progress_msg:
|
if progress_msg:
|
||||||
content = progress_msg.get('content', '')
|
content = progress_msg.get('content', '')
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||||
if progress_match:
|
if progress_match:
|
||||||
progress_value = int(progress_match.group(1))
|
progress_value = int(progress_match.group(1))
|
||||||
# 只有进度有变化时才发送更新
|
|
||||||
if progress_value > last_progress:
|
if progress_value > last_progress:
|
||||||
last_progress = progress_value
|
last_progress = progress_value
|
||||||
logger.info(f"生成进度: {progress_value}%")
|
logger.info(f"生成进度: {progress_value}%")
|
||||||
@ -166,30 +149,23 @@ class MidjourneyService:
|
|||||||
"content": content
|
"content": content
|
||||||
}
|
}
|
||||||
|
|
||||||
# 2. 完成的消息 - 包含图像结果
|
|
||||||
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
|
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
|
||||||
if imagine_message:
|
if imagine_message:
|
||||||
# 找到最终结果
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# 没有找到结果,继续等待
|
|
||||||
logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待")
|
|
||||||
await asyncio.sleep(settings.polling_interval)
|
await asyncio.sleep(settings.polling_interval)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# 检查是否超过最大轮询次数
|
|
||||||
if count >= settings.max_polling_attempts:
|
if count >= settings.max_polling_attempts:
|
||||||
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
|
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
|
||||||
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
|
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 检查是否有有效的最终结果
|
|
||||||
if not imagine_message:
|
if not imagine_message:
|
||||||
logger.error("轮询结束但没有获取到有效结果")
|
logger.error("轮询结束但没有获取到有效结果")
|
||||||
yield {"status": "error", "message": "没有获取到有效结果"}
|
yield {"status": "error", "message": "没有获取到有效结果"}
|
||||||
return
|
return
|
||||||
|
|
||||||
# 返回最终结果
|
|
||||||
logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}")
|
logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}")
|
||||||
yield {
|
yield {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -200,68 +176,38 @@ class MidjourneyService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送Imagine请求失败: {str(e)}")
|
logger.error(f"发送Imagine请求失败: {str(e)}")
|
||||||
traceback.print_exc()
|
|
||||||
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
|
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
|
||||||
|
|
||||||
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
|
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
|
||||||
"""获取生成图像的消息"""
|
"""获取生成图像的消息"""
|
||||||
try:
|
try:
|
||||||
# 获取最近的消息
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
def criteria(item):
|
def criteria(item):
|
||||||
content = item.get('content', '')
|
content = item.get('content', '')
|
||||||
|
|
||||||
# 检查进度信息并更新状态
|
|
||||||
if seed is not None and f"--seed {seed}" in content:
|
if seed is not None and f"--seed {seed}" in content:
|
||||||
# 匹配百分比,支持多种格式:(93%) 或 93%
|
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||||
if progress_match:
|
if progress_match:
|
||||||
progress_value = int(progress_match.group(1))
|
progress_value = int(progress_match.group(1))
|
||||||
# 记录进度信息
|
|
||||||
logger.info(f"任务进度: {progress_value}%")
|
logger.info(f"任务进度: {progress_value}%")
|
||||||
# 如果消息包含百分比,说明任务还在进行中,返回False继续轮询
|
|
||||||
if "%" in content:
|
if "%" in content:
|
||||||
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成")
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 排除进行中消息,只匹配完成的消息
|
|
||||||
if "%" in content:
|
if "%" in content:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# seed 匹配
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_pattern = f"--seed {seed}"
|
seed_pattern = f"--seed {seed}"
|
||||||
# 检查是否包含指定的seed
|
|
||||||
if seed_pattern not in content:
|
if seed_pattern not in content:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成")
|
|
||||||
# 检查是否包含图像附件
|
|
||||||
if 'attachments' in item and len(item.get('attachments', [])) > 0:
|
if 'attachments' in item and len(item.get('attachments', [])) > 0:
|
||||||
return True
|
return True
|
||||||
return False # 默认不匹配
|
return False
|
||||||
|
|
||||||
# 查找匹配完成状态的消息
|
return self.first_where(data, criteria)
|
||||||
raw_message = self.first_where(data, criteria)
|
|
||||||
if raw_message is None:
|
|
||||||
print("[轮询] 没有找到完成的消息,继续等待")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 打印匹配到的完整消息内容
|
|
||||||
print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}")
|
|
||||||
try:
|
|
||||||
print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[轮询] 无法打印完整消息内容: {str(e)}")
|
|
||||||
# 尝试打印关键字段
|
|
||||||
print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}")
|
|
||||||
print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}")
|
|
||||||
for i, attachment in enumerate(raw_message.get('attachments', [])):
|
|
||||||
print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}")
|
|
||||||
|
|
||||||
return raw_message
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取通道消息失败: {str(e)}")
|
logger.error(f"获取通道消息失败: {str(e)}")
|
||||||
@ -270,13 +216,11 @@ class MidjourneyService:
|
|||||||
async def get_progress_message(self, client, channel_id, prompt, seed=None):
|
async def get_progress_message(self, client, channel_id, prompt, seed=None):
|
||||||
"""获取包含进度信息的消息"""
|
"""获取包含进度信息的消息"""
|
||||||
try:
|
try:
|
||||||
# 获取最近的消息
|
|
||||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||||||
data = response.json()
|
data = response.json()
|
||||||
|
|
||||||
for item in data:
|
for item in data:
|
||||||
content = item.get('content', '')
|
content = item.get('content', '')
|
||||||
# 检查是否包含seed和进度信息
|
|
||||||
if seed is not None and f"--seed {seed}" in content and "%" in content:
|
if seed is not None and f"--seed {seed}" in content and "%" in content:
|
||||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||||
if progress_match:
|
if progress_match:
|
||||||
@ -298,69 +242,48 @@ class MidjourneyService:
|
|||||||
return image_urls
|
return image_urls
|
||||||
|
|
||||||
async def split_image(self, image_url):
|
async def split_image(self, image_url):
|
||||||
"""将一张大图切割成四张子图,保存到本地并返回URLs
|
"""将一张大图切割成四张子图,保存到本地并返回URLs"""
|
||||||
|
|
||||||
Args:
|
|
||||||
image_url: 原始图像的URL
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含四个子图像URL的列表,如果处理失败则返回None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 下载图像
|
|
||||||
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
logger.error(f"下载图像失败,状态码: {response.status_code}")
|
logger.error(f"下载图像失败,状态码: {response.status_code}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
image_data = response.content
|
image_data = response.content
|
||||||
|
|
||||||
# 从二进制数据创建图像对象
|
|
||||||
img = Image.open(io.BytesIO(image_data))
|
img = Image.open(io.BytesIO(image_data))
|
||||||
width, height = img.size
|
width, height = img.size
|
||||||
|
|
||||||
# 获取原始图片格式
|
|
||||||
original_format = img.format
|
original_format = img.format
|
||||||
if not original_format:
|
if not original_format:
|
||||||
# 如果无法获取格式,从URL中提取
|
|
||||||
parsed_url = urlparse(image_url)
|
parsed_url = urlparse(image_url)
|
||||||
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
|
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
|
||||||
if not original_format:
|
if not original_format:
|
||||||
original_format = 'PNG' # 默认使用PNG
|
original_format = 'PNG'
|
||||||
|
|
||||||
# 确认图像尺寸约为2048x2048
|
|
||||||
if width < 1500 or height < 1500:
|
if width < 1500 or height < 1500:
|
||||||
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
|
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 计算每个象限的尺寸
|
|
||||||
half_width = width // 2
|
half_width = width // 2
|
||||||
half_height = height // 2
|
half_height = height // 2
|
||||||
|
|
||||||
# 分割图像为四个象限
|
|
||||||
quadrants = [
|
quadrants = [
|
||||||
img.crop((0, 0, half_width, half_height)), # 左上
|
img.crop((0, 0, half_width, half_height)),
|
||||||
img.crop((half_width, 0, width, half_height)), # 右上
|
img.crop((half_width, 0, width, half_height)),
|
||||||
img.crop((0, half_height, half_width, height)), # 左下
|
img.crop((0, half_height, half_width, height)),
|
||||||
img.crop((half_width, half_height, width, height)) # 右下
|
img.crop((half_width, half_height, width, height))
|
||||||
]
|
]
|
||||||
|
|
||||||
# 生成唯一的图片名称前缀
|
|
||||||
image_id = uuid.uuid4().hex[:10]
|
image_id = uuid.uuid4().hex[:10]
|
||||||
|
|
||||||
# 确保保存目录存在(使用绝对路径)
|
|
||||||
save_dir = os.path.abspath(settings.save_dir)
|
save_dir = os.path.abspath(settings.save_dir)
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
# 保存图片到本地并生成URLs
|
|
||||||
image_urls = []
|
image_urls = []
|
||||||
for i, quadrant in enumerate(quadrants, 1):
|
for i, quadrant in enumerate(quadrants, 1):
|
||||||
try:
|
try:
|
||||||
# 使用原始格式保存
|
|
||||||
filename = f"split_{image_id}_{i}.{original_format.lower()}"
|
filename = f"split_{image_id}_{i}.{original_format.lower()}"
|
||||||
file_path = os.path.join(save_dir, filename)
|
file_path = os.path.join(save_dir, filename)
|
||||||
|
|
||||||
# 保存图片,保持原始格式
|
|
||||||
save_params = {"format": original_format}
|
save_params = {"format": original_format}
|
||||||
if original_format in ['PNG', 'JPEG', 'JPG']:
|
if original_format in ['PNG', 'JPEG', 'JPG']:
|
||||||
save_params["optimize"] = True
|
save_params["optimize"] = True
|
||||||
@ -369,23 +292,19 @@ class MidjourneyService:
|
|||||||
|
|
||||||
quadrant.save(file_path, **save_params)
|
quadrant.save(file_path, **save_params)
|
||||||
|
|
||||||
# 验证文件是否成功保存
|
|
||||||
if not os.path.exists(file_path):
|
if not os.path.exists(file_path):
|
||||||
raise Exception(f"文件保存失败: {file_path}")
|
raise Exception(f"文件保存失败: {file_path}")
|
||||||
|
|
||||||
# 验证文件大小
|
|
||||||
file_size = os.path.getsize(file_path)
|
file_size = os.path.getsize(file_path)
|
||||||
if file_size == 0:
|
if file_size == 0:
|
||||||
raise Exception(f"保存的文件大小为0: {file_path}")
|
raise Exception(f"保存的文件大小为0: {file_path}")
|
||||||
|
|
||||||
# 构建图片URL
|
|
||||||
image_url = f"{settings.download_url}/{filename}"
|
image_url = f"{settings.download_url}/{filename}"
|
||||||
image_urls.append(image_url)
|
image_urls.append(image_url)
|
||||||
|
|
||||||
logger.info(f"成功保存分割图片 {i}/4: {filename} (大小: {file_size} 字节, 格式: {original_format})")
|
logger.info(f"成功保存分割图片 {i}/4: {filename}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
||||||
# 如果保存失败,删除已保存的图片
|
|
||||||
for url in image_urls:
|
for url in image_urls:
|
||||||
try:
|
try:
|
||||||
file_path = os.path.join(save_dir, os.path.basename(url))
|
file_path = os.path.join(save_dir, os.path.basename(url))
|
||||||
@ -399,12 +318,11 @@ class MidjourneyService:
|
|||||||
logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张")
|
logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logger.info(f"成功完成图片分割,生成4张子图")
|
logger.info("成功完成图片分割,生成4张子图")
|
||||||
return image_urls
|
return image_urls
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"分割图像失败: {str(e)}")
|
logger.error(f"分割图像失败: {str(e)}")
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
|
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
|
||||||
@ -418,95 +336,76 @@ class MidjourneyService:
|
|||||||
for i, image_url in enumerate(image_urls, 1):
|
for i, image_url in enumerate(image_urls, 1):
|
||||||
try:
|
try:
|
||||||
if not is_valid_image_url(image_url):
|
if not is_valid_image_url(image_url):
|
||||||
response = {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"message": "无效的图片URL"
|
"message": "无效的图片URL"
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
split_urls = await self.split_image(image_url)
|
split_urls = await self.split_image(image_url)
|
||||||
if split_urls and len(split_urls) == 4:
|
if split_urls and len(split_urls) == 4:
|
||||||
success_count += 1
|
success_count += 1
|
||||||
response = {
|
yield {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"images": split_urls
|
"images": split_urls
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
else:
|
else:
|
||||||
response = {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"message": "分割图片失败"
|
"message": "分割图片失败"
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response = {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"index": i,
|
"index": i,
|
||||||
"total": total,
|
"total": total,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"message": str(e)
|
"message": str(e)
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
|
|
||||||
async def generate_image(self, prompt, config=None):
|
async def generate_image(self, prompt, config=None):
|
||||||
"""生成图像并以流式方式返回结果"""
|
"""生成图像并以流式方式返回结果"""
|
||||||
if not config:
|
if not config:
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
# 获取必要的认证信息
|
|
||||||
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
|
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
|
||||||
channel_id = config.get('channel_id', settings.midjourney_channel_id)
|
channel_id = config.get('channel_id', settings.midjourney_channel_id)
|
||||||
|
|
||||||
if not oauth_token or not channel_id:
|
if not oauth_token or not channel_id:
|
||||||
response = {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "缺少Discord配置",
|
"message": "缺少Discord配置",
|
||||||
"success_count": 0
|
"success_count": 0
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 初始化客户端
|
|
||||||
try:
|
try:
|
||||||
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
|
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
response = {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": str(e),
|
"message": str(e),
|
||||||
"success_count": 0
|
"success_count": 0
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 解析和准备提示词
|
|
||||||
prompt = prompt.strip()
|
prompt = prompt.strip()
|
||||||
|
|
||||||
# 获取或设置默认选项
|
|
||||||
options = settings.midjourney_default_options.copy()
|
options = settings.midjourney_default_options.copy()
|
||||||
if 'options' in config:
|
if 'options' in config:
|
||||||
options.update(config.get('options', {}))
|
options.update(config.get('options', {}))
|
||||||
|
|
||||||
# 强制设置seed如果没有
|
|
||||||
if "seed" not in options or not options.get('seed'):
|
if "seed" not in options or not options.get('seed'):
|
||||||
options['seed'] = random.randint(0, 4294967295)
|
options['seed'] = random.randint(0, 4294967295)
|
||||||
|
|
||||||
# 处理选项,构建参数字符串
|
|
||||||
seed = options.pop('seed', None)
|
seed = options.pop('seed', None)
|
||||||
parameter = ""
|
parameter = ""
|
||||||
no_value_key = ['relax', 'fast', 'turbo', 'tile']
|
no_value_key = ['relax', 'fast', 'turbo', 'tile']
|
||||||
@ -516,20 +415,15 @@ class MidjourneyService:
|
|||||||
else:
|
else:
|
||||||
parameter += f" --{key} {value}"
|
parameter += f" --{key} {value}"
|
||||||
|
|
||||||
# 添加seed
|
|
||||||
if seed:
|
if seed:
|
||||||
parameter += f" --seed {seed}"
|
parameter += f" --seed {seed}"
|
||||||
# 打印使用的seed值
|
logger.info(f"使用的seed值: {seed}")
|
||||||
print(f"[生成] 使用的seed值: {seed}")
|
|
||||||
|
|
||||||
# 处理参考图像
|
|
||||||
if 'image_urls' in config and config['image_urls']:
|
if 'image_urls' in config and config['image_urls']:
|
||||||
# 确保是列表格式
|
|
||||||
image_urls = config['image_urls']
|
image_urls = config['image_urls']
|
||||||
if not isinstance(image_urls, list):
|
if not isinstance(image_urls, list):
|
||||||
image_urls = [image_urls]
|
image_urls = [image_urls]
|
||||||
|
|
||||||
# 转换图片URL
|
|
||||||
new_image_urls = []
|
new_image_urls = []
|
||||||
for image_url in image_urls:
|
for image_url in image_urls:
|
||||||
if is_valid_image_url(image_url):
|
if is_valid_image_url(image_url):
|
||||||
@ -540,16 +434,13 @@ class MidjourneyService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"转换图片URL失败: {str(e)}")
|
logger.warning(f"转换图片URL失败: {str(e)}")
|
||||||
|
|
||||||
# 添加到prompt前面
|
|
||||||
if new_image_urls:
|
if new_image_urls:
|
||||||
prompt = " ".join(new_image_urls) + " " + prompt
|
prompt = " ".join(new_image_urls) + " " + prompt
|
||||||
|
|
||||||
# 设置图像权重
|
|
||||||
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
||||||
iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间
|
iw = max(0.1, min(config['image_weight'], 3))
|
||||||
parameter += f" --iw {iw}"
|
parameter += f" --iw {iw}"
|
||||||
|
|
||||||
# 添加字符引用
|
|
||||||
if 'characters' in config and config['characters']:
|
if 'characters' in config and config['characters']:
|
||||||
char_urls = []
|
char_urls = []
|
||||||
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
|
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
|
||||||
@ -566,7 +457,6 @@ class MidjourneyService:
|
|||||||
if char_urls:
|
if char_urls:
|
||||||
prompt = prompt + " --cref " + " ".join(char_urls)
|
prompt = prompt + " --cref " + " ".join(char_urls)
|
||||||
|
|
||||||
# 添加风格引用
|
|
||||||
if 'styles' in config and config['styles']:
|
if 'styles' in config and config['styles']:
|
||||||
style_urls = []
|
style_urls = []
|
||||||
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
|
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
|
||||||
@ -583,31 +473,24 @@ class MidjourneyService:
|
|||||||
if style_urls:
|
if style_urls:
|
||||||
prompt = prompt + " --sref " + " ".join(style_urls)
|
prompt = prompt + " --sref " + " ".join(style_urls)
|
||||||
|
|
||||||
# 添加参数
|
|
||||||
prompt = prompt + parameter
|
prompt = prompt + parameter
|
||||||
|
|
||||||
# 流式返回结果
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
|
|
||||||
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
||||||
if result.get("status") == "progress":
|
if result.get("status") == "progress":
|
||||||
# 进度信息只包含必要字段
|
yield {
|
||||||
response = {
|
|
||||||
"status": "progress",
|
"status": "progress",
|
||||||
"progress": result.get("progress", 0),
|
"progress": result.get("progress", 0),
|
||||||
"seed": seed # 添加seed字段
|
"seed": seed
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
elif result.get("status") == "success":
|
elif result.get("status") == "success":
|
||||||
success_count += 1
|
success_count += 1
|
||||||
response = {
|
response = {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"image_urls": result.get("images", []) # 添加image_urls字段
|
"image_urls": result.get("images", [])
|
||||||
}
|
}
|
||||||
|
|
||||||
# 如果需要分割图片
|
|
||||||
if config.get("split_image", True) and result.get("images"):
|
if config.get("split_image", True) and result.get("images"):
|
||||||
try:
|
try:
|
||||||
orig_image_url = result["images"][0]
|
orig_image_url = result["images"][0]
|
||||||
@ -619,14 +502,10 @@ class MidjourneyService:
|
|||||||
response["status"] = "error"
|
response["status"] = "error"
|
||||||
response["message"] = f"分割图像失败: {str(e)}"
|
response["message"] = f"分割图像失败: {str(e)}"
|
||||||
|
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
yield response
|
||||||
else:
|
else:
|
||||||
# 错误信息保持统一格式
|
yield {
|
||||||
response = {
|
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": result.get("message", "未知错误"),
|
"message": result.get("message", "未知错误"),
|
||||||
"success_count": success_count
|
"success_count": success_count
|
||||||
}
|
}
|
||||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
|
||||||
yield response
|
|
||||||
@ -6,7 +6,7 @@ class Settings(BaseSettings):
|
|||||||
# Japi Server 配置
|
# Japi Server 配置
|
||||||
host: str = "0.0.0.0"
|
host: str = "0.0.0.0"
|
||||||
port: int = 8113
|
port: int = 8113
|
||||||
debug: bool = True
|
debug: bool = False
|
||||||
|
|
||||||
# API路由配置
|
# API路由配置
|
||||||
router_prefix: str = "/midjourney"
|
router_prefix: str = "/midjourney"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user