commit e2575d5e61240b53bc890033f5e3b92f69afc7c7 Author: jingrow Date: Mon May 5 05:37:53 2025 +0000 初始提交 diff --git a/task/output/tmp1uqgmfm9_nobg.png b/task/output/tmp1uqgmfm9_nobg.png new file mode 100644 index 0000000..bef75db Binary files /dev/null and b/task/output/tmp1uqgmfm9_nobg.png differ diff --git a/task/output/tmp4hma0__o_nobg.png b/task/output/tmp4hma0__o_nobg.png new file mode 100644 index 0000000..bef75db Binary files /dev/null and b/task/output/tmp4hma0__o_nobg.png differ diff --git a/task/output/tmpn8r2g6do_nobg.png b/task/output/tmpn8r2g6do_nobg.png new file mode 100644 index 0000000..9045f12 Binary files /dev/null and b/task/output/tmpn8r2g6do_nobg.png differ diff --git a/task/remove_background.py b/task/remove_background.py new file mode 100644 index 0000000..abb9562 --- /dev/null +++ b/task/remove_background.py @@ -0,0 +1,145 @@ +import os +import argparse +import requests +import tempfile +from urllib.parse import urlparse +from PIL import Image +import torch +from torchvision import transforms +from transformers import AutoModelForImageSegmentation +import time + +# 设置torch精度 +torch.set_float32_matmul_precision("high") + +# 初始化模型 +model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True) +model.to("cuda") +model.eval() # 设置为评估模式 + +# 定义图像转换 +transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), +]) + +def is_valid_url(url): + """验证URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + +def download_image(url): + """从URL下载图片到临时文件""" + response = requests.get(url, stream=True) + response.raise_for_status() + + # 创建临时文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') + with open(temp_file.name, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + return temp_file.name + +def process_image(image): + """处理图像,移除背景""" + image_size = image.size + # 转换图像 + input_images = transform_image(image).unsqueeze(0).to("cuda") + + # 使用半精度加速 + with torch.cuda.amp.autocast(): + # 预测 + with torch.no_grad(): + preds = model(input_images)[-1].sigmoid().cpu() + + # 处理预测结果 + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + + # 添加透明通道 + image.putalpha(mask) + return image + +def remove_background(image_path, output_dir=None): + """ + 移除图像背景 + + Args: + image_path: 输入图像的路径或URL + output_dir: 输出目录,如果不指定则使用当前目录 + + Returns: + 处理后的图像路径 + """ + temp_file = None + try: + # 检查是否是URL + if is_valid_url(image_path): + try: + # 下载图片到临时文件 + temp_file = download_image(image_path) + image_path = temp_file + except Exception as e: + raise Exception(f"下载图片失败: {e}") + + # 验证输入文件是否存在 + if not os.path.exists(image_path): + raise FileNotFoundError(f"输入图像不存在: {image_path}") + + # 如果输出目录未指定,则使用当前目录 + if output_dir is None: + output_dir = os.getcwd() + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 加载并处理图像 + image = Image.open(image_path).convert("RGB") + image_no_bg = process_image(image) + + # 保存结果 + filename = os.path.basename(image_path) + name, ext = os.path.splitext(filename) + output_file_path = os.path.join(output_dir, f"{name}_nobg.png") + image_no_bg.save(output_file_path) + + return output_file_path + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except: + pass + +def main(): + parser = argparse.ArgumentParser(description='移除图像背景') + parser.add_argument('--image', help='输入图像的路径或URL') + parser.add_argument('--output-dir', default='output', help='输出目录') + args = parser.parse_args() + + try: + # 如果没有提供图片参数,使用示例URL + if args.image is None: + print("未提供图片参数,使用示例图片URL进行测试...") + args.image = "http://test001.jingrow.com/files/3T8A4426.JPG" + + print(f"处理图片: {args.image}") + t0 = time.time() + output_path = remove_background(args.image, args.output_dir) + print(f"处理完成,用时 {time.time()-t0:.1f} 秒") + print(f"输出文件: {output_path}") + except Exception as e: + print(f"错误: {e}") + finally: + # 清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +if __name__ == "__main__": + main() diff --git a/task/requirements.txt b/task/requirements.txt new file mode 100644 index 0000000..e846745 --- /dev/null +++ b/task/requirements.txt @@ -0,0 +1,4 @@ +requests>=2.31.0 +Pillow>=10.0.0 +torch>=2.0.0 +transformers>=4.30.0 \ No newline at end of file diff --git a/task/rmbg.py b/task/rmbg.py new file mode 100644 index 0000000..68188db --- /dev/null +++ b/task/rmbg.py @@ -0,0 +1,106 @@ +import os +import gradio as gr +from gradio_imageslider import ImageSlider +from loadimg import load_img +import spaces +from transformers import AutoModelForImageSegmentation +import torch +from torchvision import transforms + +torch.set_float32_matmul_precision(["high", "highest"][0]) + +birefnet = AutoModelForImageSegmentation.from_pretrained( + "briaai/RMBG-2.0", trust_remote_code=True +) +birefnet.to("cuda") +transform_image = transforms.Compose( + [ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ] +) + +output_folder = 'output_images' +if not os.path.exists(output_folder): + os.makedirs(output_folder) + +def fn(image): + im = load_img(image, output_type="pil") + im = im.convert("RGB") + origin = im.copy() + image = process(im) + image_path = os.path.join(output_folder, "no_bg_image.png") + image.save(image_path) + return (image, origin), image_path + +@spaces.GPU +def process(image): + image_size = image.size + input_images = transform_image(image).unsqueeze(0).to("cuda") + # Prediction + with torch.no_grad(): + preds = birefnet(input_images)[-1].sigmoid().cpu() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + image.putalpha(mask) + return image + +def process_file(f): + name_path = f.rsplit(".",1)[0]+".png" + im = load_img(f, output_type="pil") + im = im.convert("RGB") + transparent = process(im) + transparent.save(name_path) + return name_path + +slider1 = ImageSlider(label="RMBG-2.0", type="pil") +slider2 = ImageSlider(label="RMBG-2.0", type="pil") +image = gr.Image(label="Upload an image") +image2 = gr.Image(label="Upload an image",type="filepath") +text = gr.Textbox(label="Paste an image URL") +png_file = gr.File(label="output png file") + + +chameleon = load_img("giraffe.jpg", output_type="pil") + +url = "http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg" + +tab1 = gr.Interface( + fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[chameleon], api_name="image" +) + +tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=[url], api_name="text") +tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png") + + +demo = gr.TabbedInterface( + [tab1, tab2], ["input image", "input url"], title = ( + "RMBG-2.0 for background removal
" + "" + "Background removal model developed by " + "BRIA.AI, trained on a carefully selected dataset,
" + "and is available as an open-source model for non-commercial use.

