增加多GPU并发支持
This commit is contained in:
parent
c5df876f4a
commit
cc0198c9b0
@ -16,6 +16,7 @@ import logging
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
from threading import Lock
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -43,8 +44,13 @@ class RmbgService:
|
|||||||
def __init__(self, model_path=None):
|
def __init__(self, model_path=None):
|
||||||
"""初始化背景移除服务"""
|
"""初始化背景移除服务"""
|
||||||
self.model_path = model_path or settings.model_path
|
self.model_path = model_path or settings.model_path
|
||||||
|
# 单机多 GPU:维护模型和设备列表,兼容旧字段
|
||||||
|
self.models = []
|
||||||
|
self.devices = []
|
||||||
self.model = None
|
self.model = None
|
||||||
self.device = None
|
self.device = None
|
||||||
|
self._gpu_lock = Lock()
|
||||||
|
self._next_gpu_index = 0
|
||||||
self.save_dir = settings.save_dir
|
self.save_dir = settings.save_dir
|
||||||
self.download_url = settings.download_url
|
self.download_url = settings.download_url
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
@ -67,53 +73,86 @@ class RmbgService:
|
|||||||
# 队列任务将在 FastAPI startup 事件中启动
|
# 队列任务将在 FastAPI startup 事件中启动
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""加载模型"""
|
"""加载模型,支持多 GPU"""
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
# 优化显存分配策略:减少碎片化(需要在加载前设置)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
|
||||||
|
|
||||||
# 优化显存占用:使用半精度加载(如果支持)
|
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||||
# 注意:某些模型可能不支持半精度,需要测试
|
use_half = torch.cuda.is_available()
|
||||||
|
|
||||||
|
def _load_single_model(device: torch.device):
|
||||||
|
"""在指定 device 上加载一个模型实例"""
|
||||||
try:
|
try:
|
||||||
# 尝试使用半精度加载,可以减少约50%的显存占用
|
model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
torch_dtype=torch.float16 if use_half else torch.float32,
|
||||||
)
|
)
|
||||||
self.model = self.model.to(self.device)
|
model = model.to(device)
|
||||||
if torch.cuda.is_available():
|
if use_half:
|
||||||
self.model = self.model.half() # 转换为半精度
|
model = model.half()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 如果半精度加载失败,降级到全精度
|
# 如果半精度加载失败,降级到全精度
|
||||||
logger.warning(f"半精度加载失败,使用全精度: {str(e)}")
|
logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}")
|
||||||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
model = AutoModelForImageSegmentation.from_pretrained(
|
||||||
self.model_path,
|
self.model_path,
|
||||||
trust_remote_code=True
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
self.model = self.model.to(self.device)
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
self.model.eval()
|
if num_gpus > 0:
|
||||||
|
# 为每张 GPU 加载一份模型,简单轮询调度
|
||||||
|
for idx in range(num_gpus):
|
||||||
|
device = torch.device(f"cuda:{idx}")
|
||||||
|
model = _load_single_model(device)
|
||||||
|
self.devices.append(device)
|
||||||
|
self.models.append(model)
|
||||||
|
logger.info(f"检测到 {num_gpus} 张 GPU,已为每张 GPU 加载模型实例")
|
||||||
|
else:
|
||||||
|
# 仅 CPU
|
||||||
|
device = torch.device("cpu")
|
||||||
|
model = _load_single_model(device)
|
||||||
|
self.devices.append(device)
|
||||||
|
self.models.append(model)
|
||||||
|
logger.info("未检测到 GPU,使用 CPU 设备")
|
||||||
|
|
||||||
|
# 兼容旧字段:默认指向第一个设备和模型
|
||||||
|
self.device = self.devices[0]
|
||||||
|
self.model = self.models[0]
|
||||||
|
|
||||||
# 优化显存分配策略:减少碎片化
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# 设置显存分配器,减少碎片化
|
|
||||||
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _get_model_and_device(self):
|
||||||
|
"""为一次推理选择一个模型和设备(轮询)"""
|
||||||
|
if not self.models or not self.devices:
|
||||||
|
raise RuntimeError("模型尚未加载")
|
||||||
|
if len(self.models) == 1:
|
||||||
|
return self.models[0], self.devices[0]
|
||||||
|
with self._gpu_lock:
|
||||||
|
idx = self._next_gpu_index
|
||||||
|
self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models)
|
||||||
|
return self.models[idx], self.devices[idx]
|
||||||
|
|
||||||
def _process_image_sync(self, image):
|
def _process_image_sync(self, image):
|
||||||
"""同步处理图像,移除背景(单张)"""
|
"""同步处理图像,移除背景(单张)"""
|
||||||
|
model, device = self._get_model_and_device()
|
||||||
image_size = image.size
|
image_size = image.size
|
||||||
transform_image = transforms.Compose([
|
transform_image = transforms.Compose([
|
||||||
transforms.Resize((1024, 1024)),
|
transforms.Resize((1024, 1024)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
])
|
])
|
||||||
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||||
# 如果模型是半精度,输入也转换为半精度
|
# 如果模型是半精度,输入也转换为半精度
|
||||||
if next(self.model.parameters()).dtype == torch.float16:
|
if next(model.parameters()).dtype == torch.float16:
|
||||||
input_images = input_images.half()
|
input_images = input_images.half()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = self.model(input_images)[-1].sigmoid().cpu()
|
preds = model(input_images)[-1].sigmoid().cpu()
|
||||||
|
|
||||||
pred = preds[0].squeeze()
|
pred = preds[0].squeeze()
|
||||||
pred_pil = transforms.ToPILImage()(pred)
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
@ -132,6 +171,34 @@ class RmbgService:
|
|||||||
if not images_with_info:
|
if not images_with_info:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行
|
||||||
|
if len(self.models) == 1:
|
||||||
|
return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info)
|
||||||
|
|
||||||
|
# 简单均匀拆分到各个 GPU,上游调用会按 index 重新排序
|
||||||
|
tasks = []
|
||||||
|
for i, (model, device) in enumerate(zip(self.models, self.devices)):
|
||||||
|
sub_items = images_with_info[i::len(self.models)]
|
||||||
|
if not sub_items:
|
||||||
|
continue
|
||||||
|
tasks.append(
|
||||||
|
self.executor.submit(self._process_batch_on_device, model, device, sub_items)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for fut in tasks:
|
||||||
|
try:
|
||||||
|
sub_res = fut.result()
|
||||||
|
all_results.extend(sub_res)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 保证结果顺序与原始 index 一致
|
||||||
|
all_results.sort(key=lambda x: x[1])
|
||||||
|
return all_results
|
||||||
|
|
||||||
|
def _process_batch_on_device(self, model, device, images_with_info):
|
||||||
|
"""在指定 device 上批量处理图像"""
|
||||||
transform_image = transforms.Compose([
|
transform_image = transforms.Compose([
|
||||||
transforms.Resize((1024, 1024)),
|
transforms.Resize((1024, 1024)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
@ -142,15 +209,15 @@ class RmbgService:
|
|||||||
for image, image_size, index in images_with_info:
|
for image, image_size, index in images_with_info:
|
||||||
batch_tensors.append(transform_image(image))
|
batch_tensors.append(transform_image(image))
|
||||||
|
|
||||||
input_batch = torch.stack(batch_tensors).to(self.device)
|
input_batch = torch.stack(batch_tensors).to(device)
|
||||||
# 如果模型是半精度,输入也转换为半精度
|
# 如果模型是半精度,输入也转换为半精度
|
||||||
if next(self.model.parameters()).dtype == torch.float16:
|
if next(model.parameters()).dtype == torch.float16:
|
||||||
input_batch = input_batch.half()
|
input_batch = input_batch.half()
|
||||||
# 释放 batch_tensors 占用的 CPU 内存
|
# 释放 batch_tensors 占用的 CPU 内存
|
||||||
del batch_tensors
|
del batch_tensors
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
model_output = self.model(input_batch)
|
model_output = model(input_batch)
|
||||||
if isinstance(model_output, (list, tuple)):
|
if isinstance(model_output, (list, tuple)):
|
||||||
preds = model_output[-1].sigmoid().cpu()
|
preds = model_output[-1].sigmoid().cpu()
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user