272 lines
9.0 KiB
Python
272 lines
9.0 KiB
Python
import os
|
||
import tempfile
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from transformers import AutoModelForImageSegmentation
|
||
import time
|
||
import warnings
|
||
import gc
|
||
import asyncio
|
||
import io
|
||
import uuid
|
||
import httpx
|
||
import logging
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from settings import settings
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning)
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
torch.set_float32_matmul_precision("high")
|
||
|
||
class RmbgService:
|
||
def __init__(self, model_path="zhengpeng7/BiRefNet"):
|
||
"""初始化背景移除服务"""
|
||
self.model_path = model_path
|
||
self.model = None
|
||
self.device = None
|
||
self.save_dir = settings.save_dir
|
||
self.download_url = settings.download_url
|
||
os.makedirs(self.save_dir, exist_ok=True)
|
||
|
||
self.http_client = httpx.AsyncClient(
|
||
timeout=30.0,
|
||
limits=httpx.Limits(
|
||
max_keepalive_connections=50,
|
||
max_connections=100
|
||
)
|
||
)
|
||
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""加载模型"""
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
|
||
self.model = self.model.to(self.device)
|
||
self.model.eval()
|
||
|
||
def _process_image_sync(self, image):
|
||
"""同步处理图像,移除背景"""
|
||
image_size = image.size
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
||
|
||
with torch.no_grad():
|
||
preds = self.model(input_images)[-1].sigmoid().cpu()
|
||
|
||
pred = preds[0].squeeze()
|
||
pred_pil = transforms.ToPILImage()(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
image.putalpha(mask)
|
||
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
return image
|
||
|
||
async def process_image(self, image):
|
||
"""异步处理图像,移除背景"""
|
||
return await asyncio.get_event_loop().run_in_executor(
|
||
self.executor, self._process_image_sync, image
|
||
)
|
||
|
||
def save_image_to_file(self, image):
|
||
"""保存图片到文件并返回URL"""
|
||
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
|
||
file_path = os.path.join(self.save_dir, filename)
|
||
image.save(file_path, format="PNG")
|
||
image_url = f"{self.download_url}/{filename}"
|
||
return image_url
|
||
|
||
async def remove_background(self, image_path):
|
||
"""
|
||
移除图像背景
|
||
|
||
Args:
|
||
image_path: 输入图像的路径或URL
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
temp_file = None
|
||
try:
|
||
if self.is_valid_url(image_path):
|
||
try:
|
||
temp_file = await self.download_image(image_path)
|
||
image_path = temp_file
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(image_path).convert("RGB")
|
||
)
|
||
|
||
image_no_bg = await self.process_image(image)
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, image_no_bg
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"image_url": image_url
|
||
}
|
||
|
||
finally:
|
||
if temp_file and os.path.exists(temp_file):
|
||
try:
|
||
os.unlink(temp_file)
|
||
except:
|
||
pass
|
||
|
||
async def remove_background_from_file(self, file_content):
|
||
"""
|
||
从上传的文件内容移除背景
|
||
|
||
Args:
|
||
file_content: 上传的文件内容
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||
)
|
||
|
||
image_no_bg = await self.process_image(image)
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, image_no_bg
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"image_url": image_url
|
||
}
|
||
|
||
except Exception as e:
|
||
raise Exception(f"处理图片失败: {e}")
|
||
|
||
async def process_batch(self, urls):
|
||
"""批量处理多个URL图像,流水线并发模式"""
|
||
total = len(urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
batch_start_time = time.time()
|
||
loop = asyncio.get_event_loop()
|
||
|
||
async def download_and_process(index, url):
|
||
"""下载并处理单张图片"""
|
||
url_str = str(url)
|
||
try:
|
||
if self.is_valid_url(url_str):
|
||
temp_file = await self.download_image(url_str)
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(temp_file).convert("RGB")
|
||
)
|
||
os.unlink(temp_file)
|
||
else:
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||
)
|
||
|
||
processed_image = await self.process_image(image)
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, processed_image
|
||
)
|
||
|
||
return {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "success",
|
||
"image_url": image_url,
|
||
"message": "处理成功"
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理失败 (index={index}): {str(e)}")
|
||
return {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": str(e),
|
||
"message": f"处理失败: {str(e)}"
|
||
}
|
||
|
||
tasks = [
|
||
download_and_process(i, url)
|
||
for i, url in enumerate(urls, 1)
|
||
]
|
||
|
||
completed_order = 0
|
||
for coro in asyncio.as_completed(tasks):
|
||
result = await coro
|
||
completed_order += 1
|
||
|
||
if result["status"] == "success":
|
||
success_count += 1
|
||
else:
|
||
error_count += 1
|
||
|
||
result["success_count"] = success_count
|
||
result["error_count"] = error_count
|
||
result["completed_order"] = completed_order
|
||
result["batch_elapsed"] = round(time.time() - batch_start_time, 2)
|
||
|
||
yield result
|
||
|
||
def is_valid_url(self, url):
|
||
"""验证URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
async def download_image(self, url):
|
||
"""异步从URL下载图片到临时文件"""
|
||
try:
|
||
response = await self.http_client.get(url)
|
||
response.raise_for_status()
|
||
|
||
def write_temp_file(content):
|
||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||
temp_file.write(content)
|
||
temp_file.close()
|
||
return temp_file.name
|
||
|
||
loop = asyncio.get_event_loop()
|
||
temp_file_path = await loop.run_in_executor(
|
||
self.executor, write_temp_file, response.content
|
||
)
|
||
|
||
return temp_file_path
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
async def cleanup(self):
|
||
"""清理资源"""
|
||
await self.http_client.aclose()
|
||
self.executor.shutdown(wait=True)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect() |