重构jart_v1支持sd3.5模型

This commit is contained in:
jingrow 2025-06-20 00:18:39 +08:00
parent dee2929268
commit 61267eacdd

View File

@ -2,21 +2,29 @@ import json
import base64
import requests
import random
import time
import websocket
import uuid
import urllib.request
import asyncio
import io
from typing import Dict, List, Generator, Optional, AsyncGenerator
import logging
from typing import Dict, List, Optional, AsyncGenerator
# 固定配置变量
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 默认配置
DEFAULT_CONFIG = {
"comfyui_server_address": "192.168.2.200:8188",
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors",
"ckpt_name": "sd3.5_large.safetensors",
"clip_l_name": "clip_l.safetensors",
"clip_g_name": "clip_g.safetensors",
"t5_name": "t5xxl_fp16.safetensors",
"sampler_name": "euler",
"scheduler": "normal",
"steps": 20,
"cfg": 8,
"scheduler": "sgm_uniform",
"steps": 30,
"cfg": 5.5,
"denoise": 1.0,
"images_per_prompt": 1,
"image_width": 1024,
@ -27,65 +35,73 @@ DEFAULT_CONFIG = {
# 定义基础工作流 JSON 模板
WORKFLOW_TEMPLATE = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": %d,
"denoise": %d,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "%s",
"scheduler": "%s",
"seed": 8566257,
"steps": %d
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "%s"
}
},
"5": {
"class_type": "EmptyLatentImage",
"43": {
"class_type": "TripleCLIPLoader",
"inputs": {
"batch_size": 1,
"clip_name1": "%s",
"clip_name2": "%s",
"clip_name3": "%s"
}
},
"53": {
"class_type": "EmptySD3LatentImage",
"inputs": {
"width": %d,
"height": %d,
"width": %d
"batch_size": 1
}
},
"6": {
"16": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
"43",
0
],
"text": "masterpiece best quality girl"
"text": ""
}
},
"7": {
"40": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
"43",
0
],
"text": "%s"
"text": ""
}
},
"3": {
"class_type": "KSampler",
"inputs": {
"model": [
"4",
0
],
"positive": [
"16",
0
],
"negative": [
"40",
0
],
"latent_image": [
"53",
0
],
"seed": %d,
"steps": %d,
"cfg": %.2f,
"sampler_name": "%s",
"scheduler": "%s",
"denoise": %.2f
}
},
"8": {
@ -101,7 +117,7 @@ WORKFLOW_TEMPLATE = """
]
}
},
"save_image_websocket_node": {
"9": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
@ -114,117 +130,189 @@ WORKFLOW_TEMPLATE = """
"""
class TxtImgService:
def __init__(self):
def __init__(self, config: Optional[Dict] = None):
"""初始化文本生成图像服务"""
pass
self.config = DEFAULT_CONFIG.copy()
if config:
self.config.update(config)
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""将提示词发送到 ComfyUI 服务器的队列中"""
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
response = json.loads(urllib.request.urlopen(req).read())
return response
try:
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
logger.debug(f"Server address: {comfyui_server_address}")
logger.debug(f"Request data: {json.dumps(p, indent=2)}")
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
req = urllib.request.Request(
f"http://{comfyui_server_address}/prompt",
data=data,
headers=headers
)
response = urllib.request.urlopen(req)
response_data = response.read()
logger.debug(f"Response status: {response.status}")
response_json = json.loads(response_data)
logger.debug(f"Server response: {json.dumps(response_json, indent=2)}")
return response_json
except Exception as e:
logger.error(f"Failed to queue prompt: {str(e)}")
raise
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""从 ComfyUI 获取生成的图像"""
try:
# 确保工作流中的所有节点都有正确的格式
for node_id, node_data in workflow.items():
if "inputs" not in node_data:
node_data["inputs"] = {}
if "class_type" not in node_data:
logger.error(f"Node {node_id} missing class_type")
raise ValueError(f"Node {node_id} missing class_type")
logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}")
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
prompt_id = prompt_response['prompt_id']
except KeyError:
if not isinstance(prompt_response, dict):
logger.error(f"Invalid response type: {type(prompt_response)}")
return {}
prompt_id = prompt_response.get('prompt_id')
if not prompt_id:
logger.error("No prompt_id in response")
return {}
logger.debug(f"Got prompt_id: {prompt_id}")
except Exception as e:
logger.error(f"Failed to get prompt_id: {str(e)}")
return {}
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
try:
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
logger.debug(f"Received message: {message}")
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
logger.debug(f"Processing node: {current_node}")
else:
if current_node == '9': # SaveImageWebsocket节点ID
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
logger.debug(f"Saved image for node: {current_node}")
except Exception as e:
logger.error(f"Error in websocket communication: {str(e)}")
return {}
return output_images
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
cfg = DEFAULT_CONFIG.copy()
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""异步生成图像"""
cfg = self.config.copy()
if config:
cfg.update(config)
ws = websocket.WebSocket()
client_id = str(uuid.uuid4())
try:
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
images_count = int(cfg.get('images_per_prompt', 1))
logger.info("WebSocket connected successfully")
for i in range(images_count):
workflow = json.loads(WORKFLOW_TEMPLATE % (
cfg['cfg'],
cfg['denoise'],
cfg['sampler_name'],
cfg['scheduler'],
cfg['steps'],
cfg['ckpt_name'],
cfg['image_height'],
cfg['image_width'],
cfg['negative_prompt']
))
for i in range(cfg['images_per_prompt']):
logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}")
workflow["6"]["inputs"]["text"] = prompt
# 生成随机种子
seed = random.randint(1, 4294967295)
workflow["3"]["inputs"]["seed"] = seed
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
try:
# 准备参数
params = (
cfg['ckpt_name'],
cfg['clip_l_name'],
cfg['clip_g_name'],
cfg['t5_name'],
cfg['image_width'],
cfg['image_height'],
seed,
cfg['steps'],
cfg['cfg'],
cfg['sampler_name'],
cfg['scheduler'],
cfg['denoise']
)
# 格式化工作流
workflow = json.loads(WORKFLOW_TEMPLATE % params)
# 设置提示词
workflow["16"]["inputs"]["text"] = prompt
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
# 移除空字段
for node in workflow.values():
if "widgets_values" in node:
del node["widgets_values"]
# 获取生成的图像
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
if not images:
yield {
"status": "error",
"message": "No images generated"
}
continue
# 处理生成的图像
for node_id, image_list in images.items():
for image_data in image_list:
base64_image = base64.b64encode(image_data).decode('utf-8')
yield {
"status": "success",
"image": f"data:image/png;base64,{base64_image}",
"message": f"成功生成第 {i+1} 张图片"
}
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
yield {
"status": "error",
"message": f"生成图片失败: {str(e)}"
}
await asyncio.sleep(2) # 避免请求过于频繁
for node_id, image_list in images_dict.items():
for image_data in image_list:
base64_image = base64.b64encode(image_data).decode('utf-8')
yield base64_image
except Exception as e:
raise e
logger.error(f"WebSocket connection error: {str(e)}")
yield {
"status": "error",
"message": f"WebSocket连接失败: {str(e)}"
}
finally:
if ws:
ws.close()
logger.info("WebSocket connection closed")
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""异步生成图像,流式返回结果"""
try:
loop = asyncio.get_event_loop()
def sync_generator():
for base64_image in self.generate_image_sync(prompt, config):
yield base64_image
generator = await loop.run_in_executor(None, sync_generator)
for base64_image in generator:
yield {
"status": "success",
"image": f"data:image/png;base64,{base64_image}",
"message": f"成功生成图片"
}
except Exception as e:
yield {
"status": "error",
"message": f"图像生成失败: {str(e)}"
}
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
"""批量处理多个文本提示,流式返回结果"""
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""批量处理多个提示词"""
total = len(prompts)
success_count = 0
error_count = 0