优化rmbg并发逻辑,实测并发生效

This commit is contained in:
jingrow 2025-11-23 04:58:58 +08:00
parent 474ce6f5db
commit 10fb6084f5
2 changed files with 164 additions and 72 deletions

View File

@ -13,13 +13,18 @@ import asyncio
import io
import uuid
import httpx
import logging
from concurrent.futures import ThreadPoolExecutor
from settings import settings
# 关闭不必要的警告
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
# 设置torch精度
torch.set_float32_matmul_precision("high")
class RmbgService:
@ -30,25 +35,43 @@ class RmbgService:
self.device = None
self.save_dir = settings.save_dir
self.download_url = settings.download_url
# 确保保存目录存在
os.makedirs(self.save_dir, exist_ok=True)
# 创建异步HTTP客户端复用连接提高性能
self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20))
self.http_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_keepalive_connections=50,
max_connections=100
)
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
self._gpu_semaphore = None
self._max_gpu_concurrent = settings.max_gpu_concurrent
self._load_model()
@property
def gpu_semaphore(self):
"""延迟初始化GPU信号量"""
if self._gpu_semaphore is None:
if self._max_gpu_concurrent == 0:
return None
try:
loop = asyncio.get_event_loop()
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
except RuntimeError:
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
return self._gpu_semaphore
def _load_model(self):
"""加载模型"""
# 设置设备
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t0 = time.time()
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
self.model = self.model.to(self.device)
self.model.eval() # 设置为评估模式
self.model.eval()
def _process_image_sync(self, image):
"""同步处理图像,移除背景(内部方法,在线程池中执行)"""
"""同步处理图像,移除背景"""
image_size = image.size
# 转换图像
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
@ -56,30 +79,103 @@ class RmbgService:
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
# 推理
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
# 处理预测结果
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
# 添加透明通道
image.putalpha(mask)
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
def _process_batch_images_sync(self, images_with_info):
"""
批量处理图像充分利用GPU并行能力
Args:
images_with_info: [(image, image_size, index), ...] 图像和元信息列表
Returns:
[(processed_image, index), ...] 处理后的图像和索引
"""
if not images_with_info:
return []
try:
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
batch_images = []
for image, image_size, index in images_with_info:
try:
transformed = transform_image(image)
batch_images.append(transformed)
except Exception as e:
logger.error(f"图片转换失败 (index={index}): {str(e)}")
raise
if not batch_images:
raise Exception("没有有效的图片可以处理")
input_batch = torch.stack(batch_images).to(self.device)
with torch.no_grad():
model_output = self.model(input_batch)
if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu()
else:
preds = model_output.sigmoid().cpu()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
try:
if len(preds.shape) == 4:
pred = preds[i].squeeze()
elif len(preds.shape) == 3:
pred = preds[i]
else:
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
results.append((result_image, index))
except Exception as e:
logger.error(f"处理预测结果失败 (index={index}): {str(e)}")
raise
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return results
except Exception as e:
logger.error(f"批处理失败: {str(e)}")
raise
async def process_image(self, image):
"""异步处理图像,移除背景(在线程池中执行同步操作)"""
# 将同步的GPU操作放到线程池中执行避免阻塞事件循环
"""异步处理图像,移除背景"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, self._process_image_sync, image)
return await loop.run_in_executor(self.executor, self._process_image_sync, image)
async def process_batch_images(self, images_with_info):
"""异步批量处理图像"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self._process_batch_images_sync,
images_with_info
)
def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串"""
@ -88,23 +184,10 @@ class RmbgService:
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def save_image_to_file(self, image):
"""
保存图片到jfile/files目录并返回URL
Args:
image: PIL Image对象
Returns:
图片URL
"""
# 生成唯一文件名
"""保存图片到文件并返回URL"""
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
# 保存图片
image.save(file_path, format="PNG")
# 构建URL
image_url = f"{self.download_url}/{filename}"
return image_url
@ -120,32 +203,26 @@ class RmbgService:
"""
temp_file = None
try:
# 检查是否是URL
if self.is_valid_url(image_path):
try:
# 异步下载图片到临时文件
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
# 验证输入文件是否存在
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
# 加载图像IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
self.executor,
lambda: Image.open(image_path).convert("RGB")
)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.executor,
self.save_image_to_file,
image_no_bg
)
@ -156,7 +233,6 @@ class RmbgService:
}
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
@ -174,19 +250,16 @@ class RmbgService:
处理后的图像内容
"""
try:
# 从文件内容创建PIL Image对象IO操作在线程池中执行
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
None,
self.executor,
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
# 异步处理图像
image_no_bg = await self.process_image(image)
# 保存图片到文件并获取URLIO操作在线程池中执行
image_url = await loop.run_in_executor(
None,
self.executor,
self.save_image_to_file,
image_no_bg
)
@ -200,60 +273,77 @@ class RmbgService:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""
批量处理多个URL图像并发处理并流式返回结果
Args:
urls: 图片URL列表
Yields:
每个图片的处理结果按完成顺序返回
"""
"""批量处理多个URL图像流水线并发模式"""
total = len(urls)
success_count = 0
error_count = 0
batch_start_time = time.time()
# 创建并发任务
async def process_single_url(index, url):
"""处理单个URL的包装函数"""
async def download_and_process(index, url):
"""下载并处理单张图片"""
url_str = str(url)
try:
url_str = str(url)
result = await self.remove_background(url_str)
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: Image.open(url_str).convert("RGB")
)
processed_image = await self.process_image(image)
loop = asyncio.get_event_loop()
image_url = await loop.run_in_executor(
self.executor,
self.save_image_to_file,
processed_image
)
return {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": result["image_url"],
"image_url": image_url,
"message": "处理成功"
}
except Exception as e:
logger.error(f"处理失败 (index={index}): {str(e)}")
return {
"index": index,
"total": total,
"original_url": str(url),
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"处理失败: {str(e)}"
}
# 创建所有任务
tasks = [
process_single_url(i, url)
download_and_process(i, url)
for i, url in enumerate(urls, 1)
]
# 并发执行所有任务使用as_completed按完成顺序返回
completed_order = 0
for coro in asyncio.as_completed(tasks):
result = await coro
completed_order += 1
if result["status"] == "success":
success_count += 1
else:
error_count += 1
# 更新统计信息
result["success_count"] = success_count
result["error_count"] = error_count
result["completed_order"] = completed_order
result["batch_elapsed"] = round(time.time() - batch_start_time, 2)
yield result
@ -271,7 +361,6 @@ class RmbgService:
response = await self.http_client.get(url)
response.raise_for_status()
# 创建临时文件并写入内容IO操作在线程池中执行
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
@ -280,7 +369,7 @@ class RmbgService:
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
None,
self.executor,
write_temp_file,
response.content
)
@ -291,9 +380,8 @@ class RmbgService:
async def cleanup(self):
"""清理资源"""
# 关闭HTTP客户端
await self.http_client.aclose()
self.executor.shutdown(wait=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
print("资源已清理")
gc.collect()

View File

@ -26,6 +26,10 @@ class Settings(BaseSettings):
jingrow_api_key: Optional[str] = None
jingrow_api_secret: Optional[str] = None
# 并发控制配置
max_workers: int = 30 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
max_gpu_concurrent: int = 0 # GPU最大并发数0表示不限制根据显存大小设置24GB显存建议10-15
class Config:
env_file = ".env"