重构rmbg支持异步并发任务

This commit is contained in:
jingrow 2025-11-23 03:23:44 +08:00
parent c99e20ff61
commit 474ce6f5db
10 changed files with 97 additions and 50 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 905 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 761 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 732 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 766 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 696 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 845 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 753 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 745 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 914 KiB

View File

@ -1,5 +1,4 @@
import os import os
import requests
import tempfile import tempfile
from urllib.parse import urlparse from urllib.parse import urlparse
from PIL import Image from PIL import Image
@ -12,9 +11,8 @@ import gc
import base64 import base64
import asyncio import asyncio
import io import io
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
import uuid import uuid
import httpx
from settings import settings from settings import settings
# 关闭不必要的警告 # 关闭不必要的警告
@ -34,6 +32,8 @@ class RmbgService:
self.download_url = settings.download_url self.download_url = settings.download_url
# 确保保存目录存在 # 确保保存目录存在
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
# 创建异步HTTP客户端复用连接提高性能
self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20))
self._load_model() self._load_model()
def _load_model(self): def _load_model(self):
@ -45,11 +45,10 @@ class RmbgService:
self.model = self.model.to(self.device) self.model = self.model.to(self.device)
self.model.eval() # 设置为评估模式 self.model.eval() # 设置为评估模式
def process_image(self, image): def _process_image_sync(self, image):
"""处理图像,移除背景""" """同步处理图像,移除背景(内部方法,在线程池中执行)"""
image_size = image.size image_size = image.size
# 转换图像 # 转换图像
t0 = time.time()
transform_image = transforms.Compose([ transform_image = transforms.Compose([
transforms.Resize((1024, 1024)), transforms.Resize((1024, 1024)),
transforms.ToTensor(), transforms.ToTensor(),
@ -58,12 +57,10 @@ class RmbgService:
input_images = transform_image(image).unsqueeze(0).to(self.device) input_images = transform_image(image).unsqueeze(0).to(self.device)
# 推理 # 推理
t0 = time.time()
with torch.no_grad(): with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu() preds = self.model(input_images)[-1].sigmoid().cpu()
# 处理预测结果 # 处理预测结果
t0 = time.time()
pred = preds[0].squeeze() pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred) pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size) mask = pred_pil.resize(image_size)
@ -77,6 +74,12 @@ class RmbgService:
gc.collect() gc.collect()
return image return image
async def process_image(self, image):
"""异步处理图像,移除背景(在线程池中执行同步操作)"""
# 将同步的GPU操作放到线程池中执行避免阻塞事件循环
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._process_image_sync, image)
def image_to_base64(self, image): def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串""" """将PIL Image对象转换为base64字符串"""
@ -120,8 +123,8 @@ class RmbgService:
# 检查是否是URL # 检查是否是URL
if self.is_valid_url(image_path): if self.is_valid_url(image_path):
try: try:
# 下载图片到临时文件 # 异步下载图片到临时文件
temp_file = self.download_image(image_path) temp_file = await self.download_image(image_path)
image_path = temp_file image_path = temp_file
except Exception as e: except Exception as e:
raise Exception(f"下载图片失败: {e}") raise Exception(f"下载图片失败: {e}")
@ -130,12 +133,22 @@ class RmbgService:
if not os.path.exists(image_path): if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}") raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 加载并处理图像 # 加载图像IO操作在线程池中执行
image = Image.open(image_path).convert("RGB") loop = asyncio.get_event_loop()
image_no_bg = self.process_image(image) image = await loop.run_in_executor(
None,
lambda: Image.open(image_path).convert("RGB")
)
# 保存图片到文件并获取URL # 异步处理图像
image_url = self.save_image_to_file(image_no_bg) image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return { return {
"status": "success", "status": "success",
@ -161,12 +174,22 @@ class RmbgService:
处理后的图像内容 处理后的图像内容
""" """
try: try:
# 从文件内容创建PIL Image对象 # 从文件内容创建PIL Image对象IO操作在线程池中执行
image = Image.open(io.BytesIO(file_content)).convert("RGB") loop = asyncio.get_event_loop()
image_no_bg = self.process_image(image) image = await loop.run_in_executor(
None,
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
# 保存图片到文件并获取URL # 异步处理图像
image_url = self.save_image_to_file(image_no_bg) image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return { return {
"status": "success", "status": "success",
@ -178,51 +201,61 @@ class RmbgService:
async def process_batch(self, urls): async def process_batch(self, urls):
""" """
批量处理多个URL图像流式返回结果 批量处理多个URL图像并发处理并流式返回结果
Args: Args:
urls: 图片URL列表 urls: 图片URL列表
Yields: Yields:
每个图片的处理结果 每个图片的处理结果按完成顺序返回
""" """
total = len(urls) total = len(urls)
success_count = 0 success_count = 0
error_count = 0 error_count = 0
for i, url in enumerate(urls, 1): # 创建并发任务
async def process_single_url(index, url):
"""处理单个URL的包装函数"""
try: try:
url_str = str(url) url_str = str(url)
result = await self.remove_background(url_str) result = await self.remove_background(url_str)
success_count += 1 return {
"index": index,
# 确保返回正确的数据格式
yield {
"index": i,
"total": total, "total": total,
"original_url": url_str, "original_url": url_str,
"status": "success", "status": "success",
"image_url": result["image_url"], "image_url": result["image_url"],
"success_count": success_count,
"error_count": error_count,
"message": "处理成功" "message": "处理成功"
} }
except Exception as e: except Exception as e:
error_count += 1 return {
yield { "index": index,
"index": i,
"total": total, "total": total,
"original_url": str(url), "original_url": str(url),
"status": "error", "status": "error",
"error": str(e), "error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}" "message": f"处理失败: {str(e)}"
} }
# 创建所有任务
tasks = [
process_single_url(i, url)
for i, url in enumerate(urls, 1)
]
# 并发执行所有任务使用as_completed按完成顺序返回
for coro in asyncio.as_completed(tasks):
result = await coro
if result["status"] == "success":
success_count += 1
else:
error_count += 1
# 让出控制权,避免阻塞 # 更新统计信息
await asyncio.sleep(0) result["success_count"] = success_count
result["error_count"] = error_count
yield result
def is_valid_url(self, url): def is_valid_url(self, url):
"""验证URL是否有效""" """验证URL是否有效"""
@ -232,20 +265,34 @@ class RmbgService:
except: except:
return False return False
def download_image(self, url): async def download_image(self, url):
"""从URL下载图片到临时文件""" """异步从URL下载图片到临时文件"""
response = requests.get(url, stream=True) try:
response.raise_for_status() response = await self.http_client.get(url)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') # 创建临时文件并写入内容IO操作在线程池中执行
with open(temp_file.name, 'wb') as f: def write_temp_file(content):
for chunk in response.iter_content(chunk_size=8192): temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
f.write(chunk) temp_file.write(content)
return temp_file.name temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
None,
write_temp_file,
response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
def cleanup(self): async def cleanup(self):
"""清理资源""" """清理资源"""
# 关闭HTTP客户端
await self.http_client.aclose()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()