import json import base64 import requests import random import websocket import uuid import urllib.request import asyncio import io from typing import Dict, List, Generator, Optional, AsyncGenerator # 固定配置变量 DEFAULT_CONFIG = { "comfyui_server_address": "192.168.2.200:8188", "ckpt_name": "flux1-schnell-fp8.safetensors", "sampler_name": "euler", "scheduler": "normal", "steps": 4, "cfg": 1, "denoise": 1.0, "images_per_prompt": 1, "image_width": 1024, "image_height": 1024, "negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background" } # 定义基础工作流 JSON 模板 WORKFLOW_TEMPLATE = """ { "3": { "class_type": "KSampler", "inputs": { "cfg": %d, "denoise": %d, "latent_image": [ "5", 0 ], "model": [ "4", 0 ], "negative": [ "7", 0 ], "positive": [ "6", 0 ], "sampler_name": "%s", "scheduler": "%s", "seed": 8566257, "steps": %d } }, "4": { "class_type": "CheckpointLoaderSimple", "inputs": { "ckpt_name": "%s" } }, "5": { "class_type": "EmptyLatentImage", "inputs": { "batch_size": 1, "height": %d, "width": %d } }, "6": { "class_type": "CLIPTextEncode", "inputs": { "clip": [ "4", 1 ], "text": "masterpiece best quality girl" } }, "7": { "class_type": "CLIPTextEncode", "inputs": { "clip": [ "4", 1 ], "text": "%s" } }, "8": { "class_type": "VAEDecode", "inputs": { "samples": [ "3", 0 ], "vae": [ "4", 2 ] } }, "save_image_websocket_node": { "class_type": "SaveImageWebsocket", "inputs": { "images": [ "8", 0 ] } } } """ class TxtImgService: def __init__(self): """初始化文本生成图像服务""" pass def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: """将提示词发送到 ComfyUI 服务器的队列中""" p = {"prompt": prompt, "client_id": client_id} data = json.dumps(p).encode('utf-8') req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) response = json.loads(urllib.request.urlopen(req).read()) return response def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: """从 ComfyUI 获取生成的图像""" try: prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) prompt_id = prompt_response['prompt_id'] except KeyError: return {} output_images = {} current_node = "" while True: out = ws.recv() if isinstance(out, str): message = json.loads(out) if message['type'] == 'executing': data = message['data'] if data.get('prompt_id') == prompt_id: if data['node'] is None: break else: current_node = data['node'] else: if current_node == 'save_image_websocket_node': images_output = output_images.get(current_node, []) images_output.append(out[8:]) output_images[current_node] = images_output return output_images def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]: """生成 Flux 模型的图片,流式返回 base64 编码的图片""" cfg = DEFAULT_CONFIG.copy() if config: cfg.update(config) ws = websocket.WebSocket() client_id = str(uuid.uuid4()) try: ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") images_count = int(cfg.get('images_per_prompt', 1)) for i in range(images_count): workflow = json.loads(WORKFLOW_TEMPLATE % ( cfg['cfg'], cfg['denoise'], cfg['sampler_name'], cfg['scheduler'], cfg['steps'], cfg['ckpt_name'], cfg['image_height'], cfg['image_width'], cfg['negative_prompt'] )) workflow["6"]["inputs"]["text"] = prompt seed = random.randint(1, 4294967295) workflow["3"]["inputs"]["seed"] = seed images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) for node_id, image_list in images_dict.items(): for image_data in image_list: base64_image = base64.b64encode(image_data).decode('utf-8') yield base64_image except Exception as e: raise e finally: if ws: ws.close() async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: """异步生成图像,流式返回结果""" try: loop = asyncio.get_event_loop() def sync_generator(): for base64_image in self.generate_image_sync(prompt, config): yield base64_image generator = await loop.run_in_executor(None, sync_generator) for base64_image in generator: yield { "status": "success", "image": f"data:image/png;base64,{base64_image}", "message": f"成功生成图片" } except Exception as e: yield { "status": "error", "message": f"图像生成失败: {str(e)}" } async def process_batch(self, prompts: List[str], config: Optional[Dict] = None): """批量处理多个文本提示,流式返回结果""" total = len(prompts) success_count = 0 error_count = 0 for i, prompt in enumerate(prompts, 1): try: async for result in self.generate_image(prompt, config): if result["status"] == "success": success_count += 1 yield { "index": i, "total": total, "original_prompt": prompt, "status": "success", "image_content": result["image"], "success_count": success_count, "error_count": error_count, "message": result["message"] } else: error_count += 1 yield { "index": i, "total": total, "original_prompt": prompt, "status": "error", "success_count": success_count, "error_count": error_count, "message": result["message"] } except Exception as e: error_count += 1 yield { "index": i, "total": total, "original_prompt": prompt, "status": "error", "error": str(e), "success_count": success_count, "error_count": error_count, "message": f"处理失败: {str(e)}" } await asyncio.sleep(0)