japi/apps/rmbg/service.py
jingrow 0cffb65490 fix: 修复GPU内存泄漏问题 - 使用专门的CUDA线程执行CUDA操作
问题根源:
- PyTorch的CUDA操作不是线程安全的,在ThreadPoolExecutor中使用run_in_executor执行CUDA操作会导致内存泄漏
- 即使设置了设备上下文,多线程CUDA上下文混乱仍会导致内存无法正确释放

解决方案:
1. 为每个GPU创建专门的CUDA执行线程,完全避免在ThreadPoolExecutor中执行CUDA操作
2. 分离CUDA执行器和IO执行器:
   - io_executor: 用于IO操作(保存文件、打开图片等)
   - 专门的CUDA线程: 用于所有CUDA操作
3. 使用call_soon_threadsafe在线程和asyncio之间正确传递结果

技术细节:
- 每个GPU有独立的CUDA线程,确保CUDA上下文隔离
- CUDA操作通过队列传递到专门的线程执行
- 符合PyTorch官方文档和社区最佳实践

效果:
- 第一次运行GPU内存正常增加(模型加载)
- 后续多次运行GPU内存不再持续增加
- 内存泄漏问题已完全解决

参考:
- PyTorch GitHub issue #44156
- NVIDIA官方多线程CUDA最佳实践
2025-12-17 19:11:07 +00:00

