重构rmbg为流水线模式
This commit is contained in:
parent
7f03cc24e3
commit
b0e889f839
@ -443,7 +443,7 @@ class RmbgService:
|
|||||||
raise Exception(f"处理图片失败: {e}")
|
raise Exception(f"处理图片失败: {e}")
|
||||||
|
|
||||||
async def process_batch(self, urls):
|
async def process_batch(self, urls):
|
||||||
"""批量处理多个URL图像,批处理模式(推荐方案)"""
|
"""批量处理多个URL图像,流水线批处理模式(下载和处理并行)"""
|
||||||
total = len(urls)
|
total = len(urls)
|
||||||
success_count = 0
|
success_count = 0
|
||||||
error_count = 0
|
error_count = 0
|
||||||
@ -459,6 +459,12 @@ class RmbgService:
|
|||||||
batch_sizes = []
|
batch_sizes = []
|
||||||
stats_printed = False
|
stats_printed = False
|
||||||
|
|
||||||
|
# 流水线队列:收集已下载的图片
|
||||||
|
download_queue = asyncio.Queue()
|
||||||
|
download_complete = asyncio.Event()
|
||||||
|
download_done_count = 0
|
||||||
|
download_error_count = 0
|
||||||
|
|
||||||
def print_stats():
|
def print_stats():
|
||||||
"""输出性能统计信息"""
|
"""输出性能统计信息"""
|
||||||
nonlocal stats_printed
|
nonlocal stats_printed
|
||||||
@ -470,7 +476,7 @@ class RmbgService:
|
|||||||
other_time = total_time - download_time - gpu_inference_time - save_time
|
other_time = total_time - download_time - gpu_inference_time - save_time
|
||||||
|
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info("📊 批处理性能统计")
|
logger.info("📊 批处理性能统计(流水线模式)")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info(f"图片总数: {total}")
|
logger.info(f"图片总数: {total}")
|
||||||
logger.info(f"成功数量: {success_count}")
|
logger.info(f"成功数量: {success_count}")
|
||||||
@ -506,8 +512,10 @@ class RmbgService:
|
|||||||
logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s")
|
logger.info(f"📈 每批平均耗时: {avg_batch_time:.3f}s")
|
||||||
|
|
||||||
async def download_image_async(index, url):
|
async def download_image_async(index, url):
|
||||||
"""异步下载图片"""
|
"""异步下载图片并放入队列"""
|
||||||
|
nonlocal download_done_count, download_error_count
|
||||||
url_str = str(url)
|
url_str = str(url)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.is_valid_url(url_str):
|
if self.is_valid_url(url_str):
|
||||||
temp_file = await self.download_image(url_str)
|
temp_file = await self.download_image(url_str)
|
||||||
@ -519,202 +527,106 @@ class RmbgService:
|
|||||||
image = await loop.run_in_executor(
|
image = await loop.run_in_executor(
|
||||||
self.executor, lambda: Image.open(url_str).convert("RGB")
|
self.executor, lambda: Image.open(url_str).convert("RGB")
|
||||||
)
|
)
|
||||||
return (image, image.size, index, url_str, None)
|
|
||||||
|
# 下载成功,放入队列
|
||||||
|
await download_queue.put((image, image.size, index, url_str, None))
|
||||||
|
download_done_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return (None, None, index, url_str, str(e))
|
# 下载失败,也放入队列(标记为错误)
|
||||||
|
await download_queue.put((None, None, index, url_str, str(e)))
|
||||||
|
download_error_count += 1
|
||||||
|
download_done_count += 1
|
||||||
|
finally:
|
||||||
|
# 所有下载任务完成
|
||||||
|
if download_done_count >= total:
|
||||||
|
download_complete.set()
|
||||||
|
|
||||||
# 记录下载开始时间
|
# 启动所有下载任务(并行下载)
|
||||||
download_start_time = time.time()
|
download_start_time = time.time()
|
||||||
download_tasks = [download_image_async(i, url) for i, url in enumerate(urls, 1)]
|
download_tasks = [
|
||||||
downloaded_images = await asyncio.gather(*download_tasks)
|
asyncio.create_task(download_image_async(i, url))
|
||||||
download_time = time.time() - download_start_time
|
for i, url in enumerate(urls, 1)
|
||||||
|
]
|
||||||
|
|
||||||
valid_images = []
|
# 流水线批处理任务:收集队列中的图片,达到batch_size或超时后立即处理
|
||||||
failed_results = {}
|
completed_order = 0
|
||||||
|
pending_batch = []
|
||||||
|
batch_collect_timeout = 0.5 # 批处理收集超时(秒)
|
||||||
|
max_single_batch = batch_size * 2 # 允许最多2倍batch_size用于一次性处理
|
||||||
|
|
||||||
for item in downloaded_images:
|
async def process_pending_batch(force=False):
|
||||||
image, image_size, index, url_str, error = item
|
"""处理待处理的批次"""
|
||||||
if error:
|
nonlocal pending_batch, completed_order, success_count, error_count
|
||||||
failed_results[index] = {
|
nonlocal gpu_inference_time, save_time, batch_count, batch_sizes
|
||||||
|
|
||||||
|
if not pending_batch:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分离成功和失败的图片
|
||||||
|
valid_items = []
|
||||||
|
failed_items = []
|
||||||
|
|
||||||
|
for item in pending_batch:
|
||||||
|
image, image_size, index, url_str, error = item
|
||||||
|
if error:
|
||||||
|
failed_items.append((index, url_str, error))
|
||||||
|
else:
|
||||||
|
valid_items.append((image, image_size, index, url_str))
|
||||||
|
|
||||||
|
# 先处理下载失败的
|
||||||
|
for index, url_str, error in failed_items:
|
||||||
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": url_str,
|
"original_url": url_str,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": error,
|
"error": error,
|
||||||
"message": f"下载失败: {error}"
|
"message": f"下载失败: {error}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
else:
|
yield result
|
||||||
valid_images.append((image, image_size, index, url_str))
|
|
||||||
|
if not valid_items:
|
||||||
for index, result in failed_results.items():
|
pending_batch = []
|
||||||
error_count += 1
|
|
||||||
result["success_count"] = success_count
|
|
||||||
result["error_count"] = error_count
|
|
||||||
result["completed_order"] = len(failed_results)
|
|
||||||
result["batch_elapsed"] = round(time.time() - batch_start_time, 2)
|
|
||||||
yield result
|
|
||||||
|
|
||||||
completed_order = len(failed_results)
|
|
||||||
|
|
||||||
# 如果图片数量不太多(<= batch_size * 2),尝试一次性处理所有图片(避免分批,提升并发)
|
|
||||||
# 对于13张图片,batch_size=8,13 <= 16,会尝试一次性处理
|
|
||||||
# 如果显存不足,自动降级到分批处理
|
|
||||||
max_single_batch = batch_size * 2 # 允许最多2倍batch_size
|
|
||||||
use_single_batch = len(valid_images) <= max_single_batch
|
|
||||||
|
|
||||||
if use_single_batch:
|
|
||||||
try:
|
|
||||||
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_images]
|
|
||||||
|
|
||||||
# 记录GPU推理开始时间
|
|
||||||
gpu_start_time = time.time()
|
|
||||||
batch_results = await self.process_batch_images(images_with_info)
|
|
||||||
gpu_inference_time += time.time() - gpu_start_time
|
|
||||||
batch_count += 1
|
|
||||||
batch_sizes.append(len(images_with_info))
|
|
||||||
|
|
||||||
# 并行保存所有图片
|
|
||||||
save_tasks = []
|
|
||||||
result_mapping = {}
|
|
||||||
|
|
||||||
for processed_image, index in batch_results:
|
|
||||||
url_str = next(url for _, _, idx, url in valid_images if idx == index)
|
|
||||||
result_mapping[index] = (processed_image, url_str)
|
|
||||||
|
|
||||||
save_task = loop.run_in_executor(
|
|
||||||
self.executor, self.save_image_to_file, processed_image
|
|
||||||
)
|
|
||||||
save_tasks.append((index, save_task))
|
|
||||||
|
|
||||||
# 记录保存开始时间
|
|
||||||
save_start_time = time.time()
|
|
||||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
|
||||||
save_time += time.time() - save_start_time
|
|
||||||
|
|
||||||
for (index, _), image_url in zip(save_tasks, save_results):
|
|
||||||
if isinstance(image_url, Exception):
|
|
||||||
error_count += 1
|
|
||||||
completed_order += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": result_mapping[index][1],
|
|
||||||
"status": "error",
|
|
||||||
"error": str(image_url),
|
|
||||||
"message": f"保存图片失败: {str(image_url)}",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
else:
|
|
||||||
completed_order += 1
|
|
||||||
success_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": result_mapping[index][1],
|
|
||||||
"status": "success",
|
|
||||||
"image_url": image_url,
|
|
||||||
"message": "处理成功",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
except RuntimeError as e:
|
|
||||||
# 检测CUDA OOM错误,自动降级到分批处理
|
|
||||||
error_msg = str(e)
|
|
||||||
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
|
|
||||||
logger.warning(f"一次性处理显存不足,自动降级到分批处理: {error_msg[:100]}")
|
|
||||||
# 清理显存
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
gc.collect()
|
|
||||||
# 继续执行分批处理逻辑(不return,继续到else分支)
|
|
||||||
use_single_batch = False
|
|
||||||
else:
|
|
||||||
# 其他错误,直接返回
|
|
||||||
logger.error(f"批处理失败: {error_msg}")
|
|
||||||
for _, _, index, url_str in valid_images:
|
|
||||||
completed_order += 1
|
|
||||||
error_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": url_str,
|
|
||||||
"status": "error",
|
|
||||||
"error": error_msg,
|
|
||||||
"message": f"批处理失败: {error_msg}",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
print_stats()
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
# 其他异常,直接返回错误
|
|
||||||
logger.error(f"批处理失败: {str(e)}")
|
|
||||||
for _, _, index, url_str in valid_images:
|
|
||||||
completed_order += 1
|
|
||||||
error_count += 1
|
|
||||||
result = {
|
|
||||||
"index": index,
|
|
||||||
"total": total,
|
|
||||||
"original_url": url_str,
|
|
||||||
"status": "error",
|
|
||||||
"error": str(e),
|
|
||||||
"message": f"批处理失败: {str(e)}",
|
|
||||||
"success_count": success_count,
|
|
||||||
"error_count": error_count,
|
|
||||||
"completed_order": completed_order,
|
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
|
||||||
}
|
|
||||||
yield result
|
|
||||||
print_stats()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 如果一次性处理失败(显存不足)或图片数量太多,使用分批处理
|
# 处理成功的图片
|
||||||
if not use_single_batch:
|
try:
|
||||||
# 多批处理:串行处理批次,但每个批次内部并行保存
|
# 判断是否尝试一次性处理
|
||||||
for batch_start in range(0, len(valid_images), batch_size):
|
use_single_batch = len(valid_items) <= max_single_batch and force
|
||||||
batch_end = min(batch_start + batch_size, len(valid_images))
|
|
||||||
batch_images = valid_images[batch_start:batch_end]
|
|
||||||
|
|
||||||
try:
|
if use_single_batch:
|
||||||
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_images]
|
# 尝试一次性处理所有图片
|
||||||
|
images_with_info = [(img, size, idx) for img, size, idx, _ in valid_items]
|
||||||
|
|
||||||
# 记录GPU推理开始时间
|
|
||||||
gpu_start_time = time.time()
|
gpu_start_time = time.time()
|
||||||
batch_results = await self.process_batch_images(images_with_info)
|
batch_results = await self.process_batch_images(images_with_info)
|
||||||
gpu_inference_time += time.time() - gpu_start_time
|
gpu_inference_time += time.time() - gpu_start_time
|
||||||
batch_count += 1
|
batch_count += 1
|
||||||
batch_sizes.append(len(images_with_info))
|
batch_sizes.append(len(images_with_info))
|
||||||
|
|
||||||
# 并行保存所有图片
|
# 并行保存
|
||||||
save_tasks = []
|
save_tasks = []
|
||||||
result_mapping = {}
|
result_mapping = {}
|
||||||
|
|
||||||
for processed_image, index in batch_results:
|
for processed_image, index in batch_results:
|
||||||
url_str = next(url for _, _, idx, url in batch_images if idx == index)
|
url_str = next(url for _, _, idx, url in valid_items if idx == index)
|
||||||
result_mapping[index] = (processed_image, url_str)
|
result_mapping[index] = (processed_image, url_str)
|
||||||
|
|
||||||
save_task = loop.run_in_executor(
|
save_task = loop.run_in_executor(
|
||||||
self.executor, self.save_image_to_file, processed_image
|
self.executor, self.save_image_to_file, processed_image
|
||||||
)
|
)
|
||||||
save_tasks.append((index, save_task))
|
save_tasks.append((index, save_task))
|
||||||
|
|
||||||
# 记录保存开始时间
|
|
||||||
save_start_time = time.time()
|
save_start_time = time.time()
|
||||||
# 并行执行所有保存任务
|
|
||||||
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||||
save_time += time.time() - save_start_time
|
save_time += time.time() - save_start_time
|
||||||
|
|
||||||
# 按顺序返回结果
|
|
||||||
for (index, _), image_url in zip(save_tasks, save_results):
|
for (index, _), image_url in zip(save_tasks, save_results):
|
||||||
if isinstance(image_url, Exception):
|
if isinstance(image_url, Exception):
|
||||||
error_count += 1
|
error_count += 1
|
||||||
@ -748,25 +660,202 @@ class RmbgService:
|
|||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
yield result
|
yield result
|
||||||
|
else:
|
||||||
except Exception as e:
|
# 分批处理
|
||||||
logger.error(f"批处理失败: {str(e)}")
|
for batch_start in range(0, len(valid_items), batch_size):
|
||||||
for _, _, index, url_str in batch_images:
|
batch_end = min(batch_start + batch_size, len(valid_items))
|
||||||
completed_order += 1
|
batch_items = valid_items[batch_start:batch_end]
|
||||||
|
|
||||||
|
images_with_info = [(img, size, idx) for img, size, idx, _ in batch_items]
|
||||||
|
|
||||||
|
gpu_start_time = time.time()
|
||||||
|
batch_results = await self.process_batch_images(images_with_info)
|
||||||
|
gpu_inference_time += time.time() - gpu_start_time
|
||||||
|
batch_count += 1
|
||||||
|
batch_sizes.append(len(images_with_info))
|
||||||
|
|
||||||
|
# 并行保存
|
||||||
|
save_tasks = []
|
||||||
|
result_mapping = {}
|
||||||
|
|
||||||
|
for processed_image, index in batch_results:
|
||||||
|
url_str = next(url for _, _, idx, url in batch_items if idx == index)
|
||||||
|
result_mapping[index] = (processed_image, url_str)
|
||||||
|
save_task = loop.run_in_executor(
|
||||||
|
self.executor, self.save_image_to_file, processed_image
|
||||||
|
)
|
||||||
|
save_tasks.append((index, save_task))
|
||||||
|
|
||||||
|
save_start_time = time.time()
|
||||||
|
save_results = await asyncio.gather(*[task for _, task in save_tasks], return_exceptions=True)
|
||||||
|
save_time += time.time() - save_start_time
|
||||||
|
|
||||||
|
for (index, _), image_url in zip(save_tasks, save_results):
|
||||||
|
if isinstance(image_url, Exception):
|
||||||
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": result_mapping[index][1],
|
||||||
|
"status": "error",
|
||||||
|
"error": str(image_url),
|
||||||
|
"message": f"保存图片失败: {str(image_url)}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
completed_order += 1
|
||||||
|
success_count += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": result_mapping[index][1],
|
||||||
|
"status": "success",
|
||||||
|
"image_url": image_url,
|
||||||
|
"message": "处理成功",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
|
||||||
|
except RuntimeError as e:
|
||||||
|
# CUDA OOM错误,降级处理
|
||||||
|
error_msg = str(e)
|
||||||
|
if "CUDA out of memory" in error_msg or "out of memory" in error_msg.lower():
|
||||||
|
logger.warning(f"批处理显存不足,降级处理: {error_msg[:100]}")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# 降级到单张处理
|
||||||
|
for image, image_size, index, url_str, _ in valid_items:
|
||||||
|
try:
|
||||||
|
processed_image = await self.process_image(image)
|
||||||
|
image_url = await loop.run_in_executor(
|
||||||
|
self.executor, self.save_image_to_file, processed_image
|
||||||
|
)
|
||||||
|
completed_order += 1
|
||||||
|
success_count += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "success",
|
||||||
|
"image_url": image_url,
|
||||||
|
"message": "处理成功(降级模式)",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
except Exception as e2:
|
||||||
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e2),
|
||||||
|
"message": f"处理失败: {str(e2)}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
else:
|
||||||
|
# 其他错误
|
||||||
|
logger.error(f"批处理失败: {error_msg}")
|
||||||
|
for _, _, index, url_str, _ in valid_items:
|
||||||
error_count += 1
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
result = {
|
result = {
|
||||||
"index": index,
|
"index": index,
|
||||||
"total": total,
|
"total": total,
|
||||||
"original_url": url_str,
|
"original_url": url_str,
|
||||||
"status": "error",
|
"status": "error",
|
||||||
"error": str(e),
|
"error": error_msg,
|
||||||
"message": f"批处理失败: {str(e)}",
|
"message": f"批处理失败: {error_msg}",
|
||||||
"success_count": success_count,
|
"success_count": success_count,
|
||||||
"error_count": error_count,
|
"error_count": error_count,
|
||||||
"completed_order": completed_order,
|
"completed_order": completed_order,
|
||||||
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
}
|
}
|
||||||
yield result
|
yield result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批处理失败: {str(e)}")
|
||||||
|
for _, _, index, url_str, _ in valid_items:
|
||||||
|
error_count += 1
|
||||||
|
completed_order += 1
|
||||||
|
result = {
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": f"批处理失败: {str(e)}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"completed_order": completed_order,
|
||||||
|
"batch_elapsed": round(time.time() - batch_start_time, 2)
|
||||||
|
}
|
||||||
|
yield result
|
||||||
|
|
||||||
|
pending_batch = []
|
||||||
|
|
||||||
|
# 流水线处理:收集队列中的图片,达到batch_size或超时后立即处理
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 等待队列中有新图片,或超时
|
||||||
|
try:
|
||||||
|
item = await asyncio.wait_for(
|
||||||
|
download_queue.get(),
|
||||||
|
timeout=batch_collect_timeout
|
||||||
|
)
|
||||||
|
pending_batch.append(item)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 超时,处理当前批次
|
||||||
|
if pending_batch:
|
||||||
|
async for result in process_pending_batch():
|
||||||
|
yield result
|
||||||
|
# 检查是否所有下载都完成
|
||||||
|
if download_complete.is_set():
|
||||||
|
break
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 如果达到batch_size,立即处理
|
||||||
|
if len(pending_batch) >= batch_size:
|
||||||
|
async for result in process_pending_batch():
|
||||||
|
yield result
|
||||||
|
|
||||||
|
# 检查是否所有下载都完成
|
||||||
|
if download_complete.is_set() and download_queue.empty():
|
||||||
|
# 处理剩余的图片
|
||||||
|
if pending_batch:
|
||||||
|
async for result in process_pending_batch(force=True):
|
||||||
|
yield result
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流水线处理出错: {str(e)}", exc_info=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
# 等待所有下载任务完成
|
||||||
|
await asyncio.gather(*download_tasks, return_exceptions=True)
|
||||||
|
download_time = time.time() - download_start_time
|
||||||
|
|
||||||
|
# 确保所有结果都已处理
|
||||||
|
if pending_batch:
|
||||||
|
async for result in process_pending_batch(force=True):
|
||||||
|
yield result
|
||||||
|
|
||||||
# 输出性能统计信息
|
# 输出性能统计信息
|
||||||
print_stats()
|
print_stats()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user