512 lines
20 KiB
Python
512 lines
20 KiB
Python
import json
|
||
import sys
|
||
import os
|
||
import io
|
||
import requests
|
||
from requests.adapters import HTTPAdapter
|
||
from urllib3.util.retry import Retry
|
||
import random
|
||
import re
|
||
import uuid
|
||
import urllib.request
|
||
import urllib3
|
||
import logging
|
||
from pathlib import Path
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import asyncio
|
||
from typing import Dict, Any, List, AsyncGenerator, Optional
|
||
from settings import settings
|
||
from utils import get_new_image_url, is_valid_image_url
|
||
|
||
|
||
# 设置日志记录器
|
||
logger = logging.getLogger("midjourney_service")
|
||
|
||
# 禁用不安全请求警告
|
||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||
|
||
class MidjourneyService:
|
||
def __init__(self):
|
||
"""初始化MidjourneyService"""
|
||
self.API_URL = settings.midjourney_api_url
|
||
self.APPLICATION_ID = settings.midjourney_application_id
|
||
self.DATA_ID = settings.midjourney_data_id
|
||
self.DATA_VERSION = settings.midjourney_data_version
|
||
self.SESSION_ID = settings.midjourney_session_id
|
||
|
||
# 设置代理
|
||
self.proxies = {}
|
||
if settings.http_proxy and settings.http_proxy.strip():
|
||
self.proxies['http'] = settings.http_proxy
|
||
if settings.https_proxy and settings.https_proxy.strip():
|
||
self.proxies['https'] = settings.https_proxy
|
||
|
||
# 确保保存目录存在
|
||
os.makedirs(settings.save_dir, exist_ok=True)
|
||
|
||
@staticmethod
|
||
def first_where(array, key, value=None):
|
||
"""在数组中找到第一个匹配条件的项"""
|
||
for item in array:
|
||
if callable(key) and key(item):
|
||
return item
|
||
if isinstance(key, str) and item[key].startswith(value):
|
||
return item
|
||
return None
|
||
|
||
async def initialize_client(self, oauth_token, channel_id):
|
||
"""初始化Discord客户端会话"""
|
||
client = requests.Session()
|
||
|
||
retry = Retry(total=3, backoff_factor=0.5)
|
||
adapter = HTTPAdapter(max_retries=retry)
|
||
client.mount('http://', adapter)
|
||
client.mount('https://', adapter)
|
||
|
||
client.headers.update({
|
||
'Authorization': oauth_token
|
||
})
|
||
|
||
if self.proxies:
|
||
client.proxies.update(self.proxies)
|
||
|
||
try:
|
||
response = client.get(f'{self.API_URL}/channels/{channel_id}')
|
||
data = response.json()
|
||
guild_id = data['guild_id']
|
||
|
||
response = client.get(f'{self.API_URL}/users/@me')
|
||
data = response.json()
|
||
user_id = data['id']
|
||
|
||
return client, guild_id, user_id
|
||
except Exception as e:
|
||
logger.error(f"初始化Discord客户端失败: {str(e)}")
|
||
raise Exception(f"初始化Discord客户端失败: {str(e)}")
|
||
|
||
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
|
||
"""发送imagine请求到Discord"""
|
||
params = {
|
||
'type': 2,
|
||
'application_id': self.APPLICATION_ID,
|
||
'guild_id': guild_id,
|
||
'channel_id': channel_id,
|
||
'session_id': self.SESSION_ID,
|
||
'data': {
|
||
'version': self.DATA_VERSION,
|
||
'id': self.DATA_ID,
|
||
'name': 'imagine',
|
||
'type': 1,
|
||
'options': [{
|
||
'type': 3,
|
||
'name': 'prompt',
|
||
'value': prompt
|
||
}],
|
||
'application_command': {
|
||
'id': self.DATA_ID,
|
||
'application_id': self.APPLICATION_ID,
|
||
'version': self.DATA_VERSION,
|
||
'default_member_permissions': None,
|
||
'type': 1,
|
||
'nsfw': False,
|
||
'name': 'imagine',
|
||
'description': 'Create images with Midjourney',
|
||
'dm_permission': True,
|
||
'options': [{
|
||
'type': 3,
|
||
'name': 'prompt',
|
||
'description': 'The prompt to imagine',
|
||
'required': True
|
||
}]
|
||
},
|
||
'attachments': []
|
||
}
|
||
}
|
||
|
||
try:
|
||
client.post(f'{self.API_URL}/interactions', json=params)
|
||
await asyncio.sleep(20)
|
||
|
||
imagine_message = None
|
||
count = 0
|
||
last_progress = 0
|
||
|
||
while count < settings.max_polling_attempts:
|
||
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
|
||
if progress_msg:
|
||
content = progress_msg.get('content', '')
|
||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||
if progress_match:
|
||
progress_value = int(progress_match.group(1))
|
||
if progress_value > last_progress:
|
||
last_progress = progress_value
|
||
logger.info(f"生成进度: {progress_value}%")
|
||
yield {
|
||
"status": "progress",
|
||
"progress": progress_value,
|
||
"message_id": progress_msg.get('id'),
|
||
"content": content
|
||
}
|
||
|
||
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
|
||
if imagine_message:
|
||
break
|
||
|
||
await asyncio.sleep(settings.polling_interval)
|
||
count += 1
|
||
|
||
if count >= settings.max_polling_attempts:
|
||
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
|
||
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
|
||
return
|
||
|
||
if not imagine_message:
|
||
logger.error("轮询结束但没有获取到有效结果")
|
||
yield {"status": "error", "message": "没有获取到有效结果"}
|
||
return
|
||
|
||
logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}")
|
||
yield {
|
||
"status": "success",
|
||
"message_id": imagine_message.get('id'),
|
||
"content": imagine_message.get('content', ''),
|
||
"images": self.extract_image_urls(imagine_message)
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"发送Imagine请求失败: {str(e)}")
|
||
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
|
||
|
||
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
|
||
"""获取生成图像的消息"""
|
||
try:
|
||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||
data = response.json()
|
||
|
||
def criteria(item):
|
||
content = item.get('content', '')
|
||
|
||
if seed is not None and f"--seed {seed}" in content:
|
||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||
if progress_match:
|
||
progress_value = int(progress_match.group(1))
|
||
logger.info(f"任务进度: {progress_value}%")
|
||
if "%" in content:
|
||
return False
|
||
|
||
if "%" in content:
|
||
return False
|
||
|
||
if seed is not None:
|
||
seed_pattern = f"--seed {seed}"
|
||
if seed_pattern not in content:
|
||
return False
|
||
else:
|
||
if 'attachments' in item and len(item.get('attachments', [])) > 0:
|
||
return True
|
||
return False
|
||
|
||
return self.first_where(data, criteria)
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取通道消息失败: {str(e)}")
|
||
return None
|
||
|
||
async def get_progress_message(self, client, channel_id, prompt, seed=None):
|
||
"""获取包含进度信息的消息"""
|
||
try:
|
||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||
data = response.json()
|
||
|
||
for item in data:
|
||
content = item.get('content', '')
|
||
if seed is not None and f"--seed {seed}" in content and "%" in content:
|
||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||
if progress_match:
|
||
return item
|
||
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取进度消息失败: {str(e)}")
|
||
return None
|
||
|
||
def extract_image_urls(self, message):
|
||
"""从消息中提取图像URL"""
|
||
image_urls = []
|
||
if 'attachments' in message and message['attachments']:
|
||
for attachment in message['attachments']:
|
||
if 'url' in attachment:
|
||
image_urls.append(attachment['url'])
|
||
return image_urls
|
||
|
||
async def split_image(self, image_url):
|
||
"""将一张大图切割成四张子图,保存到本地并返回URLs"""
|
||
try:
|
||
response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30)
|
||
if response.status_code != 200:
|
||
logger.error(f"下载图像失败,状态码: {response.status_code}")
|
||
return None
|
||
|
||
image_data = response.content
|
||
img = Image.open(io.BytesIO(image_data))
|
||
width, height = img.size
|
||
|
||
original_format = img.format
|
||
if not original_format:
|
||
parsed_url = urlparse(image_url)
|
||
original_format = os.path.splitext(parsed_url.path)[1][1:].upper()
|
||
if not original_format:
|
||
original_format = 'PNG'
|
||
|
||
if width < 1500 or height < 1500:
|
||
logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048")
|
||
return None
|
||
|
||
half_width = width // 2
|
||
half_height = height // 2
|
||
|
||
quadrants = [
|
||
img.crop((0, 0, half_width, half_height)),
|
||
img.crop((half_width, 0, width, half_height)),
|
||
img.crop((0, half_height, half_width, height)),
|
||
img.crop((half_width, half_height, width, height))
|
||
]
|
||
|
||
image_id = uuid.uuid4().hex[:10]
|
||
save_dir = os.path.abspath(settings.save_dir)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
image_urls = []
|
||
for i, quadrant in enumerate(quadrants, 1):
|
||
try:
|
||
filename = f"split_{image_id}_{i}.{original_format.lower()}"
|
||
file_path = os.path.join(save_dir, filename)
|
||
|
||
save_params = {"format": original_format}
|
||
if original_format in ['PNG', 'JPEG', 'JPG']:
|
||
save_params["optimize"] = True
|
||
if original_format in ['JPEG', 'JPG']:
|
||
save_params["quality"] = 95
|
||
|
||
quadrant.save(file_path, **save_params)
|
||
|
||
if not os.path.exists(file_path):
|
||
raise Exception(f"文件保存失败: {file_path}")
|
||
|
||
file_size = os.path.getsize(file_path)
|
||
if file_size == 0:
|
||
raise Exception(f"保存的文件大小为0: {file_path}")
|
||
|
||
image_url = f"{settings.download_url}/{filename}"
|
||
image_urls.append(image_url)
|
||
|
||
logger.info(f"成功保存分割图片 {i}/4: {filename}")
|
||
except Exception as e:
|
||
logger.error(f"保存分割图片 {i}/4 失败: {str(e)}")
|
||
for url in image_urls:
|
||
try:
|
||
file_path = os.path.join(save_dir, os.path.basename(url))
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
except Exception as del_e:
|
||
logger.error(f"删除失败的图片文件时出错: {str(del_e)}")
|
||
return None
|
||
|
||
if len(image_urls) != 4:
|
||
logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张")
|
||
return None
|
||
|
||
logger.info("成功完成图片分割,生成4张子图")
|
||
return image_urls
|
||
|
||
except Exception as e:
|
||
logger.error(f"分割图像失败: {str(e)}")
|
||
return None
|
||
|
||
async def split_images(self, image_urls: List[str], config: Optional[Dict] = None):
|
||
"""批量处理多个图像URL"""
|
||
if not config:
|
||
config = {}
|
||
|
||
total = len(image_urls)
|
||
success_count = 0
|
||
|
||
for i, image_url in enumerate(image_urls, 1):
|
||
try:
|
||
if not is_valid_image_url(image_url):
|
||
yield {
|
||
"status": "error",
|
||
"index": i,
|
||
"total": total,
|
||
"success_count": success_count,
|
||
"message": "无效的图片URL"
|
||
}
|
||
continue
|
||
|
||
split_urls = await self.split_image(image_url)
|
||
if split_urls and len(split_urls) == 4:
|
||
success_count += 1
|
||
yield {
|
||
"status": "success",
|
||
"index": i,
|
||
"total": total,
|
||
"success_count": success_count,
|
||
"images": split_urls
|
||
}
|
||
else:
|
||
yield {
|
||
"status": "error",
|
||
"index": i,
|
||
"total": total,
|
||
"success_count": success_count,
|
||
"message": "分割图片失败"
|
||
}
|
||
except Exception as e:
|
||
yield {
|
||
"status": "error",
|
||
"index": i,
|
||
"total": total,
|
||
"success_count": success_count,
|
||
"message": str(e)
|
||
}
|
||
|
||
async def generate_image(self, prompt, config=None):
|
||
"""生成图像并以流式方式返回结果"""
|
||
if not config:
|
||
config = {}
|
||
|
||
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
|
||
channel_id = config.get('channel_id', settings.midjourney_channel_id)
|
||
|
||
if not oauth_token or not channel_id:
|
||
yield {
|
||
"status": "error",
|
||
"message": "缺少Discord配置",
|
||
"success_count": 0
|
||
}
|
||
return
|
||
|
||
try:
|
||
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
|
||
except Exception as e:
|
||
yield {
|
||
"status": "error",
|
||
"message": str(e),
|
||
"success_count": 0
|
||
}
|
||
return
|
||
|
||
prompt = prompt.strip()
|
||
options = settings.midjourney_default_options.copy()
|
||
if 'options' in config:
|
||
options.update(config.get('options', {}))
|
||
|
||
if "seed" not in options or not options.get('seed'):
|
||
options['seed'] = random.randint(0, 4294967295)
|
||
|
||
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
||
iw = max(0.1, min(config['image_weight'], 3))
|
||
options['iw'] = iw
|
||
|
||
seed = options.pop('seed', None)
|
||
|
||
parameter = ""
|
||
no_value_key = ['relax', 'fast', 'turbo', 'tile']
|
||
for key, value in options.items():
|
||
if key in no_value_key:
|
||
parameter += f" --{key}"
|
||
else:
|
||
parameter += f" --{key} {value}"
|
||
|
||
if seed:
|
||
parameter += f" --seed {seed}"
|
||
logger.info(f"使用的seed值: {seed}")
|
||
|
||
if 'image_urls' in config and config['image_urls']:
|
||
image_urls = config['image_urls']
|
||
if not isinstance(image_urls, list):
|
||
image_urls = [image_urls]
|
||
|
||
new_image_urls = []
|
||
for image_url in image_urls:
|
||
if is_valid_image_url(image_url):
|
||
try:
|
||
new_url = get_new_image_url(image_url)
|
||
if new_url:
|
||
new_image_urls.append(new_url)
|
||
except Exception as e:
|
||
logger.warning(f"转换图片URL失败: {str(e)}")
|
||
|
||
if new_image_urls:
|
||
prompt = " ".join(new_image_urls) + " " + prompt
|
||
|
||
if 'characters' in config and config['characters']:
|
||
char_urls = []
|
||
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
|
||
|
||
for char_url in characters:
|
||
if is_valid_image_url(char_url):
|
||
try:
|
||
new_url = get_new_image_url(char_url)
|
||
if new_url:
|
||
char_urls.append(new_url)
|
||
except Exception as e:
|
||
logger.warning(f"转换角色图片URL失败: {str(e)}")
|
||
|
||
if char_urls:
|
||
prompt = prompt + " --cref " + " ".join(char_urls)
|
||
|
||
if 'styles' in config and config['styles']:
|
||
style_urls = []
|
||
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
|
||
|
||
for style_url in styles:
|
||
if is_valid_image_url(style_url):
|
||
try:
|
||
new_url = get_new_image_url(style_url)
|
||
if new_url:
|
||
style_urls.append(new_url)
|
||
except Exception as e:
|
||
logger.warning(f"转换风格图片URL失败: {str(e)}")
|
||
|
||
if style_urls:
|
||
prompt = prompt + " --sref " + " ".join(style_urls)
|
||
|
||
prompt = prompt + parameter
|
||
success_count = 0
|
||
|
||
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
||
if result.get("status") == "progress":
|
||
yield {
|
||
"status": "progress",
|
||
"progress": result.get("progress", 0),
|
||
"seed": seed
|
||
}
|
||
elif result.get("status") == "success":
|
||
success_count += 1
|
||
response = {
|
||
"status": "success",
|
||
"success_count": success_count,
|
||
"image_urls": result.get("images", [])
|
||
}
|
||
|
||
if config.get("split_image", True) and result.get("images"):
|
||
try:
|
||
orig_image_url = result["images"][0]
|
||
split_urls = await self.split_image(orig_image_url)
|
||
if split_urls:
|
||
response["image_urls"] = split_urls
|
||
except Exception as e:
|
||
logger.error(f"分割图像失败: {str(e)}")
|
||
response["status"] = "error"
|
||
response["message"] = f"分割图像失败: {str(e)}"
|
||
|
||
yield response
|
||
else:
|
||
yield {
|
||
"status": "error",
|
||
"message": result.get("message", "未知错误"),
|
||
"success_count": success_count
|
||
} |