引入 num_devices 缓存在模型加载时一次性确定 GPU 数量,并基于该值直接计算全局批大小(移除原先批次大小额外 ×2 的隐式放大逻辑,由配置与设备数共同控制整体批规模)。
This commit is contained in:
parent
fbcb614c73
commit
0757566f0e
@ -50,6 +50,8 @@ class RmbgService:
|
|||||||
# 单机多 GPU:维护模型和设备列表,兼容旧字段
|
# 单机多 GPU:维护模型和设备列表,兼容旧字段
|
||||||
self.models = []
|
self.models = []
|
||||||
self.devices = []
|
self.devices = []
|
||||||
|
# 设备数量缓存(GPU 数量,CPU 视作 1 个设备)
|
||||||
|
self.num_devices = 1
|
||||||
self.model = None
|
self.model = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self._gpu_lock = Lock()
|
self._gpu_lock = Lock()
|
||||||
@ -126,6 +128,9 @@ class RmbgService:
|
|||||||
self.device = self.devices[0]
|
self.device = self.devices[0]
|
||||||
self.model = self.models[0]
|
self.model = self.models[0]
|
||||||
|
|
||||||
|
# 缓存设备数量(用于根据 GPU 数量自动放大 batch)
|
||||||
|
self.num_devices = max(1, len(self.devices))
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
@ -329,16 +334,15 @@ class RmbgService:
|
|||||||
await asyncio.sleep(0.1) # 出错后短暂等待
|
await asyncio.sleep(0.1) # 出错后短暂等待
|
||||||
|
|
||||||
async def _collect_batch_items(self):
|
async def _collect_batch_items(self):
|
||||||
"""收集一批队列项,达到batch_size或超时后返回(支持跨用户合批)"""
|
"""收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)"""
|
||||||
batch_items = []
|
batch_items = []
|
||||||
batch_size = settings.batch_size
|
per_device_batch = settings.batch_size # 单卡 batch_size
|
||||||
|
device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1))
|
||||||
|
target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限
|
||||||
|
|
||||||
collect_interval = settings.batch_collect_interval
|
collect_interval = settings.batch_collect_interval
|
||||||
collect_timeout = settings.batch_collect_timeout
|
collect_timeout = settings.batch_collect_timeout
|
||||||
|
|
||||||
# 动态调整最大 batch 大小:高并发时允许更大的 batch(充分利用 GPU)
|
|
||||||
# 如果队列中有很多待处理项,可以收集更多
|
|
||||||
max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size
|
|
||||||
|
|
||||||
# 先尝试获取第一个请求(阻塞等待)
|
# 先尝试获取第一个请求(阻塞等待)
|
||||||
try:
|
try:
|
||||||
first_item = await asyncio.wait_for(
|
first_item = await asyncio.wait_for(
|
||||||
@ -350,27 +354,16 @@ class RmbgService:
|
|||||||
# 超时,返回空列表
|
# 超时,返回空列表
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 继续收集更多请求,直到达到max_batch_size或超时
|
# 继续收集更多请求,直到达到 target_batch_size 或超时
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
while len(batch_items) < max_batch_size:
|
while len(batch_items) < target_batch_size:
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
|
|
||||||
# 如果已经超时,立即处理当前收集的请求
|
# 如果已经超时,立即处理当前收集的请求
|
||||||
if elapsed >= collect_timeout:
|
if elapsed >= collect_timeout:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 如果已经达到最小 batch_size,且队列中没有更多项,可以提前处理
|
|
||||||
if len(batch_items) >= batch_size:
|
|
||||||
# 尝试非阻塞获取,如果没有立即返回,就处理当前批次
|
|
||||||
try:
|
|
||||||
item = self.queue.get_nowait()
|
|
||||||
batch_items.append(item)
|
|
||||||
continue
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
# 队列为空,处理当前批次
|
|
||||||
break
|
|
||||||
|
|
||||||
# 尝试在剩余时间内获取更多请求
|
# 尝试在剩余时间内获取更多请求
|
||||||
remaining_time = min(collect_interval, collect_timeout - elapsed)
|
remaining_time = min(collect_interval, collect_timeout - elapsed)
|
||||||
|
|
||||||
|
|||||||
@ -34,8 +34,8 @@ class Settings(BaseSettings):
|
|||||||
http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数
|
http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数
|
||||||
|
|
||||||
# 并发控制配置(推理侧)
|
# 并发控制配置(推理侧)
|
||||||
max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
max_workers: int = 60 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
||||||
batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM)
|
batch_size: int = 16 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM)
|
||||||
|
|
||||||
# 队列聚合配置(方案B:批处理+队列模式)
|
# 队列聚合配置(方案B:批处理+队列模式)
|
||||||
batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量
|
batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user