diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 9127e4e..2d5fcd6 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -13,13 +13,18 @@ import asyncio import io import uuid import httpx +import logging +from concurrent.futures import ThreadPoolExecutor from settings import settings -# 关闭不必要的警告 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=FutureWarning) - -# 设置torch精度 torch.set_float32_matmul_precision("high") class RmbgService: @@ -30,25 +35,43 @@ class RmbgService: self.device = None self.save_dir = settings.save_dir self.download_url = settings.download_url - # 确保保存目录存在 os.makedirs(self.save_dir, exist_ok=True) - # 创建异步HTTP客户端(复用连接,提高性能) - self.http_client = httpx.AsyncClient(timeout=30.0, limits=httpx.Limits(max_keepalive_connections=20)) + + self.http_client = httpx.AsyncClient( + timeout=30.0, + limits=httpx.Limits( + max_keepalive_connections=50, + max_connections=100 + ) + ) + self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) + self._gpu_semaphore = None + self._max_gpu_concurrent = settings.max_gpu_concurrent self._load_model() + + @property + def gpu_semaphore(self): + """延迟初始化GPU信号量""" + if self._gpu_semaphore is None: + if self._max_gpu_concurrent == 0: + return None + try: + loop = asyncio.get_event_loop() + self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent) + except RuntimeError: + self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent) + return self._gpu_semaphore def _load_model(self): """加载模型""" - # 设置设备 self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - t0 = time.time() self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) self.model = self.model.to(self.device) - self.model.eval() # 设置为评估模式 + self.model.eval() def _process_image_sync(self, image): - """同步处理图像,移除背景(内部方法,在线程池中执行)""" + """同步处理图像,移除背景""" image_size = image.size - # 转换图像 transform_image = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), @@ -56,30 +79,103 @@ class RmbgService: ]) input_images = transform_image(image).unsqueeze(0).to(self.device) - # 推理 with torch.no_grad(): preds = self.model(input_images)[-1].sigmoid().cpu() - # 处理预测结果 pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(image_size) - - # 添加透明通道 image.putalpha(mask) - # 清理显存 if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() return image + def _process_batch_images_sync(self, images_with_info): + """ + 批量处理图像(充分利用GPU并行能力) + + Args: + images_with_info: [(image, image_size, index), ...] 图像和元信息列表 + + Returns: + [(processed_image, index), ...] 处理后的图像和索引 + """ + if not images_with_info: + return [] + + try: + transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + batch_images = [] + for image, image_size, index in images_with_info: + try: + transformed = transform_image(image) + batch_images.append(transformed) + except Exception as e: + logger.error(f"图片转换失败 (index={index}): {str(e)}") + raise + + if not batch_images: + raise Exception("没有有效的图片可以处理") + + input_batch = torch.stack(batch_images).to(self.device) + + with torch.no_grad(): + model_output = self.model(input_batch) + if isinstance(model_output, (list, tuple)): + preds = model_output[-1].sigmoid().cpu() + else: + preds = model_output.sigmoid().cpu() + + results = [] + for i, (image, image_size, index) in enumerate(images_with_info): + try: + if len(preds.shape) == 4: + pred = preds[i].squeeze() + elif len(preds.shape) == 3: + pred = preds[i] + else: + pred = preds[i].squeeze() + + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + result_image = image.copy() + result_image.putalpha(mask) + results.append((result_image, index)) + except Exception as e: + logger.error(f"处理预测结果失败 (index={index}): {str(e)}") + raise + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + return results + + except Exception as e: + logger.error(f"批处理失败: {str(e)}") + raise + async def process_image(self, image): - """异步处理图像,移除背景(在线程池中执行同步操作)""" - # 将同步的GPU操作放到线程池中执行,避免阻塞事件循环 + """异步处理图像,移除背景""" loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, self._process_image_sync, image) + return await loop.run_in_executor(self.executor, self._process_image_sync, image) + + async def process_batch_images(self, images_with_info): + """异步批量处理图像""" + loop = asyncio.get_event_loop() + return await loop.run_in_executor( + self.executor, + self._process_batch_images_sync, + images_with_info + ) def image_to_base64(self, image): """将PIL Image对象转换为base64字符串""" @@ -88,23 +184,10 @@ class RmbgService: return base64.b64encode(buffered.getvalue()).decode('utf-8') def save_image_to_file(self, image): - """ - 保存图片到jfile/files目录并返回URL - - Args: - image: PIL Image对象 - - Returns: - 图片URL - """ - # 生成唯一文件名 + """保存图片到文件并返回URL""" filename = f"rmbg_{uuid.uuid4().hex[:10]}.png" file_path = os.path.join(self.save_dir, filename) - - # 保存图片 image.save(file_path, format="PNG") - - # 构建URL image_url = f"{self.download_url}/{filename}" return image_url @@ -120,32 +203,26 @@ class RmbgService: """ temp_file = None try: - # 检查是否是URL if self.is_valid_url(image_path): try: - # 异步下载图片到临时文件 temp_file = await self.download_image(image_path) image_path = temp_file except Exception as e: raise Exception(f"下载图片失败: {e}") - # 验证输入文件是否存在 if not os.path.exists(image_path): raise FileNotFoundError(f"输入图像不存在: {image_path}") - # 加载图像(IO操作,在线程池中执行) loop = asyncio.get_event_loop() image = await loop.run_in_executor( - None, + self.executor, lambda: Image.open(image_path).convert("RGB") ) - # 异步处理图像 image_no_bg = await self.process_image(image) - # 保存图片到文件并获取URL(IO操作,在线程池中执行) image_url = await loop.run_in_executor( - None, + self.executor, self.save_image_to_file, image_no_bg ) @@ -156,7 +233,6 @@ class RmbgService: } finally: - # 清理临时文件 if temp_file and os.path.exists(temp_file): try: os.unlink(temp_file) @@ -174,19 +250,16 @@ class RmbgService: 处理后的图像内容 """ try: - # 从文件内容创建PIL Image对象(IO操作,在线程池中执行) loop = asyncio.get_event_loop() image = await loop.run_in_executor( - None, + self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB") ) - # 异步处理图像 image_no_bg = await self.process_image(image) - # 保存图片到文件并获取URL(IO操作,在线程池中执行) image_url = await loop.run_in_executor( - None, + self.executor, self.save_image_to_file, image_no_bg ) @@ -200,60 +273,77 @@ class RmbgService: raise Exception(f"处理图片失败: {e}") async def process_batch(self, urls): - """ - 批量处理多个URL图像,并发处理并流式返回结果 - - Args: - urls: 图片URL列表 - - Yields: - 每个图片的处理结果(按完成顺序返回) - """ + """批量处理多个URL图像,流水线并发模式""" total = len(urls) success_count = 0 error_count = 0 + batch_start_time = time.time() - # 创建并发任务 - async def process_single_url(index, url): - """处理单个URL的包装函数""" + async def download_and_process(index, url): + """下载并处理单张图片""" + url_str = str(url) try: - url_str = str(url) - result = await self.remove_background(url_str) + if self.is_valid_url(url_str): + temp_file = await self.download_image(url_str) + image = await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: Image.open(temp_file).convert("RGB") + ) + os.unlink(temp_file) + else: + image = await asyncio.get_event_loop().run_in_executor( + self.executor, + lambda: Image.open(url_str).convert("RGB") + ) + + processed_image = await self.process_image(image) + + loop = asyncio.get_event_loop() + image_url = await loop.run_in_executor( + self.executor, + self.save_image_to_file, + processed_image + ) + return { "index": index, "total": total, "original_url": url_str, "status": "success", - "image_url": result["image_url"], + "image_url": image_url, "message": "处理成功" } + except Exception as e: + logger.error(f"处理失败 (index={index}): {str(e)}") return { "index": index, "total": total, - "original_url": str(url), + "original_url": url_str, "status": "error", "error": str(e), "message": f"处理失败: {str(e)}" } - # 创建所有任务 tasks = [ - process_single_url(i, url) + download_and_process(i, url) for i, url in enumerate(urls, 1) ] - # 并发执行所有任务,使用as_completed按完成顺序返回 + completed_order = 0 for coro in asyncio.as_completed(tasks): result = await coro + completed_order += 1 + if result["status"] == "success": success_count += 1 else: error_count += 1 - # 更新统计信息 result["success_count"] = success_count result["error_count"] = error_count + result["completed_order"] = completed_order + result["batch_elapsed"] = round(time.time() - batch_start_time, 2) yield result @@ -271,7 +361,6 @@ class RmbgService: response = await self.http_client.get(url) response.raise_for_status() - # 创建临时文件并写入内容(IO操作,在线程池中执行) def write_temp_file(content): temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') temp_file.write(content) @@ -280,7 +369,7 @@ class RmbgService: loop = asyncio.get_event_loop() temp_file_path = await loop.run_in_executor( - None, + self.executor, write_temp_file, response.content ) @@ -291,9 +380,8 @@ class RmbgService: async def cleanup(self): """清理资源""" - # 关闭HTTP客户端 await self.http_client.aclose() + self.executor.shutdown(wait=True) if torch.cuda.is_available(): torch.cuda.empty_cache() - gc.collect() - print("资源已清理") \ No newline at end of file + gc.collect() \ No newline at end of file diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index 6a3b2bc..6c91f41 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -26,6 +26,10 @@ class Settings(BaseSettings): jingrow_api_key: Optional[str] = None jingrow_api_secret: Optional[str] = None + # 并发控制配置 + max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) + max_gpu_concurrent: int = 0 # GPU最大并发数(0表示不限制,根据显存大小设置,24GB显存建议10-15) + class Config: env_file = ".env"