优化rmbg,删除未使用的变量
This commit is contained in:
parent
10fb6084f5
commit
5552b30958
@ -8,7 +8,6 @@ from transformers import AutoModelForImageSegmentation
|
|||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
import gc
|
import gc
|
||||||
import base64
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import uuid
|
import uuid
|
||||||
@ -45,23 +44,8 @@ class RmbgService:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
||||||
self._gpu_semaphore = None
|
|
||||||
self._max_gpu_concurrent = settings.max_gpu_concurrent
|
|
||||||
self._load_model()
|
self._load_model()
|
||||||
|
|
||||||
@property
|
|
||||||
def gpu_semaphore(self):
|
|
||||||
"""延迟初始化GPU信号量"""
|
|
||||||
if self._gpu_semaphore is None:
|
|
||||||
if self._max_gpu_concurrent == 0:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
|
|
||||||
except RuntimeError:
|
|
||||||
self._gpu_semaphore = asyncio.Semaphore(self._max_gpu_concurrent)
|
|
||||||
return self._gpu_semaphore
|
|
||||||
|
|
||||||
def _load_model(self):
|
def _load_model(self):
|
||||||
"""加载模型"""
|
"""加载模型"""
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
@ -93,96 +77,12 @@ class RmbgService:
|
|||||||
|
|
||||||
return image
|
return image
|
||||||
|
|
||||||
def _process_batch_images_sync(self, images_with_info):
|
|
||||||
"""
|
|
||||||
批量处理图像(充分利用GPU并行能力)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
images_with_info: [(image, image_size, index), ...] 图像和元信息列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[(processed_image, index), ...] 处理后的图像和索引
|
|
||||||
"""
|
|
||||||
if not images_with_info:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
transform_image = transforms.Compose([
|
|
||||||
transforms.Resize((1024, 1024)),
|
|
||||||
transforms.ToTensor(),
|
|
||||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
batch_images = []
|
|
||||||
for image, image_size, index in images_with_info:
|
|
||||||
try:
|
|
||||||
transformed = transform_image(image)
|
|
||||||
batch_images.append(transformed)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"图片转换失败 (index={index}): {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if not batch_images:
|
|
||||||
raise Exception("没有有效的图片可以处理")
|
|
||||||
|
|
||||||
input_batch = torch.stack(batch_images).to(self.device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
model_output = self.model(input_batch)
|
|
||||||
if isinstance(model_output, (list, tuple)):
|
|
||||||
preds = model_output[-1].sigmoid().cpu()
|
|
||||||
else:
|
|
||||||
preds = model_output.sigmoid().cpu()
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for i, (image, image_size, index) in enumerate(images_with_info):
|
|
||||||
try:
|
|
||||||
if len(preds.shape) == 4:
|
|
||||||
pred = preds[i].squeeze()
|
|
||||||
elif len(preds.shape) == 3:
|
|
||||||
pred = preds[i]
|
|
||||||
else:
|
|
||||||
pred = preds[i].squeeze()
|
|
||||||
|
|
||||||
pred_pil = transforms.ToPILImage()(pred)
|
|
||||||
mask = pred_pil.resize(image_size)
|
|
||||||
result_image = image.copy()
|
|
||||||
result_image.putalpha(mask)
|
|
||||||
results.append((result_image, index))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理预测结果失败 (index={index}): {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批处理失败: {str(e)}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def process_image(self, image):
|
async def process_image(self, image):
|
||||||
"""异步处理图像,移除背景"""
|
"""异步处理图像,移除背景"""
|
||||||
loop = asyncio.get_event_loop()
|
return await asyncio.get_event_loop().run_in_executor(
|
||||||
return await loop.run_in_executor(self.executor, self._process_image_sync, image)
|
self.executor, self._process_image_sync, image
|
||||||
|
|
||||||
async def process_batch_images(self, images_with_info):
|
|
||||||
"""异步批量处理图像"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return await loop.run_in_executor(
|
|
||||||
self.executor,
|
|
||||||
self._process_batch_images_sync,
|
|
||||||
images_with_info
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def image_to_base64(self, image):
|
|
||||||
"""将PIL Image对象转换为base64字符串"""
|
|
||||||
buffered = io.BytesIO()
|
|
||||||
image.save(buffered, format="PNG")
|
|
||||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
|
||||||
|
|
||||||
def save_image_to_file(self, image):
|
def save_image_to_file(self, image):
|
||||||
"""保存图片到文件并返回URL"""
|
"""保存图片到文件并返回URL"""
|
||||||
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
|
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
|
||||||
@ -215,16 +115,12 @@ class RmbgService:
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
image = await loop.run_in_executor(
|
image = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, lambda: Image.open(image_path).convert("RGB")
|
||||||
lambda: Image.open(image_path).convert("RGB")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_no_bg = await self.process_image(image)
|
image_no_bg = await self.process_image(image)
|
||||||
|
|
||||||
image_url = await loop.run_in_executor(
|
image_url = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, self.save_image_to_file, image_no_bg
|
||||||
self.save_image_to_file,
|
|
||||||
image_no_bg
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -252,16 +148,12 @@ class RmbgService:
|
|||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
image = await loop.run_in_executor(
|
image = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||||||
lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
image_no_bg = await self.process_image(image)
|
image_no_bg = await self.process_image(image)
|
||||||
|
|
||||||
image_url = await loop.run_in_executor(
|
image_url = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, self.save_image_to_file, image_no_bg
|
||||||
self.save_image_to_file,
|
|
||||||
image_no_bg
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -278,6 +170,7 @@ class RmbgService:
|
|||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
batch_start_time = time.time()
|
batch_start_time = time.time()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
async def download_and_process(index, url):
|
async def download_and_process(index, url):
|
||||||
"""下载并处理单张图片"""
|
"""下载并处理单张图片"""
|
||||||
@ -285,24 +178,18 @@ class RmbgService:
|
|||||||
try:
|
try:
|
||||||
if self.is_valid_url(url_str):
|
if self.is_valid_url(url_str):
|
||||||
temp_file = await self.download_image(url_str)
|
temp_file = await self.download_image(url_str)
|
||||||
image = await asyncio.get_event_loop().run_in_executor(
|
image = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, lambda: Image.open(temp_file).convert("RGB")
|
||||||
lambda: Image.open(temp_file).convert("RGB")
|
|
||||||
)
|
)
|
||||||
os.unlink(temp_file)
|
os.unlink(temp_file)
|
||||||
else:
|
else:
|
||||||
image = await asyncio.get_event_loop().run_in_executor(
|
image = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||||||
lambda: Image.open(url_str).convert("RGB")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_image = await self.process_image(image)
|
processed_image = await self.process_image(image)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
image_url = await loop.run_in_executor(
|
image_url = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, self.save_image_to_file, processed_image
|
||||||
self.save_image_to_file,
|
|
||||||
processed_image
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -369,9 +256,7 @@ class RmbgService:
|
|||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
temp_file_path = await loop.run_in_executor(
|
temp_file_path = await loop.run_in_executor(
|
||||||
self.executor,
|
self.executor, write_temp_file, response.content
|
||||||
write_temp_file,
|
|
||||||
response.content
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return temp_file_path
|
return temp_file_path
|
||||||
|
|||||||
@ -28,7 +28,6 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# 并发控制配置
|
# 并发控制配置
|
||||||
max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
max_workers: int = 30 # 线程池最大工作线程数(根据CPU核心数调整,22核44线程可设置20-30)
|
||||||
max_gpu_concurrent: int = 0 # GPU最大并发数(0表示不限制,根据显存大小设置,24GB显存建议10-15)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user