272 lines
8.6 KiB
Python
272 lines
8.6 KiB
Python
import json
|
|
import base64
|
|
import requests
|
|
import random
|
|
import websocket
|
|
import uuid
|
|
import urllib.request
|
|
import asyncio
|
|
import io
|
|
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
|
|
|
# 固定配置变量
|
|
DEFAULT_CONFIG = {
|
|
"comfyui_server_address": "192.168.2.200:8188",
|
|
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
|
"sampler_name": "euler",
|
|
"scheduler": "normal",
|
|
"steps": 20,
|
|
"cfg": 8,
|
|
"denoise": 1.0,
|
|
"images_per_prompt": 1,
|
|
"image_width": 1024,
|
|
"image_height": 1024,
|
|
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
|
|
}
|
|
|
|
# 定义基础工作流 JSON 模板
|
|
WORKFLOW_TEMPLATE = """
|
|
{
|
|
"3": {
|
|
"class_type": "KSampler",
|
|
"inputs": {
|
|
"cfg": %d,
|
|
"denoise": %d,
|
|
"latent_image": [
|
|
"5",
|
|
0
|
|
],
|
|
"model": [
|
|
"4",
|
|
0
|
|
],
|
|
"negative": [
|
|
"7",
|
|
0
|
|
],
|
|
"positive": [
|
|
"6",
|
|
0
|
|
],
|
|
"sampler_name": "%s",
|
|
"scheduler": "%s",
|
|
"seed": 8566257,
|
|
"steps": %d
|
|
}
|
|
},
|
|
"4": {
|
|
"class_type": "CheckpointLoaderSimple",
|
|
"inputs": {
|
|
"ckpt_name": "%s"
|
|
}
|
|
},
|
|
"5": {
|
|
"class_type": "EmptyLatentImage",
|
|
"inputs": {
|
|
"batch_size": 1,
|
|
"height": %d,
|
|
"width": %d
|
|
}
|
|
},
|
|
"6": {
|
|
"class_type": "CLIPTextEncode",
|
|
"inputs": {
|
|
"clip": [
|
|
"4",
|
|
1
|
|
],
|
|
"text": "masterpiece best quality girl"
|
|
}
|
|
},
|
|
"7": {
|
|
"class_type": "CLIPTextEncode",
|
|
"inputs": {
|
|
"clip": [
|
|
"4",
|
|
1
|
|
],
|
|
"text": "%s"
|
|
}
|
|
},
|
|
"8": {
|
|
"class_type": "VAEDecode",
|
|
"inputs": {
|
|
"samples": [
|
|
"3",
|
|
0
|
|
],
|
|
"vae": [
|
|
"4",
|
|
2
|
|
]
|
|
}
|
|
},
|
|
"save_image_websocket_node": {
|
|
"class_type": "SaveImageWebsocket",
|
|
"inputs": {
|
|
"images": [
|
|
"8",
|
|
0
|
|
]
|
|
}
|
|
}
|
|
}
|
|
"""
|
|
|
|
class TxtImgService:
|
|
def __init__(self):
|
|
"""初始化文本生成图像服务"""
|
|
pass
|
|
|
|
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
|
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
|
p = {"prompt": prompt, "client_id": client_id}
|
|
data = json.dumps(p).encode('utf-8')
|
|
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
|
response = json.loads(urllib.request.urlopen(req).read())
|
|
return response
|
|
|
|
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
|
"""从 ComfyUI 获取生成的图像"""
|
|
try:
|
|
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
|
prompt_id = prompt_response['prompt_id']
|
|
except KeyError:
|
|
return {}
|
|
|
|
output_images = {}
|
|
current_node = ""
|
|
while True:
|
|
out = ws.recv()
|
|
if isinstance(out, str):
|
|
message = json.loads(out)
|
|
if message['type'] == 'executing':
|
|
data = message['data']
|
|
if data.get('prompt_id') == prompt_id:
|
|
if data['node'] is None:
|
|
break
|
|
else:
|
|
current_node = data['node']
|
|
else:
|
|
if current_node == 'save_image_websocket_node':
|
|
images_output = output_images.get(current_node, [])
|
|
images_output.append(out[8:])
|
|
output_images[current_node] = images_output
|
|
|
|
return output_images
|
|
|
|
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
|
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
|
|
cfg = DEFAULT_CONFIG.copy()
|
|
if config:
|
|
cfg.update(config)
|
|
|
|
ws = websocket.WebSocket()
|
|
client_id = str(uuid.uuid4())
|
|
|
|
try:
|
|
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
|
images_count = int(cfg.get('images_per_prompt', 1))
|
|
|
|
for i in range(images_count):
|
|
workflow = json.loads(WORKFLOW_TEMPLATE % (
|
|
cfg['cfg'],
|
|
cfg['denoise'],
|
|
cfg['sampler_name'],
|
|
cfg['scheduler'],
|
|
cfg['steps'],
|
|
cfg['ckpt_name'],
|
|
cfg['image_height'],
|
|
cfg['image_width'],
|
|
cfg['negative_prompt']
|
|
))
|
|
|
|
workflow["6"]["inputs"]["text"] = prompt
|
|
seed = random.randint(1, 4294967295)
|
|
workflow["3"]["inputs"]["seed"] = seed
|
|
|
|
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
|
|
|
for node_id, image_list in images_dict.items():
|
|
for image_data in image_list:
|
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
|
yield base64_image
|
|
|
|
except Exception as e:
|
|
raise e
|
|
|
|
finally:
|
|
if ws:
|
|
ws.close()
|
|
|
|
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
|
"""异步生成图像,流式返回结果"""
|
|
try:
|
|
loop = asyncio.get_event_loop()
|
|
|
|
def sync_generator():
|
|
for base64_image in self.generate_image_sync(prompt, config):
|
|
yield base64_image
|
|
|
|
generator = await loop.run_in_executor(None, sync_generator)
|
|
|
|
for base64_image in generator:
|
|
yield {
|
|
"status": "success",
|
|
"image": f"data:image/png;base64,{base64_image}",
|
|
"message": f"成功生成图片"
|
|
}
|
|
|
|
except Exception as e:
|
|
yield {
|
|
"status": "error",
|
|
"message": f"图像生成失败: {str(e)}"
|
|
}
|
|
|
|
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
|
|
"""批量处理多个文本提示,流式返回结果"""
|
|
total = len(prompts)
|
|
success_count = 0
|
|
error_count = 0
|
|
|
|
for i, prompt in enumerate(prompts, 1):
|
|
try:
|
|
async for result in self.generate_image(prompt, config):
|
|
if result["status"] == "success":
|
|
success_count += 1
|
|
yield {
|
|
"index": i,
|
|
"total": total,
|
|
"original_prompt": prompt,
|
|
"status": "success",
|
|
"image_content": result["image"],
|
|
"success_count": success_count,
|
|
"error_count": error_count,
|
|
"message": result["message"]
|
|
}
|
|
else:
|
|
error_count += 1
|
|
yield {
|
|
"index": i,
|
|
"total": total,
|
|
"original_prompt": prompt,
|
|
"status": "error",
|
|
"success_count": success_count,
|
|
"error_count": error_count,
|
|
"message": result["message"]
|
|
}
|
|
|
|
except Exception as e:
|
|
error_count += 1
|
|
yield {
|
|
"index": i,
|
|
"total": total,
|
|
"original_prompt": prompt,
|
|
"status": "error",
|
|
"error": str(e),
|
|
"success_count": success_count,
|
|
"error_count": error_count,
|
|
"message": f"处理失败: {str(e)}"
|
|
}
|
|
|
|
await asyncio.sleep(0) |