清除jart_v1调试日志
This commit is contained in:
parent
61267eacdd
commit
e09aa09f14
@ -7,13 +7,8 @@ import websocket
|
|||||||
import uuid
|
import uuid
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
from typing import Dict, List, Optional, AsyncGenerator
|
from typing import Dict, List, Optional, AsyncGenerator
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
DEFAULT_CONFIG = {
|
DEFAULT_CONFIG = {
|
||||||
"comfyui_server_address": "192.168.2.200:8188",
|
"comfyui_server_address": "192.168.2.200:8188",
|
||||||
@ -131,19 +126,14 @@ WORKFLOW_TEMPLATE = """
|
|||||||
|
|
||||||
class TxtImgService:
|
class TxtImgService:
|
||||||
def __init__(self, config: Optional[Dict] = None):
|
def __init__(self, config: Optional[Dict] = None):
|
||||||
"""初始化文本生成图像服务"""
|
|
||||||
self.config = DEFAULT_CONFIG.copy()
|
self.config = DEFAULT_CONFIG.copy()
|
||||||
if config:
|
if config:
|
||||||
self.config.update(config)
|
self.config.update(config)
|
||||||
|
|
||||||
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
|
||||||
try:
|
try:
|
||||||
p = {"prompt": prompt, "client_id": client_id}
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
data = json.dumps(p).encode('utf-8')
|
data = json.dumps(p).encode('utf-8')
|
||||||
logger.debug(f"Server address: {comfyui_server_address}")
|
|
||||||
logger.debug(f"Request data: {json.dumps(p, indent=2)}")
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'Accept': 'application/json'
|
'Accept': 'application/json'
|
||||||
@ -153,45 +143,27 @@ class TxtImgService:
|
|||||||
data=data,
|
data=data,
|
||||||
headers=headers
|
headers=headers
|
||||||
)
|
)
|
||||||
|
|
||||||
response = urllib.request.urlopen(req)
|
response = urllib.request.urlopen(req)
|
||||||
response_data = response.read()
|
response_data = response.read()
|
||||||
logger.debug(f"Response status: {response.status}")
|
|
||||||
response_json = json.loads(response_data)
|
response_json = json.loads(response_data)
|
||||||
logger.debug(f"Server response: {json.dumps(response_json, indent=2)}")
|
|
||||||
return response_json
|
return response_json
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to queue prompt: {str(e)}")
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
"""从 ComfyUI 获取生成的图像"""
|
|
||||||
try:
|
try:
|
||||||
# 确保工作流中的所有节点都有正确的格式
|
|
||||||
for node_id, node_data in workflow.items():
|
for node_id, node_data in workflow.items():
|
||||||
if "inputs" not in node_data:
|
if "inputs" not in node_data:
|
||||||
node_data["inputs"] = {}
|
node_data["inputs"] = {}
|
||||||
if "class_type" not in node_data:
|
if "class_type" not in node_data:
|
||||||
logger.error(f"Node {node_id} missing class_type")
|
|
||||||
raise ValueError(f"Node {node_id} missing class_type")
|
raise ValueError(f"Node {node_id} missing class_type")
|
||||||
|
|
||||||
logger.debug(f"Queuing prompt with workflow: {json.dumps(workflow, indent=2)}")
|
|
||||||
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||||||
|
|
||||||
if not isinstance(prompt_response, dict):
|
if not isinstance(prompt_response, dict):
|
||||||
logger.error(f"Invalid response type: {type(prompt_response)}")
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
prompt_id = prompt_response.get('prompt_id')
|
prompt_id = prompt_response.get('prompt_id')
|
||||||
if not prompt_id:
|
if not prompt_id:
|
||||||
logger.error("No prompt_id in response")
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
logger.debug(f"Got prompt_id: {prompt_id}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get prompt_id: {str(e)}")
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
output_images = {}
|
output_images = {}
|
||||||
@ -201,7 +173,6 @@ class TxtImgService:
|
|||||||
out = ws.recv()
|
out = ws.recv()
|
||||||
if isinstance(out, str):
|
if isinstance(out, str):
|
||||||
message = json.loads(out)
|
message = json.loads(out)
|
||||||
logger.debug(f"Received message: {message}")
|
|
||||||
if message['type'] == 'executing':
|
if message['type'] == 'executing':
|
||||||
data = message['data']
|
data = message['data']
|
||||||
if data.get('prompt_id') == prompt_id:
|
if data.get('prompt_id') == prompt_id:
|
||||||
@ -209,41 +180,26 @@ class TxtImgService:
|
|||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
current_node = data['node']
|
current_node = data['node']
|
||||||
logger.debug(f"Processing node: {current_node}")
|
|
||||||
else:
|
else:
|
||||||
if current_node == '9': # SaveImageWebsocket节点ID
|
if current_node == '9':
|
||||||
images_output = output_images.get(current_node, [])
|
images_output = output_images.get(current_node, [])
|
||||||
images_output.append(out[8:])
|
images_output.append(out[8:])
|
||||||
output_images[current_node] = images_output
|
output_images[current_node] = images_output
|
||||||
logger.debug(f"Saved image for node: {current_node}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in websocket communication: {str(e)}")
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return output_images
|
return output_images
|
||||||
|
|
||||||
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||||
"""异步生成图像"""
|
|
||||||
cfg = self.config.copy()
|
cfg = self.config.copy()
|
||||||
if config:
|
if config:
|
||||||
cfg.update(config)
|
cfg.update(config)
|
||||||
|
|
||||||
ws = websocket.WebSocket()
|
ws = websocket.WebSocket()
|
||||||
client_id = str(uuid.uuid4())
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||||||
logger.info("WebSocket connected successfully")
|
|
||||||
|
|
||||||
for i in range(cfg['images_per_prompt']):
|
for i in range(cfg['images_per_prompt']):
|
||||||
logger.info(f"Processing image {i+1}/{cfg['images_per_prompt']}")
|
|
||||||
|
|
||||||
# 生成随机种子
|
|
||||||
seed = random.randint(1, 4294967295)
|
seed = random.randint(1, 4294967295)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 准备参数
|
|
||||||
params = (
|
params = (
|
||||||
cfg['ckpt_name'],
|
cfg['ckpt_name'],
|
||||||
cfg['clip_l_name'],
|
cfg['clip_l_name'],
|
||||||
@ -258,30 +214,19 @@ class TxtImgService:
|
|||||||
cfg['scheduler'],
|
cfg['scheduler'],
|
||||||
cfg['denoise']
|
cfg['denoise']
|
||||||
)
|
)
|
||||||
|
|
||||||
# 格式化工作流
|
|
||||||
workflow = json.loads(WORKFLOW_TEMPLATE % params)
|
workflow = json.loads(WORKFLOW_TEMPLATE % params)
|
||||||
|
|
||||||
# 设置提示词
|
|
||||||
workflow["16"]["inputs"]["text"] = prompt
|
workflow["16"]["inputs"]["text"] = prompt
|
||||||
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
|
workflow["40"]["inputs"]["text"] = cfg['negative_prompt']
|
||||||
|
|
||||||
# 移除空字段
|
|
||||||
for node in workflow.values():
|
for node in workflow.values():
|
||||||
if "widgets_values" in node:
|
if "widgets_values" in node:
|
||||||
del node["widgets_values"]
|
del node["widgets_values"]
|
||||||
|
|
||||||
# 获取生成的图像
|
|
||||||
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
images = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||||
|
|
||||||
if not images:
|
if not images:
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": "No images generated"
|
"message": "No images generated"
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 处理生成的图像
|
|
||||||
for node_id, image_list in images.items():
|
for node_id, image_list in images.items():
|
||||||
for image_data in image_list:
|
for image_data in image_list:
|
||||||
base64_image = base64.b64encode(image_data).decode('utf-8')
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
@ -290,18 +235,13 @@ class TxtImgService:
|
|||||||
"image": f"data:image/png;base64,{base64_image}",
|
"image": f"data:image/png;base64,{base64_image}",
|
||||||
"message": f"成功生成第 {i+1} 张图片"
|
"message": f"成功生成第 {i+1} 张图片"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating image: {str(e)}")
|
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"生成图片失败: {str(e)}"
|
"message": f"生成图片失败: {str(e)}"
|
||||||
}
|
}
|
||||||
|
await asyncio.sleep(2)
|
||||||
await asyncio.sleep(2) # 避免请求过于频繁
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"WebSocket connection error: {str(e)}")
|
|
||||||
yield {
|
yield {
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"message": f"WebSocket连接失败: {str(e)}"
|
"message": f"WebSocket连接失败: {str(e)}"
|
||||||
@ -309,14 +249,11 @@ class TxtImgService:
|
|||||||
finally:
|
finally:
|
||||||
if ws:
|
if ws:
|
||||||
ws.close()
|
ws.close()
|
||||||
logger.info("WebSocket connection closed")
|
|
||||||
|
|
||||||
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||||
"""批量处理多个提示词"""
|
|
||||||
total = len(prompts)
|
total = len(prompts)
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
|
|
||||||
for i, prompt in enumerate(prompts, 1):
|
for i, prompt in enumerate(prompts, 1):
|
||||||
try:
|
try:
|
||||||
async for result in self.generate_image(prompt, config):
|
async for result in self.generate_image(prompt, config):
|
||||||
@ -343,7 +280,6 @@ class TxtImgService:
|
|||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"message": result["message"]
|
"message": result["message"]
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_count += 1
|
error_count += 1
|
||||||
yield {
|
yield {
|
||||||
@ -356,5 +292,4 @@ class TxtImgService:
|
|||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"message": f"处理失败: {str(e)}"
|
"message": f"处理失败: {str(e)}"
|
||||||
}
|
}
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
Loading…
x
Reference in New Issue
Block a user