diff --git a/apps/jmidjourney/service.py b/apps/jmidjourney/service.py index ba4038d..0ec6dcf 100644 --- a/apps/jmidjourney/service.py +++ b/apps/jmidjourney/service.py @@ -1,14 +1,8 @@ import json -import sys import os import io import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry -import random -import re import uuid -import urllib.request import urllib3 import logging from pathlib import Path @@ -18,7 +12,6 @@ import asyncio import httpx from typing import Dict, Any, List, AsyncGenerator, Optional from settings import settings -from utils import get_new_image_url, is_valid_image_url # 设置日志记录器 @@ -27,15 +20,10 @@ logger = logging.getLogger("midjourney_service") # 禁用不安全请求警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + class MidjourneyService: def __init__(self): - """初始化MidjourneyService""" - self.API_URL = settings.midjourney_api_url - self.APPLICATION_ID = settings.midjourney_application_id - self.DATA_ID = settings.midjourney_data_id - self.DATA_VERSION = settings.midjourney_data_version - self.SESSION_ID = settings.midjourney_session_id - + """初始化 MidjourneyService""" # 设置代理 self.proxies = {} if settings.http_proxy and settings.http_proxy.strip(): @@ -46,204 +34,16 @@ class MidjourneyService: # 确保保存目录存在 os.makedirs(settings.save_dir, exist_ok=True) - @staticmethod - def first_where(array, key, value=None): - """在数组中找到第一个匹配条件的项""" - for item in array: - if callable(key) and key(item): - return item - if isinstance(key, str) and item[key].startswith(value): - return item - return None - - async def initialize_client(self, oauth_token, channel_id): - """初始化Discord客户端会话""" - client = requests.Session() + async def split_image(self, image_url: str) -> Optional[List[str]]: + """ + 将一张大图切割成四张子图,保存到本地并返回 URLs - retry = Retry(total=3, backoff_factor=0.5) - adapter = HTTPAdapter(max_retries=retry) - client.mount('http://', adapter) - client.mount('https://', adapter) - - client.headers.update({ - 'Authorization': oauth_token - }) - - if self.proxies: - client.proxies.update(self.proxies) - - try: - response = client.get(f'{self.API_URL}/channels/{channel_id}') - data = response.json() - guild_id = data['guild_id'] + Args: + image_url: 图片URL - response = client.get(f'{self.API_URL}/users/@me') - data = response.json() - user_id = data['id'] - - return client, guild_id, user_id - except Exception as e: - logger.error(f"初始化Discord客户端失败: {str(e)}") - raise Exception(f"初始化Discord客户端失败: {str(e)}") - - async def imagine(self, client, guild_id, channel_id, prompt, seed=None): - """发送imagine请求到Discord""" - params = { - 'type': 2, - 'application_id': self.APPLICATION_ID, - 'guild_id': guild_id, - 'channel_id': channel_id, - 'session_id': self.SESSION_ID, - 'data': { - 'version': self.DATA_VERSION, - 'id': self.DATA_ID, - 'name': 'imagine', - 'type': 1, - 'options': [{ - 'type': 3, - 'name': 'prompt', - 'value': prompt - }], - 'application_command': { - 'id': self.DATA_ID, - 'application_id': self.APPLICATION_ID, - 'version': self.DATA_VERSION, - 'default_member_permissions': None, - 'type': 1, - 'nsfw': False, - 'name': 'imagine', - 'description': 'Create images with Midjourney', - 'dm_permission': True, - 'options': [{ - 'type': 3, - 'name': 'prompt', - 'description': 'The prompt to imagine', - 'required': True - }] - }, - 'attachments': [] - } - } - - try: - client.post(f'{self.API_URL}/interactions', json=params) - await asyncio.sleep(20) - - imagine_message = None - count = 0 - last_progress = 0 - - while count < settings.max_polling_attempts: - progress_msg = await self.get_progress_message(client, channel_id, prompt, seed) - if progress_msg: - content = progress_msg.get('content', '') - progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) - if progress_match: - progress_value = int(progress_match.group(1)) - if progress_value > last_progress: - last_progress = progress_value - logger.info(f"生成进度: {progress_value}%") - yield { - "status": "progress", - "progress": progress_value, - "message_id": progress_msg.get('id'), - "content": content - } - - imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed) - if imagine_message: - break - - await asyncio.sleep(settings.polling_interval) - count += 1 - - if count >= settings.max_polling_attempts: - logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}") - yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"} - return - - if not imagine_message: - logger.error("轮询结束但没有获取到有效结果") - yield {"status": "error", "message": "没有获取到有效结果"} - return - - logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}") - yield { - "status": "success", - "message_id": imagine_message.get('id'), - "content": imagine_message.get('content', ''), - "images": self.extract_image_urls(imagine_message) - } - - except Exception as e: - logger.error(f"发送Imagine请求失败: {str(e)}") - yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"} - - async def get_imagine(self, client, channel_id, prompt, count=0, seed=None): - """获取生成图像的消息""" - try: - response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') - data = response.json() - - def criteria(item): - content = item.get('content', '') - - if seed is not None and f"--seed {seed}" in content: - progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) - if progress_match: - progress_value = int(progress_match.group(1)) - logger.info(f"任务进度: {progress_value}%") - if "%" in content: - return False - - if "%" in content: - return False - - if seed is not None: - seed_pattern = f"--seed {seed}" - if seed_pattern not in content: - return False - else: - if 'attachments' in item and len(item.get('attachments', [])) > 0: - return True - return False - - return self.first_where(data, criteria) - - except Exception as e: - logger.error(f"获取通道消息失败: {str(e)}") - return None - - async def get_progress_message(self, client, channel_id, prompt, seed=None): - """获取包含进度信息的消息""" - try: - response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') - data = response.json() - - for item in data: - content = item.get('content', '') - if seed is not None and f"--seed {seed}" in content and "%" in content: - progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) - if progress_match: - return item - - return None - - except Exception as e: - logger.error(f"获取进度消息失败: {str(e)}") - return None - - def extract_image_urls(self, message): - """从消息中提取图像URL""" - image_urls = [] - if 'attachments' in message and message['attachments']: - for attachment in message['attachments']: - if 'url' in attachment: - image_urls.append(attachment['url']) - return image_urls - - async def split_image(self, image_url): - """将一张大图切割成四张子图,保存到本地并返回URLs""" + Returns: + 分割后的图片URL列表,失败返回 None + """ try: response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30) if response.status_code != 200: @@ -262,8 +62,8 @@ class MidjourneyService: original_format = 'PNG' if width < 500 or height < 500: - logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048") - return None + logger.warning(f"图像尺寸较小: {width}x{height},跳过分割") + return [image_url] half_width = width // 2 half_height = height // 2 @@ -306,6 +106,7 @@ class MidjourneyService: logger.info(f"成功保存分割图片 {i}/4: {filename}") except Exception as e: logger.error(f"保存分割图片 {i}/4 失败: {str(e)}") + # 清理已保存的文件 for url in image_urls: try: file_path = os.path.join(save_dir, os.path.basename(url)) @@ -326,207 +127,30 @@ class MidjourneyService: logger.error(f"分割图像失败: {str(e)}") return None - async def split_images(self, image_urls: List[str], config: Optional[Dict] = None): - """批量处理多个图像URL""" - if not config: - config = {} - - total = len(image_urls) - success_count = 0 - - for i, image_url in enumerate(image_urls, 1): - try: - if not is_valid_image_url(image_url): - yield { - "status": "error", - "index": i, - "total": total, - "success_count": success_count, - "message": "无效的图片URL" - } - continue - - split_urls = await self.split_image(image_url) - if split_urls and len(split_urls) == 4: - success_count += 1 - yield { - "status": "success", - "index": i, - "total": total, - "success_count": success_count, - "images": split_urls - } - else: - yield { - "status": "error", - "index": i, - "total": total, - "success_count": success_count, - "message": "分割图片失败" - } - except Exception as e: - yield { - "status": "error", - "index": i, - "total": total, - "success_count": success_count, - "message": str(e) - } + # ==================== VectorEngine API 方法 ==================== - async def generate_image(self, prompt, config=None): - """生成图像并以流式方式返回结果""" - if not config: - config = {} - - oauth_token = config.get('oauth_token', settings.midjourney_oauth_token) - channel_id = config.get('channel_id', settings.midjourney_channel_id) - - if not oauth_token or not channel_id: - yield { - "status": "error", - "message": "缺少Discord配置", - "success_count": 0 - } - return - - try: - client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id) - except Exception as e: - yield { - "status": "error", - "message": str(e), - "success_count": 0 - } - return - - prompt = prompt.strip() - options = settings.midjourney_default_options.copy() - if 'options' in config: - options.update(config.get('options', {})) - - if "seed" not in options or not options.get('seed'): - options['seed'] = random.randint(0, 4294967295) - - if 'image_weight' in config and isinstance(config['image_weight'], (int, float)): - iw = max(0.1, min(config['image_weight'], 3)) - options['iw'] = iw - - seed = options.pop('seed', None) - - parameter = "" - no_value_key = ['relax', 'fast', 'turbo', 'tile'] - for key, value in options.items(): - if key in no_value_key: - parameter += f" --{key}" - else: - parameter += f" --{key} {value}" - - if seed: - parameter += f" --seed {seed}" - logger.info(f"使用的seed值: {seed}") - - if 'image_urls' in config and config['image_urls']: - image_urls = config['image_urls'] - if not isinstance(image_urls, list): - image_urls = [image_urls] - - new_image_urls = [] - for image_url in image_urls: - if is_valid_image_url(image_url): - try: - new_url = get_new_image_url(image_url) - if new_url: - new_image_urls.append(new_url) - except Exception as e: - logger.warning(f"转换图片URL失败: {str(e)}") - - if new_image_urls: - prompt = " ".join(new_image_urls) + " " + prompt - - if 'characters' in config and config['characters']: - char_urls = [] - characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']] - - for char_url in characters: - if is_valid_image_url(char_url): - try: - new_url = get_new_image_url(char_url) - if new_url: - char_urls.append(new_url) - except Exception as e: - logger.warning(f"转换角色图片URL失败: {str(e)}") - - if char_urls: - prompt = prompt + " --cref " + " ".join(char_urls) - - if 'styles' in config and config['styles']: - style_urls = [] - styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']] - - for style_url in styles: - if is_valid_image_url(style_url): - try: - new_url = get_new_image_url(style_url) - if new_url: - style_urls.append(new_url) - except Exception as e: - logger.warning(f"转换风格图片URL失败: {str(e)}") - - if style_urls: - prompt = prompt + " --sref " + " ".join(style_urls) - - prompt = prompt + parameter - success_count = 0 - - async for result in self.imagine(client, guild_id, channel_id, prompt, seed): - if result.get("status") == "progress": - yield { - "status": "progress", - "progress": result.get("progress", 0), - "seed": seed - } - elif result.get("status") == "success": - success_count += 1 - response = { - "status": "success", - "success_count": success_count, - "image_urls": result.get("images", []) - } - - if config.get("split_image", True) and result.get("images"): - try: - orig_image_url = result["images"][0] - split_urls = await self.split_image(orig_image_url) - if split_urls: - response["image_urls"] = split_urls - except Exception as e: - logger.error(f"分割图像失败: {str(e)}") - response["status"] = "error" - response["message"] = f"分割图像失败: {str(e)}" - - yield response - else: - yield { - "status": "error", - "message": result.get("message", "未知错误"), - "success_count": success_count - } - - async def ve_submit_imagine(self, prompt: str, base64_array: List[str] = None, - notify_hook: str = "", state: str = "", - bot_type: str = "MID_JOURNEY") -> Dict[str, Any]: + async def ve_submit_imagine( + self, + prompt: str, + base64_array: List[str] = None, + notify_hook: str = "", + state: str = "", + bot_type: str = "MID_JOURNEY" + ) -> Dict[str, Any]: """ 提交 imagine 任务到 VectorEngine API + API 文档: https://vectorengine.apifox.cn/api-349239131 + Args: prompt: 提示词 - base64_array: base64 编码的图片数组 + base64_array: base64 编码的图片数组(垫图) notify_hook: 回调地址 state: 自定义状态 bot_type: 机器人类型,默认 MID_JOURNEY Returns: - 包含任务ID的响应 + {"result": "任务ID"} 或 {"success": False, "message": "错误信息"} """ if base64_array is None: base64_array = [] @@ -553,8 +177,9 @@ class MidjourneyService: ) if response.status_code != 200: - logger.error(f"VectorEngine API 请求失败: {response.status_code} - {response.text}") - return {"success": False, "message": f"API请求失败: {response.status_code}"} + error_text = response.text + logger.error(f"VectorEngine 提交任务失败: {response.status_code} - {error_text}") + return {"success": False, "message": f"API请求失败: {response.status_code} - {error_text}"} result = response.json() logger.info(f"VectorEngine 提交任务成功: {result}") @@ -564,15 +189,25 @@ class MidjourneyService: logger.error(f"VectorEngine 提交任务异常: {str(e)}") return {"success": False, "message": str(e)} - async def ve_get_task_status(self, task_id: str) -> Dict[str, Any]: + async def ve_get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: """ - 获取 VectorEngine 任务状态 + 根据任务ID查询任务状态 + + API 文档: https://vectorengine.apifox.cn/api-349239132 Args: task_id: 任务ID Returns: - 任务状态信息 + 任务状态对象,包含: + - id: 任务ID + - action: 动作类型 (如 IMAGINE) + - status: 任务状态 (SUCCESS, IN_PROGRESS, FAILURE 等) + - progress: 进度字符串 (如 "100%") + - imageUrl: 生成的图片URL + - buttons: 操作按钮列表 + - failReason: 失败原因 + 等字段,失败返回 None """ headers = { "Authorization": f"Bearer {settings.vectorengine_token}", @@ -588,30 +223,109 @@ class MidjourneyService: if response.status_code != 200: logger.error(f"VectorEngine 获取任务状态失败: {response.status_code} - {response.text}") - return {"success": False, "message": f"获取任务状态失败: {response.status_code}"} + return None result = response.json() return result except Exception as e: logger.error(f"VectorEngine 获取任务状态异常: {str(e)}") - return {"success": False, "message": str(e)} + return None - async def ve_generate_image(self, prompt: str, config: Dict = None) -> AsyncGenerator[Dict, None]: + async def ve_get_task_list(self, task_ids: List[str]) -> Optional[List[Dict[str, Any]]]: + """ + 根据ID列表查询任务 + + API 文档: https://vectorengine.apifox.cn/api-349239133 + + Args: + task_ids: 任务ID列表 + + Returns: + 任务对象列表,失败返回 None + """ + headers = { + "Authorization": f"Bearer {settings.vectorengine_token}", + "Content-Type": "application/json" + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + f"{settings.vectorengine_api_url}/mj/task/list-by-condition", + json={"ids": task_ids}, + headers=headers + ) + + if response.status_code != 200: + logger.error(f"VectorEngine 批量查询任务失败: {response.status_code} - {response.text}") + return None + + result = response.json() + return result + + except Exception as e: + logger.error(f"VectorEngine 批量查询任务异常: {str(e)}") + return None + + async def ve_get_image_seed(self, task_id: str) -> Optional[Dict[str, Any]]: + """ + 获取任务图片的seed + + API 文档: https://vectorengine.apifox.cn/api-349239134 + + Args: + task_id: 任务ID + + Returns: + 包含 seed 信息的任务对象,失败返回 None + """ + headers = { + "Authorization": f"Bearer {settings.vectorengine_token}", + "Content-Type": "application/json" + } + + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.get( + f"{settings.vectorengine_api_url}/mj/task/{task_id}/image-seed", + headers=headers + ) + + if response.status_code != 200: + logger.error(f"VectorEngine 获取seed失败: {response.status_code} - {response.text}") + return None + + result = response.json() + return result + + except Exception as e: + logger.error(f"VectorEngine 获取seed异常: {str(e)}") + return None + + async def ve_generate_image( + self, + prompt: str, + config: Dict = None + ) -> AsyncGenerator[Dict, None]: """ 使用 VectorEngine API 生成图像并以流式方式返回结果 Args: prompt: 提示词 config: 配置参数,可包含: - - base64_array: base64 编码的图片数组 + - base64_array: base64 编码的图片数组(垫图) - notify_hook: 回调地址 - state: 自定义状态 - bot_type: 机器人类型 - split_image: 是否分割图片 (默认 True) Yields: - 生成状态和结果 + 生成状态和结果: + - {"status": "submitted", "task_id": "..."} # 任务已提交 + - {"status": "progress", "progress": 50, "task_id": "..."} # 进度更新 + - {"status": "success", "image_urls": [...], ...} # 成功完成 + - {"status": "error", "message": "..."} # 错误 """ if not config: config = {} @@ -619,12 +333,11 @@ class MidjourneyService: if not settings.vectorengine_token: yield { "status": "error", - "message": "未配置 VectorEngine Token", - "success_count": 0 + "message": "未配置 VectorEngine Token" } return - - # 提交任务 + + # 1. 提交任务 submit_result = await self.ve_submit_imagine( prompt=prompt, base64_array=config.get("base64_array", []), @@ -633,21 +346,12 @@ class MidjourneyService: bot_type=config.get("bot_type", "MID_JOURNEY") ) - if not submit_result.get("success", True) and "result" not in submit_result: - yield { - "status": "error", - "message": submit_result.get("message", "提交任务失败"), - "success_count": 0 - } - return - - # 获取任务ID + # 检查提交结果 - VectorEngine 返回 {"result": "任务ID"} 表示成功 task_id = submit_result.get("result") if not task_id: yield { "status": "error", - "message": "未获取到任务ID", - "success_count": 0 + "message": submit_result.get("message", "提交任务失败,未获取到任务ID") } return @@ -656,10 +360,10 @@ class MidjourneyService: "status": "submitted", "task_id": task_id } - - # 轮询任务状态 + + # 2. 轮询任务状态 polling_count = 0 - last_progress = 0 + last_progress = -1 while polling_count < settings.vectorengine_max_polling_attempts: await asyncio.sleep(settings.vectorengine_polling_interval) @@ -668,26 +372,27 @@ class MidjourneyService: task_status = await self.ve_get_task_status(task_id) if not task_status: + logger.warning(f"第 {polling_count} 次轮询未获取到任务状态") continue - status = task_status.get("status") - progress = task_status.get("progress", 0) + status = task_status.get("status", "") + progress_str = task_status.get("progress", "0%") - # 解析进度值(可能是数字或带%的字符串如 '0%') + # 解析进度值(可能是数字或带%的字符串如 '0%', '100%') try: - if isinstance(progress, str): - progress_val = int(progress.strip().rstrip('%')) + if isinstance(progress_str, str): + progress_val = int(progress_str.strip().rstrip('%')) else: - progress_val = int(progress) + progress_val = int(progress_str) except (ValueError, TypeError): progress_val = 0 - # 发送进度更新 + # 发送进度更新(仅当进度有变化时) if progress_val > last_progress: last_progress = progress_val yield { "status": "progress", - "progress": last_progress, + "progress": progress_val, "task_id": task_id } @@ -696,37 +401,33 @@ class MidjourneyService: image_url = task_status.get("imageUrl") image_urls = [image_url] if image_url else [] - # 检查是否有多个图片URL(对于网格图) - if "buttons" in task_status: - # 可能包含 U1, U2, U3, U4 按钮对应的图片 - pass - response = { "status": "success", - "success_count": 1, "task_id": task_id, - "image_urls": image_urls + "image_urls": image_urls, + "buttons": task_status.get("buttons", []), + "properties": task_status.get("properties", {}) } - # 如果配置了分割图片,则处理分割 + # 如果配置了分割图片且有图片URL,则处理分割 if config.get("split_image", True) and image_url: try: split_urls = await self.split_image(image_url) - if split_urls: + if split_urls and len(split_urls) == 4: response["image_urls"] = split_urls - logger.info(f"图片分割成功,生成 {len(split_urls)} 张子图") + logger.info(f"图片分割成功,生成 4 张子图") except Exception as e: logger.error(f"分割图像失败: {str(e)}") yield response return - elif status == "FAILURE" or status == "FAILED": + elif status in ("FAILURE", "FAILED"): + fail_reason = task_status.get("failReason", "任务失败") yield { "status": "error", - "message": task_status.get("failReason", "任务失败"), - "task_id": task_id, - "success_count": 0 + "message": fail_reason, + "task_id": task_id } return @@ -734,15 +435,16 @@ class MidjourneyService: yield { "status": "error", "message": "任务已取消", - "task_id": task_id, - "success_count": 0 + "task_id": task_id } return - + + # 其他状态(如 IN_PROGRESS, SUBMITTED)继续轮询 + logger.debug(f"任务 {task_id} 状态: {status}, 进度: {progress_str}") + # 超时 yield { "status": "error", - "message": "任务超时", - "task_id": task_id, - "success_count": 0 - } \ No newline at end of file + "message": f"任务超时,已轮询 {settings.vectorengine_max_polling_attempts} 次", + "task_id": task_id + } diff --git a/apps/jmidjourney/settings.py b/apps/jmidjourney/settings.py index 54a69a0..e979c5f 100644 --- a/apps/jmidjourney/settings.py +++ b/apps/jmidjourney/settings.py @@ -1,7 +1,8 @@ from pydantic_settings import BaseSettings -from typing import Optional, Dict +from typing import Optional from functools import lru_cache + class Settings(BaseSettings): # Japi Server 配置 host: str = "0.0.0.0" @@ -10,7 +11,7 @@ class Settings(BaseSettings): # API路由配置 router_prefix: str = "/jmidjourney" - generate_route: str = "/generate" # 生成图片的路由 + generate_route: str = "/generate" # 保留用于兼容 ve_generate_route: str = "/ve/generate" # VectorEngine 生成图片的路由 api_name: str = "jmidjourney" # 默认API名称 @@ -32,41 +33,21 @@ class Settings(BaseSettings): jingrow_api_key: Optional[str] = None jingrow_api_secret: Optional[str] = None - # Discord Midjourney配置 - midjourney_api_url: str = "https://discord.com/api/v9" - midjourney_application_id: str = "936929561302675456" - midjourney_data_id: str = "938956540159881230" - midjourney_data_version: str = "1237876415471554623" - midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029" - midjourney_channel_id: str = "1259838588510670941" - midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0" - midjourney_suffix: str = "mj" # 图片文件名的后缀 - - # 代理配置 - http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理 - https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理 - - # Midjourney默认选项 - midjourney_default_options: Dict = { - "ar": "1:1", - "v": "6.1", - "quality": "1" - } - - # 图像设置 - add_title_to_image_name: bool = False # 是否将标题添加到图片名称中 + # 代理配置(用于下载外部图片) + http_proxy: Optional[str] = None + https_proxy: Optional[str] = None # 超时设置(秒) request_timeout: int = 30 - max_polling_attempts: int = 60 - polling_interval: int = 3 class Config: env_file = ".env" + @lru_cache() def get_settings() -> Settings: return Settings() + # 创建全局配置实例 -settings = get_settings() \ No newline at end of file +settings = get_settings()