优化GPU推理性能:复用转换器并减少不必要的显存清理

- 复用 ToPILImage() 转换器,避免循环中重复创建对象
- 移除批处理前不必要的显存清理(批处理会重用显存)
- 移除批处理后的 gc.collect(),减少阻塞开销
- 保留单张处理的 gc.collect(),确保及时释放内存

预期性能提升:3-8%
This commit is contained in:
jingrow 2025-11-23 16:22:31 +08:00
parent 9e259e7344
commit 0e1a99d975

View File

@ -122,6 +122,7 @@ class RmbgService:
mask = pred_pil.resize(image_size) mask = pred_pil.resize(image_size)
image.putalpha(mask) image.putalpha(mask)
# 单张处理保留 gc.collect(),确保及时释放内存
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
@ -139,10 +140,6 @@ class RmbgService:
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]) ])
# 批处理前清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
batch_tensors = [] batch_tensors = []
for image, image_size, index in images_with_info: for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image)) batch_tensors.append(transform_image(image))
@ -165,9 +162,9 @@ class RmbgService:
del input_batch del input_batch
if isinstance(model_output, (list, tuple)): if isinstance(model_output, (list, tuple)):
del model_output del model_output
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 复用 ToPILImage 转换器,避免重复创建对象
to_pil = transforms.ToPILImage()
results = [] results = []
for i, (image, image_size, index) in enumerate(images_with_info): for i, (image, image_size, index) in enumerate(images_with_info):
if len(preds.shape) == 4: if len(preds.shape) == 4:
@ -177,7 +174,7 @@ class RmbgService:
else: else:
pred = preds[i].squeeze() pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred) pred_pil = to_pil(pred)
mask = pred_pil.resize(image_size) mask = pred_pil.resize(image_size)
result_image = image.copy() result_image = image.copy()
result_image.putalpha(mask) result_image.putalpha(mask)
@ -186,10 +183,9 @@ class RmbgService:
# 释放 preds # 释放 preds
del preds del preds
# 批处理后再次清理显存 # 批处理后清理显存(移除 gc.collect(),减少阻塞)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect()
return results return results