引入 num_devices 缓存在模型加载时一次性确定 GPU 数量,并基于该值直接计算全局批大小(移除原先批次大小额外 ×2 的隐式放大逻辑,由配置与设备数共同控制整体批规模)。

This commit is contained in:
jingrow 2025-12-16 14:25:53 +00:00
parent fbcb614c73
commit 0757566f0e
2 changed files with 14 additions and 21 deletions

View File

@ -50,6 +50,8 @@ class RmbgService:
# 单机多 GPU维护模型和设备列表兼容旧字段
self.models = []
self.devices = []
# 设备数量缓存GPU 数量CPU 视作 1 个设备)
self.num_devices = 1
self.model = None
self.device = None
self._gpu_lock = Lock()
@ -126,6 +128,9 @@ class RmbgService:
self.device = self.devices[0]
self.model = self.models[0]
# 缓存设备数量(用于根据 GPU 数量自动放大 batch
self.num_devices = max(1, len(self.devices))
if torch.cuda.is_available():
torch.cuda.empty_cache()
@ -329,16 +334,15 @@ class RmbgService:
await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self):
"""收集一批队列项,达到batch_size或超时后返回支持跨用户合批"""
"""收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)"""
batch_items = []
batch_size = settings.batch_size
per_device_batch = settings.batch_size # 单卡 batch_size
device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1))
target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限
collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout
# 动态调整最大 batch 大小:高并发时允许更大的 batch充分利用 GPU
# 如果队列中有很多待处理项,可以收集更多
max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size
# 先尝试获取第一个请求(阻塞等待)
try:
first_item = await asyncio.wait_for(
@ -350,27 +354,16 @@ class RmbgService:
# 超时,返回空列表
return []
# 继续收集更多请求,直到达到max_batch_size或超时
# 继续收集更多请求,直到达到 target_batch_size 或超时
start_time = time.time()
while len(batch_items) < max_batch_size:
while len(batch_items) < target_batch_size:
elapsed = time.time() - start_time
# 如果已经超时,立即处理当前收集的请求
if elapsed >= collect_timeout:
break
# 如果已经达到最小 batch_size且队列中没有更多项可以提前处理
if len(batch_items) >= batch_size:
# 尝试非阻塞获取,如果没有立即返回,就处理当前批次
try:
item = self.queue.get_nowait()
batch_items.append(item)
continue
except asyncio.QueueEmpty:
# 队列为空,处理当前批次
break
# 尝试在剩余时间内获取更多请求
remaining_time = min(collect_interval, collect_timeout - elapsed)

View File

@ -34,8 +34,8 @@ class Settings(BaseSettings):
http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数
# 并发控制配置(推理侧)
max_workers: int = 30 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
batch_size: int = 8 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM
max_workers: int = 60 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
batch_size: int = 16 # GPU批处理大小模型显存占用较大8是安全值16会导致OOM
# 队列聚合配置方案B批处理+队列模式)
batch_collect_interval: float = 0.05 # 批处理收集间隔50ms收集一次平衡延迟和吞吐量