japi/apps/rmbg/service.py
jingrow 565115369d 优化批处理为全局队列合批以充分利用多 GPU
- 重构 process_batch,将批处理入口改为逐张通过全局队列的 process_image
- 在本地批次内并发调用 process_image,让全局队列能凑大 batch 并触发多 GPU 并行
- 保留原有流式返回结构和统计字段,对外接口兼容不变
2025-12-16 08:48:25 +00:00

805 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import asyncio
import io
import uuid
import httpx
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Dict, Any
from threading import Lock
from settings import settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
torch.set_float32_matmul_precision("high")
@dataclass
class QueueItem:
"""队列项数据结构"""
image: Image.Image
image_size: tuple
request_id: str
future: asyncio.Future
created_at: float
class RmbgService:
def __init__(self, model_path=None):
"""初始化背景移除服务"""
self.model_path = model_path or settings.model_path
# 单机多 GPU维护模型和设备列表兼容旧字段
self.models = []
self.devices = []
self.model = None
self.device = None
self._gpu_lock = Lock()
self._next_gpu_index = 0
self.save_dir = settings.save_dir
self.download_url = settings.download_url
os.makedirs(self.save_dir, exist_ok=True)
self.http_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_keepalive_connections=50,
max_connections=100
)
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
# 队列聚合机制方案B
self.queue: asyncio.Queue = asyncio.Queue()
self.queue_task: Optional[asyncio.Task] = None
self.queue_running = False
self._load_model()
# 队列任务将在 FastAPI startup 事件中启动
def _load_model(self):
"""加载模型,支持多 GPU"""
# 优化显存分配策略:减少碎片化(需要在加载前设置)
if torch.cuda.is_available():
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
use_half = torch.cuda.is_available()
def _load_single_model(device: torch.device):
"""在指定 device 上加载一个模型实例"""
try:
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
dtype=torch.float16 if use_half else torch.float32,
)
model = model.to(device)
if use_half:
model = model.half()
except Exception as e:
# 如果半精度加载失败,降级到全精度
logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}")
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
)
model = model.to(device)
model.eval()
return model
if num_gpus > 0:
# 为每张 GPU 加载一份模型,简单轮询调度
for idx in range(num_gpus):
device = torch.device(f"cuda:{idx}")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info(f"检测到 {num_gpus} 张 GPU已为每张 GPU 加载模型实例")
else:
# 仅 CPU
device = torch.device("cpu")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info("未检测到 GPU使用 CPU 设备")
# 兼容旧字段:默认指向第一个设备和模型
self.device = self.devices[0]
self.model = self.models[0]
if torch.cuda.is_available():
torch.cuda.empty_cache()
def _get_model_and_device(self):
"""为一次推理选择一个模型和设备(轮询)"""
if not self.models or not self.devices:
raise RuntimeError("模型尚未加载")
if len(self.models) == 1:
return self.models[0], self.devices[0]
with self._gpu_lock:
idx = self._next_gpu_index
self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models)
return self.models[idx], self.devices[idx]
def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)"""
model, device = self._get_model_and_device()
image_size = image.size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(device)
# 如果模型是半精度,输入也转换为半精度
if next(model.parameters()).dtype == torch.float16:
input_images = input_images.half()
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
# 单张处理保留 gc.collect(),确保及时释放内存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
def _process_batch_images_sync(self, images_with_info):
"""批量处理图像批处理模式充分利用GPU并行能力"""
if not images_with_info:
return []
# 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行
if len(self.models) == 1:
return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info)
# 简单均匀拆分到各个 GPU上游调用会按 index 重新排序
tasks = []
for i, (model, device) in enumerate(zip(self.models, self.devices)):
sub_items = images_with_info[i::len(self.models)]
if not sub_items:
continue
tasks.append(
self.executor.submit(self._process_batch_on_device, model, device, sub_items)
)
all_results = []
for fut in tasks:
try:
sub_res = fut.result()
all_results.extend(sub_res)
except Exception as e:
logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True)
# 保证结果顺序与原始 index 一致
all_results.sort(key=lambda x: x[1])
return all_results
def _process_batch_on_device(self, model, device, images_with_info):
"""在指定 device 上批量处理图像"""
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
batch_tensors = []
for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image))
input_batch = torch.stack(batch_tensors).to(device)
# 如果模型是半精度,输入也转换为半精度
if next(model.parameters()).dtype == torch.float16:
input_batch = input_batch.half()
# 释放 batch_tensors 占用的 CPU 内存
del batch_tensors
with torch.no_grad():
model_output = model(input_batch)
if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu()
else:
preds = model_output.sigmoid().cpu()
# 立即释放 GPU 上的 input_batch 和 model_output
del input_batch
if isinstance(model_output, (list, tuple)):
del model_output
# 复用 ToPILImage 转换器,避免重复创建对象
to_pil = transforms.ToPILImage()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
if len(preds.shape) == 4:
pred = preds[i].squeeze()
elif len(preds.shape) == 3:
pred = preds[i]
else:
pred = preds[i].squeeze()
pred_pil = to_pil(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
results.append((result_image, index))
# 释放 preds
del preds
# 批处理后清理显存(移除 gc.collect(),减少阻塞)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
async def process_image(self, image):
"""异步处理图像,移除背景(单张)- 使用队列批处理模式"""
if settings.enable_queue_batch and self.queue_running:
return await self._process_image_via_queue(image)
else:
# 降级到单张处理
return await asyncio.get_event_loop().run_in_executor(
self.executor, self._process_image_sync, image
)
async def _process_image_via_queue(self, image):
"""通过队列批处理模式处理单张图像"""
request_id = uuid.uuid4().hex[:10]
future = asyncio.Future()
queue_item = QueueItem(
image=image,
image_size=image.size,
request_id=request_id,
future=future,
created_at=time.time()
)
try:
await self.queue.put(queue_item)
# 等待处理结果,带超时
try:
result = await asyncio.wait_for(future, timeout=settings.request_timeout)
return result
except asyncio.TimeoutError:
future.cancel()
raise Exception(f"处理超时(超过{settings.request_timeout}秒)")
except Exception as e:
if not future.done():
future.set_exception(e)
raise
async def process_batch_images(self, images_with_info):
"""异步批量处理图像(批处理模式)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, self._process_batch_images_sync, images_with_info
)
async def _start_queue_processor(self):
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用)"""
if self.queue_running:
return
self.queue_running = True
self.queue_task = asyncio.create_task(self._queue_processor())
async def _queue_processor(self):
"""后台队列批处理任务(核心逻辑)"""
while self.queue_running:
try:
# 收集一批请求
batch_items = await self._collect_batch_items()
if not batch_items:
continue
# 处理这批请求
await self._process_batch_queue_items(batch_items)
except Exception as e:
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True)
await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self):
"""收集一批队列项达到batch_size或超时后返回"""
batch_items = []
batch_size = settings.batch_size
collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout
# 先尝试获取第一个请求(阻塞等待)
try:
first_item = await asyncio.wait_for(
self.queue.get(),
timeout=collect_timeout
)
batch_items.append(first_item)
except asyncio.TimeoutError:
# 超时,返回空列表
return []
# 继续收集更多请求直到达到batch_size或超时
start_time = time.time()
while len(batch_items) < batch_size:
elapsed = time.time() - start_time
# 如果已经超时,立即处理当前收集的请求
if elapsed >= collect_timeout:
break
# 尝试在剩余时间内获取更多请求
remaining_time = min(collect_interval, collect_timeout - elapsed)
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=remaining_time
)
batch_items.append(item)
except asyncio.TimeoutError:
# 超时,处理已收集的请求
break
return batch_items
async def _process_batch_queue_items(self, batch_items):
"""处理一批队列项"""
if not batch_items:
return
loop = asyncio.get_event_loop()
try:
# 准备批处理数据
images_with_info = []
for idx, item in enumerate(batch_items):
images_with_info.append((item.image, item.image_size, idx))
# 执行批处理
batch_results = await self.process_batch_images(images_with_info)
# 将结果返回给对应的Future
for idx, (processed_image, _) in enumerate(batch_results):
if idx < len(batch_items):
item = batch_items[idx]
# 保存图片并返回URL
try:
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, processed_image
)
result = {
"status": "success",
"image_url": image_url
}
if not item.future.done():
item.future.set_result(result)
except Exception as e:
error_msg = f"处理图片失败: {str(e)}"
logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}")
if not item.future.done():
item.future.set_exception(Exception(error_msg))
# 处理任何未完成的Future理论上不应该发生
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception("批处理结果不完整"))
except Exception as e:
error_msg = f"批处理队列项失败: {str(e)}"
logger.error(error_msg, exc_info=True)
# 所有请求都标记为失败
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception(error_msg))
def save_image_to_file(self, image):
"""保存图片到文件并返回URL"""
# 改为保存 PNG使用标准压缩级别Pillow 默认 compress_level=6
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
image.save(file_path, format="PNG")
image_url = f"{self.download_url}/{filename}"
return image_url
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
if self.is_valid_url(image_path):
try:
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(image_path).convert("RGB")
)
result = await self.process_image(image)
if isinstance(result, dict):
return result
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result
)
return {"status": "success", "image_url": image_url}
finally:
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
result = await self.process_image(image)
if isinstance(result, dict):
return result
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result
)
return {"status": "success", "image_url": image_url}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""批量处理多个URL图像流水线批处理模式下载和处理并行"""
total = len(urls)
success_count = 0
error_count = 0
batch_start_time = time.time()
batch_size = settings.batch_size
loop = asyncio.get_event_loop()
# 流水线队列:收集已下载的图片
download_queue = asyncio.Queue()
download_complete = asyncio.Event()
download_done_count = 0
download_error_count = 0
async def download_image_async(index, url):
"""异步下载图片并放入队列"""
nonlocal download_done_count, download_error_count
url_str = str(url)
try:
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await loop.run_in_executor(
self.executor, lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await loop.run_in_executor(
self.executor, lambda: Image.open(url_str).convert("RGB")
)
# 下载成功,放入队列
await download_queue.put((image, image.size, index, url_str, None))
download_done_count += 1
except Exception as e:
# 下载失败,也放入队列(标记为错误)
await download_queue.put((None, None, index, url_str, str(e)))
download_error_count += 1
download_done_count += 1
finally:
# 所有下载任务完成
if download_done_count >= total:
download_complete.set()
# 启动所有下载任务(并行下载)
download_tasks = [
asyncio.create_task(download_image_async(i, url))
for i, url in enumerate(urls, 1)
]
# 流水线批处理任务收集队列中的图片达到batch_size或超时后立即处理
completed_order = 0
pending_batch = []
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
async def process_pending_batch(force=False):
"""处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)"""
nonlocal pending_batch, completed_order, success_count, error_count
if not pending_batch:
return
# 分离成功和失败的图片
valid_items = []
failed_items = []
for item in pending_batch:
image, image_size, index, url_str, error = item
if error:
failed_items.append((index, url_str, error))
else:
valid_items.append((image, image_size, index, url_str))
# 先处理下载失败的
for index, url_str, error in failed_items:
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error,
"message": f"下载失败: {error}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
if not valid_items:
pending_batch = []
return
# 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度)
try:
# 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块
for local_batch_start in range(0, len(valid_items), batch_size):
local_batch_end = min(local_batch_start + batch_size, len(valid_items))
local_batch_items = valid_items[local_batch_start:local_batch_end]
# 本地批次内并发调用 process_image让全局队列有机会凑大 batch 并利用多 GPU
tasks = [
self.process_image(image)
for image, _, _, _ in local_batch_items
]
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
if isinstance(result_data, Exception):
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(result_data),
"message": f"处理失败: {str(result_data)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
continue
if isinstance(result_data, dict):
# 队列模式下 process_image 通常直接返回 {status, image_url}
status = result_data.get("status", "success")
image_url = result_data.get("image_url")
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
error_msg = result_data.get("error")
else:
# 兼容非 dict 返回:手动保存图片
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result_data
)
status = "success"
message = "处理成功"
error_msg = None
completed_order += 1
if status == "success" and image_url:
success_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": image_url,
"message": message,
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
else:
error_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error_msg or "处理失败",
"message": message,
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except Exception as e:
logger.error(f"批处理失败: {str(e)}")
for _, _, index, url_str, _ in valid_items:
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"批处理失败: {str(e)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
pending_batch = []
# 流水线处理收集队列中的图片达到batch_size或超时后立即处理
while True:
try:
# 等待队列中有新图片,或超时
try:
item = await asyncio.wait_for(
download_queue.get(),
timeout=batch_collect_timeout
)
pending_batch.append(item)
except asyncio.TimeoutError:
# 超时,处理当前批次
if pending_batch:
async for result in process_pending_batch():
yield result
# 检查是否所有下载都完成
if download_complete.is_set():
break
continue
# 如果达到batch_size立即处理
if len(pending_batch) >= batch_size:
async for result in process_pending_batch():
yield result
# 检查是否所有下载都完成
if download_complete.is_set() and download_queue.empty():
# 处理剩余的图片
if pending_batch:
async for result in process_pending_batch(force=True):
yield result
break
except Exception as e:
logger.error(f"流水线处理出错: {str(e)}", exc_info=True)
break
# 等待所有下载任务完成
await asyncio.gather(*download_tasks, return_exceptions=True)
# 确保所有结果都已处理
if pending_batch:
async for result in process_pending_batch(force=True):
yield result
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
self.executor, write_temp_file, response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
async def cleanup(self):
"""清理资源"""
# 停止队列处理任务
if self.queue_running:
self.queue_running = False
if self.queue_task:
self.queue_task.cancel()
try:
await self.queue_task
except asyncio.CancelledError:
pass
# 处理队列中剩余的请求
remaining_items = []
while not self.queue.empty():
try:
item = self.queue.get_nowait()
remaining_items.append(item)
except asyncio.QueueEmpty:
break
# 标记剩余请求为失败
for item in remaining_items:
if not item.future.done():
item.future.set_exception(Exception("服务关闭,请求被取消"))
await self.http_client.aclose()
self.executor.shutdown(wait=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()