japi/apps/rmbg/service.py

299 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import base64
import asyncio
import io
import uuid
import httpx
from settings import settings
# 关闭不必要的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# 设置torch精度
torch.set_float32_matmul_precision("high")
class RmbgService:
def __init__(self, model_path="zhengpeng7/BiRefNet"):
"""初始化背景移除服务"""
self.model_path = model_path
self.model = None
self.device = None
self.save_dir = settings.save_dir
self.download_url = settings.download_url
# 确保保存目录存在
os.makedirs(self.save_dir, exist_ok=True)
# 创建异步HTTP客户端复用连接提高性能
self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20))
self._load_model()
def _load_model(self):
"""加载模型"""
# 设置设备
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t0 = time.time()
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.to(self.device)
self.model.eval() # 设置为评估模式
def _process_image_sync(self, image):
"""同步处理图像,移除背景(内部方法,在线程池中执行)"""
image_size = image.size
# 转换图像
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
# 推理
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
async def process_image(self, image):
"""异步处理图像,移除背景(在线程池中执行同步操作)"""
# 将同步的GPU操作放到线程池中执行避免阻塞事件循环
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._process_image_sync, image)
def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def save_image_to_file(self, image):
"""
保存图片到jfile/files目录并返回URL
Args:
image: PIL Image对象
Returns:
图片URL
"""
# 生成唯一文件名
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
# 保存图片
image.save(file_path, format="PNG")
# 构建URL
image_url = f"{self.download_url}/{filename}"
return image_url
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
# 检查是否是URL
if self.is_valid_url(image_path):
try:
# 异步下载图片到临时文件
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 加载图像IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
lambda: Image.open(image_path).convert("RGB")
)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
# 从文件内容创建PIL Image对象IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""
批量处理多个URL图像并发处理并流式返回结果
Args:
urls: 图片URL列表
Yields:
每个图片的处理结果(按完成顺序返回)
"""
total = len(urls)
success_count = 0
error_count = 0
# 创建并发任务
async def process_single_url(index, url):
"""处理单个URL的包装函数"""
try:
url_str = str(url)
result = await self.remove_background(url_str)
return {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": result["image_url"],
"message": "处理成功"
}
except Exception as e:
return {
"index": index,
"total": total,
"original_url": str(url),
"status": "error",
"error": str(e),
"message": f"处理失败: {str(e)}"
}
# 创建所有任务
tasks = [
process_single_url(i, url)
for i, url in enumerate(urls, 1)
]
# 并发执行所有任务使用as_completed按完成顺序返回
for coro in asyncio.as_completed(tasks):
result = await coro
if result["status"] == "success":
success_count += 1
else:
error_count += 1
# 更新统计信息
result["success_count"] = success_count
result["error_count"] = error_count
yield result
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
# 创建临时文件并写入内容IO操作在线程池中执行
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
None,
write_temp_file,
response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
async def cleanup(self):
"""清理资源"""
# 关闭HTTP客户端
await self.http_client.aclose()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print("资源已清理")