refactor: 实现严格的流水线式方案,每GPU独立worker处理队列

- 架构重构:为每个GPU启动独立的队列处理worker,避免worker间竞争
- 单卡batch收集:每个worker只收集batch_size个请求,不再乘以GPU数量
- 设备绑定:每个worker固定绑定自己的model和device,不再轮询调度
- 处理逻辑:直接使用worker的model/device进行批处理,移除多GPU拆分逻辑
- 降级处理:OOM时使用当前worker的model/device进行单张处理
- 资源管理:更新cleanup方法,正确停止所有worker任务
- API更新:修复已弃用的PYTORCH_CUDA_ALLOC_CONF和torch_dtype参数

优势:
- 避免worker之间竞争和批次冲突
- 资源隔离,每个worker只使用自己的GPU
- 负载均衡,多worker并行处理提高吞吐量
- 易于扩展,GPU数量变化时自动调整worker数量
This commit is contained in:
jingrow 2025-12-16 16:36:41 +00:00
parent 0757566f0e
commit 57bfa17ac7
2 changed files with 69 additions and 44 deletions

View File

@ -69,9 +69,9 @@ class RmbgService:
) )
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
# 队列聚合机制方案B # 队列聚合机制方案B:严格的流水线式,每 GPU 一个 worker
self.queue: asyncio.Queue = asyncio.Queue() self.queue: asyncio.Queue = asyncio.Queue()
self.queue_task: Optional[asyncio.Task] = None self.queue_tasks: list[asyncio.Task] = [] # 存储所有 worker 任务
self.queue_running = False self.queue_running = False
self._load_model() self._load_model()
@ -81,7 +81,7 @@ class RmbgService:
"""加载模型,支持多 GPU""" """加载模型,支持多 GPU"""
# 优化显存分配策略:减少碎片化(需要在加载前设置) # 优化显存分配策略:减少碎片化(需要在加载前设置)
if torch.cuda.is_available(): if torch.cuda.is_available():
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') os.environ.setdefault('PYTORCH_ALLOC_CONF', 'expandable_segments:True')
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0 num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
use_half = torch.cuda.is_available() use_half = torch.cuda.is_available()
@ -92,7 +92,7 @@ class RmbgService:
model = AutoModelForImageSegmentation.from_pretrained( model = AutoModelForImageSegmentation.from_pretrained(
self.model_path, self.model_path,
trust_remote_code=True, trust_remote_code=True,
torch_dtype=torch.float16 if use_half else torch.float32, dtype=torch.float16 if use_half else torch.float32,
) )
model = model.to(device) model = model.to(device)
if use_half: if use_half:
@ -146,8 +146,12 @@ class RmbgService:
return self.models[idx], self.devices[idx] return self.models[idx], self.devices[idx]
def _process_image_sync(self, image): def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)""" """同步处理图像,移除背景(单张)- 兼容旧接口,使用轮询调度"""
model, device = self._get_model_and_device() model, device = self._get_model_and_device()
return self._process_single_image_on_device(model, device, image)
def _process_single_image_on_device(self, model, device, image):
"""在指定设备上处理单张图像(用于 worker 降级处理)"""
image_size = image.size image_size = image.size
transform_image = transforms.Compose([ transform_image = transforms.Compose([
transforms.Resize((1024, 1024)), transforms.Resize((1024, 1024)),
@ -309,36 +313,46 @@ class RmbgService:
) )
async def _start_queue_processor(self): async def _start_queue_processor(self):
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用""" """启动队列批处理后台任务(严格的流水线式方案B每 GPU 一个 worker"""
if self.queue_running: if self.queue_running:
return return
self.queue_running = True self.queue_running = True
self.queue_task = asyncio.create_task(self._queue_processor())
async def _queue_processor(self): # 为每个 GPU 启动一个独立的 worker
"""后台队列批处理任务(核心逻辑)""" num_workers = len(self.models) if self.models else 1
logger.info(f"启动 {num_workers} 个队列处理 worker每 GPU 一个)")
for worker_id in range(num_workers):
task = asyncio.create_task(self._queue_processor(worker_id))
self.queue_tasks.append(task)
async def _queue_processor(self, worker_id: int):
"""后台队列批处理任务(核心逻辑)- 每个 worker 绑定一个 GPU"""
model = self.models[worker_id]
device = self.devices[worker_id]
logger.info(f"Worker {worker_id} 启动,绑定设备: {device}")
while self.queue_running: while self.queue_running:
try: try:
# 收集一批请求 # 收集一批请求(单卡 batch_size
batch_items = await self._collect_batch_items() batch_items = await self._collect_batch_items()
if not batch_items: if not batch_items:
continue continue
# 处理这批请求 # 处理这批请求(只使用当前 worker 的 model 和 device
await self._process_batch_queue_items(batch_items) await self._process_batch_queue_items(batch_items, model, device, worker_id)
except Exception as e: except Exception as e:
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True) logger.error(f"Worker {worker_id} 队列批处理任务出错: {str(e)}", exc_info=True)
await asyncio.sleep(0.1) # 出错后短暂等待 await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self): async def _collect_batch_items(self):
"""收集一批队列项,达到单卡batch_size×GPU数量或超时后返回支持跨用户合批""" """收集一批队列项,达到单卡 batch_size 或超时后返回(单卡 batch避免 worker 之间打架"""
batch_items = [] batch_items = []
per_device_batch = settings.batch_size # 单卡 batch_size target_batch_size = settings.batch_size # 单卡 batch_size不再乘以 GPU 数量)
device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1))
target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限
collect_interval = settings.batch_collect_interval collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout collect_timeout = settings.batch_collect_timeout
@ -379,8 +393,8 @@ class RmbgService:
return batch_items return batch_items
async def _process_batch_queue_items(self, batch_items): async def _process_batch_queue_items(self, batch_items, model, device, worker_id: int):
"""处理一批队列项(统一全局 batcher支持跨用户合批""" """处理一批队列项(单卡处理,使用指定的 model 和 device"""
if not batch_items: if not batch_items:
return return
@ -394,8 +408,14 @@ class RmbgService:
images_with_info.append((item.image, item.image_size, idx)) images_with_info.append((item.image, item.image_size, idx))
item_index_map[idx] = item item_index_map[idx] = item
# 执行批处理(直接调用,充分利用 GPU # 执行批处理(只使用当前 worker 的 model 和 device不再做多卡拆分
batch_results = await self.process_batch_images(images_with_info) batch_results = await loop.run_in_executor(
self.executor,
self._process_batch_on_device,
model,
device,
images_with_info
)
# 并行保存所有图片(关键优化:避免串行 IO 阻塞) # 并行保存所有图片(关键优化:避免串行 IO 阻塞)
save_tasks = [] save_tasks = []
@ -445,27 +465,30 @@ class RmbgService:
# CUDA OOM 错误,降级处理 # CUDA OOM 错误,降级处理
error_msg = str(e) error_msg = str(e)
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
logger.warning(f"批处理显存不足,降级到单张处理: {error_msg[:100]}") logger.warning(f"Worker {worker_id} 批处理显存不足,降级到单张处理: {error_msg[:100]}")
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
gc.collect() gc.collect()
# 降级:单张处理 # 降级:单张处理(使用当前 worker 的 model 和 device
for item in batch_items: for item in batch_items:
try: try:
result_data = await self.process_image(item.image) # 使用当前 worker 的 model 和 device 进行单张处理
if isinstance(result_data, dict): result_image = await loop.run_in_executor(
if not item.future.done(): self.executor,
item.future.set_result(result_data) self._process_single_image_on_device,
else: model,
image_url = await loop.run_in_executor( device,
self.executor, self.save_image_to_file, result_data item.image
) )
if not item.future.done(): image_url = await loop.run_in_executor(
item.future.set_result({ self.executor, self.save_image_to_file, result_image
"status": "success", )
"image_url": image_url if not item.future.done():
}) item.future.set_result({
"status": "success",
"image_url": image_url
})
except Exception as e2: except Exception as e2:
if not item.future.done(): if not item.future.done():
item.future.set_exception(Exception(f"降级处理失败: {str(e2)}")) item.future.set_exception(Exception(f"降级处理失败: {str(e2)}"))
@ -753,15 +776,17 @@ class RmbgService:
async def cleanup(self): async def cleanup(self):
"""清理资源""" """清理资源"""
# 停止队列处理任务 # 停止所有队列处理 worker 任务
if self.queue_running: if self.queue_running:
self.queue_running = False self.queue_running = False
if self.queue_task: # 取消所有 worker 任务
self.queue_task.cancel() for task in self.queue_tasks:
try: if task:
await self.queue_task task.cancel()
except asyncio.CancelledError: # 等待所有任务完成取消
pass if self.queue_tasks:
await asyncio.gather(*self.queue_tasks, return_exceptions=True)
self.queue_tasks.clear()
# 处理队列中剩余的请求 # 处理队列中剩余的请求
remaining_items = [] remaining_items = []

View File

@ -35,7 +35,7 @@ class Settings(BaseSettings):
# 并发控制配置(推理侧) # 并发控制配置(推理侧)
max_workers: int = 60 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30 max_workers: int = 60 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
batch_size: int = 16 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM batch_size: int = 8 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM
# 队列聚合配置方案B批处理+队列模式) # 队列聚合配置方案B批处理+队列模式)
batch_collect_interval: float = 0.05 # 批处理收集间隔50ms收集一次平衡延迟和吞吐量 batch_collect_interval: float = 0.05 # 批处理收集间隔50ms收集一次平衡延迟和吞吐量