feat(rmbg): 添加批处理性能统计功能
- 记录下载图片、GPU推理、保存图片各阶段的耗时 - 输出详细的性能统计信息,包括: * 图片总数、成功/失败数量 * 批处理次数和每批图片数 * 各阶段耗时及占比 * 总耗时、平均每张耗时、每批平均耗时 - 使用统一的日志格式输出统计信息
This commit is contained in:
parent
4a906d87fb
commit
cecd617104
@ -454,6 +454,60 @@ class RmbgService:
|
||||
batch_size = settings.batch_size
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 性能统计变量
|
||||
download_time = 0.0
|
||||
gpu_inference_time = 0.0
|
||||
save_time = 0.0
|
||||
batch_count = 0
|
||||
batch_sizes = []
|
||||
stats_printed = False
|
||||
|
||||
def print_stats():
|
||||
"""输出性能统计信息"""
|
||||
nonlocal stats_printed
|
||||
if stats_printed:
|
||||
return
|
||||
stats_printed = True
|
||||
|
||||
total_time = time.time() - batch_start_time
|
||||
other_time = total_time - download_time - gpu_inference_time - save_time
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("📊 批处理性能统计")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"图片总数: {total}")
|
||||
logger.info(f"成功数量: {success_count}")
|
||||
logger.info(f"失败数量: {error_count}")
|
||||
logger.info(f"批处理次数: {batch_count}")
|
||||
logger.info(f"每批图片数: {batch_sizes}")
|
||||
logger.info("-" * 60)
|
||||
logger.info("⏱️ 各阶段耗时:")
|
||||
|
||||
if total_time > 0:
|
||||
download_pct = (download_time / total_time) * 100
|
||||
gpu_pct = (gpu_inference_time / total_time) * 100
|
||||
save_pct = (save_time / total_time) * 100
|
||||
other_pct = (other_time / total_time) * 100
|
||||
|
||||
logger.info(f" 1. 下载图片: {download_time:.3f}s ({download_pct:.1f}%)")
|
||||
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s ({gpu_pct:.1f}%)")
|
||||
logger.info(f" 3. 保存图片: {save_time:.3f}s ({save_pct:.1f}%)")
|
||||
logger.info(f" 4. 其他开销: {other_time:.3f}s ({other_pct:.1f}%)")
|
||||
else:
|
||||
logger.info(f" 1. 下载图片: {download_time:.3f}s")
|
||||
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s")
|
||||
logger.info(f" 3. 保存图片: {save_time:.3f}s")
|
||||
logger.info(f" 4. 其他开销: {other_time:.3f}s")
|
||||
|
||||
logger.info("-" * 60)
|
||||
logger.info(f"📈 总耗时: {total_time:.3f}s")
|
||||
if total > 0:
|
||||
avg_per_image = (total_time / total) * 1000
|
||||
logger.info(f"📈 平均每张: {avg_per_image:.1f}ms")
|
||||
if batch_count > 0:
|
||||
avg_batch_time = gpu_inference_time / batch_count
|
||||
logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s")
|
||||
|
||||
async def download_image_async(index, url):
|
||||
"""异步下载图片"""
|
||||
url_str = str(url)
|
||||
@ -472,8 +526,11 @@ class RmbgService:
|
||||
except Exception as e:
|
||||
return (None, None, index, url_str, str(e))
|
||||
|
||||
# 记录下载开始时间
|
||||
download_start_time = time.time()
|
||||
download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)]
|
||||
downloaded_images = await asyncio.gather(*download_tasks)
|
||||
download_time = time.time() - download_start_time
|
||||
|
||||
valid_images = []
|
||||
failed_results = {}
|
||||
@ -511,7 +568,13 @@ class RmbgService:
|
||||
if use_single_batch:
|
||||
try:
|
||||
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images]
|
||||
|
||||
# 记录GPU推理开始时间
|
||||
gpu_start_time = time.time()
|
||||
batch_results = await self.process_batch_images(images_with_info)
|
||||
gpu_inference_time += time.time() - gpu_start_time
|
||||
batch_count += 1
|
||||
batch_sizes.append(len(images_with_info))
|
||||
|
||||
# 并行保存所有图片
|
||||
save_tasks = []
|
||||
@ -526,7 +589,10 @@ class RmbgService:
|
||||
)
|
||||
save_tasks.append((index, save_task))
|
||||
|
||||
# 记录保存开始时间
|
||||
save_start_time = time.time()
|
||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||
save_time += time.time() - save_start_time
|
||||
|
||||
for (index, _), image_url in zip(save_tasks, save_results):
|
||||
if isinstance(image_url, Exception):
|
||||
@ -591,6 +657,7 @@ class RmbgService:
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
print_stats()
|
||||
return
|
||||
except Exception as e:
|
||||
# 其他异常,直接返回错误
|
||||
@ -611,6 +678,7 @@ class RmbgService:
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
print_stats()
|
||||
return
|
||||
|
||||
# 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理
|
||||
@ -622,7 +690,13 @@ class RmbgService:
|
||||
|
||||
try:
|
||||
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images]
|
||||
|
||||
# 记录GPU推理开始时间
|
||||
gpu_start_time = time.time()
|
||||
batch_results = await self.process_batch_images(images_with_info)
|
||||
gpu_inference_time += time.time() - gpu_start_time
|
||||
batch_count += 1
|
||||
batch_sizes.append(len(images_with_info))
|
||||
|
||||
# 并行保存所有图片
|
||||
save_tasks = []
|
||||
@ -637,8 +711,11 @@ class RmbgService:
|
||||
)
|
||||
save_tasks.append((index, save_task))
|
||||
|
||||
# 记录保存开始时间
|
||||
save_start_time = time.time()
|
||||
# 并行执行所有保存任务
|
||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||
save_time += time.time() - save_start_time
|
||||
|
||||
# 按顺序返回结果
|
||||
for (index, _), image_url in zip(save_tasks, save_results):
|
||||
@ -693,6 +770,9 @@ class RmbgService:
|
||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||
}
|
||||
yield result
|
||||
|
||||
# 输出性能统计信息
|
||||
print_stats()
|
||||
|
||||
def is_valid_url(self, url):
|
||||
"""验证URL是否有效"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user