- 重构 process_batch,将批处理入口改为逐张通过全局队列的 process_image - 在本地批次内并发调用 process_image,让全局队列能凑大 batch 并触发多 GPU 并行 - 保留原有流式返回结构和统计字段,对外接口兼容不变
805 lines
31 KiB
Python
805 lines
31 KiB
Python
import os
|
||
import tempfile
|
||
from urllib.parse import urlparse
|
||
from PIL import Image
|
||
import torch
|
||
from torchvision import transforms
|
||
from transformers import AutoModelForImageSegmentation
|
||
import time
|
||
import warnings
|
||
import gc
|
||
import asyncio
|
||
import io
|
||
import uuid
|
||
import httpx
|
||
import logging
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from dataclasses import dataclass
|
||
from typing import Optional, Dict, Any
|
||
from threading import Lock
|
||
from settings import settings
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
warnings.filterwarnings("ignore", category=UserWarning)
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
torch.set_float32_matmul_precision("high")
|
||
|
||
|
||
@dataclass
|
||
class QueueItem:
|
||
"""队列项数据结构"""
|
||
image: Image.Image
|
||
image_size: tuple
|
||
request_id: str
|
||
future: asyncio.Future
|
||
created_at: float
|
||
|
||
|
||
class RmbgService:
|
||
def __init__(self, model_path=None):
|
||
"""初始化背景移除服务"""
|
||
self.model_path = model_path or settings.model_path
|
||
# 单机多 GPU:维护模型和设备列表,兼容旧字段
|
||
self.models = []
|
||
self.devices = []
|
||
self.model = None
|
||
self.device = None
|
||
self._gpu_lock = Lock()
|
||
self._next_gpu_index = 0
|
||
self.save_dir = settings.save_dir
|
||
self.download_url = settings.download_url
|
||
os.makedirs(self.save_dir, exist_ok=True)
|
||
|
||
self.http_client = httpx.AsyncClient(
|
||
timeout=30.0,
|
||
limits=httpx.Limits(
|
||
max_keepalive_connections=50,
|
||
max_connections=100
|
||
)
|
||
)
|
||
self.executor = ThreadPoolExecutor(max_workers=settings.max_workers)
|
||
|
||
# 队列聚合机制(方案B)
|
||
self.queue: asyncio.Queue = asyncio.Queue()
|
||
self.queue_task: Optional[asyncio.Task] = None
|
||
self.queue_running = False
|
||
|
||
self._load_model()
|
||
# 队列任务将在 FastAPI startup 事件中启动
|
||
|
||
def _load_model(self):
|
||
"""加载模型,支持多 GPU"""
|
||
# 优化显存分配策略:减少碎片化(需要在加载前设置)
|
||
if torch.cuda.is_available():
|
||
os.environ.setdefault('PYTORCH_CUDA_ALLOC_CONF', 'expandable_segments:True')
|
||
|
||
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||
use_half = torch.cuda.is_available()
|
||
|
||
def _load_single_model(device: torch.device):
|
||
"""在指定 device 上加载一个模型实例"""
|
||
try:
|
||
model = AutoModelForImageSegmentation.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True,
|
||
torch_dtype=torch.float16 if use_half else torch.float32,
|
||
)
|
||
model = model.to(device)
|
||
if use_half:
|
||
model = model.half()
|
||
except Exception as e:
|
||
# 如果半精度加载失败,降级到全精度
|
||
logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}")
|
||
model = AutoModelForImageSegmentation.from_pretrained(
|
||
self.model_path,
|
||
trust_remote_code=True,
|
||
)
|
||
model = model.to(device)
|
||
model.eval()
|
||
return model
|
||
|
||
if num_gpus > 0:
|
||
# 为每张 GPU 加载一份模型,简单轮询调度
|
||
for idx in range(num_gpus):
|
||
device = torch.device(f"cuda:{idx}")
|
||
model = _load_single_model(device)
|
||
self.devices.append(device)
|
||
self.models.append(model)
|
||
logger.info(f"检测到 {num_gpus} 张 GPU,已为每张 GPU 加载模型实例")
|
||
else:
|
||
# 仅 CPU
|
||
device = torch.device("cpu")
|
||
model = _load_single_model(device)
|
||
self.devices.append(device)
|
||
self.models.append(model)
|
||
logger.info("未检测到 GPU,使用 CPU 设备")
|
||
|
||
# 兼容旧字段:默认指向第一个设备和模型
|
||
self.device = self.devices[0]
|
||
self.model = self.models[0]
|
||
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
def _get_model_and_device(self):
|
||
"""为一次推理选择一个模型和设备(轮询)"""
|
||
if not self.models or not self.devices:
|
||
raise RuntimeError("模型尚未加载")
|
||
if len(self.models) == 1:
|
||
return self.models[0], self.devices[0]
|
||
with self._gpu_lock:
|
||
idx = self._next_gpu_index
|
||
self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models)
|
||
return self.models[idx], self.devices[idx]
|
||
|
||
def _process_image_sync(self, image):
|
||
"""同步处理图像,移除背景(单张)"""
|
||
model, device = self._get_model_and_device()
|
||
image_size = image.size
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
input_images = transform_image(image).unsqueeze(0).to(device)
|
||
# 如果模型是半精度,输入也转换为半精度
|
||
if next(model.parameters()).dtype == torch.float16:
|
||
input_images = input_images.half()
|
||
|
||
with torch.no_grad():
|
||
preds = model(input_images)[-1].sigmoid().cpu()
|
||
|
||
pred = preds[0].squeeze()
|
||
pred_pil = transforms.ToPILImage()(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
image.putalpha(mask)
|
||
|
||
# 单张处理保留 gc.collect(),确保及时释放内存
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect()
|
||
|
||
return image
|
||
|
||
def _process_batch_images_sync(self, images_with_info):
|
||
"""批量处理图像(批处理模式,充分利用GPU并行能力)"""
|
||
if not images_with_info:
|
||
return []
|
||
|
||
# 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行
|
||
if len(self.models) == 1:
|
||
return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info)
|
||
|
||
# 简单均匀拆分到各个 GPU,上游调用会按 index 重新排序
|
||
tasks = []
|
||
for i, (model, device) in enumerate(zip(self.models, self.devices)):
|
||
sub_items = images_with_info[i::len(self.models)]
|
||
if not sub_items:
|
||
continue
|
||
tasks.append(
|
||
self.executor.submit(self._process_batch_on_device, model, device, sub_items)
|
||
)
|
||
|
||
all_results = []
|
||
for fut in tasks:
|
||
try:
|
||
sub_res = fut.result()
|
||
all_results.extend(sub_res)
|
||
except Exception as e:
|
||
logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True)
|
||
|
||
# 保证结果顺序与原始 index 一致
|
||
all_results.sort(key=lambda x: x[1])
|
||
return all_results
|
||
|
||
def _process_batch_on_device(self, model, device, images_with_info):
|
||
"""在指定 device 上批量处理图像"""
|
||
transform_image = transforms.Compose([
|
||
transforms.Resize((1024, 1024)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||
])
|
||
|
||
batch_tensors = []
|
||
for image, image_size, index in images_with_info:
|
||
batch_tensors.append(transform_image(image))
|
||
|
||
input_batch = torch.stack(batch_tensors).to(device)
|
||
# 如果模型是半精度,输入也转换为半精度
|
||
if next(model.parameters()).dtype == torch.float16:
|
||
input_batch = input_batch.half()
|
||
# 释放 batch_tensors 占用的 CPU 内存
|
||
del batch_tensors
|
||
|
||
with torch.no_grad():
|
||
model_output = model(input_batch)
|
||
if isinstance(model_output, (list, tuple)):
|
||
preds = model_output[-1].sigmoid().cpu()
|
||
else:
|
||
preds = model_output.sigmoid().cpu()
|
||
|
||
# 立即释放 GPU 上的 input_batch 和 model_output
|
||
del input_batch
|
||
if isinstance(model_output, (list, tuple)):
|
||
del model_output
|
||
|
||
# 复用 ToPILImage 转换器,避免重复创建对象
|
||
to_pil = transforms.ToPILImage()
|
||
results = []
|
||
for i, (image, image_size, index) in enumerate(images_with_info):
|
||
if len(preds.shape) == 4:
|
||
pred = preds[i].squeeze()
|
||
elif len(preds.shape) == 3:
|
||
pred = preds[i]
|
||
else:
|
||
pred = preds[i].squeeze()
|
||
|
||
pred_pil = to_pil(pred)
|
||
mask = pred_pil.resize(image_size)
|
||
result_image = image.copy()
|
||
result_image.putalpha(mask)
|
||
results.append((result_image, index))
|
||
|
||
# 释放 preds
|
||
del preds
|
||
|
||
# 批处理后清理显存(移除 gc.collect(),减少阻塞)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
|
||
return results
|
||
|
||
async def process_image(self, image):
|
||
"""异步处理图像,移除背景(单张)- 使用队列批处理模式"""
|
||
if settings.enable_queue_batch and self.queue_running:
|
||
return await self._process_image_via_queue(image)
|
||
else:
|
||
# 降级到单张处理
|
||
return await asyncio.get_event_loop().run_in_executor(
|
||
self.executor, self._process_image_sync, image
|
||
)
|
||
|
||
async def _process_image_via_queue(self, image):
|
||
"""通过队列批处理模式处理单张图像"""
|
||
request_id = uuid.uuid4().hex[:10]
|
||
future = asyncio.Future()
|
||
|
||
queue_item = QueueItem(
|
||
image=image,
|
||
image_size=image.size,
|
||
request_id=request_id,
|
||
future=future,
|
||
created_at=time.time()
|
||
)
|
||
|
||
try:
|
||
await self.queue.put(queue_item)
|
||
|
||
# 等待处理结果,带超时
|
||
try:
|
||
result = await asyncio.wait_for(future, timeout=settings.request_timeout)
|
||
return result
|
||
except asyncio.TimeoutError:
|
||
future.cancel()
|
||
raise Exception(f"处理超时(超过{settings.request_timeout}秒)")
|
||
|
||
except Exception as e:
|
||
if not future.done():
|
||
future.set_exception(e)
|
||
raise
|
||
|
||
async def process_batch_images(self, images_with_info):
|
||
"""异步批量处理图像(批处理模式)"""
|
||
loop = asyncio.get_event_loop()
|
||
return await loop.run_in_executor(
|
||
self.executor, self._process_batch_images_sync, images_with_info
|
||
)
|
||
|
||
async def _start_queue_processor(self):
|
||
"""启动队列批处理后台任务(异步方法,需要在事件循环中调用)"""
|
||
if self.queue_running:
|
||
return
|
||
|
||
self.queue_running = True
|
||
self.queue_task = asyncio.create_task(self._queue_processor())
|
||
|
||
async def _queue_processor(self):
|
||
"""后台队列批处理任务(核心逻辑)"""
|
||
while self.queue_running:
|
||
try:
|
||
# 收集一批请求
|
||
batch_items = await self._collect_batch_items()
|
||
|
||
if not batch_items:
|
||
continue
|
||
|
||
# 处理这批请求
|
||
await self._process_batch_queue_items(batch_items)
|
||
|
||
except Exception as e:
|
||
logger.error(f"队列批处理任务出错: {str(e)}", exc_info=True)
|
||
await asyncio.sleep(0.1) # 出错后短暂等待
|
||
|
||
async def _collect_batch_items(self):
|
||
"""收集一批队列项,达到batch_size或超时后返回"""
|
||
batch_items = []
|
||
batch_size = settings.batch_size
|
||
collect_interval = settings.batch_collect_interval
|
||
collect_timeout = settings.batch_collect_timeout
|
||
|
||
# 先尝试获取第一个请求(阻塞等待)
|
||
try:
|
||
first_item = await asyncio.wait_for(
|
||
self.queue.get(),
|
||
timeout=collect_timeout
|
||
)
|
||
batch_items.append(first_item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,返回空列表
|
||
return []
|
||
|
||
# 继续收集更多请求,直到达到batch_size或超时
|
||
start_time = time.time()
|
||
|
||
while len(batch_items) < batch_size:
|
||
elapsed = time.time() - start_time
|
||
|
||
# 如果已经超时,立即处理当前收集的请求
|
||
if elapsed >= collect_timeout:
|
||
break
|
||
|
||
# 尝试在剩余时间内获取更多请求
|
||
remaining_time = min(collect_interval, collect_timeout - elapsed)
|
||
|
||
try:
|
||
item = await asyncio.wait_for(
|
||
self.queue.get(),
|
||
timeout=remaining_time
|
||
)
|
||
batch_items.append(item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,处理已收集的请求
|
||
break
|
||
|
||
return batch_items
|
||
|
||
async def _process_batch_queue_items(self, batch_items):
|
||
"""处理一批队列项"""
|
||
if not batch_items:
|
||
return
|
||
|
||
loop = asyncio.get_event_loop()
|
||
|
||
try:
|
||
# 准备批处理数据
|
||
images_with_info = []
|
||
for idx, item in enumerate(batch_items):
|
||
images_with_info.append((item.image, item.image_size, idx))
|
||
|
||
# 执行批处理
|
||
batch_results = await self.process_batch_images(images_with_info)
|
||
|
||
# 将结果返回给对应的Future
|
||
for idx, (processed_image, _) in enumerate(batch_results):
|
||
if idx < len(batch_items):
|
||
item = batch_items[idx]
|
||
|
||
# 保存图片并返回URL
|
||
try:
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, processed_image
|
||
)
|
||
|
||
result = {
|
||
"status": "success",
|
||
"image_url": image_url
|
||
}
|
||
|
||
if not item.future.done():
|
||
item.future.set_result(result)
|
||
except Exception as e:
|
||
error_msg = f"处理图片失败: {str(e)}"
|
||
logger.error(f"队列项 {item.request_id} 处理失败: {error_msg}")
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception(error_msg))
|
||
|
||
# 处理任何未完成的Future(理论上不应该发生)
|
||
for item in batch_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception("批处理结果不完整"))
|
||
|
||
except Exception as e:
|
||
error_msg = f"批处理队列项失败: {str(e)}"
|
||
logger.error(error_msg, exc_info=True)
|
||
|
||
# 所有请求都标记为失败
|
||
for item in batch_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception(error_msg))
|
||
|
||
def save_image_to_file(self, image):
|
||
"""保存图片到文件并返回URL"""
|
||
# 改为保存 PNG,使用标准压缩级别(Pillow 默认 compress_level=6)
|
||
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
|
||
file_path = os.path.join(self.save_dir, filename)
|
||
image.save(file_path, format="PNG")
|
||
image_url = f"{self.download_url}/{filename}"
|
||
return image_url
|
||
|
||
async def remove_background(self, image_path):
|
||
"""
|
||
移除图像背景
|
||
|
||
Args:
|
||
image_path: 输入图像的路径或URL
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
temp_file = None
|
||
try:
|
||
if self.is_valid_url(image_path):
|
||
try:
|
||
temp_file = await self.download_image(image_path)
|
||
image_path = temp_file
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
if not os.path.exists(image_path):
|
||
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(image_path).convert("RGB")
|
||
)
|
||
|
||
result = await self.process_image(image)
|
||
if isinstance(result, dict):
|
||
return result
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result
|
||
)
|
||
return {"status": "success", "image_url": image_url}
|
||
|
||
finally:
|
||
if temp_file and os.path.exists(temp_file):
|
||
try:
|
||
os.unlink(temp_file)
|
||
except:
|
||
pass
|
||
|
||
async def remove_background_from_file(self, file_content):
|
||
"""
|
||
从上传的文件内容移除背景
|
||
|
||
Args:
|
||
file_content: 上传的文件内容
|
||
|
||
Returns:
|
||
处理后的图像内容
|
||
"""
|
||
try:
|
||
loop = asyncio.get_event_loop()
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
|
||
)
|
||
|
||
result = await self.process_image(image)
|
||
if isinstance(result, dict):
|
||
return result
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result
|
||
)
|
||
return {"status": "success", "image_url": image_url}
|
||
|
||
except Exception as e:
|
||
raise Exception(f"处理图片失败: {e}")
|
||
|
||
async def process_batch(self, urls):
|
||
"""批量处理多个URL图像,流水线批处理模式(下载和处理并行)"""
|
||
total = len(urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
batch_start_time = time.time()
|
||
batch_size = settings.batch_size
|
||
loop = asyncio.get_event_loop()
|
||
|
||
|
||
# 流水线队列:收集已下载的图片
|
||
download_queue = asyncio.Queue()
|
||
download_complete = asyncio.Event()
|
||
download_done_count = 0
|
||
download_error_count = 0
|
||
|
||
|
||
async def download_image_async(index, url):
|
||
"""异步下载图片并放入队列"""
|
||
nonlocal download_done_count, download_error_count
|
||
url_str = str(url)
|
||
|
||
try:
|
||
if self.is_valid_url(url_str):
|
||
temp_file = await self.download_image(url_str)
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(temp_file).convert("RGB")
|
||
)
|
||
os.unlink(temp_file)
|
||
else:
|
||
image = await loop.run_in_executor(
|
||
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||
)
|
||
|
||
# 下载成功,放入队列
|
||
await download_queue.put((image, image.size, index, url_str, None))
|
||
download_done_count += 1
|
||
|
||
except Exception as e:
|
||
# 下载失败,也放入队列(标记为错误)
|
||
await download_queue.put((None, None, index, url_str, str(e)))
|
||
download_error_count += 1
|
||
download_done_count += 1
|
||
finally:
|
||
# 所有下载任务完成
|
||
if download_done_count >= total:
|
||
download_complete.set()
|
||
|
||
# 启动所有下载任务(并行下载)
|
||
download_tasks = [
|
||
asyncio.create_task(download_image_async(i, url))
|
||
for i, url in enumerate(urls, 1)
|
||
]
|
||
|
||
# 流水线批处理任务:收集队列中的图片,达到batch_size或超时后立即处理
|
||
completed_order = 0
|
||
pending_batch = []
|
||
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
||
|
||
async def process_pending_batch(force=False):
|
||
"""处理待处理的批次(将下载好的图片逐张送入全局队列进行批处理)"""
|
||
nonlocal pending_batch, completed_order, success_count, error_count
|
||
|
||
if not pending_batch:
|
||
return
|
||
|
||
# 分离成功和失败的图片
|
||
valid_items = []
|
||
failed_items = []
|
||
|
||
for item in pending_batch:
|
||
image, image_size, index, url_str, error = item
|
||
if error:
|
||
failed_items.append((index, url_str, error))
|
||
else:
|
||
valid_items.append((image, image_size, index, url_str))
|
||
|
||
# 先处理下载失败的
|
||
for index, url_str, error in failed_items:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": error,
|
||
"message": f"下载失败: {error}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
|
||
if not valid_items:
|
||
pending_batch = []
|
||
return
|
||
|
||
# 处理成功的图片:通过全局队列进行批处理(每个本地批次内并发调度)
|
||
try:
|
||
# 为了控制单次处理时长,仍按 batch_size 对 valid_items 做分块
|
||
for local_batch_start in range(0, len(valid_items), batch_size):
|
||
local_batch_end = min(local_batch_start + batch_size, len(valid_items))
|
||
local_batch_items = valid_items[local_batch_start:local_batch_end]
|
||
|
||
# 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU
|
||
tasks = [
|
||
self.process_image(image)
|
||
for image, _, _, _ in local_batch_items
|
||
]
|
||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results):
|
||
if isinstance(result_data, Exception):
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": str(result_data),
|
||
"message": f"处理失败: {str(result_data)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
continue
|
||
|
||
if isinstance(result_data, dict):
|
||
# 队列模式下 process_image 通常直接返回 {status, image_url}
|
||
status = result_data.get("status", "success")
|
||
image_url = result_data.get("image_url")
|
||
message = result_data.get("message", "处理成功" if status == "success" else "处理失败")
|
||
error_msg = result_data.get("error")
|
||
else:
|
||
# 兼容非 dict 返回:手动保存图片
|
||
image_url = await loop.run_in_executor(
|
||
self.executor, self.save_image_to_file, result_data
|
||
)
|
||
status = "success"
|
||
message = "处理成功"
|
||
error_msg = None
|
||
|
||
completed_order += 1
|
||
if status == "success" and image_url:
|
||
success_count += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "success",
|
||
"image_url": image_url,
|
||
"message": message,
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
else:
|
||
error_count += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": error_msg or "处理失败",
|
||
"message": message,
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
except Exception as e:
|
||
logger.error(f"批处理失败: {str(e)}")
|
||
for _, _, index, url_str, _ in valid_items:
|
||
error_count += 1
|
||
completed_order += 1
|
||
result = {
|
||
"index": index,
|
||
"total": total,
|
||
"original_url": url_str,
|
||
"status": "error",
|
||
"error": str(e),
|
||
"message": f"批处理失败: {str(e)}",
|
||
"success_count": success_count,
|
||
"error_count": error_count,
|
||
"completed_order": completed_order,
|
||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||
}
|
||
yield result
|
||
|
||
pending_batch = []
|
||
|
||
# 流水线处理:收集队列中的图片,达到batch_size或超时后立即处理
|
||
while True:
|
||
try:
|
||
# 等待队列中有新图片,或超时
|
||
try:
|
||
item = await asyncio.wait_for(
|
||
download_queue.get(),
|
||
timeout=batch_collect_timeout
|
||
)
|
||
pending_batch.append(item)
|
||
except asyncio.TimeoutError:
|
||
# 超时,处理当前批次
|
||
if pending_batch:
|
||
async for result in process_pending_batch():
|
||
yield result
|
||
# 检查是否所有下载都完成
|
||
if download_complete.is_set():
|
||
break
|
||
continue
|
||
|
||
# 如果达到batch_size,立即处理
|
||
if len(pending_batch) >= batch_size:
|
||
async for result in process_pending_batch():
|
||
yield result
|
||
|
||
# 检查是否所有下载都完成
|
||
if download_complete.is_set() and download_queue.empty():
|
||
# 处理剩余的图片
|
||
if pending_batch:
|
||
async for result in process_pending_batch(force=True):
|
||
yield result
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"流水线处理出错: {str(e)}", exc_info=True)
|
||
break
|
||
|
||
# 等待所有下载任务完成
|
||
await asyncio.gather(*download_tasks, return_exceptions=True)
|
||
|
||
# 确保所有结果都已处理
|
||
if pending_batch:
|
||
async for result in process_pending_batch(force=True):
|
||
yield result
|
||
|
||
def is_valid_url(self, url):
|
||
"""验证URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
async def download_image(self, url):
|
||
"""异步从URL下载图片到临时文件"""
|
||
try:
|
||
response = await self.http_client.get(url)
|
||
response.raise_for_status()
|
||
|
||
def write_temp_file(content):
|
||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||
temp_file.write(content)
|
||
temp_file.close()
|
||
return temp_file.name
|
||
|
||
loop = asyncio.get_event_loop()
|
||
temp_file_path = await loop.run_in_executor(
|
||
self.executor, write_temp_file, response.content
|
||
)
|
||
|
||
return temp_file_path
|
||
except Exception as e:
|
||
raise Exception(f"下载图片失败: {e}")
|
||
|
||
async def cleanup(self):
|
||
"""清理资源"""
|
||
# 停止队列处理任务
|
||
if self.queue_running:
|
||
self.queue_running = False
|
||
if self.queue_task:
|
||
self.queue_task.cancel()
|
||
try:
|
||
await self.queue_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# 处理队列中剩余的请求
|
||
remaining_items = []
|
||
while not self.queue.empty():
|
||
try:
|
||
item = self.queue.get_nowait()
|
||
remaining_items.append(item)
|
||
except asyncio.QueueEmpty:
|
||
break
|
||
|
||
# 标记剩余请求为失败
|
||
for item in remaining_items:
|
||
if not item.future.done():
|
||
item.future.set_exception(Exception("服务关闭,请求被取消"))
|
||
|
||
await self.http_client.aclose()
|
||
self.executor.shutdown(wait=True)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
gc.collect() |