import os import tempfile from urllib.parse import urlparse from PIL import Image import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation import time import warnings import gc import asyncio import io import uuid import httpx import logging from concurrent.futures import ThreadPoolExecutor from settings import settings logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) torch.set_float32_matmul_precision("high") class RmbgService: def __init__(self, model_path="zhengpeng7/BiRefNet"): """初始化背景移除服务""" self.model_path = model_path self.model = None self.device = None self.save_dir = settings.save_dir self.download_url = settings.download_url os.makedirs(self.save_dir, exist_ok=True) self.http_client = httpx.AsyncClient( timeout=30.0, limits=httpx.Limits( max_keepalive_connections=50, max_connections=100 ) ) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) self._load_model() def _load_model(self): """加载模型""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) self.model = self.model.to(self.device) self.model.eval() def _process_image_sync(self, image): """同步处理图像,移除背景""" image_size = image.size transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_images = transform_image(image).unsqueeze(0).to(self.device) with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) image.putalpha(mask) if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return image async def process_image(self, image): """异步处理图像,移除背景""" return await asyncio.get_event_loop().run_in_executor( self.executor, self._process_image_sync, image ) def save_image_to_file(self, image): """保存图片到文件并返回URL""" filename = f"rmbg_{uuid.uuid4().hex[:10]}.png" file_path = os.path.join(self.save_dir, filename) image.save(file_path, format="PNG") image_url = f"{self.download_url}/{filename}" return image_url async def remove_background(self, image_path): """ 移除图像背景 Args: image_path: 输入图像的路径或URL Returns: 处理后的图像内容 """ temp_file = None try: if self.is_valid_url(image_path): try: temp_file = await self.download_image(image_path) image_path = temp_file except Exception as e: raise Exception(f"下载图片失败: {e}") if not os.path.exists(image_path): raise FileNotFoundError(f"输入图像不存在: {image_path}") loop = asyncio.get_event_loop() image = await loop.run_in_executor( self.executor, lambda: Image.open(image_path).convert("RGB") ) image_no_bg = await self.process_image(image) image_url = await loop.run_in_executor( self.executor, self.save_image_to_file, image_no_bg ) return { "status": "success", "image_url": image_url } finally: if temp_file and os.path.exists(temp_file): try: os.unlink(temp_file) except: pass async def remove_background_from_file(self, file_content): """ 从上传的文件内容移除背景 Args: file_content: 上传的文件内容 Returns: 处理后的图像内容 """ try: loop = asyncio.get_event_loop() image = await loop.run_in_executor( self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB") ) image_no_bg = await self.process_image(image) image_url = await loop.run_in_executor( self.executor, self.save_image_to_file, image_no_bg ) return { "status": "success", "image_url": image_url } except Exception as e: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): """批量处理多个URL图像,流水线并发模式""" total = len(urls) success_count = 0 error_count = 0 batch_start_time = time.time() loop = asyncio.get_event_loop() async def download_and_process(index, url): """下载并处理单张图片""" url_str = str(url) try: if self.is_valid_url(url_str): temp_file = await self.download_image(url_str) image = await loop.run_in_executor( self.executor, lambda: Image.open(temp_file).convert("RGB") ) os.unlink(temp_file) else: image = await loop.run_in_executor( self.executor, lambda: Image.open(url_str).convert("RGB") ) processed_image = await self.process_image(image) image_url = await loop.run_in_executor( self.executor, self.save_image_to_file, processed_image ) return { "index": index, "total": total, "original_url": url_str, "status": "success", "image_url": image_url, "message": "处理成功" } except Exception as e: logger.error(f"处理失败 (index={index}): {str(e)}") return { "index": index, "total": total, "original_url": url_str, "status": "error", "error": str(e), "message": f"处理失败: {str(e)}" } tasks = [ download_and_process(i, url) for i, url in enumerate(urls, 1) ] completed_order = 0 for coro in asyncio.as_completed(tasks): result = await coro completed_order += 1 if result["status"] == "success": success_count += 1 else: error_count += 1 result["success_count"] = success_count result["error_count"] = error_count result["completed_order"] = completed_order result["batch_elapsed"] = round(time.time() - batch_start_time, 2) yield result def is_valid_url(self, url): """验证URL是否有效""" try: result = urlparse(url) return all([result.scheme, result.netloc]) except: return False async def download_image(self, url): """异步从URL下载图片到临时文件""" try: response = await self.http_client.get(url) response.raise_for_status() def write_temp_file(content): temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') temp_file.write(content) temp_file.close() return temp_file.name loop = asyncio.get_event_loop() temp_file_path = await loop.run_in_executor( self.executor, write_temp_file, response.content ) return temp_file_path except Exception as e: raise Exception(f"下载图片失败: {e}") async def cleanup(self): """清理资源""" await self.http_client.aclose() self.executor.shutdown(wait=True) if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect()