japi/apps/jart_v1/service.py
2025-05-12 02:39:56 +08:00

272 lines
8.6 KiB
Python

import json
import base64
import requests
import random
import websocket
import uuid
import urllib.request
import asyncio
import io
from typing import Dict, List, Generator, Optional, AsyncGenerator
# 固定配置变量
DEFAULT_CONFIG = {
"comfyui_server_address": "192.168.2.200:8188",
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors",
"sampler_name": "euler",
"scheduler": "normal",
"steps": 20,
"cfg": 8,
"denoise": 1.0,
"images_per_prompt": 1,
"image_width": 1024,
"image_height": 1024,
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
}
# 定义基础工作流 JSON 模板
WORKFLOW_TEMPLATE = """
{
"3": {
"class_type": "KSampler",
"inputs": {
"cfg": %d,
"denoise": %d,
"latent_image": [
"5",
0
],
"model": [
"4",
0
],
"negative": [
"7",
0
],
"positive": [
"6",
0
],
"sampler_name": "%s",
"scheduler": "%s",
"seed": 8566257,
"steps": %d
}
},
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "%s"
}
},
"5": {
"class_type": "EmptyLatentImage",
"inputs": {
"batch_size": 1,
"height": %d,
"width": %d
}
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "masterpiece best quality girl"
}
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"clip": [
"4",
1
],
"text": "%s"
}
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
}
},
"save_image_websocket_node": {
"class_type": "SaveImageWebsocket",
"inputs": {
"images": [
"8",
0
]
}
}
}
"""
class TxtImgService:
def __init__(self):
"""初始化文本生成图像服务"""
pass
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""将提示词发送到 ComfyUI 服务器的队列中"""
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
response = json.loads(urllib.request.urlopen(req).read())
return response
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
"""从 ComfyUI 获取生成的图像"""
try:
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
prompt_id = prompt_response['prompt_id']
except KeyError:
return {}
output_images = {}
current_node = ""
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data.get('prompt_id') == prompt_id:
if data['node'] is None:
break
else:
current_node = data['node']
else:
if current_node == 'save_image_websocket_node':
images_output = output_images.get(current_node, [])
images_output.append(out[8:])
output_images[current_node] = images_output
return output_images
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
cfg = DEFAULT_CONFIG.copy()
if config:
cfg.update(config)
ws = websocket.WebSocket()
client_id = str(uuid.uuid4())
try:
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
images_count = int(cfg.get('images_per_prompt', 1))
for i in range(images_count):
workflow = json.loads(WORKFLOW_TEMPLATE % (
cfg['cfg'],
cfg['denoise'],
cfg['sampler_name'],
cfg['scheduler'],
cfg['steps'],
cfg['ckpt_name'],
cfg['image_height'],
cfg['image_width'],
cfg['negative_prompt']
))
workflow["6"]["inputs"]["text"] = prompt
seed = random.randint(1, 4294967295)
workflow["3"]["inputs"]["seed"] = seed
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
for node_id, image_list in images_dict.items():
for image_data in image_list:
base64_image = base64.b64encode(image_data).decode('utf-8')
yield base64_image
except Exception as e:
raise e
finally:
if ws:
ws.close()
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
"""异步生成图像,流式返回结果"""
try:
loop = asyncio.get_event_loop()
def sync_generator():
for base64_image in self.generate_image_sync(prompt, config):
yield base64_image
generator = await loop.run_in_executor(None, sync_generator)
for base64_image in generator:
yield {
"status": "success",
"image": f"data:image/png;base64,{base64_image}",
"message": f"成功生成图片"
}
except Exception as e:
yield {
"status": "error",
"message": f"图像生成失败: {str(e)}"
}
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
"""批量处理多个文本提示,流式返回结果"""
total = len(prompts)
success_count = 0
error_count = 0
for i, prompt in enumerate(prompts, 1):
try:
async for result in self.generate_image(prompt, config):
if result["status"] == "success":
success_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "success",
"image_content": result["image"],
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
else:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"success_count": success_count,
"error_count": error_count,
"message": result["message"]
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"original_prompt": prompt,
"status": "error",
"error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}"
}
await asyncio.sleep(0)