import os import requests import tempfile from urllib.parse import urlparse from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import time import warnings import gc import base64 import asyncio import io import multiprocessing as mp from concurrent.futures import ProcessPoolExecutor # 关闭不必要的警告 warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) # 设置torch精度 torch.set_float32_matmul_precision("high") class RmbgService: def __init__(self, model_path="zhengpeng7/BiRefNet"): """初始化背景移除服务""" self.model_path = model_path self.model = None self.device = None self._load_model() def _load_model(self): """加载模型""" # 设置设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") t0 = time.time() self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) self.model = self.model.to(self.device) self.model.eval() # 设置为评估模式 def process_image(self, image): """处理图像,移除背景""" image_size = image.size # 转换图像 t0 = time.time() transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_images = transform_image(image).unsqueeze(0).to(self.device) # 推理 t0 = time.time() with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() # 处理预测结果 t0 = time.time() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # 添加透明通道 image.putalpha(mask) # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return image def image_to_base64(self, image): """将PIL Image对象转换为base64字符串""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') async def remove_background(self, image_path): """ 移除图像背景 Args: image_path: 输入图像的路径或URL Returns: 处理后的图像内容 """ temp_file = None try: # 检查是否是URL if self.is_valid_url(image_path): try: # 下载图片到临时文件 temp_file = self.download_image(image_path) image_path = temp_file except Exception as e: raise Exception(f"下载图片失败: {e}") # 验证输入文件是否存在 if not os.path.exists(image_path): raise FileNotFoundError(f"输入图像不存在: {image_path}") # 加载并处理图像 image = Image.open(image_path).convert("RGB") image_no_bg = self.process_image(image) # 转换为base64 image_content = self.image_to_base64(image_no_bg) return { "status": "success", "image_content": image_content } finally: # 清理临时文件 if temp_file and os.path.exists(temp_file): try: os.unlink(temp_file) except: pass async def remove_background_from_file(self, file_content): """ 从上传的文件内容移除背景 Args: file_content: 上传的文件内容 Returns: 处理后的图像内容 """ try: # 从文件内容创建PIL Image对象 image = Image.open(io.BytesIO(file_content)).convert("RGB") image_no_bg = self.process_image(image) # 转换为base64 image_content = self.image_to_base64(image_no_bg) return { "status": "success", "image_content": image_content } except Exception as e: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): """ 批量处理多个URL图像,流式返回结果 Args: urls: 图片URL列表 Yields: 每个图片的处理结果 """ total = len(urls) success_count = 0 error_count = 0 for i, url in enumerate(urls, 1): try: url_str = str(url) result = await self.remove_background(url_str) success_count += 1 # 确保返回正确的数据格式 yield { "index": i, "total": total, "original_url": url_str, "status": "success", "image_content": result["image_content"], "success_count": success_count, "error_count": error_count, "message": "处理成功" } except Exception as e: error_count += 1 yield { "index": i, "total": total, "original_url": str(url), "status": "error", "error": str(e), "success_count": success_count, "error_count": error_count, "message": f"处理失败: {str(e)}" } # 让出控制权,避免阻塞 await asyncio.sleep(0) def is_valid_url(self, url): """验证URL是否有效""" try: result = urlparse(url) return all([result.scheme, result.netloc]) except: return False def download_image(self, url): """从URL下载图片到临时文件""" response = requests.get(url, stream=True) response.raise_for_status() # 创建临时文件 temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') with open(temp_file.name, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) return temp_file.name def cleanup(self): """清理资源""" if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print("资源已清理")