japi/apps/rmbg/service.py

272 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import asyncio
import io
import uuid
import httpx
import logging
from concurrent.futures import ThreadPoolExecutor
from settings import settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
torch.set_float32_matmul_precision("high")
class RmbgService:
def __init__(self, model_path="zhengpeng7/BiRefNet"):
"""初始化背景移除服务"""
self.model_path = model_path
self.model = None
self.device = None
self.save_dir = settings.save_dir
self.download_url = settings.download_url
os.makedirs(self.save_dir, exist_ok=True)
self.http_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_keepalive_connections=50,
max_connections=100
)
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
self._load_model()
def _load_model(self):
"""加载模型"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.to(self.device)
self.model.eval()
def _process_image_sync(self, image):
"""同步处理图像,移除背景"""
image_size = image.size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
async def process_image(self, image):
"""异步处理图像,移除背景"""
return await asyncio.get_event_loop().run_in_executor(
self.executor, self._process_image_sync, image
)
def save_image_to_file(self, image):
"""保存图片到文件并返回URL"""
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
image.save(file_path, format="PNG")
image_url = f"{self.download_url}/{filename}"
return image_url
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
if self.is_valid_url(image_path):
try:
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(image_path).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
finally:
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""批量处理多个URL图像流水线并发模式"""
total = len(urls)
success_count = 0
error_count = 0
batch_start_time = time.time()
loop = asyncio.get_event_loop()
async def download_and_process(index, url):
"""下载并处理单张图片"""
url_str = str(url)
try:
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await loop.run_in_executor(
self.executor, lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await loop.run_in_executor(
self.executor, lambda: Image.open(url_str).convert("RGB")
)
processed_image = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
return {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": image_url,
"message": "处理成功"
}
except Exception as e:
logger.error(f"处理失败 (index={index}): {str(e)}")
return {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"处理失败: {str(e)}"
}
tasks = [
download_and_process(i, url)
for i, url in enumerate(urls, 1)
]
completed_order = 0
for coro in asyncio.as_completed(tasks):
result = await coro
completed_order += 1
if result["status"] == "success":
success_count += 1
else:
error_count += 1
result["success_count"] = success_count
result["error_count"] = error_count
result["completed_order"] = completed_order
result["batch_elapsed"] = round(time.time() - batch_start_time, 2)
yield result
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
self.executor, write_temp_file, response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
async def cleanup(self):
"""清理资源"""
await self.http_client.aclose()
self.executor.shutdown(wait=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()