338 lines
13 KiB
Python
338 lines
13 KiB
Python
import json
|
||
import base64
|
||
import requests
|
||
import websocket
|
||
import uuid
|
||
import urllib.request
|
||
import asyncio
|
||
import os
|
||
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
||
from settings import settings
|
||
from PIL import Image
|
||
import io
|
||
import tempfile
|
||
import numpy as np
|
||
|
||
# 默认配置
|
||
default_config = {
|
||
"comfyui_server_address": settings.comfyui_server_address,
|
||
"upscale_model_name": "4xNomos2_otf_esrgan.pth",
|
||
}
|
||
|
||
save_dir = settings.save_dir
|
||
download_url = settings.download_url
|
||
|
||
# 定义基础工作流 JSON 模板
|
||
workflow_template = """
|
||
{
|
||
"13": {
|
||
"inputs": {
|
||
"model_name": ""
|
||
},
|
||
"class_type": "UpscaleModelLoader",
|
||
"_meta": {
|
||
"title": "Load Upscale Model"
|
||
}
|
||
},
|
||
"14": {
|
||
"inputs": {
|
||
"upscale_model": [
|
||
"13",
|
||
0
|
||
],
|
||
"image": [
|
||
"15",
|
||
0
|
||
]
|
||
},
|
||
"class_type": "ImageUpscaleWithModel",
|
||
"_meta": {
|
||
"title": "Upscale Image (using Model)"
|
||
}
|
||
},
|
||
"15": {
|
||
"inputs": {
|
||
"url_or_path": ""
|
||
},
|
||
"class_type": "LoadImageFromUrlOrPath",
|
||
"_meta": {
|
||
"title": "LoadImageFromUrlOrPath"
|
||
}
|
||
},
|
||
"16": {
|
||
"inputs": {
|
||
"images": [
|
||
"14",
|
||
0
|
||
]
|
||
},
|
||
"class_type": "SaveImageWebsocket",
|
||
"_meta": {
|
||
"title": "SaveImageWebsocket"
|
||
}
|
||
}
|
||
}
|
||
"""
|
||
|
||
class ImageUpscaleService:
|
||
def __init__(self):
|
||
"""初始化图像放大服务"""
|
||
pass
|
||
|
||
def check_image_transparency(self, image_url: str) -> tuple:
|
||
"""检查图像是否有透明通道,返回图像和是否透明的标志"""
|
||
try:
|
||
# 下载图片
|
||
response = requests.get(image_url)
|
||
if response.status_code != 200:
|
||
raise Exception(f"无法下载图片: {image_url}")
|
||
|
||
# 使用PIL打开图片
|
||
img = Image.open(io.BytesIO(response.content))
|
||
|
||
# 检查图像是否有透明通道
|
||
has_transparency = img.mode in ('RGBA', 'LA') and img.format == 'PNG'
|
||
|
||
return img, has_transparency
|
||
except Exception as e:
|
||
raise Exception(f"图片处理失败: {str(e)}")
|
||
|
||
def prepare_image_for_upscale(self, image_url: str) -> tuple:
|
||
"""根据图像类型准备图像用于放大,返回处理后的图像URL和透明标志"""
|
||
img, has_transparency = self.check_image_transparency(image_url)
|
||
|
||
if not has_transparency:
|
||
# 非透明图像直接使用原图
|
||
return image_url, False, None
|
||
|
||
# 对于透明PNG,我们需要分离RGB和Alpha通道
|
||
rgb_image = img.convert('RGB')
|
||
alpha_channel = img.split()[-1]
|
||
|
||
# 保存RGB图像到临时文件
|
||
rgb_temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
|
||
rgb_image.save(rgb_temp_file.name, 'JPEG', quality=95)
|
||
|
||
# 保存Alpha通道到临时文件
|
||
alpha_temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
||
alpha_channel.save(alpha_temp_file.name, 'PNG')
|
||
|
||
return rgb_temp_file.name, True, alpha_temp_file.name
|
||
|
||
def upscale_alpha_channel(self, alpha_path: str, scale_factor: int = 4) -> Image.Image:
|
||
"""使用双线性插值放大Alpha通道"""
|
||
alpha_img = Image.open(alpha_path)
|
||
width, height = alpha_img.size
|
||
new_width, new_height = width * scale_factor, height * scale_factor
|
||
return alpha_img.resize((new_width, new_height), Image.BILINEAR)
|
||
|
||
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
||
p = {"prompt": prompt, "client_id": client_id}
|
||
data = json.dumps(p).encode('utf-8')
|
||
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
||
response = json.loads(urllib.request.urlopen(req).read())
|
||
return response
|
||
|
||
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||
"""从 ComfyUI 获取生成的图像"""
|
||
try:
|
||
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||
prompt_id = prompt_response['prompt_id']
|
||
except KeyError:
|
||
return {}
|
||
|
||
output_images = {}
|
||
current_node = ""
|
||
while True:
|
||
out = ws.recv()
|
||
if isinstance(out, str):
|
||
message = json.loads(out)
|
||
if message['type'] == 'executing':
|
||
data = message['data']
|
||
if data.get('prompt_id') == prompt_id:
|
||
if data['node'] is None:
|
||
break
|
||
else:
|
||
current_node = data['node']
|
||
else:
|
||
if current_node == '16': # 放大图像节点
|
||
images_output = output_images.get(current_node, [])
|
||
images_output.append(out[8:])
|
||
output_images[current_node] = images_output
|
||
|
||
return output_images
|
||
|
||
def upscale_image_sync(self, image_url: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
||
"""放大图像,保存到本地并返回图片URL"""
|
||
cfg = default_config.copy()
|
||
if config:
|
||
cfg.update(config)
|
||
ws = websocket.WebSocket()
|
||
client_id = str(uuid.uuid4())
|
||
temp_file = None
|
||
alpha_temp_file = None
|
||
has_transparency = False
|
||
|
||
try:
|
||
# 准备图像用于放大
|
||
image_path, has_transparency, alpha_path = self.prepare_image_for_upscale(image_url)
|
||
if image_path != image_url:
|
||
temp_file = image_path
|
||
image_url = image_path
|
||
if has_transparency:
|
||
alpha_temp_file = alpha_path
|
||
|
||
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||
workflow = json.loads(workflow_template)
|
||
workflow["13"]["inputs"]["model_name"] = cfg['upscale_model_name']
|
||
workflow["15"]["inputs"]["url_or_path"] = image_url
|
||
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
for node_id, image_list in images_dict.items():
|
||
for image_data in image_list:
|
||
if has_transparency:
|
||
# 处理带透明通道的图像
|
||
# 保存放大后的RGB图像
|
||
upscaled_rgb_path = os.path.join(save_dir, f"upscaled_rgb_{uuid.uuid4().hex[:10]}.png")
|
||
with open(upscaled_rgb_path, "wb") as f:
|
||
f.write(image_data)
|
||
|
||
# 打开放大后的RGB图像
|
||
upscaled_rgb = Image.open(upscaled_rgb_path)
|
||
|
||
# 放大Alpha通道
|
||
upscaled_alpha = self.upscale_alpha_channel(alpha_temp_file,
|
||
scale_factor=upscaled_rgb.width//Image.open(temp_file).width)
|
||
|
||
# 确保尺寸匹配
|
||
if upscaled_rgb.size != upscaled_alpha.size:
|
||
upscaled_alpha = upscaled_alpha.resize(upscaled_rgb.size, Image.BILINEAR)
|
||
|
||
# 合并通道
|
||
upscaled_rgba = upscaled_rgb.copy()
|
||
upscaled_rgba.putalpha(upscaled_alpha)
|
||
|
||
# 保存最终的RGBA图像
|
||
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
|
||
png_file_path = os.path.join(save_dir, png_filename)
|
||
upscaled_rgba.save(png_file_path, "PNG")
|
||
|
||
# 删除临时RGB文件
|
||
os.remove(upscaled_rgb_path)
|
||
|
||
# 返回PNG URL
|
||
image_url = f"{download_url}/{png_filename}"
|
||
else:
|
||
# 处理没有透明通道的图像
|
||
# 保存为JPG以减小文件大小
|
||
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
|
||
png_file_path = os.path.join(save_dir, png_filename)
|
||
with open(png_file_path, "wb") as f:
|
||
f.write(image_data)
|
||
|
||
# 打开图像并转换为JPG
|
||
img = Image.open(png_file_path)
|
||
jpg_filename = png_filename.replace('.png', '.jpg')
|
||
jpg_file_path = os.path.join(save_dir, jpg_filename)
|
||
img = img.convert('RGB')
|
||
img.save(jpg_file_path, 'JPEG', quality=95)
|
||
|
||
# 删除PNG临时文件
|
||
os.remove(png_file_path)
|
||
|
||
# 返回JPG URL
|
||
image_url = f"{download_url}/{jpg_filename}"
|
||
|
||
yield image_url
|
||
except Exception as e:
|
||
raise e
|
||
finally:
|
||
if ws:
|
||
ws.close()
|
||
# 清理临时文件
|
||
if temp_file and os.path.exists(temp_file):
|
||
os.unlink(temp_file)
|
||
if alpha_temp_file and os.path.exists(alpha_temp_file):
|
||
os.unlink(alpha_temp_file)
|
||
|
||
async def upscale_image(self, image_url: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||
"""异步放大图像,返回图片URL"""
|
||
try:
|
||
# 在这种情况下,我们需要手动运行同步生成器并收集结果
|
||
urls = []
|
||
|
||
# 在执行器中运行同步代码
|
||
def run_sync():
|
||
return list(self.upscale_image_sync(image_url, config))
|
||
|
||
# 获取所有URL
|
||
loop = asyncio.get_event_loop()
|
||
urls = await loop.run_in_executor(None, run_sync)
|
||
|
||
# 逐个返回结果
|
||
for url in urls:
|
||
yield {
|
||
"status": "success",
|
||
"image_url": url,
|
||
"message": "图片已保存"
|
||
}
|
||
except Exception as e:
|
||
yield {
|
||
"status": "error",
|
||
"message": f"图像放大失败: {str(e)}"
|
||
}
|
||
|
||
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
|
||
"""批量处理多个图像URL,返回图片URL"""
|
||
total = len(image_urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
for i, image_url in enumerate(image_urls, 1):
|
||
try:
|
||
# 获取图片透明度信息
|
||
try:
|
||
_, has_transparency = self.check_image_transparency(image_url)
|
||
transparency_info = "PNG带透明通道" if has_transparency else "无透明通道"
|
||
except:
|
||
transparency_info = "未检测"
|
||
|
||
async for result in self.upscale_image(image_url, config):
|
||
if result["status"] == "success":
|
||
success_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"original_image_url": image_url,
|
||
"status": "success",
|
||
"image_url": result["image_url"],
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"transparency": transparency_info,
|
||
"message": result["message"]
|
||
}
|
||
else:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"original_image_url": image_url,
|
||
"status": "error",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"message": result["message"]
|
||
}
|
||
except Exception as e:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"original_image_url": image_url,
|
||
"status": "error",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"message": f"处理图像时出错: {str(e)}"
|
||
}
|
||
|