引入 num_devices 缓存在模型加载时一次性确定 GPU 数量,并基于该值直接计算全局批大小(移除原先批次大小额外 ×2 的隐式放大逻辑,由配置与设备数共同控制整体批规模)。
This commit is contained in:
parent
fbcb614c73
commit
0757566f0e
@ -50,6 +50,8 @@ class RmbgService:
|
||||
# 单机多 GPU:维护模型和设备列表,兼容旧字段
|
||||
self.models = []
|
||||
self.devices = []
|
||||
# 设备数量缓存(GPU 数量,CPU 视作 1 个设备)
|
||||
self.num_devices = 1
|
||||
self.model = None
|
||||
self.device = None
|
||||
self._gpu_lock = Lock()
|
||||
@ -126,6 +128,9 @@ class RmbgService:
|
||||
self.device = self.devices[0]
|
||||
self.model = self.models[0]
|
||||
|
||||
# 缓存设备数量(用于根据 GPU 数量自动放大 batch)
|
||||
self.num_devices = max(1, len(self.devices))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@ -329,16 +334,15 @@ class RmbgService:
|
||||
await asyncio.sleep(0.1) # 出错后短暂等待
|
||||
|
||||
async def _collect_batch_items(self):
|
||||
"""收集一批队列项,达到batch_size或超时后返回(支持跨用户合批)"""
|
||||
"""收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)"""
|
||||
batch_items = []
|
||||
batch_size = settings.batch_size
|
||||
per_device_batch = settings.batch_size # 单卡 batch_size
|
||||
device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1))
|
||||
target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限
|
||||
|
||||
collect_interval = settings.batch_collect_interval
|
||||
collect_timeout = settings.batch_collect_timeout
|
||||
|
||||
# 动态调整最大 batch 大小:高并发时允许更大的 batch(充分利用 GPU)
|
||||
# 如果队列中有很多待处理项,可以收集更多
|
||||
max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size
|
||||
|
||||
# 先尝试获取第一个请求(阻塞等待)
|
||||
try:
|
||||
first_item = await asyncio.wait_for(
|
||||
@ -350,27 +354,16 @@ class RmbgService:
|
||||
# 超时,返回空列表
|
||||
return []
|
||||
|
||||
# 继续收集更多请求,直到达到max_batch_size或超时
|
||||
# 继续收集更多请求,直到达到 target_batch_size 或超时
|
||||
start_time = time.time()
|
||||
|
||||
while len(batch_items) < max_batch_size:
|
||||
while len(batch_items) < target_batch_size:
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# 如果已经超时,立即处理当前收集的请求
|
||||
if elapsed >= collect_timeout:
|
||||
break
|
||||
|
||||
# 如果已经达到最小 batch_size,且队列中没有更多项,可以提前处理
|
||||
if len(batch_items) >= batch_size:
|
||||
# 尝试非阻塞获取,如果没有立即返回,就处理当前批次
|
||||
try:
|
||||
item = self.queue.get_nowait()
|
||||
batch_items.append(item)
|
||||
continue
|
||||
except asyncio.QueueEmpty:
|
||||
# 队列为空,处理当前批次
|
||||
break
|
||||
|
||||
# 尝试在剩余时间内获取更多请求
|
||||
remaining_time = min(collect_interval, collect_timeout - elapsed)
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ class Settings(BaseSettings):
|
||||
http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数
|
||||
|
||||
# 并发控制配置(推理侧)
|
||||
max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
||||
batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM)
|
||||
max_workers: int = 60 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
||||
batch_size: int = 16 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM)
|
||||
|
||||
# 队列聚合配置(方案B:批处理+队列模式)
|
||||
batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user