japi/apps/rmbg/service.py
jingrow cecd617104 feat(rmbg): 添加批处理性能统计功能
- 记录下载图片、GPU推理、保存图片各阶段的耗时
- 输出详细的性能统计信息,包括:
  * 图片总数、成功/失败数量
  * 批处理次数和每批图片数
  * 各阶段耗时及占比
  * 总耗时、平均每张耗时、每批平均耗时
- 使用统一的日志格式输出统计信息
2025-11-23 15:48:33 +08:00

837 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import asyncio
import io
import uuid
import httpx
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Dict, Any
from settings import settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
torch.set_float32_matmul_precision("high")
@dataclass
class QueueItem:
"""队列项数据结构"""
image: Image.Image
image_size: tuple
request_id: str
future: asyncio.Future
created_at: float
class RmbgService:
def __init__(self, model_path="zhengpeng7/BiRefNet"):
"""初始化背景移除服务"""
self.model_path = model_path
self.model = None
self.device = None
self.save_dir = settings.save_dir
self.download_url = settings.download_url
os.makedirs(self.save_dir, exist_ok=True)
self.http_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_keepalive_connections=50,
max_connections=100
)
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
# 队列聚合机制方案B
self.queue: asyncio.Queue = asyncio.Queue()
self.queue_task: Optional[asyncio.Task] = None
self.queue_running = False
self._load_model()
# 队列任务将在 FastAPI startup 事件中启动
def _load_model(self):
"""加载模型"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 优化显存占用:使用半精度加载(如果支持)
# 注意:某些模型可能不支持半精度,需要测试
try:
# 尝试使用半精度加载可以减少约50%的显存占用
self.model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
self.model = self.model.to(self.device)
if torch.cuda.is_available():
self.model = self.model.half() # 转换为半精度
logger.info("模型已使用半精度FP16加载显存占用减少约50%")
except Exception as e:
# 如果半精度加载失败,降级到全精度
logger.warning(f"半精度加载失败,使用全精度: {str(e)}")
self.model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True
)
self.model = self.model.to(self.device)
self.model.eval()
# 优化显存分配策略:减少碎片化
if torch.cuda.is_available():
# 设置显存分配器,减少碎片化
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
torch.cuda.empty_cache()
logger.info(f"模型加载完成,当前显存占用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)"""
image_size = image.size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(self.device)
# 如果模型是半精度,输入也转换为半精度
if next(self.model.parameters()).dtype == torch.float16:
input_images = input_images.half()
with torch.no_grad():
preds = self.model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
def _process_batch_images_sync(self, images_with_info):
"""批量处理图像批处理模式充分利用GPU并行能力"""
if not images_with_info:
return []
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
# 批处理前清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
batch_tensors = []
for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image))
input_batch = torch.stack(batch_tensors).to(self.device)
# 如果模型是半精度,输入也转换为半精度
if next(self.model.parameters()).dtype == torch.float16:
input_batch = input_batch.half()
# 释放 batch_tensors 占用的 CPU 内存
del batch_tensors
with torch.no_grad():
model_output = self.model(input_batch)
if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu()
else:
preds = model_output.sigmoid().cpu()
# 立即释放 GPU 上的 input_batch 和 model_output
del input_batch
if isinstance(model_output, (list, tuple)):
del model_output
if torch.cuda.is_available():
torch.cuda.empty_cache()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
if len(preds.shape) == 4:
pred = preds[i].squeeze()
elif len(preds.shape) == 3:
pred = preds[i]
else:
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
results.append((result_image, index))
# 释放 preds
del preds
# 批处理后再次清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return results
async def process_image(self, image):
"""异步处理图像,移除背景(单张)- 使用队列批处理模式"""
if settings.enable_queue_batch and self.queue_running:
return await self._process_image_via_queue(image)
else:
# 降级到单张处理
return await asyncio.get_event_loop().run_in_executor(
self.executor, self._process_image_sync, image
)
async def _process_image_via_queue(self, image):
"""通过队列批处理模式处理单张图像"""
request_id = uuid.uuid4().hex[:10]
future = asyncio.Future()
queue_item = QueueItem(
image=image,
image_size=image.size,
request_id=request_id,
future=future,
created_at=time.time()
)
try:
await self.queue.put(queue_item)
# 等待处理结果,带超时
try:
result = await asyncio.wait_for(future, timeout=settings.request_timeout)
return result
except asyncio.TimeoutError:
future.cancel()
raise Exception(f"处理超时(超过{settings.request_timeout}秒)")
except Exception as e:
if not future.done():
future.set_exception(e)
raise
async def process_batch_images(self, images_with_info):
"""异步批量处理图像(批处理模式)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, self._process_batch_images_sync, images_with_info
)
async def _start_queue_processor(self):
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用)"""
if self.queue_running:
return
self.queue_running = True
self.queue_task = asyncio.create_task(self._queue_processor())
logger.info("队列批处理机制已启动")
async def _queue_processor(self):
"""后台队列批处理任务(核心逻辑)"""
logger.info("队列批处理任务开始运行")
while self.queue_running:
try:
# 收集一批请求
batch_items = await self._collect_batch_items()
if not batch_items:
continue
# 处理这批请求
await self._process_batch_queue_items(batch_items)
except Exception as e:
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True)
await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self):
"""收集一批队列项达到batch_size或超时后返回"""
batch_items = []
batch_size = settings.batch_size
collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout
# 先尝试获取第一个请求(阻塞等待)
try:
first_item = await asyncio.wait_for(
self.queue.get(),
timeout=collect_timeout
)
batch_items.append(first_item)
except asyncio.TimeoutError:
# 超时,返回空列表
return []
# 继续收集更多请求直到达到batch_size或超时
start_time = time.time()
while len(batch_items) < batch_size:
elapsed = time.time() - start_time
# 如果已经超时,立即处理当前收集的请求
if elapsed >= collect_timeout:
break
# 尝试在剩余时间内获取更多请求
remaining_time = min(collect_interval, collect_timeout - elapsed)
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=remaining_time
)
batch_items.append(item)
except asyncio.TimeoutError:
# 超时,处理已收集的请求
break
return batch_items
async def _process_batch_queue_items(self, batch_items):
"""处理一批队列项"""
if not batch_items:
return
loop = asyncio.get_event_loop()
try:
# 准备批处理数据
images_with_info = []
for idx, item in enumerate(batch_items):
images_with_info.append((item.image, item.image_size, idx))
# 执行批处理
batch_results = await self.process_batch_images(images_with_info)
# 将结果返回给对应的Future
for idx, (processed_image, _) in enumerate(batch_results):
if idx < len(batch_items):
item = batch_items[idx]
# 保存图片并返回URL
try:
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
result = {
"status": "success",
"image_url": image_url
}
if not item.future.done():
item.future.set_result(result)
except Exception as e:
error_msg = f"处理图片失败: {str(e)}"
logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}")
if not item.future.done():
item.future.set_exception(Exception(error_msg))
# 处理任何未完成的Future理论上不应该发生
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception("批处理结果不完整"))
except Exception as e:
error_msg = f"批处理队列项失败: {str(e)}"
logger.error(error_msg, exc_info=True)
# 所有请求都标记为失败
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception(error_msg))
def save_image_to_file(self, image):
"""保存图片到文件并返回URL"""
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
image.save(file_path, format="PNG")
image_url = f"{self.download_url}/{filename}"
return image_url
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
if self.is_valid_url(image_path):
try:
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(image_path).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
finally:
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, image_no_bg
)
return {
"status": "success",
"image_url": image_url
}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""批量处理多个URL图像批处理模式推荐方案"""
total = len(urls)
success_count = 0
error_count = 0
batch_start_time = time.time()
batch_size = settings.batch_size
loop = asyncio.get_event_loop()
# 性能统计变量
download_time = 0.0
gpu_inference_time = 0.0
save_time = 0.0
batch_count = 0
batch_sizes = []
stats_printed = False
def print_stats():
"""输出性能统计信息"""
nonlocal stats_printed
if stats_printed:
return
stats_printed = True
total_time = time.time() - batch_start_time
other_time = total_time - download_time - gpu_inference_time - save_time
logger.info("=" * 60)
logger.info("📊 批处理性能统计")
logger.info("=" * 60)
logger.info(f"图片总数: {total}")
logger.info(f"成功数量: {success_count}")
logger.info(f"失败数量: {error_count}")
logger.info(f"批处理次数: {batch_count}")
logger.info(f"每批图片数: {batch_sizes}")
logger.info("-" * 60)
logger.info("⏱️ 各阶段耗时:")
if total_time > 0:
download_pct = (download_time / total_time) * 100
gpu_pct = (gpu_inference_time / total_time) * 100
save_pct = (save_time / total_time) * 100
other_pct = (other_time / total_time) * 100
logger.info(f" 1. 下载图片: {download_time:.3f}s ({download_pct:.1f}%)")
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s ({gpu_pct:.1f}%)")
logger.info(f" 3. 保存图片: {save_time:.3f}s ({save_pct:.1f}%)")
logger.info(f" 4. 其他开销: {other_time:.3f}s ({other_pct:.1f}%)")
else:
logger.info(f" 1. 下载图片: {download_time:.3f}s")
logger.info(f" 2. GPU推理: {gpu_inference_time:.3f}s")
logger.info(f" 3. 保存图片: {save_time:.3f}s")
logger.info(f" 4. 其他开销: {other_time:.3f}s")
logger.info("-" * 60)
logger.info(f"📈 总耗时: {total_time:.3f}s")
if total > 0:
avg_per_image = (total_time / total) * 1000
logger.info(f"📈 平均每张: {avg_per_image:.1f}ms")
if batch_count > 0:
avg_batch_time = gpu_inference_time / batch_count
logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s")
async def download_image_async(index, url):
"""异步下载图片"""
url_str = str(url)
try:
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await loop.run_in_executor(
self.executor, lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await loop.run_in_executor(
self.executor, lambda: Image.open(url_str).convert("RGB")
)
return (image, image.size, index, url_str, None)
except Exception as e:
return (None, None, index, url_str, str(e))
# 记录下载开始时间
download_start_time = time.time()
download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)]
downloaded_images = await asyncio.gather(*download_tasks)
download_time = time.time() - download_start_time
valid_images = []
failed_results = {}
for item in downloaded_images:
image, image_size, index, url_str, error = item
if error:
failed_results[index] = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error,
"message": f"下载失败: {error}"
}
else:
valid_images.append((image, image_size, index, url_str))
for index, result in failed_results.items():
error_count += 1
result["success_count"] = success_count
result["error_count"] = error_count
result["completed_order"] = len(failed_results)
result["batch_elapsed"] = round(time.time() - batch_start_time, 2)
yield result
completed_order = len(failed_results)
# 如果图片数量不太多(<= batch_size * 2尝试一次性处理所有图片避免分批提升并发
# 对于13张图片batch_size=813 <= 16会尝试一次性处理
# 如果显存不足,自动降级到分批处理
max_single_batch = batch_size * 2 # 允许最多2倍batch_size
use_single_batch = len(valid_images) <= max_single_batch
if use_single_batch:
try:
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images]
# 记录GPU推理开始时间
gpu_start_time = time.time()
batch_results = await self.process_batch_images(images_with_info)
gpu_inference_time += time.time() - gpu_start_time
batch_count += 1
batch_sizes.append(len(images_with_info))
# 并行保存所有图片
save_tasks = []
result_mapping = {}
for processed_image, index in batch_results:
url_str = next(url for _, _, idx, url in valid_images if idx == index)
result_mapping[index] = (processed_image, url_str)
save_task = loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
save_tasks.append((index, save_task))
# 记录保存开始时间
save_start_time = time.time()
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
save_time += time.time() - save_start_time
for (index, _), image_url in zip(save_tasks, save_results):
if isinstance(image_url, Exception):
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "error",
"error": str(image_url),
"message": f"保存图片失败: {str(image_url)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
else:
completed_order += 1
success_count += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "success",
"image_url": image_url,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except RuntimeError as e:
# 检测CUDA OOM错误自动降级到分批处理
error_msg = str(e)
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
logger.warning(f"一次性处理显存不足,自动降级到分批处理: {error_msg[:100]}")
# 清理显存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 继续执行分批处理逻辑不return继续到else分支
use_single_batch = False
else:
# 其他错误,直接返回
logger.error(f"批处理失败: {error_msg}")
for _, _, index, url_str in valid_images:
completed_order += 1
error_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error_msg,
"message": f"批处理失败: {error_msg}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
print_stats()
return
except Exception as e:
# 其他异常,直接返回错误
logger.error(f"批处理失败: {str(e)}")
for _, _, index, url_str in valid_images:
completed_order += 1
error_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"批处理失败: {str(e)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
print_stats()
return
# 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理
if not use_single_batch:
# 多批处理:串行处理批次,但每个批次内部并行保存
for batch_start in range(0, len(valid_images), batch_size):
batch_end = min(batch_start + batch_size, len(valid_images))
batch_images = valid_images[batch_start:batch_end]
try:
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images]
# 记录GPU推理开始时间
gpu_start_time = time.time()
batch_results = await self.process_batch_images(images_with_info)
gpu_inference_time += time.time() - gpu_start_time
batch_count += 1
batch_sizes.append(len(images_with_info))
# 并行保存所有图片
save_tasks = []
result_mapping = {}
for processed_image, index in batch_results:
url_str = next(url for _, _, idx, url in batch_images if idx == index)
result_mapping[index] = (processed_image, url_str)
save_task = loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
save_tasks.append((index, save_task))
# 记录保存开始时间
save_start_time = time.time()
# 并行执行所有保存任务
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
save_time += time.time() - save_start_time
# 按顺序返回结果
for (index, _), image_url in zip(save_tasks, save_results):
if isinstance(image_url, Exception):
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "error",
"error": str(image_url),
"message": f"保存图片失败: {str(image_url)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
else:
completed_order += 1
success_count += 1
result = {
"index": index,
"total": total,
"original_url": result_mapping[index][1],
"status": "success",
"image_url": image_url,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except Exception as e:
logger.error(f"批处理失败: {str(e)}")
for _, _, index, url_str in batch_images:
completed_order += 1
error_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"批处理失败: {str(e)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
# 输出性能统计信息
print_stats()
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
self.executor, write_temp_file, response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
async def cleanup(self):
"""清理资源"""
# 停止队列处理任务
if self.queue_running:
self.queue_running = False
if self.queue_task:
self.queue_task.cancel()
try:
await self.queue_task
except asyncio.CancelledError:
pass
logger.info("队列批处理机制已停止")
# 处理队列中剩余的请求
remaining_items = []
while not self.queue.empty():
try:
item = self.queue.get_nowait()
remaining_items.append(item)
except asyncio.QueueEmpty:
break
# 标记剩余请求为失败
for item in remaining_items:
if not item.future.done():
item.future.set_exception(Exception("服务关闭,请求被取消"))
await self.http_client.aclose()
self.executor.shutdown(wait=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()