196 lines
6.7 KiB
Python
196 lines
6.7 KiB
Python
import os
|
||
import argparse
|
||
import requests
|
||
import tempfile
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from transformers import AutoModelForImageSegmentation
|
||
import time
|
||
import warnings
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import gc
|
||
|
||
# 关闭不必要的警告
|
||
warnings.filterwarnings("ignore", category=UserWarning)
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
|
||
# 设置torch精度和优化
|
||
torch.set_float32_matmul_precision("high")
|
||
torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优
|
||
torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32
|
||
torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32
|
||
|
||
# 示例图片URL
|
||
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||
|
||
def is_valid_url(url):
|
||
"""验证URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
def download_image(url):
|
||
"""从URL下载图片到临时文件"""
|
||
response = requests.get(url, stream=True)
|
||
response.raise_for_status()
|
||
|
||
# 创建临时文件
|
||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||
with open(temp_file.name, 'wb') as f:
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
f.write(chunk)
|
||
return temp_file.name
|
||
|
||
def process_image(image, model, device):
|
||
"""处理图像,移除背景"""
|
||
image_size = image.size
|
||
# 转换图像
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
# 使用半精度加速
|
||
with torch.cuda.amp.autocast():
|
||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||
# 预测
|
||
with torch.no_grad():
|
||
preds = model(input_images)[-1].sigmoid().cpu()
|
||
|
||
# 处理预测结果
|
||
pred = preds[0].squeeze()
|
||
pred_pil = transforms.ToPILImage()(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
|
||
# 添加透明通道
|
||
image.putalpha(mask)
|
||
return image
|
||
|
||
def remove_background(image_path, output_dir=None, model=None, device=None):
|
||
"""
|
||
移除图像背景
|
||
|
||
Args:
|
||
image_path: 输入图像的路径或URL
|
||
output_dir: 输出目录,如果不指定则使用当前目录
|
||
model: 模型实例
|
||
device: 设备
|
||
|
||
Returns:
|
||
处理后的图像路径
|
||
"""
|
||
temp_file = None
|
||
try:
|
||
# 检查是否是URL
|
||
if is_valid_url(image_path):
|
||
try:
|
||
# 下载图片到临时文件
|
||
temp_file = download_image(image_path)
|
||
image_path = temp_file
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
# 验证输入文件是否存在
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||
|
||
# 如果输出目录未指定,则使用当前目录
|
||
if output_dir is None:
|
||
output_dir = os.getcwd()
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# 加载并处理图像
|
||
image = Image.open(image_path).convert("RGB")
|
||
image_no_bg = process_image(image, model, device)
|
||
|
||
# 保存结果
|
||
filename = os.path.basename(image_path)
|
||
name, ext = os.path.splitext(filename)
|
||
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||
image_no_bg.save(output_file_path)
|
||
|
||
return output_file_path
|
||
finally:
|
||
# 清理临时文件
|
||
if temp_file and os.path.exists(temp_file):
|
||
try:
|
||
os.unlink(temp_file)
|
||
except:
|
||
pass
|
||
|
||
def process_batch(image_paths, output_dir, model, device, max_workers=4):
|
||
"""批量处理图片"""
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
futures = []
|
||
for image_path in image_paths:
|
||
future = executor.submit(remove_background, image_path, output_dir, model, device)
|
||
futures.append((image_path, future))
|
||
|
||
results = []
|
||
for image_path, future in futures:
|
||
try:
|
||
output_path = future.result()
|
||
results.append((image_path, output_path, None))
|
||
except Exception as e:
|
||
results.append((image_path, None, str(e)))
|
||
return results
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='移除图像背景')
|
||
parser.add_argument('--image', nargs='+', help='输入图像的路径或URL(支持多个)')
|
||
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||
parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径')
|
||
parser.add_argument('--batch-size', type=int, default=4, help='批处理大小')
|
||
args = parser.parse_args()
|
||
|
||
# 设置设备
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print(f"使用设备: {device}")
|
||
if torch.cuda.is_available():
|
||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
||
|
||
try:
|
||
# 加载模型
|
||
print("正在加载模型...")
|
||
t0 = time.time()
|
||
model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True)
|
||
model = model.to(device)
|
||
model.eval() # 设置为评估模式
|
||
print(f"模型加载完成,用时 {time.time()-t0:.1f} 秒")
|
||
|
||
# 处理图片
|
||
image_paths = args.image if args.image else [image_url]
|
||
if len(image_paths) > 1:
|
||
print(f"开始批量处理 {len(image_paths)} 张图片...")
|
||
results = process_batch(image_paths, args.output_dir, model, device, args.batch_size)
|
||
for image_path, output_path, error in results:
|
||
if error:
|
||
print(f"处理失败: {image_path}, 错误: {error}")
|
||
else:
|
||
print(f"处理成功: {image_path} -> {output_path}")
|
||
else:
|
||
print(f"处理图片: {image_paths[0]}")
|
||
t1 = time.time()
|
||
output_path = remove_background(image_paths[0], args.output_dir, model, device)
|
||
print(f"处理完成,用时 {time.time()-t1:.1f} 秒")
|
||
print(f"输出文件: {output_path}")
|
||
|
||
except Exception as e:
|
||
print(f"错误: {e}")
|
||
finally:
|
||
# 清理显存
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|