import json import base64 import requests import websocket import uuid import urllib.request import asyncio import os from typing import Dict, List, Generator, Optional, AsyncGenerator from settings import settings from PIL import Image import io import tempfile import numpy as np # 默认配置 default_config = { "comfyui_server_address": settings.comfyui_server_address, "upscale_model_name": "4xNomos2_otf_esrgan.pth", } save_dir = settings.save_dir download_url = settings.download_url # 定义基础工作流 JSON 模板 workflow_template = """ { "13": { "inputs": { "model_name": "" }, "class_type": "UpscaleModelLoader" }, "14": { "inputs": { "upscale_model": ["13", 0], "image": ["15", 0] }, "class_type": "ImageUpscaleWithModel" }, "15": { "inputs": { "url_or_path": "" }, "class_type": "LoadImageFromUrlOrPath" }, "16": { "inputs": { "images": ["14", 0] }, "class_type": "SaveImageWebsocket" } } """ class ImageUpscaleService: def __init__(self): """初始化图像放大服务""" pass def check_image_transparency(self, image_url: str) -> tuple: """检查图像是否有透明通道,返回图像和是否透明的标志""" try: # 下载图片 response = requests.get(image_url) if response.status_code != 200: raise Exception(f"无法下载图片: {image_url}") # 使用PIL打开图片 img = Image.open(io.BytesIO(response.content)) # 检查图像是否有透明通道 has_transparency = img.mode in ('RGBA', 'LA') and img.format == 'PNG' return img, has_transparency except Exception as e: raise Exception(f"图片处理失败: {str(e)}") def prepare_image_for_upscale(self, image_url: str) -> tuple: """根据图像类型准备图像用于放大,返回处理后的图像URL和透明标志""" img, has_transparency = self.check_image_transparency(image_url) if not has_transparency: # 非透明图像直接使用原图 return image_url, False, None # 对于透明PNG,我们需要分离RGB和Alpha通道 rgb_image = img.convert('RGB') alpha_channel = img.split()[-1] # 保存RGB图像到临时文件 rgb_temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) rgb_image.save(rgb_temp_file.name, 'JPEG', quality=95) # 保存Alpha通道到临时文件 alpha_temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) alpha_channel.save(alpha_temp_file.name, 'PNG') return rgb_temp_file.name, True, alpha_temp_file.name def upscale_alpha_channel(self, alpha_path: str, scale_factor: int = 4) -> Image.Image: """使用双线性插值放大Alpha通道""" alpha_img = Image.open(alpha_path) width, height = alpha_img.size new_width, new_height = width * scale_factor, height * scale_factor return alpha_img.resize((new_width, new_height), Image.BILINEAR) def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: """将提示词发送到 ComfyUI 服务器的队列中""" p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode('utf-8') req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) return json.loads(urllib.request.urlopen(req).read()) def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: """从 ComfyUI 获取生成的图像""" try: prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) prompt_id = prompt_response['prompt_id'] except KeyError: return {} output_images = {} current_node = "" while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'executing': data = message['data'] if data.get('prompt_id') == prompt_id: if data['node'] is None: break else: current_node = data['node'] else: if current_node == '16': # 放大图像节点 images_output = output_images.get(current_node, []) images_output.append(out[8:]) output_images[current_node] = images_output return output_images def upscale_image_sync(self, image_url: str, config: Optional[Dict] = None) -> Generator[str, None, None]: """放大图像,保存到本地并返回图片URL""" cfg = default_config.copy() if config: cfg.update(config) ws = websocket.WebSocket() client_id = str(uuid.uuid4()) temp_file = None alpha_temp_file = None has_transparency = False try: # 准备图像用于放大 image_path, has_transparency, alpha_path = self.prepare_image_for_upscale(image_url) if image_path != image_url: temp_file = image_path image_url = image_path if has_transparency: alpha_temp_file = alpha_path ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") workflow = json.loads(workflow_template) workflow["13"]["inputs"]["model_name"] = cfg['upscale_model_name'] workflow["15"]["inputs"]["url_or_path"] = image_url images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) os.makedirs(save_dir, exist_ok=True) for node_id, image_list in images_dict.items(): for image_data in image_list: if has_transparency: # 处理带透明通道的图像 # 保存放大后的RGB图像 upscaled_rgb_path = os.path.join(save_dir, f"upscaled_rgb_{uuid.uuid4().hex[:10]}.png") with open(upscaled_rgb_path, "wb") as f: f.write(image_data) # 打开放大后的RGB图像 upscaled_rgb = Image.open(upscaled_rgb_path) # 放大Alpha通道 upscaled_alpha = self.upscale_alpha_channel(alpha_temp_file, scale_factor=upscaled_rgb.width//Image.open(temp_file).width) # 确保尺寸匹配 if upscaled_rgb.size != upscaled_alpha.size: upscaled_alpha = upscaled_alpha.resize(upscaled_rgb.size, Image.BILINEAR) # 合并通道 upscaled_rgba = upscaled_rgb.copy() upscaled_rgba.putalpha(upscaled_alpha) # 保存最终的RGBA图像 png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png" png_file_path = os.path.join(save_dir, png_filename) upscaled_rgba.save(png_file_path, "PNG") # 删除临时RGB文件 os.remove(upscaled_rgb_path) # 返回PNG URL image_url = f"{download_url}/{png_filename}" else: # 处理没有透明通道的图像 # 保存为JPG以减小文件大小 png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png" png_file_path = os.path.join(save_dir, png_filename) with open(png_file_path, "wb") as f: f.write(image_data) # 打开图像并转换为JPG img = Image.open(png_file_path) jpg_filename = png_filename.replace('.png', '.jpg') jpg_file_path = os.path.join(save_dir, jpg_filename) img = img.convert('RGB') img.save(jpg_file_path, 'JPEG', quality=95) # 删除PNG临时文件 os.remove(png_file_path) # 返回JPG URL image_url = f"{download_url}/{jpg_filename}" yield image_url except Exception as e: raise e finally: if ws: ws.close() # 清理临时文件 if temp_file and os.path.exists(temp_file): os.unlink(temp_file) if alpha_temp_file and os.path.exists(alpha_temp_file): os.unlink(alpha_temp_file) async def upscale_image(self, image_url: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: """异步放大图像,返回图片URL""" try: # 在这种情况下,我们需要手动运行同步生成器并收集结果 loop = asyncio.get_event_loop() urls = await loop.run_in_executor(None, lambda: list(self.upscale_image_sync(image_url, config))) # 逐个返回结果 for url in urls: yield { "status": "success", "image_url": url, "message": "图片已保存" } except Exception as e: yield { "status": "error", "message": f"图像放大失败: {str(e)}" } async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None): """批量处理多个图像URL,返回图片URL""" total = len(image_urls) success_count = 0 error_count = 0 for i, image_url in enumerate(image_urls, 1): try: # 获取图片透明度信息 try: _, has_transparency = self.check_image_transparency(image_url) transparency_info = "PNG带透明通道" if has_transparency else "无透明通道" except: transparency_info = "未检测" async for result in self.upscale_image(image_url, config): if result["status"] == "success": success_count += 1 yield { "index": i, "total": total, "original_image_url": image_url, "status": "success", "image_url": result["image_url"], "success_count": success_count, "error_count": error_count, "transparency": transparency_info, "message": result["message"] } else: error_count += 1 yield { "index": i, "total": total, "original_image_url": image_url, "status": "error", "success_count": success_count, "error_count": error_count, "message": result["message"] } except Exception as e: error_count += 1 yield { "index": i, "total": total, "original_image_url": image_url, "status": "error", "success_count": success_count, "error_count": error_count, "message": f"处理图像时出错: {str(e)}" }