diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 7e96842..386cd58 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -443,7 +443,7 @@ class RmbgService: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): - """批量处理多个URL图像,批处理模式(推荐方案)""" + """批量处理多个URL图像,流水线批处理模式(下载和处理并行)""" total = len(urls) success_count = 0 error_count = 0 @@ -459,6 +459,12 @@ class RmbgService: batch_sizes = [] stats_printed = False + # 流水线队列:收集已下载的图片 + download_queue = asyncio.Queue() + download_complete = asyncio.Event() + download_done_count = 0 + download_error_count = 0 + def print_stats(): """输出性能统计信息""" nonlocal stats_printed @@ -470,7 +476,7 @@ class RmbgService: other_time = total_time - download_time - gpu_inference_time - save_time logger.info("=" * 60) - logger.info("📊 批处理性能统计") + logger.info("📊 批处理性能统计(流水线模式)") logger.info("=" * 60) logger.info(f"图片总数: {total}") logger.info(f"成功数量: {success_count}") @@ -506,8 +512,10 @@ class RmbgService: logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s") async def download_image_async(index, url): - """异步下载图片""" + """异步下载图片并放入队列""" + nonlocal download_done_count, download_error_count url_str = str(url) + try: if self.is_valid_url(url_str): temp_file = await self.download_image(url_str) @@ -519,202 +527,106 @@ class RmbgService: image = await loop.run_in_executor( self.executor, lambda: Image.open(url_str).convert("RGB") ) - return (image, image.size, index, url_str, None) + + # 下载成功,放入队列 + await download_queue.put((image, image.size, index, url_str, None)) + download_done_count += 1 + except Exception as e: - return (None, None, index, url_str, str(e)) + # 下载失败,也放入队列(标记为错误) + await download_queue.put((None, None, index, url_str, str(e))) + download_error_count += 1 + download_done_count += 1 + finally: + # 所有下载任务完成 + if download_done_count >= total: + download_complete.set() - # 记录下载开始时间 + # 启动所有下载任务(并行下载) download_start_time = time.time() - download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)] - downloaded_images = await asyncio.gather(*download_tasks) - download_time = time.time() - download_start_time + download_tasks = [ + asyncio.create_task(download_image_async(i, url)) + for i, url in enumerate(urls, 1) + ] - valid_images = [] - failed_results = {} + # 流水线批处理任务:收集队列中的图片,达到batch_size或超时后立即处理 + completed_order = 0 + pending_batch = [] + batch_collect_timeout = 0.5 # 批处理收集超时(秒) + max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理 - for item in downloaded_images: - image, image_size, index, url_str, error = item - if error: - failed_results[index] = { + async def process_pending_batch(force=False): + """处理待处理的批次""" + nonlocal pending_batch, completed_order, success_count, error_count + nonlocal gpu_inference_time, save_time, batch_count, batch_sizes + + if not pending_batch: + return + + # 分离成功和失败的图片 + valid_items = [] + failed_items = [] + + for item in pending_batch: + image, image_size, index, url_str, error = item + if error: + failed_items.append((index, url_str, error)) + else: + valid_items.append((image, image_size, index, url_str)) + + # 先处理下载失败的 + for index, url_str, error in failed_items: + error_count += 1 + completed_order += 1 + result = { "index": index, "total": total, "original_url": url_str, "status": "error", "error": error, - "message": f"下载失败: {error}" + "message": f"下载失败: {error}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) } - else: - valid_images.append((image, image_size, index, url_str)) - - for index, result in failed_results.items(): - error_count += 1 - result["success_count"] = success_count - result["error_count"] = error_count - result["completed_order"] = len(failed_results) - result["batch_elapsed"] = round(time.time() - batch_start_time, 2) - yield result - - completed_order = len(failed_results) - - # 如果图片数量不太多(<= batch_size * 2),尝试一次性处理所有图片(避免分批,提升并发) - # 对于13张图片,batch_size=8,13 <= 16,会尝试一次性处理 - # 如果显存不足,自动降级到分批处理 - max_single_batch = batch_size * 2 # 允许最多2倍batch_size - use_single_batch = len(valid_images) <= max_single_batch - - if use_single_batch: - try: - images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images] - - # 记录GPU推理开始时间 - gpu_start_time = time.time() - batch_results = await self.process_batch_images(images_with_info) - gpu_inference_time += time.time() - gpu_start_time - batch_count += 1 - batch_sizes.append(len(images_with_info)) - - # 并行保存所有图片 - save_tasks = [] - result_mapping = {} - - for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in valid_images if idx == index) - result_mapping[index] = (processed_image, url_str) - - save_task = loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) - save_tasks.append((index, save_task)) - - # 记录保存开始时间 - save_start_time = time.time() - save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) - save_time += time.time() - save_start_time - - for (index, _), image_url in zip(save_tasks, save_results): - if isinstance(image_url, Exception): - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "error", - "error": str(image_url), - "message": f"保存图片失败: {str(image_url)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - else: - completed_order += 1 - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "success", - "image_url": image_url, - "message": "处理成功", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - except RuntimeError as e: - # 检测CUDA OOM错误,自动降级到分批处理 - error_msg = str(e) - if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): - logger.warning(f"一次性处理显存不足,自动降级到分批处理: {error_msg[:100]}") - # 清理显存 - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - # 继续执行分批处理逻辑(不return,继续到else分支) - use_single_batch = False - else: - # 其他错误,直接返回 - logger.error(f"批处理失败: {error_msg}") - for _, _, index, url_str in valid_images: - completed_order += 1 - error_count += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": error_msg, - "message": f"批处理失败: {error_msg}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - print_stats() - return - except Exception as e: - # 其他异常,直接返回错误 - logger.error(f"批处理失败: {str(e)}") - for _, _, index, url_str in valid_images: - completed_order += 1 - error_count += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": str(e), - "message": f"批处理失败: {str(e)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - print_stats() + yield result + + if not valid_items: + pending_batch = [] return - - # 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理 - if not use_single_batch: - # 多批处理:串行处理批次,但每个批次内部并行保存 - for batch_start in range(0, len(valid_images), batch_size): - batch_end = min(batch_start + batch_size, len(valid_images)) - batch_images = valid_images[batch_start:batch_end] + + # 处理成功的图片 + try: + # 判断是否尝试一次性处理 + use_single_batch = len(valid_items) <= max_single_batch and force - try: - images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images] + if use_single_batch: + # 尝试一次性处理所有图片 + images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items] - # 记录GPU推理开始时间 gpu_start_time = time.time() batch_results = await self.process_batch_images(images_with_info) gpu_inference_time += time.time() - gpu_start_time batch_count += 1 batch_sizes.append(len(images_with_info)) - # 并行保存所有图片 + # 并行保存 save_tasks = [] result_mapping = {} for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in batch_images if idx == index) + url_str = next(url for _, _, idx, url in valid_items if idx == index) result_mapping[index] = (processed_image, url_str) - save_task = loop.run_in_executor( self.executor, self.save_image_to_file, processed_image ) save_tasks.append((index, save_task)) - # 记录保存开始时间 save_start_time = time.time() - # 并行执行所有保存任务 save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) save_time += time.time() - save_start_time - # 按顺序返回结果 for (index, _), image_url in zip(save_tasks, save_results): if isinstance(image_url, Exception): error_count += 1 @@ -748,25 +660,202 @@ class RmbgService: "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result - - except Exception as e: - logger.error(f"批处理失败: {str(e)}") - for _, _, index, url_str in batch_images: - completed_order += 1 + else: + # 分批处理 + for batch_start in range(0, len(valid_items), batch_size): + batch_end = min(batch_start + batch_size, len(valid_items)) + batch_items = valid_items[batch_start:batch_end] + + images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items] + + gpu_start_time = time.time() + batch_results = await self.process_batch_images(images_with_info) + gpu_inference_time += time.time() - gpu_start_time + batch_count += 1 + batch_sizes.append(len(images_with_info)) + + # 并行保存 + save_tasks = [] + result_mapping = {} + + for processed_image, index in batch_results: + url_str = next(url for _, _, idx, url in batch_items if idx == index) + result_mapping[index] = (processed_image, url_str) + save_task = loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + save_tasks.append((index, save_task)) + + save_start_time = time.time() + save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) + save_time += time.time() - save_start_time + + for (index, _), image_url in zip(save_tasks, save_results): + if isinstance(image_url, Exception): + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "error", + "error": str(image_url), + "message": f"保存图片失败: {str(image_url)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + else: + completed_order += 1 + success_count += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "success", + "image_url": image_url, + "message": "处理成功", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + + except RuntimeError as e: + # CUDA OOM错误,降级处理 + error_msg = str(e) + if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): + logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # 降级到单张处理 + for image, image_size, index, url_str, _ in valid_items: + try: + processed_image = await self.process_image(image) + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + completed_order += 1 + success_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "success", + "image_url": image_url, + "message": "处理成功(降级模式)", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + except Exception as e2: + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": str(e2), + "message": f"处理失败: {str(e2)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + else: + # 其他错误 + logger.error(f"批处理失败: {error_msg}") + for _, _, index, url_str, _ in valid_items: error_count += 1 + completed_order += 1 result = { "index": index, "total": total, "original_url": url_str, "status": "error", - "error": str(e), - "message": f"批处理失败: {str(e)}", + "error": error_msg, + "message": f"批处理失败: {error_msg}", "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result + except Exception as e: + logger.error(f"批处理失败: {str(e)}") + for _, _, index, url_str, _ in valid_items: + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": str(e), + "message": f"批处理失败: {str(e)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + + pending_batch = [] + + # 流水线处理:收集队列中的图片,达到batch_size或超时后立即处理 + while True: + try: + # 等待队列中有新图片,或超时 + try: + item = await asyncio.wait_for( + download_queue.get(), + timeout=batch_collect_timeout + ) + pending_batch.append(item) + except asyncio.TimeoutError: + # 超时,处理当前批次 + if pending_batch: + async for result in process_pending_batch(): + yield result + # 检查是否所有下载都完成 + if download_complete.is_set(): + break + continue + + # 如果达到batch_size,立即处理 + if len(pending_batch) >= batch_size: + async for result in process_pending_batch(): + yield result + + # 检查是否所有下载都完成 + if download_complete.is_set() and download_queue.empty(): + # 处理剩余的图片 + if pending_batch: + async for result in process_pending_batch(force=True): + yield result + break + + except Exception as e: + logger.error(f"流水线处理出错: {str(e)}", exc_info=True) + break + + # 等待所有下载任务完成 + await asyncio.gather(*download_tasks, return_exceptions=True) + download_time = time.time() - download_start_time + + # 确保所有结果都已处理 + if pending_batch: + async for result in process_pending_batch(force=True): + yield result # 输出性能统计信息 print_stats()