重构jart_v1支持sd3.5模型
This commit is contained in:
parent
dee2929268
commit
61267eacdd
@ -2,21 +2,29 @@ import json
|
||||
import base64
|
||||
import requests
|
||||
import random
|
||||
import time
|
||||
import websocket
|
||||
import uuid
|
||||
import urllib.request
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
||||
import logging
|
||||
from typing import Dict, List, Optional, AsyncGenerator
|
||||
|
||||
# 固定配置变量
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
"comfyui_server_address": "192.168.2.200:8188",
|
||||
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
||||
"ckpt_name": "sd3.5_large.safetensors",
|
||||
"clip_l_name": "clip_l.safetensors",
|
||||
"clip_g_name": "clip_g.safetensors",
|
||||
"t5_name": "t5xxl_fp16.safetensors",
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"steps": 20,
|
||||
"cfg": 8,
|
||||
"scheduler": "sgm_uniform",
|
||||
"steps": 30,
|
||||
"cfg": 5.5,
|
||||
"denoise": 1.0,
|
||||
"images_per_prompt": 1,
|
||||
"image_width": 1024,
|
||||
@ -27,65 +35,73 @@ DEFAULT_CONFIG = {
|
||||
# 定义基础工作流 JSON 模板
|
||||
WORKFLOW_TEMPLATE = """
|
||||
{
|
||||
"3": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"cfg": %d,
|
||||
"denoise": %d,
|
||||
"latent_image": [
|
||||
"5",
|
||||
0
|
||||
],
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"sampler_name": "%s",
|
||||
"scheduler": "%s",
|
||||
"seed": 8566257,
|
||||
"steps": %d
|
||||
}
|
||||
},
|
||||
"4": {
|
||||
"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {
|
||||
"ckpt_name": "%s"
|
||||
}
|
||||
},
|
||||
"5": {
|
||||
"class_type": "EmptyLatentImage",
|
||||
"43": {
|
||||
"class_type": "TripleCLIPLoader",
|
||||
"inputs": {
|
||||
"batch_size": 1,
|
||||
"clip_name1": "%s",
|
||||
"clip_name2": "%s",
|
||||
"clip_name3": "%s"
|
||||
}
|
||||
},
|
||||
"53": {
|
||||
"class_type": "EmptySD3LatentImage",
|
||||
"inputs": {
|
||||
"width": %d,
|
||||
"height": %d,
|
||||
"width": %d
|
||||
"batch_size": 1
|
||||
}
|
||||
},
|
||||
"6": {
|
||||
"16": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
"43",
|
||||
0
|
||||
],
|
||||
"text": "masterpiece best quality girl"
|
||||
"text": ""
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"40": {
|
||||
"class_type": "CLIPTextEncode",
|
||||
"inputs": {
|
||||
"clip": [
|
||||
"4",
|
||||
1
|
||||
"43",
|
||||
0
|
||||
],
|
||||
"text": "%s"
|
||||
"text": ""
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"class_type": "KSampler",
|
||||
"inputs": {
|
||||
"model": [
|
||||
"4",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"16",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"40",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"53",
|
||||
0
|
||||
],
|
||||
"seed": %d,
|
||||
"steps": %d,
|
||||
"cfg": %.2f,
|
||||
"sampler_name": "%s",
|
||||
"scheduler": "%s",
|
||||
"denoise": %.2f
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
@ -101,7 +117,7 @@ WORKFLOW_TEMPLATE = """
|
||||
]
|
||||
}
|
||||
},
|
||||
"save_image_websocket_node": {
|
||||
"9": {
|
||||
"class_type": "SaveImageWebsocket",
|
||||
"inputs": {
|
||||
"images": [
|
||||
@ -114,117 +130,189 @@ WORKFLOW_TEMPLATE = """
|
||||
"""
|
||||
|
||||
class TxtImgService:
|
||||
def __init__(self):
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""初始化文本生成图像服务"""
|
||||
pass
|
||||
self.config = DEFAULT_CONFIG.copy()
|
||||
if config:
|
||||
self.config.update(config)
|
||||
|
||||
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
||||
response = json.loads(urllib.request.urlopen(req).read())
|
||||
return response
|
||||
try:
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
logger.debug(f"Server address: {comfyui_server_address}")
|
||||
logger.debug(f"Request data: {json.dumps(p, indent=2)}")
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
req = urllib.request.Request(
|
||||
f"http://{comfyui_server_address}/prompt",
|
||||
data=data,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
response = urllib.request.urlopen(req)
|
||||
response_data = response.read()
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
response_json = json.loads(response_data)
|
||||
logger.debug(f"Server response: {json.dumps(response_json, indent=2)}")
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to queue prompt: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||
"""从 ComfyUI 获取生成的图像"""
|
||||
try:
|
||||
# 确保工作流中的所有节点都有正确的格式
|
||||
for node_id, node_data in workflow.items():
|
||||
if "inputs" not in node_data:
|
||||
node_data["inputs"] = {}
|
||||
if "class_type" not in node_data:
|
||||
logger.error(f"Node {node_id} missing class_type")
|
||||
raise ValueError(f"Node {node_id} missing class_type")
|
||||
|
||||
logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}")
|
||||
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||||
prompt_id = prompt_response['prompt_id']
|
||||
except KeyError:
|
||||
|
||||
if not isinstance(prompt_response, dict):
|
||||
logger.error(f"Invalid response type: {type(prompt_response)}")
|
||||
return {}
|
||||
|
||||
prompt_id = prompt_response.get('prompt_id')
|
||||
if not prompt_id:
|
||||
logger.error("No prompt_id in response")
|
||||
return {}
|
||||
|
||||
logger.debug(f"Got prompt_id: {prompt_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get prompt_id: {str(e)}")
|
||||
return {}
|
||||
|
||||
output_images = {}
|
||||
current_node = ""
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data.get('prompt_id') == prompt_id:
|
||||
if data['node'] is None:
|
||||
break
|
||||
else:
|
||||
current_node = data['node']
|
||||
else:
|
||||
if current_node == 'save_image_websocket_node':
|
||||
images_output = output_images.get(current_node, [])
|
||||
images_output.append(out[8:])
|
||||
output_images[current_node] = images_output
|
||||
try:
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
logger.debug(f"Received message: {message}")
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data.get('prompt_id') == prompt_id:
|
||||
if data['node'] is None:
|
||||
break
|
||||
else:
|
||||
current_node = data['node']
|
||||
logger.debug(f"Processing node: {current_node}")
|
||||
else:
|
||||
if current_node == '9': # SaveImageWebsocket节点ID
|
||||
images_output = output_images.get(current_node, [])
|
||||
images_output.append(out[8:])
|
||||
output_images[current_node] = images_output
|
||||
logger.debug(f"Saved image for node: {current_node}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in websocket communication: {str(e)}")
|
||||
return {}
|
||||
|
||||
return output_images
|
||||
|
||||
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
||||
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
|
||||
cfg = DEFAULT_CONFIG.copy()
|
||||
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||
"""异步生成图像"""
|
||||
cfg = self.config.copy()
|
||||
if config:
|
||||
cfg.update(config)
|
||||
|
||||
|
||||
ws = websocket.WebSocket()
|
||||
client_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||||
images_count = int(cfg.get('images_per_prompt', 1))
|
||||
logger.info("WebSocket connected successfully")
|
||||
|
||||
for i in range(images_count):
|
||||
workflow = json.loads(WORKFLOW_TEMPLATE % (
|
||||
cfg['cfg'],
|
||||
cfg['denoise'],
|
||||
cfg['sampler_name'],
|
||||
cfg['scheduler'],
|
||||
cfg['steps'],
|
||||
cfg['ckpt_name'],
|
||||
cfg['image_height'],
|
||||
cfg['image_width'],
|
||||
cfg['negative_prompt']
|
||||
))
|
||||
for i in range(cfg['images_per_prompt']):
|
||||
logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}")
|
||||
|
||||
workflow["6"]["inputs"]["text"] = prompt
|
||||
# 生成随机种子
|
||||
seed = random.randint(1, 4294967295)
|
||||
workflow["3"]["inputs"]["seed"] = seed
|
||||
|
||||
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||
try:
|
||||
# 准备参数
|
||||
params = (
|
||||
cfg['ckpt_name'],
|
||||
cfg['clip_l_name'],
|
||||
cfg['clip_g_name'],
|
||||
cfg['t5_name'],
|
||||
cfg['image_width'],
|
||||
cfg['image_height'],
|
||||
seed,
|
||||
cfg['steps'],
|
||||
cfg['cfg'],
|
||||
cfg['sampler_name'],
|
||||
cfg['scheduler'],
|
||||
cfg['denoise']
|
||||
)
|
||||
|
||||
# 格式化工作流
|
||||
workflow = json.loads(WORKFLOW_TEMPLATE % params)
|
||||
|
||||
# 设置提示词
|
||||
workflow["16"]["inputs"]["text"] = prompt
|
||||
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
|
||||
|
||||
# 移除空字段
|
||||
for node in workflow.values():
|
||||
if "widgets_values" in node:
|
||||
del node["widgets_values"]
|
||||
|
||||
# 获取生成的图像
|
||||
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||
|
||||
if not images:
|
||||
yield {
|
||||
"status": "error",
|
||||
"message": "No images generated"
|
||||
}
|
||||
continue
|
||||
|
||||
# 处理生成的图像
|
||||
for node_id, image_list in images.items():
|
||||
for image_data in image_list:
|
||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
yield {
|
||||
"status": "success",
|
||||
"image": f"data:image/png;base64,{base64_image}",
|
||||
"message": f"成功生成第 {i+1} 张图片"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating image: {str(e)}")
|
||||
yield {
|
||||
"status": "error",
|
||||
"message": f"生成图片失败: {str(e)}"
|
||||
}
|
||||
|
||||
await asyncio.sleep(2) # 避免请求过于频繁
|
||||
|
||||
for node_id, image_list in images_dict.items():
|
||||
for image_data in image_list:
|
||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||
yield base64_image
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
logger.error(f"WebSocket connection error: {str(e)}")
|
||||
yield {
|
||||
"status": "error",
|
||||
"message": f"WebSocket连接失败: {str(e)}"
|
||||
}
|
||||
finally:
|
||||
if ws:
|
||||
ws.close()
|
||||
logger.info("WebSocket connection closed")
|
||||
|
||||
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||
"""异步生成图像,流式返回结果"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def sync_generator():
|
||||
for base64_image in self.generate_image_sync(prompt, config):
|
||||
yield base64_image
|
||||
|
||||
generator = await loop.run_in_executor(None, sync_generator)
|
||||
|
||||
for base64_image in generator:
|
||||
yield {
|
||||
"status": "success",
|
||||
"image": f"data:image/png;base64,{base64_image}",
|
||||
"message": f"成功生成图片"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
yield {
|
||||
"status": "error",
|
||||
"message": f"图像生成失败: {str(e)}"
|
||||
}
|
||||
|
||||
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
|
||||
"""批量处理多个文本提示,流式返回结果"""
|
||||
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||
"""批量处理多个提示词"""
|
||||
total = len(prompts)
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user