" + " For testing upload your image and wait.
" + "Commercial use license | " + "Model card | " + "Blog" + "

" + "" + "API Endpoint available on: " + "Bria.ai, " + "fal.ai
" + "ComfyUI node is available here: " + "ComfyUI Node
" + "Purchase commercial weigths for commercial use: " + "here" + "
" +) + + + +) + +if __name__ == "__main__": + demo.launch(show_error=True) \ No newline at end of file diff --git a/task/rmbg2.py b/task/rmbg2.py new file mode 100644 index 0000000..03f165f --- /dev/null +++ b/task/rmbg2.py @@ -0,0 +1,195 @@ +import os +import argparse +import requests +import tempfile +from urllib.parse import urlparse +from PIL import Image +import torch +from torchvision import transforms +from transformers import AutoModelForImageSegmentation +import time +import warnings +from concurrent.futures import ThreadPoolExecutor +import gc + +# 关闭不必要的警告 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +# 设置torch精度和优化 +torch.set_float32_matmul_precision("high") +torch.backends.cudnn.benchmark = True # 启用cuDNN自动调优 +torch.backends.cuda.matmul.allow_tf32 = True # 允许使用TF32 +torch.backends.cudnn.allow_tf32 = True # 允许cuDNN使用TF32 + +# 示例图片URL +image_url = "http://test001.jingrow.com/files/3T8A4426.JPG" + +def is_valid_url(url): + """验证URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + +def download_image(url): + """从URL下载图片到临时文件""" + response = requests.get(url, stream=True) + response.raise_for_status() + + # 创建临时文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') + with open(temp_file.name, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + return temp_file.name + +def process_image(image, model, device): + """处理图像,移除背景""" + image_size = image.size + # 转换图像 + transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + # 使用半精度加速 + with torch.cuda.amp.autocast(): + input_images = transform_image(image).unsqueeze(0).to(device) + # 预测 + with torch.no_grad(): + preds = model(input_images)[-1].sigmoid().cpu() + + # 处理预测结果 + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + + # 添加透明通道 + image.putalpha(mask) + return image + +def remove_background(image_path, output_dir=None, model=None, device=None): + """ + 移除图像背景 + + Args: + image_path: 输入图像的路径或URL + output_dir: 输出目录,如果不指定则使用当前目录 + model: 模型实例 + device: 设备 + + Returns: + 处理后的图像路径 + """ + temp_file = None + try: + # 检查是否是URL + if is_valid_url(image_path): + try: + # 下载图片到临时文件 + temp_file = download_image(image_path) + image_path = temp_file + except Exception as e: + raise Exception(f"下载图片失败: {e}") + + # 验证输入文件是否存在 + if not os.path.exists(image_path): + raise FileNotFoundError(f"输入图像不存在: {image_path}") + + # 如果输出目录未指定,则使用当前目录 + if output_dir is None: + output_dir = os.getcwd() + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + # 加载并处理图像 + image = Image.open(image_path).convert("RGB") + image_no_bg = process_image(image, model, device) + + # 保存结果 + filename = os.path.basename(image_path) + name, ext = os.path.splitext(filename) + output_file_path = os.path.join(output_dir, f"{name}_nobg.png") + image_no_bg.save(output_file_path) + + return output_file_path + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except: + pass + +def process_batch(image_paths, output_dir, model, device, max_workers=4): + """批量处理图片""" + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for image_path in image_paths: + future = executor.submit(remove_background, image_path, output_dir, model, device) + futures.append((image_path, future)) + + results = [] + for image_path, future in futures: + try: + output_path = future.result() + results.append((image_path, output_path, None)) + except Exception as e: + results.append((image_path, None, str(e))) + return results + +def main(): + parser = argparse.ArgumentParser(description='移除图像背景') + parser.add_argument('--image', nargs='+', help='输入图像的路径或URL(支持多个)') + parser.add_argument('--output-dir', default='output', help='输出目录') + parser.add_argument('--model-path', default='briaai/RMBG-2.0', help='模型路径') + parser.add_argument('--batch-size', type=int, default=4, help='批处理大小') + args = parser.parse_args() + + # 设置设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {device}") + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") + + try: + # 加载模型 + print("正在加载模型...") + t0 = time.time() + model = AutoModelForImageSegmentation.from_pretrained(args.model_path, trust_remote_code=True) + model = model.to(device) + model.eval() # 设置为评估模式 + print(f"模型加载完成,用时 {time.time()-t0:.1f} 秒") + + # 处理图片 + image_paths = args.image if args.image else [image_url] + if len(image_paths) > 1: + print(f"开始批量处理 {len(image_paths)} 张图片...") + results = process_batch(image_paths, args.output_dir, model, device, args.batch_size) + for image_path, output_path, error in results: + if error: + print(f"处理失败: {image_path}, 错误: {error}") + else: + print(f"处理成功: {image_path} -> {output_path}") + else: + print(f"处理图片: {image_paths[0]}") + t1 = time.time() + output_path = remove_background(image_paths[0], args.output_dir, model, device) + print(f"处理完成,用时 {time.time()-t1:.1f} 秒") + print(f"输出文件: {output_path}") + + except Exception as e: + print(f"错误: {e}") + finally: + # 清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + +if __name__ == "__main__": + main() diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..e69de29