diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 2d5fcd6..1608720 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -8,7 +8,6 @@ from transformers import AutoModelForImageSegmentation import time import warnings import gc -import base64 import asyncio import io import uuid @@ -45,22 +44,7 @@ class RmbgService: ) ) self.executor = ThreadPoolExecutor(max_workers=settings.max_workers) - self._gpu_semaphore = None - self._max_gpu_concurrent = settings.max_gpu_concurrent self._load_model() - - @property - def gpu_semaphore(self): - """延迟初始化GPU信号量""" - if self._gpu_semaphore is None: - if self._max_gpu_concurrent == 0: - return None - try: - loop = asyncio.get_event_loop() - self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent) - except RuntimeError: - self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent) - return self._gpu_semaphore def _load_model(self): """加载模型""" @@ -93,95 +77,11 @@ class RmbgService: return image - def _process_batch_images_sync(self, images_with_info): - """ - 批量处理图像(充分利用GPU并行能力) - - Args: - images_with_info: [(image, image_size, index), ...] 图像和元信息列表 - - Returns: - [(processed_image, index), ...] 处理后的图像和索引 - """ - if not images_with_info: - return [] - - try: - transform_image = transforms.Compose([ - transforms.Resize((1024, 1024)), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) - - batch_images = [] - for image, image_size, index in images_with_info: - try: - transformed = transform_image(image) - batch_images.append(transformed) - except Exception as e: - logger.error(f"图片转换失败 (index={index}): {str(e)}") - raise - - if not batch_images: - raise Exception("没有有效的图片可以处理") - - input_batch = torch.stack(batch_images).to(self.device) - - with torch.no_grad(): - model_output = self.model(input_batch) - if isinstance(model_output, (list, tuple)): - preds = model_output[-1].sigmoid().cpu() - else: - preds = model_output.sigmoid().cpu() - - results = [] - for i, (image, image_size, index) in enumerate(images_with_info): - try: - if len(preds.shape) == 4: - pred = preds[i].squeeze() - elif len(preds.shape) == 3: - pred = preds[i] - else: - pred = preds[i].squeeze() - - pred_pil = transforms.ToPILImage()(pred) - mask = pred_pil.resize(image_size) - result_image = image.copy() - result_image.putalpha(mask) - results.append((result_image, index)) - except Exception as e: - logger.error(f"处理预测结果失败 (index={index}): {str(e)}") - raise - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - gc.collect() - - return results - - except Exception as e: - logger.error(f"批处理失败: {str(e)}") - raise - async def process_image(self, image): """异步处理图像,移除背景""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor(self.executor, self._process_image_sync, image) - - async def process_batch_images(self, images_with_info): - """异步批量处理图像""" - loop = asyncio.get_event_loop() - return await loop.run_in_executor( - self.executor, - self._process_batch_images_sync, - images_with_info + return await asyncio.get_event_loop().run_in_executor( + self.executor, self._process_image_sync, image ) - - def image_to_base64(self, image): - """将PIL Image对象转换为base64字符串""" - buffered = io.BytesIO() - image.save(buffered, format="PNG") - return base64.b64encode(buffered.getvalue()).decode('utf-8') def save_image_to_file(self, image): """保存图片到文件并返回URL""" @@ -215,16 +115,12 @@ class RmbgService: loop = asyncio.get_event_loop() image = await loop.run_in_executor( - self.executor, - lambda: Image.open(image_path).convert("RGB") + self.executor, lambda: Image.open(image_path).convert("RGB") ) image_no_bg = await self.process_image(image) - image_url = await loop.run_in_executor( - self.executor, - self.save_image_to_file, - image_no_bg + self.executor, self.save_image_to_file, image_no_bg ) return { @@ -252,16 +148,12 @@ class RmbgService: try: loop = asyncio.get_event_loop() image = await loop.run_in_executor( - self.executor, - lambda: Image.open(io.BytesIO(file_content)).convert("RGB") + self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB") ) image_no_bg = await self.process_image(image) - image_url = await loop.run_in_executor( - self.executor, - self.save_image_to_file, - image_no_bg + self.executor, self.save_image_to_file, image_no_bg ) return { @@ -278,6 +170,7 @@ class RmbgService: success_count = 0 error_count = 0 batch_start_time = time.time() + loop = asyncio.get_event_loop() async def download_and_process(index, url): """下载并处理单张图片""" @@ -285,24 +178,18 @@ class RmbgService: try: if self.is_valid_url(url_str): temp_file = await self.download_image(url_str) - image = await asyncio.get_event_loop().run_in_executor( - self.executor, - lambda: Image.open(temp_file).convert("RGB") + image = await loop.run_in_executor( + self.executor, lambda: Image.open(temp_file).convert("RGB") ) os.unlink(temp_file) else: - image = await asyncio.get_event_loop().run_in_executor( - self.executor, - lambda: Image.open(url_str).convert("RGB") + image = await loop.run_in_executor( + self.executor, lambda: Image.open(url_str).convert("RGB") ) processed_image = await self.process_image(image) - - loop = asyncio.get_event_loop() image_url = await loop.run_in_executor( - self.executor, - self.save_image_to_file, - processed_image + self.executor, self.save_image_to_file, processed_image ) return { @@ -369,9 +256,7 @@ class RmbgService: loop = asyncio.get_event_loop() temp_file_path = await loop.run_in_executor( - self.executor, - write_temp_file, - response.content + self.executor, write_temp_file, response.content ) return temp_file_path diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py index 6c91f41..e573c8d 100644 --- a/apps/rmbg/settings.py +++ b/apps/rmbg/settings.py @@ -28,7 +28,6 @@ class Settings(BaseSettings): # 并发控制配置 max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30) - max_gpu_concurrent: int = 0 # GPU最大并发数(0表示不限制,根据显存大小设置,24GB显存建议10-15) class Config: env_file = ".env"