japi 微服务版
This commit is contained in:
commit
4be051e459
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# 忽略名为 test 的文件夹
|
||||||
|
test/
|
||||||
|
.cursor/
|
||||||
|
|
||||||
|
|
||||||
|
# 忽略所有 文件夹
|
||||||
|
**/www/files/
|
||||||
|
**/output/
|
||||||
|
**/__pycache__/
|
||||||
|
|
||||||
|
*.py[cod]
|
||||||
|
|
||||||
|
.env
|
||||||
|
|
||||||
|
|
||||||
0
apps/add_bg/__init__.py
Normal file
0
apps/add_bg/__init__.py
Normal file
80
apps/add_bg/api.py
Normal file
80
apps/add_bg/api.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from service import AddBgService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = AddBgService()
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def add_background_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个URL图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含图片URL列表和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个图片的处理结果
|
||||||
|
"""
|
||||||
|
if "urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少urls参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
total = len(data["urls"])
|
||||||
|
for index, url in enumerate(data["urls"], 1):
|
||||||
|
try:
|
||||||
|
result = await service.add_background(url, config)
|
||||||
|
result.update({
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url
|
||||||
|
})
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield json.dumps({
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(settings.file_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def add_background_file(file: UploadFile = File(...), config: str = None, request: Request = None):
|
||||||
|
"""
|
||||||
|
为上传的文件添加背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图片文件
|
||||||
|
config: JSON格式的配置参数
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图片内容
|
||||||
|
"""
|
||||||
|
content = await file.read()
|
||||||
|
|
||||||
|
# 解析配置参数
|
||||||
|
config_dict = {}
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
config_dict = json.loads(config)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="配置参数格式错误")
|
||||||
|
|
||||||
|
result = await service.add_background_from_file(content, config_dict)
|
||||||
|
return result
|
||||||
21
apps/add_bg/app.py
Normal file
21
apps/add_bg/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Add Background",
|
||||||
|
description="图片添加背景颜色",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
311
apps/add_bg/service.py
Normal file
311
apps/add_bg/service.py
Normal file
@ -0,0 +1,311 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFilter, ImageDraw, ImageChops
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import urllib3
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
from colorthief import ColorThief
|
||||||
|
import tempfile
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import gc
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
import colorsys
|
||||||
|
|
||||||
|
# 关闭不必要的警告
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
class AddBgService:
|
||||||
|
# 默认配置
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'added_background_marker': "_added_background",
|
||||||
|
'enable_texture_effect': False,
|
||||||
|
'texture_type': 'noise',
|
||||||
|
'texture_blend_mode': 'multiply',
|
||||||
|
'enable_depth_of_field': False,
|
||||||
|
'blur_intensity': 15,
|
||||||
|
'output_format': 'png',
|
||||||
|
'enable_lighting_effect': False,
|
||||||
|
'light_intensity': 0.1,
|
||||||
|
'light_position': [0.5, 0.3],
|
||||||
|
'light_radius_ratio': [0.4, 0.25],
|
||||||
|
'light_angle': 45,
|
||||||
|
'light_blur': 91,
|
||||||
|
'light_shape': 'ellipse',
|
||||||
|
'alpha_background': 0.8,
|
||||||
|
'design_rotation': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化添加背景服务"""
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
def apply_lighting_effect(self, image, config):
|
||||||
|
"""应用光照效果"""
|
||||||
|
light_intensity = config.get('light_intensity', self.DEFAULT_CONFIG['light_intensity'])
|
||||||
|
light_position = config.get('light_position', self.DEFAULT_CONFIG['light_position'])
|
||||||
|
light_radius_ratio = config.get('light_radius_ratio', self.DEFAULT_CONFIG['light_radius_ratio'])
|
||||||
|
light_angle = config.get('light_angle', self.DEFAULT_CONFIG['light_angle'])
|
||||||
|
light_blur = config.get('light_blur', self.DEFAULT_CONFIG['light_blur'])
|
||||||
|
light_shape = config.get('light_shape', self.DEFAULT_CONFIG['light_shape'])
|
||||||
|
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
light_position = (int(light_position[0] * width), int(light_position[1] * height))
|
||||||
|
light_radius = (int(light_radius_ratio[0] * width), int(light_radius_ratio[1] * height))
|
||||||
|
mask = np.zeros((height, width), dtype=np.uint8)
|
||||||
|
|
||||||
|
if light_shape == 'ellipse':
|
||||||
|
cv2.ellipse(mask, light_position, light_radius, light_angle, 0, 360, 255, -1)
|
||||||
|
elif light_shape == 'circle':
|
||||||
|
cv2.circle(mask, light_position, min(light_radius), 255, -1)
|
||||||
|
elif light_shape == 'rect':
|
||||||
|
rect_top_left = (light_position[0] - light_radius[0] // 2, light_position[1] - light_radius[1] // 2)
|
||||||
|
rect_bottom_right = (light_position[0] + light_radius[0] // 2, light_position[1] + light_radius[1] // 2)
|
||||||
|
cv2.rectangle(mask, rect_top_left, rect_bottom_right, 255, -1)
|
||||||
|
|
||||||
|
mask = cv2.GaussianBlur(mask, (light_blur, light_blur), 0)
|
||||||
|
mask = mask.astype(np.float32) / 255
|
||||||
|
result = image.astype(np.float32)
|
||||||
|
for i in range(3):
|
||||||
|
result[:, :, i] = result[:, :, i] * (1 - light_intensity + mask * light_intensity)
|
||||||
|
return result.astype(np.uint8)
|
||||||
|
|
||||||
|
def generate_noise_texture(self, size, intensity=64):
|
||||||
|
"""生成噪点纹理"""
|
||||||
|
noise = np.random.randint(0, intensity, (size, size, 4), dtype=np.uint8)
|
||||||
|
noise[..., 3] = 255 # 设置 alpha 通道为不透明
|
||||||
|
return Image.fromarray(noise)
|
||||||
|
|
||||||
|
def generate_line_texture(self, size, line_width=4, spacing=20, color=(0, 0, 0, 255)):
|
||||||
|
"""生成线条纹理"""
|
||||||
|
texture = Image.new('RGBA', (size, size), (255, 255, 255, 0))
|
||||||
|
draw = ImageDraw.Draw(texture)
|
||||||
|
for y in range(0, size, spacing):
|
||||||
|
draw.line([(0, y), (size, y)], fill=color, width=line_width)
|
||||||
|
for x in range(0, size, spacing):
|
||||||
|
draw.line([(x, 0), (x, size)], fill=color, width=line_width)
|
||||||
|
return texture
|
||||||
|
|
||||||
|
def add_texture(self, image, config):
|
||||||
|
"""添加纹理效果"""
|
||||||
|
texture_type = config.get('texture_type', self.DEFAULT_CONFIG['texture_type'])
|
||||||
|
texture_blend_mode = config.get('texture_blend_mode', self.DEFAULT_CONFIG['texture_blend_mode'])
|
||||||
|
|
||||||
|
if texture_type == 'noise':
|
||||||
|
texture = self.generate_noise_texture(image.size[0])
|
||||||
|
elif texture_type == 'lines':
|
||||||
|
texture = self.generate_line_texture(image.size[0])
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if texture_blend_mode == 'multiply':
|
||||||
|
return ImageChops.multiply(image, texture)
|
||||||
|
elif texture_blend_mode == 'overlay':
|
||||||
|
return ImageChops.overlay(image, texture)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def calculate_dominant_color(self, image):
|
||||||
|
"""计算图像的主色调"""
|
||||||
|
try:
|
||||||
|
# 将PIL Image转换为BytesIO对象
|
||||||
|
img_byte_arr = io.BytesIO()
|
||||||
|
image.save(img_byte_arr, format='PNG')
|
||||||
|
img_byte_arr.seek(0)
|
||||||
|
|
||||||
|
# 使用ColorThief计算主色调
|
||||||
|
color_thief = ColorThief(img_byte_arr)
|
||||||
|
dominant_color = color_thief.get_color(quality=1)
|
||||||
|
return dominant_color
|
||||||
|
except Exception as e:
|
||||||
|
print(f"计算主色调失败: {str(e)}")
|
||||||
|
# 如果计算失败,返回默认的白色
|
||||||
|
return (255, 255, 255)
|
||||||
|
|
||||||
|
def rgb_to_hex(self, rgb):
|
||||||
|
"""RGB转HEX"""
|
||||||
|
return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2])
|
||||||
|
|
||||||
|
def calculate_light_color(self, dominant_color, target_lightness=0.92):
|
||||||
|
r, g, b = [x / 255.0 for x in dominant_color]
|
||||||
|
h, l, s = colorsys.rgb_to_hls(r, g, b)
|
||||||
|
l = target_lightness
|
||||||
|
r2, g2, b2 = colorsys.hls_to_rgb(h, l, s)
|
||||||
|
return (int(r2 * 255), int(g2 * 255), int(b2 * 255))
|
||||||
|
|
||||||
|
def calculate_monochrome_color(self, dominant_color, alpha):
|
||||||
|
# alpha参数可忽略
|
||||||
|
return self.calculate_light_color(dominant_color, target_lightness=0.92)
|
||||||
|
|
||||||
|
def apply_depth_of_field(self, background_image, blur_intensity):
|
||||||
|
"""应用景深效果"""
|
||||||
|
background_image_pil = Image.fromarray(background_image).convert("RGBA")
|
||||||
|
blurred_background = background_image_pil.filter(ImageFilter.GaussianBlur(blur_intensity))
|
||||||
|
return np.array(blurred_background)
|
||||||
|
|
||||||
|
def rotate_image_with_transparency(self, image, angle):
|
||||||
|
"""旋转图像"""
|
||||||
|
rotated_image = image.rotate(angle, expand=True)
|
||||||
|
return rotated_image
|
||||||
|
|
||||||
|
def process_image(self, image, config):
|
||||||
|
"""处理图像,添加背景"""
|
||||||
|
try:
|
||||||
|
# 合并默认配置和用户配置
|
||||||
|
config = {**self.DEFAULT_CONFIG, **config}
|
||||||
|
|
||||||
|
# 计算主色并设置背景颜色
|
||||||
|
try:
|
||||||
|
dominant_color = self.calculate_dominant_color(image)
|
||||||
|
background_color = self.calculate_monochrome_color(dominant_color, config['alpha_background'])
|
||||||
|
except Exception as e:
|
||||||
|
# 使用默认的白色背景
|
||||||
|
background_color = (255, 255, 255)
|
||||||
|
|
||||||
|
# 创建背景图像
|
||||||
|
background_image = np.full((image.height, image.width, 4), (*background_color, 255), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 应用景深效果
|
||||||
|
if config['enable_depth_of_field']:
|
||||||
|
background_image = self.apply_depth_of_field(background_image, config['blur_intensity'])
|
||||||
|
|
||||||
|
# 应用纹理效果
|
||||||
|
if config['enable_texture_effect']:
|
||||||
|
background_image_pil = Image.fromarray(background_image).convert("RGBA")
|
||||||
|
background_image_pil = self.add_texture(background_image_pil, config)
|
||||||
|
background_image = np.array(background_image_pil)
|
||||||
|
|
||||||
|
# 将前景图像转换为numpy数组
|
||||||
|
foreground = np.array(image)
|
||||||
|
|
||||||
|
# 旋转前景图像
|
||||||
|
if config['design_rotation'] != 0:
|
||||||
|
foreground_pil = Image.fromarray(foreground)
|
||||||
|
foreground_pil = self.rotate_image_with_transparency(foreground_pil, config['design_rotation'])
|
||||||
|
foreground = np.array(foreground_pil)
|
||||||
|
|
||||||
|
# 合并前景和背景
|
||||||
|
alpha = foreground[:, :, 3] / 255.0
|
||||||
|
for c in range(3):
|
||||||
|
background_image[:, :, c] = background_image[:, :, c] * (1 - alpha) + foreground[:, :, c] * alpha
|
||||||
|
|
||||||
|
# 确保最终图像不透明
|
||||||
|
background_image[:, :, 3] = 255
|
||||||
|
|
||||||
|
# 应用光照效果
|
||||||
|
if config['enable_lighting_effect']:
|
||||||
|
background_image = self.apply_lighting_effect(background_image, config)
|
||||||
|
|
||||||
|
# 转换回PIL图像
|
||||||
|
return Image.fromarray(background_image)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"处理图片失败: {str(e)}")
|
||||||
|
|
||||||
|
def image_to_base64(self, image, config):
|
||||||
|
"""将图片转换为base64格式"""
|
||||||
|
try:
|
||||||
|
output_format = config.get('output_format', self.DEFAULT_CONFIG['output_format'])
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format=output_format.upper())
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return f"data:image/{output_format.lower()};base64,{img_str}"
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"转换图片为base64失败: {str(e)}")
|
||||||
|
|
||||||
|
async def add_background(self, image_path, config=None):
|
||||||
|
"""为图片添加背景"""
|
||||||
|
try:
|
||||||
|
# 下载图片
|
||||||
|
if self.is_valid_url(image_path):
|
||||||
|
image_content = self.download_image(image_path)
|
||||||
|
image = Image.open(io.BytesIO(image_content))
|
||||||
|
else:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
|
||||||
|
# 确保图片是RGBA模式
|
||||||
|
if image.mode != 'RGBA':
|
||||||
|
image = image.convert('RGBA')
|
||||||
|
|
||||||
|
# 处理图片
|
||||||
|
processed_image = self.process_image(image, config or {})
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
result = self.image_to_base64(processed_image, config or {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
'status': 'success',
|
||||||
|
'image_content': result
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
'status': 'error',
|
||||||
|
'message': f"处理图片失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def add_background_from_file(self, file_content, config=None):
|
||||||
|
"""
|
||||||
|
从上传的文件内容添加背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 上传的文件内容
|
||||||
|
config: 配置参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图像内容
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 从文件内容创建PIL Image对象
|
||||||
|
image = Image.open(io.BytesIO(file_content)).convert("RGBA")
|
||||||
|
image_with_bg = self.process_image(image, config)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
image_content = self.image_to_base64(image_with_bg, config)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"image_content": image_content
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"处理图片失败: {e}")
|
||||||
|
|
||||||
|
def is_valid_url(self, url):
|
||||||
|
"""验证URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_image(self, url):
|
||||||
|
"""下载图片并返回内容"""
|
||||||
|
try:
|
||||||
|
response = requests.get(url, timeout=30)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"下载图片失败: {str(e)}")
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
32
apps/add_bg/settings.py
Normal file
32
apps/add_bg/settings.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8105
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/add_bg"
|
||||||
|
file_route: str = "/file"
|
||||||
|
batch_route: str = "/batch"
|
||||||
|
api_name: str = "add_background"
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/add_bg/utils.py
Normal file
146
apps/add_bg/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/jart/__init__.py
Normal file
0
apps/jart/__init__.py
Normal file
57
apps/jart/api.py
Normal file
57
apps/jart/api.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import TxtImgService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = TxtImgService()
|
||||||
|
|
||||||
|
@router.post(settings.generate_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def generate_image(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
根据文本提示生成图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含文本提示和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的图像内容
|
||||||
|
"""
|
||||||
|
if "prompt" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少prompt参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
result = await service.generate_image(data["prompt"], config)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def generate_image_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个文本提示
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含文本提示列表和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个提示的处理结果
|
||||||
|
"""
|
||||||
|
if "prompts" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少prompts参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(data["prompts"], config):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
21
apps/jart/app.py
Normal file
21
apps/jart/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="JArt",
|
||||||
|
description="JArt绘画服务API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
272
apps/jart/service.py
Normal file
272
apps/jart/service.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import random
|
||||||
|
import websocket
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
||||||
|
|
||||||
|
# 固定配置变量
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"comfyui_server_address": "192.168.2.200:8188",
|
||||||
|
"ckpt_name": "flux1-schnell-fp8.safetensors",
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"denoise": 1.0,
|
||||||
|
"images_per_prompt": 1,
|
||||||
|
"image_width": 1024,
|
||||||
|
"image_height": 1024,
|
||||||
|
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 定义基础工作流 JSON 模板
|
||||||
|
WORKFLOW_TEMPLATE = """
|
||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"cfg": %d,
|
||||||
|
"denoise": %d,
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sampler_name": "%s",
|
||||||
|
"scheduler": "%s",
|
||||||
|
"seed": 8566257,
|
||||||
|
"steps": %d
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "%s"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"inputs": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"height": %d,
|
||||||
|
"width": %d
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "masterpiece best quality girl"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "%s"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"save_image_websocket_node": {
|
||||||
|
"class_type": "SaveImageWebsocket",
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class TxtImgService:
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化文本生成图像服务"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
||||||
|
response = json.loads(urllib.request.urlopen(req).read())
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""从 ComfyUI 获取生成的图像"""
|
||||||
|
try:
|
||||||
|
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||||||
|
prompt_id = prompt_response['prompt_id']
|
||||||
|
except KeyError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
output_images = {}
|
||||||
|
current_node = ""
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data.get('prompt_id') == prompt_id:
|
||||||
|
if data['node'] is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current_node = data['node']
|
||||||
|
else:
|
||||||
|
if current_node == 'save_image_websocket_node':
|
||||||
|
images_output = output_images.get(current_node, [])
|
||||||
|
images_output.append(out[8:])
|
||||||
|
output_images[current_node] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
||||||
|
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
|
||||||
|
cfg = DEFAULT_CONFIG.copy()
|
||||||
|
if config:
|
||||||
|
cfg.update(config)
|
||||||
|
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||||||
|
images_count = int(cfg.get('images_per_prompt', 1))
|
||||||
|
|
||||||
|
for i in range(images_count):
|
||||||
|
workflow = json.loads(WORKFLOW_TEMPLATE % (
|
||||||
|
cfg['cfg'],
|
||||||
|
cfg['denoise'],
|
||||||
|
cfg['sampler_name'],
|
||||||
|
cfg['scheduler'],
|
||||||
|
cfg['steps'],
|
||||||
|
cfg['ckpt_name'],
|
||||||
|
cfg['image_height'],
|
||||||
|
cfg['image_width'],
|
||||||
|
cfg['negative_prompt']
|
||||||
|
))
|
||||||
|
|
||||||
|
workflow["6"]["inputs"]["text"] = prompt
|
||||||
|
seed = random.randint(1, 4294967295)
|
||||||
|
workflow["3"]["inputs"]["seed"] = seed
|
||||||
|
|
||||||
|
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||||
|
|
||||||
|
for node_id, image_list in images_dict.items():
|
||||||
|
for image_data in image_list:
|
||||||
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
yield base64_image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if ws:
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||||
|
"""异步生成图像,流式返回结果"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def sync_generator():
|
||||||
|
for base64_image in self.generate_image_sync(prompt, config):
|
||||||
|
yield base64_image
|
||||||
|
|
||||||
|
generator = await loop.run_in_executor(None, sync_generator)
|
||||||
|
|
||||||
|
for base64_image in generator:
|
||||||
|
yield {
|
||||||
|
"status": "success",
|
||||||
|
"image": f"data:image/png;base64,{base64_image}",
|
||||||
|
"message": f"成功生成图片"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"图像生成失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
|
||||||
|
"""批量处理多个文本提示,流式返回结果"""
|
||||||
|
total = len(prompts)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
for i, prompt in enumerate(prompts, 1):
|
||||||
|
try:
|
||||||
|
async for result in self.generate_image(prompt, config):
|
||||||
|
if result["status"] == "success":
|
||||||
|
success_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "success",
|
||||||
|
"image_content": result["image"],
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "error",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": f"处理失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
35
apps/jart/settings.py
Normal file
35
apps/jart/settings.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8102
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jart"
|
||||||
|
generate_route: str = "/generate" # 生成图片的路由
|
||||||
|
batch_route: str = "/batch" # 批量生成图片的路由
|
||||||
|
api_name: str = "jart" # 默认API名称
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# Stable Diffusion配置
|
||||||
|
comfyui_server_address: str = "comfyui.jingrow.com:8188"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jart/utils.py
Normal file
146
apps/jart/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/jart_v1/__init__.py
Normal file
0
apps/jart_v1/__init__.py
Normal file
57
apps/jart_v1/api.py
Normal file
57
apps/jart_v1/api.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import TxtImgService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = TxtImgService()
|
||||||
|
|
||||||
|
@router.post(settings.generate_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def generate_image(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
根据文本提示生成图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含文本提示和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
生成的图像内容
|
||||||
|
"""
|
||||||
|
if "prompt" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少prompt参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
result = await service.generate_image(data["prompt"], config)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def generate_image_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个文本提示
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含文本提示列表和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个提示的处理结果
|
||||||
|
"""
|
||||||
|
if "prompts" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少prompts参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(data["prompts"], config):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
21
apps/jart_v1/app.py
Normal file
21
apps/jart_v1/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="JArt V1",
|
||||||
|
description="JArt绘画服务API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
272
apps/jart_v1/service.py
Normal file
272
apps/jart_v1/service.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import random
|
||||||
|
import websocket
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
||||||
|
|
||||||
|
# 固定配置变量
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
"comfyui_server_address": "192.168.2.200:8188",
|
||||||
|
"ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors",
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "normal",
|
||||||
|
"steps": 20,
|
||||||
|
"cfg": 8,
|
||||||
|
"denoise": 1.0,
|
||||||
|
"images_per_prompt": 1,
|
||||||
|
"image_width": 1024,
|
||||||
|
"image_height": 1024,
|
||||||
|
"negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 定义基础工作流 JSON 模板
|
||||||
|
WORKFLOW_TEMPLATE = """
|
||||||
|
{
|
||||||
|
"3": {
|
||||||
|
"class_type": "KSampler",
|
||||||
|
"inputs": {
|
||||||
|
"cfg": %d,
|
||||||
|
"denoise": %d,
|
||||||
|
"latent_image": [
|
||||||
|
"5",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"model": [
|
||||||
|
"4",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"sampler_name": "%s",
|
||||||
|
"scheduler": "%s",
|
||||||
|
"seed": 8566257,
|
||||||
|
"steps": %d
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"4": {
|
||||||
|
"class_type": "CheckpointLoaderSimple",
|
||||||
|
"inputs": {
|
||||||
|
"ckpt_name": "%s"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"5": {
|
||||||
|
"class_type": "EmptyLatentImage",
|
||||||
|
"inputs": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"height": %d,
|
||||||
|
"width": %d
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"6": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "masterpiece best quality girl"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"inputs": {
|
||||||
|
"clip": [
|
||||||
|
"4",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"text": "%s"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"3",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"4",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"save_image_websocket_node": {
|
||||||
|
"class_type": "SaveImageWebsocket",
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class TxtImgService:
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化文本生成图像服务"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
||||||
|
response = json.loads(urllib.request.urlopen(req).read())
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""从 ComfyUI 获取生成的图像"""
|
||||||
|
try:
|
||||||
|
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||||||
|
prompt_id = prompt_response['prompt_id']
|
||||||
|
except KeyError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
output_images = {}
|
||||||
|
current_node = ""
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data.get('prompt_id') == prompt_id:
|
||||||
|
if data['node'] is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current_node = data['node']
|
||||||
|
else:
|
||||||
|
if current_node == 'save_image_websocket_node':
|
||||||
|
images_output = output_images.get(current_node, [])
|
||||||
|
images_output.append(out[8:])
|
||||||
|
output_images[current_node] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
||||||
|
"""生成 Flux 模型的图片,流式返回 base64 编码的图片"""
|
||||||
|
cfg = DEFAULT_CONFIG.copy()
|
||||||
|
if config:
|
||||||
|
cfg.update(config)
|
||||||
|
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
try:
|
||||||
|
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||||||
|
images_count = int(cfg.get('images_per_prompt', 1))
|
||||||
|
|
||||||
|
for i in range(images_count):
|
||||||
|
workflow = json.loads(WORKFLOW_TEMPLATE % (
|
||||||
|
cfg['cfg'],
|
||||||
|
cfg['denoise'],
|
||||||
|
cfg['sampler_name'],
|
||||||
|
cfg['scheduler'],
|
||||||
|
cfg['steps'],
|
||||||
|
cfg['ckpt_name'],
|
||||||
|
cfg['image_height'],
|
||||||
|
cfg['image_width'],
|
||||||
|
cfg['negative_prompt']
|
||||||
|
))
|
||||||
|
|
||||||
|
workflow["6"]["inputs"]["text"] = prompt
|
||||||
|
seed = random.randint(1, 4294967295)
|
||||||
|
workflow["3"]["inputs"]["seed"] = seed
|
||||||
|
|
||||||
|
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||||
|
|
||||||
|
for node_id, image_list in images_dict.items():
|
||||||
|
for image_data in image_list:
|
||||||
|
base64_image = base64.b64encode(image_data).decode('utf-8')
|
||||||
|
yield base64_image
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if ws:
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||||
|
"""异步生成图像,流式返回结果"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
def sync_generator():
|
||||||
|
for base64_image in self.generate_image_sync(prompt, config):
|
||||||
|
yield base64_image
|
||||||
|
|
||||||
|
generator = await loop.run_in_executor(None, sync_generator)
|
||||||
|
|
||||||
|
for base64_image in generator:
|
||||||
|
yield {
|
||||||
|
"status": "success",
|
||||||
|
"image": f"data:image/png;base64,{base64_image}",
|
||||||
|
"message": f"成功生成图片"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"图像生成失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_batch(self, prompts: List[str], config: Optional[Dict] = None):
|
||||||
|
"""批量处理多个文本提示,流式返回结果"""
|
||||||
|
total = len(prompts)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
for i, prompt in enumerate(prompts, 1):
|
||||||
|
try:
|
||||||
|
async for result in self.generate_image(prompt, config):
|
||||||
|
if result["status"] == "success":
|
||||||
|
success_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "success",
|
||||||
|
"image_content": result["image"],
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "error",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_prompt": prompt,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": f"处理失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
await asyncio.sleep(0)
|
||||||
35
apps/jart_v1/settings.py
Normal file
35
apps/jart_v1/settings.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8103
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jart_v1"
|
||||||
|
generate_route: str = "/generate" # 生成图片的路由
|
||||||
|
batch_route: str = "/batch" # 批量生成图片的路由
|
||||||
|
api_name: str = "jart_v1" # 默认API名称
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# Stable Diffusion配置
|
||||||
|
comfyui_server_address: str = "comfyui.jingrow.com:8188"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jart_v1/utils.py
Normal file
146
apps/jart_v1/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
1
apps/jchat/__init__.py
Normal file
1
apps/jchat/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 使jchat目录成为Python包
|
||||||
59
apps/jchat/api.py
Normal file
59
apps/jchat/api.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from service import ChatService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = ChatService()
|
||||||
|
|
||||||
|
def dynamic_billing_wrapper(func):
|
||||||
|
"""动态API扣费装饰器,使用模型名称作为API名称"""
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(data: dict, request: Request):
|
||||||
|
api_name = settings.default_api_name # 使用settings中的默认API名称
|
||||||
|
if "model" in data:
|
||||||
|
api_name = data["model"]
|
||||||
|
|
||||||
|
dynamic_decorator = jingrow_api_verify_and_billing(api_name=api_name)
|
||||||
|
decorated_func = dynamic_decorator(func)
|
||||||
|
return await decorated_func(**{"data": data, "request": request})
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@router.post(settings.chat_route)
|
||||||
|
@dynamic_billing_wrapper
|
||||||
|
async def chat_api(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
通用文本聊天API,支持OpenAI和豆包等模型的请求格式
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含以下字段的字典:
|
||||||
|
- messages: 消息列表,每个消息包含 role 和 content(必需)
|
||||||
|
- model: 选择使用的模型(可选,默认为配置的默认模型)
|
||||||
|
- temperature: 温度参数(可选,默认为0.7)
|
||||||
|
- top_p: top_p参数(可选,默认为0.9)
|
||||||
|
- max_tokens: 最大生成token数(可选,默认为2048)
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AI生成的回复内容
|
||||||
|
"""
|
||||||
|
if "messages" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少messages参数")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if "model" in data:
|
||||||
|
service.model = data["model"]
|
||||||
|
if "temperature" in data:
|
||||||
|
service.temperature = data["temperature"]
|
||||||
|
if "top_p" in data:
|
||||||
|
service.top_p = data["top_p"]
|
||||||
|
if "max_tokens" in data:
|
||||||
|
service.max_tokens = data["max_tokens"]
|
||||||
|
|
||||||
|
result = await service.chat(data["messages"])
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
21
apps/jchat/app.py
Normal file
21
apps/jchat/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="JChat Service",
|
||||||
|
description="AI聊天服务API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
223
apps/jchat/service.py
Normal file
223
apps/jchat/service.py
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import asyncio
|
||||||
|
from typing import Dict, Optional, List, Union
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
# 默认模型配置
|
||||||
|
default_model = "deepseek" # 默认使用的模型,可选值为"gpt"、"deepseek"或"doubao"
|
||||||
|
gpt_api_model = "gpt-4o" # ChatGPT模型名称
|
||||||
|
deepseek_api_model = "deepseek-chat" # DeepSeek模型名称
|
||||||
|
doubao_api_model = "doubao-1-5-thinking-pro-250415" # Doubao模型名称
|
||||||
|
|
||||||
|
# 模型映射配置
|
||||||
|
model_mapping = {
|
||||||
|
"jingrow-chat": {
|
||||||
|
"type": "deepseek",
|
||||||
|
"model": "deepseek-chat"
|
||||||
|
},
|
||||||
|
"jingrow-chat-lite": {
|
||||||
|
"type": "doubao",
|
||||||
|
"model": "doubao-1-5-lite-32k-250115"
|
||||||
|
},
|
||||||
|
"jingrow-chat-think": {
|
||||||
|
"type": "doubao",
|
||||||
|
"model": "doubao-1-5-thinking-pro-250415"
|
||||||
|
},
|
||||||
|
"jingrow-chat-vision": {
|
||||||
|
"type": "doubao",
|
||||||
|
"model": "doubao-1.5-vision-pro-250328"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 默认系统提示词
|
||||||
|
default_system_message = """
|
||||||
|
你是一个有用的AI助手,请根据用户的问题提供清晰、准确的回答。
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
def __init__(self, model: str = None, temperature: float = 0.7, top_p: float = 0.9, max_tokens: int = 2048):
|
||||||
|
"""初始化聊天服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 选择使用的模型
|
||||||
|
temperature: 温度参数
|
||||||
|
top_p: top_p参数
|
||||||
|
max_tokens: 最大生成token数
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
|
||||||
|
def _get_model_config(self, model: str) -> Dict:
|
||||||
|
"""获取模型配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含模型类型和具体模型名称的字典
|
||||||
|
"""
|
||||||
|
# 检查是否在映射表中
|
||||||
|
if model in model_mapping:
|
||||||
|
return model_mapping[model]
|
||||||
|
|
||||||
|
# 根据模型名称判断类型
|
||||||
|
model_lower = model.lower()
|
||||||
|
if "deepseek" in model_lower:
|
||||||
|
return {"type": "deepseek", "model": model}
|
||||||
|
elif "doubao" in model_lower:
|
||||||
|
return {"type": "doubao", "model": model}
|
||||||
|
else:
|
||||||
|
return {"type": "gpt", "model": model}
|
||||||
|
|
||||||
|
def _get_api_config(self, model_type: str) -> Dict:
|
||||||
|
"""获取API配置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: 模型类型(gpt/deepseek/doubao)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含API配置的字典
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"gpt": {
|
||||||
|
"url": settings.chatgpt_api_url,
|
||||||
|
"key": settings.chatgpt_api_key,
|
||||||
|
"model": settings.chatgpt_api_model
|
||||||
|
},
|
||||||
|
"deepseek": {
|
||||||
|
"url": settings.deepseek_api_url,
|
||||||
|
"key": settings.deepseek_api_key,
|
||||||
|
"model": settings.deepseek_api_model
|
||||||
|
},
|
||||||
|
"doubao": {
|
||||||
|
"url": settings.doubao_api_url,
|
||||||
|
"key": settings.doubao_api_key,
|
||||||
|
"model": settings.doubao_api_model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config.get(model_type, config["gpt"])
|
||||||
|
|
||||||
|
def _prepare_payload(self, messages: List[Dict], model_type: str, model_name: str) -> Dict:
|
||||||
|
"""准备请求payload
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
model_type: 模型类型
|
||||||
|
model_name: 具体模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
请求payload
|
||||||
|
"""
|
||||||
|
api_config = self._get_api_config(model_type)
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model_name, # 使用映射后的具体模型名称
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"max_tokens": self.max_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def _send_request(self, messages: List[Dict], model_type: str, model_name: str) -> Optional[Dict]:
|
||||||
|
"""发送API请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表
|
||||||
|
model_type: 模型类型
|
||||||
|
model_name: 具体模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
API响应
|
||||||
|
"""
|
||||||
|
api_config = self._get_api_config(model_type)
|
||||||
|
payload = self._prepare_payload(messages, model_type, model_name)
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_config['key']}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
api_config["url"],
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=(10, 300)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat_sync(self, messages: List[Dict]) -> Dict:
|
||||||
|
"""同步处理聊天请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表,每个消息包含 role 和 content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_config = self._get_model_config(self.model or default_model)
|
||||||
|
model_type = model_config["type"]
|
||||||
|
model_name = model_config["model"]
|
||||||
|
|
||||||
|
ai_response = self._send_request(messages, model_type, model_name)
|
||||||
|
|
||||||
|
if ai_response is None:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI服务请求失败"
|
||||||
|
}
|
||||||
|
|
||||||
|
choices = ai_response.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI响应无效"
|
||||||
|
}
|
||||||
|
|
||||||
|
message = choices[0].get("message", {}).get("content", "")
|
||||||
|
if not message:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI响应内容为空"
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"data": message
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"处理聊天任务时发生错误: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def chat(self, messages: List[Dict]) -> Dict:
|
||||||
|
"""异步处理聊天请求
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 消息列表,每个消息包含 role 和 content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, self.chat_sync, messages)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"聊天请求失败: {str(e)}"
|
||||||
|
}
|
||||||
50
apps/jchat/settings.py
Normal file
50
apps/jchat/settings.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8101
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jchat"
|
||||||
|
chat_route: str = "/chat"
|
||||||
|
default_api_name: str = "jingrow-chat" # 默认API名称
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions"
|
||||||
|
deepseek_api_key: Optional[str] = None
|
||||||
|
deepseek_api_model: str = "deepseek-chat"
|
||||||
|
|
||||||
|
# Doubao配置
|
||||||
|
doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||||
|
doubao_api_key: Optional[str] = None
|
||||||
|
doubao_api_model: str = "doubao-1-5-pro-32k-250115"
|
||||||
|
|
||||||
|
# ChatGPT配置
|
||||||
|
chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions"
|
||||||
|
chatgpt_api_key: Optional[str] = None
|
||||||
|
chatgpt_api_model: str = "gpt-4"
|
||||||
|
|
||||||
|
# 默认服务模型配置
|
||||||
|
translation_model: str = "Doubao"
|
||||||
|
image_to_text_model: str = "Doubao"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jchat/utils.py
Normal file
146
apps/jchat/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
1
apps/jdescribe/__init__.py
Normal file
1
apps/jdescribe/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 使jchat目录成为Python包
|
||||||
36
apps/jdescribe/api.py
Normal file
36
apps/jdescribe/api.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from service import ImageDescribeService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = ImageDescribeService()
|
||||||
|
|
||||||
|
@router.post(settings.get_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def describe_image_api(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
根据图像URL生成中英文描述
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含以下字段的字典:
|
||||||
|
- image_url: 图片URL(必需)
|
||||||
|
- system_message: 自定义系统消息(可选)
|
||||||
|
- user_content: 自定义用户消息(可选)
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
图像的中英文描述
|
||||||
|
"""
|
||||||
|
if "image_url" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少image_url参数")
|
||||||
|
|
||||||
|
# 如果提供了自定义消息,则更新service实例的消息
|
||||||
|
if "system_message" in data:
|
||||||
|
service.system_message = data["system_message"]
|
||||||
|
if "user_content" in data:
|
||||||
|
service.user_content = data["user_content"]
|
||||||
|
|
||||||
|
result = await service.describe_image(data["image_url"])
|
||||||
|
return result
|
||||||
21
apps/jdescribe/app.py
Normal file
21
apps/jdescribe/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Jdescribe",
|
||||||
|
description="Jdescribe描述图片API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
193
apps/jdescribe/service.py
Normal file
193
apps/jdescribe/service.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
from pathlib import Path
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
image_to_text_model = "Doubao"
|
||||||
|
deepseek_api_model = "deepseek-chat"
|
||||||
|
doubao_api_model = "doubao-1.5-vision-pro-250328"
|
||||||
|
chatgpt_api_model = "gpt-4o"
|
||||||
|
|
||||||
|
default_system_message = """
|
||||||
|
请用中英文分别描述该图片,使用结构化描述,描述的内容用于ai绘画,因此请优化内容,不要用这是开头,使之适合用作ai绘画prompts。
|
||||||
|
输出格式为:
|
||||||
|
{
|
||||||
|
"中文描述": "中文内容",
|
||||||
|
"英文描述": "英文内容"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_user_content = "请用中英文分别生成该图片的内容描述。"
|
||||||
|
|
||||||
|
class ImageDescribeService:
|
||||||
|
def __init__(self, system_message: str = None, user_content: str = None):
|
||||||
|
"""初始化图像描述服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_message: 自定义系统提示词
|
||||||
|
user_content: 自定义用户提示词
|
||||||
|
"""
|
||||||
|
self.system_message = system_message or default_system_message
|
||||||
|
self.user_content = user_content or default_user_content
|
||||||
|
|
||||||
|
def send_to_chatgpt(self, image_url: str) -> Optional[Dict]:
|
||||||
|
"""向ChatGPT发送图像描述请求"""
|
||||||
|
payload = {
|
||||||
|
"model": chatgpt_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{self.user_content}\n\n图片链接: {image_url}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.chatgpt_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.chatgpt_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def send_to_deepseek(self, image_url: str) -> Optional[Dict]:
|
||||||
|
"""向DeepSeek发送图像描述请求"""
|
||||||
|
payload = {
|
||||||
|
"model": deepseek_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{self.user_content}\n\n图片链接: {image_url}"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.deepseek_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.deepseek_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def send_to_doubao(self, image_url: str) -> Optional[Dict]:
|
||||||
|
"""向Doubao发送图像描述请求"""
|
||||||
|
payload = {
|
||||||
|
"model": doubao_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": self.user_content
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.9,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.doubao_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.doubao_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def describe_image_sync(self, image_url: str) -> Dict:
|
||||||
|
"""同步处理图像描述请求"""
|
||||||
|
try:
|
||||||
|
# 选择合适的AI模型处理请求
|
||||||
|
if image_to_text_model == "DeepSeek":
|
||||||
|
ai_response = self.send_to_deepseek(image_url)
|
||||||
|
elif image_to_text_model == "Doubao":
|
||||||
|
ai_response = self.send_to_doubao(image_url)
|
||||||
|
else:
|
||||||
|
ai_response = self.send_to_chatgpt(image_url)
|
||||||
|
|
||||||
|
if ai_response is None:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI服务请求失败"
|
||||||
|
}
|
||||||
|
|
||||||
|
choices = ai_response.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI响应无效"
|
||||||
|
}
|
||||||
|
|
||||||
|
message = choices[0].get("message", {}).get("content", "")
|
||||||
|
response_data = json.loads(message)
|
||||||
|
cn_description = response_data.get("中文描述", "")
|
||||||
|
en_description = response_data.get("英文描述", "")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"data": {
|
||||||
|
"cn_description": cn_description,
|
||||||
|
"en_description": en_description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"描述任务处理失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"处理描述任务时发生错误: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def describe_image(self, image_url: str) -> Dict:
|
||||||
|
"""异步处理图像描述请求"""
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, self.describe_image_sync, image_url)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"图像描述失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
50
apps/jdescribe/settings.py
Normal file
50
apps/jdescribe/settings.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8107
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jdescribe"
|
||||||
|
get_route: str = "/get"
|
||||||
|
api_name: str = "jdescribe" # 默认API名称
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions"
|
||||||
|
deepseek_api_key: Optional[str] = None
|
||||||
|
deepseek_api_model: str = "deepseek-chat"
|
||||||
|
|
||||||
|
# Doubao配置
|
||||||
|
doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||||
|
doubao_api_key: Optional[str] = None
|
||||||
|
doubao_api_model: str = "doubao-1-5-pro-32k-250115"
|
||||||
|
|
||||||
|
# ChatGPT配置
|
||||||
|
chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions"
|
||||||
|
chatgpt_api_key: Optional[str] = None
|
||||||
|
chatgpt_api_model: str = "gpt-4"
|
||||||
|
|
||||||
|
# 默认服务模型配置
|
||||||
|
translation_model: str = "Doubao"
|
||||||
|
image_to_text_model: str = "Doubao"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jdescribe/utils.py
Normal file
146
apps/jdescribe/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
35
apps/jfile/app.py
Normal file
35
apps/jfile/app.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from file_cleaner import FileCleaner
|
||||||
|
from settings import settings
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="www",
|
||||||
|
description="公共静态资源访问服务",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 挂载静态文件目录
|
||||||
|
app.mount("/files", StaticFiles(directory="files"), name="files")
|
||||||
|
|
||||||
|
|
||||||
|
# 注册文件定时清理任务
|
||||||
|
save_dir = "files"
|
||||||
|
file_prefix = "upscaled_"
|
||||||
|
retention_hours = settings.file_retention_hours
|
||||||
|
cleaner = FileCleaner(save_dir, file_prefix, retention_hours)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
asyncio.create_task(cleaner.periodic_cleanup())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
33
apps/jfile/file_cleaner.py
Normal file
33
apps/jfile/file_cleaner.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
class FileCleaner:
|
||||||
|
def __init__(self, target_dir, prefix, retention_hours):
|
||||||
|
self.target_dir = target_dir
|
||||||
|
self.prefix = prefix
|
||||||
|
self.retention_hours = retention_hours
|
||||||
|
|
||||||
|
async def periodic_cleanup(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self.cleanup_old_files()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理文件时出错: {str(e)}")
|
||||||
|
await asyncio.sleep(3600)
|
||||||
|
|
||||||
|
def cleanup_old_files(self):
|
||||||
|
if not os.path.exists(self.target_dir):
|
||||||
|
return
|
||||||
|
cutoff_time = datetime.now() - timedelta(hours=self.retention_hours)
|
||||||
|
for filename in os.listdir(self.target_dir):
|
||||||
|
if not filename.startswith(self.prefix):
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(self.target_dir, filename)
|
||||||
|
file_time = datetime.fromtimestamp(os.path.getctime(file_path))
|
||||||
|
if file_time < cutoff_time:
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
print(f"已删除过期文件: {filename}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"删除文件失败 {filename}: {str(e)}")
|
||||||
BIN
apps/jfile/files/upscaled_2fead90ea9.jpg
Normal file
BIN
apps/jfile/files/upscaled_2fead90ea9.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.2 MiB |
BIN
apps/jfile/files/upscaled_466670e1cb.jpg
Normal file
BIN
apps/jfile/files/upscaled_466670e1cb.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
21
apps/jfile/settings.py
Normal file
21
apps/jfile/settings.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8100
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# Japi 静态资源下载URL
|
||||||
|
download_url: str = "http://api.jingrow.com:9080/files"
|
||||||
|
|
||||||
|
# 文件保留时间(小时)
|
||||||
|
file_retention_hours: int = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = Settings()
|
||||||
1
apps/jtranslate/__init__.py
Normal file
1
apps/jtranslate/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 使jchat目录成为Python包
|
||||||
35
apps/jtranslate/api.py
Normal file
35
apps/jtranslate/api.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from service import TranslateService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = TranslateService()
|
||||||
|
|
||||||
|
@router.post(settings.get_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def translate_text_api(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
将中文文本翻译成英文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含以下字段的字典:
|
||||||
|
- source_text: 源文本(必需)
|
||||||
|
- system_message: 自定义系统消息(可选)
|
||||||
|
- user_content: 自定义用户消息(可选)
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
翻译后的英文文本
|
||||||
|
"""
|
||||||
|
if "source_text" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少source_text参数")
|
||||||
|
|
||||||
|
# 如果提供了自定义消息,则更新service实例的消息
|
||||||
|
if "system_message" in data:
|
||||||
|
service.system_message = data["system_message"]
|
||||||
|
if "user_content" in data:
|
||||||
|
service.user_content = data["user_content"]
|
||||||
|
|
||||||
|
result = await service.translate_text(data["source_text"])
|
||||||
|
return result
|
||||||
21
apps/jtranslate/app.py
Normal file
21
apps/jtranslate/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Jtranslate",
|
||||||
|
description="Jtranslate翻译API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
179
apps/jtranslate/service.py
Normal file
179
apps/jtranslate/service.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import os
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
translation_model = "Doubao"
|
||||||
|
deepseek_api_model = "deepseek-chat"
|
||||||
|
doubao_api_model = "doubao-1-5-pro-32k-250115"
|
||||||
|
chatgpt_api_model = "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义提示词配置
|
||||||
|
default_system_message = """
|
||||||
|
你是一位专业的中译英翻译专家。请将提供的中文内容翻译成地道、流畅的英文,确保保留原文的风格和语境。
|
||||||
|
翻译时要注意原文的专业术语和表达方式,使翻译结果符合英语的最佳实践。
|
||||||
|
只需返回翻译后的英文内容,不要包含任何其他说明或注释。
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_user_content = "请将以下中文内容翻译成英文:\n\n{source_text}"
|
||||||
|
|
||||||
|
class TranslateService:
|
||||||
|
def __init__(self, system_message: str = None, user_content: str = None):
|
||||||
|
"""初始化翻译服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_message: 自定义系统提示词
|
||||||
|
user_content: 自定义用户提示词
|
||||||
|
"""
|
||||||
|
self.system_message = system_message or default_system_message
|
||||||
|
self.user_content = user_content or default_user_content
|
||||||
|
|
||||||
|
def send_to_chatgpt(self, source_text: str) -> Optional[Dict]:
|
||||||
|
"""向ChatGPT发送翻译请求"""
|
||||||
|
payload = {
|
||||||
|
"model": chatgpt_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.user_content.format(source_text=source_text)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.chatgpt_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.chatgpt_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def send_to_deepseek(self, source_text: str) -> Optional[Dict]:
|
||||||
|
"""向DeepSeek发送翻译请求"""
|
||||||
|
payload = {
|
||||||
|
"model": deepseek_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.user_content.format(source_text=source_text)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.deepseek_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.deepseek_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def send_to_doubao(self, source_text: str) -> Optional[Dict]:
|
||||||
|
"""向Doubao发送翻译请求"""
|
||||||
|
payload = {
|
||||||
|
"model": doubao_api_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_message
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": self.user_content.format(source_text=source_text)
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 0.9
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.doubao_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(settings.doubao_api_url, headers=headers, json=payload)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Error: {response.status_code}, {response.text}")
|
||||||
|
return None
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def translate_text_sync(self, source_text: str) -> Dict:
|
||||||
|
"""同步处理翻译请求"""
|
||||||
|
try:
|
||||||
|
if not source_text:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "未提供翻译文本"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 选择合适的AI模型处理请求
|
||||||
|
if translation_model == "DeepSeek":
|
||||||
|
ai_response = self.send_to_deepseek(source_text)
|
||||||
|
elif translation_model == "Doubao":
|
||||||
|
ai_response = self.send_to_doubao(source_text)
|
||||||
|
else:
|
||||||
|
ai_response = self.send_to_chatgpt(source_text)
|
||||||
|
|
||||||
|
if ai_response is None:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI服务请求失败"
|
||||||
|
}
|
||||||
|
|
||||||
|
choices = ai_response.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": "AI响应无效"
|
||||||
|
}
|
||||||
|
|
||||||
|
english_translation = choices[0].get("message", {}).get("content", "").strip()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"data": {
|
||||||
|
"english_translation": english_translation
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"翻译任务处理失败: {str(e)}")
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"处理翻译任务时发生错误: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def translate_text(self, source_text: str) -> Dict:
|
||||||
|
"""异步处理翻译请求"""
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(None, self.translate_text_sync, source_text)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"翻译失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
50
apps/jtranslate/settings.py
Normal file
50
apps/jtranslate/settings.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8108
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jtranslate"
|
||||||
|
get_route: str = "/get"
|
||||||
|
api_name: str = "jtranslate" # 默认API名称
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# DeepSeek配置
|
||||||
|
deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions"
|
||||||
|
deepseek_api_key: Optional[str] = None
|
||||||
|
deepseek_api_model: str = "deepseek-chat"
|
||||||
|
|
||||||
|
# Doubao配置
|
||||||
|
doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions"
|
||||||
|
doubao_api_key: Optional[str] = None
|
||||||
|
doubao_api_model: str = "doubao-1-5-pro-32k-250115"
|
||||||
|
|
||||||
|
# ChatGPT配置
|
||||||
|
chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions"
|
||||||
|
chatgpt_api_key: Optional[str] = None
|
||||||
|
chatgpt_api_model: str = "gpt-4"
|
||||||
|
|
||||||
|
# 默认服务模型配置
|
||||||
|
translation_model: str = "Doubao"
|
||||||
|
image_to_text_model: str = "Doubao"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jtranslate/utils.py
Normal file
146
apps/jtranslate/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/jupscale/__init__.py
Normal file
0
apps/jupscale/__init__.py
Normal file
57
apps/jupscale/api.py
Normal file
57
apps/jupscale/api.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import ImageUpscaleService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = ImageUpscaleService()
|
||||||
|
|
||||||
|
@router.post(settings.upscale_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def upscale_image_api(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
根据图像URL放大图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含图像URL的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
放大后的图片URL
|
||||||
|
"""
|
||||||
|
if "image_url" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少image_url参数")
|
||||||
|
|
||||||
|
result = await service.upscale_image(data["image_url"])
|
||||||
|
return result
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def upscale_image_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个图像URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含图像URL列表的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个图像的处理结果(图片URL)
|
||||||
|
"""
|
||||||
|
if "image_urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少image_urls参数")
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(data["image_urls"]):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
21
apps/jupscale/app.py
Normal file
21
apps/jupscale/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Jupscale",
|
||||||
|
description="Jupscale放大图片API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
337
apps/jupscale/service.py
Normal file
337
apps/jupscale/service.py
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import requests
|
||||||
|
import websocket
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Generator, Optional, AsyncGenerator
|
||||||
|
from settings import settings
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 默认配置
|
||||||
|
default_config = {
|
||||||
|
"comfyui_server_address": settings.comfyui_server_address,
|
||||||
|
"upscale_model_name": "4xNomos2_otf_esrgan.pth",
|
||||||
|
}
|
||||||
|
|
||||||
|
save_dir = settings.save_dir
|
||||||
|
download_url = settings.download_url
|
||||||
|
|
||||||
|
# 定义基础工作流 JSON 模板
|
||||||
|
workflow_template = """
|
||||||
|
{
|
||||||
|
"13": {
|
||||||
|
"inputs": {
|
||||||
|
"model_name": ""
|
||||||
|
},
|
||||||
|
"class_type": "UpscaleModelLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Load Upscale Model"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"14": {
|
||||||
|
"inputs": {
|
||||||
|
"upscale_model": [
|
||||||
|
"13",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"image": [
|
||||||
|
"15",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageUpscaleWithModel",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Upscale Image (using Model)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"15": {
|
||||||
|
"inputs": {
|
||||||
|
"url_or_path": ""
|
||||||
|
},
|
||||||
|
"class_type": "LoadImageFromUrlOrPath",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoadImageFromUrlOrPath"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"16": {
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"14",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveImageWebsocket",
|
||||||
|
"_meta": {
|
||||||
|
"title": "SaveImageWebsocket"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
class ImageUpscaleService:
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化图像放大服务"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_image_transparency(self, image_url: str) -> tuple:
|
||||||
|
"""检查图像是否有透明通道,返回图像和是否透明的标志"""
|
||||||
|
try:
|
||||||
|
# 下载图片
|
||||||
|
response = requests.get(image_url)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"无法下载图片: {image_url}")
|
||||||
|
|
||||||
|
# 使用PIL打开图片
|
||||||
|
img = Image.open(io.BytesIO(response.content))
|
||||||
|
|
||||||
|
# 检查图像是否有透明通道
|
||||||
|
has_transparency = img.mode in ('RGBA', 'LA') and img.format == 'PNG'
|
||||||
|
|
||||||
|
return img, has_transparency
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"图片处理失败: {str(e)}")
|
||||||
|
|
||||||
|
def prepare_image_for_upscale(self, image_url: str) -> tuple:
|
||||||
|
"""根据图像类型准备图像用于放大,返回处理后的图像URL和透明标志"""
|
||||||
|
img, has_transparency = self.check_image_transparency(image_url)
|
||||||
|
|
||||||
|
if not has_transparency:
|
||||||
|
# 非透明图像直接使用原图
|
||||||
|
return image_url, False, None
|
||||||
|
|
||||||
|
# 对于透明PNG,我们需要分离RGB和Alpha通道
|
||||||
|
rgb_image = img.convert('RGB')
|
||||||
|
alpha_channel = img.split()[-1]
|
||||||
|
|
||||||
|
# 保存RGB图像到临时文件
|
||||||
|
rgb_temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False)
|
||||||
|
rgb_image.save(rgb_temp_file.name, 'JPEG', quality=95)
|
||||||
|
|
||||||
|
# 保存Alpha通道到临时文件
|
||||||
|
alpha_temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
||||||
|
alpha_channel.save(alpha_temp_file.name, 'PNG')
|
||||||
|
|
||||||
|
return rgb_temp_file.name, True, alpha_temp_file.name
|
||||||
|
|
||||||
|
def upscale_alpha_channel(self, alpha_path: str, scale_factor: int = 4) -> Image.Image:
|
||||||
|
"""使用双线性插值放大Alpha通道"""
|
||||||
|
alpha_img = Image.open(alpha_path)
|
||||||
|
width, height = alpha_img.size
|
||||||
|
new_width, new_height = width * scale_factor, height * scale_factor
|
||||||
|
return alpha_img.resize((new_width, new_height), Image.BILINEAR)
|
||||||
|
|
||||||
|
def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""将提示词发送到 ComfyUI 服务器的队列中"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data)
|
||||||
|
response = json.loads(urllib.request.urlopen(req).read())
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict:
|
||||||
|
"""从 ComfyUI 获取生成的图像"""
|
||||||
|
try:
|
||||||
|
prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id)
|
||||||
|
prompt_id = prompt_response['prompt_id']
|
||||||
|
except KeyError:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
output_images = {}
|
||||||
|
current_node = ""
|
||||||
|
while True:
|
||||||
|
out = ws.recv()
|
||||||
|
if isinstance(out, str):
|
||||||
|
message = json.loads(out)
|
||||||
|
if message['type'] == 'executing':
|
||||||
|
data = message['data']
|
||||||
|
if data.get('prompt_id') == prompt_id:
|
||||||
|
if data['node'] is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
current_node = data['node']
|
||||||
|
else:
|
||||||
|
if current_node == '16': # 放大图像节点
|
||||||
|
images_output = output_images.get(current_node, [])
|
||||||
|
images_output.append(out[8:])
|
||||||
|
output_images[current_node] = images_output
|
||||||
|
|
||||||
|
return output_images
|
||||||
|
|
||||||
|
def upscale_image_sync(self, image_url: str, config: Optional[Dict] = None) -> Generator[str, None, None]:
|
||||||
|
"""放大图像,保存到本地并返回图片URL"""
|
||||||
|
cfg = default_config.copy()
|
||||||
|
if config:
|
||||||
|
cfg.update(config)
|
||||||
|
ws = websocket.WebSocket()
|
||||||
|
client_id = str(uuid.uuid4())
|
||||||
|
temp_file = None
|
||||||
|
alpha_temp_file = None
|
||||||
|
has_transparency = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 准备图像用于放大
|
||||||
|
image_path, has_transparency, alpha_path = self.prepare_image_for_upscale(image_url)
|
||||||
|
if image_path != image_url:
|
||||||
|
temp_file = image_path
|
||||||
|
image_url = image_path
|
||||||
|
if has_transparency:
|
||||||
|
alpha_temp_file = alpha_path
|
||||||
|
|
||||||
|
ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}")
|
||||||
|
workflow = json.loads(workflow_template)
|
||||||
|
workflow["13"]["inputs"]["model_name"] = cfg['upscale_model_name']
|
||||||
|
workflow["15"]["inputs"]["url_or_path"] = image_url
|
||||||
|
images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for node_id, image_list in images_dict.items():
|
||||||
|
for image_data in image_list:
|
||||||
|
if has_transparency:
|
||||||
|
# 处理带透明通道的图像
|
||||||
|
# 保存放大后的RGB图像
|
||||||
|
upscaled_rgb_path = os.path.join(save_dir, f"upscaled_rgb_{uuid.uuid4().hex[:10]}.png")
|
||||||
|
with open(upscaled_rgb_path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
# 打开放大后的RGB图像
|
||||||
|
upscaled_rgb = Image.open(upscaled_rgb_path)
|
||||||
|
|
||||||
|
# 放大Alpha通道
|
||||||
|
upscaled_alpha = self.upscale_alpha_channel(alpha_temp_file,
|
||||||
|
scale_factor=upscaled_rgb.width//Image.open(temp_file).width)
|
||||||
|
|
||||||
|
# 确保尺寸匹配
|
||||||
|
if upscaled_rgb.size != upscaled_alpha.size:
|
||||||
|
upscaled_alpha = upscaled_alpha.resize(upscaled_rgb.size, Image.BILINEAR)
|
||||||
|
|
||||||
|
# 合并通道
|
||||||
|
upscaled_rgba = upscaled_rgb.copy()
|
||||||
|
upscaled_rgba.putalpha(upscaled_alpha)
|
||||||
|
|
||||||
|
# 保存最终的RGBA图像
|
||||||
|
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
|
||||||
|
png_file_path = os.path.join(save_dir, png_filename)
|
||||||
|
upscaled_rgba.save(png_file_path, "PNG")
|
||||||
|
|
||||||
|
# 删除临时RGB文件
|
||||||
|
os.remove(upscaled_rgb_path)
|
||||||
|
|
||||||
|
# 返回PNG URL
|
||||||
|
image_url = f"{download_url}/{png_filename}"
|
||||||
|
else:
|
||||||
|
# 处理没有透明通道的图像
|
||||||
|
# 保存为JPG以减小文件大小
|
||||||
|
png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png"
|
||||||
|
png_file_path = os.path.join(save_dir, png_filename)
|
||||||
|
with open(png_file_path, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
# 打开图像并转换为JPG
|
||||||
|
img = Image.open(png_file_path)
|
||||||
|
jpg_filename = png_filename.replace('.png', '.jpg')
|
||||||
|
jpg_file_path = os.path.join(save_dir, jpg_filename)
|
||||||
|
img = img.convert('RGB')
|
||||||
|
img.save(jpg_file_path, 'JPEG', quality=95)
|
||||||
|
|
||||||
|
# 删除PNG临时文件
|
||||||
|
os.remove(png_file_path)
|
||||||
|
|
||||||
|
# 返回JPG URL
|
||||||
|
image_url = f"{download_url}/{jpg_filename}"
|
||||||
|
|
||||||
|
yield image_url
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if ws:
|
||||||
|
ws.close()
|
||||||
|
# 清理临时文件
|
||||||
|
if temp_file and os.path.exists(temp_file):
|
||||||
|
os.unlink(temp_file)
|
||||||
|
if alpha_temp_file and os.path.exists(alpha_temp_file):
|
||||||
|
os.unlink(alpha_temp_file)
|
||||||
|
|
||||||
|
async def upscale_image(self, image_url: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]:
|
||||||
|
"""异步放大图像,返回图片URL"""
|
||||||
|
try:
|
||||||
|
# 在这种情况下,我们需要手动运行同步生成器并收集结果
|
||||||
|
urls = []
|
||||||
|
|
||||||
|
# 在执行器中运行同步代码
|
||||||
|
def run_sync():
|
||||||
|
return list(self.upscale_image_sync(image_url, config))
|
||||||
|
|
||||||
|
# 获取所有URL
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
urls = await loop.run_in_executor(None, run_sync)
|
||||||
|
|
||||||
|
# 逐个返回结果
|
||||||
|
for url in urls:
|
||||||
|
yield {
|
||||||
|
"status": "success",
|
||||||
|
"image_url": url,
|
||||||
|
"message": "图片已保存"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"图像放大失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
|
||||||
|
"""批量处理多个图像URL,返回图片URL"""
|
||||||
|
total = len(image_urls)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
for i, image_url in enumerate(image_urls, 1):
|
||||||
|
try:
|
||||||
|
# 获取图片透明度信息
|
||||||
|
try:
|
||||||
|
_, has_transparency = self.check_image_transparency(image_url)
|
||||||
|
transparency_info = "PNG带透明通道" if has_transparency else "无透明通道"
|
||||||
|
except:
|
||||||
|
transparency_info = "未检测"
|
||||||
|
|
||||||
|
async for result in self.upscale_image(image_url, config):
|
||||||
|
if result["status"] == "success":
|
||||||
|
success_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_image_url": image_url,
|
||||||
|
"status": "success",
|
||||||
|
"image_url": result["image_url"],
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"transparency": transparency_info,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_image_url": image_url,
|
||||||
|
"status": "error",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": result["message"]
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_image_url": image_url,
|
||||||
|
"status": "error",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": f"处理图像时出错: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
39
apps/jupscale/settings.py
Normal file
39
apps/jupscale/settings.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8109
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jupscale"
|
||||||
|
upscale_route: str = "/upscale" # 放大图片的路由
|
||||||
|
batch_route: str = "/batch" # 批量放大图片的路由
|
||||||
|
api_name: str = "jupscale" # 默认API名称
|
||||||
|
save_dir: str = "../jfile/files"
|
||||||
|
# Japi 静态资源下载URL
|
||||||
|
download_url: str = "http://api.jingrow.com:9080/files"
|
||||||
|
|
||||||
|
# 中转图床服务上传URL
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# Stable Diffusion配置
|
||||||
|
comfyui_server_address: str = "192.168.2.200:8188"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jupscale/utils.py
Normal file
146
apps/jupscale/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/jvector/__init__.py
Normal file
0
apps/jvector/__init__.py
Normal file
53
apps/jvector/api.py
Normal file
53
apps/jvector/api.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import JvectorService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = JvectorService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(settings.file_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def vectorize_image_file(file: UploadFile = File(...), request: Request = None):
|
||||||
|
"""
|
||||||
|
将上传的文件转换为矢量图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图片文件
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的矢量图内容
|
||||||
|
"""
|
||||||
|
content = await file.read()
|
||||||
|
result = await service.vectorize_from_file(content)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def vectorize_image_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个URL图片转换为矢量图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含图片URL列表的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个图片的处理结果
|
||||||
|
"""
|
||||||
|
if "urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少urls参数")
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(data["urls"]):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
21
apps/jvector/app.py
Normal file
21
apps/jvector/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Jvector",
|
||||||
|
description="Jvector转矢量图API",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
228
apps/jvector/service.py
Normal file
228
apps/jvector/service.py
Normal file
@ -0,0 +1,228 @@
|
|||||||
|
import io
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import traceback
|
||||||
|
import tempfile
|
||||||
|
import base64
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
import asyncio
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
class JvectorService:
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化矢量图转换服务"""
|
||||||
|
# 获取配置变量
|
||||||
|
self.upload_url = settings.upload_url
|
||||||
|
self.vector_api_id = settings.vector_api_id
|
||||||
|
self.vector_api_secret = settings.vector_api_secret
|
||||||
|
self.vector_mode = settings.vector_mode
|
||||||
|
|
||||||
|
def _get_config(self, key):
|
||||||
|
"""获取配置值,从环境变量读取"""
|
||||||
|
if key == "upload_url":
|
||||||
|
return settings.upload_url
|
||||||
|
|
||||||
|
# 其他配置项的处理方式
|
||||||
|
config_map = {}
|
||||||
|
return config_map.get(key, "")
|
||||||
|
|
||||||
|
def upload_image_to_intermediate_server(self, image_url):
|
||||||
|
"""上传图片到中转服务器的函数"""
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url, verify=False)
|
||||||
|
response.raise_for_status()
|
||||||
|
image_data = response.content
|
||||||
|
|
||||||
|
parsed_url = urlparse(image_url)
|
||||||
|
file_name = Path(parsed_url.path).name
|
||||||
|
file_ext = Path(file_name).suffix
|
||||||
|
|
||||||
|
# 如果图片是webp格式,转换为png格式
|
||||||
|
if file_ext.lower() == '.webp':
|
||||||
|
image = Image.open(io.BytesIO(image_data))
|
||||||
|
png_buffer = io.BytesIO()
|
||||||
|
image.save(png_buffer, format='PNG')
|
||||||
|
image_data = png_buffer.getvalue()
|
||||||
|
file_name = file_name.replace('.webp', '.png')
|
||||||
|
|
||||||
|
files = {"file": (file_name, image_data)}
|
||||||
|
|
||||||
|
upload_response = requests.post(self.upload_url, files=files, verify=False)
|
||||||
|
|
||||||
|
if upload_response.status_code == 200:
|
||||||
|
return upload_response.json()["file_url"]
|
||||||
|
else:
|
||||||
|
error_msg = f"上传失败. 状态码: {upload_response.status_code}, {upload_response.text}"
|
||||||
|
print(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"上传图像到中间服务器失败: {str(e)}, URL: {image_url}"
|
||||||
|
print(error_msg)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
def convert_image_to_vector(self, image_url):
|
||||||
|
"""将图片转换为矢量图的函数"""
|
||||||
|
try:
|
||||||
|
url = "https://vectorizer.ai/api/v1/vectorize"
|
||||||
|
data = {
|
||||||
|
'image.url': image_url,
|
||||||
|
'mode': self.vector_mode
|
||||||
|
}
|
||||||
|
auth = (self.vector_api_id, self.vector_api_secret)
|
||||||
|
response = requests.post(url, data=data, auth=auth)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.content
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"转换图像为矢量图失败: {str(e)}, URL: {image_url}"
|
||||||
|
print(error_msg)
|
||||||
|
traceback.print_exc()
|
||||||
|
raise Exception(error_msg)
|
||||||
|
|
||||||
|
def svg_to_base64(self, svg_content):
|
||||||
|
"""将SVG内容转换为base64字符串"""
|
||||||
|
return base64.b64encode(svg_content).decode('utf-8')
|
||||||
|
|
||||||
|
async def vectorize_image(self, image_url):
|
||||||
|
"""
|
||||||
|
将图片转换为矢量图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url: 输入图像的URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的矢量图内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 转换为矢量图
|
||||||
|
vector_content = self.convert_image_to_vector(image_url)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
svg_content = self.svg_to_base64(vector_content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"svg_content": svg_content
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"矢量图转换失败: {str(e)}")
|
||||||
|
|
||||||
|
async def vectorize_from_file(self, file_content):
|
||||||
|
"""
|
||||||
|
从上传的文件内容创建矢量图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 上传的文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的矢量图内容
|
||||||
|
"""
|
||||||
|
temp_file = None
|
||||||
|
try:
|
||||||
|
# 创建临时文件
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||||
|
with open(temp_file.name, 'wb') as f:
|
||||||
|
f.write(file_content)
|
||||||
|
|
||||||
|
# 上传到中转服务器
|
||||||
|
with open(temp_file.name, 'rb') as f:
|
||||||
|
files = {"file": ("image.png", f)}
|
||||||
|
upload_response = requests.post(self.upload_url, files=files, verify=False)
|
||||||
|
|
||||||
|
if upload_response.status_code == 200:
|
||||||
|
intermediate_url = upload_response.json()["file_url"]
|
||||||
|
else:
|
||||||
|
raise Exception(f"上传失败. 状态码: {upload_response.status_code}")
|
||||||
|
|
||||||
|
# 转换为矢量图
|
||||||
|
vector_content = self.convert_image_to_vector(intermediate_url)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
svg_content = self.svg_to_base64(vector_content)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"svg_content": svg_content
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"处理文件失败: {str(e)}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if temp_file and os.path.exists(temp_file.name):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_file.name)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def process_batch(self, urls):
|
||||||
|
"""
|
||||||
|
批量处理多个URL图像,流式返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urls: 图片URL列表
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每个图片的处理结果
|
||||||
|
"""
|
||||||
|
total = len(urls)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
for i, url in enumerate(urls, 1):
|
||||||
|
try:
|
||||||
|
url_str = str(url)
|
||||||
|
result = await self.vectorize_image(url_str)
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
# 确保返回正确的数据格式
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "success",
|
||||||
|
"svg_content": result["svg_content"],
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": "处理成功"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_url": str(url),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": f"处理失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 让出控制权,避免阻塞
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def is_valid_url(self, url):
|
||||||
|
"""验证URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_file(self, url, filename):
|
||||||
|
"""下载文件到本地"""
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
with open(filename, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
|
||||||
|
return filename
|
||||||
41
apps/jvector/settings.py
Normal file
41
apps/jvector/settings.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8110
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/jvector"
|
||||||
|
file_route: str = "/file" # 转矢量图的路由
|
||||||
|
batch_route: str = "/batch" # 批量转矢量图的路由
|
||||||
|
api_name: str = "jvector" # 默认API名称
|
||||||
|
save_dir: str = "../jfile/files"
|
||||||
|
# Japi 静态资源下载URL
|
||||||
|
download_url: str = "http://api.jingrow.com:9080/files"
|
||||||
|
|
||||||
|
# 中转图床服务上传URL
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
# 矢量图转换服务配置
|
||||||
|
vector_api_id: Optional[str] = None
|
||||||
|
vector_api_secret: Optional[str] = None
|
||||||
|
vector_mode: str = "production" # 'test' 或 'production' 或 'preview'
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/jvector/utils.py
Normal file
146
apps/jvector/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/ptn_to_tshirt/__init__.py
Normal file
0
apps/ptn_to_tshirt/__init__.py
Normal file
105
apps/ptn_to_tshirt/api.py
Normal file
105
apps/ptn_to_tshirt/api.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse, JSONResponse
|
||||||
|
from service import PtnToTshirtService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import List, Optional
|
||||||
|
import io
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = PtnToTshirtService()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def pattern_to_tshirt_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个URL花型图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含花型图片URL列表和配置参数的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个图片的处理结果
|
||||||
|
"""
|
||||||
|
if "urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少urls参数")
|
||||||
|
|
||||||
|
config = data.get("config", {})
|
||||||
|
|
||||||
|
# 支持传入T恤图片URL列表
|
||||||
|
if "tshirt_urls" in data and isinstance(data["tshirt_urls"], list):
|
||||||
|
if not config:
|
||||||
|
config = {}
|
||||||
|
config["tshirt_urls"] = data["tshirt_urls"]
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
total = len(data["urls"])
|
||||||
|
for index, url in enumerate(data["urls"], 1):
|
||||||
|
try:
|
||||||
|
result = await service.pattern_to_tshirt(url, config)
|
||||||
|
result.update({
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url
|
||||||
|
})
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
except Exception as e:
|
||||||
|
yield json.dumps({
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"index": index,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url
|
||||||
|
}) + "\n"
|
||||||
|
|
||||||
|
try:
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@router.post(settings.file_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def pattern_to_tshirt_file(file: UploadFile = File(...), config: str = Form("{}"), tshirt_urls: str = None, request: Request = None):
|
||||||
|
"""
|
||||||
|
将上传的花型文件添加到T恤上
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的花型图片文件
|
||||||
|
config: JSON格式的配置参数
|
||||||
|
tshirt_urls: JSON格式的T恤图片URL列表
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的T恤图片内容
|
||||||
|
"""
|
||||||
|
content = await file.read()
|
||||||
|
|
||||||
|
# 解析配置参数
|
||||||
|
config_dict = {}
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
config_dict = json.loads(config)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="配置参数格式错误")
|
||||||
|
|
||||||
|
# 解析T恤图片URL列表
|
||||||
|
if tshirt_urls:
|
||||||
|
try:
|
||||||
|
urls_list = json.loads(tshirt_urls)
|
||||||
|
if isinstance(urls_list, list):
|
||||||
|
config_dict["tshirt_urls"] = urls_list
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="T恤图片URL列表格式错误")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await service.pattern_to_tshirt_from_file(content, config_dict)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"处理图像失败: {str(e)}")
|
||||||
21
apps/ptn_to_tshirt/app.py
Normal file
21
apps/ptn_to_tshirt/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Pattern to Tshirt",
|
||||||
|
description="将图片中的花型添加到T恤上",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
523
apps/ptn_to_tshirt/service.py
Normal file
523
apps/ptn_to_tshirt/service.py
Normal file
@ -0,0 +1,523 @@
|
|||||||
|
# pattern_to_tshirt.py
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image, ImageFilter, ImageDraw, ImageChops
|
||||||
|
import uuid
|
||||||
|
import urllib.request
|
||||||
|
import urllib3
|
||||||
|
import requests
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import warnings
|
||||||
|
import tempfile
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
import torch
|
||||||
|
import time
|
||||||
|
import gc
|
||||||
|
|
||||||
|
# 关闭不必要的警告
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
class PtnToTshirtService:
|
||||||
|
# 默认配置
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'background_removed_marker': "_rmbg",
|
||||||
|
'ptt_exclude_markers': ["_upscaled", "_vector", "_processed", "_tshirt", "_tryon"],
|
||||||
|
'tshirt_marker': "_tshirt",
|
||||||
|
'processed_pattern_marker': "_processed",
|
||||||
|
'tshirt_image_path': 'home/tshirt',
|
||||||
|
'tshirt_urls': [], # 新增: T恤图片URL列表
|
||||||
|
'alpha': 1, # 透明度,0表示全透明,1表示全不透明
|
||||||
|
'ptt_design_size_ratio': 0.4, # 设计图像占T恤图像的比例
|
||||||
|
'ptt_design_offset': [0.5, 0.45], # 设计图像在T恤图像中的相对位置 [x, y]
|
||||||
|
'ptt_design_rotation': 0, # 设计图案旋转角度
|
||||||
|
|
||||||
|
'enable_gradient_effect': True, # 是否启用渐变效果
|
||||||
|
'gradient_width': 512, # 渐变宽度
|
||||||
|
'gradient_direction': 'outward', # 渐变方向: 'outward', 'inward'
|
||||||
|
'gradient_type': 'linear', # 渐变类型: 'linear', 'radial'
|
||||||
|
'gradient_max_alpha': 150, # 渐变的最大透明度值,0-255
|
||||||
|
'gradient_start_alpha': 0, # 渐变起始处的透明度,0-255
|
||||||
|
'gradient_color': [255, 255, 255, 255], # 渐变颜色
|
||||||
|
'gradient_blur_intensity': 10, # 渐变模糊强度
|
||||||
|
'gradient_center': [0.5, 0.5], # 渐变中心位置,相对于设计图案的 [x, y]
|
||||||
|
'gradient_repeat_count': 1, # 渐变重复次数
|
||||||
|
|
||||||
|
'ptt_enable_texture_effect': False, # 是否启用纹理效果
|
||||||
|
'ptt_texture_type': 'lines', # 纹理类型: 'noise', 'lines'
|
||||||
|
'ptt_texture_blend_mode': 'multiply', # 纹理混合模式
|
||||||
|
|
||||||
|
'enable_save_processed_design': True, # 是否单独保存处理后的设计图案
|
||||||
|
'ptt_design_output_format': 'png', # 设计图案保存格式: 'png' 或 'tiff'
|
||||||
|
|
||||||
|
'ptt_enable_color_matching': True, # 是否启用颜色匹配
|
||||||
|
'ptt_enable_lighting_effect': False, # 是否启用光效
|
||||||
|
'ptt_enable_monochrome': False,
|
||||||
|
'ptt_light_intensity': 0.5, # 光照强度
|
||||||
|
'ptt_light_position': [0.5, 0.3], # 光源位置 [相对位置 x, y]
|
||||||
|
'ptt_light_radius_ratio': [0.4, 0.25], # 光源半径相对比例 [宽, 高]
|
||||||
|
'ptt_light_angle': 45, # 光源角度(度)
|
||||||
|
'ptt_light_blur': 91, # 光源模糊程度
|
||||||
|
'ptt_light_shape': 'ellipse', # 光源形状: 'ellipse', 'circle', 'rect'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化图案到T恤服务"""
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"使用设备: {self.device}")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(f"GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB")
|
||||||
|
|
||||||
|
def overlay_image_alpha(self, img, img_overlay, pos, alpha_mask):
|
||||||
|
"""在图像上叠加另一个具有透明度的图像"""
|
||||||
|
x, y = pos
|
||||||
|
y1, y2 = max(0, y), min(img.shape[0], y + img_overlay.shape[0])
|
||||||
|
x1, x2 = max(0, x), min(img.shape[1], x + img_overlay.shape[1])
|
||||||
|
y1o, y2o = max(0, -y), min(img_overlay.shape[0], img.shape[0] - y)
|
||||||
|
x1o, x2o = max(0, -x), min(img_overlay.shape[1], img.shape[1] - x)
|
||||||
|
|
||||||
|
if y1 >= y2 or x1 >= x2 or y1o >= y2o or x1o >= x2o:
|
||||||
|
return
|
||||||
|
|
||||||
|
img_crop = img[y1:y2, x1:x2]
|
||||||
|
img_overlay_crop = img_overlay[y1o:y2o, x1o:x2o]
|
||||||
|
alpha = alpha_mask[y1o:y2o, x1o:x2o, np.newaxis]
|
||||||
|
img_crop[:] = alpha * img_overlay_crop + (1 - alpha) * img_crop
|
||||||
|
|
||||||
|
def color_transfer(self, source, target):
|
||||||
|
"""颜色匹配:将源图像的颜色转换为目标图像的颜色风格"""
|
||||||
|
source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB)
|
||||||
|
target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB)
|
||||||
|
|
||||||
|
src_mean, src_std = cv2.meanStdDev(source)
|
||||||
|
tgt_mean, tgt_std = cv2.meanStdDev(target)
|
||||||
|
|
||||||
|
src_mean = src_mean.reshape(1, 1, 3)
|
||||||
|
src_std = src_std.reshape(1, 1, 3)
|
||||||
|
tgt_mean = tgt_mean.reshape(1, 1, 3)
|
||||||
|
tgt_std = tgt_std.reshape(1, 1, 3)
|
||||||
|
|
||||||
|
result = (source - src_mean) * (tgt_std / src_std) + tgt_mean
|
||||||
|
result = np.clip(result, 0, 255)
|
||||||
|
result = result.astype(np.uint8)
|
||||||
|
|
||||||
|
return cv2.cvtColor(result, cv2.COLOR_LAB2BGR)
|
||||||
|
|
||||||
|
def apply_lighting_effect(self, image, light_intensity=0.5, light_position=[0.5, 0.3],
|
||||||
|
light_radius_ratio=[0.4, 0.25], light_angle=45, light_blur=91, light_shape='ellipse'):
|
||||||
|
"""应用光照效果到图像"""
|
||||||
|
height, width = image.shape[:2]
|
||||||
|
light_position = (int(light_position[0] * width), int(light_position[1] * height))
|
||||||
|
light_radius = (int(light_radius_ratio[0] * width), int(light_radius_ratio[1] * height))
|
||||||
|
mask = np.zeros((height, width), dtype=np.uint8)
|
||||||
|
|
||||||
|
if light_shape == 'ellipse':
|
||||||
|
cv2.ellipse(mask, light_position, light_radius, light_angle, 0, 360, 255, -1)
|
||||||
|
elif light_shape == 'circle':
|
||||||
|
cv2.circle(mask, light_position, min(light_radius), 255, -1)
|
||||||
|
elif light_shape == 'rect':
|
||||||
|
rect_top_left = (light_position[0] - light_radius[0] // 2, light_position[1] - light_radius[1] // 2)
|
||||||
|
rect_bottom_right = (light_position[0] + light_radius[0] // 2, light_position[1] + light_radius[1] // 2)
|
||||||
|
cv2.rectangle(mask, rect_top_left, rect_bottom_right, 255, -1)
|
||||||
|
|
||||||
|
mask = cv2.GaussianBlur(mask, (light_blur, light_blur), 0)
|
||||||
|
mask = mask.astype(np.float32) / 255
|
||||||
|
result = image.astype(np.float32)
|
||||||
|
for i in range(3):
|
||||||
|
result[:, :, i] = result[:, :, i] * (1 - light_intensity + mask * light_intensity)
|
||||||
|
return result.astype(np.uint8)
|
||||||
|
|
||||||
|
def apply_monochrome(self, image):
|
||||||
|
"""将图像转换为单色"""
|
||||||
|
gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||||
|
monochrome_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR)
|
||||||
|
return monochrome_image
|
||||||
|
|
||||||
|
def enhance_design(self, design_image, tshirt_image, config):
|
||||||
|
"""增强设计图像,应用各种效果"""
|
||||||
|
if config.get('ptt_enable_color_matching', self.DEFAULT_CONFIG['ptt_enable_color_matching']):
|
||||||
|
design_image = self.color_transfer(design_image, tshirt_image)
|
||||||
|
|
||||||
|
if config.get('ptt_enable_lighting_effect', self.DEFAULT_CONFIG['ptt_enable_lighting_effect']):
|
||||||
|
design_image = self.apply_lighting_effect(
|
||||||
|
design_image,
|
||||||
|
config.get('ptt_light_intensity', self.DEFAULT_CONFIG['ptt_light_intensity']),
|
||||||
|
config.get('ptt_light_position', self.DEFAULT_CONFIG['ptt_light_position']),
|
||||||
|
config.get('ptt_light_radius_ratio', self.DEFAULT_CONFIG['ptt_light_radius_ratio']),
|
||||||
|
config.get('ptt_light_angle', self.DEFAULT_CONFIG['ptt_light_angle']),
|
||||||
|
config.get('ptt_light_blur', self.DEFAULT_CONFIG['ptt_light_blur']),
|
||||||
|
config.get('ptt_light_shape', self.DEFAULT_CONFIG['ptt_light_shape'])
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.get('ptt_enable_monochrome', self.DEFAULT_CONFIG['ptt_enable_monochrome']):
|
||||||
|
design_image = self.apply_monochrome(design_image)
|
||||||
|
|
||||||
|
return design_image
|
||||||
|
|
||||||
|
def add_edge_gradient(self, image, gradient_width, gradient_direction, gradient_type,
|
||||||
|
gradient_max_alpha, gradient_start_alpha, gradient_color,
|
||||||
|
gradient_blur_intensity, gradient_center):
|
||||||
|
"""添加边缘渐变效果"""
|
||||||
|
alpha = image.getchannel('A')
|
||||||
|
width, height = alpha.size
|
||||||
|
mask = Image.new('L', (width, height), 0)
|
||||||
|
draw = ImageDraw.Draw(mask)
|
||||||
|
|
||||||
|
# 确保渐变宽度不超过图像尺寸的一半
|
||||||
|
gradient_width = min(gradient_width, width // 2, height // 2)
|
||||||
|
|
||||||
|
if gradient_type == 'linear':
|
||||||
|
if gradient_direction == 'outward':
|
||||||
|
for i in range(gradient_width):
|
||||||
|
if i >= width // 2 or i >= height // 2:
|
||||||
|
break # 避免无效的矩形坐标
|
||||||
|
fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * (i / gradient_width))
|
||||||
|
draw.rectangle([i, i, width - i - 1, height - i - 1], fill=fill_value)
|
||||||
|
elif gradient_direction == 'inward':
|
||||||
|
for i in range(gradient_width):
|
||||||
|
if i >= width // 2 or i >= height // 2:
|
||||||
|
break # 避免无效的矩形坐标
|
||||||
|
fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * ((gradient_width - i) / gradient_width))
|
||||||
|
draw.rectangle([i, i, width - i - 1, height - i - 1], fill=fill_value)
|
||||||
|
elif gradient_type == 'radial':
|
||||||
|
center_x = int(width * gradient_center[0])
|
||||||
|
center_y = int(height * gradient_center[1])
|
||||||
|
max_radius = min(center_x, center_y, width - center_x, height - center_y)
|
||||||
|
for i in range(gradient_width):
|
||||||
|
radius = max_radius * (i / gradient_width)
|
||||||
|
if radius <= 0:
|
||||||
|
continue
|
||||||
|
fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * (i / gradient_width))
|
||||||
|
draw.ellipse([center_x - radius, center_y - radius, center_x + radius, center_y + radius], fill=fill_value)
|
||||||
|
|
||||||
|
if gradient_blur_intensity < 1:
|
||||||
|
gradient_blur_intensity = 1
|
||||||
|
mask = mask.filter(ImageFilter.GaussianBlur(gradient_blur_intensity))
|
||||||
|
alpha = ImageChops.multiply(alpha, mask)
|
||||||
|
image.putalpha(alpha)
|
||||||
|
|
||||||
|
if gradient_color != [255, 255, 255, 255]:
|
||||||
|
colored_mask = Image.new('RGBA', image.size, tuple(gradient_color))
|
||||||
|
colored_mask.putalpha(mask)
|
||||||
|
image = Image.alpha_composite(image, colored_mask)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def add_gradient_repeat(self, image, gradient_repeat_count, *args, **kwargs):
|
||||||
|
"""重复应用渐变效果"""
|
||||||
|
for _ in range(max(gradient_repeat_count, 1)): # 确保至少执行一次
|
||||||
|
image = self.add_edge_gradient(image, *args, **kwargs)
|
||||||
|
return image
|
||||||
|
|
||||||
|
def generate_noise_texture(self, size, intensity=64):
|
||||||
|
"""生成噪点纹理"""
|
||||||
|
noise = np.random.randint(0, intensity, (size, size, 4), dtype=np.uint8)
|
||||||
|
noise[..., 3] = 255 # 设置 alpha 通道为不透明
|
||||||
|
return Image.fromarray(noise)
|
||||||
|
|
||||||
|
def generate_line_texture(self, size, line_width=4, spacing=20, color=(0, 0, 0, 255)):
|
||||||
|
"""生成线条纹理"""
|
||||||
|
texture = Image.new('RGBA', (size, size), (255, 255, 255, 0))
|
||||||
|
draw = ImageDraw.Draw(texture)
|
||||||
|
for y in range(0, size, spacing):
|
||||||
|
draw.line([(0, y), (size, y)], fill=color, width=line_width)
|
||||||
|
for x in range(0, size, spacing):
|
||||||
|
draw.line([(x, 0), (x, size)], fill=color, width=line_width)
|
||||||
|
return texture
|
||||||
|
|
||||||
|
def add_texture(self, image, texture_type, texture_blend_mode):
|
||||||
|
"""添加纹理效果到图像"""
|
||||||
|
if texture_type == 'noise':
|
||||||
|
texture = self.generate_noise_texture(image.size[0])
|
||||||
|
elif texture_type == 'lines':
|
||||||
|
texture = self.generate_line_texture(image.size[0])
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if texture_blend_mode == 'multiply':
|
||||||
|
return ImageChops.multiply(image, texture)
|
||||||
|
elif texture_blend_mode == 'overlay':
|
||||||
|
return ImageChops.overlay(image, texture)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def rotate_image_with_transparency(self, image, angle):
|
||||||
|
"""旋转带有透明度的图像"""
|
||||||
|
rotated_image = image.rotate(angle, expand=True)
|
||||||
|
return rotated_image
|
||||||
|
|
||||||
|
def save_processed_design_image(self, design_image, output_format='png'):
|
||||||
|
"""保存处理后的设计图像"""
|
||||||
|
try:
|
||||||
|
img_bytes = io.BytesIO()
|
||||||
|
# 确保使用包含透明背景的BGRA格式
|
||||||
|
design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert('RGBA')
|
||||||
|
|
||||||
|
if output_format == 'tiff':
|
||||||
|
design_image_pil.save(img_bytes, format='TIFF', save_all=True, compression='tiff_deflate')
|
||||||
|
else:
|
||||||
|
design_image_pil.save(img_bytes, format='PNG')
|
||||||
|
|
||||||
|
img_bytes.seek(0)
|
||||||
|
return img_bytes
|
||||||
|
except Exception as e:
|
||||||
|
print(f"保存处理后的设计图像时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def generate_tshirt_image(self, design_image, tshirt_image, config):
|
||||||
|
"""将花型图案合成到T恤图像上"""
|
||||||
|
# 合并默认配置和用户配置
|
||||||
|
config = {**self.DEFAULT_CONFIG, **config}
|
||||||
|
|
||||||
|
# 将设计图像从RGBA转换为BGRA(如果需要)
|
||||||
|
if isinstance(design_image, np.ndarray) and design_image.shape[2] == 4:
|
||||||
|
if design_image.dtype != np.uint8:
|
||||||
|
design_image = design_image.astype(np.uint8)
|
||||||
|
else:
|
||||||
|
# 如果输入是PIL图像,转换为OpenCV格式
|
||||||
|
if isinstance(design_image, Image.Image):
|
||||||
|
design_image = cv2.cvtColor(np.array(design_image), cv2.COLOR_RGBA2BGRA)
|
||||||
|
else:
|
||||||
|
raise ValueError("设计图像必须是PIL Image或带Alpha通道的NumPy数组")
|
||||||
|
|
||||||
|
# 对设计图像应用渐变效果
|
||||||
|
if config.get('enable_gradient_effect', self.DEFAULT_CONFIG['enable_gradient_effect']):
|
||||||
|
design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA")
|
||||||
|
design_image_pil = self.add_gradient_repeat(
|
||||||
|
design_image_pil,
|
||||||
|
config.get('gradient_repeat_count', self.DEFAULT_CONFIG['gradient_repeat_count']),
|
||||||
|
config.get('gradient_width', self.DEFAULT_CONFIG['gradient_width']),
|
||||||
|
config.get('gradient_direction', self.DEFAULT_CONFIG['gradient_direction']),
|
||||||
|
config.get('gradient_type', self.DEFAULT_CONFIG['gradient_type']),
|
||||||
|
config.get('gradient_max_alpha', self.DEFAULT_CONFIG['gradient_max_alpha']),
|
||||||
|
config.get('gradient_start_alpha', self.DEFAULT_CONFIG['gradient_start_alpha']),
|
||||||
|
config.get('gradient_color', self.DEFAULT_CONFIG['gradient_color']),
|
||||||
|
config.get('gradient_blur_intensity', self.DEFAULT_CONFIG['gradient_blur_intensity']),
|
||||||
|
config.get('gradient_center', self.DEFAULT_CONFIG['gradient_center'])
|
||||||
|
)
|
||||||
|
design_image = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA)
|
||||||
|
|
||||||
|
# 应用纹理效果到设计图案
|
||||||
|
if config.get('ptt_enable_texture_effect', self.DEFAULT_CONFIG['ptt_enable_texture_effect']):
|
||||||
|
design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA")
|
||||||
|
design_image_pil = self.add_texture(
|
||||||
|
design_image_pil,
|
||||||
|
config.get('ptt_texture_type', self.DEFAULT_CONFIG['ptt_texture_type']),
|
||||||
|
config.get('ptt_texture_blend_mode', self.DEFAULT_CONFIG['ptt_texture_blend_mode'])
|
||||||
|
)
|
||||||
|
design_image = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA)
|
||||||
|
|
||||||
|
# 进行设计图像增强处理
|
||||||
|
design_image_enhanced = self.enhance_design(design_image[:, :, :3], tshirt_image, config)
|
||||||
|
|
||||||
|
# 应用旋转效果到设计图案
|
||||||
|
ptt_design_rotation = config.get('ptt_design_rotation', self.DEFAULT_CONFIG['ptt_design_rotation'])
|
||||||
|
if ptt_design_rotation != 0:
|
||||||
|
design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA")
|
||||||
|
design_image_pil = self.rotate_image_with_transparency(design_image_pil, ptt_design_rotation)
|
||||||
|
processed_design_image_with_alpha = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA)
|
||||||
|
else:
|
||||||
|
processed_design_image_with_alpha = cv2.merge((design_image_enhanced, design_image[:, :, 3]))
|
||||||
|
|
||||||
|
# 保存处理后的设计图像
|
||||||
|
processed_design_io = None
|
||||||
|
if config.get('enable_save_processed_design', self.DEFAULT_CONFIG['enable_save_processed_design']):
|
||||||
|
processed_design_io = self.save_processed_design_image(
|
||||||
|
processed_design_image_with_alpha,
|
||||||
|
config.get('ptt_design_output_format', self.DEFAULT_CONFIG['ptt_design_output_format'])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调整设计图像大小
|
||||||
|
tshirt_height, tshirt_width = tshirt_image.shape[:2]
|
||||||
|
design_width = int(tshirt_width * config.get('ptt_design_size_ratio', self.DEFAULT_CONFIG['ptt_design_size_ratio']))
|
||||||
|
aspect_ratio = processed_design_image_with_alpha.shape[0] / processed_design_image_with_alpha.shape[1]
|
||||||
|
design_height = int(design_width * aspect_ratio)
|
||||||
|
design_image_resized = cv2.resize(processed_design_image_with_alpha, (design_width, design_height))
|
||||||
|
|
||||||
|
# 提取Alpha通道
|
||||||
|
alpha_channel = design_image_resized[:, :, 3] / 255.0
|
||||||
|
|
||||||
|
# 计算设计图像在T恤上的位置
|
||||||
|
ptt_design_offset = config.get('ptt_design_offset', self.DEFAULT_CONFIG['ptt_design_offset'])
|
||||||
|
design_position = (
|
||||||
|
int((tshirt_width - design_width) * ptt_design_offset[0]),
|
||||||
|
int((tshirt_height - design_height) * ptt_design_offset[1])
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将设计图像叠加到T恤图像上
|
||||||
|
result_image = tshirt_image.copy()
|
||||||
|
self.overlay_image_alpha(result_image, design_image_resized[:, :, :3], design_position, alpha_channel)
|
||||||
|
|
||||||
|
# 返回结果图像和处理后的设计图像
|
||||||
|
return result_image, processed_design_io
|
||||||
|
|
||||||
|
def image_to_base64(self, image, format='png'):
|
||||||
|
"""将图像转换为base64字符串"""
|
||||||
|
try:
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
# 如果是OpenCV图像(NumPy数组),转换为PIL图像
|
||||||
|
if image.shape[2] == 3:
|
||||||
|
# BGR转RGB
|
||||||
|
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
|
||||||
|
else:
|
||||||
|
# BGRA转RGBA
|
||||||
|
image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))
|
||||||
|
else:
|
||||||
|
# 已经是PIL图像
|
||||||
|
image_pil = image
|
||||||
|
|
||||||
|
# 保存为BytesIO对象
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image_pil.save(buffered, format=format.upper())
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return img_str
|
||||||
|
except Exception as e:
|
||||||
|
print(f"将图像转换为base64时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def download_tshirt_images(self, config):
|
||||||
|
"""下载T恤图像列表"""
|
||||||
|
try:
|
||||||
|
# 首先检查是否提供了T恤图片URL列表
|
||||||
|
tshirt_urls = config.get('tshirt_urls', self.DEFAULT_CONFIG['tshirt_urls'])
|
||||||
|
if tshirt_urls and isinstance(tshirt_urls, list) and len(tshirt_urls) > 0:
|
||||||
|
tshirt_images = []
|
||||||
|
for url in tshirt_urls:
|
||||||
|
if self.is_valid_url(url):
|
||||||
|
tshirt_io = self.download_image(url)
|
||||||
|
if tshirt_io:
|
||||||
|
tshirt_image = cv2.imdecode(np.frombuffer(tshirt_io.getvalue(), np.uint8), cv2.IMREAD_COLOR)
|
||||||
|
if tshirt_image is not None:
|
||||||
|
tshirt_images.append(tshirt_image)
|
||||||
|
|
||||||
|
if tshirt_images:
|
||||||
|
return tshirt_images
|
||||||
|
|
||||||
|
# 如果没有提供URL或URL下载失败,则尝试使用本地模板
|
||||||
|
sample_tshirt_path = os.path.join(config.get('tshirt_image_path', self.DEFAULT_CONFIG['tshirt_image_path']), 'sample_tshirt.jpg')
|
||||||
|
if os.path.exists(sample_tshirt_path):
|
||||||
|
tshirt_image = cv2.imread(sample_tshirt_path)
|
||||||
|
return [tshirt_image]
|
||||||
|
else:
|
||||||
|
# 创建一个纯白色的示例T恤图像作为最后的备选
|
||||||
|
tshirt_image = np.ones((800, 600, 3), dtype=np.uint8) * 255
|
||||||
|
return [tshirt_image]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载T恤图像时发生错误: {e}")
|
||||||
|
# 创建一个纯白色的示例T恤图像作为最后的备选
|
||||||
|
tshirt_image = np.ones((800, 600, 3), dtype=np.uint8) * 255
|
||||||
|
return [tshirt_image]
|
||||||
|
|
||||||
|
def is_valid_url(self, url):
|
||||||
|
"""检查URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_image(self, url):
|
||||||
|
"""下载图像"""
|
||||||
|
try:
|
||||||
|
if self.is_valid_url(url):
|
||||||
|
response = requests.get(url, verify=False, timeout=10)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return io.BytesIO(response.content)
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载图像失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def pattern_to_tshirt(self, image_url, config=None):
|
||||||
|
"""将花型图案添加到T恤上(URL输入)"""
|
||||||
|
if not config:
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 下载花型图案
|
||||||
|
design_io = self.download_image(image_url)
|
||||||
|
if not design_io:
|
||||||
|
return {"status": "error", "message": "无法下载图像"}
|
||||||
|
|
||||||
|
return await self.pattern_to_tshirt_from_file(design_io.getvalue(), config)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_trace = traceback.format_exc()
|
||||||
|
print(f"处理图像时发生错误: {str(e)}\n{error_trace}")
|
||||||
|
return {"status": "error", "message": f"处理图像失败: {str(e)}"}
|
||||||
|
|
||||||
|
async def pattern_to_tshirt_from_file(self, file_content, config=None):
|
||||||
|
"""将花型图案添加到T恤上(文件输入)"""
|
||||||
|
if not config:
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 加载花型图案
|
||||||
|
design_io = io.BytesIO(file_content)
|
||||||
|
design_image = Image.open(design_io).convert("RGBA")
|
||||||
|
|
||||||
|
# 获取T恤图像列表
|
||||||
|
tshirt_images = self.download_tshirt_images(config)
|
||||||
|
if not tshirt_images:
|
||||||
|
return {"status": "error", "message": "无法获取T恤图像模板"}
|
||||||
|
|
||||||
|
results = []
|
||||||
|
processed_design_base64 = None
|
||||||
|
|
||||||
|
# 处理每个T恤图像
|
||||||
|
for tshirt_image in tshirt_images:
|
||||||
|
try:
|
||||||
|
# 生成合成图像
|
||||||
|
result_image, processed_design_io = self.generate_tshirt_image(design_image, tshirt_image, config)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
result_base64 = self.image_to_base64(result_image)
|
||||||
|
|
||||||
|
# 如果有处理后的设计图像,也转换为base64
|
||||||
|
if processed_design_io and processed_design_base64 is None:
|
||||||
|
processed_design_image = Image.open(processed_design_io)
|
||||||
|
processed_design_base64 = self.image_to_base64(processed_design_image)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"tshirt_image": result_base64
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_trace = traceback.format_exc()
|
||||||
|
print(f"处理单个T恤图像时发生错误: {str(e)}\n{error_trace}")
|
||||||
|
# 继续处理下一个T恤图像
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return {"status": "error", "message": "所有T恤图像处理均失败"}
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"status": "success",
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果有处理后的设计图像,添加到响应中
|
||||||
|
if processed_design_base64:
|
||||||
|
response["processed_design"] = processed_design_base64
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
error_trace = traceback.format_exc()
|
||||||
|
print(f"处理图像时发生错误: {str(e)}\n{error_trace}")
|
||||||
|
return {"status": "error", "message": f"处理图像失败: {str(e)}"}
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
32
apps/ptn_to_tshirt/settings.py
Normal file
32
apps/ptn_to_tshirt/settings.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8111
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/ptn_to_tshirt"
|
||||||
|
file_route: str = "/file"
|
||||||
|
batch_route: str = "/batch"
|
||||||
|
api_name: str = "ptn_to_tshirt"
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/ptn_to_tshirt/utils.py
Normal file
146
apps/ptn_to_tshirt/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/rmbg/__init__.py
Normal file
0
apps/rmbg/__init__.py
Normal file
52
apps/rmbg/api.py
Normal file
52
apps/rmbg/api.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import RmbgService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = RmbgService()
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def remove_background_batch(data: dict, request: Request):
|
||||||
|
"""
|
||||||
|
批量处理多个URL图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 包含图片URL列表的字典
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流式响应,包含每个图片的处理结果
|
||||||
|
"""
|
||||||
|
if "urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少urls参数")
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(data["urls"]):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post(settings.file_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def remove_background_file(file: UploadFile = File(...), request: Request = None):
|
||||||
|
"""
|
||||||
|
从上传的文件移除背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: 上传的图片文件
|
||||||
|
request: FastAPI 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图片内容
|
||||||
|
"""
|
||||||
|
content = await file.read()
|
||||||
|
result = await service.remove_background_from_file(content)
|
||||||
|
return result
|
||||||
21
apps/rmbg/app.py
Normal file
21
apps/rmbg/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Remove Background",
|
||||||
|
description="图片去背景",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
225
apps/rmbg/service.py
Normal file
225
apps/rmbg/service.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import tempfile
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from transformers import AutoModelForImageSegmentation
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
import gc
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import multiprocessing as mp
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
|
||||||
|
# 关闭不必要的警告
|
||||||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||||||
|
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||||
|
|
||||||
|
# 设置torch精度
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
|
||||||
|
class RmbgService:
|
||||||
|
def __init__(self, model_path="zhengpeng7/BiRefNet"):
|
||||||
|
"""初始化背景移除服务"""
|
||||||
|
self.model_path = model_path
|
||||||
|
self.model = None
|
||||||
|
self.device = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型"""
|
||||||
|
# 设置设备
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
t0 = time.time()
|
||||||
|
self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True)
|
||||||
|
self.model = self.model.to(self.device)
|
||||||
|
self.model.eval() # 设置为评估模式
|
||||||
|
|
||||||
|
def process_image(self, image):
|
||||||
|
"""处理图像,移除背景"""
|
||||||
|
image_size = image.size
|
||||||
|
# 转换图像
|
||||||
|
t0 = time.time()
|
||||||
|
transform_image = transforms.Compose([
|
||||||
|
transforms.Resize((1024, 1024)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
||||||
|
])
|
||||||
|
input_images = transform_image(image).unsqueeze(0).to(self.device)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
t0 = time.time()
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = self.model(input_images)[-1].sigmoid().cpu()
|
||||||
|
|
||||||
|
# 处理预测结果
|
||||||
|
t0 = time.time()
|
||||||
|
pred = preds[0].squeeze()
|
||||||
|
pred_pil = transforms.ToPILImage()(pred)
|
||||||
|
mask = pred_pil.resize(image_size)
|
||||||
|
|
||||||
|
# 添加透明通道
|
||||||
|
image.putalpha(mask)
|
||||||
|
|
||||||
|
# 清理显存
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def image_to_base64(self, image):
|
||||||
|
"""将PIL Image对象转换为base64字符串"""
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
|
||||||
|
async def remove_background(self, image_path):
|
||||||
|
"""
|
||||||
|
移除图像背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 输入图像的路径或URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图像内容
|
||||||
|
"""
|
||||||
|
temp_file = None
|
||||||
|
try:
|
||||||
|
# 检查是否是URL
|
||||||
|
if self.is_valid_url(image_path):
|
||||||
|
try:
|
||||||
|
# 下载图片到临时文件
|
||||||
|
temp_file = self.download_image(image_path)
|
||||||
|
image_path = temp_file
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"下载图片失败: {e}")
|
||||||
|
|
||||||
|
# 验证输入文件是否存在
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
raise FileNotFoundError(f"输入图像不存在: {image_path}")
|
||||||
|
|
||||||
|
# 加载并处理图像
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image_no_bg = self.process_image(image)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
image_content = self.image_to_base64(image_no_bg)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"image_content": image_content
|
||||||
|
}
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
if temp_file and os.path.exists(temp_file):
|
||||||
|
try:
|
||||||
|
os.unlink(temp_file)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def remove_background_from_file(self, file_content):
|
||||||
|
"""
|
||||||
|
从上传的文件内容移除背景
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 上传的文件内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的图像内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从文件内容创建PIL Image对象
|
||||||
|
image = Image.open(io.BytesIO(file_content)).convert("RGB")
|
||||||
|
image_no_bg = self.process_image(image)
|
||||||
|
|
||||||
|
# 转换为base64
|
||||||
|
image_content = self.image_to_base64(image_no_bg)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "success",
|
||||||
|
"image_content": image_content
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"处理图片失败: {e}")
|
||||||
|
|
||||||
|
async def process_batch(self, urls):
|
||||||
|
"""
|
||||||
|
批量处理多个URL图像,流式返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
urls: 图片URL列表
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每个图片的处理结果
|
||||||
|
"""
|
||||||
|
total = len(urls)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
for i, url in enumerate(urls, 1):
|
||||||
|
try:
|
||||||
|
url_str = str(url)
|
||||||
|
result = await self.remove_background(url_str)
|
||||||
|
success_count += 1
|
||||||
|
|
||||||
|
# 确保返回正确的数据格式
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_url": url_str,
|
||||||
|
"status": "success",
|
||||||
|
"image_content": result["image_content"],
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": "处理成功"
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"original_url": str(url),
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count,
|
||||||
|
"message": f"处理失败: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 让出控制权,避免阻塞
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
def is_valid_url(self, url):
|
||||||
|
"""验证URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def download_image(self, url):
|
||||||
|
"""从URL下载图片到临时文件"""
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# 创建临时文件
|
||||||
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
||||||
|
with open(temp_file.name, 'wb') as f:
|
||||||
|
for chunk in response.iter_content(chunk_size=8192):
|
||||||
|
f.write(chunk)
|
||||||
|
return temp_file.name
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
print("资源已清理")
|
||||||
32
apps/rmbg/settings.py
Normal file
32
apps/rmbg/settings.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8106
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/rmbg"
|
||||||
|
file_route: str = "/file"
|
||||||
|
batch_route: str = "/batch"
|
||||||
|
api_name: str = "remove_background"
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/rmbg/utils.py
Normal file
146
apps/rmbg/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
0
apps/tryon/__init__.py
Normal file
0
apps/tryon/__init__.py
Normal file
71
apps/tryon/api.py
Normal file
71
apps/tryon/api.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from service import TryonService
|
||||||
|
from utils import jingrow_api_verify_and_billing
|
||||||
|
from settings import settings
|
||||||
|
import json
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
router = APIRouter(prefix=settings.router_prefix)
|
||||||
|
service = TryonService()
|
||||||
|
|
||||||
|
@router.post(settings.batch_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def tryon_batch(data: dict, request: Request):
|
||||||
|
if "tshirt_urls" not in data:
|
||||||
|
raise HTTPException(status_code=400, detail="缺少tshirt_urls参数")
|
||||||
|
if "model_urls" not in data or not isinstance(data["model_urls"], list):
|
||||||
|
raise HTTPException(status_code=400, detail="缺少model_urls参数或格式错误")
|
||||||
|
|
||||||
|
tshirt_urls = data["tshirt_urls"]
|
||||||
|
if not isinstance(tshirt_urls, list):
|
||||||
|
raise HTTPException(status_code=400, detail="tshirt_urls必须是URL列表")
|
||||||
|
|
||||||
|
combinations = []
|
||||||
|
for model_url in data["model_urls"]:
|
||||||
|
for tshirt_url in tshirt_urls:
|
||||||
|
combinations.append(f"{tshirt_url}|{model_url}")
|
||||||
|
|
||||||
|
data["urls"] = combinations
|
||||||
|
config = data.get("config", {})
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_batch(tshirt_urls, data["model_urls"], config):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(settings.file_route)
|
||||||
|
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||||
|
async def tryon_file(
|
||||||
|
tshirt_files: List[UploadFile] = File(...),
|
||||||
|
model_file: UploadFile = File(...),
|
||||||
|
config: str = None,
|
||||||
|
request: Request = None
|
||||||
|
):
|
||||||
|
tshirt_contents = [await file.read() for file in tshirt_files]
|
||||||
|
model_content = await model_file.read()
|
||||||
|
|
||||||
|
if request:
|
||||||
|
request._body = json.dumps({"urls": [f"file_{i}" for i in range(len(tshirt_files))]}).encode()
|
||||||
|
|
||||||
|
config_dict = {}
|
||||||
|
if config:
|
||||||
|
try:
|
||||||
|
config_dict = json.loads(config)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise HTTPException(status_code=400, detail="配置参数格式错误")
|
||||||
|
|
||||||
|
async def process_and_stream():
|
||||||
|
async for result in service.process_files(tshirt_contents, model_content, config_dict):
|
||||||
|
yield json.dumps(result) + "\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
process_and_stream(),
|
||||||
|
media_type="application/x-ndjson"
|
||||||
|
)
|
||||||
|
|
||||||
21
apps/tryon/app.py
Normal file
21
apps/tryon/app.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from settings import settings
|
||||||
|
from api import router
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Tryon",
|
||||||
|
description="虚拟试穿",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(
|
||||||
|
"app:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug
|
||||||
|
)
|
||||||
455
apps/tryon/service.py
Normal file
455
apps/tryon/service.py
Normal file
@ -0,0 +1,455 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
import aiohttp
|
||||||
|
from typing import List
|
||||||
|
import tempfile
|
||||||
|
from gradio_client import Client, handle_file
|
||||||
|
import asyncio
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from settings import settings
|
||||||
|
|
||||||
|
tryon_server_url = settings.tryon_server_url
|
||||||
|
|
||||||
|
class TryonService:
|
||||||
|
# 默认配置
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'tryon_marker': "_tryon",
|
||||||
|
'tryon_target_marker': "tshirt",
|
||||||
|
'tryon_models_dir': "/files/models",
|
||||||
|
'denoise_steps': 20,
|
||||||
|
'seed': 42,
|
||||||
|
'is_crop': False,
|
||||||
|
'output_format': 'png'
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化虚拟试穿服务"""
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
def get_gradio_client(self):
|
||||||
|
"""获取或初始化Gradio客户端"""
|
||||||
|
if self.client is None:
|
||||||
|
try:
|
||||||
|
self.client = Client(tryon_server_url)
|
||||||
|
except Exception:
|
||||||
|
self.client = None
|
||||||
|
return self.client
|
||||||
|
|
||||||
|
def _convert_config_types(self, config):
|
||||||
|
"""转换配置参数类型"""
|
||||||
|
if not config:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
converted = {}
|
||||||
|
for key, value in config.items():
|
||||||
|
if key == 'denoise_steps':
|
||||||
|
try:
|
||||||
|
converted[key] = int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
converted[key] = 20
|
||||||
|
elif key == 'seed':
|
||||||
|
try:
|
||||||
|
converted[key] = int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
converted[key] = 42
|
||||||
|
elif key == 'is_crop':
|
||||||
|
try:
|
||||||
|
converted[key] = bool(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
converted[key] = False
|
||||||
|
else:
|
||||||
|
converted[key] = value
|
||||||
|
return converted
|
||||||
|
|
||||||
|
async def generate_virtual_tryon(self, tshirt_image_io: List[io.BytesIO], model_image_io: io.BytesIO, config=None):
|
||||||
|
"""生成虚拟试穿结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tshirt_image_io: T恤图片IO对象列表
|
||||||
|
model_image_io: 模特图片IO对象
|
||||||
|
config: 配置参数
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = {}
|
||||||
|
|
||||||
|
# 转换配置参数类型并合并默认配置
|
||||||
|
config = {**self.DEFAULT_CONFIG, **self._convert_config_types(config)}
|
||||||
|
|
||||||
|
# 检查图片大小
|
||||||
|
min_image_size = 1024 # 最小1KB
|
||||||
|
for tshirt_io in tshirt_image_io:
|
||||||
|
if len(tshirt_io.getvalue()) < min_image_size:
|
||||||
|
raise ValueError(f"T恤图片太小,可能不是有效图片,大小: {len(tshirt_io.getvalue())} 字节")
|
||||||
|
|
||||||
|
if len(model_image_io.getvalue()) < min_image_size:
|
||||||
|
raise ValueError(f"模特图片太小,可能不是有效图片,大小: {len(model_image_io.getvalue())} 字节")
|
||||||
|
|
||||||
|
client = self.get_gradio_client()
|
||||||
|
if client is None:
|
||||||
|
raise RuntimeError("Gradio API服务不可用,无法进行虚拟试穿")
|
||||||
|
|
||||||
|
# 创建临时目录
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
# 保存所有T-shirt图片为临时文件
|
||||||
|
temp_tshirt_files = []
|
||||||
|
for tshirt_io in tshirt_image_io:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_tshirt_file:
|
||||||
|
temp_tshirt_file.write(tshirt_io.getvalue())
|
||||||
|
temp_tshirt_files.append(temp_tshirt_file.name)
|
||||||
|
|
||||||
|
# 保存模特图片为临时文件
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_model_file:
|
||||||
|
temp_model_file.write(model_image_io.getvalue())
|
||||||
|
temp_model_file_path = temp_model_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = []
|
||||||
|
# 对每件T恤进行虚拟试穿
|
||||||
|
for temp_tshirt_file_path in temp_tshirt_files:
|
||||||
|
try:
|
||||||
|
# 调用API进行虚拟试穿
|
||||||
|
result = client.predict(
|
||||||
|
dict({"background": handle_file(temp_model_file_path), "layers": [], "composite": None}),
|
||||||
|
garm_img=handle_file(temp_tshirt_file_path),
|
||||||
|
garment_des="",
|
||||||
|
is_checked=True,
|
||||||
|
is_checked_crop=config.get('is_crop', False),
|
||||||
|
denoise_steps=config.get('denoise_steps', 20),
|
||||||
|
seed=config.get('seed', 42),
|
||||||
|
api_name="/tryon"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理返回结果
|
||||||
|
if not result or not isinstance(result, tuple) or len(result) < 1:
|
||||||
|
raise RuntimeError("虚拟试穿服务返回了无效的结果格式")
|
||||||
|
|
||||||
|
output_path = result[0] # 使用第一个图片作为结果
|
||||||
|
if not os.path.exists(output_path):
|
||||||
|
raise RuntimeError(f"输出文件不存在: {output_path}")
|
||||||
|
|
||||||
|
with open(output_path, 'rb') as f:
|
||||||
|
result_data = f.read()
|
||||||
|
results.append(io.BytesIO(result_data))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results.append(None)
|
||||||
|
|
||||||
|
return results[0] if len(results) == 1 else results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 清理临时文件
|
||||||
|
for temp_tshirt_file_path in temp_tshirt_files:
|
||||||
|
if os.path.exists(temp_tshirt_file_path):
|
||||||
|
os.remove(temp_tshirt_file_path)
|
||||||
|
if os.path.exists(temp_model_file_path):
|
||||||
|
os.remove(temp_model_file_path)
|
||||||
|
|
||||||
|
def is_valid_url(self, url):
|
||||||
|
"""检查URL是否有效"""
|
||||||
|
try:
|
||||||
|
result = urlparse(url)
|
||||||
|
return all([result.scheme, result.netloc])
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def download_image(self, url):
|
||||||
|
"""下载图片"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(url) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
content = await response.read()
|
||||||
|
image_io = io.BytesIO(content)
|
||||||
|
|
||||||
|
# 检查内容长度
|
||||||
|
if len(content) < 100: # 太小可能不是有效图片
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 检查内容类型
|
||||||
|
content_type = response.headers.get('Content-Type', '')
|
||||||
|
if 'image' not in content_type.lower():
|
||||||
|
# 检查文件头部魔术数字
|
||||||
|
header = content[:12]
|
||||||
|
is_image = any([
|
||||||
|
header.startswith(b'\x89PNG'), # PNG
|
||||||
|
header.startswith(b'\xff\xd8\xff'), # JPEG
|
||||||
|
header.startswith(b'GIF8'), # GIF
|
||||||
|
header.startswith(b'RIFF') and b'WEBP' in header # WEBP
|
||||||
|
])
|
||||||
|
if not is_image:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return image_io
|
||||||
|
return None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def process_urls(self, tshirt_urls: List[str], model_url: str, config=None):
|
||||||
|
"""
|
||||||
|
处理多个T恤URL,流式返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tshirt_urls: T恤图片URL列表
|
||||||
|
model_url: 模特图片URL
|
||||||
|
config: 配置参数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每个T恤的处理结果
|
||||||
|
"""
|
||||||
|
total = len(tshirt_urls)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
# 下载模特图片
|
||||||
|
model_io = await self.download_image(model_url)
|
||||||
|
if model_io is None:
|
||||||
|
yield {
|
||||||
|
"status": "error",
|
||||||
|
"message": f"无法下载模特图片: {model_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count + 1,
|
||||||
|
"total": total
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
for i, tshirt_url in enumerate(tshirt_urls, 1):
|
||||||
|
try:
|
||||||
|
# 下载T恤图片
|
||||||
|
tshirt_io = await self.download_image(tshirt_url)
|
||||||
|
if tshirt_io is None:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": f"无法下载T恤图片: {tshirt_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 处理图片
|
||||||
|
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||||||
|
if result is None:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": f"处理T恤图片失败: {tshirt_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
success_count += 1
|
||||||
|
result.seek(0)
|
||||||
|
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||||||
|
json.dumps(result.read().hex())
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "success",
|
||||||
|
"data": base64_data,
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
# 让出控制权,避免阻塞
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def process_files(self, tshirt_contents: List[bytes], model_content: bytes, config=None):
|
||||||
|
"""
|
||||||
|
处理多个T恤文件,流式返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tshirt_contents: T恤图片内容列表
|
||||||
|
model_content: 模特图片内容
|
||||||
|
config: 配置参数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每个T恤的处理结果
|
||||||
|
"""
|
||||||
|
total = len(tshirt_contents)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
|
||||||
|
model_io = io.BytesIO(model_content)
|
||||||
|
|
||||||
|
for i, content in enumerate(tshirt_contents, 1):
|
||||||
|
try:
|
||||||
|
tshirt_io = io.BytesIO(content)
|
||||||
|
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||||||
|
if result is None:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"status": "error",
|
||||||
|
"message": "处理T恤图片失败",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
success_count += 1
|
||||||
|
result.seek(0)
|
||||||
|
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||||||
|
json.dumps(result.read().hex())
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"status": "success",
|
||||||
|
"data": base64_data,
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": i,
|
||||||
|
"total": total,
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
# 让出控制权,避免阻塞
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
async def process_batch(self, tshirt_urls: List[str], model_urls: List[str], config=None):
|
||||||
|
"""
|
||||||
|
批量处理多个T恤和模特图片,流式返回结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tshirt_urls: T恤图片URL列表
|
||||||
|
model_urls: 模特图片URL列表
|
||||||
|
config: 配置参数
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
每个组合的处理结果
|
||||||
|
"""
|
||||||
|
total = len(tshirt_urls) * len(model_urls)
|
||||||
|
success_count = 0
|
||||||
|
error_count = 0
|
||||||
|
current_index = 0
|
||||||
|
|
||||||
|
for model_url in model_urls:
|
||||||
|
try:
|
||||||
|
model_io = await self.download_image(model_url)
|
||||||
|
if model_io is None:
|
||||||
|
error_count += len(tshirt_urls)
|
||||||
|
for tshirt_url in tshirt_urls:
|
||||||
|
current_index += 1
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": f"无法下载模特图片: {model_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
for tshirt_url in tshirt_urls:
|
||||||
|
current_index += 1
|
||||||
|
try:
|
||||||
|
tshirt_io = await self.download_image(tshirt_url)
|
||||||
|
if tshirt_io is None:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": f"无法下载T恤图片: {tshirt_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = await self.generate_virtual_tryon([tshirt_io], model_io, config)
|
||||||
|
if result is None:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": f"处理T恤图片失败: {tshirt_url}",
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
success_count += 1
|
||||||
|
result.seek(0)
|
||||||
|
base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \
|
||||||
|
json.dumps(result.read().hex())
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "success",
|
||||||
|
"data": base64_data,
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += 1
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
# 让出控制权,避免阻塞
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_count += len(tshirt_urls)
|
||||||
|
for tshirt_url in tshirt_urls:
|
||||||
|
current_index += 1
|
||||||
|
yield {
|
||||||
|
"index": current_index,
|
||||||
|
"total": total,
|
||||||
|
"model_url": model_url,
|
||||||
|
"tshirt_url": tshirt_url,
|
||||||
|
"status": "error",
|
||||||
|
"message": str(e),
|
||||||
|
"success_count": success_count,
|
||||||
|
"error_count": error_count
|
||||||
|
}
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""清理资源"""
|
||||||
|
self.client = None
|
||||||
35
apps/tryon/settings.py
Normal file
35
apps/tryon/settings.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
# Japi Server 配置
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8112
|
||||||
|
debug: bool = False
|
||||||
|
|
||||||
|
# API路由配置
|
||||||
|
router_prefix: str = "/tryon"
|
||||||
|
file_route: str = "/file"
|
||||||
|
batch_route: str = "/batch"
|
||||||
|
api_name: str = "tryon"
|
||||||
|
|
||||||
|
upload_url: str = "http://173.255.202.68/imgurl/upload"
|
||||||
|
|
||||||
|
# 虚拟试穿Tryon服务器URL
|
||||||
|
tryon_server_url: str = "http://192.168.2.200:7860"
|
||||||
|
|
||||||
|
# Jingrow Jcloud API 配置
|
||||||
|
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||||
|
jingrow_api_key: Optional[str] = None
|
||||||
|
jingrow_api_secret: Optional[str] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = get_settings()
|
||||||
146
apps/tryon/utils.py
Normal file
146
apps/tryon/utils.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import aiohttp
|
||||||
|
from functools import wraps
|
||||||
|
from fastapi import HTTPException
|
||||||
|
import os
|
||||||
|
from typing import Callable, Any, Dict, Optional, Tuple
|
||||||
|
from settings import settings
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import json
|
||||||
|
|
||||||
|
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||||
|
"""验证API密钥和团队余额"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||||
|
|
||||||
|
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||||
|
"""从Jingrow平台扣除API使用费"""
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||||
|
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||||
|
json={
|
||||||
|
"api_key": api_key,
|
||||||
|
"api_secret": api_secret,
|
||||||
|
"api_name": api_name,
|
||||||
|
"usage_count": usage_count
|
||||||
|
}
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if "message" in result and isinstance(result["message"], dict):
|
||||||
|
result = result["message"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||||
|
|
||||||
|
def get_token_from_request(request) -> str:
|
||||||
|
"""从请求中获取访问令牌"""
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header or not auth_header.startswith("token "):
|
||||||
|
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||||
|
|
||||||
|
token = auth_header[6:]
|
||||||
|
if ":" not in token:
|
||||||
|
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
def jingrow_api_verify_and_billing(api_name: str):
|
||||||
|
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not request:
|
||||||
|
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||||
|
|
||||||
|
token = get_token_from_request(request)
|
||||||
|
api_key, api_secret = token.split(":", 1)
|
||||||
|
|
||||||
|
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||||
|
if not verify_result.get("success"):
|
||||||
|
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||||
|
|
||||||
|
result = await func(*args, **kwargs)
|
||||||
|
|
||||||
|
usage_count = 1
|
||||||
|
try:
|
||||||
|
body_data = await request.json()
|
||||||
|
if isinstance(body_data, dict):
|
||||||
|
for key in ["items", "urls", "images", "files"]:
|
||||||
|
if key in body_data and isinstance(body_data[key], list):
|
||||||
|
usage_count = len(body_data[key])
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(result, StreamingResponse):
|
||||||
|
original_generator = result.body_iterator
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
async def wrapped_generator():
|
||||||
|
nonlocal success_count
|
||||||
|
async for chunk in original_generator:
|
||||||
|
try:
|
||||||
|
data = json.loads(chunk)
|
||||||
|
if isinstance(data, dict) and data.get("status") == "success":
|
||||||
|
success_count += 1
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
if success_count > 0:
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
wrapped_generator(),
|
||||||
|
media_type=result.media_type,
|
||||||
|
headers=result.headers
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(result, dict) and result.get("success") is True:
|
||||||
|
actual_usage_count = result.get("successful_count", usage_count)
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
Loading…
x
Reference in New Issue
Block a user