diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index afb4240..7be346d 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -69,9 +69,9 @@ class RmbgService: ) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) - # 队列聚合机制(方案B) + # 队列聚合机制(方案B:严格的流水线式,每 GPU 一个 worker) self.queue: asyncio.Queue = asyncio.Queue() - self.queue_task: Optional[asyncio.Task] = None + self.queue_tasks: list[asyncio.Task] = [] # 存储所有 worker 任务 self.queue_running = False self._load_model() @@ -81,7 +81,7 @@ class RmbgService: """加载模型,支持多 GPU""" # 优化显存分配策略:减少碎片化(需要在加载前设置) if torch.cuda.is_available(): - os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') + os.environ.setdefault('PYTORCH_ALLOC_CONF', 'expandable_segments:True') num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 use_half = torch.cuda.is_available() @@ -92,7 +92,7 @@ class RmbgService: model = AutoModelForImageSegmentation.from_pretrained( self.model_path, trust_remote_code=True, - torch_dtype=torch.float16 if use_half else torch.float32, + dtype=torch.float16 if use_half else torch.float32, ) model = model.to(device) if use_half: @@ -146,8 +146,12 @@ class RmbgService: return self.models[idx], self.devices[idx] def _process_image_sync(self, image): - """同步处理图像,移除背景(单张)""" + """同步处理图像,移除背景(单张)- 兼容旧接口,使用轮询调度""" model, device = self._get_model_and_device() + return self._process_single_image_on_device(model, device, image) + + def _process_single_image_on_device(self, model, device, image): + """在指定设备上处理单张图像(用于 worker 降级处理)""" image_size = image.size transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), @@ -309,36 +313,46 @@ class RmbgService: ) async def _start_queue_processor(self): - """启动队列批处理后台任务(异步方法,需要在事件循环中调用)""" + """启动队列批处理后台任务(严格的流水线式方案B:每 GPU 一个 worker)""" if self.queue_running: return self.queue_running = True - self.queue_task = asyncio.create_task(self._queue_processor()) + + # 为每个 GPU 启动一个独立的 worker + num_workers = len(self.models) if self.models else 1 + logger.info(f"启动 {num_workers} 个队列处理 worker(每 GPU 一个)") + + for worker_id in range(num_workers): + task = asyncio.create_task(self._queue_processor(worker_id)) + self.queue_tasks.append(task) - async def _queue_processor(self): - """后台队列批处理任务(核心逻辑)""" + async def _queue_processor(self, worker_id: int): + """后台队列批处理任务(核心逻辑)- 每个 worker 绑定一个 GPU""" + model = self.models[worker_id] + device = self.devices[worker_id] + + logger.info(f"Worker {worker_id} 启动,绑定设备: {device}") + while self.queue_running: try: - # 收集一批请求 + # 收集一批请求(单卡 batch_size) batch_items = await self._collect_batch_items() if not batch_items: continue - # 处理这批请求 - await self._process_batch_queue_items(batch_items) + # 处理这批请求(只使用当前 worker 的 model 和 device) + await self._process_batch_queue_items(batch_items, model, device, worker_id) except Exception as e: - logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True) + logger.error(f"Worker {worker_id} 队列批处理任务出错: {str(e)}", exc_info=True) await asyncio.sleep(0.1) # 出错后短暂等待 async def _collect_batch_items(self): - """收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)""" + """收集一批队列项,达到单卡 batch_size 或超时后返回(单卡 batch,避免 worker 之间打架)""" batch_items = [] - per_device_batch = settings.batch_size # 单卡 batch_size - device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1)) - target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限 + target_batch_size = settings.batch_size # 单卡 batch_size(不再乘以 GPU 数量) collect_interval = settings.batch_collect_interval collect_timeout = settings.batch_collect_timeout @@ -379,8 +393,8 @@ class RmbgService: return batch_items - async def _process_batch_queue_items(self, batch_items): - """处理一批队列项(统一全局 batcher,支持跨用户合批)""" + async def _process_batch_queue_items(self, batch_items, model, device, worker_id: int): + """处理一批队列项(单卡处理,使用指定的 model 和 device)""" if not batch_items: return @@ -394,8 +408,14 @@ class RmbgService: images_with_info.append((item.image, item.image_size, idx)) item_index_map[idx] = item - # 执行批处理(直接调用,充分利用 GPU) - batch_results = await self.process_batch_images(images_with_info) + # 执行批处理(只使用当前 worker 的 model 和 device,不再做多卡拆分) + batch_results = await loop.run_in_executor( + self.executor, + self._process_batch_on_device, + model, + device, + images_with_info + ) # 并行保存所有图片(关键优化:避免串行 IO 阻塞) save_tasks = [] @@ -445,27 +465,30 @@ class RmbgService: # CUDA OOM 错误,降级处理 error_msg = str(e) if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): - logger.warning(f"批处理显存不足,降级到单张处理: {error_msg[:100]}") + logger.warning(f"Worker {worker_id} 批处理显存不足,降级到单张处理: {error_msg[:100]}") if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() - # 降级:单张处理 + # 降级:单张处理(使用当前 worker 的 model 和 device) for item in batch_items: try: - result_data = await self.process_image(item.image) - if isinstance(result_data, dict): - if not item.future.done(): - item.future.set_result(result_data) - else: - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, result_data - ) - if not item.future.done(): - item.future.set_result({ - "status": "success", - "image_url": image_url - }) + # 使用当前 worker 的 model 和 device 进行单张处理 + result_image = await loop.run_in_executor( + self.executor, + self._process_single_image_on_device, + model, + device, + item.image + ) + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, result_image + ) + if not item.future.done(): + item.future.set_result({ + "status": "success", + "image_url": image_url + }) except Exception as e2: if not item.future.done(): item.future.set_exception(Exception(f"降级处理失败: {str(e2)}")) @@ -753,15 +776,17 @@ class RmbgService: async def cleanup(self): """清理资源""" - # 停止队列处理任务 + # 停止所有队列处理 worker 任务 if self.queue_running: self.queue_running = False - if self.queue_task: - self.queue_task.cancel() - try: - await self.queue_task - except asyncio.CancelledError: - pass + # 取消所有 worker 任务 + for task in self.queue_tasks: + if task: + task.cancel() + # 等待所有任务完成取消 + if self.queue_tasks: + await asyncio.gather(*self.queue_tasks, return_exceptions=True) + self.queue_tasks.clear() # 处理队列中剩余的请求 remaining_items = [] diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index 057fef6..6607497 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -35,7 +35,7 @@ class Settings(BaseSettings): # 并发控制配置(推理侧) max_workers: int = 60 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) - batch_size: int = 16 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) + batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) # 队列聚合配置(方案B:批处理+队列模式) batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量