456 lines
18 KiB
Python
456 lines
18 KiB
Python
import os
|
||
import json
|
||
import io
|
||
import aiohttp
|
||
from typing import List
|
||
import tempfile
|
||
from gradio_client import Client, handle_file
|
||
import asyncio
|
||
from urllib.parse import urlparse
|
||
from settings import settings
|
||
|
||
tryon_server_url = settings.tryon_server_url
|
||
|
||
class TryonService:
|
||
# 默认配置
|
||
DEFAULT_CONFIG = {
|
||
'tryon_marker': "_tryon",
|
||
'tryon_target_marker': "tshirt",
|
||
'tryon_models_dir': "/files/models",
|
||
'denoise_steps': 20,
|
||
'seed': 42,
|
||
'is_crop': False,
|
||
'output_format': 'png'
|
||
}
|
||
|
||
def __init__(self):
|
||
"""初始化虚拟试穿服务"""
|
||
self.client = None
|
||
|
||
def get_gradio_client(self):
|
||
"""获取或初始化Gradio客户端"""
|
||
if self.client is None:
|
||
try:
|
||
self.client = Client(tryon_server_url)
|
||
except Exception:
|
||
self.client = None
|
||
return self.client
|
||
|
||
def _convert_config_types(self, config):
|
||
"""转换配置参数类型"""
|
||
if not config:
|
||
return {}
|
||
|
||
converted = {}
|
||
for key, value in config.items():
|
||
if key == 'denoise_steps':
|
||
try:
|
||
converted[key] = int(value)
|
||
except (ValueError, TypeError):
|
||
converted[key] = 20
|
||
elif key == 'seed':
|
||
try:
|
||
converted[key] = int(value)
|
||
except (ValueError, TypeError):
|
||
converted[key] = 42
|
||
elif key == 'is_crop':
|
||
try:
|
||
converted[key] = bool(value)
|
||
except (ValueError, TypeError):
|
||
converted[key] = False
|
||
else:
|
||
converted[key] = value
|
||
return converted
|
||
|
||
async def generate_virtual_tryon(self, tshirt_image_io: List[io.BytesIO], model_image_io: io.BytesIO, config=None):
|
||
"""生成虚拟试穿结果
|
||
|
||
Args:
|
||
tshirt_image_io: T恤图片IO对象列表
|
||
model_image_io: 模特图片IO对象
|
||
config: 配置参数
|
||
"""
|
||
if config is None:
|
||
config = {}
|
||
|
||
# 转换配置参数类型并合并默认配置
|
||
config = {**self.DEFAULT_CONFIG, **self._convert_config_types(config)}
|
||
|
||
# 检查图片大小
|
||
min_image_size = 1024 # 最小1KB
|
||
for tshirt_io in tshirt_image_io:
|
||
if len(tshirt_io.getvalue()) < min_image_size:
|
||
raise ValueError(f"T恤图片太小,可能不是有效图片,大小: {len(tshirt_io.getvalue())} 字节")
|
||
|
||
if len(model_image_io.getvalue()) < min_image_size:
|
||
raise ValueError(f"模特图片太小,可能不是有效图片,大小: {len(model_image_io.getvalue())} 字节")
|
||
|
||
client = self.get_gradio_client()
|
||
if client is None:
|
||
raise RuntimeError("Gradio API服务不可用,无法进行虚拟试穿")
|
||
|
||
# 创建临时目录
|
||
with tempfile.TemporaryDirectory() as temp_dir:
|
||
# 保存所有T-shirt图片为临时文件
|
||
temp_tshirt_files = []
|
||
for tshirt_io in tshirt_image_io:
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_tshirt_file:
|
||
temp_tshirt_file.write(tshirt_io.getvalue())
|
||
temp_tshirt_files.append(temp_tshirt_file.name)
|
||
|
||
# 保存模特图片为临时文件
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_model_file:
|
||
temp_model_file.write(model_image_io.getvalue())
|
||
temp_model_file_path = temp_model_file.name
|
||
|
||
try:
|
||
results = []
|
||
# 对每件T恤进行虚拟试穿
|
||
for temp_tshirt_file_path in temp_tshirt_files:
|
||
try:
|
||
# 调用API进行虚拟试穿
|
||
result = client.predict(
|
||
dict({"background": handle_file(temp_model_file_path), "layers": [], "composite": None}),
|
||
garm_img=handle_file(temp_tshirt_file_path),
|
||
garment_des="",
|
||
is_checked=True,
|
||
is_checked_crop=config.get('is_crop', False),
|
||
denoise_steps=config.get('denoise_steps', 20),
|
||
seed=config.get('seed', 42),
|
||
api_name="/tryon"
|
||
)
|
||
|
||
# 处理返回结果
|
||
if not result or not isinstance(result, tuple) or len(result) < 1:
|
||
raise RuntimeError("虚拟试穿服务返回了无效的结果格式")
|
||
|
||
output_path = result[0] # 使用第一个图片作为结果
|
||
if not os.path.exists(output_path):
|
||
raise RuntimeError(f"输出文件不存在: {output_path}")
|
||
|
||
with open(output_path, 'rb') as f:
|
||
result_data = f.read()
|
||
results.append(io.BytesIO(result_data))
|
||
|
||
except Exception as e:
|
||
results.append(None)
|
||
|
||
return results[0] if len(results) == 1 else results
|
||
|
||
except Exception as e:
|
||
raise
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
for temp_tshirt_file_path in temp_tshirt_files:
|
||
if os.path.exists(temp_tshirt_file_path):
|
||
os.remove(temp_tshirt_file_path)
|
||
if os.path.exists(temp_model_file_path):
|
||
os.remove(temp_model_file_path)
|
||
|
||
def is_valid_url(self, url):
|
||
"""检查URL是否有效"""
|
||
try:
|
||
result = urlparse(url)
|
||
return all([result.scheme, result.netloc])
|
||
except:
|
||
return False
|
||
|
||
async def download_image(self, url):
|
||
"""下载图片"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(url) as response:
|
||
if response.status == 200:
|
||
content = await response.read()
|
||
image_io = io.BytesIO(content)
|
||
|
||
# 检查内容长度
|
||
if len(content) < 100: # 太小可能不是有效图片
|
||
return None
|
||
|
||
# 检查内容类型
|
||
content_type = response.headers.get('Content-Type', '')
|
||
if 'image' not in content_type.lower():
|
||
# 检查文件头部魔术数字
|
||
header = content[:12]
|
||
is_image = any([
|
||
header.startswith(b'\x89PNG'), # PNG
|
||
header.startswith(b'\xff\xd8\xff'), # JPEG
|
||
header.startswith(b'GIF8'), # GIF
|
||
header.startswith(b'RIFF') and b'WEBP' in header # WEBP
|
||
])
|
||
if not is_image:
|
||
return None
|
||
|
||
return image_io
|
||
return None
|
||
except Exception:
|
||
return None
|
||
|
||
async def process_urls(self, tshirt_urls: List[str], model_url: str, config=None):
|
||
"""
|
||
处理多个T恤URL,流式返回结果
|
||
|
||
Args:
|
||
tshirt_urls: T恤图片URL列表
|
||
model_url: 模特图片URL
|
||
config: 配置参数
|
||
|
||
Yields:
|
||
每个T恤的处理结果
|
||
"""
|
||
total = len(tshirt_urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
# 下载模特图片
|
||
model_io = await self.download_image(model_url)
|
||
if model_io is None:
|
||
yield {
|
||
"status": "error",
|
||
"message": f"无法下载模特图片: {model_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count + 1,
|
||
"total": total
|
||
}
|
||
return
|
||
|
||
for i, tshirt_url in enumerate(tshirt_urls, 1):
|
||
try:
|
||
# 下载T恤图片
|
||
tshirt_io = await self.download_image(tshirt_url)
|
||
if tshirt_io is None:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": f"无法下载T恤图片: {tshirt_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
continue
|
||
|
||
# 处理图片
|
||
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||
if result is None:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": f"处理T恤图片失败: {tshirt_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
else:
|
||
success_count += 1
|
||
result.seek(0)
|
||
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||
json.dumps(result.read().hex())
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "success",
|
||
"data": base64_data,
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": str(e),
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
# 让出控制权,避免阻塞
|
||
await asyncio.sleep(0)
|
||
|
||
async def process_files(self, tshirt_contents: List[bytes], model_content: bytes, config=None):
|
||
"""
|
||
处理多个T恤文件,流式返回结果
|
||
|
||
Args:
|
||
tshirt_contents: T恤图片内容列表
|
||
model_content: 模特图片内容
|
||
config: 配置参数
|
||
|
||
Yields:
|
||
每个T恤的处理结果
|
||
"""
|
||
total = len(tshirt_contents)
|
||
success_count = 0
|
||
error_count = 0
|
||
|
||
model_io = io.BytesIO(model_content)
|
||
|
||
for i, content in enumerate(tshirt_contents, 1):
|
||
try:
|
||
tshirt_io = io.BytesIO(content)
|
||
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||
if result is None:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"status": "error",
|
||
"message": "处理T恤图片失败",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
else:
|
||
success_count += 1
|
||
result.seek(0)
|
||
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||
json.dumps(result.read().hex())
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"status": "success",
|
||
"data": base64_data,
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
yield {
|
||
"index": i,
|
||
"total": total,
|
||
"status": "error",
|
||
"message": str(e),
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
# 让出控制权,避免阻塞
|
||
await asyncio.sleep(0)
|
||
|
||
async def process_batch(self, tshirt_urls: List[str], model_urls: List[str], config=None):
|
||
"""
|
||
批量处理多个T恤和模特图片,流式返回结果
|
||
|
||
Args:
|
||
tshirt_urls: T恤图片URL列表
|
||
model_urls: 模特图片URL列表
|
||
config: 配置参数
|
||
|
||
Yields:
|
||
每个组合的处理结果
|
||
"""
|
||
total = len(tshirt_urls) * len(model_urls)
|
||
success_count = 0
|
||
error_count = 0
|
||
current_index = 0
|
||
|
||
for model_url in model_urls:
|
||
try:
|
||
model_io = await self.download_image(model_url)
|
||
if model_io is None:
|
||
error_count += len(tshirt_urls)
|
||
for tshirt_url in tshirt_urls:
|
||
current_index += 1
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": f"无法下载模特图片: {model_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
continue
|
||
|
||
for tshirt_url in tshirt_urls:
|
||
current_index += 1
|
||
try:
|
||
tshirt_io = await self.download_image(tshirt_url)
|
||
if tshirt_io is None:
|
||
error_count += 1
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": f"无法下载T恤图片: {tshirt_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
continue
|
||
|
||
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||
if result is None:
|
||
error_count += 1
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": f"处理T恤图片失败: {tshirt_url}",
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
else:
|
||
success_count += 1
|
||
result.seek(0)
|
||
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||
json.dumps(result.read().hex())
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "success",
|
||
"data": base64_data,
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
except Exception as e:
|
||
error_count += 1
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": str(e),
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
# 让出控制权,避免阻塞
|
||
await asyncio.sleep(0)
|
||
|
||
except Exception as e:
|
||
error_count += len(tshirt_urls)
|
||
for tshirt_url in tshirt_urls:
|
||
current_index += 1
|
||
yield {
|
||
"index": current_index,
|
||
"total": total,
|
||
"model_url": model_url,
|
||
"tshirt_url": tshirt_url,
|
||
"status": "error",
|
||
"message": str(e),
|
||
"success_count": success_count,
|
||
"error_count": error_count
|
||
}
|
||
|
||
def cleanup(self):
|
||
"""清理资源"""
|
||
self.client = None
|