japi/apps/jart_v1/service.py

295 lines
9.7 KiB
Python

import json
import base64
import requests
import random
import time
import websocket
import uuid
import urllib.request
import asyncio
from typing import Dict, List, Optional, AsyncGenerator
# 默认配置
DEFAULT_CONFIG = {
"comfyui_server_address": "192.168.2.200:8188",
"ckpt_name": "sd3.5_large.safetensors",
"clip_l_name": "clip_l.safetensors",
"clip_g_name": "clip_g.safetensors",
"t5_name": "t5xxl_fp16.safetensors",
"sampler_name": "euler",
"scheduler": "sgm_uniform",
"steps": 30,
"cfg": 5.5,
"denoise": 1.0,
"images_per_prompt": 1,
"image_width": 1024,
"image_height": 1024,
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
}
# 定义基础工作流 JSON 模板
WORKFLOW_TEMPLATE = """
{
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "%s"
}
},
"43": {
"class_type": "TripleCLIPLoader",
"inputs": {
"clip_name1": "%s",
"clip_name2": "%s",
"clip_name3": "%s"
}
},
"53": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": %d,
"height": %d,
"batch_size": 1
}
},
"16": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"43",
0
],
"text": ""
}
},
"40": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"43",
0
],
"text": ""
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"model": [
"4",
0
],
"positive": [
"16",
0
],
"negative": [
"40",
0
],
"latent_image": [
"53",
0
],
"seed": %d,
"steps": %d,
"cfg": %.2f,
"sampler_name": "%s",
"scheduler": "%s",
"denoise": %.2f
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"9": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
class TxtImgService:
def __init__(self, config: Optional[Dict] = None):
self.config = DEFAULT_CONFIG.copy()
if config:
self.config.update(config)
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
try:
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
req = urllib.request.Request(
f"http://{comfyui_server_address}/prompt",
data=data,
headers=headers
)
response = urllib.request.urlopen(req)
response_data = response.read()
response_json = json.loads(response_data)
return response_json
except Exception as e:
raise
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
try:
for node_id, node_data in workflow.items():
if "inputs" not in node_data:
node_data["inputs"] = {}
if "class_type" not in node_data:
raise ValueError(f"Node {node_id} missing class_type")
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
if not isinstance(prompt_response, dict):
return {}
prompt_id = prompt_response.get('prompt_id')
if not prompt_id:
return {}
except Exception as e:
return {}
output_images = {}
current_node = ""
try:
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == '9':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
except Exception as e:
return {}
return output_images
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
cfg = self.config.copy()
if config:
cfg.update(config)
ws = websocket.WebSocket()
client_id = str(uuid.uuid4())
try:
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
for i in range(cfg['images_per_prompt']):
seed = random.randint(1, 4294967295)
try:
params = (
cfg['ckpt_name'],
cfg['clip_l_name'],
cfg['clip_g_name'],
cfg['t5_name'],
cfg['image_width'],
cfg['image_height'],
seed,
cfg['steps'],
cfg['cfg'],
cfg['sampler_name'],
cfg['scheduler'],
cfg['denoise']
)
workflow = json.loads(WORKFLOW_TEMPLATE % params)
workflow["16"]["inputs"]["text"] = prompt
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
for node in workflow.values():
if "widgets_values" in node:
del node["widgets_values"]
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
if not images:
yield {
"status": "error",
"message": "No images generated"
}
continue
for node_id, image_list in images.items():
for image_data in image_list:
base64_image = base64.b64encode(image_data).decode('utf-8')
yield {
"status": "success",
"image": f"data:image/png;base64,{base64_image}",
"message": f"成功生成第 {i+1} 张图片"
}
except Exception as e:
yield {
"status": "error",
"message": f"生成图片失败: {str(e)}"
}
await asyncio.sleep(2)
except Exception as e:
yield {
"status": "error",
"message": f"WebSocket连接失败: {str(e)}"
}
finally:
if ws:
ws.close()
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
total = len(prompts)
success_count = 0
error_count = 0
for i, prompt in enumerate(prompts, 1):
try:
async for result in self.generate_image(prompt, config):
if result["status"] == "success":
success_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "success",
"image_content": result["image"],
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
else:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}"
}
await asyncio.sleep(0)