重构rmbg支持异步并发任务

This commit is contained in:
jingrow 2025-11-23 03:23:44 +08:00
parent c99e20ff61
commit 474ce6f5db
10 changed files with 97 additions and 50 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 905 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 761 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 732 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 766 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 696 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 845 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 753 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 745 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 914 KiB

View File

@ -1,5 +1,4 @@
import os
import requests
import tempfile
from urllib.parse import urlparse
from PIL import Image
@ -12,9 +11,8 @@ import gc
import base64
import asyncio
import io
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor
import uuid
import httpx
from settings import settings
# 关闭不必要的警告
@ -34,6 +32,8 @@ class RmbgService:
self.download_url = settings.download_url
# 确保保存目录存在
os.makedirs(self.save_dir, exist_ok=True)
# 创建异步HTTP客户端复用连接提高性能
self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20))
self._load_model()
def _load_model(self):
@ -45,11 +45,10 @@ class RmbgService:
self.model = self.model.to(self.device)
self.model.eval() # 设置为评估模式
def process_image(self, image):
"""处理图像,移除背景"""
def _process_image_sync(self, image):
"""同步处理图像,移除背景(内部方法,在线程池中执行)"""
image_size = image.size
# 转换图像
t0 = time.time()
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
@ -58,12 +57,10 @@ class RmbgService:
input_images = transform_image(image).unsqueeze(0).to(self.device)
# 推理
t0 = time.time()
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
t0 = time.time()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
@ -77,6 +74,12 @@ class RmbgService:
gc.collect()
return image
async def process_image(self, image):
"""异步处理图像,移除背景(在线程池中执行同步操作)"""
# 将同步的GPU操作放到线程池中执行避免阻塞事件循环
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._process_image_sync, image)
def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串"""
@ -120,8 +123,8 @@ class RmbgService:
# 检查是否是URL
if self.is_valid_url(image_path):
try:
# 下载图片到临时文件
temp_file = self.download_image(image_path)
# 异步下载图片到临时文件
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
@ -130,12 +133,22 @@ class RmbgService:
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 加载并处理图像
image = Image.open(image_path).convert("RGB")
image_no_bg = self.process_image(image)
# 加载图像IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
lambda: Image.open(image_path).convert("RGB")
)
# 保存图片到文件并获取URL
image_url = self.save_image_to_file(image_no_bg)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return {
"status": "success",
@ -161,12 +174,22 @@ class RmbgService:
处理后的图像内容
"""
try:
# 从文件内容创建PIL Image对象
image = Image.open(io.BytesIO(file_content)).convert("RGB")
image_no_bg = self.process_image(image)
# 从文件内容创建PIL Image对象IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
# 保存图片到文件并获取URL
image_url = self.save_image_to_file(image_no_bg)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.save_image_to_file,
image_no_bg
)
return {
"status": "success",
@ -178,51 +201,61 @@ class RmbgService:
async def process_batch(self, urls):
"""
批量处理多个URL图像流式返回结果
批量处理多个URL图像并发处理并流式返回结果
Args:
urls: 图片URL列表
Yields:
每个图片的处理结果
每个图片的处理结果按完成顺序返回
"""
total = len(urls)
success_count = 0
error_count = 0
for i, url in enumerate(urls, 1):
# 创建并发任务
async def process_single_url(index, url):
"""处理单个URL的包装函数"""
try:
url_str = str(url)
result = await self.remove_background(url_str)
success_count += 1
# 确保返回正确的数据格式
yield {
"index": i,
return {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": result["image_url"],
"success_count": success_count,
"error_count": error_count,
"message": "处理成功"
}
except Exception as e:
error_count += 1
yield {
"index": i,
return {
"index": index,
"total": total,
"original_url": str(url),
"status": "error",
"error": str(e),
"success_count": success_count,
"error_count": error_count,
"message": f"处理失败: {str(e)}"
}
# 创建所有任务
tasks = [
process_single_url(i, url)
for i, url in enumerate(urls, 1)
]
# 并发执行所有任务使用as_completed按完成顺序返回
for coro in asyncio.as_completed(tasks):
result = await coro
if result["status"] == "success":
success_count += 1
else:
error_count += 1
# 让出控制权,避免阻塞
await asyncio.sleep(0)
# 更新统计信息
result["success_count"] = success_count
result["error_count"] = error_count
yield result
def is_valid_url(self, url):
"""验证URL是否有效"""
@ -232,20 +265,34 @@ class RmbgService:
except:
return False
def download_image(self, url):
"""从URL下载图片到临时文件"""
response = requests.get(url, stream=True)
response.raise_for_status()
# 创建临时文件
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
with open(temp_file.name, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return temp_file.name
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
# 创建临时文件并写入内容IO操作在线程池中执行
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
None,
write_temp_file,
response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
def cleanup(self):
async def cleanup(self):
"""清理资源"""
# 关闭HTTP客户端
await self.http_client.aclose()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()