优化rmbg,删除未使用的变量

This commit is contained in:
jingrow 2025-11-23 05:14:25 +08:00
parent 10fb6084f5
commit 5552b30958
2 changed files with 13 additions and 129 deletions

View File

@ -8,7 +8,6 @@ from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import base64
import asyncio
import io
import uuid
@ -45,22 +44,7 @@ class RmbgService:
)
)
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
self._gpu_semaphore = None
self._max_gpu_concurrent = settings.max_gpu_concurrent
self._load_model()
@property
def gpu_semaphore(self):
"""延迟初始化GPU信号量"""
if self._gpu_semaphore is None:
if self._max_gpu_concurrent == 0:
return None
try:
loop = asyncio.get_event_loop()
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
except RuntimeError:
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
return self._gpu_semaphore
def _load_model(self):
"""加载模型"""
@ -93,95 +77,11 @@ class RmbgService:
return image
def _process_batch_images_sync(self, images_with_info):
"""
批量处理图像充分利用GPU并行能力
Args:
images_with_info: [(image, image_size, index), ...] 图像和元信息列表
Returns:
[(processed_image, index), ...] 处理后的图像和索引
"""
if not images_with_info:
return []
try:
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
batch_images = []
for image, image_size, index in images_with_info:
try:
transformed = transform_image(image)
batch_images.append(transformed)
except Exception as e:
logger.error(f"图片转换失败 (index={index}): {str(e)}")
raise
if not batch_images:
raise Exception("没有有效的图片可以处理")
input_batch = torch.stack(batch_images).to(self.device)
with torch.no_grad():
model_output = self.model(input_batch)
if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu()
else:
preds = model_output.sigmoid().cpu()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
try:
if len(preds.shape) == 4:
pred = preds[i].squeeze()
elif len(preds.shape) == 3:
pred = preds[i]
else:
pred = preds[i].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
results.append((result_image, index))
except Exception as e:
logger.error(f"处理预测结果失败 (index={index}): {str(e)}")
raise
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return results
except Exception as e:
logger.error(f"批处理失败: {str(e)}")
raise
async def process_image(self, image):
"""异步处理图像,移除背景"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.executor, self._process_image_sync, image)
async def process_batch_images(self, images_with_info):
"""异步批量处理图像"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor,
self._process_batch_images_sync,
images_with_info
return await asyncio.get_event_loop().run_in_executor(
self.executor, self._process_image_sync, image
)
def image_to_base64(self, image):
"""将PIL Image对象转换为base64字符串"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def save_image_to_file(self, image):
"""保存图片到文件并返回URL"""
@ -215,16 +115,12 @@ class RmbgService:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor,
lambda: Image.open(image_path).convert("RGB")
self.executor, lambda: Image.open(image_path).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor,
self.save_image_to_file,
image_no_bg
self.executor, self.save_image_to_file, image_no_bg
)
return {
@ -252,16 +148,12 @@ class RmbgService:
try:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor,
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
image_no_bg = await self.process_image(image)
image_url = await loop.run_in_executor(
self.executor,
self.save_image_to_file,
image_no_bg
self.executor, self.save_image_to_file, image_no_bg
)
return {
@ -278,6 +170,7 @@ class RmbgService:
success_count = 0
error_count = 0
batch_start_time = time.time()
loop = asyncio.get_event_loop()
async def download_and_process(index, url):
"""下载并处理单张图片"""
@ -285,24 +178,18 @@ class RmbgService:
try:
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: Image.open(temp_file).convert("RGB")
image = await loop.run_in_executor(
self.executor, lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await asyncio.get_event_loop().run_in_executor(
self.executor,
lambda: Image.open(url_str).convert("RGB")
image = await loop.run_in_executor(
self.executor, lambda: Image.open(url_str).convert("RGB")
)
processed_image = await self.process_image(image)
loop = asyncio.get_event_loop()
image_url = await loop.run_in_executor(
self.executor,
self.save_image_to_file,
processed_image
self.executor, self.save_image_to_file, processed_image
)
return {
@ -369,9 +256,7 @@ class RmbgService:
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
self.executor,
write_temp_file,
response.content
self.executor, write_temp_file, response.content
)
return temp_file_path

View File

@ -28,7 +28,6 @@ class Settings(BaseSettings):
# 并发控制配置
max_workers: int = 30 # 线程池最大工作线程数根据CPU核心数调整22核44线程可设置20-30
max_gpu_concurrent: int = 0 # GPU最大并发数0表示不限制根据显存大小设置24GB显存建议10-15
class Config:
env_file = ".env"