import os import tempfile from urllib.parse import urlparse from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import time import warnings import gc import base64 import asyncio import io import uuid import httpx from settings import settings # 关闭不必要的警告 warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) # 设置torch精度 torch.set_float32_matmul_precision("high") class RmbgService: def __init__(self, model_path="zhengpeng7/BiRefNet"): """初始化背景移除服务""" self.model_path = model_path self.model = None self.device = None self.save_dir = settings.save_dir self.download_url = settings.download_url # 确保保存目录存在 os.makedirs(self.save_dir, exist_ok=True) # 创建异步HTTP客户端(复用连接,提高性能) self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20)) self._load_model() def _load_model(self): """加载模型""" # 设置设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") t0 = time.time() self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) self.model = self.model.to(self.device) self.model.eval() # 设置为评估模式 def _process_image_sync(self, image): """同步处理图像,移除背景(内部方法,在线程池中执行)""" image_size = image.size # 转换图像 transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_images = transform_image(image).unsqueeze(0).to(self.device) # 推理 with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() # 处理预测结果 pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) # 添加透明通道 image.putalpha(mask) # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return image async def process_image(self, image): """异步处理图像,移除背景(在线程池中执行同步操作)""" # 将同步的GPU操作放到线程池中执行,避免阻塞事件循环 loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._process_image_sync, image) def image_to_base64(self, image): """将PIL Image对象转换为base64字符串""" buffered = io.BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode('utf-8') def save_image_to_file(self, image): """ 保存图片到jfile/files目录并返回URL Args: image: PIL Image对象 Returns: 图片URL """ # 生成唯一文件名 filename = f"rmbg_{uuid.uuid4().hex[:10]}.png" file_path = os.path.join(self.save_dir, filename) # 保存图片 image.save(file_path, format="PNG") # 构建URL image_url = f"{self.download_url}/{filename}" return image_url async def remove_background(self, image_path): """ 移除图像背景 Args: image_path: 输入图像的路径或URL Returns: 处理后的图像内容 """ temp_file = None try: # 检查是否是URL if self.is_valid_url(image_path): try: # 异步下载图片到临时文件 temp_file = await self.download_image(image_path) image_path = temp_file except Exception as e: raise Exception(f"下载图片失败: {e}") # 验证输入文件是否存在 if not os.path.exists(image_path): raise FileNotFoundError(f"输入图像不存在: {image_path}") # 加载图像(IO操作,在线程池中执行) loop = asyncio.get_event_loop() image = await loop.run_in_executor( None, lambda: Image.open(image_path).convert("RGB") ) # 异步处理图像 image_no_bg = await self.process_image(image) # 保存图片到文件并获取URL(IO操作,在线程池中执行) image_url = await loop.run_in_executor( None, self.save_image_to_file, image_no_bg ) return { "status": "success", "image_url": image_url } finally: # 清理临时文件 if temp_file and os.path.exists(temp_file): try: os.unlink(temp_file) except: pass async def remove_background_from_file(self, file_content): """ 从上传的文件内容移除背景 Args: file_content: 上传的文件内容 Returns: 处理后的图像内容 """ try: # 从文件内容创建PIL Image对象(IO操作,在线程池中执行) loop = asyncio.get_event_loop() image = await loop.run_in_executor( None, lambda: Image.open(io.BytesIO(file_content)).convert("RGB") ) # 异步处理图像 image_no_bg = await self.process_image(image) # 保存图片到文件并获取URL(IO操作,在线程池中执行) image_url = await loop.run_in_executor( None, self.save_image_to_file, image_no_bg ) return { "status": "success", "image_url": image_url } except Exception as e: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): """ 批量处理多个URL图像,并发处理并流式返回结果 Args: urls: 图片URL列表 Yields: 每个图片的处理结果(按完成顺序返回) """ total = len(urls) success_count = 0 error_count = 0 # 创建并发任务 async def process_single_url(index, url): """处理单个URL的包装函数""" try: url_str = str(url) result = await self.remove_background(url_str) return { "index": index, "total": total, "original_url": url_str, "status": "success", "image_url": result["image_url"], "message": "处理成功" } except Exception as e: return { "index": index, "total": total, "original_url": str(url), "status": "error", "error": str(e), "message": f"处理失败: {str(e)}" } # 创建所有任务 tasks = [ process_single_url(i, url) for i, url in enumerate(urls, 1) ] # 并发执行所有任务,使用as_completed按完成顺序返回 for coro in asyncio.as_completed(tasks): result = await coro if result["status"] == "success": success_count += 1 else: error_count += 1 # 更新统计信息 result["success_count"] = success_count result["error_count"] = error_count yield result def is_valid_url(self, url): """验证URL是否有效""" try: result = urlparse(url) return all([result.scheme, result.netloc]) except: return False async def download_image(self, url): """异步从URL下载图片到临时文件""" try: response = await self.http_client.get(url) response.raise_for_status() # 创建临时文件并写入内容(IO操作,在线程池中执行) def write_temp_file(content): temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') temp_file.write(content) temp_file.close() return temp_file.name loop = asyncio.get_event_loop() temp_file_path = await loop.run_in_executor( None, write_temp_file, response.content ) return temp_file_path except Exception as e: raise Exception(f"下载图片失败: {e}") async def cleanup(self): """清理资源""" # 关闭HTTP客户端 await self.http_client.aclose() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() print("资源已清理")