From 61267eacdd52506494999e2ec04713a2650bf95e Mon Sep 17 00:00:00 2001 From: jingrow Date: Fri, 20 Jun 2025 00:18:39 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84jart=5Fv1=E6=94=AF=E6=8C=81sd?= =?UTF-8?q?3.5=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/jart_v1/service.py | 338 +++++++++++++++++++++++++--------------- 1 file changed, 213 insertions(+), 125 deletions(-) diff --git a/apps/jart_v1/service.py b/apps/jart_v1/service.py index a3cbb32..ff9d7d1 100644 --- a/apps/jart_v1/service.py +++ b/apps/jart_v1/service.py @@ -2,21 +2,29 @@ import json import base64 import requests import random +import time import websocket import uuid import urllib.request import asyncio -import io -from typing import Dict, List, Generator, Optional, AsyncGenerator +import logging +from typing import Dict, List, Optional, AsyncGenerator -# 固定配置变量 +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 默认配置 DEFAULT_CONFIG = { "comfyui_server_address": "192.168.2.200:8188", - "ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors", + "ckpt_name": "sd3.5_large.safetensors", + "clip_l_name": "clip_l.safetensors", + "clip_g_name": "clip_g.safetensors", + "t5_name": "t5xxl_fp16.safetensors", "sampler_name": "euler", - "scheduler": "normal", - "steps": 20, - "cfg": 8, + "scheduler": "sgm_uniform", + "steps": 30, + "cfg": 5.5, "denoise": 1.0, "images_per_prompt": 1, "image_width": 1024, @@ -27,65 +35,73 @@ DEFAULT_CONFIG = { # 定义基础工作流 JSON 模板 WORKFLOW_TEMPLATE = """ { - "3": { - "class_type": "KSampler", - "inputs": { - "cfg": %d, - "denoise": %d, - "latent_image": [ - "5", - 0 - ], - "model": [ - "4", - 0 - ], - "negative": [ - "7", - 0 - ], - "positive": [ - "6", - 0 - ], - "sampler_name": "%s", - "scheduler": "%s", - "seed": 8566257, - "steps": %d - } - }, "4": { "class_type": "CheckpointLoaderSimple", "inputs": { "ckpt_name": "%s" } }, - "5": { - "class_type": "EmptyLatentImage", + "43": { + "class_type": "TripleCLIPLoader", "inputs": { - "batch_size": 1, + "clip_name1": "%s", + "clip_name2": "%s", + "clip_name3": "%s" + } + }, + "53": { + "class_type": "EmptySD3LatentImage", + "inputs": { + "width": %d, "height": %d, - "width": %d + "batch_size": 1 } }, - "6": { + "16": { "class_type": "CLIPTextEncode", "inputs": { "clip": [ - "4", - 1 + "43", + 0 ], - "text": "masterpiece best quality girl" + "text": "" } }, - "7": { + "40": { "class_type": "CLIPTextEncode", "inputs": { "clip": [ - "4", - 1 + "43", + 0 ], - "text": "%s" + "text": "" + } + }, + "3": { + "class_type": "KSampler", + "inputs": { + "model": [ + "4", + 0 + ], + "positive": [ + "16", + 0 + ], + "negative": [ + "40", + 0 + ], + "latent_image": [ + "53", + 0 + ], + "seed": %d, + "steps": %d, + "cfg": %.2f, + "sampler_name": "%s", + "scheduler": "%s", + "denoise": %.2f } }, "8": { @@ -101,7 +117,7 @@ WORKFLOW_TEMPLATE = """ ] } }, - "save_image_websocket_node": { + "9": { "class_type": "SaveImageWebsocket", "inputs": { "images": [ @@ -114,117 +130,189 @@ WORKFLOW_TEMPLATE = """ """ class TxtImgService: - def __init__(self): + def __init__(self, config: Optional[Dict] = None): """初始化文本生成图像服务""" - pass + self.config = DEFAULT_CONFIG.copy() + if config: + self.config.update(config) def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: """将提示词发送到 ComfyUI 服务器的队列中""" - p = {"prompt": prompt, "client_id": client_id} - data = json.dumps(p).encode('utf-8') - req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) - response = json.loads(urllib.request.urlopen(req).read()) - return response + try: + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + logger.debug(f"Server address: {comfyui_server_address}") + logger.debug(f"Request data: {json.dumps(p, indent=2)}") + + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + req = urllib.request.Request( + f"http://{comfyui_server_address}/prompt", + data=data, + headers=headers + ) + + response = urllib.request.urlopen(req) + response_data = response.read() + logger.debug(f"Response status: {response.status}") + response_json = json.loads(response_data) + logger.debug(f"Server response: {json.dumps(response_json, indent=2)}") + return response_json + + except Exception as e: + logger.error(f"Failed to queue prompt: {str(e)}") + raise def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: """从 ComfyUI 获取生成的图像""" try: + # 确保工作流中的所有节点都有正确的格式 + for node_id, node_data in workflow.items(): + if "inputs" not in node_data: + node_data["inputs"] = {} + if "class_type" not in node_data: + logger.error(f"Node {node_id} missing class_type") + raise ValueError(f"Node {node_id} missing class_type") + + logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}") prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) - prompt_id = prompt_response['prompt_id'] - except KeyError: + + if not isinstance(prompt_response, dict): + logger.error(f"Invalid response type: {type(prompt_response)}") + return {} + + prompt_id = prompt_response.get('prompt_id') + if not prompt_id: + logger.error("No prompt_id in response") + return {} + + logger.debug(f"Got prompt_id: {prompt_id}") + + except Exception as e: + logger.error(f"Failed to get prompt_id: {str(e)}") return {} output_images = {} current_node = "" - while True: - out = ws.recv() - if isinstance(out, str): - message = json.loads(out) - if message['type'] == 'executing': - data = message['data'] - if data.get('prompt_id') == prompt_id: - if data['node'] is None: - break - else: - current_node = data['node'] - else: - if current_node == 'save_image_websocket_node': - images_output = output_images.get(current_node, []) - images_output.append(out[8:]) - output_images[current_node] = images_output + try: + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + logger.debug(f"Received message: {message}") + if message['type'] == 'executing': + data = message['data'] + if data.get('prompt_id') == prompt_id: + if data['node'] is None: + break + else: + current_node = data['node'] + logger.debug(f"Processing node: {current_node}") + else: + if current_node == '9': # SaveImageWebsocket节点ID + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + logger.debug(f"Saved image for node: {current_node}") + + except Exception as e: + logger.error(f"Error in websocket communication: {str(e)}") + return {} return output_images - def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]: - """生成 Flux 模型的图片,流式返回 base64 编码的图片""" - cfg = DEFAULT_CONFIG.copy() + async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: + """异步生成图像""" + cfg = self.config.copy() if config: cfg.update(config) - + ws = websocket.WebSocket() client_id = str(uuid.uuid4()) try: ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") - images_count = int(cfg.get('images_per_prompt', 1)) + logger.info("WebSocket connected successfully") - for i in range(images_count): - workflow = json.loads(WORKFLOW_TEMPLATE % ( - cfg['cfg'], - cfg['denoise'], - cfg['sampler_name'], - cfg['scheduler'], - cfg['steps'], - cfg['ckpt_name'], - cfg['image_height'], - cfg['image_width'], - cfg['negative_prompt'] - )) + for i in range(cfg['images_per_prompt']): + logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}") - workflow["6"]["inputs"]["text"] = prompt + # 生成随机种子 seed = random.randint(1, 4294967295) - workflow["3"]["inputs"]["seed"] = seed - images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) + try: + # 准备参数 + params = ( + cfg['ckpt_name'], + cfg['clip_l_name'], + cfg['clip_g_name'], + cfg['t5_name'], + cfg['image_width'], + cfg['image_height'], + seed, + cfg['steps'], + cfg['cfg'], + cfg['sampler_name'], + cfg['scheduler'], + cfg['denoise'] + ) + + # 格式化工作流 + workflow = json.loads(WORKFLOW_TEMPLATE % params) + + # 设置提示词 + workflow["16"]["inputs"]["text"] = prompt + workflow["40"]["inputs"]["text"] = cfg['negative_prompt'] + + # 移除空字段 + for node in workflow.values(): + if "widgets_values" in node: + del node["widgets_values"] + + # 获取生成的图像 + images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) + + if not images: + yield { + "status": "error", + "message": "No images generated" + } + continue + + # 处理生成的图像 + for node_id, image_list in images.items(): + for image_data in image_list: + base64_image = base64.b64encode(image_data).decode('utf-8') + yield { + "status": "success", + "image": f"data:image/png;base64,{base64_image}", + "message": f"成功生成第 {i+1} 张图片" + } + + except Exception as e: + logger.error(f"Error generating image: {str(e)}") + yield { + "status": "error", + "message": f"生成图片失败: {str(e)}" + } + + await asyncio.sleep(2) # 避免请求过于频繁 - for node_id, image_list in images_dict.items(): - for image_data in image_list: - base64_image = base64.b64encode(image_data).decode('utf-8') - yield base64_image - except Exception as e: - raise e - + logger.error(f"WebSocket connection error: {str(e)}") + yield { + "status": "error", + "message": f"WebSocket连接失败: {str(e)}" + } finally: if ws: ws.close() + logger.info("WebSocket connection closed") - async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: - """异步生成图像,流式返回结果""" - try: - loop = asyncio.get_event_loop() - - def sync_generator(): - for base64_image in self.generate_image_sync(prompt, config): - yield base64_image - - generator = await loop.run_in_executor(None, sync_generator) - - for base64_image in generator: - yield { - "status": "success", - "image": f"data:image/png;base64,{base64_image}", - "message": f"成功生成图片" - } - - except Exception as e: - yield { - "status": "error", - "message": f"图像生成失败: {str(e)}" - } - - async def process_batch(self, prompts: List[str], config: Optional[Dict] = None): - """批量处理多个文本提示,流式返回结果""" + async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: + """批量处理多个提示词""" total = len(prompts) success_count = 0 error_count = 0