From 0e1a99d975bd173c7720babc5d48a1e5c9568144 Mon Sep 17 00:00:00 2001 From: jingrow Date: Sun, 23 Nov 2025 16:22:31 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96GPU=E6=8E=A8=E7=90=86?= =?UTF-8?q?=E6=80=A7=E8=83=BD=EF=BC=9A=E5=A4=8D=E7=94=A8=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E5=99=A8=E5=B9=B6=E5=87=8F=E5=B0=91=E4=B8=8D=E5=BF=85=E8=A6=81?= =?UTF-8?q?=E7=9A=84=E6=98=BE=E5=AD=98=E6=B8=85=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 复用 ToPILImage() 转换器,避免循环中重复创建对象 - 移除批处理前不必要的显存清理(批处理会重用显存) - 移除批处理后的 gc.collect(),减少阻塞开销 - 保留单张处理的 gc.collect(),确保及时释放内存 预期性能提升:3-8% --- apps/rmbg/service.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 9293950..7e96842 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -122,6 +122,7 @@ class RmbgService: mask = pred_pil.resize(image_size) image.putalpha(mask) + # 单张处理保留 gc.collect(),确保及时释放内存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() @@ -139,10 +140,6 @@ class RmbgService: transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) - # 批处理前清理显存 - if torch.cuda.is_available(): - torch.cuda.empty_cache() - batch_tensors = [] for image, image_size, index in images_with_info: batch_tensors.append(transform_image(image)) @@ -165,9 +162,9 @@ class RmbgService: del input_batch if isinstance(model_output, (list, tuple)): del model_output - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # 复用 ToPILImage 转换器,避免重复创建对象 + to_pil = transforms.ToPILImage() results = [] for i, (image, image_size, index) in enumerate(images_with_info): if len(preds.shape) == 4: @@ -177,7 +174,7 @@ class RmbgService: else: pred = preds[i].squeeze() - pred_pil = transforms.ToPILImage()(pred) + pred_pil = to_pil(pred) mask = pred_pil.resize(image_size) result_image = image.copy() result_image.putalpha(mask) @@ -186,10 +183,9 @@ class RmbgService: # 释放 preds del preds - # 批处理后再次清理显存 + # 批处理后清理显存(移除 gc.collect(),减少阻塞) if torch.cuda.is_available(): torch.cuda.empty_cache() - gc.collect() return results