diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index a1ba363..afb4240 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -50,6 +50,8 @@ class RmbgService: # 单机多 GPU:维护模型和设备列表,兼容旧字段 self.models = [] self.devices = [] + # 设备数量缓存(GPU 数量,CPU 视作 1 个设备) + self.num_devices = 1 self.model = None self.device = None self._gpu_lock = Lock() @@ -126,6 +128,9 @@ class RmbgService: self.device = self.devices[0] self.model = self.models[0] + # 缓存设备数量(用于根据 GPU 数量自动放大 batch) + self.num_devices = max(1, len(self.devices)) + if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -329,16 +334,15 @@ class RmbgService: await asyncio.sleep(0.1) # 出错后短暂等待 async def _collect_batch_items(self): - """收集一批队列项,达到batch_size或超时后返回(支持跨用户合批)""" + """收集一批队列项,达到单卡batch_size×GPU数量或超时后返回(支持跨用户合批)""" batch_items = [] - batch_size = settings.batch_size + per_device_batch = settings.batch_size # 单卡 batch_size + device_count = max(1, getattr(self, "num_devices", len(self.devices) or 1)) + target_batch_size = per_device_batch * device_count # 本次期望的全局 batch 上限 + collect_interval = settings.batch_collect_interval collect_timeout = settings.batch_collect_timeout - # 动态调整最大 batch 大小:高并发时允许更大的 batch(充分利用 GPU) - # 如果队列中有很多待处理项,可以收集更多 - max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size - # 先尝试获取第一个请求(阻塞等待) try: first_item = await asyncio.wait_for( @@ -350,27 +354,16 @@ class RmbgService: # 超时,返回空列表 return [] - # 继续收集更多请求,直到达到max_batch_size或超时 + # 继续收集更多请求,直到达到 target_batch_size 或超时 start_time = time.time() - while len(batch_items) < max_batch_size: + while len(batch_items) < target_batch_size: elapsed = time.time() - start_time # 如果已经超时,立即处理当前收集的请求 if elapsed >= collect_timeout: break - # 如果已经达到最小 batch_size,且队列中没有更多项,可以提前处理 - if len(batch_items) >= batch_size: - # 尝试非阻塞获取,如果没有立即返回,就处理当前批次 - try: - item = self.queue.get_nowait() - batch_items.append(item) - continue - except asyncio.QueueEmpty: - # 队列为空,处理当前批次 - break - # 尝试在剩余时间内获取更多请求 remaining_time = min(collect_interval, collect_timeout - elapsed) diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index e28c46c..057fef6 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -34,8 +34,8 @@ class Settings(BaseSettings): http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数 # 并发控制配置(推理侧) - max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) - batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) + max_workers: int = 60 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) + batch_size: int = 16 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM) # 队列聚合配置(方案B:批处理+队列模式) batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量