846 lines
34 KiB
Python
846 lines
34 KiB
Python
import os
|
||
import tempfile
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from transformers import AutoModelForImageSegmentation
|
||
import time
|
||
import warnings
|
||
import gc
|
||
import asyncio
|
||
import io
|
||
import uuid
|
||
import httpx
|
||
import logging
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Dict, Any
|
||
from settings import settings
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning)
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
torch.set_float32_matmul_precision("high")
|
||
|
||
|
||
@dataclass
|
||
class QueueItem:
|
||
"""队列项数据结构"""
|
||
image: Image.Image
|
||
image_size: tuple
|
||
request_id: str
|
||
future: asyncio.Future
|
||
created_at: float
|
||
|
||
|
||
class RmbgService:
|
||
def __init__(self, model_path=None):
|
||
"""初始化背景移除服务"""
|
||
self.model_path = model_path or settings.model_path
|
||
self.model = None
|
||
self.device = None
|
||
self.save_dir = settings.save_dir
|
||
self.download_url = settings.download_url
|
||
os.makedirs(self.save_dir, exist_ok=True)
|
||
|
||
self.http_client = httpx.AsyncClient(
|
||
timeout=30.0,
|
||
limits=httpx.Limits(
|
||
max_keepalive_connections=50,
|
||
max_connections=100
|
||
)
|
||
)
|
||
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
||
|
||
# 队列聚合机制(方案B)
|
||
self.queue: asyncio.Queue = asyncio.Queue()
|
||
self.queue_task: Optional[asyncio.Task] = None
|
||
self.queue_running = False
|
||
|
||
self._load_model()
|
||
# 队列任务将在 FastAPI startup 事件中启动
|
||
|
||
def _load_model(self):
|
||
"""加载模型"""
|
||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
||
# 优化显存占用:使用半精度加载(如果支持)
|
||
# 注意:某些模型可能不支持半精度,需要测试
|
||
try:
|
||
# 尝试使用半精度加载,可以减少约50%的显存占用
|
||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True,
|
||
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
||
)
|
||
self.model = self.model.to(self.device)
|
||
if torch.cuda.is_available():
|
||
self.model = self.model.half() # 转换为半精度
|
||
except Exception as e:
|
||
# 如果半精度加载失败,降级到全精度
|
||
logger.warning(f"半精度加载失败,使用全精度: {str(e)}")
|
||
self.model = AutoModelForImageSegmentation.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True
|
||
)
|
||
self.model = self.model.to(self.device)
|
||
|
||
self.model.eval()
|
||
|
||
# 优化显存分配策略:减少碎片化
|
||
if torch.cuda.is_available():
|
||
# 设置显存分配器,减少碎片化
|
||
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
|
||
torch.cuda.empty_cache()
|
||
|
||
def _process_image_sync(self, image):
|
||
"""同步处理图像,移除背景(单张)"""
|
||
image_size = image.size
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
||
# 如果模型是半精度,输入也转换为半精度
|
||
if next(self.model.parameters()).dtype == torch.float16:
|
||
input_images = input_images.half()
|
||
|
||
with torch.no_grad():
|
||
preds = self.model(input_images)[-1].sigmoid().cpu()
|
||
|
||
pred = preds[0].squeeze()
|
||
pred_pil = transforms.ToPILImage()(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
image.putalpha(mask)
|
||
|
||
# 单张处理保留 gc.collect(),确保及时释放内存
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
return image
|
||
|
||
def _process_batch_images_sync(self, images_with_info):
|
||
"""批量处理图像(批处理模式,充分利用GPU并行能力)"""
|
||
if not images_with_info:
|
||
return []
|
||
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
batch_tensors = []
|
||
for image, image_size, index in images_with_info:
|
||
batch_tensors.append(transform_image(image))
|
||
|
||
input_batch = torch.stack(batch_tensors).to(self.device)
|
||
# 如果模型是半精度,输入也转换为半精度
|
||
if next(self.model.parameters()).dtype == torch.float16:
|
||
input_batch = input_batch.half()
|
||
# 释放 batch_tensors 占用的 CPU 内存
|
||
del batch_tensors
|
||
|
||
with torch.no_grad():
|
||
model_output = self.model(input_batch)
|
||
if isinstance(model_output, (list, tuple)):
|
||
preds = model_output[-1].sigmoid().cpu()
|
||
else:
|
||
preds = model_output.sigmoid().cpu()
|
||
|
||
# 立即释放 GPU 上的 input_batch 和 model_output
|
||
del input_batch
|
||
if isinstance(model_output, (list, tuple)):
|
||
del model_output
|
||
|
||
# 复用 ToPILImage 转换器,避免重复创建对象
|
||
to_pil = transforms.ToPILImage()
|
||
results = []
|
||
for i, (image, image_size, index) in enumerate(images_with_info):
|
||
if len(preds.shape) == 4:
|
||
pred = preds[i].squeeze()
|
||
elif len(preds.shape) == 3:
|
||
pred = preds[i]
|
||
else:
|
||
pred = preds[i].squeeze()
|
||
|
||
pred_pil = to_pil(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
result_image = image.copy()
|
||
result_image.putalpha(mask)
|
||
results.append((result_image, index))
|
||
|
||
# 释放 preds
|
||
del preds
|
||
|
||
# 批处理后清理显存(移除 gc.collect(),减少阻塞)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
return results
|
||
|
||
async def process_image(self, image):
|
||
"""异步处理图像,移除背景(单张)- 使用队列批处理模式"""
|
||
if settings.enable_queue_batch and self.queue_running:
|
||
return await self._process_image_via_queue(image)
|
||
else:
|
||
# 降级到单张处理
|
||
return await asyncio.get_event_loop().run_in_executor(
|
||
self.executor, self._process_image_sync, image
|
||
)
|
||
|
||
async def _process_image_via_queue(self, image):
|
||
"""通过队列批处理模式处理单张图像"""
|
||
request_id = uuid.uuid4().hex[:10]
|
||
future = asyncio.Future()
|
||
|
||
queue_item = QueueItem(
|
||
image=image,
|
||
image_size=image.size,
|
||
request_id=request_id,
|
||
future=future,
|
||
created_at=time.time()
|
||
)
|
||
|
||
try:
|
||
await self.queue.put(queue_item)
|
||
|
||
# 等待处理结果,带超时
|
||
try:
|
||
result = await asyncio.wait_for(future, timeout=settings.request_timeout)
|
||
return result
|
||
except asyncio.TimeoutError:
|
||
future.cancel()
|
||
raise Exception(f"处理超时(超过{settings.request_timeout}秒)")
|
||
|
||
except Exception as e:
|
||
if not future.done():
|
||
future.set_exception(e)
|
||
raise
|
||
|
||
async def process_batch_images(self, images_with_info):
|
||
"""异步批量处理图像(批处理模式)"""
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(
|
||
self.executor, self._process_batch_images_sync, images_with_info
|
||
)
|
||
|
||
async def _start_queue_processor(self):
|
||
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用)"""
|
||
if self.queue_running:
|
||
return
|
||
|
||
self.queue_running = True
|
||
self.queue_task = asyncio.create_task(self._queue_processor())
|
||
|
||
async def _queue_processor(self):
|
||
"""后台队列批处理任务(核心逻辑)"""
|
||
while self.queue_running:
|
||
try:
|
||
# 收集一批请求
|
||
batch_items = await self._collect_batch_items()
|
||
|
||
if not batch_items:
|
||
continue
|
||
|
||
# 处理这批请求
|
||
await self._process_batch_queue_items(batch_items)
|
||
|
||
except Exception as e:
|
||
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True)
|
||
await asyncio.sleep(0.1) # 出错后短暂等待
|
||
|
||
async def _collect_batch_items(self):
|
||
"""收集一批队列项,达到batch_size或超时后返回"""
|
||
batch_items = []
|
||
batch_size = settings.batch_size
|
||
collect_interval = settings.batch_collect_interval
|
||
collect_timeout = settings.batch_collect_timeout
|
||
|
||
# 先尝试获取第一个请求(阻塞等待)
|
||
try:
|
||
first_item = await asyncio.wait_for(
|
||
self.queue.get(),
|
||
timeout=collect_timeout
|
||
)
|
||
batch_items.append(first_item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,返回空列表
|
||
return []
|
||
|
||
# 继续收集更多请求,直到达到batch_size或超时
|
||
start_time = time.time()
|
||
|
||
while len(batch_items) < batch_size:
|
||
elapsed = time.time() - start_time
|
||
|
||
# 如果已经超时,立即处理当前收集的请求
|
||
if elapsed >= collect_timeout:
|
||
break
|
||
|
||
# 尝试在剩余时间内获取更多请求
|
||
remaining_time = min(collect_interval, collect_timeout - elapsed)
|
||
|
||
try:
|
||
item = await asyncio.wait_for(
|
||
self.queue.get(),
|
||
timeout=remaining_time
|
||
)
|
||
batch_items.append(item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,处理已收集的请求
|
||
break
|
||
|
||
return batch_items
|
||
|
||
async def _process_batch_queue_items(self, batch_items):
|
||
"""处理一批队列项"""
|
||
if not batch_items:
|
||
return
|
||
|
||
loop = asyncio.get_event_loop()
|
||
|
||
try:
|
||
# 准备批处理数据
|
||
images_with_info = []
|
||
for idx, item in enumerate(batch_items):
|
||
images_with_info.append((item.image, item.image_size, idx))
|
||
|
||
# 执行批处理
|
||
batch_results = await self.process_batch_images(images_with_info)
|
||
|
||
# 将结果返回给对应的Future
|
||
for idx, (processed_image, _) in enumerate(batch_results):
|
||
if idx < len(batch_items):
|
||
item = batch_items[idx]
|
||
|
||
# 保存图片并返回URL
|
||
try:
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, processed_image
|
||
)
|
||
|
||
result = {
|
||
"status": "success",
|
||
"image_url": image_url
|
||
}
|
||
|
||
if not item.future.done():
|
||
item.future.set_result(result)
|
||
except Exception as e:
|
||
error_msg = f"处理图片失败: {str(e)}"
|
||
logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}")
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception(error_msg))
|
||
|
||
# 处理任何未完成的Future(理论上不应该发生)
|
||
for item in batch_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception("批处理结果不完整"))
|
||
|
||
except Exception as e:
|
||
error_msg = f"批处理队列项失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
# 所有请求都标记为失败
|
||
for item in batch_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception(error_msg))
|
||
|
||
def save_image_to_file(self, image):
|
||
"""保存图片到文件并返回URL"""
|
||
filename = f"rmbg_{uuid.uuid4().hex[:10]}.webp"
|
||
file_path = os.path.join(self.save_dir, filename)
|
||
|
||
image.save(file_path, format="WEBP", quality=85, method=6)
|
||
image_url = f"{self.download_url}/{filename}"
|
||
return image_url
|
||
|
||
async def remove_background(self, image_path):
|
||
"""
|
||
移除图像背景
|
||
|
||
Args:
|
||
image_path: 输入图像的路径或URL
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
temp_file = None
|
||
try:
|
||
if self.is_valid_url(image_path):
|
||
try:
|
||
temp_file = await self.download_image(image_path)
|
||
image_path = temp_file
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(image_path).convert("RGB")
|
||
)
|
||
|
||
result = await self.process_image(image)
|
||
if isinstance(result, dict):
|
||
return result
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result
|
||
)
|
||
return {"status": "success", "image_url": image_url}
|
||
|
||
finally:
|
||
if temp_file and os.path.exists(temp_file):
|
||
try:
|
||
os.unlink(temp_file)
|
||
except:
|
||
pass
|
||
|
||
async def remove_background_from_file(self, file_content):
|
||
"""
|
||
从上传的文件内容移除背景
|
||
|
||
Args:
|
||
file_content: 上传的文件内容
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||
)
|
||
|
||
result = await self.process_image(image)
|
||
if isinstance(result, dict):
|
||
return result
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result
|
||
)
|
||
return {"status": "success", "image_url": image_url}
|
||
|
||
except Exception as e:
|
||
raise Exception(f"处理图片失败: {e}")
|
||
|
||
async def process_batch(self, urls):
|
||
"""批量处理多个URL图像,流水线批处理模式(下载和处理并行)"""
|
||
total = len(urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
batch_start_time = time.time()
|
||
batch_size = settings.batch_size
|
||
loop = asyncio.get_event_loop()
|
||
|
||
|
||
# 流水线队列:收集已下载的图片
|
||
download_queue = asyncio.Queue()
|
||
download_complete = asyncio.Event()
|
||
download_done_count = 0
|
||
download_error_count = 0
|
||
|
||
|
||
async def download_image_async(index, url):
|
||
"""异步下载图片并放入队列"""
|
||
nonlocal download_done_count, download_error_count
|
||
url_str = str(url)
|
||
|
||
try:
|
||
if self.is_valid_url(url_str):
|
||
temp_file = await self.download_image(url_str)
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(temp_file).convert("RGB")
|
||
)
|
||
os.unlink(temp_file)
|
||
else:
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||
)
|
||
|
||
# 下载成功,放入队列
|
||
await download_queue.put((image, image.size, index, url_str, None))
|
||
download_done_count += 1
|
||
|
||
except Exception as e:
|
||
# 下载失败,也放入队列(标记为错误)
|
||
await download_queue.put((None, None, index, url_str, str(e)))
|
||
download_error_count += 1
|
||
download_done_count += 1
|
||
finally:
|
||
# 所有下载任务完成
|
||
if download_done_count >= total:
|
||
download_complete.set()
|
||
|
||
# 启动所有下载任务(并行下载)
|
||
download_tasks = [
|
||
asyncio.create_task(download_image_async(i, url))
|
||
for i, url in enumerate(urls, 1)
|
||
]
|
||
|
||
# 流水线批处理任务:收集队列中的图片,达到batch_size或超时后立即处理
|
||
completed_order = 0
|
||
pending_batch = []
|
||
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
||
max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理
|
||
|
||
async def process_pending_batch(force=False):
|
||
"""处理待处理的批次"""
|
||
nonlocal pending_batch, completed_order, success_count, error_count
|
||
|
||
if not pending_batch:
|
||
return
|
||
|
||
# 分离成功和失败的图片
|
||
valid_items = []
|
||
failed_items = []
|
||
|
||
for item in pending_batch:
|
||
image, image_size, index, url_str, error = item
|
||
if error:
|
||
failed_items.append((index, url_str, error))
|
||
else:
|
||
valid_items.append((image, image_size, index, url_str))
|
||
|
||
# 先处理下载失败的
|
||
for index, url_str, error in failed_items:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": error,
|
||
"message": f"下载失败: {error}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
|
||
if not valid_items:
|
||
pending_batch = []
|
||
return
|
||
|
||
# 处理成功的图片
|
||
try:
|
||
# 判断是否尝试一次性处理
|
||
use_single_batch = len(valid_items) <= max_single_batch and force
|
||
|
||
if use_single_batch:
|
||
# 尝试一次性处理所有图片
|
||
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items]
|
||
|
||
batch_results = await self.process_batch_images(images_with_info)
|
||
|
||
# 并行保存
|
||
save_tasks = []
|
||
result_mapping = {}
|
||
|
||
for processed_image, index in batch_results:
|
||
url_str = next(url for _, _, idx, url in valid_items if idx == index)
|
||
result_mapping[index] = (processed_image, url_str)
|
||
save_task = loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, processed_image
|
||
)
|
||
save_tasks.append((index, save_task))
|
||
|
||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||
|
||
for (index, _), image_url in zip(save_tasks, save_results):
|
||
if isinstance(image_url, Exception):
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": result_mapping[index][1],
|
||
"status": "error",
|
||
"error": str(image_url),
|
||
"message": f"保存图片失败: {str(image_url)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
else:
|
||
completed_order += 1
|
||
success_count += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": result_mapping[index][1],
|
||
"status": "success",
|
||
"image_url": image_url,
|
||
"message": "处理成功",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
else:
|
||
# 分批处理
|
||
for batch_start in range(0, len(valid_items), batch_size):
|
||
batch_end = min(batch_start + batch_size, len(valid_items))
|
||
batch_items = valid_items[batch_start:batch_end]
|
||
|
||
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items]
|
||
|
||
batch_results = await self.process_batch_images(images_with_info)
|
||
|
||
# 并行保存
|
||
save_tasks = []
|
||
result_mapping = {}
|
||
|
||
for processed_image, index in batch_results:
|
||
url_str = next(url for _, _, idx, url in batch_items if idx == index)
|
||
result_mapping[index] = (processed_image, url_str)
|
||
save_task = loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, processed_image
|
||
)
|
||
save_tasks.append((index, save_task))
|
||
|
||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||
|
||
for (index, _), image_url in zip(save_tasks, save_results):
|
||
if isinstance(image_url, Exception):
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": result_mapping[index][1],
|
||
"status": "error",
|
||
"error": str(image_url),
|
||
"message": f"保存图片失败: {str(image_url)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
else:
|
||
completed_order += 1
|
||
success_count += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": result_mapping[index][1],
|
||
"status": "success",
|
||
"image_url": image_url,
|
||
"message": "处理成功",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
|
||
except RuntimeError as e:
|
||
# CUDA OOM错误,降级处理
|
||
error_msg = str(e)
|
||
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
|
||
logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}")
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
# 降级到单张处理
|
||
for image, image_size, index, url_str, _ in valid_items:
|
||
try:
|
||
result_data = await self.process_image(image)
|
||
if isinstance(result_data, dict):
|
||
image_url = result_data["image_url"]
|
||
else:
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result_data
|
||
)
|
||
completed_order += 1
|
||
success_count += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "success",
|
||
"image_url": image_url,
|
||
"message": "处理成功(降级模式)",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
except Exception as e2:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": str(e2),
|
||
"message": f"处理失败: {str(e2)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
else:
|
||
# 其他错误
|
||
logger.error(f"批处理失败: {error_msg}")
|
||
for _, _, index, url_str, _ in valid_items:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": error_msg,
|
||
"message": f"批处理失败: {error_msg}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
except Exception as e:
|
||
logger.error(f"批处理失败: {str(e)}")
|
||
for _, _, index, url_str, _ in valid_items:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": str(e),
|
||
"message": f"批处理失败: {str(e)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
|
||
pending_batch = []
|
||
|
||
# 流水线处理:收集队列中的图片,达到batch_size或超时后立即处理
|
||
while True:
|
||
try:
|
||
# 等待队列中有新图片,或超时
|
||
try:
|
||
item = await asyncio.wait_for(
|
||
download_queue.get(),
|
||
timeout=batch_collect_timeout
|
||
)
|
||
pending_batch.append(item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,处理当前批次
|
||
if pending_batch:
|
||
async for result in process_pending_batch():
|
||
yield result
|
||
# 检查是否所有下载都完成
|
||
if download_complete.is_set():
|
||
break
|
||
continue
|
||
|
||
# 如果达到batch_size,立即处理
|
||
if len(pending_batch) >= batch_size:
|
||
async for result in process_pending_batch():
|
||
yield result
|
||
|
||
# 检查是否所有下载都完成
|
||
if download_complete.is_set() and download_queue.empty():
|
||
# 处理剩余的图片
|
||
if pending_batch:
|
||
async for result in process_pending_batch(force=True):
|
||
yield result
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"流水线处理出错: {str(e)}", exc_info=True)
|
||
break
|
||
|
||
# 等待所有下载任务完成
|
||
await asyncio.gather(*download_tasks, return_exceptions=True)
|
||
|
||
# 确保所有结果都已处理
|
||
if pending_batch:
|
||
async for result in process_pending_batch(force=True):
|
||
yield result
|
||
|
||
def is_valid_url(self, url):
|
||
"""验证URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
async def download_image(self, url):
|
||
"""异步从URL下载图片到临时文件"""
|
||
try:
|
||
response = await self.http_client.get(url)
|
||
response.raise_for_status()
|
||
|
||
def write_temp_file(content):
|
||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||
temp_file.write(content)
|
||
temp_file.close()
|
||
return temp_file.name
|
||
|
||
loop = asyncio.get_event_loop()
|
||
temp_file_path = await loop.run_in_executor(
|
||
self.executor, write_temp_file, response.content
|
||
)
|
||
|
||
return temp_file_path
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
async def cleanup(self):
|
||
"""清理资源"""
|
||
# 停止队列处理任务
|
||
if self.queue_running:
|
||
self.queue_running = False
|
||
if self.queue_task:
|
||
self.queue_task.cancel()
|
||
try:
|
||
await self.queue_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 处理队列中剩余的请求
|
||
remaining_items = []
|
||
while not self.queue.empty():
|
||
try:
|
||
item = self.queue.get_nowait()
|
||
remaining_items.append(item)
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
# 标记剩余请求为失败
|
||
for item in remaining_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception("服务关闭,请求被取消"))
|
||
|
||
await self.http_client.aclose()
|
||
self.executor.shutdown(wait=True)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect() |