初始提交
This commit is contained in:
commit
e2575d5e61
BIN
task/output/tmp1uqgmfm9_nobg.png
Normal file
BIN
task/output/tmp1uqgmfm9_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
BIN
task/output/tmp4hma0__o_nobg.png
Normal file
BIN
task/output/tmp4hma0__o_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
BIN
task/output/tmpn8r2g6do_nobg.png
Normal file
BIN
task/output/tmpn8r2g6do_nobg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1005 KiB |
145
task/remove_background.py
Normal file
145
task/remove_background.py
Normal file
@ -0,0 +1,145 @@
|
||||
import os
|
||||
import argparse
|
||||
import requests
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import time
|
||||
|
||||
# 设置torch精度
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
# 初始化模型
|
||||
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
|
||||
model.to("cuda")
|
||||
model.eval() # 设置为评估模式
|
||||
|
||||
# 定义图像转换
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
def is_valid_url(url):
|
||||
"""验证URL是否有效"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return all([result.scheme, result.netloc])
|
||||
except:
|
||||
return False
|
||||
|
||||
def download_image(url):
|
||||
"""从URL下载图片到临时文件"""
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
# 创建临时文件
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||
with open(temp_file.name, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return temp_file.name
|
||||
|
||||
def process_image(image):
|
||||
"""处理图像,移除背景"""
|
||||
image_size = image.size
|
||||
# 转换图像
|
||||
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
||||
|
||||
# 使用半精度加速
|
||||
with torch.cuda.amp.autocast():
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
preds = model(input_images)[-1].sigmoid().cpu()
|
||||
|
||||
# 处理预测结果
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# 添加透明通道
|
||||
image.putalpha(mask)
|
||||
return image
|
||||
|
||||
def remove_background(image_path, output_dir=None):
|
||||
"""
|
||||
移除图像背景
|
||||
|
||||
Args:
|
||||
image_path: 输入图像的路径或URL
|
||||
output_dir: 输出目录,如果不指定则使用当前目录
|
||||
|
||||
Returns:
|
||||
处理后的图像路径
|
||||
"""
|
||||
temp_file = None
|
||||
try:
|
||||
# 检查是否是URL
|
||||
if is_valid_url(image_path):
|
||||
try:
|
||||
# 下载图片到临时文件
|
||||
temp_file = download_image(image_path)
|
||||
image_path = temp_file
|
||||
except Exception as e:
|
||||
raise Exception(f"下载图片失败: {e}")
|
||||
|
||||
# 验证输入文件是否存在
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||
|
||||
# 如果输出目录未指定,则使用当前目录
|
||||
if output_dir is None:
|
||||
output_dir = os.getcwd()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 加载并处理图像
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image_no_bg = process_image(image)
|
||||
|
||||
# 保存结果
|
||||
filename = os.path.basename(image_path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||||
image_no_bg.save(output_file_path)
|
||||
|
||||
return output_file_path
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='移除图像背景')
|
||||
parser.add_argument('--image', help='输入图像的路径或URL')
|
||||
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# 如果没有提供图片参数,使用示例URL
|
||||
if args.image is None:
|
||||
print("未提供图片参数,使用示例图片URL进行测试...")
|
||||
args.image = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||||
|
||||
print(f"处理图片: {args.image}")
|
||||
t0 = time.time()
|
||||
output_path = remove_background(args.image, args.output_dir)
|
||||
print(f"处理完成,用时 {time.time()-t0:.1f} 秒")
|
||||
print(f"输出文件: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
# 清理显存
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
task/requirements.txt
Normal file
4
task/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
requests>=2.31.0
|
||||
Pillow>=10.0.0
|
||||
torch>=2.0.0
|
||||
transformers>=4.30.0
|
||||
106
task/rmbg.py
Normal file
106
task/rmbg.py
Normal file
@ -0,0 +1,106 @@
|
||||
import os
|
||||
import gradio as gr
|
||||
from gradio_imageslider import ImageSlider
|
||||
from loadimg import load_img
|
||||
import spaces
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
torch.set_float32_matmul_precision(["high", "highest"][0])
|
||||
|
||||
birefnet = AutoModelForImageSegmentation.from_pretrained(
|
||||
"briaai/RMBG-2.0", trust_remote_code=True
|
||||
)
|
||||
birefnet.to("cuda")
|
||||
transform_image = transforms.Compose(
|
||||
[
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
output_folder = 'output_images'
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder)
|
||||
|
||||
def fn(image):
|
||||
im = load_img(image, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
origin = im.copy()
|
||||
image = process(im)
|
||||
image_path = os.path.join(output_folder, "no_bg_image.png")
|
||||
image.save(image_path)
|
||||
return (image, origin), image_path
|
||||
|
||||
@spaces.GPU
|
||||
def process(image):
|
||||
image_size = image.size
|
||||
input_images = transform_image(image).unsqueeze(0).to("cuda")
|
||||
# Prediction
|
||||
with torch.no_grad():
|
||||
preds = birefnet(input_images)[-1].sigmoid().cpu()
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
image.putalpha(mask)
|
||||
return image
|
||||
|
||||
def process_file(f):
|
||||
name_path = f.rsplit(".",1)[0]+".png"
|
||||
im = load_img(f, output_type="pil")
|
||||
im = im.convert("RGB")
|
||||
transparent = process(im)
|
||||
transparent.save(name_path)
|
||||
return name_path
|
||||
|
||||
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
|
||||
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
|
||||
image = gr.Image(label="Upload an image")
|
||||
image2 = gr.Image(label="Upload an image",type="filepath")
|
||||
text = gr.Textbox(label="Paste an image URL")
|
||||
png_file = gr.File(label="output png file")
|
||||
|
||||
|
||||
chameleon = load_img("giraffe.jpg", output_type="pil")
|
||||
|
||||
url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
|
||||
|
||||
tab1 = gr.Interface(
|
||||
fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
|
||||
)
|
||||
|
||||
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
|
||||
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
|
||||
|
||||
|
||||
demo = gr.TabbedInterface(
|
||||
[tab1, tab2], ["input image", "input url"], title = (
|
||||
"RMBG-2.0 for background removal <br>"
|
||||
"<span style='font-size:16px; font-weight:300;'>"
|
||||
"Background removal model developed by "
|
||||
"<a href='https://bria.ai' target='_blank'>BRIA.AI</a>, trained on a carefully selected dataset,<br> "
|
||||
"and is available as an open-source model for non-commercial use.</span><br>"
|
||||
"<span style='font-size:16px; font-weight:500;'> For testing upload your image and wait.<br>"
|
||||
"<a href='https://go.bria.ai/3ZCBTLH' target='_blank'>Commercial use license</a> | "
|
||||
"<a href='https://huggingface.co/briaai/RMBG-2.0' target='_blank'>Model card</a> | "
|
||||
"<a href='https://blog.bria.ai/brias-new-state-of-the-art-remove-background-2.0-outperforms-the-competition' target='_blank'>Blog</a>"
|
||||
"</span><br>"
|
||||
"<span style='font-size:16px; font-weight:300;'>"
|
||||
"API Endpoint available on: "
|
||||
"<a href='https://platform.bria.ai/console/api/image-editing' target='_blank'>Bria.ai</a>, "
|
||||
"<a href='https://fal.ai/models/fal-ai/bria/background/remove' target='_blank'>fal.ai</a><br>"
|
||||
"ComfyUI node is available here: "
|
||||
"<a href='https://github.com/Bria-AI/ComfyUI-BRIA-API' target='_blank'>ComfyUI Node</a><br>"
|
||||
"Purchase commercial weigths for commercial use: "
|
||||
"<a href='https://go.bria.ai/3D5EGp0' target='_blank'>here</a>"
|
||||
"</span>"
|
||||
)
|
||||
|
||||
|
||||
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo.launch(show_error=True)
|
||||
195
task/rmbg2.py
Normal file
195
task/rmbg2.py
Normal file
@ -0,0 +1,195 @@
|
||||
import os
|
||||
import argparse
|
||||
import requests
|
||||
import tempfile
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from transformers import AutoModelForImageSegmentation
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import gc
|
||||
|
||||
# 关闭不必要的警告
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
|
||||
# 设置torch精度和优化
|
||||
torch.set_float32_matmul_precision("high")
|
||||
torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32
|
||||
torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32
|
||||
|
||||
# 示例图片URL
|
||||
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
|
||||
|
||||
def is_valid_url(url):
|
||||
"""验证URL是否有效"""
|
||||
try:
|
||||
result = urlparse(url)
|
||||
return all([result.scheme, result.netloc])
|
||||
except:
|
||||
return False
|
||||
|
||||
def download_image(url):
|
||||
"""从URL下载图片到临时文件"""
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
# 创建临时文件
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||
with open(temp_file.name, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
return temp_file.name
|
||||
|
||||
def process_image(image, model, device):
|
||||
"""处理图像,移除背景"""
|
||||
image_size = image.size
|
||||
# 转换图像
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
# 使用半精度加速
|
||||
with torch.cuda.amp.autocast():
|
||||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
preds = model(input_images)[-1].sigmoid().cpu()
|
||||
|
||||
# 处理预测结果
|
||||
pred = preds[0].squeeze()
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
|
||||
# 添加透明通道
|
||||
image.putalpha(mask)
|
||||
return image
|
||||
|
||||
def remove_background(image_path, output_dir=None, model=None, device=None):
|
||||
"""
|
||||
移除图像背景
|
||||
|
||||
Args:
|
||||
image_path: 输入图像的路径或URL
|
||||
output_dir: 输出目录,如果不指定则使用当前目录
|
||||
model: 模型实例
|
||||
device: 设备
|
||||
|
||||
Returns:
|
||||
处理后的图像路径
|
||||
"""
|
||||
temp_file = None
|
||||
try:
|
||||
# 检查是否是URL
|
||||
if is_valid_url(image_path):
|
||||
try:
|
||||
# 下载图片到临时文件
|
||||
temp_file = download_image(image_path)
|
||||
image_path = temp_file
|
||||
except Exception as e:
|
||||
raise Exception(f"下载图片失败: {e}")
|
||||
|
||||
# 验证输入文件是否存在
|
||||
if not os.path.exists(image_path):
|
||||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||
|
||||
# 如果输出目录未指定,则使用当前目录
|
||||
if output_dir is None:
|
||||
output_dir = os.getcwd()
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 加载并处理图像
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image_no_bg = process_image(image, model, device)
|
||||
|
||||
# 保存结果
|
||||
filename = os.path.basename(image_path)
|
||||
name, ext = os.path.splitext(filename)
|
||||
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
|
||||
image_no_bg.save(output_file_path)
|
||||
|
||||
return output_file_path
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file and os.path.exists(temp_file):
|
||||
try:
|
||||
os.unlink(temp_file)
|
||||
except:
|
||||
pass
|
||||
|
||||
def process_batch(image_paths, output_dir, model, device, max_workers=4):
|
||||
"""批量处理图片"""
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = []
|
||||
for image_path in image_paths:
|
||||
future = executor.submit(remove_background, image_path, output_dir, model, device)
|
||||
futures.append((image_path, future))
|
||||
|
||||
results = []
|
||||
for image_path, future in futures:
|
||||
try:
|
||||
output_path = future.result()
|
||||
results.append((image_path, output_path, None))
|
||||
except Exception as e:
|
||||
results.append((image_path, None, str(e)))
|
||||
return results
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='移除图像背景')
|
||||
parser.add_argument('--image', nargs='+', help='输入图像的路径或URL(支持多个)')
|
||||
parser.add_argument('--output-dir', default='output', help='输出目录')
|
||||
parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径')
|
||||
parser.add_argument('--batch-size', type=int, default=4, help='批处理大小')
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置设备
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f"使用设备: {device}")
|
||||
if torch.cuda.is_available():
|
||||
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
||||
|
||||
try:
|
||||
# 加载模型
|
||||
print("正在加载模型...")
|
||||
t0 = time.time()
|
||||
model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
model = model.to(device)
|
||||
model.eval() # 设置为评估模式
|
||||
print(f"模型加载完成,用时 {time.time()-t0:.1f} 秒")
|
||||
|
||||
# 处理图片
|
||||
image_paths = args.image if args.image else [image_url]
|
||||
if len(image_paths) > 1:
|
||||
print(f"开始批量处理 {len(image_paths)} 张图片...")
|
||||
results = process_batch(image_paths, args.output_dir, model, device, args.batch_size)
|
||||
for image_path, output_path, error in results:
|
||||
if error:
|
||||
print(f"处理失败: {image_path}, 错误: {error}")
|
||||
else:
|
||||
print(f"处理成功: {image_path} -> {output_path}")
|
||||
else:
|
||||
print(f"处理图片: {image_paths[0]}")
|
||||
t1 = time.time()
|
||||
output_path = remove_background(image_paths[0], args.output_dir, model, device)
|
||||
print(f"处理完成,用时 {time.time()-t1:.1f} 秒")
|
||||
print(f"输出文件: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
finally:
|
||||
# 清理显存
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
utils/utils.py
Normal file
0
utils/utils.py
Normal file
Loading…
x
Reference in New Issue
Block a user