优化GPU推理性能:复用转换器并减少不必要的显存清理

- 复用 ToPILImage() 转换器,避免循环中重复创建对象
- 移除批处理前不必要的显存清理(批处理会重用显存)
- 移除批处理后的 gc.collect(),减少阻塞开销
- 保留单张处理的 gc.collect(),确保及时释放内存

预期性能提升:3-8%
This commit is contained in:
jingrow 2025-11-23 16:22:31 +08:00
parent 9e259e7344
commit 0e1a99d975

View File

@ -122,6 +122,7 @@ class RmbgService:
mask = pred_pil.resize(image_size)
image.putalpha(mask)
# 单张处理保留 gc.collect(),确保及时释放内存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
@ -139,10 +140,6 @@ class RmbgService:
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 批处理前清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
batch_tensors = []
for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image))
@ -165,9 +162,9 @@ class RmbgService:
del input_batch
if isinstance(model_output, (list, tuple)):
del model_output
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 复用 ToPILImage 转换器,避免重复创建对象
to_pil = transforms.ToPILImage()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
if len(preds.shape) == 4:
@ -177,7 +174,7 @@ class RmbgService:
else:
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
pred_pil = to_pil(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
@ -186,10 +183,9 @@ class RmbgService:
# 释放 preds
del preds
# 批处理后再次清理显存
# 批处理后清理显存(移除 gc.collect(),减少阻塞)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return results