diff --git a/apps/jart_v1/service.py b/apps/jart_v1/service.py index ff9d7d1..b7afab7 100644 --- a/apps/jart_v1/service.py +++ b/apps/jart_v1/service.py @@ -7,13 +7,8 @@ import websocket import uuid import urllib.request import asyncio -import logging from typing import Dict, List, Optional, AsyncGenerator -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - # 默认配置 DEFAULT_CONFIG = { "comfyui_server_address": "192.168.2.200:8188", @@ -131,19 +126,14 @@ WORKFLOW_TEMPLATE = """ class TxtImgService: def __init__(self, config: Optional[Dict] = None): - """初始化文本生成图像服务""" self.config = DEFAULT_CONFIG.copy() if config: self.config.update(config) def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: - """将提示词发送到 ComfyUI 服务器的队列中""" try: p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode('utf-8') - logger.debug(f"Server address: {comfyui_server_address}") - logger.debug(f"Request data: {json.dumps(p, indent=2)}") - headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' @@ -153,45 +143,27 @@ class TxtImgService: data=data, headers=headers ) - response = urllib.request.urlopen(req) response_data = response.read() - logger.debug(f"Response status: {response.status}") response_json = json.loads(response_data) - logger.debug(f"Server response: {json.dumps(response_json, indent=2)}") return response_json - except Exception as e: - logger.error(f"Failed to queue prompt: {str(e)}") raise def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: - """从 ComfyUI 获取生成的图像""" try: - # 确保工作流中的所有节点都有正确的格式 for node_id, node_data in workflow.items(): if "inputs" not in node_data: node_data["inputs"] = {} if "class_type" not in node_data: - logger.error(f"Node {node_id} missing class_type") raise ValueError(f"Node {node_id} missing class_type") - - logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}") prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) - if not isinstance(prompt_response, dict): - logger.error(f"Invalid response type: {type(prompt_response)}") return {} - prompt_id = prompt_response.get('prompt_id') if not prompt_id: - logger.error("No prompt_id in response") return {} - - logger.debug(f"Got prompt_id: {prompt_id}") - except Exception as e: - logger.error(f"Failed to get prompt_id: {str(e)}") return {} output_images = {} @@ -201,7 +173,6 @@ class TxtImgService: out = ws.recv() if isinstance(out, str): message = json.loads(out) - logger.debug(f"Received message: {message}") if message['type'] == 'executing': data = message['data'] if data.get('prompt_id') == prompt_id: @@ -209,41 +180,26 @@ class TxtImgService: break else: current_node = data['node'] - logger.debug(f"Processing node: {current_node}") else: - if current_node == '9': # SaveImageWebsocket节点ID + if current_node == '9': images_output = output_images.get(current_node, []) images_output.append(out[8:]) output_images[current_node] = images_output - logger.debug(f"Saved image for node: {current_node}") - except Exception as e: - logger.error(f"Error in websocket communication: {str(e)}") return {} - return output_images async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: - """异步生成图像""" cfg = self.config.copy() if config: cfg.update(config) - ws = websocket.WebSocket() client_id = str(uuid.uuid4()) - try: ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") - logger.info("WebSocket connected successfully") - for i in range(cfg['images_per_prompt']): - logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}") - - # 生成随机种子 seed = random.randint(1, 4294967295) - try: - # 准备参数 params = ( cfg['ckpt_name'], cfg['clip_l_name'], @@ -258,30 +214,19 @@ class TxtImgService: cfg['scheduler'], cfg['denoise'] ) - - # 格式化工作流 workflow = json.loads(WORKFLOW_TEMPLATE % params) - - # 设置提示词 workflow["16"]["inputs"]["text"] = prompt workflow["40"]["inputs"]["text"] = cfg['negative_prompt'] - - # 移除空字段 for node in workflow.values(): if "widgets_values" in node: del node["widgets_values"] - - # 获取生成的图像 images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) - if not images: yield { "status": "error", "message": "No images generated" } continue - - # 处理生成的图像 for node_id, image_list in images.items(): for image_data in image_list: base64_image = base64.b64encode(image_data).decode('utf-8') @@ -290,18 +235,13 @@ class TxtImgService: "image": f"data:image/png;base64,{base64_image}", "message": f"成功生成第 {i+1} 张图片" } - except Exception as e: - logger.error(f"Error generating image: {str(e)}") yield { "status": "error", "message": f"生成图片失败: {str(e)}" } - - await asyncio.sleep(2) # 避免请求过于频繁 - + await asyncio.sleep(2) except Exception as e: - logger.error(f"WebSocket connection error: {str(e)}") yield { "status": "error", "message": f"WebSocket连接失败: {str(e)}" @@ -309,14 +249,11 @@ class TxtImgService: finally: if ws: ws.close() - logger.info("WebSocket connection closed") async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: - """批量处理多个提示词""" total = len(prompts) success_count = 0 error_count = 0 - for i, prompt in enumerate(prompts, 1): try: async for result in self.generate_image(prompt, config): @@ -343,7 +280,6 @@ class TxtImgService: "error_count": error_count, "message": result["message"] } - except Exception as e: error_count += 1 yield { @@ -356,5 +292,4 @@ class TxtImgService: "error_count": error_count, "message": f"处理失败: {str(e)}" } - await asyncio.sleep(0) \ No newline at end of file