重构rmbg支持异步并发任务
|
Before Width: | Height: | Size: 905 KiB |
|
Before Width: | Height: | Size: 761 KiB |
|
Before Width: | Height: | Size: 732 KiB |
|
Before Width: | Height: | Size: 766 KiB |
|
Before Width: | Height: | Size: 696 KiB |
|
Before Width: | Height: | Size: 845 KiB |
|
Before Width: | Height: | Size: 753 KiB |
|
Before Width: | Height: | Size: 745 KiB |
|
Before Width: | Height: | Size: 914 KiB |
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import requests
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@ -12,9 +11,8 @@ import gc
|
|||||||
import base64
|
import base64
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import multiprocessing as mp
|
|
||||||
from concurrent.futures import ProcessPoolExecutor
|
|
||||||
import uuid
|
import uuid
|
||||||
|
import httpx
|
||||||
from settings import settings
|
from settings import settings
|
||||||
|
|
||||||
# 关闭不必要的警告
|
# 关闭不必要的警告
|
||||||
@ -34,6 +32,8 @@ class RmbgService:
|
|||||||
self.download_url = settings.download_url
|
self.download_url = settings.download_url
|
||||||
# 确保保存目录存在
|
# 确保保存目录存在
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
# 创建异步HTTP客户端(复用连接,提高性能)
|
||||||
|
self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20))
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
@ -45,11 +45,10 @@ class RmbgService:
|
|||||||
self.model = self.model.to(self.device)
|
self.model = self.model.to(self.device)
|
||||||
self.model.eval() # 设置为评估模式
|
self.model.eval() # 设置为评估模式
|
||||||
|
|
||||||
def process_image(self, image):
|
def _process_image_sync(self, image):
|
||||||
"""处理图像,移除背景"""
|
"""同步处理图像,移除背景(内部方法,在线程池中执行)"""
|
||||||
image_size = image.size
|
image_size = image.size
|
||||||
# 转换图像
|
# 转换图像
|
||||||
t0 = time.time()
|
|
||||||
transform_image = transforms.Compose([
|
transform_image = transforms.Compose([
|
||||||
transforms.Resize((1024, 1024)),
|
transforms.Resize((1024, 1024)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
@ -58,12 +57,10 @@ class RmbgService:
|
|||||||
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
# 推理
|
# 推理
|
||||||
t0 = time.time()
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
preds = self.model(input_images)[-1].sigmoid().cpu()
|
preds = self.model(input_images)[-1].sigmoid().cpu()
|
||||||
|
|
||||||
# 处理预测结果
|
# 处理预测结果
|
||||||
t0 = time.time()
|
|
||||||
pred = preds[0].squeeze()
|
pred = preds[0].squeeze()
|
||||||
pred_pil = transforms.ToPILImage()(pred)
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
mask = pred_pil.resize(image_size)
|
mask = pred_pil.resize(image_size)
|
||||||
@ -77,6 +74,12 @@ class RmbgService:
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
async def process_image(self, image):
|
||||||
|
"""异步处理图像,移除背景(在线程池中执行同步操作)"""
|
||||||
|
# 将同步的GPU操作放到线程池中执行,避免阻塞事件循环
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return await loop.run_in_executor(None, self._process_image_sync, image)
|
||||||
|
|
||||||
def image_to_base64(self, image):
|
def image_to_base64(self, image):
|
||||||
"""将PIL Image对象转换为base64字符串"""
|
"""将PIL Image对象转换为base64字符串"""
|
||||||
@ -120,8 +123,8 @@ class RmbgService:
|
|||||||
# 检查是否是URL
|
# 检查是否是URL
|
||||||
if self.is_valid_url(image_path):
|
if self.is_valid_url(image_path):
|
||||||
try:
|
try:
|
||||||
# 下载图片到临时文件
|
# 异步下载图片到临时文件
|
||||||
temp_file = self.download_image(image_path)
|
temp_file = await self.download_image(image_path)
|
||||||
image_path = temp_file
|
image_path = temp_file
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"下载图片失败: {e}")
|
raise Exception(f"下载图片失败: {e}")
|
||||||
@ -130,12 +133,22 @@ class RmbgService:
|
|||||||
if not os.path.exists(image_path):
|
if not os.path.exists(image_path):
|
||||||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||||
|
|
||||||
# 加载并处理图像
|
# 加载图像(IO操作,在线程池中执行)
|
||||||
image = Image.open(image_path).convert("RGB")
|
loop = asyncio.get_event_loop()
|
||||||
image_no_bg = self.process_image(image)
|
image = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: Image.open(image_path).convert("RGB")
|
||||||
|
)
|
||||||
|
|
||||||
# 保存图片到文件并获取URL
|
# 异步处理图像
|
||||||
image_url = self.save_image_to_file(image_no_bg)
|
image_no_bg = await self.process_image(image)
|
||||||
|
|
||||||
|
# 保存图片到文件并获取URL(IO操作,在线程池中执行)
|
||||||
|
image_url = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self.save_image_to_file,
|
||||||
|
image_no_bg
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -161,12 +174,22 @@ class RmbgService:
|
|||||||
处理后的图像内容
|
处理后的图像内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 从文件内容创建PIL Image对象
|
# 从文件内容创建PIL Image对象(IO操作,在线程池中执行)
|
||||||
image = Image.open(io.BytesIO(file_content)).convert("RGB")
|
loop = asyncio.get_event_loop()
|
||||||
image_no_bg = self.process_image(image)
|
image = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||||||
|
)
|
||||||
|
|
||||||
# 保存图片到文件并获取URL
|
# 异步处理图像
|
||||||
image_url = self.save_image_to_file(image_no_bg)
|
image_no_bg = await self.process_image(image)
|
||||||
|
|
||||||
|
# 保存图片到文件并获取URL(IO操作,在线程池中执行)
|
||||||
|
image_url = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
self.save_image_to_file,
|
||||||
|
image_no_bg
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "success",
|
"status": "success",
|
||||||
@ -178,51 +201,61 @@ class RmbgService:
|
|||||||
|
|
||||||
async def process_batch(self, urls):
|
async def process_batch(self, urls):
|
||||||
"""
|
"""
|
||||||
批量处理多个URL图像,流式返回结果
|
批量处理多个URL图像,并发处理并流式返回结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
urls: 图片URL列表
|
urls: 图片URL列表
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
每个图片的处理结果
|
每个图片的处理结果(按完成顺序返回)
|
||||||
"""
|
"""
|
||||||
total = len(urls)
|
total = len(urls)
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
|
|
||||||
for i, url in enumerate(urls, 1):
|
# 创建并发任务
|
||||||
|
async def process_single_url(index, url):
|
||||||
|
"""处理单个URL的包装函数"""
|
||||||
try:
|
try:
|
||||||
url_str = str(url)
|
url_str = str(url)
|
||||||
result = await self.remove_background(url_str)
|
result = await self.remove_background(url_str)
|
||||||
success_count += 1
|
return {
|
||||||
|
"index": index,
|
||||||
# 确保返回正确的数据格式
|
|
||||||
yield {
|
|
||||||
"index": i,
|
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": url_str,
|
"original_url": url_str,
|
||||||
"status": "success",
|
"status": "success",
|
||||||
"image_url": result["image_url"],
|
"image_url": result["image_url"],
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"message": "处理成功"
|
"message": "处理成功"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_count += 1
|
return {
|
||||||
yield {
|
"index": index,
|
||||||
"index": i,
|
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": str(url),
|
"original_url": str(url),
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"message": f"处理失败: {str(e)}"
|
"message": f"处理失败: {str(e)}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 创建所有任务
|
||||||
|
tasks = [
|
||||||
|
process_single_url(i, url)
|
||||||
|
for i, url in enumerate(urls, 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 并发执行所有任务,使用as_completed按完成顺序返回
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
result = await coro
|
||||||
|
if result["status"] == "success":
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
|
||||||
# 让出控制权,避免阻塞
|
# 更新统计信息
|
||||||
await asyncio.sleep(0)
|
result["success_count"] = success_count
|
||||||
|
result["error_count"] = error_count
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
def is_valid_url(self, url):
|
def is_valid_url(self, url):
|
||||||
"""验证URL是否有效"""
|
"""验证URL是否有效"""
|
||||||
@ -232,20 +265,34 @@ class RmbgService:
|
|||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def download_image(self, url):
|
async def download_image(self, url):
|
||||||
"""从URL下载图片到临时文件"""
|
"""异步从URL下载图片到临时文件"""
|
||||||
response = requests.get(url, stream=True)
|
try:
|
||||||
response.raise_for_status()
|
response = await self.http_client.get(url)
|
||||||
|
response.raise_for_status()
|
||||||
# 创建临时文件
|
|
||||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
# 创建临时文件并写入内容(IO操作,在线程池中执行)
|
||||||
with open(temp_file.name, 'wb') as f:
|
def write_temp_file(content):
|
||||||
for chunk in response.iter_content(chunk_size=8192):
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||||
f.write(chunk)
|
temp_file.write(content)
|
||||||
return temp_file.name
|
temp_file.close()
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
temp_file_path = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
write_temp_file,
|
||||||
|
response.content
|
||||||
|
)
|
||||||
|
|
||||||
|
return temp_file_path
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"下载图片失败: {e}")
|
||||||
|
|
||||||
def cleanup(self):
|
async def cleanup(self):
|
||||||
"""清理资源"""
|
"""清理资源"""
|
||||||
|
# 关闭HTTP客户端
|
||||||
|
await self.http_client.aclose()
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|||||||