优化批处理为全局队列合批以充分利用多 GPU
- 重构 process_batch,将批处理入口改为逐张通过全局队列的 process_image - 在本地批次内并发调用 process_image,让全局队列能凑大 batch 并触发多 GPU 并行 - 保留原有流式返回结构和统计字段,对外接口兼容不变
This commit is contained in:
parent
530e7c8961
commit
5a06746fcb
@ -558,10 +558,9 @@ class RmbgService:
|
|||||||
completed_order = 0
|
completed_order = 0
|
||||||
pending_batch = []
|
pending_batch = []
|
||||||
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
||||||
max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理
|
|
||||||
|
|
||||||
async def process_pending_batch(force=False):
|
async def process_pending_batch(force=False):
|
||||||
"""处理待处理的批次"""
|
"""处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)"""
|
||||||
nonlocal pending_batch, completed_order, success_count, error_count
|
nonlocal pending_batch, completed_order, success_count, error_count
|
||||||
|
|
||||||
if not pending_batch:
|
if not pending_batch:
|
||||||
@ -600,142 +599,56 @@ class RmbgService:
|
|||||||
pending_batch = []
|
pending_batch = []
|
||||||
return
|
return
|
||||||
|
|
||||||
# 处理成功的图片
|
# 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度)
|
||||||
try:
|
try:
|
||||||
# 判断是否尝试一次性处理
|
# 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块
|
||||||
use_single_batch = len(valid_items) <= max_single_batch and force
|
for local_batch_start in range(0, len(valid_items), batch_size):
|
||||||
|
local_batch_end = min(local_batch_start + batch_size, len(valid_items))
|
||||||
if use_single_batch:
|
local_batch_items = valid_items[local_batch_start:local_batch_end]
|
||||||
# 尝试一次性处理所有图片
|
|
||||||
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items]
|
# 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU
|
||||||
|
tasks = [
|
||||||
batch_results = await self.process_batch_images(images_with_info)
|
self.process_image(image)
|
||||||
|
for image, _, _, _ in local_batch_items
|
||||||
# 并行保存
|
]
|
||||||
save_tasks = []
|
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
result_mapping = {}
|
|
||||||
|
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
|
||||||
for processed_image, index in batch_results:
|
if isinstance(result_data, Exception):
|
||||||
url_str = next(url for _, _, idx, url in valid_items if idx == index)
|
|
||||||
result_mapping[index] = (processed_image, url_str)
|
|
||||||
save_task = loop.run_in_executor(
|
|
||||||
self.executor, self.save_image_to_file, processed_image
|
|
||||||
)
|
|
||||||
save_tasks.append((index, save_task))
|
|
||||||
|
|
||||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
|
||||||
|
|
||||||
for (index, _), image_url in zip(save_tasks, save_results):
|
|
||||||
if isinstance(image_url, Exception):
|
|
||||||
error_count += 1
|
error_count += 1
|
||||||
completed_order += 1
|
completed_order += 1
|
||||||
result = {
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": result_mapping[index][1],
|
"original_url": url_str,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(image_url),
|
"error": str(result_data),
|
||||||
"message": f"保存图片失败: {str(image_url)}",
|
"message": f"处理失败: {str(result_data)}",
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"completed_order": completed_order,
|
"completed_order": completed_order,
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
yield result
|
yield result
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(result_data, dict):
|
||||||
|
# 队列模式下 process_image 通常直接返回 {status, image_url}
|
||||||
|
status = result_data.get("status", "success")
|
||||||
|
image_url = result_data.get("image_url")
|
||||||
|
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
|
||||||
|
error_msg = result_data.get("error")
|
||||||
else:
|
else:
|
||||||
completed_order += 1
|
# 兼容非 dict 返回:手动保存图片
|
||||||
success_count += 1
|
image_url = await loop.run_in_executor(
|
||||||
result = {
|
self.executor, self.save_image_to_file, result_data
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": result_mapping[index][1],
|
|
||||||
"status": "success",
|
|
||||||
"image_url": image_url,
|
|
||||||
"message": "处理成功",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
else:
|
|
||||||
# 分批处理
|
|
||||||
for batch_start in range(0, len(valid_items), batch_size):
|
|
||||||
batch_end = min(batch_start + batch_size, len(valid_items))
|
|
||||||
batch_items = valid_items[batch_start:batch_end]
|
|
||||||
|
|
||||||
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items]
|
|
||||||
|
|
||||||
batch_results = await self.process_batch_images(images_with_info)
|
|
||||||
|
|
||||||
# 并行保存
|
|
||||||
save_tasks = []
|
|
||||||
result_mapping = {}
|
|
||||||
|
|
||||||
for processed_image, index in batch_results:
|
|
||||||
url_str = next(url for _, _, idx, url in batch_items if idx == index)
|
|
||||||
result_mapping[index] = (processed_image, url_str)
|
|
||||||
save_task = loop.run_in_executor(
|
|
||||||
self.executor, self.save_image_to_file, processed_image
|
|
||||||
)
|
)
|
||||||
save_tasks.append((index, save_task))
|
status = "success"
|
||||||
|
message = "处理成功"
|
||||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
error_msg = None
|
||||||
|
|
||||||
for (index, _), image_url in zip(save_tasks, save_results):
|
completed_order += 1
|
||||||
if isinstance(image_url, Exception):
|
if status == "success" and image_url:
|
||||||
error_count += 1
|
|
||||||
completed_order += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": result_mapping[index][1],
|
|
||||||
"status": "error",
|
|
||||||
"error": str(image_url),
|
|
||||||
"message": f"保存图片失败: {str(image_url)}",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
else:
|
|
||||||
completed_order += 1
|
|
||||||
success_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": result_mapping[index][1],
|
|
||||||
"status": "success",
|
|
||||||
"image_url": image_url,
|
|
||||||
"message": "处理成功",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
|
|
||||||
except RuntimeError as e:
|
|
||||||
# CUDA OOM错误,降级处理
|
|
||||||
error_msg = str(e)
|
|
||||||
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
|
|
||||||
logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}")
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
# 降级到单张处理
|
|
||||||
for image, image_size, index, url_str, _ in valid_items:
|
|
||||||
try:
|
|
||||||
result_data = await self.process_image(image)
|
|
||||||
if isinstance(result_data, dict):
|
|
||||||
image_url = result_data["image_url"]
|
|
||||||
else:
|
|
||||||
image_url = await loop.run_in_executor(
|
|
||||||
self.executor, self.save_image_to_file, result_data
|
|
||||||
)
|
|
||||||
completed_order += 1
|
|
||||||
success_count += 1
|
success_count += 1
|
||||||
result = {
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
@ -743,47 +656,26 @@ class RmbgService:
|
|||||||
"original_url": url_str,
|
"original_url": url_str,
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"image_url": image_url,
|
"image_url": image_url,
|
||||||
"message": "处理成功(降级模式)",
|
"message": message,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"completed_order": completed_order,
|
"completed_order": completed_order,
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
yield result
|
else:
|
||||||
except Exception as e2:
|
|
||||||
error_count += 1
|
error_count += 1
|
||||||
completed_order += 1
|
|
||||||
result = {
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": url_str,
|
"original_url": url_str,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e2),
|
"error": error_msg or "处理失败",
|
||||||
"message": f"处理失败: {str(e2)}",
|
"message": message,
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"completed_order": completed_order,
|
"completed_order": completed_order,
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
yield result
|
|
||||||
else:
|
|
||||||
# 其他错误
|
|
||||||
logger.error(f"批处理失败: {error_msg}")
|
|
||||||
for _, _, index, url_str, _ in valid_items:
|
|
||||||
error_count += 1
|
|
||||||
completed_order += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": url_str,
|
|
||||||
"status": "error",
|
|
||||||
"error": error_msg,
|
|
||||||
"message": f"批处理失败: {error_msg}",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
yield result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批处理失败: {str(e)}")
|
logger.error(f"批处理失败: {str(e)}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user