import os import argparse import requests import tempfile from urllib.parse import urlparse from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import time # 设置torch精度 torch.set_float32_matmul_precision("high") # 初始化模型 model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True) model.to("cuda") model.eval() # 设置为评估模式 # 定义图像转换 transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def is_valid_url(url): """验证URL是否有效""" try: result = urlparse(url) return all([result.scheme, result.netloc]) except: return False def download_image(url): """从URL下载图片到临时文件""" response = requests.get(url, stream=True) response.raise_for_status() # 创建临时文件 temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') with open(temp_file.name, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return temp_file.name def process_image(image): """处理图像,移除背景""" image_size = image.size # 转换图像 input_images = transform_image(image).unsqueeze(0).to("cuda") # 使用半精度加速 with torch.cuda.amp.autocast(): # 预测 with torch.no_grad(): preds = model(input_images)[-1].sigmoid().cpu() # 处理预测结果 pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # 添加透明通道 image.putalpha(mask) return image def remove_background(image_path, output_dir=None): """ 移除图像背景 Args: image_path: 输入图像的路径或URL output_dir: 输出目录,如果不指定则使用当前目录 Returns: 处理后的图像路径 """ temp_file = None try: # 检查是否是URL if is_valid_url(image_path): try: # 下载图片到临时文件 temp_file = download_image(image_path) image_path = temp_file except Exception as e: raise Exception(f"下载图片失败: {e}") # 验证输入文件是否存在 if not os.path.exists(image_path): raise FileNotFoundError(f"输入图像不存在: {image_path}") # 如果输出目录未指定,则使用当前目录 if output_dir is None: output_dir = os.getcwd() # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) # 加载并处理图像 image = Image.open(image_path).convert("RGB") image_no_bg = process_image(image) # 保存结果 filename = os.path.basename(image_path) name, ext = os.path.splitext(filename) output_file_path = os.path.join(output_dir, f"{name}_nobg.png") image_no_bg.save(output_file_path) return output_file_path finally: # 清理临时文件 if temp_file and os.path.exists(temp_file): try: os.unlink(temp_file) except: pass def main(): parser = argparse.ArgumentParser(description='移除图像背景') parser.add_argument('--image', help='输入图像的路径或URL') parser.add_argument('--output-dir', default='output', help='输出目录') args = parser.parse_args() try: # 如果没有提供图片参数,使用示例URL if args.image is None: print("未提供图片参数,使用示例图片URL进行测试...") args.image = "http://test001.jingrow.com/files/3T8A4426.JPG" print(f"处理图片: {args.image}") t0 = time.time() output_path = remove_background(args.image, args.output_dir) print(f"处理完成,用时 {time.time()-t0:.1f} 秒") print(f"输出文件: {output_path}") except Exception as e: print(f"错误: {e}") finally: # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() if __name__ == "__main__": main()