diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 9293950..7e96842 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -122,6 +122,7 @@ class RmbgService: mask = pred_pil.resize(image_size) image.putalpha(mask) + # 单张处理保留 gc.collect(),确保及时释放内存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() @@ -139,10 +140,6 @@ class RmbgService: transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) - # 批处理前清理显存 - if torch.cuda.is_available(): - torch.cuda.empty_cache() - batch_tensors = [] for image, image_size, index in images_with_info: batch_tensors.append(transform_image(image)) @@ -165,9 +162,9 @@ class RmbgService: del input_batch if isinstance(model_output, (list, tuple)): del model_output - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # 复用 ToPILImage 转换器,避免重复创建对象 + to_pil = transforms.ToPILImage() results = [] for i, (image, image_size, index) in enumerate(images_with_info): if len(preds.shape) == 4: @@ -177,7 +174,7 @@ class RmbgService: else: pred = preds[i].squeeze() - pred_pil = transforms.ToPILImage()(pred) + pred_pil = to_pil(pred) mask = pred_pil.resize(image_size) result_image = image.copy() result_image.putalpha(mask) @@ -186,10 +183,9 @@ class RmbgService: # 释放 preds del preds - # 批处理后再次清理显存 + # 批处理后清理显存(移除 gc.collect(),减少阻塞) if torch.cuda.is_available(): torch.cuda.empty_cache() - gc.collect() return results