commit 4be051e4592180129987d533c2a1d7ef68bd8160 Author: jingrow Date: Mon May 12 02:39:56 2025 +0800 japi 微服务版 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a58a423 --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# 忽略名为 test 的文件夹 +test/ +.cursor/ + + +# 忽略所有 文件夹 +**/www/files/ +**/output/ +**/__pycache__/ + +*.py[cod] + +.env + + diff --git a/apps/add_bg/__init__.py b/apps/add_bg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/add_bg/api.py b/apps/add_bg/api.py new file mode 100644 index 0000000..d10ecab --- /dev/null +++ b/apps/add_bg/api.py @@ -0,0 +1,80 @@ +from fastapi import APIRouter, UploadFile, File, HTTPException, Request +from fastapi.responses import StreamingResponse, JSONResponse +from service import AddBgService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio + +router = APIRouter(prefix=settings.router_prefix) +service = AddBgService() + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def add_background_batch(data: dict, request: Request): + """ + 批量处理多个URL图片 + + Args: + data: 包含图片URL列表和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个图片的处理结果 + """ + if "urls" not in data: + raise HTTPException(status_code=400, detail="缺少urls参数") + + config = data.get("config", {}) + + async def process_and_stream(): + total = len(data["urls"]) + for index, url in enumerate(data["urls"], 1): + try: + result = await service.add_background(url, config) + result.update({ + "index": index, + "total": total, + "original_url": url + }) + yield json.dumps(result) + "\n" + except Exception as e: + yield json.dumps({ + "status": "error", + "message": str(e), + "index": index, + "total": total, + "original_url": url + }) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) + +@router.post(settings.file_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def add_background_file(file: UploadFile = File(...), config: str = None, request: Request = None): + """ + 为上传的文件添加背景 + + Args: + file: 上传的图片文件 + config: JSON格式的配置参数 + request: FastAPI 请求对象 + + Returns: + 处理后的图片内容 + """ + content = await file.read() + + # 解析配置参数 + config_dict = {} + if config: + try: + config_dict = json.loads(config) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="配置参数格式错误") + + result = await service.add_background_from_file(content, config_dict) + return result diff --git a/apps/add_bg/app.py b/apps/add_bg/app.py new file mode 100644 index 0000000..5168e20 --- /dev/null +++ b/apps/add_bg/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Add Background", + description="图片添加背景颜色", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/add_bg/service.py b/apps/add_bg/service.py new file mode 100644 index 0000000..2c81065 --- /dev/null +++ b/apps/add_bg/service.py @@ -0,0 +1,311 @@ +import sys +import os +import json +import io +import cv2 +import numpy as np +from PIL import Image, ImageFilter, ImageDraw, ImageChops +import uuid +import urllib.request +import urllib3 +import requests +from pydantic import BaseModel +from typing import Optional +from colorthief import ColorThief +import tempfile +from urllib.parse import urlparse +import torch +import time +import warnings +import gc +import base64 +import asyncio +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor +import colorsys + +# 关闭不必要的警告 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class AddBgService: + # 默认配置 + DEFAULT_CONFIG = { + 'added_background_marker': "_added_background", + 'enable_texture_effect': False, + 'texture_type': 'noise', + 'texture_blend_mode': 'multiply', + 'enable_depth_of_field': False, + 'blur_intensity': 15, + 'output_format': 'png', + 'enable_lighting_effect': False, + 'light_intensity': 0.1, + 'light_position': [0.5, 0.3], + 'light_radius_ratio': [0.4, 0.25], + 'light_angle': 45, + 'light_blur': 91, + 'light_shape': 'ellipse', + 'alpha_background': 0.8, + 'design_rotation': 0 + } + + def __init__(self): + """初始化添加背景服务""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def apply_lighting_effect(self, image, config): + """应用光照效果""" + light_intensity = config.get('light_intensity', self.DEFAULT_CONFIG['light_intensity']) + light_position = config.get('light_position', self.DEFAULT_CONFIG['light_position']) + light_radius_ratio = config.get('light_radius_ratio', self.DEFAULT_CONFIG['light_radius_ratio']) + light_angle = config.get('light_angle', self.DEFAULT_CONFIG['light_angle']) + light_blur = config.get('light_blur', self.DEFAULT_CONFIG['light_blur']) + light_shape = config.get('light_shape', self.DEFAULT_CONFIG['light_shape']) + + height, width = image.shape[:2] + light_position = (int(light_position[0] * width), int(light_position[1] * height)) + light_radius = (int(light_radius_ratio[0] * width), int(light_radius_ratio[1] * height)) + mask = np.zeros((height, width), dtype=np.uint8) + + if light_shape == 'ellipse': + cv2.ellipse(mask, light_position, light_radius, light_angle, 0, 360, 255, -1) + elif light_shape == 'circle': + cv2.circle(mask, light_position, min(light_radius), 255, -1) + elif light_shape == 'rect': + rect_top_left = (light_position[0] - light_radius[0] // 2, light_position[1] - light_radius[1] // 2) + rect_bottom_right = (light_position[0] + light_radius[0] // 2, light_position[1] + light_radius[1] // 2) + cv2.rectangle(mask, rect_top_left, rect_bottom_right, 255, -1) + + mask = cv2.GaussianBlur(mask, (light_blur, light_blur), 0) + mask = mask.astype(np.float32) / 255 + result = image.astype(np.float32) + for i in range(3): + result[:, :, i] = result[:, :, i] * (1 - light_intensity + mask * light_intensity) + return result.astype(np.uint8) + + def generate_noise_texture(self, size, intensity=64): + """生成噪点纹理""" + noise = np.random.randint(0, intensity, (size, size, 4), dtype=np.uint8) + noise[..., 3] = 255 # 设置 alpha 通道为不透明 + return Image.fromarray(noise) + + def generate_line_texture(self, size, line_width=4, spacing=20, color=(0, 0, 0, 255)): + """生成线条纹理""" + texture = Image.new('RGBA', (size, size), (255, 255, 255, 0)) + draw = ImageDraw.Draw(texture) + for y in range(0, size, spacing): + draw.line([(0, y), (size, y)], fill=color, width=line_width) + for x in range(0, size, spacing): + draw.line([(x, 0), (x, size)], fill=color, width=line_width) + return texture + + def add_texture(self, image, config): + """添加纹理效果""" + texture_type = config.get('texture_type', self.DEFAULT_CONFIG['texture_type']) + texture_blend_mode = config.get('texture_blend_mode', self.DEFAULT_CONFIG['texture_blend_mode']) + + if texture_type == 'noise': + texture = self.generate_noise_texture(image.size[0]) + elif texture_type == 'lines': + texture = self.generate_line_texture(image.size[0]) + else: + return image + + if texture_blend_mode == 'multiply': + return ImageChops.multiply(image, texture) + elif texture_blend_mode == 'overlay': + return ImageChops.overlay(image, texture) + else: + return image + + def calculate_dominant_color(self, image): + """计算图像的主色调""" + try: + # 将PIL Image转换为BytesIO对象 + img_byte_arr = io.BytesIO() + image.save(img_byte_arr, format='PNG') + img_byte_arr.seek(0) + + # 使用ColorThief计算主色调 + color_thief = ColorThief(img_byte_arr) + dominant_color = color_thief.get_color(quality=1) + return dominant_color + except Exception as e: + print(f"计算主色调失败: {str(e)}") + # 如果计算失败,返回默认的白色 + return (255, 255, 255) + + def rgb_to_hex(self, rgb): + """RGB转HEX""" + return '#{:02x}{:02x}{:02x}'.format(rgb[0], rgb[1], rgb[2]) + + def calculate_light_color(self, dominant_color, target_lightness=0.92): + r, g, b = [x / 255.0 for x in dominant_color] + h, l, s = colorsys.rgb_to_hls(r, g, b) + l = target_lightness + r2, g2, b2 = colorsys.hls_to_rgb(h, l, s) + return (int(r2 * 255), int(g2 * 255), int(b2 * 255)) + + def calculate_monochrome_color(self, dominant_color, alpha): + # alpha参数可忽略 + return self.calculate_light_color(dominant_color, target_lightness=0.92) + + def apply_depth_of_field(self, background_image, blur_intensity): + """应用景深效果""" + background_image_pil = Image.fromarray(background_image).convert("RGBA") + blurred_background = background_image_pil.filter(ImageFilter.GaussianBlur(blur_intensity)) + return np.array(blurred_background) + + def rotate_image_with_transparency(self, image, angle): + """旋转图像""" + rotated_image = image.rotate(angle, expand=True) + return rotated_image + + def process_image(self, image, config): + """处理图像,添加背景""" + try: + # 合并默认配置和用户配置 + config = {**self.DEFAULT_CONFIG, **config} + + # 计算主色并设置背景颜色 + try: + dominant_color = self.calculate_dominant_color(image) + background_color = self.calculate_monochrome_color(dominant_color, config['alpha_background']) + except Exception as e: + # 使用默认的白色背景 + background_color = (255, 255, 255) + + # 创建背景图像 + background_image = np.full((image.height, image.width, 4), (*background_color, 255), dtype=np.uint8) + + # 应用景深效果 + if config['enable_depth_of_field']: + background_image = self.apply_depth_of_field(background_image, config['blur_intensity']) + + # 应用纹理效果 + if config['enable_texture_effect']: + background_image_pil = Image.fromarray(background_image).convert("RGBA") + background_image_pil = self.add_texture(background_image_pil, config) + background_image = np.array(background_image_pil) + + # 将前景图像转换为numpy数组 + foreground = np.array(image) + + # 旋转前景图像 + if config['design_rotation'] != 0: + foreground_pil = Image.fromarray(foreground) + foreground_pil = self.rotate_image_with_transparency(foreground_pil, config['design_rotation']) + foreground = np.array(foreground_pil) + + # 合并前景和背景 + alpha = foreground[:, :, 3] / 255.0 + for c in range(3): + background_image[:, :, c] = background_image[:, :, c] * (1 - alpha) + foreground[:, :, c] * alpha + + # 确保最终图像不透明 + background_image[:, :, 3] = 255 + + # 应用光照效果 + if config['enable_lighting_effect']: + background_image = self.apply_lighting_effect(background_image, config) + + # 转换回PIL图像 + return Image.fromarray(background_image) + + except Exception as e: + raise Exception(f"处理图片失败: {str(e)}") + + def image_to_base64(self, image, config): + """将图片转换为base64格式""" + try: + output_format = config.get('output_format', self.DEFAULT_CONFIG['output_format']) + buffered = io.BytesIO() + image.save(buffered, format=output_format.upper()) + img_str = base64.b64encode(buffered.getvalue()).decode() + return f"data:image/{output_format.lower()};base64,{img_str}" + except Exception as e: + raise Exception(f"转换图片为base64失败: {str(e)}") + + async def add_background(self, image_path, config=None): + """为图片添加背景""" + try: + # 下载图片 + if self.is_valid_url(image_path): + image_content = self.download_image(image_path) + image = Image.open(io.BytesIO(image_content)) + else: + image = Image.open(image_path) + + # 确保图片是RGBA模式 + if image.mode != 'RGBA': + image = image.convert('RGBA') + + # 处理图片 + processed_image = self.process_image(image, config or {}) + + # 转换为base64 + result = self.image_to_base64(processed_image, config or {}) + + return { + 'status': 'success', + 'image_content': result + } + + except Exception as e: + return { + 'status': 'error', + 'message': f"处理图片失败: {str(e)}" + } + + async def add_background_from_file(self, file_content, config=None): + """ + 从上传的文件内容添加背景 + + Args: + file_content: 上传的文件内容 + config: 配置参数 + + Returns: + 处理后的图像内容 + """ + if config is None: + config = {} + + try: + # 从文件内容创建PIL Image对象 + image = Image.open(io.BytesIO(file_content)).convert("RGBA") + image_with_bg = self.process_image(image, config) + + # 转换为base64 + image_content = self.image_to_base64(image_with_bg, config) + + return { + "status": "success", + "image_content": image_content + } + + except Exception as e: + raise Exception(f"处理图片失败: {e}") + + def is_valid_url(self, url): + """验证URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + + def download_image(self, url): + """下载图片并返回内容""" + try: + response = requests.get(url, timeout=30) + response.raise_for_status() + return response.content + except Exception as e: + raise Exception(f"下载图片失败: {str(e)}") + + def cleanup(self): + """清理资源""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() diff --git a/apps/add_bg/settings.py b/apps/add_bg/settings.py new file mode 100644 index 0000000..500aa51 --- /dev/null +++ b/apps/add_bg/settings.py @@ -0,0 +1,32 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8105 + debug: bool = False + + # API路由配置 + router_prefix: str = "/add_bg" + file_route: str = "/file" + batch_route: str = "/batch" + api_name: str = "add_background" + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/add_bg/utils.py b/apps/add_bg/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/add_bg/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jart/__init__.py b/apps/jart/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/jart/api.py b/apps/jart/api.py new file mode 100644 index 0000000..83548f4 --- /dev/null +++ b/apps/jart/api.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import TxtImgService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio + +router = APIRouter(prefix=settings.router_prefix) +service = TxtImgService() + +@router.post(settings.generate_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def generate_image(data: dict, request: Request): + """ + 根据文本提示生成图像 + + Args: + data: 包含文本提示和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 生成的图像内容 + """ + if "prompt" not in data: + raise HTTPException(status_code=400, detail="缺少prompt参数") + + config = data.get("config", {}) + result = await service.generate_image(data["prompt"], config) + return result + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def generate_image_batch(data: dict, request: Request): + """ + 批量处理多个文本提示 + + Args: + data: 包含文本提示列表和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个提示的处理结果 + """ + if "prompts" not in data: + raise HTTPException(status_code=400, detail="缺少prompts参数") + + config = data.get("config", {}) + + async def process_and_stream(): + async for result in service.process_batch(data["prompts"], config): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) diff --git a/apps/jart/app.py b/apps/jart/app.py new file mode 100644 index 0000000..fa973b2 --- /dev/null +++ b/apps/jart/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="JArt", + description="JArt绘画服务API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/jart/service.py b/apps/jart/service.py new file mode 100644 index 0000000..8f92c98 --- /dev/null +++ b/apps/jart/service.py @@ -0,0 +1,272 @@ +import json +import base64 +import requests +import random +import websocket +import uuid +import urllib.request +import asyncio +import io +from typing import Dict, List, Generator, Optional, AsyncGenerator + +# 固定配置变量 +DEFAULT_CONFIG = { + "comfyui_server_address": "192.168.2.200:8188", + "ckpt_name": "flux1-schnell-fp8.safetensors", + "sampler_name": "euler", + "scheduler": "normal", + "steps": 4, + "cfg": 1, + "denoise": 1.0, + "images_per_prompt": 1, + "image_width": 1024, + "image_height": 1024, + "negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background" +} + +# 定义基础工作流 JSON 模板 +WORKFLOW_TEMPLATE = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": %d, + "denoise": %d, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "%s", + "scheduler": "%s", + "seed": 8566257, + "steps": %d + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "%s" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": %d, + "width": %d + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "%s" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +class TxtImgService: + def __init__(self): + """初始化文本生成图像服务""" + pass + + def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """将提示词发送到 ComfyUI 服务器的队列中""" + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) + response = json.loads(urllib.request.urlopen(req).read()) + return response + + def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """从 ComfyUI 获取生成的图像""" + try: + prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) + prompt_id = prompt_response['prompt_id'] + except KeyError: + return {} + + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data.get('prompt_id') == prompt_id: + if data['node'] is None: + break + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + + def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]: + """生成 Flux 模型的图片,流式返回 base64 编码的图片""" + cfg = DEFAULT_CONFIG.copy() + if config: + cfg.update(config) + + ws = websocket.WebSocket() + client_id = str(uuid.uuid4()) + + try: + ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") + images_count = int(cfg.get('images_per_prompt', 1)) + + for i in range(images_count): + workflow = json.loads(WORKFLOW_TEMPLATE % ( + cfg['cfg'], + cfg['denoise'], + cfg['sampler_name'], + cfg['scheduler'], + cfg['steps'], + cfg['ckpt_name'], + cfg['image_height'], + cfg['image_width'], + cfg['negative_prompt'] + )) + + workflow["6"]["inputs"]["text"] = prompt + seed = random.randint(1, 4294967295) + workflow["3"]["inputs"]["seed"] = seed + + images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) + + for node_id, image_list in images_dict.items(): + for image_data in image_list: + base64_image = base64.b64encode(image_data).decode('utf-8') + yield base64_image + + except Exception as e: + raise e + + finally: + if ws: + ws.close() + + async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: + """异步生成图像,流式返回结果""" + try: + loop = asyncio.get_event_loop() + + def sync_generator(): + for base64_image in self.generate_image_sync(prompt, config): + yield base64_image + + generator = await loop.run_in_executor(None, sync_generator) + + for base64_image in generator: + yield { + "status": "success", + "image": f"data:image/png;base64,{base64_image}", + "message": f"成功生成图片" + } + + except Exception as e: + yield { + "status": "error", + "message": f"图像生成失败: {str(e)}" + } + + async def process_batch(self, prompts: List[str], config: Optional[Dict] = None): + """批量处理多个文本提示,流式返回结果""" + total = len(prompts) + success_count = 0 + error_count = 0 + + for i, prompt in enumerate(prompts, 1): + try: + async for result in self.generate_image(prompt, config): + if result["status"] == "success": + success_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "success", + "image_content": result["image"], + "success_count": success_count, + "error_count": error_count, + "message": result["message"] + } + else: + error_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "error", + "success_count": success_count, + "error_count": error_count, + "message": result["message"] + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "error", + "error": str(e), + "success_count": success_count, + "error_count": error_count, + "message": f"处理失败: {str(e)}" + } + + await asyncio.sleep(0) \ No newline at end of file diff --git a/apps/jart/settings.py b/apps/jart/settings.py new file mode 100644 index 0000000..d59f308 --- /dev/null +++ b/apps/jart/settings.py @@ -0,0 +1,35 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8102 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jart" + generate_route: str = "/generate" # 生成图片的路由 + batch_route: str = "/batch" # 批量生成图片的路由 + api_name: str = "jart" # 默认API名称 + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # Stable Diffusion配置 + comfyui_server_address: str = "comfyui.jingrow.com:8188" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jart/utils.py b/apps/jart/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jart/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jart_v1/__init__.py b/apps/jart_v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/jart_v1/api.py b/apps/jart_v1/api.py new file mode 100644 index 0000000..83548f4 --- /dev/null +++ b/apps/jart_v1/api.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import TxtImgService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio + +router = APIRouter(prefix=settings.router_prefix) +service = TxtImgService() + +@router.post(settings.generate_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def generate_image(data: dict, request: Request): + """ + 根据文本提示生成图像 + + Args: + data: 包含文本提示和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 生成的图像内容 + """ + if "prompt" not in data: + raise HTTPException(status_code=400, detail="缺少prompt参数") + + config = data.get("config", {}) + result = await service.generate_image(data["prompt"], config) + return result + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def generate_image_batch(data: dict, request: Request): + """ + 批量处理多个文本提示 + + Args: + data: 包含文本提示列表和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个提示的处理结果 + """ + if "prompts" not in data: + raise HTTPException(status_code=400, detail="缺少prompts参数") + + config = data.get("config", {}) + + async def process_and_stream(): + async for result in service.process_batch(data["prompts"], config): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) diff --git a/apps/jart_v1/app.py b/apps/jart_v1/app.py new file mode 100644 index 0000000..9574ad6 --- /dev/null +++ b/apps/jart_v1/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="JArt V1", + description="JArt绘画服务API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/jart_v1/service.py b/apps/jart_v1/service.py new file mode 100644 index 0000000..a3cbb32 --- /dev/null +++ b/apps/jart_v1/service.py @@ -0,0 +1,272 @@ +import json +import base64 +import requests +import random +import websocket +import uuid +import urllib.request +import asyncio +import io +from typing import Dict, List, Generator, Optional, AsyncGenerator + +# 固定配置变量 +DEFAULT_CONFIG = { + "comfyui_server_address": "192.168.2.200:8188", + "ckpt_name": "sd3_medium_incl_clips_t5xxlfp8.safetensors", + "sampler_name": "euler", + "scheduler": "normal", + "steps": 20, + "cfg": 8, + "denoise": 1.0, + "images_per_prompt": 1, + "image_width": 1024, + "image_height": 1024, + "negative_prompt": "blur, low quality, low resolution, artifacts, text, watermark, underexposed, bad anatomy, deformed body, extra limbs, missing limbs, noisy background, cluttered background, blurry background" +} + +# 定义基础工作流 JSON 模板 +WORKFLOW_TEMPLATE = """ +{ + "3": { + "class_type": "KSampler", + "inputs": { + "cfg": %d, + "denoise": %d, + "latent_image": [ + "5", + 0 + ], + "model": [ + "4", + 0 + ], + "negative": [ + "7", + 0 + ], + "positive": [ + "6", + 0 + ], + "sampler_name": "%s", + "scheduler": "%s", + "seed": 8566257, + "steps": %d + } + }, + "4": { + "class_type": "CheckpointLoaderSimple", + "inputs": { + "ckpt_name": "%s" + } + }, + "5": { + "class_type": "EmptyLatentImage", + "inputs": { + "batch_size": 1, + "height": %d, + "width": %d + } + }, + "6": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "masterpiece best quality girl" + } + }, + "7": { + "class_type": "CLIPTextEncode", + "inputs": { + "clip": [ + "4", + 1 + ], + "text": "%s" + } + }, + "8": { + "class_type": "VAEDecode", + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "4", + 2 + ] + } + }, + "save_image_websocket_node": { + "class_type": "SaveImageWebsocket", + "inputs": { + "images": [ + "8", + 0 + ] + } + } +} +""" + +class TxtImgService: + def __init__(self): + """初始化文本生成图像服务""" + pass + + def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """将提示词发送到 ComfyUI 服务器的队列中""" + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) + response = json.loads(urllib.request.urlopen(req).read()) + return response + + def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """从 ComfyUI 获取生成的图像""" + try: + prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) + prompt_id = prompt_response['prompt_id'] + except KeyError: + return {} + + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data.get('prompt_id') == prompt_id: + if data['node'] is None: + break + else: + current_node = data['node'] + else: + if current_node == 'save_image_websocket_node': + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + + def generate_image_sync(self, prompt: str, config: Optional[Dict] = None) -> Generator[str, None, None]: + """生成 Flux 模型的图片,流式返回 base64 编码的图片""" + cfg = DEFAULT_CONFIG.copy() + if config: + cfg.update(config) + + ws = websocket.WebSocket() + client_id = str(uuid.uuid4()) + + try: + ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") + images_count = int(cfg.get('images_per_prompt', 1)) + + for i in range(images_count): + workflow = json.loads(WORKFLOW_TEMPLATE % ( + cfg['cfg'], + cfg['denoise'], + cfg['sampler_name'], + cfg['scheduler'], + cfg['steps'], + cfg['ckpt_name'], + cfg['image_height'], + cfg['image_width'], + cfg['negative_prompt'] + )) + + workflow["6"]["inputs"]["text"] = prompt + seed = random.randint(1, 4294967295) + workflow["3"]["inputs"]["seed"] = seed + + images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) + + for node_id, image_list in images_dict.items(): + for image_data in image_list: + base64_image = base64.b64encode(image_data).decode('utf-8') + yield base64_image + + except Exception as e: + raise e + + finally: + if ws: + ws.close() + + async def generate_image(self, prompt: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: + """异步生成图像,流式返回结果""" + try: + loop = asyncio.get_event_loop() + + def sync_generator(): + for base64_image in self.generate_image_sync(prompt, config): + yield base64_image + + generator = await loop.run_in_executor(None, sync_generator) + + for base64_image in generator: + yield { + "status": "success", + "image": f"data:image/png;base64,{base64_image}", + "message": f"成功生成图片" + } + + except Exception as e: + yield { + "status": "error", + "message": f"图像生成失败: {str(e)}" + } + + async def process_batch(self, prompts: List[str], config: Optional[Dict] = None): + """批量处理多个文本提示,流式返回结果""" + total = len(prompts) + success_count = 0 + error_count = 0 + + for i, prompt in enumerate(prompts, 1): + try: + async for result in self.generate_image(prompt, config): + if result["status"] == "success": + success_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "success", + "image_content": result["image"], + "success_count": success_count, + "error_count": error_count, + "message": result["message"] + } + else: + error_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "error", + "success_count": success_count, + "error_count": error_count, + "message": result["message"] + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "original_prompt": prompt, + "status": "error", + "error": str(e), + "success_count": success_count, + "error_count": error_count, + "message": f"处理失败: {str(e)}" + } + + await asyncio.sleep(0) \ No newline at end of file diff --git a/apps/jart_v1/settings.py b/apps/jart_v1/settings.py new file mode 100644 index 0000000..5154f0a --- /dev/null +++ b/apps/jart_v1/settings.py @@ -0,0 +1,35 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8103 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jart_v1" + generate_route: str = "/generate" # 生成图片的路由 + batch_route: str = "/batch" # 批量生成图片的路由 + api_name: str = "jart_v1" # 默认API名称 + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # Stable Diffusion配置 + comfyui_server_address: str = "comfyui.jingrow.com:8188" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jart_v1/utils.py b/apps/jart_v1/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jart_v1/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jchat/__init__.py b/apps/jchat/__init__.py new file mode 100644 index 0000000..c0e6f31 --- /dev/null +++ b/apps/jchat/__init__.py @@ -0,0 +1 @@ +# 使jchat目录成为Python包 \ No newline at end of file diff --git a/apps/jchat/api.py b/apps/jchat/api.py new file mode 100644 index 0000000..3c4b671 --- /dev/null +++ b/apps/jchat/api.py @@ -0,0 +1,59 @@ +from fastapi import APIRouter, HTTPException, Request +from service import ChatService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +from functools import wraps + +router = APIRouter(prefix=settings.router_prefix) +service = ChatService() + +def dynamic_billing_wrapper(func): + """动态API扣费装饰器,使用模型名称作为API名称""" + @wraps(func) + async def wrapper(data: dict, request: Request): + api_name = settings.default_api_name # 使用settings中的默认API名称 + if "model" in data: + api_name = data["model"] + + dynamic_decorator = jingrow_api_verify_and_billing(api_name=api_name) + decorated_func = dynamic_decorator(func) + return await decorated_func(**{"data": data, "request": request}) + + return wrapper + +@router.post(settings.chat_route) +@dynamic_billing_wrapper +async def chat_api(data: dict, request: Request): + """ + 通用文本聊天API,支持OpenAI和豆包等模型的请求格式 + + Args: + data: 包含以下字段的字典: + - messages: 消息列表,每个消息包含 role 和 content(必需) + - model: 选择使用的模型(可选,默认为配置的默认模型) + - temperature: 温度参数(可选,默认为0.7) + - top_p: top_p参数(可选,默认为0.9) + - max_tokens: 最大生成token数(可选,默认为2048) + request: FastAPI 请求对象 + + Returns: + AI生成的回复内容 + """ + if "messages" not in data: + raise HTTPException(status_code=400, detail="缺少messages参数") + + try: + if "model" in data: + service.model = data["model"] + if "temperature" in data: + service.temperature = data["temperature"] + if "top_p" in data: + service.top_p = data["top_p"] + if "max_tokens" in data: + service.max_tokens = data["max_tokens"] + + result = await service.chat(data["messages"]) + return result + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/apps/jchat/app.py b/apps/jchat/app.py new file mode 100644 index 0000000..4731a74 --- /dev/null +++ b/apps/jchat/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="JChat Service", + description="AI聊天服务API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) diff --git a/apps/jchat/service.py b/apps/jchat/service.py new file mode 100644 index 0000000..34d26d4 --- /dev/null +++ b/apps/jchat/service.py @@ -0,0 +1,223 @@ +import json +import requests +import asyncio +from typing import Dict, Optional, List, Union +from settings import settings + +# 默认模型配置 +default_model = "deepseek" # 默认使用的模型,可选值为"gpt"、"deepseek"或"doubao" +gpt_api_model = "gpt-4o" # ChatGPT模型名称 +deepseek_api_model = "deepseek-chat" # DeepSeek模型名称 +doubao_api_model = "doubao-1-5-thinking-pro-250415" # Doubao模型名称 + +# 模型映射配置 +model_mapping = { + "jingrow-chat": { + "type": "deepseek", + "model": "deepseek-chat" + }, + "jingrow-chat-lite": { + "type": "doubao", + "model": "doubao-1-5-lite-32k-250115" + }, + "jingrow-chat-think": { + "type": "doubao", + "model": "doubao-1-5-thinking-pro-250415" + }, + "jingrow-chat-vision": { + "type": "doubao", + "model": "doubao-1.5-vision-pro-250328" + } +} + +# 默认系统提示词 +default_system_message = """ +你是一个有用的AI助手,请根据用户的问题提供清晰、准确的回答。 +""" + +class ChatService: + def __init__(self, model: str = None, temperature: float = 0.7, top_p: float = 0.9, max_tokens: int = 2048): + """初始化聊天服务 + + Args: + model: 选择使用的模型 + temperature: 温度参数 + top_p: top_p参数 + max_tokens: 最大生成token数 + """ + self.model = model + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + + def _get_model_config(self, model: str) -> Dict: + """获取模型配置 + + Args: + model: 模型名称 + + Returns: + 包含模型类型和具体模型名称的字典 + """ + # 检查是否在映射表中 + if model in model_mapping: + return model_mapping[model] + + # 根据模型名称判断类型 + model_lower = model.lower() + if "deepseek" in model_lower: + return {"type": "deepseek", "model": model} + elif "doubao" in model_lower: + return {"type": "doubao", "model": model} + else: + return {"type": "gpt", "model": model} + + def _get_api_config(self, model_type: str) -> Dict: + """获取API配置 + + Args: + model_type: 模型类型(gpt/deepseek/doubao) + + Returns: + 包含API配置的字典 + """ + config = { + "gpt": { + "url": settings.chatgpt_api_url, + "key": settings.chatgpt_api_key, + "model": settings.chatgpt_api_model + }, + "deepseek": { + "url": settings.deepseek_api_url, + "key": settings.deepseek_api_key, + "model": settings.deepseek_api_model + }, + "doubao": { + "url": settings.doubao_api_url, + "key": settings.doubao_api_key, + "model": settings.doubao_api_model + } + } + return config.get(model_type, config["gpt"]) + + def _prepare_payload(self, messages: List[Dict], model_type: str, model_name: str) -> Dict: + """准备请求payload + + Args: + messages: 消息列表 + model_type: 模型类型 + model_name: 具体模型名称 + + Returns: + 请求payload + """ + api_config = self._get_api_config(model_type) + + payload = { + "model": model_name, # 使用映射后的具体模型名称 + "messages": messages, + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens + } + + return payload + + def _send_request(self, messages: List[Dict], model_type: str, model_name: str) -> Optional[Dict]: + """发送API请求 + + Args: + messages: 消息列表 + model_type: 模型类型 + model_name: 具体模型名称 + + Returns: + API响应 + """ + api_config = self._get_api_config(model_type) + payload = self._prepare_payload(messages, model_type, model_name) + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_config['key']}" + } + + try: + response = requests.post( + api_config["url"], + headers=headers, + json=payload, + timeout=(10, 300) + ) + + if response.status_code != 200: + return None + return response.json() + except Exception as e: + return None + + def chat_sync(self, messages: List[Dict]) -> Dict: + """同步处理聊天请求 + + Args: + messages: 消息列表,每个消息包含 role 和 content + + Returns: + 处理结果 + """ + try: + model_config = self._get_model_config(self.model or default_model) + model_type = model_config["type"] + model_name = model_config["model"] + + ai_response = self._send_request(messages, model_type, model_name) + + if ai_response is None: + return { + "status": "error", + "message": "AI服务请求失败" + } + + choices = ai_response.get("choices", []) + if not choices: + return { + "status": "error", + "message": "AI响应无效" + } + + message = choices[0].get("message", {}).get("content", "") + if not message: + return { + "status": "error", + "message": "AI响应内容为空" + } + + return { + "status": "success", + "data": message + } + + except Exception as e: + return { + "status": "error", + "message": f"处理聊天任务时发生错误: {str(e)}" + } + + async def chat(self, messages: List[Dict]) -> Dict: + """异步处理聊天请求 + + Args: + messages: 消息列表,每个消息包含 role 和 content + + Returns: + 处理结果 + """ + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, self.chat_sync, messages) + return result + except Exception as e: + return { + "status": "error", + "message": f"聊天请求失败: {str(e)}" + } diff --git a/apps/jchat/settings.py b/apps/jchat/settings.py new file mode 100644 index 0000000..34b73a8 --- /dev/null +++ b/apps/jchat/settings.py @@ -0,0 +1,50 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8101 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jchat" + chat_route: str = "/chat" + default_api_name: str = "jingrow-chat" # 默认API名称 + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # DeepSeek配置 + deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions" + deepseek_api_key: Optional[str] = None + deepseek_api_model: str = "deepseek-chat" + + # Doubao配置 + doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + doubao_api_key: Optional[str] = None + doubao_api_model: str = "doubao-1-5-pro-32k-250115" + + # ChatGPT配置 + chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions" + chatgpt_api_key: Optional[str] = None + chatgpt_api_model: str = "gpt-4" + + # 默认服务模型配置 + translation_model: str = "Doubao" + image_to_text_model: str = "Doubao" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jchat/utils.py b/apps/jchat/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jchat/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jdescribe/__init__.py b/apps/jdescribe/__init__.py new file mode 100644 index 0000000..c0e6f31 --- /dev/null +++ b/apps/jdescribe/__init__.py @@ -0,0 +1 @@ +# 使jchat目录成为Python包 \ No newline at end of file diff --git a/apps/jdescribe/api.py b/apps/jdescribe/api.py new file mode 100644 index 0000000..6e487ae --- /dev/null +++ b/apps/jdescribe/api.py @@ -0,0 +1,36 @@ +from fastapi import APIRouter, HTTPException, Request +from service import ImageDescribeService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json + +router = APIRouter(prefix=settings.router_prefix) +service = ImageDescribeService() + +@router.post(settings.get_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def describe_image_api(data: dict, request: Request): + """ + 根据图像URL生成中英文描述 + + Args: + data: 包含以下字段的字典: + - image_url: 图片URL(必需) + - system_message: 自定义系统消息(可选) + - user_content: 自定义用户消息(可选) + request: FastAPI 请求对象 + + Returns: + 图像的中英文描述 + """ + if "image_url" not in data: + raise HTTPException(status_code=400, detail="缺少image_url参数") + + # 如果提供了自定义消息,则更新service实例的消息 + if "system_message" in data: + service.system_message = data["system_message"] + if "user_content" in data: + service.user_content = data["user_content"] + + result = await service.describe_image(data["image_url"]) + return result diff --git a/apps/jdescribe/app.py b/apps/jdescribe/app.py new file mode 100644 index 0000000..7e04efa --- /dev/null +++ b/apps/jdescribe/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Jdescribe", + description="Jdescribe描述图片API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) diff --git a/apps/jdescribe/service.py b/apps/jdescribe/service.py new file mode 100644 index 0000000..c6dd860 --- /dev/null +++ b/apps/jdescribe/service.py @@ -0,0 +1,193 @@ +import json +import requests +from pathlib import Path +from urllib.parse import urlparse +from PIL import Image +import io +import asyncio +import os +from typing import Dict, Optional +from settings import settings + +image_to_text_model = "Doubao" +deepseek_api_model = "deepseek-chat" +doubao_api_model = "doubao-1.5-vision-pro-250328" +chatgpt_api_model = "gpt-4o" + +default_system_message = """ +请用中英文分别描述该图片,使用结构化描述,描述的内容用于ai绘画,因此请优化内容,不要用这是开头,使之适合用作ai绘画prompts。 +输出格式为: +{ + "中文描述": "中文内容", + "英文描述": "英文内容" +} +""" + +default_user_content = "请用中英文分别生成该图片的内容描述。" + +class ImageDescribeService: + def __init__(self, system_message: str = None, user_content: str = None): + """初始化图像描述服务 + + Args: + system_message: 自定义系统提示词 + user_content: 自定义用户提示词 + """ + self.system_message = system_message or default_system_message + self.user_content = user_content or default_user_content + + def send_to_chatgpt(self, image_url: str) -> Optional[Dict]: + """向ChatGPT发送图像描述请求""" + payload = { + "model": chatgpt_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": f"{self.user_content}\n\n图片链接: {image_url}" + } + ], + "temperature": 0.9, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.chatgpt_api_key}" + } + + response = requests.post(settings.chatgpt_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def send_to_deepseek(self, image_url: str) -> Optional[Dict]: + """向DeepSeek发送图像描述请求""" + payload = { + "model": deepseek_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": f"{self.user_content}\n\n图片链接: {image_url}" + } + ], + "temperature": 0.9, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.deepseek_api_key}" + } + + response = requests.post(settings.deepseek_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def send_to_doubao(self, image_url: str) -> Optional[Dict]: + """向Doubao发送图像描述请求""" + payload = { + "model": doubao_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": self.user_content + }, + { + "type": "image_url", + "image_url": { + "url": image_url + } + } + ] + } + ], + "temperature": 0.9, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.doubao_api_key}" + } + + response = requests.post(settings.doubao_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def describe_image_sync(self, image_url: str) -> Dict: + """同步处理图像描述请求""" + try: + # 选择合适的AI模型处理请求 + if image_to_text_model == "DeepSeek": + ai_response = self.send_to_deepseek(image_url) + elif image_to_text_model == "Doubao": + ai_response = self.send_to_doubao(image_url) + else: + ai_response = self.send_to_chatgpt(image_url) + + if ai_response is None: + return { + "status": "error", + "message": "AI服务请求失败" + } + + choices = ai_response.get("choices", []) + if not choices: + return { + "status": "error", + "message": "AI响应无效" + } + + message = choices[0].get("message", {}).get("content", "") + response_data = json.loads(message) + cn_description = response_data.get("中文描述", "") + en_description = response_data.get("英文描述", "") + + return { + "status": "success", + "data": { + "cn_description": cn_description, + "en_description": en_description + } + } + + except Exception as e: + print(f"描述任务处理失败: {str(e)}") + return { + "status": "error", + "message": f"处理描述任务时发生错误: {str(e)}" + } + + async def describe_image(self, image_url: str) -> Dict: + """异步处理图像描述请求""" + try: + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, self.describe_image_sync, image_url) + return result + except Exception as e: + return { + "status": "error", + "message": f"图像描述失败: {str(e)}" + } + + diff --git a/apps/jdescribe/settings.py b/apps/jdescribe/settings.py new file mode 100644 index 0000000..f2bddfd --- /dev/null +++ b/apps/jdescribe/settings.py @@ -0,0 +1,50 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8107 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jdescribe" + get_route: str = "/get" + api_name: str = "jdescribe" # 默认API名称 + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # DeepSeek配置 + deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions" + deepseek_api_key: Optional[str] = None + deepseek_api_model: str = "deepseek-chat" + + # Doubao配置 + doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + doubao_api_key: Optional[str] = None + doubao_api_model: str = "doubao-1-5-pro-32k-250115" + + # ChatGPT配置 + chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions" + chatgpt_api_key: Optional[str] = None + chatgpt_api_model: str = "gpt-4" + + # 默认服务模型配置 + translation_model: str = "Doubao" + image_to_text_model: str = "Doubao" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jdescribe/utils.py b/apps/jdescribe/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jdescribe/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jfile/app.py b/apps/jfile/app.py new file mode 100644 index 0000000..486c4ae --- /dev/null +++ b/apps/jfile/app.py @@ -0,0 +1,35 @@ +import uvicorn +from fastapi import FastAPI +from fastapi.staticfiles import StaticFiles +from file_cleaner import FileCleaner +from settings import settings +import asyncio + +app = FastAPI( + title="www", + description="公共静态资源访问服务", + version="1.0.0" +) + +# 挂载静态文件目录 +app.mount("/files", StaticFiles(directory="files"), name="files") + + +# 注册文件定时清理任务 +save_dir = "files" +file_prefix = "upscaled_" +retention_hours = settings.file_retention_hours +cleaner = FileCleaner(save_dir, file_prefix, retention_hours) + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(cleaner.periodic_cleanup()) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/jfile/file_cleaner.py b/apps/jfile/file_cleaner.py new file mode 100644 index 0000000..1a77089 --- /dev/null +++ b/apps/jfile/file_cleaner.py @@ -0,0 +1,33 @@ +import os +import asyncio +from datetime import datetime, timedelta + +class FileCleaner: + def __init__(self, target_dir, prefix, retention_hours): + self.target_dir = target_dir + self.prefix = prefix + self.retention_hours = retention_hours + + async def periodic_cleanup(self): + while True: + try: + self.cleanup_old_files() + except Exception as e: + print(f"清理文件时出错: {str(e)}") + await asyncio.sleep(3600) + + def cleanup_old_files(self): + if not os.path.exists(self.target_dir): + return + cutoff_time = datetime.now() - timedelta(hours=self.retention_hours) + for filename in os.listdir(self.target_dir): + if not filename.startswith(self.prefix): + continue + file_path = os.path.join(self.target_dir, filename) + file_time = datetime.fromtimestamp(os.path.getctime(file_path)) + if file_time < cutoff_time: + try: + os.remove(file_path) + print(f"已删除过期文件: {filename}") + except Exception as e: + print(f"删除文件失败 {filename}: {str(e)}") \ No newline at end of file diff --git a/apps/jfile/files/upscaled_2fead90ea9.jpg b/apps/jfile/files/upscaled_2fead90ea9.jpg new file mode 100644 index 0000000..8f920b3 Binary files /dev/null and b/apps/jfile/files/upscaled_2fead90ea9.jpg differ diff --git a/apps/jfile/files/upscaled_466670e1cb.jpg b/apps/jfile/files/upscaled_466670e1cb.jpg new file mode 100644 index 0000000..900118e Binary files /dev/null and b/apps/jfile/files/upscaled_466670e1cb.jpg differ diff --git a/apps/jfile/settings.py b/apps/jfile/settings.py new file mode 100644 index 0000000..dc8dfe4 --- /dev/null +++ b/apps/jfile/settings.py @@ -0,0 +1,21 @@ +from pydantic_settings import BaseSettings +from typing import Optional + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8100 + debug: bool = False + + # Japi 静态资源下载URL + download_url: str = "http://api.jingrow.com:9080/files" + + # 文件保留时间(小时) + file_retention_hours: int = 1 + + + class Config: + env_file = ".env" + +# 创建全局配置实例 +settings = Settings() \ No newline at end of file diff --git a/apps/jtranslate/__init__.py b/apps/jtranslate/__init__.py new file mode 100644 index 0000000..c0e6f31 --- /dev/null +++ b/apps/jtranslate/__init__.py @@ -0,0 +1 @@ +# 使jchat目录成为Python包 \ No newline at end of file diff --git a/apps/jtranslate/api.py b/apps/jtranslate/api.py new file mode 100644 index 0000000..5cc4a96 --- /dev/null +++ b/apps/jtranslate/api.py @@ -0,0 +1,35 @@ +from fastapi import APIRouter, HTTPException, Request +from service import TranslateService +from utils import jingrow_api_verify_and_billing +from settings import settings + +router = APIRouter(prefix=settings.router_prefix) +service = TranslateService() + +@router.post(settings.get_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def translate_text_api(data: dict, request: Request): + """ + 将中文文本翻译成英文 + + Args: + data: 包含以下字段的字典: + - source_text: 源文本(必需) + - system_message: 自定义系统消息(可选) + - user_content: 自定义用户消息(可选) + request: FastAPI 请求对象 + + Returns: + 翻译后的英文文本 + """ + if "source_text" not in data: + raise HTTPException(status_code=400, detail="缺少source_text参数") + + # 如果提供了自定义消息,则更新service实例的消息 + if "system_message" in data: + service.system_message = data["system_message"] + if "user_content" in data: + service.user_content = data["user_content"] + + result = await service.translate_text(data["source_text"]) + return result diff --git a/apps/jtranslate/app.py b/apps/jtranslate/app.py new file mode 100644 index 0000000..fc0c86d --- /dev/null +++ b/apps/jtranslate/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Jtranslate", + description="Jtranslate翻译API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) diff --git a/apps/jtranslate/service.py b/apps/jtranslate/service.py new file mode 100644 index 0000000..8a2aa47 --- /dev/null +++ b/apps/jtranslate/service.py @@ -0,0 +1,179 @@ +import json +import requests +import os +from typing import Dict, Optional +from settings import settings + +# 模型配置 +translation_model = "Doubao" +deepseek_api_model = "deepseek-chat" +doubao_api_model = "doubao-1-5-pro-32k-250115" +chatgpt_api_model = "gpt-4o" + + +# 自定义提示词配置 +default_system_message = """ +你是一位专业的中译英翻译专家。请将提供的中文内容翻译成地道、流畅的英文,确保保留原文的风格和语境。 +翻译时要注意原文的专业术语和表达方式,使翻译结果符合英语的最佳实践。 +只需返回翻译后的英文内容,不要包含任何其他说明或注释。 +""" + +default_user_content = "请将以下中文内容翻译成英文:\n\n{source_text}" + +class TranslateService: + def __init__(self, system_message: str = None, user_content: str = None): + """初始化翻译服务 + + Args: + system_message: 自定义系统提示词 + user_content: 自定义用户提示词 + """ + self.system_message = system_message or default_system_message + self.user_content = user_content or default_user_content + + def send_to_chatgpt(self, source_text: str) -> Optional[Dict]: + """向ChatGPT发送翻译请求""" + payload = { + "model": chatgpt_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": self.user_content.format(source_text=source_text) + } + ], + "temperature": 0.3, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.chatgpt_api_key}" + } + + response = requests.post(settings.chatgpt_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def send_to_deepseek(self, source_text: str) -> Optional[Dict]: + """向DeepSeek发送翻译请求""" + payload = { + "model": deepseek_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": self.user_content.format(source_text=source_text) + } + ], + "temperature": 0.3, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.deepseek_api_key}" + } + + response = requests.post(settings.deepseek_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def send_to_doubao(self, source_text: str) -> Optional[Dict]: + """向Doubao发送翻译请求""" + payload = { + "model": doubao_api_model, + "messages": [ + { + "role": "system", + "content": self.system_message + }, + { + "role": "user", + "content": self.user_content.format(source_text=source_text) + } + ], + "temperature": 0.3, + "top_p": 0.9 + } + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {settings.doubao_api_key}" + } + + response = requests.post(settings.doubao_api_url, headers=headers, json=payload) + if response.status_code != 200: + print(f"Error: {response.status_code}, {response.text}") + return None + return response.json() + + def translate_text_sync(self, source_text: str) -> Dict: + """同步处理翻译请求""" + try: + if not source_text: + return { + "status": "error", + "message": "未提供翻译文本" + } + + # 选择合适的AI模型处理请求 + if translation_model == "DeepSeek": + ai_response = self.send_to_deepseek(source_text) + elif translation_model == "Doubao": + ai_response = self.send_to_doubao(source_text) + else: + ai_response = self.send_to_chatgpt(source_text) + + if ai_response is None: + return { + "status": "error", + "message": "AI服务请求失败" + } + + choices = ai_response.get("choices", []) + if not choices: + return { + "status": "error", + "message": "AI响应无效" + } + + english_translation = choices[0].get("message", {}).get("content", "").strip() + + return { + "status": "success", + "data": { + "english_translation": english_translation + } + } + + except Exception as e: + print(f"翻译任务处理失败: {str(e)}") + return { + "status": "error", + "message": f"处理翻译任务时发生错误: {str(e)}" + } + + async def translate_text(self, source_text: str) -> Dict: + """异步处理翻译请求""" + try: + import asyncio + loop = asyncio.get_event_loop() + result = await loop.run_in_executor(None, self.translate_text_sync, source_text) + return result + except Exception as e: + return { + "status": "error", + "message": f"翻译失败: {str(e)}" + } + diff --git a/apps/jtranslate/settings.py b/apps/jtranslate/settings.py new file mode 100644 index 0000000..19c9105 --- /dev/null +++ b/apps/jtranslate/settings.py @@ -0,0 +1,50 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8108 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jtranslate" + get_route: str = "/get" + api_name: str = "jtranslate" # 默认API名称 + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # DeepSeek配置 + deepseek_api_url: str = "https://api.deepseek.com/v1/chat/completions" + deepseek_api_key: Optional[str] = None + deepseek_api_model: str = "deepseek-chat" + + # Doubao配置 + doubao_api_url: str = "https://ark.cn-beijing.volces.com/api/v3/chat/completions" + doubao_api_key: Optional[str] = None + doubao_api_model: str = "doubao-1-5-pro-32k-250115" + + # ChatGPT配置 + chatgpt_api_url: str = "https://api.openai.com/v1/chat/completions" + chatgpt_api_key: Optional[str] = None + chatgpt_api_model: str = "gpt-4" + + # 默认服务模型配置 + translation_model: str = "Doubao" + image_to_text_model: str = "Doubao" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jtranslate/utils.py b/apps/jtranslate/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jtranslate/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jupscale/__init__.py b/apps/jupscale/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/jupscale/api.py b/apps/jupscale/api.py new file mode 100644 index 0000000..279fc5f --- /dev/null +++ b/apps/jupscale/api.py @@ -0,0 +1,57 @@ +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import ImageUpscaleService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio +import logging +import time +from typing import Optional + +router = APIRouter(prefix=settings.router_prefix) +service = ImageUpscaleService() + +@router.post(settings.upscale_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def upscale_image_api(data: dict, request: Request): + """ + 根据图像URL放大图像 + + Args: + data: 包含图像URL的字典 + request: FastAPI 请求对象 + + Returns: + 放大后的图片URL + """ + if "image_url" not in data: + raise HTTPException(status_code=400, detail="缺少image_url参数") + + result = await service.upscale_image(data["image_url"]) + return result + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def upscale_image_batch(data: dict, request: Request): + """ + 批量处理多个图像URL + + Args: + data: 包含图像URL列表的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个图像的处理结果(图片URL) + """ + if "image_urls" not in data: + raise HTTPException(status_code=400, detail="缺少image_urls参数") + + async def process_and_stream(): + async for result in service.process_batch(data["image_urls"]): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) diff --git a/apps/jupscale/app.py b/apps/jupscale/app.py new file mode 100644 index 0000000..f8c9f6e --- /dev/null +++ b/apps/jupscale/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Jupscale", + description="Jupscale放大图片API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/jupscale/service.py b/apps/jupscale/service.py new file mode 100644 index 0000000..3b08386 --- /dev/null +++ b/apps/jupscale/service.py @@ -0,0 +1,337 @@ +import json +import base64 +import requests +import websocket +import uuid +import urllib.request +import asyncio +import os +from typing import Dict, List, Generator, Optional, AsyncGenerator +from settings import settings +from PIL import Image +import io +import tempfile +import numpy as np + +# 默认配置 +default_config = { + "comfyui_server_address": settings.comfyui_server_address, + "upscale_model_name": "4xNomos2_otf_esrgan.pth", +} + +save_dir = settings.save_dir +download_url = settings.download_url + +# 定义基础工作流 JSON 模板 +workflow_template = """ +{ + "13": { + "inputs": { + "model_name": "" + }, + "class_type": "UpscaleModelLoader", + "_meta": { + "title": "Load Upscale Model" + } + }, + "14": { + "inputs": { + "upscale_model": [ + "13", + 0 + ], + "image": [ + "15", + 0 + ] + }, + "class_type": "ImageUpscaleWithModel", + "_meta": { + "title": "Upscale Image (using Model)" + } + }, + "15": { + "inputs": { + "url_or_path": "" + }, + "class_type": "LoadImageFromUrlOrPath", + "_meta": { + "title": "LoadImageFromUrlOrPath" + } + }, + "16": { + "inputs": { + "images": [ + "14", + 0 + ] + }, + "class_type": "SaveImageWebsocket", + "_meta": { + "title": "SaveImageWebsocket" + } + } +} +""" + +class ImageUpscaleService: + def __init__(self): + """初始化图像放大服务""" + pass + + def check_image_transparency(self, image_url: str) -> tuple: + """检查图像是否有透明通道,返回图像和是否透明的标志""" + try: + # 下载图片 + response = requests.get(image_url) + if response.status_code != 200: + raise Exception(f"无法下载图片: {image_url}") + + # 使用PIL打开图片 + img = Image.open(io.BytesIO(response.content)) + + # 检查图像是否有透明通道 + has_transparency = img.mode in ('RGBA', 'LA') and img.format == 'PNG' + + return img, has_transparency + except Exception as e: + raise Exception(f"图片处理失败: {str(e)}") + + def prepare_image_for_upscale(self, image_url: str) -> tuple: + """根据图像类型准备图像用于放大,返回处理后的图像URL和透明标志""" + img, has_transparency = self.check_image_transparency(image_url) + + if not has_transparency: + # 非透明图像直接使用原图 + return image_url, False, None + + # 对于透明PNG,我们需要分离RGB和Alpha通道 + rgb_image = img.convert('RGB') + alpha_channel = img.split()[-1] + + # 保存RGB图像到临时文件 + rgb_temp_file = tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) + rgb_image.save(rgb_temp_file.name, 'JPEG', quality=95) + + # 保存Alpha通道到临时文件 + alpha_temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + alpha_channel.save(alpha_temp_file.name, 'PNG') + + return rgb_temp_file.name, True, alpha_temp_file.name + + def upscale_alpha_channel(self, alpha_path: str, scale_factor: int = 4) -> Image.Image: + """使用双线性插值放大Alpha通道""" + alpha_img = Image.open(alpha_path) + width, height = alpha_img.size + new_width, new_height = width * scale_factor, height * scale_factor + return alpha_img.resize((new_width, new_height), Image.BILINEAR) + + def queue_prompt(self, prompt: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """将提示词发送到 ComfyUI 服务器的队列中""" + p = {"prompt": prompt, "client_id": client_id} + data = json.dumps(p).encode('utf-8') + req = urllib.request.Request(f"http://{comfyui_server_address}/prompt", data=data) + response = json.loads(urllib.request.urlopen(req).read()) + return response + + def get_images(self, ws: websocket.WebSocket, workflow: Dict, comfyui_server_address: str, client_id: str) -> Dict: + """从 ComfyUI 获取生成的图像""" + try: + prompt_response = self.queue_prompt(workflow, comfyui_server_address, client_id) + prompt_id = prompt_response['prompt_id'] + except KeyError: + return {} + + output_images = {} + current_node = "" + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message['type'] == 'executing': + data = message['data'] + if data.get('prompt_id') == prompt_id: + if data['node'] is None: + break + else: + current_node = data['node'] + else: + if current_node == '16': # 放大图像节点 + images_output = output_images.get(current_node, []) + images_output.append(out[8:]) + output_images[current_node] = images_output + + return output_images + + def upscale_image_sync(self, image_url: str, config: Optional[Dict] = None) -> Generator[str, None, None]: + """放大图像,保存到本地并返回图片URL""" + cfg = default_config.copy() + if config: + cfg.update(config) + ws = websocket.WebSocket() + client_id = str(uuid.uuid4()) + temp_file = None + alpha_temp_file = None + has_transparency = False + + try: + # 准备图像用于放大 + image_path, has_transparency, alpha_path = self.prepare_image_for_upscale(image_url) + if image_path != image_url: + temp_file = image_path + image_url = image_path + if has_transparency: + alpha_temp_file = alpha_path + + ws.connect(f"ws://{cfg['comfyui_server_address']}/ws?clientId={client_id}") + workflow = json.loads(workflow_template) + workflow["13"]["inputs"]["model_name"] = cfg['upscale_model_name'] + workflow["15"]["inputs"]["url_or_path"] = image_url + images_dict = self.get_images(ws, workflow, cfg['comfyui_server_address'], client_id) + os.makedirs(save_dir, exist_ok=True) + + for node_id, image_list in images_dict.items(): + for image_data in image_list: + if has_transparency: + # 处理带透明通道的图像 + # 保存放大后的RGB图像 + upscaled_rgb_path = os.path.join(save_dir, f"upscaled_rgb_{uuid.uuid4().hex[:10]}.png") + with open(upscaled_rgb_path, "wb") as f: + f.write(image_data) + + # 打开放大后的RGB图像 + upscaled_rgb = Image.open(upscaled_rgb_path) + + # 放大Alpha通道 + upscaled_alpha = self.upscale_alpha_channel(alpha_temp_file, + scale_factor=upscaled_rgb.width//Image.open(temp_file).width) + + # 确保尺寸匹配 + if upscaled_rgb.size != upscaled_alpha.size: + upscaled_alpha = upscaled_alpha.resize(upscaled_rgb.size, Image.BILINEAR) + + # 合并通道 + upscaled_rgba = upscaled_rgb.copy() + upscaled_rgba.putalpha(upscaled_alpha) + + # 保存最终的RGBA图像 + png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png" + png_file_path = os.path.join(save_dir, png_filename) + upscaled_rgba.save(png_file_path, "PNG") + + # 删除临时RGB文件 + os.remove(upscaled_rgb_path) + + # 返回PNG URL + image_url = f"{download_url}/{png_filename}" + else: + # 处理没有透明通道的图像 + # 保存为JPG以减小文件大小 + png_filename = f"upscaled_{uuid.uuid4().hex[:10]}.png" + png_file_path = os.path.join(save_dir, png_filename) + with open(png_file_path, "wb") as f: + f.write(image_data) + + # 打开图像并转换为JPG + img = Image.open(png_file_path) + jpg_filename = png_filename.replace('.png', '.jpg') + jpg_file_path = os.path.join(save_dir, jpg_filename) + img = img.convert('RGB') + img.save(jpg_file_path, 'JPEG', quality=95) + + # 删除PNG临时文件 + os.remove(png_file_path) + + # 返回JPG URL + image_url = f"{download_url}/{jpg_filename}" + + yield image_url + except Exception as e: + raise e + finally: + if ws: + ws.close() + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + os.unlink(temp_file) + if alpha_temp_file and os.path.exists(alpha_temp_file): + os.unlink(alpha_temp_file) + + async def upscale_image(self, image_url: str, config: Optional[Dict] = None) -> AsyncGenerator[Dict, None]: + """异步放大图像,返回图片URL""" + try: + # 在这种情况下,我们需要手动运行同步生成器并收集结果 + urls = [] + + # 在执行器中运行同步代码 + def run_sync(): + return list(self.upscale_image_sync(image_url, config)) + + # 获取所有URL + loop = asyncio.get_event_loop() + urls = await loop.run_in_executor(None, run_sync) + + # 逐个返回结果 + for url in urls: + yield { + "status": "success", + "image_url": url, + "message": "图片已保存" + } + except Exception as e: + yield { + "status": "error", + "message": f"图像放大失败: {str(e)}" + } + + async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None): + """批量处理多个图像URL,返回图片URL""" + total = len(image_urls) + success_count = 0 + error_count = 0 + for i, image_url in enumerate(image_urls, 1): + try: + # 获取图片透明度信息 + try: + _, has_transparency = self.check_image_transparency(image_url) + transparency_info = "PNG带透明通道" if has_transparency else "无透明通道" + except: + transparency_info = "未检测" + + async for result in self.upscale_image(image_url, config): + if result["status"] == "success": + success_count += 1 + yield { + "index": i, + "total": total, + "original_image_url": image_url, + "status": "success", + "image_url": result["image_url"], + "success_count": success_count, + "error_count": error_count, + "transparency": transparency_info, + "message": result["message"] + } + else: + error_count += 1 + yield { + "index": i, + "total": total, + "original_image_url": image_url, + "status": "error", + "success_count": success_count, + "error_count": error_count, + "message": result["message"] + } + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "original_image_url": image_url, + "status": "error", + "success_count": success_count, + "error_count": error_count, + "message": f"处理图像时出错: {str(e)}" + } + diff --git a/apps/jupscale/settings.py b/apps/jupscale/settings.py new file mode 100644 index 0000000..b8bb01d --- /dev/null +++ b/apps/jupscale/settings.py @@ -0,0 +1,39 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8109 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jupscale" + upscale_route: str = "/upscale" # 放大图片的路由 + batch_route: str = "/batch" # 批量放大图片的路由 + api_name: str = "jupscale" # 默认API名称 + save_dir: str = "../jfile/files" + # Japi 静态资源下载URL + download_url: str = "http://api.jingrow.com:9080/files" + + # 中转图床服务上传URL + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # Stable Diffusion配置 + comfyui_server_address: str = "192.168.2.200:8188" + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jupscale/utils.py b/apps/jupscale/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jupscale/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/jvector/__init__.py b/apps/jvector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/jvector/api.py b/apps/jvector/api.py new file mode 100644 index 0000000..3b30bc3 --- /dev/null +++ b/apps/jvector/api.py @@ -0,0 +1,53 @@ +from fastapi import APIRouter, UploadFile, File, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import JvectorService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio + +router = APIRouter(prefix=settings.router_prefix) +service = JvectorService() + + +@router.post(settings.file_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def vectorize_image_file(file: UploadFile = File(...), request: Request = None): + """ + 将上传的文件转换为矢量图 + + Args: + file: 上传的图片文件 + request: FastAPI 请求对象 + + Returns: + 处理后的矢量图内容 + """ + content = await file.read() + result = await service.vectorize_from_file(content) + return result + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def vectorize_image_batch(data: dict, request: Request): + """ + 批量处理多个URL图片转换为矢量图 + + Args: + data: 包含图片URL列表的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个图片的处理结果 + """ + if "urls" not in data: + raise HTTPException(status_code=400, detail="缺少urls参数") + + async def process_and_stream(): + async for result in service.process_batch(data["urls"]): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) diff --git a/apps/jvector/app.py b/apps/jvector/app.py new file mode 100644 index 0000000..91986f8 --- /dev/null +++ b/apps/jvector/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Jvector", + description="Jvector转矢量图API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/jvector/service.py b/apps/jvector/service.py new file mode 100644 index 0000000..ff59d15 --- /dev/null +++ b/apps/jvector/service.py @@ -0,0 +1,228 @@ +import io +import os +import requests +import traceback +import tempfile +import base64 +from urllib.parse import urlparse +from pathlib import Path +from PIL import Image +import asyncio +from settings import settings + +class JvectorService: + def __init__(self): + """初始化矢量图转换服务""" + # 获取配置变量 + self.upload_url = settings.upload_url + self.vector_api_id = settings.vector_api_id + self.vector_api_secret = settings.vector_api_secret + self.vector_mode = settings.vector_mode + + def _get_config(self, key): + """获取配置值,从环境变量读取""" + if key == "upload_url": + return settings.upload_url + + # 其他配置项的处理方式 + config_map = {} + return config_map.get(key, "") + + def upload_image_to_intermediate_server(self, image_url): + """上传图片到中转服务器的函数""" + try: + response = requests.get(image_url, verify=False) + response.raise_for_status() + image_data = response.content + + parsed_url = urlparse(image_url) + file_name = Path(parsed_url.path).name + file_ext = Path(file_name).suffix + + # 如果图片是webp格式,转换为png格式 + if file_ext.lower() == '.webp': + image = Image.open(io.BytesIO(image_data)) + png_buffer = io.BytesIO() + image.save(png_buffer, format='PNG') + image_data = png_buffer.getvalue() + file_name = file_name.replace('.webp', '.png') + + files = {"file": (file_name, image_data)} + + upload_response = requests.post(self.upload_url, files=files, verify=False) + + if upload_response.status_code == 200: + return upload_response.json()["file_url"] + else: + error_msg = f"上传失败. 状态码: {upload_response.status_code}, {upload_response.text}" + print(error_msg) + raise Exception(error_msg) + + except Exception as e: + error_msg = f"上传图像到中间服务器失败: {str(e)}, URL: {image_url}" + print(error_msg) + traceback.print_exc() + raise Exception(error_msg) + + def convert_image_to_vector(self, image_url): + """将图片转换为矢量图的函数""" + try: + url = "https://vectorizer.ai/api/v1/vectorize" + data = { + 'image.url': image_url, + 'mode': self.vector_mode + } + auth = (self.vector_api_id, self.vector_api_secret) + response = requests.post(url, data=data, auth=auth) + response.raise_for_status() + return response.content + except Exception as e: + error_msg = f"转换图像为矢量图失败: {str(e)}, URL: {image_url}" + print(error_msg) + traceback.print_exc() + raise Exception(error_msg) + + def svg_to_base64(self, svg_content): + """将SVG内容转换为base64字符串""" + return base64.b64encode(svg_content).decode('utf-8') + + async def vectorize_image(self, image_url): + """ + 将图片转换为矢量图 + + Args: + image_url: 输入图像的URL + + Returns: + 处理后的矢量图内容 + """ + try: + # 转换为矢量图 + vector_content = self.convert_image_to_vector(image_url) + + # 转换为base64 + svg_content = self.svg_to_base64(vector_content) + + return { + "status": "success", + "svg_content": svg_content + } + + except Exception as e: + raise Exception(f"矢量图转换失败: {str(e)}") + + async def vectorize_from_file(self, file_content): + """ + 从上传的文件内容创建矢量图 + + Args: + file_content: 上传的文件内容 + + Returns: + 处理后的矢量图内容 + """ + temp_file = None + try: + # 创建临时文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') + with open(temp_file.name, 'wb') as f: + f.write(file_content) + + # 上传到中转服务器 + with open(temp_file.name, 'rb') as f: + files = {"file": ("image.png", f)} + upload_response = requests.post(self.upload_url, files=files, verify=False) + + if upload_response.status_code == 200: + intermediate_url = upload_response.json()["file_url"] + else: + raise Exception(f"上传失败. 状态码: {upload_response.status_code}") + + # 转换为矢量图 + vector_content = self.convert_image_to_vector(intermediate_url) + + # 转换为base64 + svg_content = self.svg_to_base64(vector_content) + + return { + "status": "success", + "svg_content": svg_content + } + + except Exception as e: + raise Exception(f"处理文件失败: {str(e)}") + + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file.name): + try: + os.unlink(temp_file.name) + except: + pass + + async def process_batch(self, urls): + """ + 批量处理多个URL图像,流式返回结果 + + Args: + urls: 图片URL列表 + + Yields: + 每个图片的处理结果 + """ + total = len(urls) + success_count = 0 + error_count = 0 + + for i, url in enumerate(urls, 1): + try: + url_str = str(url) + result = await self.vectorize_image(url_str) + success_count += 1 + + # 确保返回正确的数据格式 + yield { + "index": i, + "total": total, + "original_url": url_str, + "status": "success", + "svg_content": result["svg_content"], + "success_count": success_count, + "error_count": error_count, + "message": "处理成功" + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "original_url": str(url), + "status": "error", + "error": str(e), + "success_count": success_count, + "error_count": error_count, + "message": f"处理失败: {str(e)}" + } + + # 让出控制权,避免阻塞 + await asyncio.sleep(0) + + def is_valid_url(self, url): + """验证URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + + def download_file(self, url, filename): + """下载文件到本地""" + response = requests.get(url, stream=True) + response.raise_for_status() + + with open(filename, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + return filename diff --git a/apps/jvector/settings.py b/apps/jvector/settings.py new file mode 100644 index 0000000..02e8ac7 --- /dev/null +++ b/apps/jvector/settings.py @@ -0,0 +1,41 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8110 + debug: bool = False + + # API路由配置 + router_prefix: str = "/jvector" + file_route: str = "/file" # 转矢量图的路由 + batch_route: str = "/batch" # 批量转矢量图的路由 + api_name: str = "jvector" # 默认API名称 + save_dir: str = "../jfile/files" + # Japi 静态资源下载URL + download_url: str = "http://api.jingrow.com:9080/files" + + # 中转图床服务上传URL + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # 矢量图转换服务配置 + vector_api_id: Optional[str] = None + vector_api_secret: Optional[str] = None + vector_mode: str = "production" # 'test' 或 'production' 或 'preview' + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/jvector/utils.py b/apps/jvector/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/jvector/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/ptn_to_tshirt/__init__.py b/apps/ptn_to_tshirt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/ptn_to_tshirt/api.py b/apps/ptn_to_tshirt/api.py new file mode 100644 index 0000000..a601e78 --- /dev/null +++ b/apps/ptn_to_tshirt/api.py @@ -0,0 +1,105 @@ +from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Request +from fastapi.responses import StreamingResponse, JSONResponse +from service import PtnToTshirtService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio +from typing import List, Optional +import io + +router = APIRouter(prefix=settings.router_prefix) +service = PtnToTshirtService() + + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def pattern_to_tshirt_batch(data: dict, request: Request): + """ + 批量处理多个URL花型图片 + + Args: + data: 包含花型图片URL列表和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个图片的处理结果 + """ + if "urls" not in data: + raise HTTPException(status_code=400, detail="缺少urls参数") + + config = data.get("config", {}) + + # 支持传入T恤图片URL列表 + if "tshirt_urls" in data and isinstance(data["tshirt_urls"], list): + if not config: + config = {} + config["tshirt_urls"] = data["tshirt_urls"] + + async def process_and_stream(): + total = len(data["urls"]) + for index, url in enumerate(data["urls"], 1): + try: + result = await service.pattern_to_tshirt(url, config) + result.update({ + "index": index, + "total": total, + "original_url": url + }) + yield json.dumps(result) + "\n" + except Exception as e: + yield json.dumps({ + "status": "error", + "message": str(e), + "index": index, + "total": total, + "original_url": url + }) + "\n" + + try: + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +@router.post(settings.file_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def pattern_to_tshirt_file(file: UploadFile = File(...), config: str = Form("{}"), tshirt_urls: str = None, request: Request = None): + """ + 将上传的花型文件添加到T恤上 + + Args: + file: 上传的花型图片文件 + config: JSON格式的配置参数 + tshirt_urls: JSON格式的T恤图片URL列表 + request: FastAPI 请求对象 + + Returns: + 处理后的T恤图片内容 + """ + content = await file.read() + + # 解析配置参数 + config_dict = {} + if config: + try: + config_dict = json.loads(config) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="配置参数格式错误") + + # 解析T恤图片URL列表 + if tshirt_urls: + try: + urls_list = json.loads(tshirt_urls) + if isinstance(urls_list, list): + config_dict["tshirt_urls"] = urls_list + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="T恤图片URL列表格式错误") + + try: + result = await service.pattern_to_tshirt_from_file(content, config_dict) + return result + except Exception as e: + raise HTTPException(status_code=500, detail=f"处理图像失败: {str(e)}") diff --git a/apps/ptn_to_tshirt/app.py b/apps/ptn_to_tshirt/app.py new file mode 100644 index 0000000..1595389 --- /dev/null +++ b/apps/ptn_to_tshirt/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Pattern to Tshirt", + description="将图片中的花型添加到T恤上", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/ptn_to_tshirt/service.py b/apps/ptn_to_tshirt/service.py new file mode 100644 index 0000000..7f7a76c --- /dev/null +++ b/apps/ptn_to_tshirt/service.py @@ -0,0 +1,523 @@ +# pattern_to_tshirt.py +import sys +import os +import json +import io +import cv2 +import numpy as np +from PIL import Image, ImageFilter, ImageDraw, ImageChops +import uuid +import urllib.request +import urllib3 +import requests +from pydantic import BaseModel +from typing import Optional +import base64 +import asyncio +import warnings +import tempfile +from urllib.parse import urlparse +import torch +import time +import gc + +# 关闭不必要的警告 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +class PtnToTshirtService: + # 默认配置 + DEFAULT_CONFIG = { + 'background_removed_marker': "_rmbg", + 'ptt_exclude_markers': ["_upscaled", "_vector", "_processed", "_tshirt", "_tryon"], + 'tshirt_marker': "_tshirt", + 'processed_pattern_marker': "_processed", + 'tshirt_image_path': 'home/tshirt', + 'tshirt_urls': [], # 新增: T恤图片URL列表 + 'alpha': 1, # 透明度,0表示全透明,1表示全不透明 + 'ptt_design_size_ratio': 0.4, # 设计图像占T恤图像的比例 + 'ptt_design_offset': [0.5, 0.45], # 设计图像在T恤图像中的相对位置 [x, y] + 'ptt_design_rotation': 0, # 设计图案旋转角度 + + 'enable_gradient_effect': True, # 是否启用渐变效果 + 'gradient_width': 512, # 渐变宽度 + 'gradient_direction': 'outward', # 渐变方向: 'outward', 'inward' + 'gradient_type': 'linear', # 渐变类型: 'linear', 'radial' + 'gradient_max_alpha': 150, # 渐变的最大透明度值,0-255 + 'gradient_start_alpha': 0, # 渐变起始处的透明度,0-255 + 'gradient_color': [255, 255, 255, 255], # 渐变颜色 + 'gradient_blur_intensity': 10, # 渐变模糊强度 + 'gradient_center': [0.5, 0.5], # 渐变中心位置,相对于设计图案的 [x, y] + 'gradient_repeat_count': 1, # 渐变重复次数 + + 'ptt_enable_texture_effect': False, # 是否启用纹理效果 + 'ptt_texture_type': 'lines', # 纹理类型: 'noise', 'lines' + 'ptt_texture_blend_mode': 'multiply', # 纹理混合模式 + + 'enable_save_processed_design': True, # 是否单独保存处理后的设计图案 + 'ptt_design_output_format': 'png', # 设计图案保存格式: 'png' 或 'tiff' + + 'ptt_enable_color_matching': True, # 是否启用颜色匹配 + 'ptt_enable_lighting_effect': False, # 是否启用光效 + 'ptt_enable_monochrome': False, + 'ptt_light_intensity': 0.5, # 光照强度 + 'ptt_light_position': [0.5, 0.3], # 光源位置 [相对位置 x, y] + 'ptt_light_radius_ratio': [0.4, 0.25], # 光源半径相对比例 [宽, 高] + 'ptt_light_angle': 45, # 光源角度(度) + 'ptt_light_blur': 91, # 光源模糊程度 + 'ptt_light_shape': 'ellipse', # 光源形状: 'ellipse', 'circle', 'rect' + } + + def __init__(self): + """初始化图案到T恤服务""" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"使用设备: {self.device}") + + if torch.cuda.is_available(): + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f"显存总量: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f}GB") + + def overlay_image_alpha(self, img, img_overlay, pos, alpha_mask): + """在图像上叠加另一个具有透明度的图像""" + x, y = pos + y1, y2 = max(0, y), min(img.shape[0], y + img_overlay.shape[0]) + x1, x2 = max(0, x), min(img.shape[1], x + img_overlay.shape[1]) + y1o, y2o = max(0, -y), min(img_overlay.shape[0], img.shape[0] - y) + x1o, x2o = max(0, -x), min(img_overlay.shape[1], img.shape[1] - x) + + if y1 >= y2 or x1 >= x2 or y1o >= y2o or x1o >= x2o: + return + + img_crop = img[y1:y2, x1:x2] + img_overlay_crop = img_overlay[y1o:y2o, x1o:x2o] + alpha = alpha_mask[y1o:y2o, x1o:x2o, np.newaxis] + img_crop[:] = alpha * img_overlay_crop + (1 - alpha) * img_crop + + def color_transfer(self, source, target): + """颜色匹配:将源图像的颜色转换为目标图像的颜色风格""" + source = cv2.cvtColor(source, cv2.COLOR_BGR2LAB) + target = cv2.cvtColor(target, cv2.COLOR_BGR2LAB) + + src_mean, src_std = cv2.meanStdDev(source) + tgt_mean, tgt_std = cv2.meanStdDev(target) + + src_mean = src_mean.reshape(1, 1, 3) + src_std = src_std.reshape(1, 1, 3) + tgt_mean = tgt_mean.reshape(1, 1, 3) + tgt_std = tgt_std.reshape(1, 1, 3) + + result = (source - src_mean) * (tgt_std / src_std) + tgt_mean + result = np.clip(result, 0, 255) + result = result.astype(np.uint8) + + return cv2.cvtColor(result, cv2.COLOR_LAB2BGR) + + def apply_lighting_effect(self, image, light_intensity=0.5, light_position=[0.5, 0.3], + light_radius_ratio=[0.4, 0.25], light_angle=45, light_blur=91, light_shape='ellipse'): + """应用光照效果到图像""" + height, width = image.shape[:2] + light_position = (int(light_position[0] * width), int(light_position[1] * height)) + light_radius = (int(light_radius_ratio[0] * width), int(light_radius_ratio[1] * height)) + mask = np.zeros((height, width), dtype=np.uint8) + + if light_shape == 'ellipse': + cv2.ellipse(mask, light_position, light_radius, light_angle, 0, 360, 255, -1) + elif light_shape == 'circle': + cv2.circle(mask, light_position, min(light_radius), 255, -1) + elif light_shape == 'rect': + rect_top_left = (light_position[0] - light_radius[0] // 2, light_position[1] - light_radius[1] // 2) + rect_bottom_right = (light_position[0] + light_radius[0] // 2, light_position[1] + light_radius[1] // 2) + cv2.rectangle(mask, rect_top_left, rect_bottom_right, 255, -1) + + mask = cv2.GaussianBlur(mask, (light_blur, light_blur), 0) + mask = mask.astype(np.float32) / 255 + result = image.astype(np.float32) + for i in range(3): + result[:, :, i] = result[:, :, i] * (1 - light_intensity + mask * light_intensity) + return result.astype(np.uint8) + + def apply_monochrome(self, image): + """将图像转换为单色""" + gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + monochrome_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR) + return monochrome_image + + def enhance_design(self, design_image, tshirt_image, config): + """增强设计图像,应用各种效果""" + if config.get('ptt_enable_color_matching', self.DEFAULT_CONFIG['ptt_enable_color_matching']): + design_image = self.color_transfer(design_image, tshirt_image) + + if config.get('ptt_enable_lighting_effect', self.DEFAULT_CONFIG['ptt_enable_lighting_effect']): + design_image = self.apply_lighting_effect( + design_image, + config.get('ptt_light_intensity', self.DEFAULT_CONFIG['ptt_light_intensity']), + config.get('ptt_light_position', self.DEFAULT_CONFIG['ptt_light_position']), + config.get('ptt_light_radius_ratio', self.DEFAULT_CONFIG['ptt_light_radius_ratio']), + config.get('ptt_light_angle', self.DEFAULT_CONFIG['ptt_light_angle']), + config.get('ptt_light_blur', self.DEFAULT_CONFIG['ptt_light_blur']), + config.get('ptt_light_shape', self.DEFAULT_CONFIG['ptt_light_shape']) + ) + + if config.get('ptt_enable_monochrome', self.DEFAULT_CONFIG['ptt_enable_monochrome']): + design_image = self.apply_monochrome(design_image) + + return design_image + + def add_edge_gradient(self, image, gradient_width, gradient_direction, gradient_type, + gradient_max_alpha, gradient_start_alpha, gradient_color, + gradient_blur_intensity, gradient_center): + """添加边缘渐变效果""" + alpha = image.getchannel('A') + width, height = alpha.size + mask = Image.new('L', (width, height), 0) + draw = ImageDraw.Draw(mask) + + # 确保渐变宽度不超过图像尺寸的一半 + gradient_width = min(gradient_width, width // 2, height // 2) + + if gradient_type == 'linear': + if gradient_direction == 'outward': + for i in range(gradient_width): + if i >= width // 2 or i >= height // 2: + break # 避免无效的矩形坐标 + fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * (i / gradient_width)) + draw.rectangle([i, i, width - i - 1, height - i - 1], fill=fill_value) + elif gradient_direction == 'inward': + for i in range(gradient_width): + if i >= width // 2 or i >= height // 2: + break # 避免无效的矩形坐标 + fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * ((gradient_width - i) / gradient_width)) + draw.rectangle([i, i, width - i - 1, height - i - 1], fill=fill_value) + elif gradient_type == 'radial': + center_x = int(width * gradient_center[0]) + center_y = int(height * gradient_center[1]) + max_radius = min(center_x, center_y, width - center_x, height - center_y) + for i in range(gradient_width): + radius = max_radius * (i / gradient_width) + if radius <= 0: + continue + fill_value = int(gradient_start_alpha + (gradient_max_alpha - gradient_start_alpha) * (i / gradient_width)) + draw.ellipse([center_x - radius, center_y - radius, center_x + radius, center_y + radius], fill=fill_value) + + if gradient_blur_intensity < 1: + gradient_blur_intensity = 1 + mask = mask.filter(ImageFilter.GaussianBlur(gradient_blur_intensity)) + alpha = ImageChops.multiply(alpha, mask) + image.putalpha(alpha) + + if gradient_color != [255, 255, 255, 255]: + colored_mask = Image.new('RGBA', image.size, tuple(gradient_color)) + colored_mask.putalpha(mask) + image = Image.alpha_composite(image, colored_mask) + + return image + + def add_gradient_repeat(self, image, gradient_repeat_count, *args, **kwargs): + """重复应用渐变效果""" + for _ in range(max(gradient_repeat_count, 1)): # 确保至少执行一次 + image = self.add_edge_gradient(image, *args, **kwargs) + return image + + def generate_noise_texture(self, size, intensity=64): + """生成噪点纹理""" + noise = np.random.randint(0, intensity, (size, size, 4), dtype=np.uint8) + noise[..., 3] = 255 # 设置 alpha 通道为不透明 + return Image.fromarray(noise) + + def generate_line_texture(self, size, line_width=4, spacing=20, color=(0, 0, 0, 255)): + """生成线条纹理""" + texture = Image.new('RGBA', (size, size), (255, 255, 255, 0)) + draw = ImageDraw.Draw(texture) + for y in range(0, size, spacing): + draw.line([(0, y), (size, y)], fill=color, width=line_width) + for x in range(0, size, spacing): + draw.line([(x, 0), (x, size)], fill=color, width=line_width) + return texture + + def add_texture(self, image, texture_type, texture_blend_mode): + """添加纹理效果到图像""" + if texture_type == 'noise': + texture = self.generate_noise_texture(image.size[0]) + elif texture_type == 'lines': + texture = self.generate_line_texture(image.size[0]) + else: + return image + + if texture_blend_mode == 'multiply': + return ImageChops.multiply(image, texture) + elif texture_blend_mode == 'overlay': + return ImageChops.overlay(image, texture) + else: + return image + + def rotate_image_with_transparency(self, image, angle): + """旋转带有透明度的图像""" + rotated_image = image.rotate(angle, expand=True) + return rotated_image + + def save_processed_design_image(self, design_image, output_format='png'): + """保存处理后的设计图像""" + try: + img_bytes = io.BytesIO() + # 确保使用包含透明背景的BGRA格式 + design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert('RGBA') + + if output_format == 'tiff': + design_image_pil.save(img_bytes, format='TIFF', save_all=True, compression='tiff_deflate') + else: + design_image_pil.save(img_bytes, format='PNG') + + img_bytes.seek(0) + return img_bytes + except Exception as e: + print(f"保存处理后的设计图像时发生错误: {e}") + return None + + def generate_tshirt_image(self, design_image, tshirt_image, config): + """将花型图案合成到T恤图像上""" + # 合并默认配置和用户配置 + config = {**self.DEFAULT_CONFIG, **config} + + # 将设计图像从RGBA转换为BGRA(如果需要) + if isinstance(design_image, np.ndarray) and design_image.shape[2] == 4: + if design_image.dtype != np.uint8: + design_image = design_image.astype(np.uint8) + else: + # 如果输入是PIL图像,转换为OpenCV格式 + if isinstance(design_image, Image.Image): + design_image = cv2.cvtColor(np.array(design_image), cv2.COLOR_RGBA2BGRA) + else: + raise ValueError("设计图像必须是PIL Image或带Alpha通道的NumPy数组") + + # 对设计图像应用渐变效果 + if config.get('enable_gradient_effect', self.DEFAULT_CONFIG['enable_gradient_effect']): + design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA") + design_image_pil = self.add_gradient_repeat( + design_image_pil, + config.get('gradient_repeat_count', self.DEFAULT_CONFIG['gradient_repeat_count']), + config.get('gradient_width', self.DEFAULT_CONFIG['gradient_width']), + config.get('gradient_direction', self.DEFAULT_CONFIG['gradient_direction']), + config.get('gradient_type', self.DEFAULT_CONFIG['gradient_type']), + config.get('gradient_max_alpha', self.DEFAULT_CONFIG['gradient_max_alpha']), + config.get('gradient_start_alpha', self.DEFAULT_CONFIG['gradient_start_alpha']), + config.get('gradient_color', self.DEFAULT_CONFIG['gradient_color']), + config.get('gradient_blur_intensity', self.DEFAULT_CONFIG['gradient_blur_intensity']), + config.get('gradient_center', self.DEFAULT_CONFIG['gradient_center']) + ) + design_image = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA) + + # 应用纹理效果到设计图案 + if config.get('ptt_enable_texture_effect', self.DEFAULT_CONFIG['ptt_enable_texture_effect']): + design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA") + design_image_pil = self.add_texture( + design_image_pil, + config.get('ptt_texture_type', self.DEFAULT_CONFIG['ptt_texture_type']), + config.get('ptt_texture_blend_mode', self.DEFAULT_CONFIG['ptt_texture_blend_mode']) + ) + design_image = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA) + + # 进行设计图像增强处理 + design_image_enhanced = self.enhance_design(design_image[:, :, :3], tshirt_image, config) + + # 应用旋转效果到设计图案 + ptt_design_rotation = config.get('ptt_design_rotation', self.DEFAULT_CONFIG['ptt_design_rotation']) + if ptt_design_rotation != 0: + design_image_pil = Image.fromarray(cv2.cvtColor(design_image, cv2.COLOR_BGRA2RGBA)).convert("RGBA") + design_image_pil = self.rotate_image_with_transparency(design_image_pil, ptt_design_rotation) + processed_design_image_with_alpha = cv2.cvtColor(np.array(design_image_pil), cv2.COLOR_RGBA2BGRA) + else: + processed_design_image_with_alpha = cv2.merge((design_image_enhanced, design_image[:, :, 3])) + + # 保存处理后的设计图像 + processed_design_io = None + if config.get('enable_save_processed_design', self.DEFAULT_CONFIG['enable_save_processed_design']): + processed_design_io = self.save_processed_design_image( + processed_design_image_with_alpha, + config.get('ptt_design_output_format', self.DEFAULT_CONFIG['ptt_design_output_format']) + ) + + # 调整设计图像大小 + tshirt_height, tshirt_width = tshirt_image.shape[:2] + design_width = int(tshirt_width * config.get('ptt_design_size_ratio', self.DEFAULT_CONFIG['ptt_design_size_ratio'])) + aspect_ratio = processed_design_image_with_alpha.shape[0] / processed_design_image_with_alpha.shape[1] + design_height = int(design_width * aspect_ratio) + design_image_resized = cv2.resize(processed_design_image_with_alpha, (design_width, design_height)) + + # 提取Alpha通道 + alpha_channel = design_image_resized[:, :, 3] / 255.0 + + # 计算设计图像在T恤上的位置 + ptt_design_offset = config.get('ptt_design_offset', self.DEFAULT_CONFIG['ptt_design_offset']) + design_position = ( + int((tshirt_width - design_width) * ptt_design_offset[0]), + int((tshirt_height - design_height) * ptt_design_offset[1]) + ) + + # 将设计图像叠加到T恤图像上 + result_image = tshirt_image.copy() + self.overlay_image_alpha(result_image, design_image_resized[:, :, :3], design_position, alpha_channel) + + # 返回结果图像和处理后的设计图像 + return result_image, processed_design_io + + def image_to_base64(self, image, format='png'): + """将图像转换为base64字符串""" + try: + if isinstance(image, np.ndarray): + # 如果是OpenCV图像(NumPy数组),转换为PIL图像 + if image.shape[2] == 3: + # BGR转RGB + image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + else: + # BGRA转RGBA + image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)) + else: + # 已经是PIL图像 + image_pil = image + + # 保存为BytesIO对象 + buffered = io.BytesIO() + image_pil.save(buffered, format=format.upper()) + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + except Exception as e: + print(f"将图像转换为base64时发生错误: {e}") + return None + + def download_tshirt_images(self, config): + """下载T恤图像列表""" + try: + # 首先检查是否提供了T恤图片URL列表 + tshirt_urls = config.get('tshirt_urls', self.DEFAULT_CONFIG['tshirt_urls']) + if tshirt_urls and isinstance(tshirt_urls, list) and len(tshirt_urls) > 0: + tshirt_images = [] + for url in tshirt_urls: + if self.is_valid_url(url): + tshirt_io = self.download_image(url) + if tshirt_io: + tshirt_image = cv2.imdecode(np.frombuffer(tshirt_io.getvalue(), np.uint8), cv2.IMREAD_COLOR) + if tshirt_image is not None: + tshirt_images.append(tshirt_image) + + if tshirt_images: + return tshirt_images + + # 如果没有提供URL或URL下载失败,则尝试使用本地模板 + sample_tshirt_path = os.path.join(config.get('tshirt_image_path', self.DEFAULT_CONFIG['tshirt_image_path']), 'sample_tshirt.jpg') + if os.path.exists(sample_tshirt_path): + tshirt_image = cv2.imread(sample_tshirt_path) + return [tshirt_image] + else: + # 创建一个纯白色的示例T恤图像作为最后的备选 + tshirt_image = np.ones((800, 600, 3), dtype=np.uint8) * 255 + return [tshirt_image] + except Exception as e: + print(f"下载T恤图像时发生错误: {e}") + # 创建一个纯白色的示例T恤图像作为最后的备选 + tshirt_image = np.ones((800, 600, 3), dtype=np.uint8) * 255 + return [tshirt_image] + + def is_valid_url(self, url): + """检查URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + + def download_image(self, url): + """下载图像""" + try: + if self.is_valid_url(url): + response = requests.get(url, verify=False, timeout=10) + if response.status_code == 200: + return io.BytesIO(response.content) + return None + except Exception as e: + print(f"下载图像失败: {str(e)}") + return None + + async def pattern_to_tshirt(self, image_url, config=None): + """将花型图案添加到T恤上(URL输入)""" + if not config: + config = {} + + try: + # 下载花型图案 + design_io = self.download_image(image_url) + if not design_io: + return {"status": "error", "message": "无法下载图像"} + + return await self.pattern_to_tshirt_from_file(design_io.getvalue(), config) + except Exception as e: + import traceback + error_trace = traceback.format_exc() + print(f"处理图像时发生错误: {str(e)}\n{error_trace}") + return {"status": "error", "message": f"处理图像失败: {str(e)}"} + + async def pattern_to_tshirt_from_file(self, file_content, config=None): + """将花型图案添加到T恤上(文件输入)""" + if not config: + config = {} + + try: + # 加载花型图案 + design_io = io.BytesIO(file_content) + design_image = Image.open(design_io).convert("RGBA") + + # 获取T恤图像列表 + tshirt_images = self.download_tshirt_images(config) + if not tshirt_images: + return {"status": "error", "message": "无法获取T恤图像模板"} + + results = [] + processed_design_base64 = None + + # 处理每个T恤图像 + for tshirt_image in tshirt_images: + try: + # 生成合成图像 + result_image, processed_design_io = self.generate_tshirt_image(design_image, tshirt_image, config) + + # 转换为base64 + result_base64 = self.image_to_base64(result_image) + + # 如果有处理后的设计图像,也转换为base64 + if processed_design_io and processed_design_base64 is None: + processed_design_image = Image.open(processed_design_io) + processed_design_base64 = self.image_to_base64(processed_design_image) + + results.append({ + "tshirt_image": result_base64 + }) + except Exception as e: + import traceback + error_trace = traceback.format_exc() + print(f"处理单个T恤图像时发生错误: {str(e)}\n{error_trace}") + # 继续处理下一个T恤图像 + + if not results: + return {"status": "error", "message": "所有T恤图像处理均失败"} + + response = { + "status": "success", + "results": results + } + + # 如果有处理后的设计图像,添加到响应中 + if processed_design_base64: + response["processed_design"] = processed_design_base64 + + return response + + except Exception as e: + import traceback + error_trace = traceback.format_exc() + print(f"处理图像时发生错误: {str(e)}\n{error_trace}") + return {"status": "error", "message": f"处理图像失败: {str(e)}"} + + def cleanup(self): + """清理资源""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() diff --git a/apps/ptn_to_tshirt/settings.py b/apps/ptn_to_tshirt/settings.py new file mode 100644 index 0000000..b919f48 --- /dev/null +++ b/apps/ptn_to_tshirt/settings.py @@ -0,0 +1,32 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8111 + debug: bool = False + + # API路由配置 + router_prefix: str = "/ptn_to_tshirt" + file_route: str = "/file" + batch_route: str = "/batch" + api_name: str = "ptn_to_tshirt" + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/ptn_to_tshirt/utils.py b/apps/ptn_to_tshirt/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/ptn_to_tshirt/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/rmbg/__init__.py b/apps/rmbg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/rmbg/api.py b/apps/rmbg/api.py new file mode 100644 index 0000000..b52be95 --- /dev/null +++ b/apps/rmbg/api.py @@ -0,0 +1,52 @@ +from fastapi import APIRouter, UploadFile, File, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import RmbgService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio + +router = APIRouter(prefix=settings.router_prefix) +service = RmbgService() + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def remove_background_batch(data: dict, request: Request): + """ + 批量处理多个URL图片 + + Args: + data: 包含图片URL列表的字典 + request: FastAPI 请求对象 + + Returns: + 流式响应,包含每个图片的处理结果 + """ + if "urls" not in data: + raise HTTPException(status_code=400, detail="缺少urls参数") + + async def process_and_stream(): + async for result in service.process_batch(data["urls"]): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) + +@router.post(settings.file_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def remove_background_file(file: UploadFile = File(...), request: Request = None): + """ + 从上传的文件移除背景 + + Args: + file: 上传的图片文件 + request: FastAPI 请求对象 + + Returns: + 处理后的图片内容 + """ + content = await file.read() + result = await service.remove_background_from_file(content) + return result diff --git a/apps/rmbg/app.py b/apps/rmbg/app.py new file mode 100644 index 0000000..4a1399d --- /dev/null +++ b/apps/rmbg/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Remove Background", + description="图片去背景", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/rmbg/service.py b/apps/rmbg/service.py new file mode 100644 index 0000000..8508586 --- /dev/null +++ b/apps/rmbg/service.py @@ -0,0 +1,225 @@ +import os +import requests +import tempfile +from urllib.parse import urlparse +from PIL import Image +import torch +from torchvision import transforms +from transformers import AutoModelForImageSegmentation +import time +import warnings +import gc +import base64 +import asyncio +import io +import multiprocessing as mp +from concurrent.futures import ProcessPoolExecutor + +# 关闭不必要的警告 +warnings.filterwarnings("ignore", category=UserWarning) +warnings.filterwarnings("ignore", category=FutureWarning) + +# 设置torch精度 +torch.set_float32_matmul_precision("high") + +class RmbgService: + def __init__(self, model_path="zhengpeng7/BiRefNet"): + """初始化背景移除服务""" + self.model_path = model_path + self.model = None + self.device = None + self._load_model() + + def _load_model(self): + """加载模型""" + # 设置设备 + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + t0 = time.time() + self.model = AutoModelForImageSegmentation.from_pretrained(self.model_path, trust_remote_code=True) + self.model = self.model.to(self.device) + self.model.eval() # 设置为评估模式 + + def process_image(self, image): + """处理图像,移除背景""" + image_size = image.size + # 转换图像 + t0 = time.time() + transform_image = transforms.Compose([ + transforms.Resize((1024, 1024)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + input_images = transform_image(image).unsqueeze(0).to(self.device) + + # 推理 + t0 = time.time() + with torch.no_grad(): + preds = self.model(input_images)[-1].sigmoid().cpu() + + # 处理预测结果 + t0 = time.time() + pred = preds[0].squeeze() + pred_pil = transforms.ToPILImage()(pred) + mask = pred_pil.resize(image_size) + + # 添加透明通道 + image.putalpha(mask) + + # 清理显存 + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + + return image + + def image_to_base64(self, image): + """将PIL Image对象转换为base64字符串""" + buffered = io.BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + async def remove_background(self, image_path): + """ + 移除图像背景 + + Args: + image_path: 输入图像的路径或URL + + Returns: + 处理后的图像内容 + """ + temp_file = None + try: + # 检查是否是URL + if self.is_valid_url(image_path): + try: + # 下载图片到临时文件 + temp_file = self.download_image(image_path) + image_path = temp_file + except Exception as e: + raise Exception(f"下载图片失败: {e}") + + # 验证输入文件是否存在 + if not os.path.exists(image_path): + raise FileNotFoundError(f"输入图像不存在: {image_path}") + + # 加载并处理图像 + image = Image.open(image_path).convert("RGB") + image_no_bg = self.process_image(image) + + # 转换为base64 + image_content = self.image_to_base64(image_no_bg) + + return { + "status": "success", + "image_content": image_content + } + + finally: + # 清理临时文件 + if temp_file and os.path.exists(temp_file): + try: + os.unlink(temp_file) + except: + pass + + async def remove_background_from_file(self, file_content): + """ + 从上传的文件内容移除背景 + + Args: + file_content: 上传的文件内容 + + Returns: + 处理后的图像内容 + """ + try: + # 从文件内容创建PIL Image对象 + image = Image.open(io.BytesIO(file_content)).convert("RGB") + image_no_bg = self.process_image(image) + + # 转换为base64 + image_content = self.image_to_base64(image_no_bg) + + return { + "status": "success", + "image_content": image_content + } + + except Exception as e: + raise Exception(f"处理图片失败: {e}") + + async def process_batch(self, urls): + """ + 批量处理多个URL图像,流式返回结果 + + Args: + urls: 图片URL列表 + + Yields: + 每个图片的处理结果 + """ + total = len(urls) + success_count = 0 + error_count = 0 + + for i, url in enumerate(urls, 1): + try: + url_str = str(url) + result = await self.remove_background(url_str) + success_count += 1 + + # 确保返回正确的数据格式 + yield { + "index": i, + "total": total, + "original_url": url_str, + "status": "success", + "image_content": result["image_content"], + "success_count": success_count, + "error_count": error_count, + "message": "处理成功" + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "original_url": str(url), + "status": "error", + "error": str(e), + "success_count": success_count, + "error_count": error_count, + "message": f"处理失败: {str(e)}" + } + + # 让出控制权,避免阻塞 + await asyncio.sleep(0) + + def is_valid_url(self, url): + """验证URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + + def download_image(self, url): + """从URL下载图片到临时文件""" + response = requests.get(url, stream=True) + response.raise_for_status() + + # 创建临时文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') + with open(temp_file.name, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + return temp_file.name + + def cleanup(self): + """清理资源""" + if torch.cuda.is_available(): + torch.cuda.empty_cache() + gc.collect() + print("资源已清理") \ No newline at end of file diff --git a/apps/rmbg/settings.py b/apps/rmbg/settings.py new file mode 100644 index 0000000..3711ad5 --- /dev/null +++ b/apps/rmbg/settings.py @@ -0,0 +1,32 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8106 + debug: bool = False + + # API路由配置 + router_prefix: str = "/rmbg" + file_route: str = "/file" + batch_route: str = "/batch" + api_name: str = "remove_background" + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/rmbg/utils.py b/apps/rmbg/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/rmbg/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator diff --git a/apps/tryon/__init__.py b/apps/tryon/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/tryon/api.py b/apps/tryon/api.py new file mode 100644 index 0000000..5ad659f --- /dev/null +++ b/apps/tryon/api.py @@ -0,0 +1,71 @@ +from fastapi import APIRouter, UploadFile, File, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import TryonService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +from typing import List + +router = APIRouter(prefix=settings.router_prefix) +service = TryonService() + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def tryon_batch(data: dict, request: Request): + if "tshirt_urls" not in data: + raise HTTPException(status_code=400, detail="缺少tshirt_urls参数") + if "model_urls" not in data or not isinstance(data["model_urls"], list): + raise HTTPException(status_code=400, detail="缺少model_urls参数或格式错误") + + tshirt_urls = data["tshirt_urls"] + if not isinstance(tshirt_urls, list): + raise HTTPException(status_code=400, detail="tshirt_urls必须是URL列表") + + combinations = [] + for model_url in data["model_urls"]: + for tshirt_url in tshirt_urls: + combinations.append(f"{tshirt_url}|{model_url}") + + data["urls"] = combinations + config = data.get("config", {}) + + async def process_and_stream(): + async for result in service.process_batch(tshirt_urls, data["model_urls"], config): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) + + +@router.post(settings.file_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def tryon_file( + tshirt_files: List[UploadFile] = File(...), + model_file: UploadFile = File(...), + config: str = None, + request: Request = None +): + tshirt_contents = [await file.read() for file in tshirt_files] + model_content = await model_file.read() + + if request: + request._body = json.dumps({"urls": [f"file_{i}" for i in range(len(tshirt_files))]}).encode() + + config_dict = {} + if config: + try: + config_dict = json.loads(config) + except json.JSONDecodeError: + raise HTTPException(status_code=400, detail="配置参数格式错误") + + async def process_and_stream(): + async for result in service.process_files(tshirt_contents, model_content, config_dict): + yield json.dumps(result) + "\n" + + return StreamingResponse( + process_and_stream(), + media_type="application/x-ndjson" + ) + diff --git a/apps/tryon/app.py b/apps/tryon/app.py new file mode 100644 index 0000000..19a8e7a --- /dev/null +++ b/apps/tryon/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Tryon", + description="虚拟试穿", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/tryon/service.py b/apps/tryon/service.py new file mode 100644 index 0000000..0b404fa --- /dev/null +++ b/apps/tryon/service.py @@ -0,0 +1,455 @@ +import os +import json +import io +import aiohttp +from typing import List +import tempfile +from gradio_client import Client, handle_file +import asyncio +from urllib.parse import urlparse +from settings import settings + +tryon_server_url = settings.tryon_server_url + +class TryonService: + # 默认配置 + DEFAULT_CONFIG = { + 'tryon_marker': "_tryon", + 'tryon_target_marker': "tshirt", + 'tryon_models_dir': "/files/models", + 'denoise_steps': 20, + 'seed': 42, + 'is_crop': False, + 'output_format': 'png' + } + + def __init__(self): + """初始化虚拟试穿服务""" + self.client = None + + def get_gradio_client(self): + """获取或初始化Gradio客户端""" + if self.client is None: + try: + self.client = Client(tryon_server_url) + except Exception: + self.client = None + return self.client + + def _convert_config_types(self, config): + """转换配置参数类型""" + if not config: + return {} + + converted = {} + for key, value in config.items(): + if key == 'denoise_steps': + try: + converted[key] = int(value) + except (ValueError, TypeError): + converted[key] = 20 + elif key == 'seed': + try: + converted[key] = int(value) + except (ValueError, TypeError): + converted[key] = 42 + elif key == 'is_crop': + try: + converted[key] = bool(value) + except (ValueError, TypeError): + converted[key] = False + else: + converted[key] = value + return converted + + async def generate_virtual_tryon(self, tshirt_image_io: List[io.BytesIO], model_image_io: io.BytesIO, config=None): + """生成虚拟试穿结果 + + Args: + tshirt_image_io: T恤图片IO对象列表 + model_image_io: 模特图片IO对象 + config: 配置参数 + """ + if config is None: + config = {} + + # 转换配置参数类型并合并默认配置 + config = {**self.DEFAULT_CONFIG, **self._convert_config_types(config)} + + # 检查图片大小 + min_image_size = 1024 # 最小1KB + for tshirt_io in tshirt_image_io: + if len(tshirt_io.getvalue()) < min_image_size: + raise ValueError(f"T恤图片太小,可能不是有效图片,大小: {len(tshirt_io.getvalue())} 字节") + + if len(model_image_io.getvalue()) < min_image_size: + raise ValueError(f"模特图片太小,可能不是有效图片,大小: {len(model_image_io.getvalue())} 字节") + + client = self.get_gradio_client() + if client is None: + raise RuntimeError("Gradio API服务不可用,无法进行虚拟试穿") + + # 创建临时目录 + with tempfile.TemporaryDirectory() as temp_dir: + # 保存所有T-shirt图片为临时文件 + temp_tshirt_files = [] + for tshirt_io in tshirt_image_io: + with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_tshirt_file: + temp_tshirt_file.write(tshirt_io.getvalue()) + temp_tshirt_files.append(temp_tshirt_file.name) + + # 保存模特图片为临时文件 + with tempfile.NamedTemporaryFile(delete=False, suffix='.png', dir=temp_dir) as temp_model_file: + temp_model_file.write(model_image_io.getvalue()) + temp_model_file_path = temp_model_file.name + + try: + results = [] + # 对每件T恤进行虚拟试穿 + for temp_tshirt_file_path in temp_tshirt_files: + try: + # 调用API进行虚拟试穿 + result = client.predict( + dict({"background": handle_file(temp_model_file_path), "layers": [], "composite": None}), + garm_img=handle_file(temp_tshirt_file_path), + garment_des="", + is_checked=True, + is_checked_crop=config.get('is_crop', False), + denoise_steps=config.get('denoise_steps', 20), + seed=config.get('seed', 42), + api_name="/tryon" + ) + + # 处理返回结果 + if not result or not isinstance(result, tuple) or len(result) < 1: + raise RuntimeError("虚拟试穿服务返回了无效的结果格式") + + output_path = result[0] # 使用第一个图片作为结果 + if not os.path.exists(output_path): + raise RuntimeError(f"输出文件不存在: {output_path}") + + with open(output_path, 'rb') as f: + result_data = f.read() + results.append(io.BytesIO(result_data)) + + except Exception as e: + results.append(None) + + return results[0] if len(results) == 1 else results + + except Exception as e: + raise + + finally: + # 清理临时文件 + for temp_tshirt_file_path in temp_tshirt_files: + if os.path.exists(temp_tshirt_file_path): + os.remove(temp_tshirt_file_path) + if os.path.exists(temp_model_file_path): + os.remove(temp_model_file_path) + + def is_valid_url(self, url): + """检查URL是否有效""" + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except: + return False + + async def download_image(self, url): + """下载图片""" + try: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + content = await response.read() + image_io = io.BytesIO(content) + + # 检查内容长度 + if len(content) < 100: # 太小可能不是有效图片 + return None + + # 检查内容类型 + content_type = response.headers.get('Content-Type', '') + if 'image' not in content_type.lower(): + # 检查文件头部魔术数字 + header = content[:12] + is_image = any([ + header.startswith(b'\x89PNG'), # PNG + header.startswith(b'\xff\xd8\xff'), # JPEG + header.startswith(b'GIF8'), # GIF + header.startswith(b'RIFF') and b'WEBP' in header # WEBP + ]) + if not is_image: + return None + + return image_io + return None + except Exception: + return None + + async def process_urls(self, tshirt_urls: List[str], model_url: str, config=None): + """ + 处理多个T恤URL,流式返回结果 + + Args: + tshirt_urls: T恤图片URL列表 + model_url: 模特图片URL + config: 配置参数 + + Yields: + 每个T恤的处理结果 + """ + total = len(tshirt_urls) + success_count = 0 + error_count = 0 + + # 下载模特图片 + model_io = await self.download_image(model_url) + if model_io is None: + yield { + "status": "error", + "message": f"无法下载模特图片: {model_url}", + "success_count": success_count, + "error_count": error_count + 1, + "total": total + } + return + + for i, tshirt_url in enumerate(tshirt_urls, 1): + try: + # 下载T恤图片 + tshirt_io = await self.download_image(tshirt_url) + if tshirt_io is None: + error_count += 1 + yield { + "index": i, + "total": total, + "tshirt_url": tshirt_url, + "status": "error", + "message": f"无法下载T恤图片: {tshirt_url}", + "success_count": success_count, + "error_count": error_count + } + continue + + # 处理图片 + result = await self.generate_virtual_tryon([tshirt_io], model_io, config) + if result is None: + error_count += 1 + yield { + "index": i, + "total": total, + "tshirt_url": tshirt_url, + "status": "error", + "message": f"处理T恤图片失败: {tshirt_url}", + "success_count": success_count, + "error_count": error_count + } + else: + success_count += 1 + result.seek(0) + base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ + json.dumps(result.read().hex()) + yield { + "index": i, + "total": total, + "tshirt_url": tshirt_url, + "status": "success", + "data": base64_data, + "success_count": success_count, + "error_count": error_count + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "tshirt_url": tshirt_url, + "status": "error", + "message": str(e), + "success_count": success_count, + "error_count": error_count + } + + # 让出控制权,避免阻塞 + await asyncio.sleep(0) + + async def process_files(self, tshirt_contents: List[bytes], model_content: bytes, config=None): + """ + 处理多个T恤文件,流式返回结果 + + Args: + tshirt_contents: T恤图片内容列表 + model_content: 模特图片内容 + config: 配置参数 + + Yields: + 每个T恤的处理结果 + """ + total = len(tshirt_contents) + success_count = 0 + error_count = 0 + + model_io = io.BytesIO(model_content) + + for i, content in enumerate(tshirt_contents, 1): + try: + tshirt_io = io.BytesIO(content) + result = await self.generate_virtual_tryon([tshirt_io], model_io, config) + if result is None: + error_count += 1 + yield { + "index": i, + "total": total, + "status": "error", + "message": "处理T恤图片失败", + "success_count": success_count, + "error_count": error_count + } + else: + success_count += 1 + result.seek(0) + base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ + json.dumps(result.read().hex()) + yield { + "index": i, + "total": total, + "status": "success", + "data": base64_data, + "success_count": success_count, + "error_count": error_count + } + + except Exception as e: + error_count += 1 + yield { + "index": i, + "total": total, + "status": "error", + "message": str(e), + "success_count": success_count, + "error_count": error_count + } + + # 让出控制权,避免阻塞 + await asyncio.sleep(0) + + async def process_batch(self, tshirt_urls: List[str], model_urls: List[str], config=None): + """ + 批量处理多个T恤和模特图片,流式返回结果 + + Args: + tshirt_urls: T恤图片URL列表 + model_urls: 模特图片URL列表 + config: 配置参数 + + Yields: + 每个组合的处理结果 + """ + total = len(tshirt_urls) * len(model_urls) + success_count = 0 + error_count = 0 + current_index = 0 + + for model_url in model_urls: + try: + model_io = await self.download_image(model_url) + if model_io is None: + error_count += len(tshirt_urls) + for tshirt_url in tshirt_urls: + current_index += 1 + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "error", + "message": f"无法下载模特图片: {model_url}", + "success_count": success_count, + "error_count": error_count + } + continue + + for tshirt_url in tshirt_urls: + current_index += 1 + try: + tshirt_io = await self.download_image(tshirt_url) + if tshirt_io is None: + error_count += 1 + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "error", + "message": f"无法下载T恤图片: {tshirt_url}", + "success_count": success_count, + "error_count": error_count + } + continue + + result = await self.generate_virtual_tryon([tshirt_io], model_io, config) + if result is None: + error_count += 1 + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "error", + "message": f"处理T恤图片失败: {tshirt_url}", + "success_count": success_count, + "error_count": error_count + } + else: + success_count += 1 + result.seek(0) + base64_data = f"data:image/{config.get('output_format', 'png')};base64," + \ + json.dumps(result.read().hex()) + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "success", + "data": base64_data, + "success_count": success_count, + "error_count": error_count + } + + except Exception as e: + error_count += 1 + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "error", + "message": str(e), + "success_count": success_count, + "error_count": error_count + } + + # 让出控制权,避免阻塞 + await asyncio.sleep(0) + + except Exception as e: + error_count += len(tshirt_urls) + for tshirt_url in tshirt_urls: + current_index += 1 + yield { + "index": current_index, + "total": total, + "model_url": model_url, + "tshirt_url": tshirt_url, + "status": "error", + "message": str(e), + "success_count": success_count, + "error_count": error_count + } + + def cleanup(self): + """清理资源""" + self.client = None diff --git a/apps/tryon/settings.py b/apps/tryon/settings.py new file mode 100644 index 0000000..3673d89 --- /dev/null +++ b/apps/tryon/settings.py @@ -0,0 +1,35 @@ +from pydantic_settings import BaseSettings +from typing import Optional +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8112 + debug: bool = False + + # API路由配置 + router_prefix: str = "/tryon" + file_route: str = "/file" + batch_route: str = "/batch" + api_name: str = "tryon" + + upload_url: str = "http://173.255.202.68/imgurl/upload" + + # 虚拟试穿Tryon服务器URL + tryon_server_url: str = "http://192.168.2.200:7860" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/tryon/utils.py b/apps/tryon/utils.py new file mode 100644 index 0000000..8592041 --- /dev/null +++ b/apps/tryon/utils.py @@ -0,0 +1,146 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple +from settings import settings +from fastapi.responses import StreamingResponse +import json + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator