From 0757566f0e3687865b623b304953282d9aa11de9 Mon Sep 17 00:00:00 2001 From: jingrow Date: Tue, 16 Dec 2025 14:25:53 +0000 Subject: [PATCH] =?UTF-8?q?=E5=BC=95=E5=85=A5=20num=5Fdevices=20=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E5=9C=A8=E6=A8=A1=E5=9E=8B=E5=8A=A0=E8=BD=BD=E6=97=B6?= =?UTF-8?q?=E4=B8=80=E6=AC=A1=E6=80=A7=E7=A1=AE=E5=AE=9A=20GPU=20=E6=95=B0?= =?UTF-8?q?=E9=87=8F=EF=BC=8C=E5=B9=B6=E5=9F=BA=E4=BA=8E=E8=AF=A5=E5=80=BC?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E8=AE=A1=E7=AE=97=E5=85=A8=E5=B1=80=E6=89=B9?= =?UTF-8?q?=E5=A4=A7=E5=B0=8F=EF=BC=88=E7=A7=BB=E9=99=A4=E5=8E=9F=E5=85=88?= =?UTF-8?q?=E6=89=B9=E6=AC=A1=E5=A4=A7=E5=B0=8F=E9=A2=9D=E5=A4=96=20=C3=97?= =?UTF-8?q?2=20=E7=9A=84=E9=9A=90=E5=BC=8F=E6=94=BE=E5=A4=A7=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E7=94=B1=E9=85=8D=E7=BD=AE=E4=B8=8E=E8=AE=BE?= =?UTF-8?q?=E5=A4=87=E6=95=B0=E5=85=B1=E5=90=8C=E6=8E=A7=E5=88=B6=E6=95=B4?= =?UTF-8?q?=E4=BD=93=E6=89=B9=E8=A7=84=E6=A8=A1=EF=BC=89=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/rmbg/service.py | 31 ++++++++++++------------------- apps/rmbg/settings.py | 4 ++-- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index a1ba363..afb4240 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -50,6 +50,8 @@ class RmbgService: # 单机多 GPU:维护模型和设备列表,兼容旧字段 self.models = [] self.devices = [] + # 设备数量缓存(GPU 数量,CPU 视作 1 个设备) + self.num_devices = 1 self.model = None self.device = None self._gpu_lock = Lock() @@ -126,6 +128,9 @@ class RmbgService: self.device = self.devices[0] self.model = self.models[0] + # 缓存设备数量(用于根据 GPU 数量自动放大 batch) + self.num_devices = max(1, len(self.devices)) + if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -329,16 +334,15 @@ class RmbgService: await asyncio.sleep(0.1) # 出错后短暂等待 async def _collect_batch_items(self): - """收集一批队列项,达到batch_size或超时后返回(支持跨用户合批)""" + """收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)""" batch_items = [] - batch_size = settings.batch_size + per_device_batch = settings.batch_size # 单卡 batch_size + device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1)) + target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限 + collect_interval = settings.batch_collect_interval collect_timeout = settings.batch_collect_timeout - # 动态调整最大 batch 大小:高并发时允许更大的 batch(充分利用 GPU) - # 如果队列中有很多待处理项,可以收集更多 - max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size - # 先尝试获取第一个请求(阻塞等待) try: first_item = await asyncio.wait_for( @@ -350,27 +354,16 @@ class RmbgService: # 超时,返回空列表 return [] - # 继续收集更多请求,直到达到max_batch_size或超时 + # 继续收集更多请求,直到达到 target_batch_size 或超时 start_time = time.time() - while len(batch_items) < max_batch_size: + while len(batch_items) < target_batch_size: elapsed = time.time() - start_time # 如果已经超时,立即处理当前收集的请求 if elapsed >= collect_timeout: break - # 如果已经达到最小 batch_size,且队列中没有更多项,可以提前处理 - if len(batch_items) >= batch_size: - # 尝试非阻塞获取,如果没有立即返回,就处理当前批次 - try: - item = self.queue.get_nowait() - batch_items.append(item) - continue - except asyncio.QueueEmpty: - # 队列为空,处理当前批次 - break - # 尝试在剩余时间内获取更多请求 remaining_time = min(collect_interval, collect_timeout - elapsed) diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index e28c46c..057fef6 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -34,8 +34,8 @@ class Settings(BaseSettings): http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数 # 并发控制配置(推理侧) - max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) - batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) + max_workers: int = 60 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) + batch_size: int = 16 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) # 队列聚合配置(方案B:批处理+队列模式) batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量