japi/apps/jart_v1/service.py

360 lines
12 KiB
Python

import json
import base64
import requests
import random
import time
import websocket
import uuid
import urllib.request
import asyncio
import logging
from typing import Dict, List, Optional, AsyncGenerator
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 默认配置
DEFAULT_CONFIG = {
"comfyui_server_address": "192.168.2.200:8188",
"ckpt_name": "sd3.5_large.safetensors",
"clip_l_name": "clip_l.safetensors",
"clip_g_name": "clip_g.safetensors",
"t5_name": "t5xxl_fp16.safetensors",
"sampler_name": "euler",
"scheduler": "sgm_uniform",
"steps": 30,
"cfg": 5.5,
"denoise": 1.0,
"images_per_prompt": 1,
"image_width": 1024,
"image_height": 1024,
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
}
# 定义基础工作流 JSON 模板
WORKFLOW_TEMPLATE = """
{
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "%s"
}
},
"43": {
"class_type": "TripleCLIPLoader",
"inputs": {
"clip_name1": "%s",
"clip_name2": "%s",
"clip_name3": "%s"
}
},
"53": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": %d,
"height": %d,
"batch_size": 1
}
},
"16": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"43",
0
],
"text": ""
}
},
"40": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"43",
0
],
"text": ""
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"model": [
"4",
0
],
"positive": [
"16",
0
],
"negative": [
"40",
0
],
"latent_image": [
"53",
0
],
"seed": %d,
"steps": %d,
"cfg": %.2f,
"sampler_name": "%s",
"scheduler": "%s",
"denoise": %.2f
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"9": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
class TxtImgService:
def __init__(self, config: Optional[Dict] = None):
"""初始化文本生成图像服务"""
self.config = DEFAULT_CONFIG.copy()
if config:
self.config.update(config)
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""将提示词发送到 ComfyUI 服务器的队列中"""
try:
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
logger.debug(f"Server address: {comfyui_server_address}")
logger.debug(f"Request data: {json.dumps(p, indent=2)}")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
req = urllib.request.Request(
f"http://{comfyui_server_address}/prompt",
data=data,
headers=headers
)
response = urllib.request.urlopen(req)
response_data = response.read()
logger.debug(f"Response status: {response.status}")
response_json = json.loads(response_data)
logger.debug(f"Server response: {json.dumps(response_json, indent=2)}")
return response_json
except Exception as e:
logger.error(f"Failed to queue prompt: {str(e)}")
raise
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""从 ComfyUI 获取生成的图像"""
try:
# 确保工作流中的所有节点都有正确的格式
for node_id, node_data in workflow.items():
if "inputs" not in node_data:
node_data["inputs"] = {}
if "class_type" not in node_data:
logger.error(f"Node {node_id} missing class_type")
raise ValueError(f"Node {node_id} missing class_type")
logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}")
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
if not isinstance(prompt_response, dict):
logger.error(f"Invalid response type: {type(prompt_response)}")
return {}
prompt_id = prompt_response.get('prompt_id')
if not prompt_id:
logger.error("No prompt_id in response")
return {}
logger.debug(f"Got prompt_id: {prompt_id}")
except Exception as e:
logger.error(f"Failed to get prompt_id: {str(e)}")
return {}
output_images = {}
current_node = ""
try:
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
logger.debug(f"Received message: {message}")
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
logger.debug(f"Processing node: {current_node}")
else:
if current_node == '9': # SaveImageWebsocket节点ID
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
logger.debug(f"Saved image for node: {current_node}")
except Exception as e:
logger.error(f"Error in websocket communication: {str(e)}")
return {}
return output_images
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""异步生成图像"""
cfg = self.config.copy()
if config:
cfg.update(config)
ws = websocket.WebSocket()
client_id = str(uuid.uuid4())
try:
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
logger.info("WebSocket connected successfully")
for i in range(cfg['images_per_prompt']):
logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}")
# 生成随机种子
seed = random.randint(1, 4294967295)
try:
# 准备参数
params = (
cfg['ckpt_name'],
cfg['clip_l_name'],
cfg['clip_g_name'],
cfg['t5_name'],
cfg['image_width'],
cfg['image_height'],
seed,
cfg['steps'],
cfg['cfg'],
cfg['sampler_name'],
cfg['scheduler'],
cfg['denoise']
)
# 格式化工作流
workflow = json.loads(WORKFLOW_TEMPLATE % params)
# 设置提示词
workflow["16"]["inputs"]["text"] = prompt
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
# 移除空字段
for node in workflow.values():
if "widgets_values" in node:
del node["widgets_values"]
# 获取生成的图像
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
if not images:
yield {
"status": "error",
"message": "No images generated"
}
continue
# 处理生成的图像
for node_id, image_list in images.items():
for image_data in image_list:
base64_image = base64.b64encode(image_data).decode('utf-8')
yield {
"status": "success",
"image": f"data:image/png;base64,{base64_image}",
"message": f"成功生成第 {i+1} 张图片"
}
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
yield {
"status": "error",
"message": f"生成图片失败: {str(e)}"
}
await asyncio.sleep(2) # 避免请求过于频繁
except Exception as e:
logger.error(f"WebSocket connection error: {str(e)}")
yield {
"status": "error",
"message": f"WebSocket连接失败: {str(e)}"
}
finally:
if ws:
ws.close()
logger.info("WebSocket connection closed")
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""批量处理多个提示词"""
total = len(prompts)
success_count = 0
error_count = 0
for i, prompt in enumerate(prompts, 1):
try:
async for result in self.generate_image(prompt, config):
if result["status"] == "success":
success_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "success",
"image_content": result["image"],
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
else:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}"
}
await asyncio.sleep(0)