146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
import os
|
||
import argparse
|
||
import requests
|
||
import tempfile
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from transformers import AutoModelForImageSegmentation
|
||
import time
|
||
|
||
# 设置torch精度
|
||
torch.set_float32_matmul_precision("high")
|
||
|
||
# 初始化模型
|
||
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
|
||
model.to("cuda")
|
||
model.eval() # 设置为评估模式
|
||
|
||
# 定义图像转换
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
def is_valid_url(url):
|
||
"""验证URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
def download_image(url):
|
||
"""从URL下载图片到临时文件"""
|
||
response = requests.get(url, stream=True)
|
||
response.raise_for_status()
|
||
|
||
# 创建临时文件
|
||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||
with open(temp_file.name, 'wb') as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
return temp_file.name
|
||
|
||
def process_image(image):
|
||
"""处理图像,移除背景"""
|
||
image_size = image.size
|
||
# 转换图像
|
||
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
||
|
||
# 使用半精度加速
|
||
with torch.cuda.amp.autocast():
|
||
# 预测
|
||
with torch.no_grad():
|
||
preds = model(input_images)[-1].sigmoid().cpu()
|
||
|
||
# 处理预测结果
|
||
pred = preds[0].squeeze()
|
||
pred_pil = transforms.ToPILImage()(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
|
||
# 添加透明通道
|
||
image.putalpha(mask)
|
||
return image
|
||
|
||
def remove_background(image_path, output_dir=None):
|
||
"""
|
||
移除图像背景
|
||
|
||
Args:
|
||
image_path: 输入图像的路径或URL
|
||
output_dir: 输出目录,如果不指定则使用当前目录
|
||
|
||
Returns:
|
||
处理后的图像路径
|
||
"""
|
||
temp_file = None
|
||
try:
|
||
# 检查是否是URL
|
||
if is_valid_url(image_path):
|
||
try:
|
||
# 下载图片到临时文件
|
||
temp_file = download_image(image_path)
|
||
image_path = temp_file
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
# 验证输入文件是否存在
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||
|
||
# 如果输出目录未指定,则使用当前目录
|
||
if output_dir is None:
|
||
output_dir = os.getcwd()
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 加载并处理图像
|
||
image = Image.open(image_path).convert("RGB")
|
||
image_no_bg = process_image(image)
|
||
|
||
# 保存结果
|
||
filename = os.path.basename(image_path)
|
||
name, ext = os.path.splitext(filename)
|
||
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||
image_no_bg.save(output_file_path)
|
||
|
||
return output_file_path
|
||
finally:
|
||
# 清理临时文件
|
||
if temp_file and os.path.exists(temp_file):
|
||
try:
|
||
os.unlink(temp_file)
|
||
except:
|
||
pass
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='移除图像背景')
|
||
parser.add_argument('--image', help='输入图像的路径或URL')
|
||
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||
args = parser.parse_args()
|
||
|
||
try:
|
||
# 如果没有提供图片参数,使用示例URL
|
||
if args.image is None:
|
||
print("未提供图片参数,使用示例图片URL进行测试...")
|
||
args.image = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||
|
||
print(f"处理图片: {args.image}")
|
||
t0 = time.time()
|
||
output_path = remove_background(args.image, args.output_dir)
|
||
print(f"处理完成,用时 {time.time()-t0:.1f} 秒")
|
||
print(f"输出文件: {output_path}")
|
||
except Exception as e:
|
||
print(f"错误: {e}")
|
||
finally:
|
||
# 清理显存
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|