优化rmbg,删除未使用的变量
This commit is contained in:
parent
10fb6084f5
commit
5552b30958
@ -8,7 +8,6 @@ from transformers import AutoModelForImageSegmentation
|
||||
import time
|
||||
import warnings
|
||||
import gc
|
||||
import base64
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
@ -45,22 +44,7 @@ class RmbgService:
|
||||
)
|
||||
)
|
||||
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
||||
self._gpu_semaphore = None
|
||||
self._max_gpu_concurrent = settings.max_gpu_concurrent
|
||||
self._load_model()
|
||||
|
||||
@property
|
||||
def gpu_semaphore(self):
|
||||
"""延迟初始化GPU信号量"""
|
||||
if self._gpu_semaphore is None:
|
||||
if self._max_gpu_concurrent == 0:
|
||||
return None
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
|
||||
except RuntimeError:
|
||||
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
|
||||
return self._gpu_semaphore
|
||||
|
||||
def _load_model(self):
|
||||
"""加载模型"""
|
||||
@ -93,95 +77,11 @@ class RmbgService:
|
||||
|
||||
return image
|
||||
|
||||
def _process_batch_images_sync(self, images_with_info):
|
||||
"""
|
||||
批量处理图像(充分利用GPU并行能力)
|
||||
|
||||
Args:
|
||||
images_with_info: [(image, image_size, index), ...] 图像和元信息列表
|
||||
|
||||
Returns:
|
||||
[(processed_image, index), ...] 处理后的图像和索引
|
||||
"""
|
||||
if not images_with_info:
|
||||
return []
|
||||
|
||||
try:
|
||||
transform_image = transforms.Compose([
|
||||
transforms.Resize((1024, 1024)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||
])
|
||||
|
||||
batch_images = []
|
||||
for image, image_size, index in images_with_info:
|
||||
try:
|
||||
transformed = transform_image(image)
|
||||
batch_images.append(transformed)
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换失败 (index={index}): {str(e)}")
|
||||
raise
|
||||
|
||||
if not batch_images:
|
||||
raise Exception("没有有效的图片可以处理")
|
||||
|
||||
input_batch = torch.stack(batch_images).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = self.model(input_batch)
|
||||
if isinstance(model_output, (list, tuple)):
|
||||
preds = model_output[-1].sigmoid().cpu()
|
||||
else:
|
||||
preds = model_output.sigmoid().cpu()
|
||||
|
||||
results = []
|
||||
for i, (image, image_size, index) in enumerate(images_with_info):
|
||||
try:
|
||||
if len(preds.shape) == 4:
|
||||
pred = preds[i].squeeze()
|
||||
elif len(preds.shape) == 3:
|
||||
pred = preds[i]
|
||||
else:
|
||||
pred = preds[i].squeeze()
|
||||
|
||||
pred_pil = transforms.ToPILImage()(pred)
|
||||
mask = pred_pil.resize(image_size)
|
||||
result_image = image.copy()
|
||||
result_image.putalpha(mask)
|
||||
results.append((result_image, index))
|
||||
except Exception as e:
|
||||
logger.error(f"处理预测结果失败 (index={index}): {str(e)}")
|
||||
raise
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批处理失败: {str(e)}")
|
||||
raise
|
||||
|
||||
async def process_image(self, image):
|
||||
"""异步处理图像,移除背景"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(self.executor, self._process_image_sync, image)
|
||||
|
||||
async def process_batch_images(self, images_with_info):
|
||||
"""异步批量处理图像"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
self.executor,
|
||||
self._process_batch_images_sync,
|
||||
images_with_info
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor, self._process_image_sync, image
|
||||
)
|
||||
|
||||
def image_to_base64(self, image):
|
||||
"""将PIL Image对象转换为base64字符串"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
def save_image_to_file(self, image):
|
||||
"""保存图片到文件并返回URL"""
|
||||
@ -215,16 +115,12 @@ class RmbgService:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
image = await loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: Image.open(image_path).convert("RGB")
|
||||
self.executor, lambda: Image.open(image_path).convert("RGB")
|
||||
)
|
||||
|
||||
image_no_bg = await self.process_image(image)
|
||||
|
||||
image_url = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self.save_image_to_file,
|
||||
image_no_bg
|
||||
self.executor, self.save_image_to_file, image_no_bg
|
||||
)
|
||||
|
||||
return {
|
||||
@ -252,16 +148,12 @@ class RmbgService:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
image = await loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||||
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||||
)
|
||||
|
||||
image_no_bg = await self.process_image(image)
|
||||
|
||||
image_url = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self.save_image_to_file,
|
||||
image_no_bg
|
||||
self.executor, self.save_image_to_file, image_no_bg
|
||||
)
|
||||
|
||||
return {
|
||||
@ -278,6 +170,7 @@ class RmbgService:
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
batch_start_time = time.time()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def download_and_process(index, url):
|
||||
"""下载并处理单张图片"""
|
||||
@ -285,24 +178,18 @@ class RmbgService:
|
||||
try:
|
||||
if self.is_valid_url(url_str):
|
||||
temp_file = await self.download_image(url_str)
|
||||
image = await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: Image.open(temp_file).convert("RGB")
|
||||
image = await loop.run_in_executor(
|
||||
self.executor, lambda: Image.open(temp_file).convert("RGB")
|
||||
)
|
||||
os.unlink(temp_file)
|
||||
else:
|
||||
image = await asyncio.get_event_loop().run_in_executor(
|
||||
self.executor,
|
||||
lambda: Image.open(url_str).convert("RGB")
|
||||
image = await loop.run_in_executor(
|
||||
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||||
)
|
||||
|
||||
processed_image = await self.process_image(image)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
image_url = await loop.run_in_executor(
|
||||
self.executor,
|
||||
self.save_image_to_file,
|
||||
processed_image
|
||||
self.executor, self.save_image_to_file, processed_image
|
||||
)
|
||||
|
||||
return {
|
||||
@ -369,9 +256,7 @@ class RmbgService:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
temp_file_path = await loop.run_in_executor(
|
||||
self.executor,
|
||||
write_temp_file,
|
||||
response.content
|
||||
self.executor, write_temp_file, response.content
|
||||
)
|
||||
|
||||
return temp_file_path
|
||||
|
||||
@ -28,7 +28,6 @@ class Settings(BaseSettings):
|
||||
|
||||
# 并发控制配置
|
||||
max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
||||
max_gpu_concurrent: int = 0 # GPU最大并发数(0表示不限制,根据显存大小设置,24GB显存建议10-15)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user