refactor: 实现严格的流水线式方案,每GPU独立worker处理队列

- 架构重构:为每个GPU启动独立的队列处理worker,避免worker间竞争
- 单卡batch收集:每个worker只收集batch_size个请求,不再乘以GPU数量
- 设备绑定:每个worker固定绑定自己的model和device,不再轮询调度
- 处理逻辑:直接使用worker的model/device进行批处理,移除多GPU拆分逻辑
- 降级处理:OOM时使用当前worker的model/device进行单张处理
- 资源管理:更新cleanup方法,正确停止所有worker任务
- API更新:修复已弃用的PYTORCH_CUDA_ALLOC_CONF和torch_dtype参数

优势:
- 避免worker之间竞争和批次冲突
- 资源隔离,每个worker只使用自己的GPU
- 负载均衡,多worker并行处理提高吞吐量
- 易于扩展,GPU数量变化时自动调整worker数量
This commit is contained in:
jingrow 2025-12-16 16:36:41 +00:00
parent 0757566f0e
commit 57bfa17ac7
2 changed files with 69 additions and 44 deletions

View File

@ -69,9 +69,9 @@ class RmbgService:
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
# 队列聚合机制方案B
# 队列聚合机制方案B:严格的流水线式,每 GPU 一个 worker
self.queue: asyncio.Queue = asyncio.Queue()
self.queue_task: Optional[asyncio.Task] = None
self.queue_tasks: list[asyncio.Task] = [] # 存储所有 worker 任务
self.queue_running = False
self._load_model()
@ -81,7 +81,7 @@ class RmbgService:
"""加载模型,支持多 GPU"""
# 优化显存分配策略:减少碎片化(需要在加载前设置)
if torch.cuda.is_available():
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
os.environ.setdefault('PYTORCH_ALLOC_CONF', 'expandable_segments:True')
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
use_half = torch.cuda.is_available()
@ -92,7 +92,7 @@ class RmbgService:
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
torch_dtype=torch.float16 if use_half else torch.float32,
dtype=torch.float16 if use_half else torch.float32,
)
model = model.to(device)
if use_half:
@ -146,8 +146,12 @@ class RmbgService:
return self.models[idx], self.devices[idx]
def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)"""
"""同步处理图像,移除背景(单张)- 兼容旧接口,使用轮询调度"""
model, device = self._get_model_and_device()
return self._process_single_image_on_device(model, device, image)
def _process_single_image_on_device(self, model, device, image):
"""在指定设备上处理单张图像(用于 worker 降级处理)"""
image_size = image.size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
@ -309,36 +313,46 @@ class RmbgService:
)
async def _start_queue_processor(self):
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用"""
"""启动队列批处理后台任务(严格的流水线式方案B每 GPU 一个 worker"""
if self.queue_running:
return
self.queue_running = True
self.queue_task = asyncio.create_task(self._queue_processor())
# 为每个 GPU 启动一个独立的 worker
num_workers = len(self.models) if self.models else 1
logger.info(f"启动 {num_workers} 个队列处理 worker每 GPU 一个)")
for worker_id in range(num_workers):
task = asyncio.create_task(self._queue_processor(worker_id))
self.queue_tasks.append(task)
async def _queue_processor(self):
"""后台队列批处理任务(核心逻辑)"""
async def _queue_processor(self, worker_id: int):
"""后台队列批处理任务(核心逻辑)- 每个 worker 绑定一个 GPU"""
model = self.models[worker_id]
device = self.devices[worker_id]
logger.info(f"Worker {worker_id} 启动,绑定设备: {device}")
while self.queue_running:
try:
# 收集一批请求
# 收集一批请求(单卡 batch_size
batch_items = await self._collect_batch_items()
if not batch_items:
continue
# 处理这批请求
await self._process_batch_queue_items(batch_items)
# 处理这批请求(只使用当前 worker 的 model 和 device
await self._process_batch_queue_items(batch_items, model, device, worker_id)
except Exception as e:
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True)
logger.error(f"Worker {worker_id} 队列批处理任务出错: {str(e)}", exc_info=True)
await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self):
"""收集一批队列项,达到单卡batch_size×GPU数量或超时后返回支持跨用户合批"""
"""收集一批队列项,达到单卡 batch_size 或超时后返回(单卡 batch避免 worker 之间打架"""
batch_items = []
per_device_batch = settings.batch_size # 单卡 batch_size
device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1))
target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限
target_batch_size = settings.batch_size # 单卡 batch_size不再乘以 GPU 数量)
collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout
@ -379,8 +393,8 @@ class RmbgService:
return batch_items
async def _process_batch_queue_items(self, batch_items):
"""处理一批队列项(统一全局 batcher支持跨用户合批"""
async def _process_batch_queue_items(self, batch_items, model, device, worker_id: int):
"""处理一批队列项(单卡处理,使用指定的 model 和 device"""
if not batch_items:
return
@ -394,8 +408,14 @@ class RmbgService:
images_with_info.append((item.image, item.image_size, idx))
item_index_map[idx] = item
# 执行批处理(直接调用,充分利用 GPU
batch_results = await self.process_batch_images(images_with_info)
# 执行批处理(只使用当前 worker 的 model 和 device不再做多卡拆分
batch_results = await loop.run_in_executor(
self.executor,
self._process_batch_on_device,
model,
device,
images_with_info
)
# 并行保存所有图片(关键优化:避免串行 IO 阻塞)
save_tasks = []
@ -445,27 +465,30 @@ class RmbgService:
# CUDA OOM 错误,降级处理
error_msg = str(e)
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
logger.warning(f"批处理显存不足,降级到单张处理: {error_msg[:100]}")
logger.warning(f"Worker {worker_id} 批处理显存不足,降级到单张处理: {error_msg[:100]}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 降级:单张处理
# 降级:单张处理(使用当前 worker 的 model 和 device
for item in batch_items:
try:
result_data = await self.process_image(item.image)
if isinstance(result_data, dict):
if not item.future.done():
item.future.set_result(result_data)
else:
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result_data
)
if not item.future.done():
item.future.set_result({
"status": "success",
"image_url": image_url
})
# 使用当前 worker 的 model 和 device 进行单张处理
result_image = await loop.run_in_executor(
self.executor,
self._process_single_image_on_device,
model,
device,
item.image
)
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result_image
)
if not item.future.done():
item.future.set_result({
"status": "success",
"image_url": image_url
})
except Exception as e2:
if not item.future.done():
item.future.set_exception(Exception(f"降级处理失败: {str(e2)}"))
@ -753,15 +776,17 @@ class RmbgService:
async def cleanup(self):
"""清理资源"""
# 停止队列处理任务
# 停止所有队列处理 worker 任务
if self.queue_running:
self.queue_running = False
if self.queue_task:
self.queue_task.cancel()
try:
await self.queue_task
except asyncio.CancelledError:
pass
# 取消所有 worker 任务
for task in self.queue_tasks:
if task:
task.cancel()
# 等待所有任务完成取消
if self.queue_tasks:
await asyncio.gather(*self.queue_tasks, return_exceptions=True)
self.queue_tasks.clear()
# 处理队列中剩余的请求
remaining_items = []

View File

@ -35,7 +35,7 @@ class Settings(BaseSettings):
# 并发控制配置(推理侧)
max_workers: int = 60 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
batch_size: int = 16 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM
batch_size: int = 8 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM
# 队列聚合配置方案B批处理+队列模式)
batch_collect_interval: float = 0.05 # 批处理收集间隔50ms收集一次平衡延迟和吞吐量