diff --git a/apps/midjourney/service.py b/apps/midjourney/service.py index 8599a08..1688e50 100644 --- a/apps/midjourney/service.py +++ b/apps/midjourney/service.py @@ -5,19 +5,15 @@ import io import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -import time import random import re import uuid import urllib.request import urllib3 -import traceback import logging -import base64 from pathlib import Path from urllib.parse import urlparse from PIL import Image -import mimetypes import asyncio from typing import Dict, Any, List, AsyncGenerator, Optional from settings import settings @@ -63,7 +59,6 @@ class MidjourneyService: """初始化Discord客户端会话""" client = requests.Session() - # 添加重试机制 retry = Retry(total=3, backoff_factor=0.5) adapter = HTTPAdapter(max_retries=retry) client.mount('http://', adapter) @@ -73,17 +68,14 @@ class MidjourneyService: 'Authorization': oauth_token }) - # 设置代理 if self.proxies: client.proxies.update(self.proxies) try: - # 获取频道信息 response = client.get(f'{self.API_URL}/channels/{channel_id}') data = response.json() guild_id = data['guild_id'] - # 获取用户信息 response = client.get(f'{self.API_URL}/users/@me') data = response.json() user_id = data['id'] @@ -91,7 +83,6 @@ class MidjourneyService: return client, guild_id, user_id except Exception as e: logger.error(f"初始化Discord客户端失败: {str(e)}") - traceback.print_exc() raise Exception(f"初始化Discord客户端失败: {str(e)}") async def imagine(self, client, guild_id, channel_id, prompt, seed=None): @@ -134,28 +125,20 @@ class MidjourneyService: } try: - # 发送请求 - r = client.post(f'{self.API_URL}/interactions', json=params) - # 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求 - print(f"[生成] 已发送请求,等待20秒后开始轮询...") + client.post(f'{self.API_URL}/interactions', json=params) await asyncio.sleep(20) - # 轮询获取结果 imagine_message = None count = 0 last_progress = 0 - # 轮询直到获取完整的结果或达到最大次数 while count < settings.max_polling_attempts: - # 轮询两种消息: - # 1. 进度消息 - 用于更新进度 progress_msg = await self.get_progress_message(client, channel_id, prompt, seed) if progress_msg: content = progress_msg.get('content', '') progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: progress_value = int(progress_match.group(1)) - # 只有进度有变化时才发送更新 if progress_value > last_progress: last_progress = progress_value logger.info(f"生成进度: {progress_value}%") @@ -166,30 +149,23 @@ class MidjourneyService: "content": content } - # 2. 完成的消息 - 包含图像结果 imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed) if imagine_message: - # 找到最终结果 break - # 没有找到结果,继续等待 - logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待") await asyncio.sleep(settings.polling_interval) count += 1 - # 检查是否超过最大轮询次数 if count >= settings.max_polling_attempts: logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}") yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"} return - # 检查是否有有效的最终结果 if not imagine_message: logger.error("轮询结束但没有获取到有效结果") yield {"status": "error", "message": "没有获取到有效结果"} return - # 返回最终结果 logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}") yield { "status": "success", @@ -200,68 +176,38 @@ class MidjourneyService: except Exception as e: logger.error(f"发送Imagine请求失败: {str(e)}") - traceback.print_exc() yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"} async def get_imagine(self, client, channel_id, prompt, count=0, seed=None): """获取生成图像的消息""" try: - # 获取最近的消息 response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') data = response.json() def criteria(item): content = item.get('content', '') - # 检查进度信息并更新状态 if seed is not None and f"--seed {seed}" in content: - # 匹配百分比,支持多种格式:(93%) 或 93% progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: progress_value = int(progress_match.group(1)) - # 记录进度信息 logger.info(f"任务进度: {progress_value}%") - # 如果消息包含百分比,说明任务还在进行中,返回False继续轮询 if "%" in content: - print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成") return False - # 排除进行中消息,只匹配完成的消息 if "%" in content: return False - # seed 匹配 if seed is not None: seed_pattern = f"--seed {seed}" - # 检查是否包含指定的seed if seed_pattern not in content: return False else: - print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成") - # 检查是否包含图像附件 if 'attachments' in item and len(item.get('attachments', [])) > 0: return True - return False # 默认不匹配 + return False - # 查找匹配完成状态的消息 - raw_message = self.first_where(data, criteria) - if raw_message is None: - print("[轮询] 没有找到完成的消息,继续等待") - return None - - # 打印匹配到的完整消息内容 - print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}") - try: - print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}") - except Exception as e: - print(f"[轮询] 无法打印完整消息内容: {str(e)}") - # 尝试打印关键字段 - print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}") - print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}") - for i, attachment in enumerate(raw_message.get('attachments', [])): - print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}") - - return raw_message + return self.first_where(data, criteria) except Exception as e: logger.error(f"获取通道消息失败: {str(e)}") @@ -270,13 +216,11 @@ class MidjourneyService: async def get_progress_message(self, client, channel_id, prompt, seed=None): """获取包含进度信息的消息""" try: - # 获取最近的消息 response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') data = response.json() for item in data: content = item.get('content', '') - # 检查是否包含seed和进度信息 if seed is not None and f"--seed {seed}" in content and "%" in content: progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) if progress_match: @@ -298,69 +242,48 @@ class MidjourneyService: return image_urls async def split_image(self, image_url): - """将一张大图切割成四张子图,保存到本地并返回URLs - - Args: - image_url: 原始图像的URL - - Returns: - 包含四个子图像URL的列表,如果处理失败则返回None - """ + """将一张大图切割成四张子图,保存到本地并返回URLs""" try: - # 下载图像 response = requests.get(image_url, proxies=self.proxies if self.proxies else None, timeout=30) if response.status_code != 200: logger.error(f"下载图像失败,状态码: {response.status_code}") return None image_data = response.content - - # 从二进制数据创建图像对象 img = Image.open(io.BytesIO(image_data)) width, height = img.size - # 获取原始图片格式 original_format = img.format if not original_format: - # 如果无法获取格式,从URL中提取 parsed_url = urlparse(image_url) original_format = os.path.splitext(parsed_url.path)[1][1:].upper() if not original_format: - original_format = 'PNG' # 默认使用PNG + original_format = 'PNG' - # 确认图像尺寸约为2048x2048 if width < 1500 or height < 1500: logger.error(f"图像尺寸不符合预期: {width}x{height}, 应该接近2048x2048") return None - # 计算每个象限的尺寸 half_width = width // 2 half_height = height // 2 - # 分割图像为四个象限 quadrants = [ - img.crop((0, 0, half_width, half_height)), # 左上 - img.crop((half_width, 0, width, half_height)), # 右上 - img.crop((0, half_height, half_width, height)), # 左下 - img.crop((half_width, half_height, width, height)) # 右下 + img.crop((0, 0, half_width, half_height)), + img.crop((half_width, 0, width, half_height)), + img.crop((0, half_height, half_width, height)), + img.crop((half_width, half_height, width, height)) ] - # 生成唯一的图片名称前缀 image_id = uuid.uuid4().hex[:10] - - # 确保保存目录存在(使用绝对路径) save_dir = os.path.abspath(settings.save_dir) os.makedirs(save_dir, exist_ok=True) - # 保存图片到本地并生成URLs image_urls = [] for i, quadrant in enumerate(quadrants, 1): try: - # 使用原始格式保存 filename = f"split_{image_id}_{i}.{original_format.lower()}" file_path = os.path.join(save_dir, filename) - # 保存图片,保持原始格式 save_params = {"format": original_format} if original_format in ['PNG', 'JPEG', 'JPG']: save_params["optimize"] = True @@ -369,23 +292,19 @@ class MidjourneyService: quadrant.save(file_path, **save_params) - # 验证文件是否成功保存 if not os.path.exists(file_path): raise Exception(f"文件保存失败: {file_path}") - # 验证文件大小 file_size = os.path.getsize(file_path) if file_size == 0: raise Exception(f"保存的文件大小为0: {file_path}") - # 构建图片URL image_url = f"{settings.download_url}/{filename}" image_urls.append(image_url) - logger.info(f"成功保存分割图片 {i}/4: {filename} (大小: {file_size} 字节, 格式: {original_format})") + logger.info(f"成功保存分割图片 {i}/4: {filename}") except Exception as e: logger.error(f"保存分割图片 {i}/4 失败: {str(e)}") - # 如果保存失败,删除已保存的图片 for url in image_urls: try: file_path = os.path.join(save_dir, os.path.basename(url)) @@ -399,12 +318,11 @@ class MidjourneyService: logger.error(f"分割图片数量不正确: 期望4张,实际{len(image_urls)}张") return None - logger.info(f"成功完成图片分割,生成4张子图") + logger.info("成功完成图片分割,生成4张子图") return image_urls except Exception as e: logger.error(f"分割图像失败: {str(e)}") - traceback.print_exc() return None async def split_images(self, image_urls: List[str], config: Optional[Dict] = None): @@ -418,95 +336,76 @@ class MidjourneyService: for i, image_url in enumerate(image_urls, 1): try: if not is_valid_image_url(image_url): - response = { + yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": "无效的图片URL" } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response continue split_urls = await self.split_image(image_url) if split_urls and len(split_urls) == 4: success_count += 1 - response = { + yield { "status": "success", "index": i, "total": total, "success_count": success_count, "images": split_urls } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response else: - response = { + yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": "分割图片失败" } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response except Exception as e: - response = { + yield { "status": "error", "index": i, "total": total, "success_count": success_count, "message": str(e) } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response async def generate_image(self, prompt, config=None): """生成图像并以流式方式返回结果""" if not config: config = {} - # 获取必要的认证信息 oauth_token = config.get('oauth_token', settings.midjourney_oauth_token) channel_id = config.get('channel_id', settings.midjourney_channel_id) if not oauth_token or not channel_id: - response = { + yield { "status": "error", "message": "缺少Discord配置", "success_count": 0 } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response return - # 初始化客户端 try: client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id) except Exception as e: - response = { + yield { "status": "error", "message": str(e), "success_count": 0 } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response return - # 解析和准备提示词 prompt = prompt.strip() - - # 获取或设置默认选项 options = settings.midjourney_default_options.copy() if 'options' in config: options.update(config.get('options', {})) - # 强制设置seed如果没有 if "seed" not in options or not options.get('seed'): options['seed'] = random.randint(0, 4294967295) - # 处理选项,构建参数字符串 seed = options.pop('seed', None) parameter = "" no_value_key = ['relax', 'fast', 'turbo', 'tile'] @@ -516,20 +415,15 @@ class MidjourneyService: else: parameter += f" --{key} {value}" - # 添加seed if seed: parameter += f" --seed {seed}" - # 打印使用的seed值 - print(f"[生成] 使用的seed值: {seed}") + logger.info(f"使用的seed值: {seed}") - # 处理参考图像 if 'image_urls' in config and config['image_urls']: - # 确保是列表格式 image_urls = config['image_urls'] if not isinstance(image_urls, list): image_urls = [image_urls] - # 转换图片URL new_image_urls = [] for image_url in image_urls: if is_valid_image_url(image_url): @@ -540,16 +434,13 @@ class MidjourneyService: except Exception as e: logger.warning(f"转换图片URL失败: {str(e)}") - # 添加到prompt前面 if new_image_urls: prompt = " ".join(new_image_urls) + " " + prompt - # 设置图像权重 if 'image_weight' in config and isinstance(config['image_weight'], (int, float)): - iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间 + iw = max(0.1, min(config['image_weight'], 3)) parameter += f" --iw {iw}" - # 添加字符引用 if 'characters' in config and config['characters']: char_urls = [] characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']] @@ -566,7 +457,6 @@ class MidjourneyService: if char_urls: prompt = prompt + " --cref " + " ".join(char_urls) - # 添加风格引用 if 'styles' in config and config['styles']: style_urls = [] styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']] @@ -583,31 +473,24 @@ class MidjourneyService: if style_urls: prompt = prompt + " --sref " + " ".join(style_urls) - # 添加参数 prompt = prompt + parameter - - # 流式返回结果 success_count = 0 async for result in self.imagine(client, guild_id, channel_id, prompt, seed): if result.get("status") == "progress": - # 进度信息只包含必要字段 - response = { + yield { "status": "progress", "progress": result.get("progress", 0), - "seed": seed # 添加seed字段 + "seed": seed } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response elif result.get("status") == "success": success_count += 1 response = { "status": "success", "success_count": success_count, - "image_urls": result.get("images", []) # 添加image_urls字段 + "image_urls": result.get("images", []) } - # 如果需要分割图片 if config.get("split_image", True) and result.get("images"): try: orig_image_url = result["images"][0] @@ -619,14 +502,10 @@ class MidjourneyService: response["status"] = "error" response["message"] = f"分割图像失败: {str(e)}" - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") yield response else: - # 错误信息保持统一格式 - response = { + yield { "status": "error", "message": result.get("message", "未知错误"), "success_count": success_count - } - print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") - yield response \ No newline at end of file + } \ No newline at end of file diff --git a/apps/midjourney/settings.py b/apps/midjourney/settings.py index 63661db..d1e613f 100644 --- a/apps/midjourney/settings.py +++ b/apps/midjourney/settings.py @@ -6,7 +6,7 @@ class Settings(BaseSettings): # Japi Server 配置 host: str = "0.0.0.0" port: int = 8113 - debug: bool = True + debug: bool = False # API路由配置 router_prefix: str = "/midjourney"