优化批处理为全局队列合批以充分利用多 GPU

- 重构 process_batch,将批处理入口改为逐张通过全局队列的 process_image
- 在本地批次内并发调用 process_image,让全局队列能凑大 batch 并触发多 GPU 并行
- 保留原有流式返回结构和统计字段,对外接口兼容不变
This commit is contained in:
jingrow 2025-12-15 19:08:13 +00:00
parent 530e7c8961
commit 565115369d

View File

@ -87,7 +87,7 @@ class RmbgService:
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if use_half else torch.float32,
dtype=torch.float16 if use_half else torch.float32,
)
model = model.to(device)
if use_half:
@ -558,10 +558,9 @@ class RmbgService:
completed_order = 0
pending_batch = []
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理
async def process_pending_batch(force=False):
"""处理待处理的批次"""
"""处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)"""
nonlocal pending_batch, completed_order, success_count, error_count
if not pending_batch:
@ -600,142 +599,56 @@ class RmbgService:
pending_batch = []
return
# 处理成功的图片
# 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度)
try:
# 判断是否尝试一次性处理
use_single_batch = len(valid_items) <= max_single_batch and force
if use_single_batch:
# 尝试一次性处理所有图片
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items]
batch_results = await self.process_batch_images(images_with_info)
# 并行保存
save_tasks = []
result_mapping = {}
for processed_image, index in batch_results:
url_str = next(url for _, _, idx, url in valid_items if idx == index)
result_mapping[index] = (processed_image, url_str)
save_task = loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
save_tasks.append((index, save_task))
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
for (index, _), image_url in zip(save_tasks, save_results):
if isinstance(image_url, Exception):
# 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块
for local_batch_start in range(0, len(valid_items), batch_size):
local_batch_end = min(local_batch_start + batch_size, len(valid_items))
local_batch_items = valid_items[local_batch_start:local_batch_end]
# 本地批次内并发调用 process_image让全局队列有机会凑大 batch 并利用多 GPU
tasks = [
self.process_image(image)
for image, _, _, _ in local_batch_items
]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
if isinstance(result_data, Exception):
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"original_url": url_str,
"status": "error",
"error": str(image_url),
"message": f"保存图片失败: {str(image_url)}",
"error": str(result_data),
"message": f"处理失败: {str(result_data)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
continue
if isinstance(result_data, dict):
# 队列模式下 process_image 通常直接返回 {status, image_url}
status = result_data.get("status", "success")
image_url = result_data.get("image_url")
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
error_msg = result_data.get("error")
else:
completed_order += 1
success_count += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "success",
"image_url": image_url,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
else:
# 分批处理
for batch_start in range(0, len(valid_items), batch_size):
batch_end = min(batch_start + batch_size, len(valid_items))
batch_items = valid_items[batch_start:batch_end]
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items]
batch_results = await self.process_batch_images(images_with_info)
# 并行保存
save_tasks = []
result_mapping = {}
for processed_image, index in batch_results:
url_str = next(url for _, _, idx, url in batch_items if idx == index)
result_mapping[index] = (processed_image, url_str)
save_task = loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
# 兼容非 dict 返回:手动保存图片
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result_data
)
save_tasks.append((index, save_task))
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
for (index, _), image_url in zip(save_tasks, save_results):
if isinstance(image_url, Exception):
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "error",
"error": str(image_url),
"message": f"保存图片失败: {str(image_url)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
else:
completed_order += 1
success_count += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "success",
"image_url": image_url,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except RuntimeError as e:
# CUDA OOM错误降级处理
error_msg = str(e)
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 降级到单张处理
for image, image_size, index, url_str, _ in valid_items:
try:
result_data = await self.process_image(image)
if isinstance(result_data, dict):
image_url = result_data["image_url"]
else:
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result_data
)
completed_order += 1
status = "success"
message = "处理成功"
error_msg = None
completed_order += 1
if status == "success" and image_url:
success_count += 1
result = {
"index": index,
@ -743,47 +656,26 @@ class RmbgService:
"original_url": url_str,
"status": "success",
"image_url": image_url,
"message": "处理成功(降级模式)",
"message": message,
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except Exception as e2:
else:
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e2),
"message": f"处理失败: {str(e2)}",
"error": error_msg or "处理失败",
"message": message,
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
else:
# 其他错误
logger.error(f"批处理失败: {error_msg}")
for _, _, index, url_str, _ in valid_items:
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error_msg,
"message": f"批处理失败: {error_msg}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except Exception as e:
logger.error(f"批处理失败: {str(e)}")