72 lines
2.4 KiB
Python
72 lines
2.4 KiB
Python
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
|
|
from fastapi.responses import StreamingResponse
|
|
from service import TryonService
|
|
from utils import jingrow_api_verify_and_billing
|
|
from settings import settings
|
|
import json
|
|
from typing import List
|
|
|
|
router = APIRouter(prefix=settings.router_prefix)
|
|
service = TryonService()
|
|
|
|
@router.post(settings.batch_route)
|
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
|
async def tryon_batch(data: dict, request: Request):
|
|
if "tshirt_urls" not in data:
|
|
raise HTTPException(status_code=400, detail="缺少tshirt_urls参数")
|
|
if "model_urls" not in data or not isinstance(data["model_urls"], list):
|
|
raise HTTPException(status_code=400, detail="缺少model_urls参数或格式错误")
|
|
|
|
tshirt_urls = data["tshirt_urls"]
|
|
if not isinstance(tshirt_urls, list):
|
|
raise HTTPException(status_code=400, detail="tshirt_urls必须是URL列表")
|
|
|
|
combinations = []
|
|
for model_url in data["model_urls"]:
|
|
for tshirt_url in tshirt_urls:
|
|
combinations.append(f"{tshirt_url}|{model_url}")
|
|
|
|
data["urls"] = combinations
|
|
config = data.get("config", {})
|
|
|
|
async def process_and_stream():
|
|
async for result in service.process_batch(tshirt_urls, data["model_urls"], config):
|
|
yield json.dumps(result) + "\n"
|
|
|
|
return StreamingResponse(
|
|
process_and_stream(),
|
|
media_type="application/x-ndjson"
|
|
)
|
|
|
|
|
|
@router.post(settings.file_route)
|
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
|
async def tryon_file(
|
|
tshirt_files: List[UploadFile] = File(...),
|
|
model_file: UploadFile = File(...),
|
|
config: str = None,
|
|
request: Request = None
|
|
):
|
|
tshirt_contents = [await file.read() for file in tshirt_files]
|
|
model_content = await model_file.read()
|
|
|
|
if request:
|
|
request._body = json.dumps({"urls": [f"file_{i}" for i in range(len(tshirt_files))]}).encode()
|
|
|
|
config_dict = {}
|
|
if config:
|
|
try:
|
|
config_dict = json.loads(config)
|
|
except json.JSONDecodeError:
|
|
raise HTTPException(status_code=400, detail="配置参数格式错误")
|
|
|
|
async def process_and_stream():
|
|
async for result in service.process_files(tshirt_contents, model_content, config_dict):
|
|
yield json.dumps(result) + "\n"
|
|
|
|
return StreamingResponse(
|
|
process_and_stream(),
|
|
media_type="application/x-ndjson"
|
|
)
|
|
|