import os import json import io import aiohttp from typing import List import tempfile from gradio_client import Client, handle_file import asyncio from urllib.parse import urlparse from settings import settings tryon_server_url = settings.tryon_server_url class TryonService: # 默认配置 DEFAULT_CONFIG = { 'tryon_marker': "_tryon", 'tryon_target_marker': "tshirt", 'tryon_models_dir': "/files/models", 'denoise_steps': 20, 'seed': 42, 'is_crop': False, 'output_format': 'png' } def __init__(self): """初始化虚拟试穿服务""" self.client = None def get_gradio_client(self): """获取或初始化Gradio客户端""" if self.client is None: try: self.client = Client(tryon_server_url) except Exception: self.client = None return self.client def _convert_config_types(self, config): """转换配置参数类型""" if not config: return {} converted = {} for key, value in config.items(): if key == 'denoise_steps': try: converted[key] = int(value) except (ValueError, TypeError): converted[key] = 20 elif key == 'seed': try: converted[key] = int(value) except (ValueError, TypeError): converted[key] = 42 elif key == 'is_crop': try: converted[key] = bool(value) except (ValueError, TypeError): converted[key] = False else: converted[key] = value return converted async def generate_virtual_tryon(self, tshirt_image_io: List[io.BytesIO], model_image_io: io.BytesIO, config=None): """生成虚拟试穿结果 Args: tshirt_image_io: T恤图片IO对象列表 model_image_io: 模特图片IO对象 config: 配置参数 """ if config is None: config = {} # 转换配置参数类型并合并默认配置 config = {**self.DEFAULT_CONFIG, **self._convert_config_types(config)} # 检查图片大小 min_image_size = 1024 # 最小1KB for tshirt_io in tshirt_image_io: if len(tshirt_io.getvalue()) < min_image_size: raise ValueError(f"T恤图片太小,可能不是有效图片,大小: {len(tshirt_io.getvalue())} 字节") if len(model_image_io.getvalue()) < min_image_size: raise ValueError(f"模特图片太小,可能不是有效图片,大小: {len(model_image_io.getvalue())} 字节") client = self.get_gradio_client() if client is None: raise RuntimeError("Gradio API服务不可用,无法进行虚拟试穿") # 创建临时目录 with tempfile.TemporaryDirectory() as temp_dir: # 保存所有T-shirt图片为临时文件 temp_tshirt_files = [] for tshirt_io in tshirt_image_io: with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_tshirt_file: temp_tshirt_file.write(tshirt_io.getvalue()) temp_tshirt_files.append(temp_tshirt_file.name) # 保存模特图片为临时文件 with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_model_file: temp_model_file.write(model_image_io.getvalue()) temp_model_file_path = temp_model_file.name try: results = [] # 对每件T恤进行虚拟试穿 for temp_tshirt_file_path in temp_tshirt_files: try: # 调用API进行虚拟试穿 result = client.predict( dict({"background": handle_file(temp_model_file_path), "layers": [], "composite": None}), garm_img=handle_file(temp_tshirt_file_path), garment_des="", is_checked=True, is_checked_crop=config.get('is_crop', False), denoise_steps=config.get('denoise_steps', 20), seed=config.get('seed', 42), api_name="/tryon" ) # 处理返回结果 if not result or not isinstance(result, tuple) or len(result) < 1: raise RuntimeError("虚拟试穿服务返回了无效的结果格式") output_path = result[0] # 使用第一个图片作为结果 if not os.path.exists(output_path): raise RuntimeError(f"输出文件不存在: {output_path}") with open(output_path, 'rb') as f: result_data = f.read() results.append(io.BytesIO(result_data)) except Exception as e: results.append(None) return results[0] if len(results) == 1 else results except Exception as e: raise finally: # 清理临时文件 for temp_tshirt_file_path in temp_tshirt_files: if os.path.exists(temp_tshirt_file_path): os.remove(temp_tshirt_file_path) if os.path.exists(temp_model_file_path): os.remove(temp_model_file_path) def is_valid_url(self, url): """检查URL是否有效""" try: result = urlparse(url) return all([result.scheme, result.netloc]) except: return False async def download_image(self, url): """下载图片""" try: async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: content = await response.read() image_io = io.BytesIO(content) # 检查内容长度 if len(content) < 100: # 太小可能不是有效图片 return None # 检查内容类型 content_type = response.headers.get('Content-Type', '') if 'image' not in content_type.lower(): # 检查文件头部魔术数字 header = content[:12] is_image = any([ header.startswith(b'\x89PNG'), # PNG header.startswith(b'\xff\xd8\xff'), # JPEG header.startswith(b'GIF8'), # GIF header.startswith(b'RIFF') and b'WEBP' in header # WEBP ]) if not is_image: return None return image_io return None except Exception: return None async def process_urls(self, tshirt_urls: List[str], model_url: str, config=None): """ 处理多个T恤URL,流式返回结果 Args: tshirt_urls: T恤图片URL列表 model_url: 模特图片URL config: 配置参数 Yields: 每个T恤的处理结果 """ total = len(tshirt_urls) success_count = 0 error_count = 0 # 下载模特图片 model_io = await self.download_image(model_url) if model_io is None: yield { "status": "error", "message": f"无法下载模特图片: {model_url}", "success_count": success_count, "error_count": error_count + 1, "total": total } return for i, tshirt_url in enumerate(tshirt_urls, 1): try: # 下载T恤图片 tshirt_io = await self.download_image(tshirt_url) if tshirt_io is None: error_count += 1 yield { "index": i, "total": total, "tshirt_url": tshirt_url, "status": "error", "message": f"无法下载T恤图片: {tshirt_url}", "success_count": success_count, "error_count": error_count } continue # 处理图片 result = await self.generate_virtual_tryon([tshirt_io], model_io, config) if result is None: error_count += 1 yield { "index": i, "total": total, "tshirt_url": tshirt_url, "status": "error", "message": f"处理T恤图片失败: {tshirt_url}", "success_count": success_count, "error_count": error_count } else: success_count += 1 result.seek(0) base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ json.dumps(result.read().hex()) yield { "index": i, "total": total, "tshirt_url": tshirt_url, "status": "success", "data": base64_data, "success_count": success_count, "error_count": error_count } except Exception as e: error_count += 1 yield { "index": i, "total": total, "tshirt_url": tshirt_url, "status": "error", "message": str(e), "success_count": success_count, "error_count": error_count } # 让出控制权,避免阻塞 await asyncio.sleep(0) async def process_files(self, tshirt_contents: List[bytes], model_content: bytes, config=None): """ 处理多个T恤文件,流式返回结果 Args: tshirt_contents: T恤图片内容列表 model_content: 模特图片内容 config: 配置参数 Yields: 每个T恤的处理结果 """ total = len(tshirt_contents) success_count = 0 error_count = 0 model_io = io.BytesIO(model_content) for i, content in enumerate(tshirt_contents, 1): try: tshirt_io = io.BytesIO(content) result = await self.generate_virtual_tryon([tshirt_io], model_io, config) if result is None: error_count += 1 yield { "index": i, "total": total, "status": "error", "message": "处理T恤图片失败", "success_count": success_count, "error_count": error_count } else: success_count += 1 result.seek(0) base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ json.dumps(result.read().hex()) yield { "index": i, "total": total, "status": "success", "data": base64_data, "success_count": success_count, "error_count": error_count } except Exception as e: error_count += 1 yield { "index": i, "total": total, "status": "error", "message": str(e), "success_count": success_count, "error_count": error_count } # 让出控制权,避免阻塞 await asyncio.sleep(0) async def process_batch(self, tshirt_urls: List[str], model_urls: List[str], config=None): """ 批量处理多个T恤和模特图片,流式返回结果 Args: tshirt_urls: T恤图片URL列表 model_urls: 模特图片URL列表 config: 配置参数 Yields: 每个组合的处理结果 """ total = len(tshirt_urls) * len(model_urls) success_count = 0 error_count = 0 current_index = 0 for model_url in model_urls: try: model_io = await self.download_image(model_url) if model_io is None: error_count += len(tshirt_urls) for tshirt_url in tshirt_urls: current_index += 1 yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "error", "message": f"无法下载模特图片: {model_url}", "success_count": success_count, "error_count": error_count } continue for tshirt_url in tshirt_urls: current_index += 1 try: tshirt_io = await self.download_image(tshirt_url) if tshirt_io is None: error_count += 1 yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "error", "message": f"无法下载T恤图片: {tshirt_url}", "success_count": success_count, "error_count": error_count } continue result = await self.generate_virtual_tryon([tshirt_io], model_io, config) if result is None: error_count += 1 yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "error", "message": f"处理T恤图片失败: {tshirt_url}", "success_count": success_count, "error_count": error_count } else: success_count += 1 result.seek(0) base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ json.dumps(result.read().hex()) yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "success", "data": base64_data, "success_count": success_count, "error_count": error_count } except Exception as e: error_count += 1 yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "error", "message": str(e), "success_count": success_count, "error_count": error_count } # 让出控制权,避免阻塞 await asyncio.sleep(0) except Exception as e: error_count += len(tshirt_urls) for tshirt_url in tshirt_urls: current_index += 1 yield { "index": current_index, "total": total, "model_url": model_url, "tshirt_url": tshirt_url, "status": "error", "message": str(e), "success_count": success_count, "error_count": error_count } def cleanup(self): """清理资源""" self.client = None