task/app/api/v1/endpoints/remove_background.py

146 lines
4.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
# 设置torch精度
torch.set_float32_matmul_precision("high")
# 初始化模型
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
model.to("cuda")
model.eval() # 设置为评估模式
# 定义图像转换
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def is_valid_url(url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def download_image(url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
def process_image(image):
"""处理图像,移除背景"""
image_size = image.size
# 转换图像
input_images = transform_image(image).unsqueeze(0).to("cuda")
# 使用半精度加速
with torch.cuda.amp.autocast():
# 预测
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
return image
def remove_background(image_path, output_dir=None):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
output_dir: 输出目录,如果不指定则使用当前目录
Returns:
处理后的图像路径
"""
temp_file = None
try:
# 检查是否是URL
if is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 如果输出目录未指定,则使用当前目录
if output_dir is None:
output_dir = os.getcwd()
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = process_image(image)
# 保存结果
filename = os.path.basename(image_path)
name, ext = os.path.splitext(filename)
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
image_no_bg.save(output_file_path)
return output_file_path
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
def main():
parser = argparse.ArgumentParser(description='移除图像背景')
parser.add_argument('--image', help='输入图像的路径或URL')
parser.add_argument('--output-dir', default='output', help='输出目录')
args = parser.parse_args()
try:
# 如果没有提供图片参数使用示例URL
if args.image is None:
print("未提供图片参数使用示例图片URL进行测试...")
args.image = "http://test001.jingrow.com/files/3T8A4426.JPG"
print(f"处理图片: {args.image}")
t0 = time.time()
output_path = remove_background(args.image, args.output_dir)
print(f"处理完成,用时 {time.time()-t0:.1f}")
print(f"输出文件: {output_path}")
except Exception as e:
print(f"错误: {e}")
finally:
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()