japi/apps/rmbg/service.py
2025-05-12 02:39:56 +08:00

225 lines
7.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import base64
import asyncio
import io
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
# 关闭不必要的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# 设置torch精度
torch.set_float32_matmul_precision("high")
class RmbgService:
def __init__(self, model_path="zhengpeng7/BiRefNet"):
"""初始化背景移除服务"""
self.model_path = model_path
self.model = None
self.device = None
self._load_model()
def _load_model(self):
"""加载模型"""
# 设置设备
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t0 = time.time()
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.to(self.device)
self.model.eval() # 设置为评估模式
def process_image(self, image):
"""处理图像,移除背景"""
image_size = image.size
# 转换图像
t0 = time.time()
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
# 推理
t0 = time.time()
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
t0 = time.time()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
# 检查是否是URL
if self.is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = self.process_image(image)
# 转换为base64
image_content = self.image_to_base64(image_no_bg)
return {
"status": "success",
"image_content": image_content
}
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
# 从文件内容创建PIL Image对象
image = Image.open(io.BytesIO(file_content)).convert("RGB")
image_no_bg = self.process_image(image)
# 转换为base64
image_content = self.image_to_base64(image_no_bg)
return {
"status": "success",
"image_content": image_content
}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""
批量处理多个URL图像流式返回结果
Args:
urls: 图片URL列表
Yields:
每个图片的处理结果
"""
total = len(urls)
success_count = 0
error_count = 0
for i, url in enumerate(urls, 1):
try:
url_str = str(url)
result = await self.remove_background(url_str)
success_count += 1
# 确保返回正确的数据格式
yield {
"index": i,
"total": total,
"original_url": url_str,
"status": "success",
"image_content": result["image_content"],
"success_count": success_count,
"error_count": error_count,
"message": "处理成功"
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"original_url": str(url),
"status": "error",
"error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}"
}
# 让出控制权,避免阻塞
await asyncio.sleep(0)
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
def download_image(self, url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
def cleanup(self):
"""清理资源"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print("资源已清理")