diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index ef1431e..90bd33c 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -87,7 +87,7 @@ class RmbgService: model = AutoModelForImageSegmentation.from_pretrained( self.model_path, trust_remote_code=True, - torch_dtype=torch.float16 if use_half else torch.float32, + dtype=torch.float16 if use_half else torch.float32, ) model = model.to(device) if use_half: @@ -558,10 +558,9 @@ class RmbgService: completed_order = 0 pending_batch = [] batch_collect_timeout = 0.5 # 批处理收集超时(秒) - max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理 async def process_pending_batch(force=False): - """处理待处理的批次""" + """处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)""" nonlocal pending_batch, completed_order, success_count, error_count if not pending_batch: @@ -600,142 +599,56 @@ class RmbgService: pending_batch = [] return - # 处理成功的图片 + # 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度) try: - # 判断是否尝试一次性处理 - use_single_batch = len(valid_items) <= max_single_batch and force - - if use_single_batch: - # 尝试一次性处理所有图片 - images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items] - - batch_results = await self.process_batch_images(images_with_info) - - # 并行保存 - save_tasks = [] - result_mapping = {} - - for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in valid_items if idx == index) - result_mapping[index] = (processed_image, url_str) - save_task = loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) - save_tasks.append((index, save_task)) - - save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) - - for (index, _), image_url in zip(save_tasks, save_results): - if isinstance(image_url, Exception): + # 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块 + for local_batch_start in range(0, len(valid_items), batch_size): + local_batch_end = min(local_batch_start + batch_size, len(valid_items)) + local_batch_items = valid_items[local_batch_start:local_batch_end] + + # 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU + tasks = [ + self.process_image(image) + for image, _, _, _ in local_batch_items + ] + batch_results = await asyncio.gather(*tasks, return_exceptions=True) + + for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results): + if isinstance(result_data, Exception): error_count += 1 completed_order += 1 result = { "index": index, "total": total, - "original_url": result_mapping[index][1], + "original_url": url_str, "status": "error", - "error": str(image_url), - "message": f"保存图片失败: {str(image_url)}", + "error": str(result_data), + "message": f"处理失败: {str(result_data)}", "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result + continue + + if isinstance(result_data, dict): + # 队列模式下 process_image 通常直接返回 {status, image_url} + status = result_data.get("status", "success") + image_url = result_data.get("image_url") + message = result_data.get("message", "处理成功" if status == "success" else "处理失败") + error_msg = result_data.get("error") else: - completed_order += 1 - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "success", - "image_url": image_url, - "message": "处理成功", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - else: - # 分批处理 - for batch_start in range(0, len(valid_items), batch_size): - batch_end = min(batch_start + batch_size, len(valid_items)) - batch_items = valid_items[batch_start:batch_end] - - images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items] - - batch_results = await self.process_batch_images(images_with_info) - - # 并行保存 - save_tasks = [] - result_mapping = {} - - for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in batch_items if idx == index) - result_mapping[index] = (processed_image, url_str) - save_task = loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image + # 兼容非 dict 返回:手动保存图片 + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, result_data ) - save_tasks.append((index, save_task)) - - save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) - - for (index, _), image_url in zip(save_tasks, save_results): - if isinstance(image_url, Exception): - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "error", - "error": str(image_url), - "message": f"保存图片失败: {str(image_url)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - else: - completed_order += 1 - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "success", - "image_url": image_url, - "message": "处理成功", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - - except RuntimeError as e: - # CUDA OOM错误,降级处理 - error_msg = str(e) - if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): - logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 降级到单张处理 - for image, image_size, index, url_str, _ in valid_items: - try: - result_data = await self.process_image(image) - if isinstance(result_data, dict): - image_url = result_data["image_url"] - else: - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, result_data - ) - completed_order += 1 + status = "success" + message = "处理成功" + error_msg = None + + completed_order += 1 + if status == "success" and image_url: success_count += 1 result = { "index": index, @@ -743,47 +656,26 @@ class RmbgService: "original_url": url_str, "status": "success", "image_url": image_url, - "message": "处理成功(降级模式)", + "message": message, "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } - yield result - except Exception as e2: + else: error_count += 1 - completed_order += 1 result = { "index": index, "total": total, "original_url": url_str, "status": "error", - "error": str(e2), - "message": f"处理失败: {str(e2)}", + "error": error_msg or "处理失败", + "message": message, "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } - yield result - else: - # 其他错误 - logger.error(f"批处理失败: {error_msg}") - for _, _, index, url_str, _ in valid_items: - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": error_msg, - "message": f"批处理失败: {error_msg}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } yield result except Exception as e: logger.error(f"批处理失败: {str(e)}")