japi/apps/tryon/service.py
2025-05-12 02:39:56 +08:00

456 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import io
import aiohttp
from typing import List
import tempfile
from gradio_client import Client, handle_file
import asyncio
from urllib.parse import urlparse
from settings import settings
tryon_server_url = settings.tryon_server_url
class TryonService:
# 默认配置
DEFAULT_CONFIG = {
'tryon_marker': "_tryon",
'tryon_target_marker': "tshirt",
'tryon_models_dir': "/files/models",
'denoise_steps': 20,
'seed': 42,
'is_crop': False,
'output_format': 'png'
}
def __init__(self):
"""初始化虚拟试穿服务"""
self.client = None
def get_gradio_client(self):
"""获取或初始化Gradio客户端"""
if self.client is None:
try:
self.client = Client(tryon_server_url)
except Exception:
self.client = None
return self.client
def _convert_config_types(self, config):
"""转换配置参数类型"""
if not config:
return {}
converted = {}
for key, value in config.items():
if key == 'denoise_steps':
try:
converted[key] = int(value)
except (ValueError, TypeError):
converted[key] = 20
elif key == 'seed':
try:
converted[key] = int(value)
except (ValueError, TypeError):
converted[key] = 42
elif key == 'is_crop':
try:
converted[key] = bool(value)
except (ValueError, TypeError):
converted[key] = False
else:
converted[key] = value
return converted
async def generate_virtual_tryon(self, tshirt_image_io: List[io.BytesIO], model_image_io: io.BytesIO, config=None):
"""生成虚拟试穿结果
Args:
tshirt_image_io: T恤图片IO对象列表
model_image_io: 模特图片IO对象
config: 配置参数
"""
if config is None:
config = {}
# 转换配置参数类型并合并默认配置
config = {**self.DEFAULT_CONFIG, **self._convert_config_types(config)}
# 检查图片大小
min_image_size = 1024 # 最小1KB
for tshirt_io in tshirt_image_io:
if len(tshirt_io.getvalue()) < min_image_size:
raise ValueError(f"T恤图片太小可能不是有效图片大小: {len(tshirt_io.getvalue())} 字节")
if len(model_image_io.getvalue()) < min_image_size:
raise ValueError(f"模特图片太小,可能不是有效图片,大小: {len(model_image_io.getvalue())} 字节")
client = self.get_gradio_client()
if client is None:
raise RuntimeError("Gradio API服务不可用无法进行虚拟试穿")
# 创建临时目录
with tempfile.TemporaryDirectory() as temp_dir:
# 保存所有T-shirt图片为临时文件
temp_tshirt_files = []
for tshirt_io in tshirt_image_io:
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_tshirt_file:
temp_tshirt_file.write(tshirt_io.getvalue())
temp_tshirt_files.append(temp_tshirt_file.name)
# 保存模特图片为临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_model_file:
temp_model_file.write(model_image_io.getvalue())
temp_model_file_path = temp_model_file.name
try:
results = []
# 对每件T恤进行虚拟试穿
for temp_tshirt_file_path in temp_tshirt_files:
try:
# 调用API进行虚拟试穿
result = client.predict(
dict({"background": handle_file(temp_model_file_path), "layers": [], "composite": None}),
garm_img=handle_file(temp_tshirt_file_path),
garment_des="",
is_checked=True,
is_checked_crop=config.get('is_crop', False),
denoise_steps=config.get('denoise_steps', 20),
seed=config.get('seed', 42),
api_name="/tryon"
)
# 处理返回结果
if not result or not isinstance(result, tuple) or len(result) < 1:
raise RuntimeError("虚拟试穿服务返回了无效的结果格式")
output_path = result[0] # 使用第一个图片作为结果
if not os.path.exists(output_path):
raise RuntimeError(f"输出文件不存在: {output_path}")
with open(output_path, 'rb') as f:
result_data = f.read()
results.append(io.BytesIO(result_data))
except Exception as e:
results.append(None)
return results[0] if len(results) == 1 else results
except Exception as e:
raise
finally:
# 清理临时文件
for temp_tshirt_file_path in temp_tshirt_files:
if os.path.exists(temp_tshirt_file_path):
os.remove(temp_tshirt_file_path)
if os.path.exists(temp_model_file_path):
os.remove(temp_model_file_path)
def is_valid_url(self, url):
"""检查URL是否有效"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except:
return False
async def download_image(self, url):
"""下载图片"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
content = await response.read()
image_io = io.BytesIO(content)
# 检查内容长度
if len(content) < 100: # 太小可能不是有效图片
return None
# 检查内容类型
content_type = response.headers.get('Content-Type', '')
if 'image' not in content_type.lower():
# 检查文件头部魔术数字
header = content[:12]
is_image = any([
header.startswith(b'\x89PNG'), # PNG
header.startswith(b'\xff\xd8\xff'), # JPEG
header.startswith(b'GIF8'), # GIF
header.startswith(b'RIFF') and b'WEBP' in header # WEBP
])
if not is_image:
return None
return image_io
return None
except Exception:
return None
async def process_urls(self, tshirt_urls: List[str], model_url: str, config=None):
"""
处理多个T恤URL流式返回结果
Args:
tshirt_urls: T恤图片URL列表
model_url: 模特图片URL
config: 配置参数
Yields:
每个T恤的处理结果
"""
total = len(tshirt_urls)
success_count = 0
error_count = 0
# 下载模特图片
model_io = await self.download_image(model_url)
if model_io is None:
yield {
"status": "error",
"message": f"无法下载模特图片: {model_url}",
"success_count": success_count,
"error_count": error_count + 1,
"total": total
}
return
for i, tshirt_url in enumerate(tshirt_urls, 1):
try:
# 下载T恤图片
tshirt_io = await self.download_image(tshirt_url)
if tshirt_io is None:
error_count += 1
yield {
"index": i,
"total": total,
"tshirt_url": tshirt_url,
"status": "error",
"message": f"无法下载T恤图片: {tshirt_url}",
"success_count": success_count,
"error_count": error_count
}
continue
# 处理图片
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
if result is None:
error_count += 1
yield {
"index": i,
"total": total,
"tshirt_url": tshirt_url,
"status": "error",
"message": f"处理T恤图片失败: {tshirt_url}",
"success_count": success_count,
"error_count": error_count
}
else:
success_count += 1
result.seek(0)
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
json.dumps(result.read().hex())
yield {
"index": i,
"total": total,
"tshirt_url": tshirt_url,
"status": "success",
"data": base64_data,
"success_count": success_count,
"error_count": error_count
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"tshirt_url": tshirt_url,
"status": "error",
"message": str(e),
"success_count": success_count,
"error_count": error_count
}
# 让出控制权,避免阻塞
await asyncio.sleep(0)
async def process_files(self, tshirt_contents: List[bytes], model_content: bytes, config=None):
"""
处理多个T恤文件流式返回结果
Args:
tshirt_contents: T恤图片内容列表
model_content: 模特图片内容
config: 配置参数
Yields:
每个T恤的处理结果
"""
total = len(tshirt_contents)
success_count = 0
error_count = 0
model_io = io.BytesIO(model_content)
for i, content in enumerate(tshirt_contents, 1):
try:
tshirt_io = io.BytesIO(content)
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
if result is None:
error_count += 1
yield {
"index": i,
"total": total,
"status": "error",
"message": "处理T恤图片失败",
"success_count": success_count,
"error_count": error_count
}
else:
success_count += 1
result.seek(0)
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
json.dumps(result.read().hex())
yield {
"index": i,
"total": total,
"status": "success",
"data": base64_data,
"success_count": success_count,
"error_count": error_count
}
except Exception as e:
error_count += 1
yield {
"index": i,
"total": total,
"status": "error",
"message": str(e),
"success_count": success_count,
"error_count": error_count
}
# 让出控制权,避免阻塞
await asyncio.sleep(0)
async def process_batch(self, tshirt_urls: List[str], model_urls: List[str], config=None):
"""
批量处理多个T恤和模特图片流式返回结果
Args:
tshirt_urls: T恤图片URL列表
model_urls: 模特图片URL列表
config: 配置参数
Yields:
每个组合的处理结果
"""
total = len(tshirt_urls) * len(model_urls)
success_count = 0
error_count = 0
current_index = 0
for model_url in model_urls:
try:
model_io = await self.download_image(model_url)
if model_io is None:
error_count += len(tshirt_urls)
for tshirt_url in tshirt_urls:
current_index += 1
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "error",
"message": f"无法下载模特图片: {model_url}",
"success_count": success_count,
"error_count": error_count
}
continue
for tshirt_url in tshirt_urls:
current_index += 1
try:
tshirt_io = await self.download_image(tshirt_url)
if tshirt_io is None:
error_count += 1
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "error",
"message": f"无法下载T恤图片: {tshirt_url}",
"success_count": success_count,
"error_count": error_count
}
continue
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
if result is None:
error_count += 1
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "error",
"message": f"处理T恤图片失败: {tshirt_url}",
"success_count": success_count,
"error_count": error_count
}
else:
success_count += 1
result.seek(0)
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
json.dumps(result.read().hex())
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "success",
"data": base64_data,
"success_count": success_count,
"error_count": error_count
}
except Exception as e:
error_count += 1
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "error",
"message": str(e),
"success_count": success_count,
"error_count": error_count
}
# 让出控制权,避免阻塞
await asyncio.sleep(0)
except Exception as e:
error_count += len(tshirt_urls)
for tshirt_url in tshirt_urls:
current_index += 1
yield {
"index": current_index,
"total": total,
"model_url": model_url,
"tshirt_url": tshirt_url,
"status": "error",
"message": str(e),
"success_count": success_count,
"error_count": error_count
}
def cleanup(self):
"""清理资源"""
self.client = None