task/app/api/v1/endpoints/rmbg2.py

198 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import argparse
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
import gc
# 关闭不必要的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# 设置torch精度
torch.set_float32_matmul_precision("high")
# 示例图片URL
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
def is_valid_url(url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def download_image(url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
def process_image(image, model, device):
"""处理图像,移除背景"""
image_size = image.size
# 转换图像
t0 = time.time()
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(device)
print(f"图像预处理耗时: {time.time()-t0:.3f}")
# 使用半精度加速
t0 = time.time()
with torch.cuda.amp.autocast():
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
print(f"模型推理耗时: {time.time()-t0:.3f}")
# 处理预测结果
t0 = time.time()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
print(f"后处理耗时: {time.time()-t0:.3f}")
return image
def remove_background(image_path, output_dir=None, model=None, device=None):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
output_dir: 输出目录,如果不指定则使用当前目录
model: 模型实例
device: 设备
Returns:
处理后的图像路径
"""
temp_file = None
try:
# 检查是否是URL
if is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 如果输出目录未指定,则使用当前目录
if output_dir is None:
output_dir = os.getcwd()
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = process_image(image, model, device)
# 保存结果
filename = os.path.basename(image_path)
name, ext = os.path.splitext(filename)
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
image_no_bg.save(output_file_path)
return output_file_path
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
def process_batch(image_paths, output_dir, model, device, max_workers=4):
"""批量处理图片"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for image_path in image_paths:
future = executor.submit(remove_background, image_path, output_dir, model, device)
futures.append((image_path, future))
results = []
for image_path, future in futures:
try:
output_path = future.result()
results.append((image_path, output_path, None))
except Exception as e:
results.append((image_path, None, str(e)))
return results
def main():
parser = argparse.ArgumentParser(description='移除图像背景')
parser.add_argument('--image', nargs='+', help='输入图像的路径或URL支持多个')
parser.add_argument('--output-dir', default='output', help='输出目录')
parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径')
parser.add_argument('--batch-size', type=int, default=4, help='批处理大小')
args = parser.parse_args()
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
try:
# 加载模型
print("正在加载模型...")
t0 = time.time()
model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True)
model = model.to(device)
model.eval() # 设置为评估模式
print(f"模型加载完成,用时 {time.time()-t0:.1f}")
# 处理图片
image_paths = args.image if args.image else [image_url]
if len(image_paths) > 1:
print(f"开始批量处理 {len(image_paths)} 张图片...")
results = process_batch(image_paths, args.output_dir, model, device, args.batch_size)
for image_path, output_path, error in results:
if error:
print(f"处理失败: {image_path}, 错误: {error}")
else:
print(f"处理成功: {image_path} -> {output_path}")
else:
print(f"处理图片: {image_paths[0]}")
t1 = time.time()
output_path = remove_background(image_paths[0], args.output_dir, model, device)
print(f"处理完成,用时 {time.time()-t1:.1f}")
print(f"输出文件: {output_path}")
except Exception as e:
print(f"错误: {e}")
finally:
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
main()