diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 9ba22eb..0363b47 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -454,6 +454,60 @@ class RmbgService: batch_size = settings.batch_size loop = asyncio.get_event_loop() + # 性能统计变量 + download_time = 0.0 + gpu_inference_time = 0.0 + save_time = 0.0 + batch_count = 0 + batch_sizes = [] + stats_printed = False + + def print_stats(): + """输出性能统计信息""" + nonlocal stats_printed + if stats_printed: + return + stats_printed = True + + total_time = time.time() - batch_start_time + other_time = total_time - download_time - gpu_inference_time - save_time + + logger.info("=" * 60) + logger.info("📊 批处理性能统计") + logger.info("=" * 60) + logger.info(f"图片总数: {total}") + logger.info(f"成功数量: {success_count}") + logger.info(f"失败数量: {error_count}") + logger.info(f"批处理次数: {batch_count}") + logger.info(f"每批图片数: {batch_sizes}") + logger.info("-" * 60) + logger.info("⏱️ 各阶段耗时:") + + if total_time > 0: + download_pct = (download_time / total_time) * 100 + gpu_pct = (gpu_inference_time / total_time) * 100 + save_pct = (save_time / total_time) * 100 + other_pct = (other_time / total_time) * 100 + + logger.info(f" 1. 下载图片: {download_time:.3f}s ({download_pct:.1f}%)") + logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s ({gpu_pct:.1f}%)") + logger.info(f" 3. 保存图片: {save_time:.3f}s ({save_pct:.1f}%)") + logger.info(f" 4. 其他开销: {other_time:.3f}s ({other_pct:.1f}%)") + else: + logger.info(f" 1. 下载图片: {download_time:.3f}s") + logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s") + logger.info(f" 3. 保存图片: {save_time:.3f}s") + logger.info(f" 4. 其他开销: {other_time:.3f}s") + + logger.info("-" * 60) + logger.info(f"📈 总耗时: {total_time:.3f}s") + if total > 0: + avg_per_image = (total_time / total) * 1000 + logger.info(f"📈 平均每张: {avg_per_image:.1f}ms") + if batch_count > 0: + avg_batch_time = gpu_inference_time / batch_count + logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s") + async def download_image_async(index, url): """异步下载图片""" url_str = str(url) @@ -472,8 +526,11 @@ class RmbgService: except Exception as e: return (None, None, index, url_str, str(e)) + # 记录下载开始时间 + download_start_time = time.time() download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)] downloaded_images = await asyncio.gather(*download_tasks) + download_time = time.time() - download_start_time valid_images = [] failed_results = {} @@ -511,7 +568,13 @@ class RmbgService: if use_single_batch: try: images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images] + + # 记录GPU推理开始时间 + gpu_start_time = time.time() batch_results = await self.process_batch_images(images_with_info) + gpu_inference_time += time.time() - gpu_start_time + batch_count += 1 + batch_sizes.append(len(images_with_info)) # 并行保存所有图片 save_tasks = [] @@ -526,7 +589,10 @@ class RmbgService: ) save_tasks.append((index, save_task)) + # 记录保存开始时间 + save_start_time = time.time() save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) + save_time += time.time() - save_start_time for (index, _), image_url in zip(save_tasks, save_results): if isinstance(image_url, Exception): @@ -591,6 +657,7 @@ class RmbgService: "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result + print_stats() return except Exception as e: # 其他异常,直接返回错误 @@ -611,6 +678,7 @@ class RmbgService: "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result + print_stats() return # 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理 @@ -622,7 +690,13 @@ class RmbgService: try: images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images] + + # 记录GPU推理开始时间 + gpu_start_time = time.time() batch_results = await self.process_batch_images(images_with_info) + gpu_inference_time += time.time() - gpu_start_time + batch_count += 1 + batch_sizes.append(len(images_with_info)) # 并行保存所有图片 save_tasks = [] @@ -637,8 +711,11 @@ class RmbgService: ) save_tasks.append((index, save_task)) + # 记录保存开始时间 + save_start_time = time.time() # 并行执行所有保存任务 save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) + save_time += time.time() - save_start_time # 按顺序返回结果 for (index, _), image_url in zip(save_tasks, save_results): @@ -693,6 +770,9 @@ class RmbgService: "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result + + # 输出性能统计信息 + print_stats() def is_valid_url(self, url): """验证URL是否有效"""