diff --git a/apps/rmbg/app.py b/apps/rmbg/app.py index 4a1399d..2a7743e 100644 --- a/apps/rmbg/app.py +++ b/apps/rmbg/app.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from settings import settings -from api import router +from api import router, service app = FastAPI( title="Remove Background", @@ -11,6 +11,19 @@ app = FastAPI( # 注册路由 app.include_router(router) + +@app.on_event("startup") +async def startup_event(): + """应用启动时初始化队列批处理机制""" + if settings.enable_queue_batch: + await service._start_queue_processor() + + +@app.on_event("shutdown") +async def shutdown_event(): + """应用关闭时清理资源""" + await service.cleanup() + if __name__ == "__main__": import uvicorn uvicorn.run( diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 1608720..9ba22eb 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -14,6 +14,8 @@ import uuid import httpx import logging from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Optional, Dict, Any from settings import settings logging.basicConfig( @@ -26,6 +28,17 @@ warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) torch.set_float32_matmul_precision("high") + +@dataclass +class QueueItem: + """队列项数据结构""" + image: Image.Image + image_size: tuple + request_id: str + future: asyncio.Future + created_at: float + + class RmbgService: def __init__(self, model_path="zhengpeng7/BiRefNet"): """初始化背景移除服务""" @@ -44,17 +57,52 @@ class RmbgService: ) ) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) + + # 队列聚合机制(方案B) + self.queue: asyncio.Queue = asyncio.Queue() + self.queue_task: Optional[asyncio.Task] = None + self.queue_running = False + self._load_model() + # 队列任务将在 FastAPI startup 事件中启动 def _load_model(self): """加载模型""" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) - self.model = self.model.to(self.device) + + # 优化显存占用:使用半精度加载(如果支持) + # 注意:某些模型可能不支持半精度,需要测试 + try: + # 尝试使用半精度加载,可以减少约50%的显存占用 + self.model = AutoModelForImageSegmentation.from_pretrained( + self.model_path, + trust_remote_code=True, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 + ) + self.model = self.model.to(self.device) + if torch.cuda.is_available(): + self.model = self.model.half() # 转换为半精度 + logger.info("模型已使用半精度(FP16)加载,显存占用减少约50%") + except Exception as e: + # 如果半精度加载失败,降级到全精度 + logger.warning(f"半精度加载失败,使用全精度: {str(e)}") + self.model = AutoModelForImageSegmentation.from_pretrained( + self.model_path, + trust_remote_code=True + ) + self.model = self.model.to(self.device) + self.model.eval() + + # 优化显存分配策略:减少碎片化 + if torch.cuda.is_available(): + # 设置显存分配器,减少碎片化 + os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True') + torch.cuda.empty_cache() + logger.info(f"模型加载完成,当前显存占用: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") def _process_image_sync(self, image): - """同步处理图像,移除背景""" + """同步处理图像,移除背景(单张)""" image_size = image.size transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), @@ -62,6 +110,9 @@ class RmbgService: transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) input_images = transform_image(image).unsqueeze(0).to(self.device) + # 如果模型是半精度,输入也转换为半精度 + if next(self.model.parameters()).dtype == torch.float16: + input_images = input_images.half() with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() @@ -77,11 +128,241 @@ class RmbgService: return image + def _process_batch_images_sync(self, images_with_info): + """批量处理图像(批处理模式,充分利用GPU并行能力)""" + if not images_with_info: + return [] + + transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + # 批处理前清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + batch_tensors = [] + for image, image_size, index in images_with_info: + batch_tensors.append(transform_image(image)) + + input_batch = torch.stack(batch_tensors).to(self.device) + # 如果模型是半精度,输入也转换为半精度 + if next(self.model.parameters()).dtype == torch.float16: + input_batch = input_batch.half() + # 释放 batch_tensors 占用的 CPU 内存 + del batch_tensors + + with torch.no_grad(): + model_output = self.model(input_batch) + if isinstance(model_output, (list, tuple)): + preds = model_output[-1].sigmoid().cpu() + else: + preds = model_output.sigmoid().cpu() + + # 立即释放 GPU 上的 input_batch 和 model_output + del input_batch + if isinstance(model_output, (list, tuple)): + del model_output + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + results = [] + for i, (image, image_size, index) in enumerate(images_with_info): + if len(preds.shape) == 4: + pred = preds[i].squeeze() + elif len(preds.shape) == 3: + pred = preds[i] + else: + pred = preds[i].squeeze() + + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + result_image = image.copy() + result_image.putalpha(mask) + results.append((result_image, index)) + + # 释放 preds + del preds + + # 批处理后再次清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + return results + async def process_image(self, image): - """异步处理图像,移除背景""" - return await asyncio.get_event_loop().run_in_executor( - self.executor, self._process_image_sync, image + """异步处理图像,移除背景(单张)- 使用队列批处理模式""" + if settings.enable_queue_batch and self.queue_running: + return await self._process_image_via_queue(image) + else: + # 降级到单张处理 + return await asyncio.get_event_loop().run_in_executor( + self.executor, self._process_image_sync, image + ) + + async def _process_image_via_queue(self, image): + """通过队列批处理模式处理单张图像""" + request_id = uuid.uuid4().hex[:10] + future = asyncio.Future() + + queue_item = QueueItem( + image=image, + image_size=image.size, + request_id=request_id, + future=future, + created_at=time.time() ) + + try: + await self.queue.put(queue_item) + + # 等待处理结果,带超时 + try: + result = await asyncio.wait_for(future, timeout=settings.request_timeout) + return result + except asyncio.TimeoutError: + future.cancel() + raise Exception(f"处理超时(超过{settings.request_timeout}秒)") + + except Exception as e: + if not future.done(): + future.set_exception(e) + raise + + async def process_batch_images(self, images_with_info): + """异步批量处理图像(批处理模式)""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, self._process_batch_images_sync, images_with_info + ) + + async def _start_queue_processor(self): + """启动队列批处理后台任务(异步方法,需要在事件循环中调用)""" + if self.queue_running: + return + + self.queue_running = True + self.queue_task = asyncio.create_task(self._queue_processor()) + logger.info("队列批处理机制已启动") + + async def _queue_processor(self): + """后台队列批处理任务(核心逻辑)""" + logger.info("队列批处理任务开始运行") + + while self.queue_running: + try: + # 收集一批请求 + batch_items = await self._collect_batch_items() + + if not batch_items: + continue + + # 处理这批请求 + await self._process_batch_queue_items(batch_items) + + except Exception as e: + logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True) + await asyncio.sleep(0.1) # 出错后短暂等待 + + async def _collect_batch_items(self): + """收集一批队列项,达到batch_size或超时后返回""" + batch_items = [] + batch_size = settings.batch_size + collect_interval = settings.batch_collect_interval + collect_timeout = settings.batch_collect_timeout + + # 先尝试获取第一个请求(阻塞等待) + try: + first_item = await asyncio.wait_for( + self.queue.get(), + timeout=collect_timeout + ) + batch_items.append(first_item) + except asyncio.TimeoutError: + # 超时,返回空列表 + return [] + + # 继续收集更多请求,直到达到batch_size或超时 + start_time = time.time() + + while len(batch_items) < batch_size: + elapsed = time.time() - start_time + + # 如果已经超时,立即处理当前收集的请求 + if elapsed >= collect_timeout: + break + + # 尝试在剩余时间内获取更多请求 + remaining_time = min(collect_interval, collect_timeout - elapsed) + + try: + item = await asyncio.wait_for( + self.queue.get(), + timeout=remaining_time + ) + batch_items.append(item) + except asyncio.TimeoutError: + # 超时,处理已收集的请求 + break + + return batch_items + + async def _process_batch_queue_items(self, batch_items): + """处理一批队列项""" + if not batch_items: + return + + loop = asyncio.get_event_loop() + + try: + # 准备批处理数据 + images_with_info = [] + for idx, item in enumerate(batch_items): + images_with_info.append((item.image, item.image_size, idx)) + + # 执行批处理 + batch_results = await self.process_batch_images(images_with_info) + + # 将结果返回给对应的Future + for idx, (processed_image, _) in enumerate(batch_results): + if idx < len(batch_items): + item = batch_items[idx] + + # 保存图片并返回URL + try: + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + + result = { + "status": "success", + "image_url": image_url + } + + if not item.future.done(): + item.future.set_result(result) + except Exception as e: + error_msg = f"处理图片失败: {str(e)}" + logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}") + if not item.future.done(): + item.future.set_exception(Exception(error_msg)) + + # 处理任何未完成的Future(理论上不应该发生) + for item in batch_items: + if not item.future.done(): + item.future.set_exception(Exception("批处理结果不完整")) + + except Exception as e: + error_msg = f"批处理队列项失败: {str(e)}" + logger.error(error_msg, exc_info=True) + + # 所有请求都标记为失败 + for item in batch_items: + if not item.future.done(): + item.future.set_exception(Exception(error_msg)) def save_image_to_file(self, image): """保存图片到文件并返回URL""" @@ -165,15 +446,16 @@ class RmbgService: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): - """批量处理多个URL图像,流水线并发模式""" + """批量处理多个URL图像,批处理模式(推荐方案)""" total = len(urls) success_count = 0 error_count = 0 batch_start_time = time.time() + batch_size = settings.batch_size loop = asyncio.get_event_loop() - async def download_and_process(index, url): - """下载并处理单张图片""" + async def download_image_async(index, url): + """异步下载图片""" url_str = str(url) try: if self.is_valid_url(url_str): @@ -186,53 +468,231 @@ class RmbgService: image = await loop.run_in_executor( self.executor, lambda: Image.open(url_str).convert("RGB") ) - - processed_image = await self.process_image(image) - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, processed_image - ) - - return { - "index": index, - "total": total, - "original_url": url_str, - "status": "success", - "image_url": image_url, - "message": "处理成功" - } - + return (image, image.size, index, url_str, None) except Exception as e: - logger.error(f"处理失败 (index={index}): {str(e)}") - return { + return (None, None, index, url_str, str(e)) + + download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)] + downloaded_images = await asyncio.gather(*download_tasks) + + valid_images = [] + failed_results = {} + + for item in downloaded_images: + image, image_size, index, url_str, error = item + if error: + failed_results[index] = { "index": index, "total": total, "original_url": url_str, "status": "error", - "error": str(e), - "message": f"处理失败: {str(e)}" + "error": error, + "message": f"下载失败: {error}" } - - tasks = [ - download_and_process(i, url) - for i, url in enumerate(urls, 1) - ] - - completed_order = 0 - for coro in asyncio.as_completed(tasks): - result = await coro - completed_order += 1 - - if result["status"] == "success": - success_count += 1 else: - error_count += 1 - + valid_images.append((image, image_size, index, url_str)) + + for index, result in failed_results.items(): + error_count += 1 result["success_count"] = success_count result["error_count"] = error_count - result["completed_order"] = completed_order + result["completed_order"] = len(failed_results) result["batch_elapsed"] = round(time.time() - batch_start_time, 2) - yield result + + completed_order = len(failed_results) + + # 如果图片数量不太多(<= batch_size * 2),尝试一次性处理所有图片(避免分批,提升并发) + # 对于13张图片,batch_size=8,13 <= 16,会尝试一次性处理 + # 如果显存不足,自动降级到分批处理 + max_single_batch = batch_size * 2 # 允许最多2倍batch_size + use_single_batch = len(valid_images) <= max_single_batch + + if use_single_batch: + try: + images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images] + batch_results = await self.process_batch_images(images_with_info) + + # 并行保存所有图片 + save_tasks = [] + result_mapping = {} + + for processed_image, index in batch_results: + url_str = next(url for _, _, idx, url in valid_images if idx == index) + result_mapping[index] = (processed_image, url_str) + + save_task = loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + save_tasks.append((index, save_task)) + + save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) + + for (index, _), image_url in zip(save_tasks, save_results): + if isinstance(image_url, Exception): + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "error", + "error": str(image_url), + "message": f"保存图片失败: {str(image_url)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + else: + completed_order += 1 + success_count += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "success", + "image_url": image_url, + "message": "处理成功", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + except RuntimeError as e: + # 检测CUDA OOM错误,自动降级到分批处理 + error_msg = str(e) + if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower(): + logger.warning(f"一次性处理显存不足,自动降级到分批处理: {error_msg[:100]}") + # 清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + # 继续执行分批处理逻辑(不return,继续到else分支) + use_single_batch = False + else: + # 其他错误,直接返回 + logger.error(f"批处理失败: {error_msg}") + for _, _, index, url_str in valid_images: + completed_order += 1 + error_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": error_msg, + "message": f"批处理失败: {error_msg}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + return + except Exception as e: + # 其他异常,直接返回错误 + logger.error(f"批处理失败: {str(e)}") + for _, _, index, url_str in valid_images: + completed_order += 1 + error_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": str(e), + "message": f"批处理失败: {str(e)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + return + + # 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理 + if not use_single_batch: + # 多批处理:串行处理批次,但每个批次内部并行保存 + for batch_start in range(0, len(valid_images), batch_size): + batch_end = min(batch_start + batch_size, len(valid_images)) + batch_images = valid_images[batch_start:batch_end] + + try: + images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images] + batch_results = await self.process_batch_images(images_with_info) + + # 并行保存所有图片 + save_tasks = [] + result_mapping = {} + + for processed_image, index in batch_results: + url_str = next(url for _, _, idx, url in batch_images if idx == index) + result_mapping[index] = (processed_image, url_str) + + save_task = loop.run_in_executor( + self.executor, self.save_image_to_file, processed_image + ) + save_tasks.append((index, save_task)) + + # 并行执行所有保存任务 + save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True) + + # 按顺序返回结果 + for (index, _), image_url in zip(save_tasks, save_results): + if isinstance(image_url, Exception): + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "error", + "error": str(image_url), + "message": f"保存图片失败: {str(image_url)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + else: + completed_order += 1 + success_count += 1 + result = { + "index": index, + "total": total, + "original_url": result_mapping[index][1], + "status": "success", + "image_url": image_url, + "message": "处理成功", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + + except Exception as e: + logger.error(f"批处理失败: {str(e)}") + for _, _, index, url_str in batch_images: + completed_order += 1 + error_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": str(e), + "message": f"批处理失败: {str(e)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result def is_valid_url(self, url): """验证URL是否有效""" @@ -265,6 +725,31 @@ class RmbgService: async def cleanup(self): """清理资源""" + # 停止队列处理任务 + if self.queue_running: + self.queue_running = False + if self.queue_task: + self.queue_task.cancel() + try: + await self.queue_task + except asyncio.CancelledError: + pass + logger.info("队列批处理机制已停止") + + # 处理队列中剩余的请求 + remaining_items = [] + while not self.queue.empty(): + try: + item = self.queue.get_nowait() + remaining_items.append(item) + except asyncio.QueueEmpty: + break + + # 标记剩余请求为失败 + for item in remaining_items: + if not item.future.done(): + item.future.set_exception(Exception("服务关闭,请求被取消")) + await self.http_client.aclose() self.executor.shutdown(wait=True) if torch.cuda.is_available(): diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index e573c8d..403ab85 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -28,6 +28,13 @@ class Settings(BaseSettings): # 并发控制配置 max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) + batch_size: int = 8 # GPU批处理大小(BiRefNet模型显存占用较大,8是安全值,16会导致OOM) + + # 队列聚合配置(方案B:批处理+队列模式) + batch_collect_interval: float = 0.05 # 批处理收集间隔(秒),50ms收集一次,平衡延迟和吞吐量 + batch_collect_timeout: float = 0.5 # 批处理收集超时(秒),即使未满batch_size,500ms后也处理 + request_timeout: float = 30.0 # 单个请求超时时间(秒) + enable_queue_batch: bool = True # 是否启用队列批处理模式(推荐开启) class Config: env_file = ".env"