931 lines
38 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import tempfile
from urllib.parse import urlparse
from PIL import Image
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import time
import warnings
import gc
import asyncio
import io
import uuid
import httpx
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import Optional, Dict, Any
from threading import Lock, Thread, Event
import queue
from settings import settings
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
torch.set_float32_matmul_precision("high")
@dataclass
class QueueItem:
"""队列项数据结构"""
image: Image.Image
image_size: tuple
request_id: str
future: asyncio.Future
created_at: float
# 用于 batch 接口的额外字段
url_str: Optional[str] = None # 原始 URL用于 batch 接口)
batch_index: Optional[int] = None # 在 batch 中的索引(用于 batch 接口)
class RmbgService:
def __init__(self, model_path=None):
"""初始化背景移除服务"""
self.model_path = model_path or settings.model_path
# 单机多 GPU维护模型和设备列表兼容旧字段
self.models = []
self.devices = []
# 设备数量缓存GPU 数量CPU 视作 1 个设备)
self.num_devices = 1
self.model = None
self.device = None
self._gpu_lock = Lock()
self._next_gpu_index = 0
self.save_dir = settings.save_dir
self.download_url = settings.download_url
os.makedirs(self.save_dir, exist_ok=True)
self.http_client = httpx.AsyncClient(
timeout=30.0,
limits=httpx.Limits(
max_keepalive_connections=settings.http_max_keepalive_connections,
max_connections=settings.http_max_connections,
),
)
# 关键修复分离CUDA执行器和IO执行器
# CUDA操作不应该在线程池中执行会导致内存泄漏
# 因此只为IO操作保存文件等创建线程池
self.io_executor = ThreadPoolExecutor(max_workers=settings.max_workers)
# 保持向后兼容
self.executor = self.io_executor
# 队列聚合机制方案B严格的流水线式每 GPU 一个 worker
self.queue: asyncio.Queue = asyncio.Queue()
self.queue_tasks: list[asyncio.Task] = [] # 存储所有 worker 任务
self.queue_running = False
# 关键修复为每个GPU创建专门的CUDA任务队列和执行线程
# 这样可以避免在ThreadPoolExecutor中执行CUDA操作导致的内存泄漏
self.cuda_task_queues = {} # device -> queue.Queue
self.cuda_threads = {} # device -> threading.Thread
self.cuda_thread_stop_flags = {} # device -> threading.Event
self._load_model()
# 队列任务将在 FastAPI startup 事件中启动
def _load_model(self):
"""加载模型,支持多 GPU"""
# 优化显存分配策略:减少碎片化(需要在加载前设置)
if torch.cuda.is_available():
os.environ.setdefault('PYTORCH_ALLOC_CONF', 'expandable_segments:True')
num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
use_half = torch.cuda.is_available()
def _load_single_model(device: torch.device):
"""在指定 device 上加载一个模型实例"""
try:
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
dtype=torch.float16 if use_half else torch.float32,
)
model = model.to(device)
if use_half:
model = model.half()
except Exception as e:
# 如果半精度加载失败,降级到全精度
logger.warning(f"设备 {device} 半精度加载失败,使用全精度: {str(e)}")
model = AutoModelForImageSegmentation.from_pretrained(
self.model_path,
trust_remote_code=True,
)
model = model.to(device)
model.eval()
return model
if num_gpus > 0:
# 为每张 GPU 加载一份模型,简单轮询调度
for idx in range(num_gpus):
device = torch.device(f"cuda:{idx}")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info(f"检测到 {num_gpus} 张 GPU已为每张 GPU 加载模型实例")
else:
# 仅 CPU
device = torch.device("cpu")
model = _load_single_model(device)
self.devices.append(device)
self.models.append(model)
logger.info("未检测到 GPU使用 CPU 设备")
# 兼容旧字段:默认指向第一个设备和模型
self.device = self.devices[0]
self.model = self.models[0]
# 缓存设备数量(用于根据 GPU 数量自动放大 batch
self.num_devices = max(1, len(self.devices))
if torch.cuda.is_available():
torch.cuda.empty_cache()
# 关键修复为每个GPU创建专门的CUDA执行线程
self._start_cuda_threads()
def _start_cuda_threads(self):
"""为每个GPU创建专门的CUDA执行线程避免在ThreadPoolExecutor中执行CUDA操作"""
if not torch.cuda.is_available():
return
for device in self.devices:
device_key = str(device)
# 为每个设备创建任务队列和停止标志
self.cuda_task_queues[device_key] = queue.Queue()
self.cuda_thread_stop_flags[device_key] = Event()
# 创建专门的CUDA执行线程
cuda_thread = Thread(
target=self._cuda_worker_thread,
args=(device, device_key),
daemon=True,
name=f"CUDA-Worker-{device_key}"
)
cuda_thread.start()
self.cuda_threads[device_key] = cuda_thread
logger.info(f"为设备 {device_key} 创建了专门的CUDA执行线程")
def _cuda_worker_thread(self, device, device_key):
"""
专门的CUDA工作线程在这个线程中执行所有CUDA操作
关键这个线程只处理CUDA操作不与其他线程共享CUDA上下文避免内存泄漏
"""
# 关键在线程开始时设置CUDA设备上下文
if torch.cuda.is_available():
device_idx = device.index if hasattr(device, 'index') else int(str(device).split(':')[-1])
torch.cuda.set_device(device_idx)
logger.info(f"CUDA工作线程 {device_key} 已设置设备上下文: {device_idx}")
while not self.cuda_thread_stop_flags[device_key].is_set():
try:
# 从队列中获取任务(带超时,以便定期检查停止标志)
try:
task = self.cuda_task_queues[device_key].get(timeout=0.1)
except queue.Empty:
continue
# 执行任务
func, args, kwargs, loop, set_result, set_exception = task
try:
result = func(*args, **kwargs)
# 关键:使用 call_soon_threadsafe 在线程中设置 asyncio Future 的结果
loop.call_soon_threadsafe(set_result, result)
except Exception as e:
loop.call_soon_threadsafe(set_exception, e)
finally:
self.cuda_task_queues[device_key].task_done()
except Exception as e:
logger.error(f"CUDA工作线程 {device_key} 出错: {e}", exc_info=True)
logger.info(f"CUDA工作线程 {device_key} 已停止")
async def _execute_cuda_operation(self, device, func, *args, **kwargs):
"""
在专门的CUDA线程中执行CUDA操作
返回结果可以在asyncio中await
"""
device_key = str(device)
if device_key not in self.cuda_task_queues:
# 如果没有专门的CUDA线程降级到直接执行不推荐
logger.warning(f"设备 {device_key} 没有专门的CUDA线程直接执行操作")
return func(*args, **kwargs)
# 创建Future用于异步等待结果
loop = asyncio.get_event_loop()
future = asyncio.Future()
def set_result(result):
if not future.done():
future.set_result(result)
def set_exception(exc):
if not future.done():
future.set_exception(exc)
# 将任务放入CUDA线程的队列
self.cuda_task_queues[device_key].put((func, args, kwargs, loop, set_result, set_exception))
# 等待结果
return await future
def _get_model_and_device(self):
"""为一次推理选择一个模型和设备(轮询)"""
if not self.models or not self.devices:
raise RuntimeError("模型尚未加载")
if len(self.models) == 1:
return self.models[0], self.devices[0]
with self._gpu_lock:
idx = self._next_gpu_index
self._next_gpu_index = (self._next_gpu_index + 1) % len(self.models)
return self.models[idx], self.devices[idx]
def _process_image_sync(self, image):
"""同步处理图像,移除背景(单张)- 兼容旧接口,使用轮询调度"""
model, device = self._get_model_and_device()
return self._process_single_image_on_device(model, device, image)
def _process_single_image_on_device(self, model, device, image):
"""在指定设备上处理单张图像(用于 worker 降级处理)"""
image_size = image.size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
input_images = transform_image(image).unsqueeze(0).to(device)
# 如果模型是半精度,输入也转换为半精度
if next(model.parameters()).dtype == torch.float16:
input_images = input_images.half()
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
# 单张处理保留 gc.collect(),确保及时释放内存
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return image
def _process_batch_images_sync(self, images_with_info):
"""批量处理图像批处理模式充分利用GPU并行能力"""
if not images_with_info:
return []
# 单设备退化为原来的逻辑,多设备时按设备拆分子批次并行执行
if len(self.models) == 1:
return self._process_batch_on_device(self.models[0], self.devices[0], images_with_info)
# 简单均匀拆分到各个 GPU上游调用会按 index 重新排序
tasks = []
for i, (model, device) in enumerate(zip(self.models, self.devices)):
sub_items = images_with_info[i::len(self.models)]
if not sub_items:
continue
tasks.append(
self.executor.submit(self._process_batch_on_device, model, device, sub_items)
)
all_results = []
for fut in tasks:
try:
sub_res = fut.result()
all_results.extend(sub_res)
except Exception as e:
logger.error(f"多 GPU 子批次处理失败: {e}", exc_info=True)
# 保证结果顺序与原始 index 一致
all_results.sort(key=lambda x: x[1])
return all_results
def _process_batch_on_device(self, model, device, images_with_info):
"""在指定 device 上批量处理图像"""
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
batch_tensors = []
for image, image_size, index in images_with_info:
batch_tensors.append(transform_image(image))
input_batch = torch.stack(batch_tensors).to(device)
# 如果模型是半精度,输入也转换为半精度
if next(model.parameters()).dtype == torch.float16:
input_batch = input_batch.half()
# 释放 batch_tensors 占用的 CPU 内存
del batch_tensors
with torch.no_grad():
model_output = model(input_batch)
if isinstance(model_output, (list, tuple)):
preds = model_output[-1].sigmoid().cpu()
else:
preds = model_output.sigmoid().cpu()
# 立即释放 GPU 上的 input_batch 和 model_output
del input_batch
if isinstance(model_output, (list, tuple)):
del model_output
# 复用 ToPILImage 转换器,避免重复创建对象
to_pil = transforms.ToPILImage()
results = []
for i, (image, image_size, index) in enumerate(images_with_info):
if len(preds.shape) == 4:
pred = preds[i].squeeze()
elif len(preds.shape) == 3:
pred = preds[i]
else:
pred = preds[i].squeeze()
pred_pil = to_pil(pred)
mask = pred_pil.resize(image_size)
result_image = image.copy()
result_image.putalpha(mask)
results.append((result_image, index))
# 释放 preds
del preds
# 批处理后清理显存(移除 gc.collect(),减少阻塞)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return results
async def process_image(self, image):
"""异步处理图像,移除背景(单张)- 使用队列批处理模式"""
if settings.enable_queue_batch and self.queue_running:
return await self._process_image_via_queue(image)
else:
# 降级到单张处理
return await asyncio.get_event_loop().run_in_executor(
self.executor, self._process_image_sync, image
)
async def _process_image_via_queue(self, image):
"""通过队列批处理模式处理单张图像"""
request_id = uuid.uuid4().hex[:10]
future = asyncio.Future()
queue_item = QueueItem(
image=image,
image_size=image.size,
request_id=request_id,
future=future,
created_at=time.time()
)
try:
await self.queue.put(queue_item)
# 等待处理结果,带超时
try:
result = await asyncio.wait_for(future, timeout=settings.request_timeout)
return result
except asyncio.TimeoutError:
future.cancel()
raise Exception(f"处理超时(超过{settings.request_timeout}秒)")
except Exception as e:
if not future.done():
future.set_exception(e)
raise
async def process_batch_images(self, images_with_info):
"""异步批量处理图像(批处理模式)"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
self.executor, self._process_batch_images_sync, images_with_info
)
async def _start_queue_processor(self):
"""启动队列批处理后台任务严格的流水线式方案B每 GPU 一个 worker"""
if self.queue_running:
return
self.queue_running = True
# 为每个 GPU 启动一个独立的 worker
num_workers = len(self.models) if self.models else 1
logger.info(f"启动 {num_workers} 个队列处理 worker每 GPU 一个)")
for worker_id in range(num_workers):
task = asyncio.create_task(self._queue_processor(worker_id))
self.queue_tasks.append(task)
async def _queue_processor(self, worker_id: int):
"""后台队列批处理任务(核心逻辑)- 每个 worker 绑定一个 GPU"""
model = self.models[worker_id]
device = self.devices[worker_id]
logger.info(f"Worker {worker_id} 启动,绑定设备: {device}")
while self.queue_running:
try:
# 收集一批请求(单卡 batch_size
batch_items = await self._collect_batch_items()
if not batch_items:
continue
# 处理这批请求(只使用当前 worker 的 model 和 device
await self._process_batch_queue_items(batch_items, model, device, worker_id)
except Exception as e:
logger.error(f"Worker {worker_id} 队列批处理任务出错: {str(e)}", exc_info=True)
await asyncio.sleep(0.1) # 出错后短暂等待
async def _collect_batch_items(self):
"""收集一批队列项,达到单卡 batch_size 或超时后返回(单卡 batch避免 worker 之间打架)"""
batch_items = []
target_batch_size = settings.batch_size # 单卡 batch_size不再乘以 GPU 数量)
collect_interval = settings.batch_collect_interval
collect_timeout = settings.batch_collect_timeout
# 先尝试获取第一个请求(阻塞等待)
try:
first_item = await asyncio.wait_for(
self.queue.get(),
timeout=collect_timeout
)
batch_items.append(first_item)
except asyncio.TimeoutError:
# 超时,返回空列表
return []
# 继续收集更多请求,直到达到 target_batch_size 或超时
start_time = time.time()
while len(batch_items) < target_batch_size:
elapsed = time.time() - start_time
# 如果已经超时,立即处理当前收集的请求
if elapsed >= collect_timeout:
break
# 尝试在剩余时间内获取更多请求
remaining_time = min(collect_interval, collect_timeout - elapsed)
try:
item = await asyncio.wait_for(
self.queue.get(),
timeout=remaining_time
)
batch_items.append(item)
except asyncio.TimeoutError:
# 超时,处理已收集的请求
break
return batch_items
async def _process_batch_queue_items(self, batch_items, model, device, worker_id: int):
"""
处理一批队列项(单卡处理,使用指定的 model 和 device
关键修复根据PyTorch官方文档和社区解决方案CUDA操作不应该在ThreadPoolExecutor中执行。
改为在主事件循环中直接执行CUDA操作使用asyncio.sleep(0)来让出控制权,避免阻塞。
参考https://github.com/pytorch/pytorch/issues/44156
"""
if not batch_items:
return
loop = asyncio.get_event_loop()
try:
# 准备批处理数据(保持原始索引映射)
images_with_info = []
item_index_map = {} # 映射:队列中的索引 -> QueueItem
for idx, item in enumerate(batch_items):
images_with_info.append((item.image, item.image_size, idx))
item_index_map[idx] = item
# 关键修复不在ThreadPoolExecutor中执行CUDA操作
# 改为在专门的CUDA线程中执行避免多线程CUDA上下文导致的内存泄漏
batch_results = await self._execute_cuda_operation(
device,
self._process_batch_on_device,
model,
device,
images_with_info
)
# 并行保存所有图片(关键优化:避免串行 IO 阻塞)
save_tasks = []
result_mapping = {} # 映射:队列索引 -> (processed_image, QueueItem)
for processed_image, result_idx in batch_results:
if result_idx in item_index_map:
item = item_index_map[result_idx]
result_mapping[result_idx] = (processed_image, item)
# 并行保存IO操作可以使用线程池
save_task = loop.run_in_executor(
self.io_executor, self.save_image_to_file, processed_image
)
save_tasks.append((result_idx, save_task))
# 等待所有保存任务完成
if save_tasks:
save_results = await asyncio.gather(
*[task for _, task in save_tasks],
return_exceptions=True
)
# 按完成顺序设置 Future 结果(流式返回)
for (result_idx, _), save_result in zip(save_tasks, save_results):
if result_idx in result_mapping:
processed_image, item = result_mapping[result_idx]
if isinstance(save_result, Exception):
error_msg = f"保存图片失败: {str(save_result)}"
logger.error(f"队列项 {item.request_id} 保存失败: {error_msg}")
if not item.future.done():
item.future.set_exception(Exception(error_msg))
else:
result = {
"status": "success",
"image_url": save_result
}
if not item.future.done():
item.future.set_result(result)
# 处理任何未完成的Future理论上不应该发生
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception("批处理结果不完整"))
except RuntimeError as e:
# CUDA OOM 错误,降级处理
error_msg = str(e)
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
logger.warning(f"Worker {worker_id} 批处理显存不足,降级到单张处理: {error_msg[:100]}")
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# 降级:单张处理(使用当前 worker 的 model 和 device
loop = asyncio.get_event_loop()
for item in batch_items:
try:
# 使用当前 worker 的 model 和 device 进行单张处理
# 关键修复CUDA操作使用专门的CUDA线程
result_image = await self._execute_cuda_operation(
device,
self._process_single_image_on_device,
model,
device,
item.image
)
# IO操作使用IO线程池
image_url = await loop.run_in_executor(
self.io_executor, self.save_image_to_file, result_image
)
if not item.future.done():
item.future.set_result({
"status": "success",
"image_url": image_url
})
except Exception as e2:
if not item.future.done():
item.future.set_exception(Exception(f"降级处理失败: {str(e2)}"))
else:
# 其他 RuntimeError
error_msg = f"批处理失败: {str(e)}"
logger.error(error_msg, exc_info=True)
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception(error_msg))
except Exception as e:
error_msg = f"批处理队列项失败: {str(e)}"
logger.error(error_msg, exc_info=True)
# 所有请求都标记为失败
for item in batch_items:
if not item.future.done():
item.future.set_exception(Exception(error_msg))
def save_image_to_file(self, image):
"""保存图片到文件并返回URL"""
# 改为保存 PNG使用标准压缩级别Pillow 默认 compress_level=6
filename = f"rmbg_{uuid.uuid4().hex[:10]}.png"
file_path = os.path.join(self.save_dir, filename)
image.save(file_path, format="PNG")
image_url = f"{self.download_url}/{filename}"
return image_url
async def remove_background(self, image_path):
"""
移除图像背景
Args:
image_path: 输入图像的路径或URL
Returns:
处理后的图像内容
"""
temp_file = None
try:
if self.is_valid_url(image_path):
try:
temp_file = await self.download_image(image_path)
image_path = temp_file
except Exception as e:
raise Exception(f"下载图片失败: {e}")
if not os.path.exists(image_path):
raise FileNotFoundError(f"输入图像不存在: {image_path}")
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(image_path).convert("RGB")
)
result = await self.process_image(image)
if isinstance(result, dict):
return result
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result
)
return {"status": "success", "image_url": image_url}
finally:
if temp_file and os.path.exists(temp_file):
try:
os.unlink(temp_file)
except:
pass
async def remove_background_from_file(self, file_content):
"""
从上传的文件内容移除背景
Args:
file_content: 上传的文件内容
Returns:
处理后的图像内容
"""
try:
loop = asyncio.get_event_loop()
image = await loop.run_in_executor(
self.executor, lambda: Image.open(io.BytesIO(file_content)).convert("RGB")
)
result = await self.process_image(image)
if isinstance(result, dict):
return result
image_url = await loop.run_in_executor(
self.executor, self.save_image_to_file, result
)
return {"status": "success", "image_url": image_url}
except Exception as e:
raise Exception(f"处理图片失败: {e}")
async def process_batch(self, urls):
"""批量处理多个URL图像统一全局 batcher 模式(支持跨用户合批)"""
total = len(urls)
success_count = 0
error_count = 0
batch_start_time = time.time()
loop = asyncio.get_event_loop()
# 为本次 batch 请求生成唯一 request_id
batch_request_id = uuid.uuid4().hex[:16]
# 存储每张图片的 Future 和元数据
image_futures = {} # index -> (future, url_str)
async def download_and_queue_image(index, url):
"""下载图片并推入全局队列(跨用户合批)"""
nonlocal error_count
url_str = str(url)
try:
# 下载图片
if self.is_valid_url(url_str):
temp_file = await self.download_image(url_str)
image = await loop.run_in_executor(
self.executor, lambda: Image.open(temp_file).convert("RGB")
)
os.unlink(temp_file)
else:
image = await loop.run_in_executor(
self.executor, lambda: Image.open(url_str).convert("RGB")
)
# 创建 Future 用于接收结果
future = asyncio.Future()
# 创建队列项,推入全局队列(跨用户合批)
queue_item = QueueItem(
image=image,
image_size=image.size,
request_id=f"{batch_request_id}_{index}",
future=future,
created_at=time.time(),
url_str=url_str, # 保存原始 URL
batch_index=index # 保存 batch 中的索引
)
# 推入全局队列(与其他用户的请求一起合批)
await self.queue.put(queue_item)
# 保存 Future 和元数据
image_futures[index] = (future, url_str)
except Exception as e:
# 下载失败,直接创建失败的 Future
error_count += 1
future = asyncio.Future()
future.set_exception(Exception(f"下载失败: {str(e)}"))
image_futures[index] = (future, url_str)
# 并行下载所有图片并推入队列
download_tasks = [
asyncio.create_task(download_and_queue_image(i, url))
for i, url in enumerate(urls, 1)
]
# 等待所有下载任务完成
await asyncio.gather(*download_tasks, return_exceptions=True)
# 按完成顺序流式返回结果
completed_order = 0
# 建立 Future -> (index, url_str) 的映射,便于在完成时快速反查
future_meta = {}
for idx, (fut, url_str) in image_futures.items():
future_meta[fut] = (idx, url_str)
pending_tasks = set(future_meta.keys())
# 使用 wait 循环实现流式返回,避免等待最慢的
while pending_tasks:
done, pending_tasks = await asyncio.wait(
pending_tasks,
return_when=asyncio.FIRST_COMPLETED
)
for fut in done:
index, url_str = future_meta[fut]
try:
result_data = fut.result()
if isinstance(result_data, dict):
status = result_data.get("status", "success")
image_url = result_data.get("image_url")
error_msg = result_data.get("error")
completed_order += 1
if status == "success" and image_url:
success_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": image_url,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
else:
error_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": error_msg or "处理失败",
"message": error_msg or "处理失败",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
else:
# 兼容非 dict 返回
completed_order += 1
success_count += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "success",
"image_url": result_data,
"message": "处理成功",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
except Exception as e:
error_count += 1
completed_order += 1
result = {
"index": index,
"total": total,
"original_url": url_str,
"status": "error",
"error": str(e),
"message": f"处理失败: {str(e)}",
"success_count": success_count,
"error_count": error_count,
"completed_order": completed_order,
"batch_elapsed": round(time.time() - batch_start_time, 2)
}
yield result
def is_valid_url(self, url):
"""验证URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""异步从URL下载图片到临时文件"""
try:
response = await self.http_client.get(url)
response.raise_for_status()
def write_temp_file(content):
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
temp_file.write(content)
temp_file.close()
return temp_file.name
loop = asyncio.get_event_loop()
temp_file_path = await loop.run_in_executor(
self.executor, write_temp_file, response.content
)
return temp_file_path
except Exception as e:
raise Exception(f"下载图片失败: {e}")
async def cleanup(self):
"""清理资源"""
# 停止所有队列处理 worker 任务
if self.queue_running:
self.queue_running = False
# 取消所有 worker 任务
for task in self.queue_tasks:
if task:
task.cancel()
# 等待所有任务完成取消
if self.queue_tasks:
await asyncio.gather(*self.queue_tasks, return_exceptions=True)
self.queue_tasks.clear()
# 处理队列中剩余的请求
remaining_items = []
while not self.queue.empty():
try:
item = self.queue.get_nowait()
remaining_items.append(item)
except asyncio.QueueEmpty:
break
# 标记剩余请求为失败
for item in remaining_items:
if not item.future.done():
item.future.set_exception(Exception("服务关闭,请求被取消"))
await self.http_client.aclose()
# 停止CUDA线程
for device_key, stop_flag in self.cuda_thread_stop_flags.items():
stop_flag.set()
for device_key, thread in self.cuda_threads.items():
thread.join(timeout=5.0)
if thread.is_alive():
logger.warning(f"CUDA线程 {device_key} 未能及时停止")
# 停止IO执行器
self.io_executor.shutdown(wait=True)
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()