From 36f299403f3ae54aca233ff88514e3f750ed0613 Mon Sep 17 00:00:00 2001 From: jingrow Date: Tue, 16 Dec 2025 10:46:06 +0000 Subject: [PATCH] =?UTF-8?q?rmbg=E5=90=8C=E4=B8=80=E6=89=B9=E6=AC=A1?= =?UTF-8?q?=E5=86=85=E6=94=B9=E4=B8=BA=E6=B5=81=E5=BC=8F=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/rmbg/service.py | 147 +++++++++++++++++++++++-------------------- 1 file changed, 79 insertions(+), 68 deletions(-) diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py index 90bd33c..043e7d7 100644 --- a/apps/rmbg/service.py +++ b/apps/rmbg/service.py @@ -607,79 +607,90 @@ class RmbgService: local_batch_items = valid_items[local_batch_start:local_batch_end] # 本地批次内并发调用 process_image,让全局队列有机会凑大 batch 并利用多 GPU - tasks = [ - self.process_image(image) - for image, _, _, _ in local_batch_items - ] - batch_results = await asyncio.gather(*tasks, return_exceptions=True) + # 改为按完成顺序流式返回,避免等待最慢的那张 + tasks = [] + task_meta = {} + for image, _, index, url_str in local_batch_items: + t = asyncio.create_task(self.process_image(image)) + tasks.append(t) + task_meta[t] = (index, url_str) + + # 使用 wait 循环而不是 as_completed,避免未等待的协程残留 + pending_tasks = set(tasks) + while pending_tasks: + done, pending_tasks = await asyncio.wait( + pending_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for t in done: + index, url_str = task_meta[t] + try: + result_data = await t + except Exception as e: + error_count += 1 + completed_order += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": str(e), + "message": f"处理失败: {str(e)}", + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + yield result + continue + + if isinstance(result_data, dict): + # 队列模式下 process_image 通常直接返回 {status, image_url} + status = result_data.get("status", "success") + image_url = result_data.get("image_url") + message = result_data.get("message", "处理成功" if status == "success" else "处理失败") + error_msg = result_data.get("error") + else: + # 兼容非 dict 返回:手动保存图片 + image_url = await loop.run_in_executor( + self.executor, self.save_image_to_file, result_data + ) + status = "success" + message = "处理成功" + error_msg = None - for (image, _, index, url_str), result_data in zip(local_batch_items, batch_results): - if isinstance(result_data, Exception): - error_count += 1 completed_order += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": str(result_data), - "message": f"处理失败: {str(result_data)}", - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } + if status == "success" and image_url: + success_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "success", + "image_url": image_url, + "message": message, + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } + else: + error_count += 1 + result = { + "index": index, + "total": total, + "original_url": url_str, + "status": "error", + "error": error_msg or "处理失败", + "message": message, + "success_count": success_count, + "error_count": error_count, + "completed_order": completed_order, + "batch_elapsed": round(time.time() - batch_start_time, 2) + } yield result - continue - - if isinstance(result_data, dict): - # 队列模式下 process_image 通常直接返回 {status, image_url} - status = result_data.get("status", "success") - image_url = result_data.get("image_url") - message = result_data.get("message", "处理成功" if status == "success" else "处理失败") - error_msg = result_data.get("error") - else: - # 兼容非 dict 返回:手动保存图片 - image_url = await loop.run_in_executor( - self.executor, self.save_image_to_file, result_data - ) - status = "success" - message = "处理成功" - error_msg = None - - completed_order += 1 - if status == "success" and image_url: - success_count += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "success", - "image_url": image_url, - "message": message, - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - else: - error_count += 1 - result = { - "index": index, - "total": total, - "original_url": url_str, - "status": "error", - "error": error_msg or "处理失败", - "message": message, - "success_count": success_count, - "error_count": error_count, - "completed_order": completed_order, - "batch_elapsed": round(time.time() - batch_start_time, 2) - } - yield result except Exception as e: logger.error(f"批处理失败: {str(e)}") - for _, _, index, url_str, _ in valid_items: + for _, _, index, url_str in valid_items: error_count += 1 completed_order += 1 result = {