rmbg同一批次内改为流式返回
This commit is contained in:
parent
565115369d
commit
36f299403f
@ -607,79 +607,90 @@ class RmbgService:
|
|||||||
local_batch_items = valid_items[local_batch_start:local_batch_end]
|
local_batch_items = valid_items[local_batch_start:local_batch_end]
|
||||||
|
|
||||||
# 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU
|
# 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU
|
||||||
tasks = [
|
# 改为按完成顺序流式返回,避免等待最慢的那张
|
||||||
self.process_image(image)
|
tasks = []
|
||||||
for image, _, _, _ in local_batch_items
|
task_meta = {}
|
||||||
]
|
for image, _, index, url_str in local_batch_items:
|
||||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
t = asyncio.create_task(self.process_image(image))
|
||||||
|
tasks.append(t)
|
||||||
|
task_meta[t] = (index, url_str)
|
||||||
|
|
||||||
|
# 使用 wait 循环而不是 as_completed,避免未等待的协程残留
|
||||||
|
pending_tasks = set(tasks)
|
||||||
|
while pending_tasks:
|
||||||
|
done, pending_tasks = await asyncio.wait(
|
||||||
|
pending_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||||
|
)
|
||||||
|
for t in done:
|
||||||
|
index, url_str = task_meta[t]
|
||||||
|
try:
|
||||||
|
result_data = await t
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": f"处理失败: {str(e)}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(result_data, dict):
|
||||||
|
# 队列模式下 process_image 通常直接返回 {status, image_url}
|
||||||
|
status = result_data.get("status", "success")
|
||||||
|
image_url = result_data.get("image_url")
|
||||||
|
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
|
||||||
|
error_msg = result_data.get("error")
|
||||||
|
else:
|
||||||
|
# 兼容非 dict 返回:手动保存图片
|
||||||
|
image_url = await loop.run_in_executor(
|
||||||
|
self.executor, self.save_image_to_file, result_data
|
||||||
|
)
|
||||||
|
status = "success"
|
||||||
|
message = "处理成功"
|
||||||
|
error_msg = None
|
||||||
|
|
||||||
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
|
|
||||||
if isinstance(result_data, Exception):
|
|
||||||
error_count += 1
|
|
||||||
completed_order += 1
|
completed_order += 1
|
||||||
result = {
|
if status == "success" and image_url:
|
||||||
"index": index,
|
success_count += 1
|
||||||
"total": total,
|
result = {
|
||||||
"original_url": url_str,
|
"index": index,
|
||||||
"status": "error",
|
"total": total,
|
||||||
"error": str(result_data),
|
"original_url": url_str,
|
||||||
"message": f"处理失败: {str(result_data)}",
|
"status": "success",
|
||||||
"success_count": success_count,
|
"image_url": image_url,
|
||||||
"error_count": error_count,
|
"message": message,
|
||||||
"completed_order": completed_order,
|
"success_count": success_count,
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"error_count": error_count,
|
||||||
}
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "error",
|
||||||
|
"error": error_msg or "处理失败",
|
||||||
|
"message": message,
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
yield result
|
yield result
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(result_data, dict):
|
|
||||||
# 队列模式下 process_image 通常直接返回 {status, image_url}
|
|
||||||
status = result_data.get("status", "success")
|
|
||||||
image_url = result_data.get("image_url")
|
|
||||||
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
|
|
||||||
error_msg = result_data.get("error")
|
|
||||||
else:
|
|
||||||
# 兼容非 dict 返回:手动保存图片
|
|
||||||
image_url = await loop.run_in_executor(
|
|
||||||
self.executor, self.save_image_to_file, result_data
|
|
||||||
)
|
|
||||||
status = "success"
|
|
||||||
message = "处理成功"
|
|
||||||
error_msg = None
|
|
||||||
|
|
||||||
completed_order += 1
|
|
||||||
if status == "success" and image_url:
|
|
||||||
success_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": url_str,
|
|
||||||
"status": "success",
|
|
||||||
"image_url": image_url,
|
|
||||||
"message": message,
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
error_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": url_str,
|
|
||||||
"status": "error",
|
|
||||||
"error": error_msg or "处理失败",
|
|
||||||
"message": message,
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批处理失败: {str(e)}")
|
logger.error(f"批处理失败: {str(e)}")
|
||||||
for _, _, index, url_str, _ in valid_items:
|
for _, _, index, url_str in valid_items:
|
||||||
error_count += 1
|
error_count += 1
|
||||||
completed_order += 1
|
completed_order += 1
|
||||||
result = {
|
result = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user