diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 15e9abb..387ee04 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -16,6 +16,7 @@ import logging from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from typing import Optional, Dict, Any +from threading import Lock from settings import settings logging.basicConfig( @@ -43,8 +44,13 @@ class RmbgService: def __init__(self, model_path=None): """初始化背景移除服务""" self.model_path = model_path or settings.model_path + # 单机多 GPU:维护模型和设备列表,兼容旧字段 + self.models = [] + self.devices = [] self.model = None self.device = None + self._gpu_lock = Lock() + self._next_gpu_index = 0 self.save_dir = settings.save_dir self.download_url = settings.download_url os.makedirs(self.save_dir, exist_ok=True) @@ -67,53 +73,86 @@ class RmbgService: # 队列任务将在 FastAPI startup 事件中启动 def _load_model(self): - """加载模型""" - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # 优化显存占用:使用半精度加载(如果支持) - # 注意:某些模型可能不支持半精度,需要测试 - try: - # 尝试使用半精度加载,可以减少约50%的显存占用 - self.model = AutoModelForImageSegmentation.from_pretrained( - self.model_path, - trust_remote_code=True, - torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 - ) - self.model = self.model.to(self.device) - if torch.cuda.is_available(): - self.model = self.model.half() # 转换为半精度 - except Exception as e: - # 如果半精度加载失败,降级到全精度 - logger.warning(f"半精度加载失败,使用全精度: {str(e)}") - self.model = AutoModelForImageSegmentation.from_pretrained( - self.model_path, - trust_remote_code=True - ) - self.model = self.model.to(self.device) - - self.model.eval() - - # 优化显存分配策略:减少碎片化 + """加载模型,支持多 GPU""" + # 优化显存分配策略:减少碎片化(需要在加载前设置) if torch.cuda.is_available(): - # 设置显存分配器,减少碎片化 os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') + + num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 + use_half = torch.cuda.is_available() + + def _load_single_model(device: torch.device): + """在指定 device 上加载一个模型实例""" + try: + model = AutoModelForImageSegmentation.from_pretrained( + self.model_path, + trust_remote_code=True, + torch_dtype=torch.float16 if use_half else torch.float32, + ) + model = model.to(device) + if use_half: + model = model.half() + except Exception as e: + # 如果半精度加载失败,降级到全精度 + logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}") + model = AutoModelForImageSegmentation.from_pretrained( + self.model_path, + trust_remote_code=True, + ) + model = model.to(device) + model.eval() + return model + + if num_gpus > 0: + # 为每张 GPU 加载一份模型,简单轮询调度 + for idx in range(num_gpus): + device = torch.device(f"cuda:{idx}") + model = _load_single_model(device) + self.devices.append(device) + self.models.append(model) + logger.info(f"检测到 {num_gpus} 张 GPU,已为每张 GPU 加载模型实例") + else: + # 仅 CPU + device = torch.device("cpu") + model = _load_single_model(device) + self.devices.append(device) + self.models.append(model) + logger.info("未检测到 GPU,使用 CPU 设备") + + # 兼容旧字段:默认指向第一个设备和模型 + self.device = self.devices[0] + self.model = self.models[0] + + if torch.cuda.is_available(): torch.cuda.empty_cache() + def _get_model_and_device(self): + """为一次推理选择一个模型和设备(轮询)""" + if not self.models or not self.devices: + raise RuntimeError("模型尚未加载") + if len(self.models) == 1: + return self.models[0], self.devices[0] + with self._gpu_lock: + idx = self._next_gpu_index + self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models) + return self.models[idx], self.devices[idx] + def _process_image_sync(self, image): """同步处理图像,移除背景(单张)""" + model, device = self._get_model_and_device() image_size = image.size transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) - input_images = transform_image(image).unsqueeze(0).to(self.device) + input_images = transform_image(image).unsqueeze(0).to(device) # 如果模型是半精度,输入也转换为半精度 - if next(self.model.parameters()).dtype == torch.float16: + if next(model.parameters()).dtype == torch.float16: input_images = input_images.half() with torch.no_grad(): - preds = self.model(input_images)[-1].sigmoid().cpu() + preds = model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) @@ -131,7 +170,35 @@ class RmbgService: """批量处理图像(批处理模式,充分利用GPU并行能力)""" if not images_with_info: return [] - + + # 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行 + if len(self.models) == 1: + return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info) + + # 简单均匀拆分到各个 GPU,上游调用会按 index 重新排序 + tasks = [] + for i, (model, device) in enumerate(zip(self.models, self.devices)): + sub_items = images_with_info[i::len(self.models)] + if not sub_items: + continue + tasks.append( + self.executor.submit(self._process_batch_on_device, model, device, sub_items) + ) + + all_results = [] + for fut in tasks: + try: + sub_res = fut.result() + all_results.extend(sub_res) + except Exception as e: + logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True) + + # 保证结果顺序与原始 index 一致 + all_results.sort(key=lambda x: x[1]) + return all_results + + def _process_batch_on_device(self, model, device, images_with_info): + """在指定 device 上批量处理图像""" transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), @@ -142,15 +209,15 @@ class RmbgService: for image, image_size, index in images_with_info: batch_tensors.append(transform_image(image)) - input_batch = torch.stack(batch_tensors).to(self.device) + input_batch = torch.stack(batch_tensors).to(device) # 如果模型是半精度,输入也转换为半精度 - if next(self.model.parameters()).dtype == torch.float16: + if next(model.parameters()).dtype == torch.float16: input_batch = input_batch.half() # 释放 batch_tensors 占用的 CPU 内存 del batch_tensors with torch.no_grad(): - model_output = self.model(input_batch) + model_output = model(input_batch) if isinstance(model_output, (list, tuple)): preds = model_output[-1].sigmoid().cpu() else: