diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index ef1431e..a1ba363 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -38,6 +38,9 @@ class QueueItem: request_id: str future: asyncio.Future created_at: float + # 用于 batch 接口的额外字段 + url_str: Optional[str] = None # 原始 URL(用于 batch 接口) + batch_index: Optional[int] = None # 在 batch 中的索引(用于 batch 接口) class RmbgService: @@ -56,11 +59,11 @@ class RmbgService: os.makedirs(self.save_dir, exist_ok=True) self.http_client = httpx.AsyncClient( - timeout=30.0, + timeout=30.0, limits=httpx.Limits( - max_keepalive_connections=50, - max_connections=100 - ) + max_keepalive_connections=settings.http_max_keepalive_connections, + max_connections=settings.http_max_connections, + ), ) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) @@ -326,12 +329,16 @@ class RmbgService: await asyncio.sleep(0.1) # 出错后短暂等待 async def _collect_batch_items(self): - """收集一批队列项,达到batch_size或超时后返回""" + """收集一批队列项,达到batch_size或超时后返回(支持跨用户合批)""" batch_items = [] batch_size = settings.batch_size collect_interval = settings.batch_collect_interval collect_timeout = settings.batch_collect_timeout + # 动态调整最大 batch 大小:高并发时允许更大的 batch(充分利用 GPU) + # 如果队列中有很多待处理项,可以收集更多 + max_batch_size = batch_size * 2 # 最多收集 2 倍 batch_size + # 先尝试获取第一个请求(阻塞等待) try: first_item = await asyncio.wait_for( @@ -343,16 +350,27 @@ class RmbgService: # 超时,返回空列表 return [] - # 继续收集更多请求,直到达到batch_size或超时 + # 继续收集更多请求,直到达到max_batch_size或超时 start_time = time.time() - while len(batch_items) < batch_size: + while len(batch_items) < max_batch_size: elapsed = time.time() - start_time # 如果已经超时,立即处理当前收集的请求 if elapsed >= collect_timeout: break + # 如果已经达到最小 batch_size,且队列中没有更多项,可以提前处理 + if len(batch_items) >= batch_size: + # 尝试非阻塞获取,如果没有立即返回,就处理当前批次 + try: + item = self.queue.get_nowait() + batch_items.append(item) + continue + except asyncio.QueueEmpty: + # 队列为空,处理当前批次 + break + # 尝试在剩余时间内获取更多请求 remaining_time = min(collect_interval, collect_timeout - elapsed) @@ -369,50 +387,102 @@ class RmbgService: return batch_items async def _process_batch_queue_items(self, batch_items): - """处理一批队列项""" + """处理一批队列项(统一全局 batcher,支持跨用户合批)""" if not batch_items: return loop = asyncio.get_event_loop() try: - # 准备批处理数据 + # 准备批处理数据(保持原始索引映射) images_with_info = [] + item_index_map = {} # 映射:队列中的索引 -> QueueItem for idx, item in enumerate(batch_items): images_with_info.append((item.image, item.image_size, idx)) + item_index_map[idx] = item - # 执行批处理 + # 执行批处理(直接调用,充分利用 GPU) batch_results = await self.process_batch_images(images_with_info) - # 将结果返回给对应的Future - for idx, (processed_image, _) in enumerate(batch_results): - if idx < len(batch_items): - item = batch_items[idx] - - # 保存图片并返回URL - try: - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) + # 并行保存所有图片(关键优化:避免串行 IO 阻塞) + save_tasks = [] + result_mapping = {} # 映射:队列索引 -> (processed_image, QueueItem) + + for processed_image, result_idx in batch_results: + if result_idx in item_index_map: + item = item_index_map[result_idx] + result_mapping[result_idx] = (processed_image, item) + # 并行保存 + save_task = loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + save_tasks.append((result_idx, save_task)) + + # 等待所有保存任务完成 + if save_tasks: + save_results = await asyncio.gather( + *[task for _, task in save_tasks], + return_exceptions=True + ) + + # 按完成顺序设置 Future 结果(流式返回) + for (result_idx, _), save_result in zip(save_tasks, save_results): + if result_idx in result_mapping: + processed_image, item = result_mapping[result_idx] - result = { - "status": "success", - "image_url": image_url - } - - if not item.future.done(): - item.future.set_result(result) - except Exception as e: - error_msg = f"处理图片失败: {str(e)}" - logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}") - if not item.future.done(): - item.future.set_exception(Exception(error_msg)) + if isinstance(save_result, Exception): + error_msg = f"保存图片失败: {str(save_result)}" + logger.error(f"队列项 {item.request_id} 保存失败: {error_msg}") + if not item.future.done(): + item.future.set_exception(Exception(error_msg)) + else: + result = { + "status": "success", + "image_url": save_result + } + if not item.future.done(): + item.future.set_result(result) # 处理任何未完成的Future(理论上不应该发生) for item in batch_items: if not item.future.done(): item.future.set_exception(Exception("批处理结果不完整")) + except RuntimeError as e: + # CUDA OOM 错误,降级处理 + error_msg = str(e) + if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): + logger.warning(f"批处理显存不足,降级到单张处理: {error_msg[:100]}") + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + # 降级:单张处理 + for item in batch_items: + try: + result_data = await self.process_image(item.image) + if isinstance(result_data, dict): + if not item.future.done(): + item.future.set_result(result_data) + else: + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, result_data + ) + if not item.future.done(): + item.future.set_result({ + "status": "success", + "image_url": image_url + }) + except Exception as e2: + if not item.future.done(): + item.future.set_exception(Exception(f"降级处理失败: {str(e2)}")) + else: + # 其他 RuntimeError + error_msg = f"批处理失败: {str(e)}" + logger.error(error_msg, exc_info=True) + for item in batch_items: + if not item.future.done(): + item.future.set_exception(Exception(error_msg)) except Exception as e: error_msg = f"批处理队列项失败: {str(e)}" logger.error(error_msg, exc_info=True) @@ -501,28 +571,26 @@ class RmbgService: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): - """批量处理多个URL图像,流水线批处理模式(下载和处理并行)""" + """批量处理多个URL图像,统一全局 batcher 模式(支持跨用户合批)""" total = len(urls) success_count = 0 error_count = 0 batch_start_time = time.time() - batch_size = settings.batch_size loop = asyncio.get_event_loop() + # 为本次 batch 请求生成唯一 request_id + batch_request_id = uuid.uuid4().hex[:16] - # 流水线队列:收集已下载的图片 - download_queue = asyncio.Queue() - download_complete = asyncio.Event() - download_done_count = 0 - download_error_count = 0 + # 存储每张图片的 Future 和元数据 + image_futures = {} # index -> (future, url_str) - - async def download_image_async(index, url): - """异步下载图片并放入队列""" - nonlocal download_done_count, download_error_count + async def download_and_queue_image(index, url): + """下载图片并推入全局队列(跨用户合批)""" + nonlocal error_count url_str = str(url) try: + # 下载图片 if self.is_valid_url(url_str): temp_file = await self.download_image(url_str) image = await loop.run_in_executor( @@ -534,121 +602,75 @@ class RmbgService: self.executor, lambda: Image.open(url_str).convert("RGB") ) - # 下载成功,放入队列 - await download_queue.put((image, image.size, index, url_str, None)) - download_done_count += 1 + # 创建 Future 用于接收结果 + future = asyncio.Future() + + # 创建队列项,推入全局队列(跨用户合批) + queue_item = QueueItem( + image=image, + image_size=image.size, + request_id=f"{batch_request_id}_{index}", + future=future, + created_at=time.time(), + url_str=url_str, # 保存原始 URL + batch_index=index # 保存 batch 中的索引 + ) + + # 推入全局队列(与其他用户的请求一起合批) + await self.queue.put(queue_item) + + # 保存 Future 和元数据 + image_futures[index] = (future, url_str) except Exception as e: - # 下载失败,也放入队列(标记为错误) - await download_queue.put((None, None, index, url_str, str(e))) - download_error_count += 1 - download_done_count += 1 - finally: - # 所有下载任务完成 - if download_done_count >= total: - download_complete.set() + # 下载失败,直接创建失败的 Future + error_count += 1 + future = asyncio.Future() + future.set_exception(Exception(f"下载失败: {str(e)}")) + image_futures[index] = (future, url_str) - # 启动所有下载任务(并行下载) + # 并行下载所有图片并推入队列 download_tasks = [ - asyncio.create_task(download_image_async(i, url)) + asyncio.create_task(download_and_queue_image(i, url)) for i, url in enumerate(urls, 1) ] - # 流水线批处理任务:收集队列中的图片,达到batch_size或超时后立即处理 - completed_order = 0 - pending_batch = [] - batch_collect_timeout = 0.5 # 批处理收集超时(秒) - max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理 + # 等待所有下载任务完成 + await asyncio.gather(*download_tasks, return_exceptions=True) - async def process_pending_batch(force=False): - """处理待处理的批次""" - nonlocal pending_batch, completed_order, success_count, error_count + # 按完成顺序流式返回结果 + completed_order = 0 + # 建立 Future -> (index, url_str) 的映射,便于在完成时快速反查 + future_meta = {} + for idx, (fut, url_str) in image_futures.items(): + future_meta[fut] = (idx, url_str) + pending_tasks = set(future_meta.keys()) + + # 使用 wait 循环实现流式返回,避免等待最慢的 + while pending_tasks: + done, pending_tasks = await asyncio.wait( + pending_tasks, + return_when=asyncio.FIRST_COMPLETED + ) - if not pending_batch: - return - - # 分离成功和失败的图片 - valid_items = [] - failed_items = [] - - for item in pending_batch: - image, image_size, index, url_str, error = item - if error: - failed_items.append((index, url_str, error)) - else: - valid_items.append((image, image_size, index, url_str)) - - # 先处理下载失败的 - for index, url_str, error in failed_items: - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": error, - "message": f"下载失败: {error}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - - if not valid_items: - pending_batch = [] - return - - # 处理成功的图片 - try: - # 判断是否尝试一次性处理 - use_single_batch = len(valid_items) <= max_single_batch and force + for fut in done: + index, url_str = future_meta[fut] - if use_single_batch: - # 尝试一次性处理所有图片 - images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items] + try: + result_data = fut.result() - batch_results = await self.process_batch_images(images_with_info) - - # 并行保存 - save_tasks = [] - result_mapping = {} - - for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in valid_items if idx == index) - result_mapping[index] = (processed_image, url_str) - save_task = loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) - save_tasks.append((index, save_task)) - - save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) - - for (index, _), image_url in zip(save_tasks, save_results): - if isinstance(image_url, Exception): - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "error", - "error": str(image_url), - "message": f"保存图片失败: {str(image_url)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - else: - completed_order += 1 + if isinstance(result_data, dict): + status = result_data.get("status", "success") + image_url = result_data.get("image_url") + error_msg = result_data.get("error") + + completed_order += 1 + if status == "success" and image_url: success_count += 1 result = { "index": index, "total": total, - "original_url": result_mapping[index][1], + "original_url": url_str, "status": "success", "image_url": image_url, "message": "处理成功", @@ -657,137 +679,40 @@ class RmbgService: "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } - yield result - else: - # 分批处理 - for batch_start in range(0, len(valid_items), batch_size): - batch_end = min(batch_start + batch_size, len(valid_items)) - batch_items = valid_items[batch_start:batch_end] - - images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items] - - batch_results = await self.process_batch_images(images_with_info) - - # 并行保存 - save_tasks = [] - result_mapping = {} - - for processed_image, index in batch_results: - url_str = next(url for _, _, idx, url in batch_items if idx == index) - result_mapping[index] = (processed_image, url_str) - save_task = loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) - save_tasks.append((index, save_task)) - - save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) - - for (index, _), image_url in zip(save_tasks, save_results): - if isinstance(image_url, Exception): - error_count += 1 - completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "error", - "error": str(image_url), - "message": f"保存图片失败: {str(image_url)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - else: - completed_order += 1 - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": result_mapping[index][1], - "status": "success", - "image_url": image_url, - "message": "处理成功", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - - except RuntimeError as e: - # CUDA OOM错误,降级处理 - error_msg = str(e) - if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): - logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}") - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - # 降级到单张处理 - for image, image_size, index, url_str, _ in valid_items: - try: - result_data = await self.process_image(image) - if isinstance(result_data, dict): - image_url = result_data["image_url"] - else: - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, result_data - ) - completed_order += 1 - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "success", - "image_url": image_url, - "message": "处理成功(降级模式)", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result - except Exception as e2: + else: error_count += 1 - completed_order += 1 result = { "index": index, "total": total, "original_url": url_str, "status": "error", - "error": str(e2), - "message": f"处理失败: {str(e2)}", + "error": error_msg or "处理失败", + "message": error_msg or "处理失败", "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } - yield result - else: - # 其他错误 - logger.error(f"批处理失败: {error_msg}") - for _, _, index, url_str, _ in valid_items: - error_count += 1 + else: + # 兼容非 dict 返回 completed_order += 1 + success_count += 1 result = { "index": index, "total": total, "original_url": url_str, - "status": "error", - "error": error_msg, - "message": f"批处理失败: {error_msg}", + "status": "success", + "image_url": result_data, + "message": "处理成功", "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } - yield result - except Exception as e: - logger.error(f"批处理失败: {str(e)}") - for _, _, index, url_str, _ in valid_items: + + yield result + + except Exception as e: error_count += 1 completed_order += 1 result = { @@ -796,60 +721,13 @@ class RmbgService: "original_url": url_str, "status": "error", "error": str(e), - "message": f"批处理失败: {str(e)}", + "message": f"处理失败: {str(e)}", "success_count": success_count, "error_count": error_count, "completed_order": completed_order, "batch_elapsed": round(time.time() - batch_start_time, 2) } yield result - - pending_batch = [] - - # 流水线处理:收集队列中的图片,达到batch_size或超时后立即处理 - while True: - try: - # 等待队列中有新图片,或超时 - try: - item = await asyncio.wait_for( - download_queue.get(), - timeout=batch_collect_timeout - ) - pending_batch.append(item) - except asyncio.TimeoutError: - # 超时,处理当前批次 - if pending_batch: - async for result in process_pending_batch(): - yield result - # 检查是否所有下载都完成 - if download_complete.is_set(): - break - continue - - # 如果达到batch_size,立即处理 - if len(pending_batch) >= batch_size: - async for result in process_pending_batch(): - yield result - - # 检查是否所有下载都完成 - if download_complete.is_set() and download_queue.empty(): - # 处理剩余的图片 - if pending_batch: - async for result in process_pending_batch(force=True): - yield result - break - - except Exception as e: - logger.error(f"流水线处理出错: {str(e)}", exc_info=True) - break - - # 等待所有下载任务完成 - await asyncio.gather(*download_tasks, return_exceptions=True) - - # 确保所有结果都已处理 - if pending_batch: - async for result in process_pending_batch(force=True): - yield result def is_valid_url(self, url): """验证URL是否有效""" diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index 5c0585d..e28c46c 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -29,7 +29,11 @@ class Settings(BaseSettings): # 模型配置 model_path: str = "./models" # 本地模型文件夹路径(包含 model.safetensors 和 config.json) - # 并发控制配置 + # HTTP 客户端连接池配置(用于下载图片) + http_max_connections: int = 200 # httpx 最大并发连接数(根据上行带宽和对端能力调整) + http_max_keepalive_connections: int = 100 # httpx 最大 keep-alive 空闲连接数 + + # 并发控制配置(推理侧) max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) batch_size: int = 8 # GPU批处理大小(模型显存占用较大,8是安全值,16会导致OOM)