japi/apps/jupscale/service.py
2025-05-12 02:39:56 +08:00

338 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import base64
import requests
import websocket
import uuid
import urllib.request
import asyncio
import os
from typing import Dict, List, Generator, Optional, AsyncGenerator
from settings import settings
from PIL import Image
import io
import tempfile
import numpy as np
# 默认配置
default_config = {
"comfyui_server_address": settings.comfyui_server_address,
"upscale_model_name": "4xNomos2_otf_esrgan.pth",
}
save_dir = settings.save_dir
download_url = settings.download_url
# 定义基础工作流 JSON 模板
workflow_template = """
{
"13": {
"inputs": {
"model_name": ""
},
"class_type": "UpscaleModelLoader",
"_meta": {
"title": "Load Upscale Model"
}
},
"14": {
"inputs": {
"upscale_model": [
"13",
0
],
"image": [
"15",
0
]
},
"class_type": "ImageUpscaleWithModel",
"_meta": {
"title": "Upscale Image (using Model)"
}
},
"15": {
"inputs": {
"url_or_path": ""
},
"class_type": "LoadImageFromUrlOrPath",
"_meta": {
"title": "LoadImageFromUrlOrPath"
}
},
"16": {
"inputs": {
"images": [
"14",
0
]
},
"class_type": "SaveImageWebsocket",
"_meta": {
"title": "SaveImageWebsocket"
}
}
}
"""
class ImageUpscaleService:
def __init__(self):
"""初始化图像放大服务"""
pass
def check_image_transparency(self, image_url: str) -> tuple:
"""检查图像是否有透明通道,返回图像和是否透明的标志"""
try:
# 下载图片
response = requests.get(image_url)
if response.status_code != 200:
raise Exception(f"无法下载图片: {image_url}")
# 使用PIL打开图片
img = Image.open(io.BytesIO(response.content))
# 检查图像是否有透明通道
has_transparency = img.mode in ('RGBA', 'LA') and img.format == 'PNG'
return img, has_transparency
except Exception as e:
raise Exception(f"图片处理失败: {str(e)}")
def prepare_image_for_upscale(self, image_url: str) -> tuple:
"""根据图像类型准备图像用于放大返回处理后的图像URL和透明标志"""
img, has_transparency = self.check_image_transparency(image_url)
if not has_transparency:
# 非透明图像直接使用原图
return image_url, False, None
# 对于透明PNG我们需要分离RGB和Alpha通道
rgb_image = img.convert('RGB')
alpha_channel = img.split()[-1]
# 保存RGB图像到临时文件
rgb_temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
rgb_image.save(rgb_temp_file.name, 'JPEG', quality=95)
# 保存Alpha通道到临时文件
alpha_temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
alpha_channel.save(alpha_temp_file.name, 'PNG')
return rgb_temp_file.name, True, alpha_temp_file.name
def upscale_alpha_channel(self, alpha_path: str, scale_factor: int = 4) -> Image.Image:
"""使用双线性插值放大Alpha通道"""
alpha_img = Image.open(alpha_path)
width, height = alpha_img.size
new_width, new_height = width * scale_factor, height * scale_factor
return alpha_img.resize((new_width, new_height), Image.BILINEAR)
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""将提示词发送到 ComfyUI 服务器的队列中"""
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
response = json.loads(urllib.request.urlopen(req).read())
return response
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""从 ComfyUI 获取生成的图像"""
try:
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
prompt_id = prompt_response['prompt_id']
except KeyError:
return {}
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == '16': # 放大图像节点
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
def upscale_image_sync(self, image_url: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
"""放大图像保存到本地并返回图片URL"""
cfg = default_config.copy()
if config:
cfg.update(config)
ws = websocket.WebSocket()
client_id = str(uuid.uuid4())
temp_file = None
alpha_temp_file = None
has_transparency = False
try:
# 准备图像用于放大
image_path, has_transparency, alpha_path = self.prepare_image_for_upscale(image_url)
if image_path != image_url:
temp_file = image_path
image_url = image_path
if has_transparency:
alpha_temp_file = alpha_path
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
workflow = json.loads(workflow_template)
workflow["13"]["inputs"]["model_name"] = cfg['upscale_model_name']
workflow["15"]["inputs"]["url_or_path"] = image_url
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
os.makedirs(save_dir, exist_ok=True)
for node_id, image_list in images_dict.items():
for image_data in image_list:
if has_transparency:
# 处理带透明通道的图像
# 保存放大后的RGB图像
upscaled_rgb_path = os.path.join(save_dir, f"upscaled_rgb_{uuid.uuid4().hex[:10]}.png")
with open(upscaled_rgb_path, "wb") as f:
f.write(image_data)
# 打开放大后的RGB图像
upscaled_rgb = Image.open(upscaled_rgb_path)
# 放大Alpha通道
upscaled_alpha = self.upscale_alpha_channel(alpha_temp_file,
scale_factor=upscaled_rgb.width//Image.open(temp_file).width)
# 确保尺寸匹配
if upscaled_rgb.size != upscaled_alpha.size:
upscaled_alpha = upscaled_alpha.resize(upscaled_rgb.size, Image.BILINEAR)
# 合并通道
upscaled_rgba = upscaled_rgb.copy()
upscaled_rgba.putalpha(upscaled_alpha)
# 保存最终的RGBA图像
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
png_file_path = os.path.join(save_dir, png_filename)
upscaled_rgba.save(png_file_path, "PNG")
# 删除临时RGB文件
os.remove(upscaled_rgb_path)
# 返回PNG URL
image_url = f"{download_url}/{png_filename}"
else:
# 处理没有透明通道的图像
# 保存为JPG以减小文件大小
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
png_file_path = os.path.join(save_dir, png_filename)
with open(png_file_path, "wb") as f:
f.write(image_data)
# 打开图像并转换为JPG
img = Image.open(png_file_path)
jpg_filename = png_filename.replace('.png', '.jpg')
jpg_file_path = os.path.join(save_dir, jpg_filename)
img = img.convert('RGB')
img.save(jpg_file_path, 'JPEG', quality=95)
# 删除PNG临时文件
os.remove(png_file_path)
# 返回JPG URL
image_url = f"{download_url}/{jpg_filename}"
yield image_url
except Exception as e:
raise e
finally:
if ws:
ws.close()
# 清理临时文件
if temp_file and os.path.exists(temp_file):
os.unlink(temp_file)
if alpha_temp_file and os.path.exists(alpha_temp_file):
os.unlink(alpha_temp_file)
async def upscale_image(self, image_url: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""异步放大图像返回图片URL"""
try:
# 在这种情况下,我们需要手动运行同步生成器并收集结果
urls = []
# 在执行器中运行同步代码
def run_sync():
return list(self.upscale_image_sync(image_url, config))
# 获取所有URL
loop = asyncio.get_event_loop()
urls = await loop.run_in_executor(None, run_sync)
# 逐个返回结果
for url in urls:
yield {
"status": "success",
"image_url": url,
"message": "图片已保存"
}
except Exception as e:
yield {
"status": "error",
"message": f"图像放大失败: {str(e)}"
}
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
"""批量处理多个图像URL返回图片URL"""
total = len(image_urls)
success_count = 0
error_count = 0
for i, image_url in enumerate(image_urls, 1):
try:
# 获取图片透明度信息
try:
_, has_transparency = self.check_image_transparency(image_url)
transparency_info = "PNG带透明通道" if has_transparency else "无透明通道"
except:
transparency_info = "未检测"
async for result in self.upscale_image(image_url, config):
if result["status"] == "success":
success_count += 1
yield {
"index": i,
"total": total,
"original_image_url": image_url,
"status": "success",
"image_url": result["image_url"],
"success_count": success_count,
"error_count": error_count,
"transparency": transparency_info,
"message": result["message"]
}
else:
error_count += 1
yield {
"index": i,
"total": total,
"original_image_url": image_url,
"status": "error",
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"original_image_url": image_url,
"status": "error",
"success_count": success_count,
"error_count": error_count,
"message": f"处理图像时出错: {str(e)}"
}