feat(rmbg): 添加批处理性能统计功能

- 记录下载图片、GPU推理、保存图片各阶段的耗时
- 输出详细的性能统计信息,包括:
  * 图片总数、成功/失败数量
  * 批处理次数和每批图片数
  * 各阶段耗时及占比
  * 总耗时、平均每张耗时、每批平均耗时
- 使用统一的日志格式输出统计信息
This commit is contained in:
jingrow 2025-11-23 15:48:33 +08:00
parent 4a906d87fb
commit cecd617104

View File

@ -454,6 +454,60 @@ class RmbgService:
batch_size = settings.batch_size
loop = asyncio.get_event_loop()
# 性能统计变量
download_time = 0.0
gpu_inference_time = 0.0
save_time = 0.0
batch_count = 0
batch_sizes = []
stats_printed = False
def print_stats():
"""输出性能统计信息"""
nonlocal stats_printed
if stats_printed:
return
stats_printed = True
total_time = time.time() - batch_start_time
other_time = total_time - download_time - gpu_inference_time - save_time
logger.info("=" * 60)
logger.info("📊 批处理性能统计")
logger.info("=" * 60)
logger.info(f"图片总数: {total}")
logger.info(f"成功数量: {success_count}")
logger.info(f"失败数量: {error_count}")
logger.info(f"批处理次数: {batch_count}")
logger.info(f"每批图片数: {batch_sizes}")
logger.info("-" * 60)
logger.info("⏱️ 各阶段耗时:")
if total_time > 0:
download_pct = (download_time / total_time) * 100
gpu_pct = (gpu_inference_time / total_time) * 100
save_pct = (save_time / total_time) * 100
other_pct = (other_time / total_time) * 100
logger.info(f" 1. 下载图片: {download_time:.3f}s ({download_pct:.1f}%)")
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s ({gpu_pct:.1f}%)")
logger.info(f" 3. 保存图片: {save_time:.3f}s ({save_pct:.1f}%)")
logger.info(f" 4. 其他开销: {other_time:.3f}s ({other_pct:.1f}%)")
else:
logger.info(f" 1. 下载图片: {download_time:.3f}s")
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s")
logger.info(f" 3. 保存图片: {save_time:.3f}s")
logger.info(f" 4. 其他开销: {other_time:.3f}s")
logger.info("-" * 60)
logger.info(f"📈 总耗时: {total_time:.3f}s")
if total > 0:
avg_per_image = (total_time / total) * 1000
logger.info(f"📈 平均每张: {avg_per_image:.1f}ms")
if batch_count > 0:
avg_batch_time = gpu_inference_time / batch_count
logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s")
async def download_image_async(index, url):
"""异步下载图片"""
url_str = str(url)
@ -472,8 +526,11 @@ class RmbgService:
except Exception as e:
return (None, None, index, url_str, str(e))
# 记录下载开始时间
download_start_time = time.time()
download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)]
downloaded_images = await asyncio.gather(*download_tasks)
download_time = time.time() - download_start_time
valid_images = []
failed_results = {}
@ -511,7 +568,13 @@ class RmbgService:
if use_single_batch:
try:
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images]
# 记录GPU推理开始时间
gpu_start_time = time.time()
batch_results = await self.process_batch_images(images_with_info)
gpu_inference_time += time.time() - gpu_start_time
batch_count += 1
batch_sizes.append(len(images_with_info))
# 并行保存所有图片
save_tasks = []
@ -526,7 +589,10 @@ class RmbgService:
)
save_tasks.append((index, save_task))
# 记录保存开始时间
save_start_time = time.time()
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
save_time += time.time() - save_start_time
for (index, _), image_url in zip(save_tasks, save_results):
if isinstance(image_url, Exception):
@ -591,6 +657,7 @@ class RmbgService:
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
print_stats()
return
except Exception as e:
# 其他异常,直接返回错误
@ -611,6 +678,7 @@ class RmbgService:
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
print_stats()
return
# 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理
@ -622,7 +690,13 @@ class RmbgService:
try:
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images]
# 记录GPU推理开始时间
gpu_start_time = time.time()
batch_results = await self.process_batch_images(images_with_info)
gpu_inference_time += time.time() - gpu_start_time
batch_count += 1
batch_sizes.append(len(images_with_info))
# 并行保存所有图片
save_tasks = []
@ -637,8 +711,11 @@ class RmbgService:
)
save_tasks.append((index, save_task))
# 记录保存开始时间
save_start_time = time.time()
# 并行执行所有保存任务
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
save_time += time.time() - save_start_time
# 按顺序返回结果
for (index, _), image_url in zip(save_tasks, save_results):
@ -693,6 +770,9 @@ class RmbgService:
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
# 输出性能统计信息
print_stats()
def is_valid_url(self, url):
"""验证URL是否有效"""