import json import sys import os import io import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry import random import re import uuid import urllib.request import urllib3 import logging from pathlib import Path from urllib.parse import urlparse from PIL import Image import asyncio import httpx from typing import Dict, Any, List, AsyncGenerator, Optional from settings import settings from utils import get_new_image_url, is_valid_image_url # 设置日志记录器 logger = logging.getLogger("midjourney_service") # 禁用不安全请求警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) class MidjourneyService: def __init__(self): """初始化MidjourneyService""" self.API_URL = settings.midjourney_api_url self.APPLICATION_ID = settings.midjourney_application_id self.DATA_ID = settings.midjourney_data_id self.DATA_VERSION = settings.midjourney_data_version self.SESSION_ID = settings.midjourney_session_id # 设置代理 self.proxies = {} if settings.http_proxy and settings.http_proxy.strip(): self.proxies['http'] = settings.http_proxy if settings.https_proxy and settings.https_proxy.strip(): self.proxies['https'] = settings.https_proxy # 确保保存目录存在 os.makedirs(settings.save_dir, exist_ok=True) @staticmethod def first_where(array, key, value=None): """在数组中找到第一个匹配条件的项""" for item in array: if callable(key) and key(item): return item if isinstance(key, str) and item[key].startswith(value): return item return None async def initialize_client(self, oauth_token, channel_id): """初始化Discord客户端会话""" client = requests.Session() retry = Retry(total=3, backoff_factor=0.5) adapter = HTTPAdapter(max_retries=retry) client.mount('http://', adapter) client.mount('https://', adapter) client.headers.update({ 'Authorization': oauth_token }) if self.proxies: client.proxies.update(self.proxies) try: response = client.get(f'{self.API_URL}/channels/{channel_id}') data = response.json() guild_id = data['guild_id'] response = client.get(f'{self.API_URL}/users/@me') data = response.json() user_id = data['id'] return client, guild_id, user_id except Exception as e: logger.error(f"初始化Discord客户端失败: {str(e)}") raise Exception(f"初始化Discord客户端失败: {str(e)}") async def imagine(self, client, guild_id, channel_id, prompt, seed=None): """发送imagine请求到Discord""" params = { 'type': 2, 'application_id': self.APPLICATION_ID, 'guild_id': guild_id, 'channel_id': channel_id, 'session_id': self.SESSION_ID, 'data': { 'version': self.DATA_VERSION, 'id': self.DATA_ID, 'name': 'imagine', 'type': 1, 'options': [{ 'type': 3, 'name': 'prompt', 'value': prompt }], 'application_command': { 'id': self.DATA_ID, 'application_id': self.APPLICATION_ID, 'version': self.DATA_VERSION, 'default_member_permissions': None, 'type': 1, 'nsfw': False, 'name': 'imagine', 'description': 'Create images with Midjourney', 'dm_permission': True, 'options': [{ 'type': 3, 'name': 'prompt', 'description': 'The prompt to imagine', 'required': True }] }, 'attachments': [] } } try: client.post(f'{self.API_URL}/interactions', json=params) await asyncio.sleep(20) imagine_message = None count = 0 last_progress = 0 while count < settings.max_polling_attempts: progress_msg = await self.get_progress_message(client, channel_id, prompt, seed) if progress_msg: content = progress_msg.get('content', '') progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: progress_value = int(progress_match.group(1)) if progress_value > last_progress: last_progress = progress_value logger.info(f"生成进度: {progress_value}%") yield { "status": "progress", "progress": progress_value, "message_id": progress_msg.get('id'), "content": content } imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed) if imagine_message: break await asyncio.sleep(settings.polling_interval) count += 1 if count >= settings.max_polling_attempts: logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}") yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"} return if not imagine_message: logger.error("轮询结束但没有获取到有效结果") yield {"status": "error", "message": "没有获取到有效结果"} return logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}") yield { "status": "success", "message_id": imagine_message.get('id'), "content": imagine_message.get('content', ''), "images": self.extract_image_urls(imagine_message) } except Exception as e: logger.error(f"发送Imagine请求失败: {str(e)}") yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"} async def get_imagine(self, client, channel_id, prompt, count=0, seed=None): """获取生成图像的消息""" try: response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') data = response.json() def criteria(item): content = item.get('content', '') if seed is not None and f"--seed {seed}" in content: progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: progress_value = int(progress_match.group(1)) logger.info(f"任务进度: {progress_value}%") if "%" in content: return False if "%" in content: return False if seed is not None: seed_pattern = f"--seed {seed}" if seed_pattern not in content: return False else: if 'attachments' in item and len(item.get('attachments', [])) > 0: return True return False return self.first_where(data, criteria) except Exception as e: logger.error(f"获取通道消息失败: {str(e)}") return None async def get_progress_message(self, client, channel_id, prompt, seed=None): """获取包含进度信息的消息""" try: response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') data = response.json() for item in data: content = item.get('content', '') if seed is not None and f"--seed {seed}" in content and "%" in content: progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: return item return None except Exception as e: logger.error(f"获取进度消息失败: {str(e)}") return None def extract_image_urls(self, message): """从消息中提取图像URL""" image_urls = [] if 'attachments' in message and message['attachments']: for attachment in message['attachments']: if 'url' in attachment: image_urls.append(attachment['url']) return image_urls async def split_image(self, image_url): """将一张大图切割成四张子图,保存到本地并返回URLs""" try: response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30) if response.status_code != 200: logger.error(f"下载图像失败,状态码: {response.status_code}") return None image_data = response.content img = Image.open(io.BytesIO(image_data)) width, height = img.size original_format = img.format if not original_format: parsed_url = urlparse(image_url) original_format = os.path.splitext(parsed_url.path)[1][1:].upper() if not original_format: original_format = 'PNG' if width < 500 or height < 500: logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048") return None half_width = width // 2 half_height = height // 2 quadrants = [ img.crop((0, 0, half_width, half_height)), img.crop((half_width, 0, width, half_height)), img.crop((0, half_height, half_width, height)), img.crop((half_width, half_height, width, height)) ] image_id = uuid.uuid4().hex[:10] save_dir = os.path.abspath(settings.save_dir) os.makedirs(save_dir, exist_ok=True) image_urls = [] for i, quadrant in enumerate(quadrants, 1): try: filename = f"split_{image_id}_{i}.{original_format.lower()}" file_path = os.path.join(save_dir, filename) save_params = {"format": original_format} if original_format in ['PNG', 'JPEG', 'JPG']: save_params["optimize"] = True if original_format in ['JPEG', 'JPG']: save_params["quality"] = 95 quadrant.save(file_path, **save_params) if not os.path.exists(file_path): raise Exception(f"文件保存失败: {file_path}") file_size = os.path.getsize(file_path) if file_size == 0: raise Exception(f"保存的文件大小为0: {file_path}") image_url = f"{settings.download_url}/{filename}" image_urls.append(image_url) logger.info(f"成功保存分割图片 {i}/4: {filename}") except Exception as e: logger.error(f"保存分割图片 {i}/4 失败: {str(e)}") for url in image_urls: try: file_path = os.path.join(save_dir, os.path.basename(url)) if os.path.exists(file_path): os.remove(file_path) except Exception as del_e: logger.error(f"删除失败的图片文件时出错: {str(del_e)}") return None if len(image_urls) != 4: logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张") return None logger.info("成功完成图片分割,生成4张子图") return image_urls except Exception as e: logger.error(f"分割图像失败: {str(e)}") return None async def split_images(self, image_urls: List[str], config: Optional[Dict] = None): """批量处理多个图像URL""" if not config: config = {} total = len(image_urls) success_count = 0 for i, image_url in enumerate(image_urls, 1): try: if not is_valid_image_url(image_url): yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": "无效的图片URL" } continue split_urls = await self.split_image(image_url) if split_urls and len(split_urls) == 4: success_count += 1 yield { "status": "success", "index": i, "total": total, "success_count": success_count, "images": split_urls } else: yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": "分割图片失败" } except Exception as e: yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": str(e) } async def generate_image(self, prompt, config=None): """生成图像并以流式方式返回结果""" if not config: config = {} oauth_token = config.get('oauth_token', settings.midjourney_oauth_token) channel_id = config.get('channel_id', settings.midjourney_channel_id) if not oauth_token or not channel_id: yield { "status": "error", "message": "缺少Discord配置", "success_count": 0 } return try: client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id) except Exception as e: yield { "status": "error", "message": str(e), "success_count": 0 } return prompt = prompt.strip() options = settings.midjourney_default_options.copy() if 'options' in config: options.update(config.get('options', {})) if "seed" not in options or not options.get('seed'): options['seed'] = random.randint(0, 4294967295) if 'image_weight' in config and isinstance(config['image_weight'], (int, float)): iw = max(0.1, min(config['image_weight'], 3)) options['iw'] = iw seed = options.pop('seed', None) parameter = "" no_value_key = ['relax', 'fast', 'turbo', 'tile'] for key, value in options.items(): if key in no_value_key: parameter += f" --{key}" else: parameter += f" --{key} {value}" if seed: parameter += f" --seed {seed}" logger.info(f"使用的seed值: {seed}") if 'image_urls' in config and config['image_urls']: image_urls = config['image_urls'] if not isinstance(image_urls, list): image_urls = [image_urls] new_image_urls = [] for image_url in image_urls: if is_valid_image_url(image_url): try: new_url = get_new_image_url(image_url) if new_url: new_image_urls.append(new_url) except Exception as e: logger.warning(f"转换图片URL失败: {str(e)}") if new_image_urls: prompt = " ".join(new_image_urls) + " " + prompt if 'characters' in config and config['characters']: char_urls = [] characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']] for char_url in characters: if is_valid_image_url(char_url): try: new_url = get_new_image_url(char_url) if new_url: char_urls.append(new_url) except Exception as e: logger.warning(f"转换角色图片URL失败: {str(e)}") if char_urls: prompt = prompt + " --cref " + " ".join(char_urls) if 'styles' in config and config['styles']: style_urls = [] styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']] for style_url in styles: if is_valid_image_url(style_url): try: new_url = get_new_image_url(style_url) if new_url: style_urls.append(new_url) except Exception as e: logger.warning(f"转换风格图片URL失败: {str(e)}") if style_urls: prompt = prompt + " --sref " + " ".join(style_urls) prompt = prompt + parameter success_count = 0 async for result in self.imagine(client, guild_id, channel_id, prompt, seed): if result.get("status") == "progress": yield { "status": "progress", "progress": result.get("progress", 0), "seed": seed } elif result.get("status") == "success": success_count += 1 response = { "status": "success", "success_count": success_count, "image_urls": result.get("images", []) } if config.get("split_image", True) and result.get("images"): try: orig_image_url = result["images"][0] split_urls = await self.split_image(orig_image_url) if split_urls: response["image_urls"] = split_urls except Exception as e: logger.error(f"分割图像失败: {str(e)}") response["status"] = "error" response["message"] = f"分割图像失败: {str(e)}" yield response else: yield { "status": "error", "message": result.get("message", "未知错误"), "success_count": success_count } async def ve_submit_imagine(self, prompt: str, base64_array: List[str] = None, notify_hook: str = "", state: str = "", bot_type: str = "MID_JOURNEY") -> Dict[str, Any]: """ 提交 imagine 任务到 VectorEngine API Args: prompt: 提示词 base64_array: base64 编码的图片数组 notify_hook: 回调地址 state: 自定义状态 bot_type: 机器人类型,默认 MID_JOURNEY Returns: 包含任务ID的响应 """ if base64_array is None: base64_array = [] payload = { "base64Array": base64_array, "notifyHook": notify_hook, "prompt": prompt, "state": state, "botType": bot_type } headers = { "Authorization": f"Bearer {settings.vectorengine_token}", "Content-Type": "application/json" } try: async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{settings.vectorengine_api_url}/mj/submit/imagine", json=payload, headers=headers ) if response.status_code != 200: logger.error(f"VectorEngine API 请求失败: {response.status_code} - {response.text}") return {"success": False, "message": f"API请求失败: {response.status_code}"} result = response.json() logger.info(f"VectorEngine 提交任务成功: {result}") return result except Exception as e: logger.error(f"VectorEngine 提交任务异常: {str(e)}") return {"success": False, "message": str(e)} async def ve_get_task_status(self, task_id: str) -> Dict[str, Any]: """ 获取 VectorEngine 任务状态 Args: task_id: 任务ID Returns: 任务状态信息 """ headers = { "Authorization": f"Bearer {settings.vectorengine_token}", "Content-Type": "application/json" } try: async with httpx.AsyncClient(timeout=60.0) as client: response = await client.get( f"{settings.vectorengine_api_url}/mj/task/{task_id}/fetch", headers=headers ) if response.status_code != 200: logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}") return {"success": False, "message": f"获取任务状态失败: {response.status_code}"} result = response.json() return result except Exception as e: logger.error(f"VectorEngine 获取任务状态异常: {str(e)}") return {"success": False, "message": str(e)} async def ve_generate_image(self, prompt: str, config: Dict = None) -> AsyncGenerator[Dict, None]: """ 使用 VectorEngine API 生成图像并以流式方式返回结果 Args: prompt: 提示词 config: 配置参数,可包含: - base64_array: base64 编码的图片数组 - notify_hook: 回调地址 - state: 自定义状态 - bot_type: 机器人类型 - split_image: 是否分割图片 (默认 True) Yields: 生成状态和结果 """ if not config: config = {} if not settings.vectorengine_token: yield { "status": "error", "message": "未配置 VectorEngine Token", "success_count": 0 } return # 提交任务 submit_result = await self.ve_submit_imagine( prompt=prompt, base64_array=config.get("base64_array", []), notify_hook=config.get("notify_hook", ""), state=config.get("state", ""), bot_type=config.get("bot_type", "MID_JOURNEY") ) if not submit_result.get("success", True) and "result" not in submit_result: yield { "status": "error", "message": submit_result.get("message", "提交任务失败"), "success_count": 0 } return # 获取任务ID task_id = submit_result.get("result") if not task_id: yield { "status": "error", "message": "未获取到任务ID", "success_count": 0 } return logger.info(f"VectorEngine 任务已提交,任务ID: {task_id}") yield { "status": "submitted", "task_id": task_id } # 轮询任务状态 polling_count = 0 last_progress = 0 while polling_count < settings.vectorengine_max_polling_attempts: await asyncio.sleep(settings.vectorengine_polling_interval) polling_count += 1 task_status = await self.ve_get_task_status(task_id) if not task_status: continue status = task_status.get("status") progress = task_status.get("progress", 0) # 解析进度值(可能是数字或带%的字符串如 '0%') try: if isinstance(progress, str): progress_val = int(progress.strip().rstrip('%')) else: progress_val = int(progress) except (ValueError, TypeError): progress_val = 0 # 发送进度更新 if progress_val > last_progress: last_progress = progress_val yield { "status": "progress", "progress": last_progress, "task_id": task_id } # 检查任务完成状态 if status == "SUCCESS": image_url = task_status.get("imageUrl") image_urls = [image_url] if image_url else [] # 检查是否有多个图片URL(对于网格图) if "buttons" in task_status: # 可能包含 U1, U2, U3, U4 按钮对应的图片 pass response = { "status": "success", "success_count": 1, "task_id": task_id, "image_urls": image_urls } # 如果配置了分割图片,则处理分割 if config.get("split_image", True) and image_url: try: split_urls = await self.split_image(image_url) if split_urls: response["image_urls"] = split_urls logger.info(f"图片分割成功,生成 {len(split_urls)} 张子图") except Exception as e: logger.error(f"分割图像失败: {str(e)}") yield response return elif status == "FAILURE" or status == "FAILED": yield { "status": "error", "message": task_status.get("failReason", "任务失败"), "task_id": task_id, "success_count": 0 } return elif status == "CANCELLED": yield { "status": "error", "message": "任务已取消", "task_id": task_id, "success_count": 0 } return # 超时 yield { "status": "error", "message": "任务超时", "task_id": task_id, "success_count": 0 }