优化批处理为全局队列合批以充分利用多 GPU
- 重构 process_batch,将批处理入口改为逐张通过全局队列的 process_image - 在本地批次内并发调用 process_image,让全局队列能凑大 batch 并触发多 GPU 并行 - 保留原有流式返回结构和统计字段,对外接口兼容不变
This commit is contained in:
parent
530e7c8961
commit
565115369d
@ -87,7 +87,7 @@ class RmbgService:
|
||||
model = AutoModelForImageSegmentation.from_pretrained(
|
||||
self.model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16 if use_half else torch.float32,
|
||||
dtype=torch.float16 if use_half else torch.float32,
|
||||
)
|
||||
model = model.to(device)
|
||||
if use_half:
|
||||
@ -558,10 +558,9 @@ class RmbgService:
|
||||
completed_order = 0
|
||||
pending_batch = []
|
||||
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
||||
max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理
|
||||
|
||||
async def process_pending_batch(force=False):
|
||||
"""处理待处理的批次"""
|
||||
"""处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)"""
|
||||
nonlocal pending_batch, completed_order, success_count, error_count
|
||||
|
||||
if not pending_batch:
|
||||
@ -600,142 +599,56 @@ class RmbgService:
|
||||
pending_batch = []
|
||||
return
|
||||
|
||||
# 处理成功的图片
|
||||
# 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度)
|
||||
try:
|
||||
# 判断是否尝试一次性处理
|
||||
use_single_batch = len(valid_items) <= max_single_batch and force
|
||||
|
||||
if use_single_batch:
|
||||
# 尝试一次性处理所有图片
|
||||
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items]
|
||||
|
||||
batch_results = await self.process_batch_images(images_with_info)
|
||||
|
||||
# 并行保存
|
||||
save_tasks = []
|
||||
result_mapping = {}
|
||||
|
||||
for processed_image, index in batch_results:
|
||||
url_str = next(url for _, _, idx, url in valid_items if idx == index)
|
||||
result_mapping[index] = (processed_image, url_str)
|
||||
save_task = loop.run_in_executor(
|
||||
self.executor, self.save_image_to_file, processed_image
|
||||
)
|
||||
save_tasks.append((index, save_task))
|
||||
|
||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||
|
||||
for (index, _), image_url in zip(save_tasks, save_results):
|
||||
if isinstance(image_url, Exception):
|
||||
# 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块
|
||||
for local_batch_start in range(0, len(valid_items), batch_size):
|
||||
local_batch_end = min(local_batch_start + batch_size, len(valid_items))
|
||||
local_batch_items = valid_items[local_batch_start:local_batch_end]
|
||||
|
||||
# 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU
|
||||
tasks = [
|
||||
self.process_image(image)
|
||||
for image, _, _, _ in local_batch_items
|
||||
]
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
|
||||
if isinstance(result_data, Exception):
|
||||
error_count += 1
|
||||
completed_order += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": result_mapping[index][1],
|
||||
"original_url": url_str,
|
||||
"status": "error",
|
||||
"error": str(image_url),
|
||||
"message": f"保存图片失败: {str(image_url)}",
|
||||
"error": str(result_data),
|
||||
"message": f"处理失败: {str(result_data)}",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
continue
|
||||
|
||||
if isinstance(result_data, dict):
|
||||
# 队列模式下 process_image 通常直接返回 {status, image_url}
|
||||
status = result_data.get("status", "success")
|
||||
image_url = result_data.get("image_url")
|
||||
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
|
||||
error_msg = result_data.get("error")
|
||||
else:
|
||||
completed_order += 1
|
||||
success_count += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": result_mapping[index][1],
|
||||
"status": "success",
|
||||
"image_url": image_url,
|
||||
"message": "处理成功",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
else:
|
||||
# 分批处理
|
||||
for batch_start in range(0, len(valid_items), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(valid_items))
|
||||
batch_items = valid_items[batch_start:batch_end]
|
||||
|
||||
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items]
|
||||
|
||||
batch_results = await self.process_batch_images(images_with_info)
|
||||
|
||||
# 并行保存
|
||||
save_tasks = []
|
||||
result_mapping = {}
|
||||
|
||||
for processed_image, index in batch_results:
|
||||
url_str = next(url for _, _, idx, url in batch_items if idx == index)
|
||||
result_mapping[index] = (processed_image, url_str)
|
||||
save_task = loop.run_in_executor(
|
||||
self.executor, self.save_image_to_file, processed_image
|
||||
# 兼容非 dict 返回:手动保存图片
|
||||
image_url = await loop.run_in_executor(
|
||||
self.executor, self.save_image_to_file, result_data
|
||||
)
|
||||
save_tasks.append((index, save_task))
|
||||
|
||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||
|
||||
for (index, _), image_url in zip(save_tasks, save_results):
|
||||
if isinstance(image_url, Exception):
|
||||
error_count += 1
|
||||
completed_order += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": result_mapping[index][1],
|
||||
"status": "error",
|
||||
"error": str(image_url),
|
||||
"message": f"保存图片失败: {str(image_url)}",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
else:
|
||||
completed_order += 1
|
||||
success_count += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": result_mapping[index][1],
|
||||
"status": "success",
|
||||
"image_url": image_url,
|
||||
"message": "处理成功",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
|
||||
except RuntimeError as e:
|
||||
# CUDA OOM错误,降级处理
|
||||
error_msg = str(e)
|
||||
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
|
||||
logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}")
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# 降级到单张处理
|
||||
for image, image_size, index, url_str, _ in valid_items:
|
||||
try:
|
||||
result_data = await self.process_image(image)
|
||||
if isinstance(result_data, dict):
|
||||
image_url = result_data["image_url"]
|
||||
else:
|
||||
image_url = await loop.run_in_executor(
|
||||
self.executor, self.save_image_to_file, result_data
|
||||
)
|
||||
completed_order += 1
|
||||
status = "success"
|
||||
message = "处理成功"
|
||||
error_msg = None
|
||||
|
||||
completed_order += 1
|
||||
if status == "success" and image_url:
|
||||
success_count += 1
|
||||
result = {
|
||||
"index": index,
|
||||
@ -743,47 +656,26 @@ class RmbgService:
|
||||
"original_url": url_str,
|
||||
"status": "success",
|
||||
"image_url": image_url,
|
||||
"message": "处理成功(降级模式)",
|
||||
"message": message,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
except Exception as e2:
|
||||
else:
|
||||
error_count += 1
|
||||
completed_order += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": url_str,
|
||||
"status": "error",
|
||||
"error": str(e2),
|
||||
"message": f"处理失败: {str(e2)}",
|
||||
"error": error_msg or "处理失败",
|
||||
"message": message,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
else:
|
||||
# 其他错误
|
||||
logger.error(f"批处理失败: {error_msg}")
|
||||
for _, _, index, url_str, _ in valid_items:
|
||||
error_count += 1
|
||||
completed_order += 1
|
||||
result = {
|
||||
"index": index,
|
||||
"total": total,
|
||||
"original_url": url_str,
|
||||
"status": "error",
|
||||
"error": error_msg,
|
||||
"message": f"批处理失败: {error_msg}",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"completed_order": completed_order,
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
except Exception as e:
|
||||
logger.error(f"批处理失败: {str(e)}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user