优化rmbg2.py
BIN
output/3T8A4426_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp0kh35900_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp119tqook_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp3v4i6y57_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp4d66uuvp_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp4kv2mhwg_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp7e7ckapa_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp972za3v4_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmp_v9_rdjz_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpb2icryq6_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpb4585ssd_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpdyscjyi5_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpfc5ytkx0_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpisrqzyuh_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpktvlg1nn_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmppdj3mnpp_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpsrxe2mao_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmptv1q37s9_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpxrzisy8i_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpycn_udag_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpym2v1swb_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
BIN
output/tmpz2kzshzf_nobg.png
Normal file
|
After Width: | Height: | Size: 1005 KiB |
@ -16,11 +16,8 @@ import gc
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
# 设置torch精度和优化
|
||||
# 设置torch精度
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32
|
||||
torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32
|
||||
|
||||
# 示例图片URL
|
||||
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||||
@ -49,26 +46,31 @@ def process_image(image, model, device):
|
||||
"""处理图像,移除背景"""
|
||||
image_size = image.size
|
||||
# 转换图像
|
||||
t0 = time.time()
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||
print(f"图像预处理耗时: {time.time()-t0:.3f}秒")
|
||||
|
||||
# 使用半精度加速
|
||||
t0 = time.time()
|
||||
with torch.cuda.amp.autocast():
|
||||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
preds = model(input_images)[-1].sigmoid().cpu()
|
||||
print(f"模型推理耗时: {time.time()-t0:.3f}秒")
|
||||
|
||||
# 处理预测结果
|
||||
t0 = time.time()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# 添加透明通道
|
||||
image.putalpha(mask)
|
||||
print(f"后处理耗时: {time.time()-t0:.3f}秒")
|
||||
return image
|
||||
|
||||
def remove_background(image_path, output_dir=None, model=None, device=None):
|
||||
|
||||