增加多GPU并发支持

This commit is contained in:
jingrow 2025-12-15 15:59:42 +00:00
parent c5df876f4a
commit cc0198c9b0

View File

@ -16,6 +16,7 @@ import logging
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from threading import Lock
from settings import settings from settings import settings
logging.basicConfig( logging.basicConfig(
@ -43,8 +44,13 @@ class RmbgService:
def __init__(self, model_path=None): def __init__(self, model_path=None):
"""初始化背景移除服务""" """初始化背景移除服务"""
self.model_path = model_path or settings.model_path self.model_path = model_path or settings.model_path
# 单机多 GPU维护模型和设备列表兼容旧字段
self.models = []
self.devices = []
self.model = None self.model = None
self.device = None self.device = None
self._gpu_lock = Lock()
self._next_gpu_index = 0
self.save_dir = settings.save_dir self.save_dir = settings.save_dir
self.download_url = settings.download_url self.download_url = settings.download_url
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
@ -67,53 +73,86 @@ class RmbgService:
# 队列任务将在 FastAPI startup 事件中启动 # 队列任务将在 FastAPI startup 事件中启动
def _load_model(self): def _load_model(self):
"""加载模型""" """加载模型,支持多 GPU"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 优化显存分配策略:减少碎片化(需要在加载前设置)
# 优化显存占用:使用半精度加载(如果支持)
# 注意:某些模型可能不支持半精度,需要测试
try:
# 尝试使用半精度加载可以减少约50%的显存占用
self.model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.model = self.model.to(self.device)
if torch.cuda.is_available():
self.model = self.model.half() # 转换为半精度
except Exception as e:
# 如果半精度加载失败,降级到全精度
logger.warning(f"半精度加载失败,使用全精度: {str(e)}")
self.model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model = self.model.to(self.device)
self.model.eval()
# 优化显存分配策略:减少碎片化
if torch.cuda.is_available(): if torch.cuda.is_available():
# 设置显存分配器,减少碎片化
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
use_half = torch.cuda.is_available()
def _load_single_model(device: torch.device):
"""在指定 device 上加载一个模型实例"""
try:
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if use_half else torch.float32,
)
model = model.to(device)
if use_half:
model = model.half()
except Exception as e:
# 如果半精度加载失败,降级到全精度
logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}")
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
)
model = model.to(device)
model.eval()
return model
if num_gpus > 0:
# 为每张 GPU 加载一份模型,简单轮询调度
for idx in range(num_gpus):
device = torch.device(f"cuda:{idx}")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info(f"检测到 {num_gpus} 张 GPU已为每张 GPU 加载模型实例")
else:
# 仅 CPU
device = torch.device("cpu")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info("未检测到 GPU使用 CPU 设备")
# 兼容旧字段:默认指向第一个设备和模型
self.device = self.devices[0]
self.model = self.models[0]
if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def _get_model_and_device(self):
"""为一次推理选择一个模型和设备(轮询)"""
if not self.models or not self.devices:
raise RuntimeError("模型尚未加载")
if len(self.models) == 1:
return self.models[0], self.devices[0]
with self._gpu_lock:
idx = self._next_gpu_index
self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models)
return self.models[idx], self.devices[idx]
def _process_image_sync(self, image): def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)""" """同步处理图像,移除背景(单张)"""
model, device = self._get_model_and_device()
image_size = image.size image_size = image.size
transform_image = transforms.Compose([ transform_image = transforms.Compose([
transforms.Resize((1024, 1024)), transforms.Resize((1024, 1024)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]) ])
input_images = transform_image(image).unsqueeze(0).to(self.device) input_images = transform_image(image).unsqueeze(0).to(device)
# 如果模型是半精度,输入也转换为半精度 # 如果模型是半精度,输入也转换为半精度
if next(self.model.parameters()).dtype == torch.float16: if next(model.parameters()).dtype == torch.float16:
input_images = input_images.half() input_images = input_images.half()
with torch.no_grad(): with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu() preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze() pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred) pred_pil = transforms.ToPILImage()(pred)
@ -132,6 +171,34 @@ class RmbgService:
if not images_with_info: if not images_with_info:
return [] return []
# 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行
if len(self.models) == 1:
return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info)
# 简单均匀拆分到各个 GPU上游调用会按 index 重新排序
tasks = []
for i, (model, device) in enumerate(zip(self.models, self.devices)):
sub_items = images_with_info[i::len(self.models)]
if not sub_items:
continue
tasks.append(
self.executor.submit(self._process_batch_on_device, model, device, sub_items)
)
all_results = []
for fut in tasks:
try:
sub_res = fut.result()
all_results.extend(sub_res)
except Exception as e:
logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True)
# 保证结果顺序与原始 index 一致
all_results.sort(key=lambda x: x[1])
return all_results
def _process_batch_on_device(self, model, device, images_with_info):
"""在指定 device 上批量处理图像"""
transform_image = transforms.Compose([ transform_image = transforms.Compose([
transforms.Resize((1024, 1024)), transforms.Resize((1024, 1024)),
transforms.ToTensor(), transforms.ToTensor(),
@ -142,15 +209,15 @@ class RmbgService:
for image, image_size, index in images_with_info: for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image)) batch_tensors.append(transform_image(image))
input_batch = torch.stack(batch_tensors).to(self.device) input_batch = torch.stack(batch_tensors).to(device)
# 如果模型是半精度,输入也转换为半精度 # 如果模型是半精度,输入也转换为半精度
if next(self.model.parameters()).dtype == torch.float16: if next(model.parameters()).dtype == torch.float16:
input_batch = input_batch.half() input_batch = input_batch.half()
# 释放 batch_tensors 占用的 CPU 内存 # 释放 batch_tensors 占用的 CPU 内存
del batch_tensors del batch_tensors
with torch.no_grad(): with torch.no_grad():
model_output = self.model(input_batch) model_output = model(input_batch)
if isinstance(model_output, (list, tuple)): if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu() preds = model_output[-1].sigmoid().cpu()
else: else: