初始提交

This commit is contained in:
jingrow 2025-05-05 05:37:53 +00:00
commit e2575d5e61
8 changed files with 450 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 1005 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1005 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1005 KiB

145
task/remove_background.py Normal file
View File

@ -0,0 +1,145 @@
import os
import argparse
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
# 设置torch精度
torch.set_float32_matmul_precision("high")
# 初始化模型
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
model.to("cuda")
model.eval() # 设置为评估模式
# 定义图像转换
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
def is_valid_url(url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def download_image(url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
def process_image(image):
"""处理图像,移除背景"""
image_size = image.size
# 转换图像
input_images = transform_image(image).unsqueeze(0).to("cuda")
# 使用半精度加速
with torch.cuda.amp.autocast():
# 预测
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
return image
def remove_background(image_path, output_dir=None):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
output_dir: 输出目录如果不指定则使用当前目录
Returns:
处理后的图像路径
"""
temp_file = None
try:
# 检查是否是URL
if is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 如果输出目录未指定,则使用当前目录
if output_dir is None:
output_dir = os.getcwd()
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = process_image(image)
# 保存结果
filename = os.path.basename(image_path)
name, ext = os.path.splitext(filename)
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
image_no_bg.save(output_file_path)
return output_file_path
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
def main():
parser = argparse.ArgumentParser(description='移除图像背景')
parser.add_argument('--image', help='输入图像的路径或URL')
parser.add_argument('--output-dir', default='output', help='输出目录')
args = parser.parse_args()
try:
# 如果没有提供图片参数使用示例URL
if args.image is None:
print("未提供图片参数使用示例图片URL进行测试...")
args.image = "http://test001.jingrow.com/files/3T8A4426.JPG"
print(f"处理图片: {args.image}")
t0 = time.time()
output_path = remove_background(args.image, args.output_dir)
print(f"处理完成,用时 {time.time()-t0:.1f}")
print(f"输出文件: {output_path}")
except Exception as e:
print(f"错误: {e}")
finally:
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
if __name__ == "__main__":
main()

4
task/requirements.txt Normal file
View File

@ -0,0 +1,4 @@
requests>=2.31.0
Pillow>=10.0.0
torch>=2.0.0
transformers>=4.30.0

106
task/rmbg.py Normal file
View File

@ -0,0 +1,106 @@
import os
import gradio as gr
from gradio_imageslider import ImageSlider
from loadimg import load_img
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
torch.set_float32_matmul_precision(["high", "highest"][0])
birefnet = AutoModelForImageSegmentation.from_pretrained(
"briaai/RMBG-2.0", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
output_folder = 'output_images'
if not os.path.exists(output_folder):
os.makedirs(output_folder)
def fn(image):
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
image = process(im)
image_path = os.path.join(output_folder, "no_bg_image.png")
image.save(image_path)
return (image, origin), image_path
@spaces.GPU
def process(image):
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f):
name_path = f.rsplit(".",1)[0]+".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
slider1 = ImageSlider(label="RMBG-2.0", type="pil")
slider2 = ImageSlider(label="RMBG-2.0", type="pil")
image = gr.Image(label="Upload an image")
image2 = gr.Image(label="Upload an image",type="filepath")
text = gr.Textbox(label="Paste an image URL")
png_file = gr.File(label="output png file")
chameleon = load_img("giraffe.jpg", output_type="pil")
url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"
tab1 = gr.Interface(
fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image"
)
tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text")
tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
demo = gr.TabbedInterface(
[tab1, tab2], ["input image", "input url"], title = (
"RMBG-2.0 for background removal <br>"
"<span style='font-size:16px; font-weight:300;'>"
"Background removal model developed by "
"<a href='https://bria.ai' target='_blank'>BRIA.AI</a>, trained on a carefully selected dataset,<br> "
"and is available as an open-source model for non-commercial use.</span><br>"
"<span style='font-size:16px; font-weight:500;'> For testing upload your image and wait.<br>"
"<a href='https://go.bria.ai/3ZCBTLH' target='_blank'>Commercial use license</a> | "
"<a href='https://huggingface.co/briaai/RMBG-2.0' target='_blank'>Model card</a> | "
"<a href='https://blog.bria.ai/brias-new-state-of-the-art-remove-background-2.0-outperforms-the-competition' target='_blank'>Blog</a>"
"</span><br>"
"<span style='font-size:16px; font-weight:300;'>"
"API Endpoint available on: "
"<a href='https://platform.bria.ai/console/api/image-editing' target='_blank'>Bria.ai</a>, "
"<a href='https://fal.ai/models/fal-ai/bria/background/remove' target='_blank'>fal.ai</a><br>"
"ComfyUI node is available here: "
"<a href='https://github.com/Bria-AI/ComfyUI-BRIA-API' target='_blank'>ComfyUI Node</a><br>"
"Purchase commercial weigths for commercial use: "
"<a href='https://go.bria.ai/3D5EGp0' target='_blank'>here</a>"
"</span>"
)
)
if __name__ == "__main__":
demo.launch(show_error=True)

195
task/rmbg2.py Normal file
View File

@ -0,0 +1,195 @@
import os
import argparse
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
import gc
# 关闭不必要的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# 设置torch精度和优化
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优
torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32
torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32
# 示例图片URL
image_url = "http://test001.jingrow.com/files/3T8A4426.JPG"
def is_valid_url(url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def download_image(url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
def process_image(image, model, device):
"""处理图像,移除背景"""
image_size = image.size
# 转换图像
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 使用半精度加速
with torch.cuda.amp.autocast():
input_images = transform_image(image).unsqueeze(0).to(device)
# 预测
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
return image
def remove_background(image_path, output_dir=None, model=None, device=None):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
output_dir: 输出目录如果不指定则使用当前目录
model: 模型实例
device: 设备
Returns:
处理后的图像路径
"""
temp_file = None
try:
# 检查是否是URL
if is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 如果输出目录未指定,则使用当前目录
if output_dir is None:
output_dir = os.getcwd()
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = process_image(image, model, device)
# 保存结果
filename = os.path.basename(image_path)
name, ext = os.path.splitext(filename)
output_file_path = os.path.join(output_dir, f"{name}_nobg.png")
image_no_bg.save(output_file_path)
return output_file_path
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
def process_batch(image_paths, output_dir, model, device, max_workers=4):
"""批量处理图片"""
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for image_path in image_paths:
future = executor.submit(remove_background, image_path, output_dir, model, device)
futures.append((image_path, future))
results = []
for image_path, future in futures:
try:
output_path = future.result()
results.append((image_path, output_path, None))
except Exception as e:
results.append((image_path, None, str(e)))
return results
def main():
parser = argparse.ArgumentParser(description='移除图像背景')
parser.add_argument('--image', nargs='+', help='输入图像的路径或URL支持多个')
parser.add_argument('--output-dir', default='output', help='输出目录')
parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径')
parser.add_argument('--batch-size', type=int, default=4, help='批处理大小')
args = parser.parse_args()
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
try:
# 加载模型
print("正在加载模型...")
t0 = time.time()
model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True)
model = model.to(device)
model.eval() # 设置为评估模式
print(f"模型加载完成,用时 {time.time()-t0:.1f}")
# 处理图片
image_paths = args.image if args.image else [image_url]
if len(image_paths) > 1:
print(f"开始批量处理 {len(image_paths)} 张图片...")
results = process_batch(image_paths, args.output_dir, model, device, args.batch_size)
for image_path, output_path, error in results:
if error:
print(f"处理失败: {image_path}, 错误: {error}")
else:
print(f"处理成功: {image_path} -> {output_path}")
else:
print(f"处理图片: {image_paths[0]}")
t1 = time.time()
output_path = remove_background(image_paths[0], args.output_dir, model, device)
print(f"处理完成,用时 {time.time()-t1:.1f}")
print(f"输出文件: {output_path}")
except Exception as e:
print(f"错误: {e}")
finally:
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
if __name__ == "__main__":
main()

0
utils/utils.py Normal file
View File