初始提交
This commit is contained in:
commit
e2575d5e61
BIN
task/output/tmp1uqgmfm9_nobg.png
Normal file
BIN
task/output/tmp1uqgmfm9_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
BIN
task/output/tmp4hma0__o_nobg.png
Normal file
BIN
task/output/tmp4hma0__o_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
BIN
task/output/tmpn8r2g6do_nobg.png
Normal file
BIN
task/output/tmpn8r2g6do_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
145
task/remove_background.py
Normal file
145
task/remove_background.py
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 设置torch精度
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
|
||||||
|
model.to("cuda")
|
||||||
|
model.eval() # 设置为评估模式
|
||||||
|
|
||||||
|
# 定义图像转换
|
||||||
|
transform_image = transforms.Compose([
|
||||||
|
transforms.Resize((1024, 1024)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
def is_valid_url(url):
|
||||||
|
"""验证URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_image(url):
|
||||||
|
"""从URL下载图片到临时文件"""
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||||
|
with open(temp_file.name, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
def process_image(image):
|
||||||
|
"""处理图像,移除背景"""
|
||||||
|
image_size = image.size
|
||||||
|
# 转换图像
|
||||||
|
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
||||||
|
|
||||||
|
# 使用半精度加速
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
# 预测
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(input_images)[-1].sigmoid().cpu()
|
||||||
|
|
||||||
|
# 处理预测结果
|
||||||
|
pred = preds[0].squeeze()
|
||||||
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
|
mask = pred_pil.resize(image_size)
|
||||||
|
|
||||||
|
# 添加透明通道
|
||||||
|
image.putalpha(mask)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def remove_background(image_path, output_dir=None):
|
||||||
|
"""
|
||||||
|
移除图像背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 输入图像的路径或URL
|
||||||
|
output_dir: 输出目录,如果不指定则使用当前目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图像路径
|
||||||
|
"""
|
||||||
|
temp_file = None
|
||||||
|
try:
|
||||||
|
# 检查是否是URL
|
||||||
|
if is_valid_url(image_path):
|
||||||
|
try:
|
||||||
|
# 下载图片到临时文件
|
||||||
|
temp_file = download_image(image_path)
|
||||||
|
image_path = temp_file
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"下载图片失败: {e}")
|
||||||
|
|
||||||
|
# 验证输入文件是否存在
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||||
|
|
||||||
|
# 如果输出目录未指定,则使用当前目录
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.getcwd()
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载并处理图像
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image_no_bg = process_image(image)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
filename = os.path.basename(image_path)
|
||||||
|
name, ext = os.path.splitext(filename)
|
||||||
|
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||||||
|
image_no_bg.save(output_file_path)
|
||||||
|
|
||||||
|
return output_file_path
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if temp_file and os.path.exists(temp_file):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='移除图像背景')
|
||||||
|
parser.add_argument('--image', help='输入图像的路径或URL')
|
||||||
|
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 如果没有提供图片参数,使用示例URL
|
||||||
|
if args.image is None:
|
||||||
|
print("未提供图片参数,使用示例图片URL进行测试...")
|
||||||
|
args.image = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||||||
|
|
||||||
|
print(f"处理图片: {args.image}")
|
||||||
|
t0 = time.time()
|
||||||
|
output_path = remove_background(args.image, args.output_dir)
|
||||||
|
print(f"处理完成,用时 {time.time()-t0:.1f} 秒")
|
||||||
|
print(f"输出文件: {output_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
finally:
|
||||||
|
# 清理显存
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
4
task/requirements.txt
Normal file
4
task/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
requests>=2.31.0
|
||||||
|
Pillow>=10.0.0
|
||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.30.0
|
||||||
106
task/rmbg.py
Normal file
106
task/rmbg.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
import os
|
||||||
|
import gradio as gr
|
||||||
|
from gradio_imageslider import ImageSlider
|
||||||
|
from loadimg import load_img
|
||||||
|
import spaces
|
||||||
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||||
|
|
||||||
|
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||||
|
"briaai/RMBG-2.0", trust_remote_code=True
|
||||||
|
)
|
||||||
|
birefnet.to("cuda")
|
||||||
|
transform_image = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.Resize((1024, 1024)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
output_folder = 'output_images'
|
||||||
|
if not os.path.exists(output_folder):
|
||||||
|
os.makedirs(output_folder)
|
||||||
|
|
||||||
|
def fn(image):
|
||||||
|
im = load_img(image, output_type="pil")
|
||||||
|
im = im.convert("RGB")
|
||||||
|
origin = im.copy()
|
||||||
|
image = process(im)
|
||||||
|
image_path = os.path.join(output_folder, "no_bg_image.png")
|
||||||
|
image.save(image_path)
|
||||||
|
return (image, origin), image_path
|
||||||
|
|
||||||
|
@spaces.GPU
|
||||||
|
def process(image):
|
||||||
|
image_size = image.size
|
||||||
|
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
||||||
|
# Prediction
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
||||||
|
pred = preds[0].squeeze()
|
||||||
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
|
mask = pred_pil.resize(image_size)
|
||||||
|
image.putalpha(mask)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def process_file(f):
|
||||||
|
name_path = f.rsplit(".",1)[0]+".png"
|
||||||
|
im = load_img(f, output_type="pil")
|
||||||
|
im = im.convert("RGB")
|
||||||
|
transparent = process(im)
|
||||||
|
transparent.save(name_path)
|
||||||
|
return name_path
|
||||||
|
|
||||||
|
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
|
||||||
|
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
|
||||||
|
image = gr.Image(label="Upload an image")
|
||||||
|
image2 = gr.Image(label="Upload an image",type="filepath")
|
||||||
|
text = gr.Textbox(label="Paste an image URL")
|
||||||
|
png_file = gr.File(label="output png file")
|
||||||
|
|
||||||
|
|
||||||
|
chameleon = load_img("giraffe.jpg", output_type="pil")
|
||||||
|
|
||||||
|
url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
|
||||||
|
|
||||||
|
tab1 = gr.Interface(
|
||||||
|
fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
|
||||||
|
)
|
||||||
|
|
||||||
|
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
|
||||||
|
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
|
||||||
|
|
||||||
|
|
||||||
|
demo = gr.TabbedInterface(
|
||||||
|
[tab1, tab2], ["input image", "input url"], title = (
|
||||||
|
"RMBG-2.0 for background removal <br>"
|
||||||
|
"<span style='font-size:16px; font-weight:300;'>"
|
||||||
|
"Background removal model developed by "
|
||||||
|
"<a href='https://bria.ai' target='_blank'>BRIA.AI</a>, trained on a carefully selected dataset,<br> "
|
||||||
|
"and is available as an open-source model for non-commercial use.</span><br>"
|
||||||
|
"<span style='font-size:16px; font-weight:500;'> For testing upload your image and wait.<br>"
|
||||||
|
"<a href='https://go.bria.ai/3ZCBTLH' target='_blank'>Commercial use license</a> | "
|
||||||
|
"<a href='https://huggingface.co/briaai/RMBG-2.0' target='_blank'>Model card</a> | "
|
||||||
|
"<a href='https://blog.bria.ai/brias-new-state-of-the-art-remove-background-2.0-outperforms-the-competition' target='_blank'>Blog</a>"
|
||||||
|
"</span><br>"
|
||||||
|
"<span style='font-size:16px; font-weight:300;'>"
|
||||||
|
"API Endpoint available on: "
|
||||||
|
"<a href='https://platform.bria.ai/console/api/image-editing' target='_blank'>Bria.ai</a>, "
|
||||||
|
"<a href='https://fal.ai/models/fal-ai/bria/background/remove' target='_blank'>fal.ai</a><br>"
|
||||||
|
"ComfyUI node is available here: "
|
||||||
|
"<a href='https://github.com/Bria-AI/ComfyUI-BRIA-API' target='_blank'>ComfyUI Node</a><br>"
|
||||||
|
"Purchase commercial weigths for commercial use: "
|
||||||
|
"<a href='https://go.bria.ai/3D5EGp0' target='_blank'>here</a>"
|
||||||
|
"</span>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo.launch(show_error=True)
|
||||||
195
task/rmbg2.py
Normal file
195
task/rmbg2.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# 关闭不必要的警告
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
# 设置torch精度和优化
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32
|
||||||
|
torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32
|
||||||
|
|
||||||
|
# 示例图片URL
|
||||||
|
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||||||
|
|
||||||
|
def is_valid_url(url):
|
||||||
|
"""验证URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_image(url):
|
||||||
|
"""从URL下载图片到临时文件"""
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||||
|
with open(temp_file.name, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
def process_image(image, model, device):
|
||||||
|
"""处理图像,移除背景"""
|
||||||
|
image_size = image.size
|
||||||
|
# 转换图像
|
||||||
|
transform_image = transforms.Compose([
|
||||||
|
transforms.Resize((1024, 1024)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 使用半精度加速
|
||||||
|
with torch.cuda.amp.autocast():
|
||||||
|
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||||
|
# 预测
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(input_images)[-1].sigmoid().cpu()
|
||||||
|
|
||||||
|
# 处理预测结果
|
||||||
|
pred = preds[0].squeeze()
|
||||||
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
|
mask = pred_pil.resize(image_size)
|
||||||
|
|
||||||
|
# 添加透明通道
|
||||||
|
image.putalpha(mask)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def remove_background(image_path, output_dir=None, model=None, device=None):
|
||||||
|
"""
|
||||||
|
移除图像背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 输入图像的路径或URL
|
||||||
|
output_dir: 输出目录,如果不指定则使用当前目录
|
||||||
|
model: 模型实例
|
||||||
|
device: 设备
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图像路径
|
||||||
|
"""
|
||||||
|
temp_file = None
|
||||||
|
try:
|
||||||
|
# 检查是否是URL
|
||||||
|
if is_valid_url(image_path):
|
||||||
|
try:
|
||||||
|
# 下载图片到临时文件
|
||||||
|
temp_file = download_image(image_path)
|
||||||
|
image_path = temp_file
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"下载图片失败: {e}")
|
||||||
|
|
||||||
|
# 验证输入文件是否存在
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||||
|
|
||||||
|
# 如果输出目录未指定,则使用当前目录
|
||||||
|
if output_dir is None:
|
||||||
|
output_dir = os.getcwd()
|
||||||
|
|
||||||
|
# 确保输出目录存在
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 加载并处理图像
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image_no_bg = process_image(image, model, device)
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
filename = os.path.basename(image_path)
|
||||||
|
name, ext = os.path.splitext(filename)
|
||||||
|
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||||||
|
image_no_bg.save(output_file_path)
|
||||||
|
|
||||||
|
return output_file_path
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if temp_file and os.path.exists(temp_file):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process_batch(image_paths, output_dir, model, device, max_workers=4):
|
||||||
|
"""批量处理图片"""
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
futures = []
|
||||||
|
for image_path in image_paths:
|
||||||
|
future = executor.submit(remove_background, image_path, output_dir, model, device)
|
||||||
|
futures.append((image_path, future))
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for image_path, future in futures:
|
||||||
|
try:
|
||||||
|
output_path = future.result()
|
||||||
|
results.append((image_path, output_path, None))
|
||||||
|
except Exception as e:
|
||||||
|
results.append((image_path, None, str(e)))
|
||||||
|
return results
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='移除图像背景')
|
||||||
|
parser.add_argument('--image', nargs='+', help='输入图像的路径或URL(支持多个)')
|
||||||
|
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||||||
|
parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=4, help='批处理大小')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 设置设备
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"使用设备: {device}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载模型
|
||||||
|
print("正在加载模型...")
|
||||||
|
t0 = time.time()
|
||||||
|
model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True)
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval() # 设置为评估模式
|
||||||
|
print(f"模型加载完成,用时 {time.time()-t0:.1f} 秒")
|
||||||
|
|
||||||
|
# 处理图片
|
||||||
|
image_paths = args.image if args.image else [image_url]
|
||||||
|
if len(image_paths) > 1:
|
||||||
|
print(f"开始批量处理 {len(image_paths)} 张图片...")
|
||||||
|
results = process_batch(image_paths, args.output_dir, model, device, args.batch_size)
|
||||||
|
for image_path, output_path, error in results:
|
||||||
|
if error:
|
||||||
|
print(f"处理失败: {image_path}, 错误: {error}")
|
||||||
|
else:
|
||||||
|
print(f"处理成功: {image_path} -> {output_path}")
|
||||||
|
else:
|
||||||
|
print(f"处理图片: {image_paths[0]}")
|
||||||
|
t1 = time.time()
|
||||||
|
output_path = remove_background(image_paths[0], args.output_dir, model, device)
|
||||||
|
print(f"处理完成,用时 {time.time()-t1:.1f} 秒")
|
||||||
|
print(f"输出文件: {output_path}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"错误: {e}")
|
||||||
|
finally:
|
||||||
|
# 清理显存
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
utils/utils.py
Normal file
0
utils/utils.py
Normal file
Loading…
x
Reference in New Issue
Block a user