diff --git a/apps/midjourney/__init__.py b/apps/midjourney/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/apps/midjourney/api.py b/apps/midjourney/api.py new file mode 100644 index 0000000..8717f09 --- /dev/null +++ b/apps/midjourney/api.py @@ -0,0 +1,69 @@ +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import StreamingResponse +from service import MidjourneyService +from utils import jingrow_api_verify_and_billing +from settings import settings +import json +import asyncio +from typing import AsyncGenerator, List + +router = APIRouter(prefix=settings.router_prefix) +service = MidjourneyService() + +@router.post(settings.generate_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def generate_image(data: dict, request: Request): + """ + 根据文本提示生成图像 + + Args: + data: 包含文本提示和配置参数的字典 + request: FastAPI 请求对象 + + Returns: + 生成的图像内容 + """ + if "prompt" not in data: + raise HTTPException(status_code=400, detail="缺少prompt参数") + + prompt = data["prompt"] + config = data.get("config", {}) + + async def generate() -> AsyncGenerator[str, None]: + async for result in service.generate_image(prompt, config): + yield json.dumps(result, ensure_ascii=False) + "\n" + + return StreamingResponse( + generate(), + media_type="application/x-ndjson", + headers={"X-Content-Type-Options": "nosniff"} + ) + +@router.post(settings.batch_route) +@jingrow_api_verify_and_billing(api_name=settings.api_name) +async def batch_process_images(data: dict, request: Request): + """ + 批量处理多个图像URL,将每张图片分割成4张并保存 + + Args: + data: 包含图片URLs列表的字典 + request: FastAPI 请求对象 + + Returns: + 处理结果的流式响应 + """ + if "image_urls" not in data or not isinstance(data["image_urls"], list): + raise HTTPException(status_code=400, detail="缺少有效的image_urls参数") + + image_urls: List[str] = data["image_urls"] + config = data.get("config", {}) + + async def process() -> AsyncGenerator[str, None]: + async for result in service.process_batch(image_urls, config): + yield json.dumps(result, ensure_ascii=False) + "\n" + + return StreamingResponse( + process(), + media_type="application/x-ndjson", + headers={"X-Content-Type-Options": "nosniff"} + ) \ No newline at end of file diff --git a/apps/midjourney/app.py b/apps/midjourney/app.py new file mode 100644 index 0000000..6d3f947 --- /dev/null +++ b/apps/midjourney/app.py @@ -0,0 +1,21 @@ +from fastapi import FastAPI +from settings import settings +from api import router + +app = FastAPI( + title="Midjourney", + description="Midjourney绘画服务API", + version="1.0.0" +) + +# 注册路由 +app.include_router(router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run( + "app:app", + host=settings.host, + port=settings.port, + reload=settings.debug + ) \ No newline at end of file diff --git a/apps/midjourney/service.py b/apps/midjourney/service.py new file mode 100644 index 0000000..fa435e4 --- /dev/null +++ b/apps/midjourney/service.py @@ -0,0 +1,580 @@ +import json +import sys +import os +import io +import requests +import time +import random +import re +import uuid +import urllib.request +import urllib3 +import traceback +import logging +import base64 +from pathlib import Path +from urllib.parse import urlparse +from PIL import Image +import mimetypes +import asyncio +from typing import Dict, Any, List, AsyncGenerator, Optional +from settings import settings +from utils import get_new_image_url, is_valid_image_url + + +# 设置日志记录器 +logger = logging.getLogger("midjourney_service") + +# 禁用不安全请求警告 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +class MidjourneyService: + def __init__(self): + """初始化MidjourneyService""" + self.API_URL = settings.midjourney_api_url + self.APPLICATION_ID = settings.midjourney_application_id + self.DATA_ID = settings.midjourney_data_id + self.DATA_VERSION = settings.midjourney_data_version + self.SESSION_ID = settings.midjourney_session_id + + # 设置代理 + self.proxies = {} + if settings.http_proxy and settings.http_proxy.strip(): + self.proxies['http'] = settings.http_proxy + if settings.https_proxy and settings.https_proxy.strip(): + self.proxies['https'] = settings.https_proxy + + # 确保保存目录存在 + os.makedirs(settings.save_dir, exist_ok=True) + + @staticmethod + def first_where(array, key, value=None): + """在数组中找到第一个匹配条件的项""" + for item in array: + if callable(key) and key(item): + return item + if isinstance(key, str) and item[key].startswith(value): + return item + return None + + async def initialize_client(self, oauth_token, channel_id): + """初始化Discord客户端会话""" + client = requests.Session() + client.headers.update({ + 'Authorization': oauth_token + }) + + # 设置代理 + if self.proxies: + client.proxies.update(self.proxies) + + try: + # 获取频道信息 + response = client.get(f'{self.API_URL}/channels/{channel_id}') + data = response.json() + guild_id = data['guild_id'] + + # 获取用户信息 + response = client.get(f'{self.API_URL}/users/@me') + data = response.json() + user_id = data['id'] + + return client, guild_id, user_id + except Exception as e: + logger.error(f"初始化Discord客户端失败: {str(e)}") + traceback.print_exc() + raise Exception(f"初始化Discord客户端失败: {str(e)}") + + async def imagine(self, client, guild_id, channel_id, prompt, seed=None): + """发送imagine请求到Discord""" + params = { + 'type': 2, + 'application_id': self.APPLICATION_ID, + 'guild_id': guild_id, + 'channel_id': channel_id, + 'session_id': self.SESSION_ID, + 'data': { + 'version': self.DATA_VERSION, + 'id': self.DATA_ID, + 'name': 'imagine', + 'type': 1, + 'options': [{ + 'type': 3, + 'name': 'prompt', + 'value': prompt + }], + 'application_command': { + 'id': self.DATA_ID, + 'application_id': self.APPLICATION_ID, + 'version': self.DATA_VERSION, + 'default_member_permissions': None, + 'type': 1, + 'nsfw': False, + 'name': 'imagine', + 'description': 'Create images with Midjourney', + 'dm_permission': True, + 'options': [{ + 'type': 3, + 'name': 'prompt', + 'description': 'The prompt to imagine', + 'required': True + }] + }, + 'attachments': [] + } + } + + try: + # 发送请求 + r = client.post(f'{self.API_URL}/interactions', json=params) + # 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求 + print(f"[生成] 已发送请求,等待30秒后开始轮询...") + await asyncio.sleep(30) + + # 轮询获取结果 + imagine_message = None + count = 0 + last_progress = 0 + + # 轮询直到获取完整的结果或达到最大次数 + while count < settings.max_polling_attempts: + # 轮询两种消息: + # 1. 进度消息 - 用于更新进度 + progress_msg = await self.get_progress_message(client, channel_id, prompt, seed) + if progress_msg: + content = progress_msg.get('content', '') + progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) + if progress_match: + progress_value = int(progress_match.group(1)) + # 只有进度有变化时才发送更新 + if progress_value > last_progress: + last_progress = progress_value + logger.info(f"生成进度: {progress_value}%") + yield { + "status": "progress", + "progress": progress_value, + "message_id": progress_msg.get('id'), + "content": content + } + + # 2. 完成的消息 - 包含图像结果 + imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed) + if imagine_message: + # 找到最终结果 + break + + # 没有找到结果,继续等待 + logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待") + await asyncio.sleep(settings.polling_interval) + count += 1 + + # 检查是否超过最大轮询次数 + if count >= settings.max_polling_attempts: + logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}") + yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"} + return + + # 检查是否有有效的最终结果 + if not imagine_message: + logger.error("轮询结束但没有获取到有效结果") + yield {"status": "error", "message": "没有获取到有效结果"} + return + + # 返回最终结果 + logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}") + yield { + "status": "success", + "message_id": imagine_message.get('id'), + "content": imagine_message.get('content', ''), + "images": self.extract_image_urls(imagine_message) + } + + except Exception as e: + logger.error(f"发送Imagine请求失败: {str(e)}") + traceback.print_exc() + yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"} + + async def get_imagine(self, client, channel_id, prompt, count=0, seed=None): + """获取生成图像的消息""" + try: + # 获取最近的消息 + response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') + data = response.json() + + def criteria(item): + content = item.get('content', '') + + # 检查进度信息并更新状态 + if seed is not None and f"--seed {seed}" in content: + # 匹配百分比,支持多种格式:(93%) 或 93% + progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) + if progress_match: + progress_value = int(progress_match.group(1)) + # 记录进度信息 + logger.info(f"任务进度: {progress_value}%") + # 如果消息包含百分比,说明任务还在进行中,返回False继续轮询 + if "%" in content: + print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成") + return False + + # 排除进行中消息,只匹配完成的消息 + if "%" in content: + return False + + # seed 匹配 + if seed is not None: + seed_pattern = f"--seed {seed}" + # 检查是否包含指定的seed + if seed_pattern not in content: + return False + else: + print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成") + # 检查是否包含图像附件 + if 'attachments' in item and len(item.get('attachments', [])) > 0: + return True + return False # 默认不匹配 + + # 查找匹配完成状态的消息 + raw_message = self.first_where(data, criteria) + if raw_message is None: + print("[轮询] 没有找到完成的消息,继续等待") + return None + + # 打印匹配到的完整消息内容 + print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}") + try: + print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}") + except Exception as e: + print(f"[轮询] 无法打印完整消息内容: {str(e)}") + # 尝试打印关键字段 + print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}") + print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}") + for i, attachment in enumerate(raw_message.get('attachments', [])): + print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}") + + return raw_message + + except Exception as e: + logger.error(f"获取通道消息失败: {str(e)}") + return None + + async def get_progress_message(self, client, channel_id, prompt, seed=None): + """获取包含进度信息的消息""" + try: + # 获取最近的消息 + response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10') + data = response.json() + + for item in data: + content = item.get('content', '') + # 检查是否包含seed和进度信息 + if seed is not None and f"--seed {seed}" in content and "%" in content: + progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content) + if progress_match: + return item + + return None + + except Exception as e: + logger.error(f"获取进度消息失败: {str(e)}") + return None + + def extract_image_urls(self, message): + """从消息中提取图像URL""" + image_urls = [] + if 'attachments' in message and message['attachments']: + for attachment in message['attachments']: + if 'url' in attachment: + image_urls.append(attachment['url']) + return image_urls + + async def split_image(self, image_url): + """将一张大图切割成四张子图,保存到本地并返回URLs""" + try: + # 下载图像 + response = requests.get(image_url, proxies=self.proxies if self.proxies else None) + if response.status_code != 200: + logger.error(f"下载图像失败,状态码: {response.status_code}") + return None + + image_data = response.content + + # 从二进制数据创建图像对象 + img = Image.open(io.BytesIO(image_data)) + width, height = img.size + + # 确认图像尺寸约为2048x2048 + if width < 1500 or height < 1500: + logger.error(f"图像尺寸不符合预期: {width}x{height}") + return None + + # 计算每个象限的尺寸 + half_width = width // 2 + half_height = height // 2 + + # 分割图像为四个象限 + top_left = img.crop((0, 0, half_width, half_height)) + top_right = img.crop((half_width, 0, width, half_height)) + bottom_left = img.crop((0, half_height, half_width, height)) + bottom_right = img.crop((half_width, half_height, width, height)) + + # 生成唯一的图片名称前缀 + image_id = uuid.uuid4().hex[:10] + + # 保存图片到本地并生成URLs + image_urls = [] + for i, quadrant in enumerate([top_left, top_right, bottom_left, bottom_right], 1): + # 生成文件名和保存路径 + filename = f"split_{image_id}_{i}.png" + file_path = os.path.join(settings.save_dir, filename) + + # 保存图片 + quadrant.save(file_path, format="PNG") + + # 构建图片URL + image_url = f"{settings.download_url}/{filename}" + image_urls.append(image_url) + + return image_urls + except Exception as e: + logger.error(f"分割图像失败: {str(e)}") + traceback.print_exc() + return None + + async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None): + """批量处理多个图像URL""" + if not config: + config = {} + + total = len(image_urls) + success_count = 0 + error_count = 0 + + for i, image_url in enumerate(image_urls, 1): + try: + if not is_valid_image_url(image_url): + error_count += 1 + response = { + "status": "error", + "index": i, + "total": total, + "success_count": success_count, + "error_count": error_count, + "message": "无效的图片URL" + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + continue + + split_urls = await self.split_image(image_url) + if split_urls and len(split_urls) == 4: + success_count += 1 + response = { + "status": "success", + "index": i, + "total": total, + "success_count": success_count, + "error_count": error_count, + "images": split_urls + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + else: + error_count += 1 + response = { + "status": "error", + "index": i, + "total": total, + "success_count": success_count, + "error_count": error_count, + "message": "分割图片失败" + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + except Exception as e: + error_count += 1 + response = { + "status": "error", + "index": i, + "total": total, + "success_count": success_count, + "error_count": error_count, + "message": str(e) + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + + async def generate_image(self, prompt, config=None): + """生成图像并以流式方式返回结果""" + if not config: + config = {} + + # 获取必要的认证信息 + oauth_token = config.get('oauth_token', settings.midjourney_oauth_token) + channel_id = config.get('channel_id', settings.midjourney_channel_id) + + if not oauth_token or not channel_id: + response = { + "status": "error", + "message": "缺少Discord配置", + "success_count": 0, + "error_count": 1 + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + return + + # 初始化客户端 + try: + client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id) + except Exception as e: + response = { + "status": "error", + "message": str(e), + "success_count": 0, + "error_count": 1 + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + return + + # 解析和准备提示词 + prompt = prompt.strip() + + # 获取或设置默认选项 + options = settings.midjourney_default_options.copy() + if 'options' in config: + options.update(config.get('options', {})) + + # 强制设置seed如果没有 + if "seed" not in options or not options.get('seed'): + options['seed'] = random.randint(0, 4294967295) + + # 处理选项,构建参数字符串 + seed = options.pop('seed', None) + parameter = "" + no_value_key = ['relax', 'fast', 'turbo', 'tile'] + for key, value in options.items(): + if key in no_value_key: + parameter += f" --{key}" + else: + parameter += f" --{key} {value}" + + # 添加seed + if seed: + parameter += f" --seed {seed}" + # 打印使用的seed值 + print(f"[生成] 使用的seed值: {seed}") + + # 处理参考图像 + if 'reference_images' in config and config['reference_images']: + # 确保是列表格式 + reference_images = config['reference_images'] + if not isinstance(reference_images, list): + reference_images = [reference_images] + + # 转换图片URL + image_urls = [] + for image_url in reference_images: + if is_valid_image_url(image_url): + try: + new_url = get_new_image_url(image_url) + if new_url: + image_urls.append(new_url) + except Exception as e: + logger.warning(f"转换图片URL失败: {str(e)}") + + # 添加到prompt前面 + if image_urls: + prompt = " ".join(image_urls) + " " + prompt + + # 设置图像权重 + if 'image_weight' in config and isinstance(config['image_weight'], (int, float)): + iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间 + parameter += f" --iw {iw}" + + # 添加字符引用 + if 'characters' in config and config['characters']: + char_urls = [] + characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']] + + for char_url in characters: + if is_valid_image_url(char_url): + try: + new_url = get_new_image_url(char_url) + if new_url: + char_urls.append(new_url) + except Exception as e: + logger.warning(f"转换角色图片URL失败: {str(e)}") + + if char_urls: + prompt = prompt + " --cref " + " ".join(char_urls) + + # 添加风格引用 + if 'styles' in config and config['styles']: + style_urls = [] + styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']] + + for style_url in styles: + if is_valid_image_url(style_url): + try: + new_url = get_new_image_url(style_url) + if new_url: + style_urls.append(new_url) + except Exception as e: + logger.warning(f"转换风格图片URL失败: {str(e)}") + + if style_urls: + prompt = prompt + " --sref " + " ".join(style_urls) + + # 添加参数 + prompt = prompt + parameter + + # 流式返回结果 + success_count = 0 + error_count = 0 + + async for result in self.imagine(client, guild_id, channel_id, prompt, seed): + if result.get("status") == "progress": + # 进度信息保持简单 + response = { + "status": "progress", + "progress": result.get("progress", 0), + "success_count": success_count, + "error_count": error_count + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + elif result.get("status") == "success": + success_count += 1 + response = { + "status": "success", + "success_count": success_count, + "error_count": error_count + } + + # 如果需要分割图片 + if config.get("split_image", False) and result.get("images"): + try: + orig_image_url = result["images"][0] + split_urls = await self.split_image(orig_image_url) + if split_urls: + response["images"] = split_urls + except Exception as e: + error_count += 1 + logger.error(f"分割图像失败: {str(e)}") + response["error_count"] = error_count + + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response + else: + # 错误信息保持简单 + error_count += 1 + response = { + "status": "error", + "message": result.get("message", "未知错误"), + "success_count": success_count, + "error_count": error_count + } + print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}") + yield response \ No newline at end of file diff --git a/apps/midjourney/settings.py b/apps/midjourney/settings.py new file mode 100644 index 0000000..95322da --- /dev/null +++ b/apps/midjourney/settings.py @@ -0,0 +1,66 @@ +from pydantic_settings import BaseSettings +from typing import Optional, Dict +from functools import lru_cache + +class Settings(BaseSettings): + # Japi Server 配置 + host: str = "0.0.0.0" + port: int = 8113 + debug: bool = True + + # API路由配置 + router_prefix: str = "/midjourney" + generate_route: str = "/generate" # 生成图片的路由 + batch_route: str = "/batch" # 批量处理图片的路由 + api_name: str = "midjourney" # 默认API名称 + + upload_url: str = "http://images.jingrow.com:8080/api/v1/image" + + # 图片保存配置 + save_dir: str = "../jfile/midjourney" + # Japi 静态资源下载URL + download_url: str = "http://api.jingrow.com:9080/midjourney" + + # Jingrow Jcloud API 配置 + jingrow_api_url: str = "https://cloud.jingrow.com" + jingrow_api_key: Optional[str] = None + jingrow_api_secret: Optional[str] = None + + # Discord Midjourney配置 + midjourney_api_url: str = "https://discord.com/api/v9" + midjourney_application_id: str = "936929561302675456" + midjourney_data_id: str = "938956540159881230" + midjourney_data_version: str = "1237876415471554623" + midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029" + midjourney_channel_id: str = "1259838588510670941" + midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0" + midjourney_suffix: str = "mj" # 图片文件名的后缀 + + # 代理配置 + http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理 + https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理 + + # Midjourney默认选项 + midjourney_default_options: Dict = { + "ar": "1:1", + "v": "6.1", + "quality": "1" + } + + # 图像设置 + add_title_to_image_name: bool = False # 是否将标题添加到图片名称中 + + # 超时设置(秒) + request_timeout: int = 30 + max_polling_attempts: int = 60 + polling_interval: int = 3 + + class Config: + env_file = ".env" + +@lru_cache() +def get_settings() -> Settings: + return Settings() + +# 创建全局配置实例 +settings = get_settings() \ No newline at end of file diff --git a/apps/midjourney/utils.py b/apps/midjourney/utils.py new file mode 100644 index 0000000..62bb8ce --- /dev/null +++ b/apps/midjourney/utils.py @@ -0,0 +1,316 @@ +import aiohttp +from functools import wraps +from fastapi import HTTPException +import os +from typing import Callable, Any, Dict, Optional, Tuple, List +from settings import settings +from fastapi.responses import StreamingResponse +import json +import requests +import io +import re +from pathlib import Path +from urllib.parse import urlparse +from PIL import Image + +async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]: + """验证API密钥和团队余额""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name} + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="验证服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + if not result.get("success"): + raise HTTPException(status_code=401, detail=result.get("message", "验证失败")) + + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}") + +async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]: + """从Jingrow平台扣除API使用费""" + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee", + headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"}, + json={ + "api_key": api_key, + "api_secret": api_secret, + "api_name": api_name, + "usage_count": usage_count + } + ) as response: + if response.status != 200: + raise HTTPException(status_code=500, detail="扣费服务暂时不可用") + + result = await response.json() + if "message" in result and isinstance(result["message"], dict): + result = result["message"] + + return result + + except HTTPException: + raise + except Exception as e: + return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"} + +def get_token_from_request(request) -> str: + """从请求中获取访问令牌""" + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + auth_header = request.headers.get("Authorization", "") + if not auth_header or not auth_header.startswith("token "): + raise HTTPException(status_code=401, detail="无效的Authorization头格式") + + token = auth_header[6:] + if ":" not in token: + raise HTTPException(status_code=401, detail="无效的令牌格式") + + return token + +def jingrow_api_verify_and_billing(api_name: str): + """Jingrow API 验证装饰器(带余额检查和扣费)""" + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + try: + request = kwargs.get('request') + if not request: + raise HTTPException(status_code=400, detail="无法获取请求信息") + + token = get_token_from_request(request) + api_key, api_secret = token.split(":", 1) + + verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name) + if not verify_result.get("success"): + raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败")) + + result = await func(*args, **kwargs) + + usage_count = 1 + try: + body_data = await request.json() + if isinstance(body_data, dict): + for key in ["items", "urls", "images", "files"]: + if key in body_data and isinstance(body_data[key], list): + usage_count = len(body_data[key]) + break + except Exception: + pass + + if isinstance(result, StreamingResponse): + original_generator = result.body_iterator + success_count = 0 + + async def wrapped_generator(): + nonlocal success_count + async for chunk in original_generator: + try: + data = json.loads(chunk) + if isinstance(data, dict) and data.get("status") == "success": + success_count += 1 + except: + pass + yield chunk + + if success_count > 0: + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count) + + return StreamingResponse( + wrapped_generator(), + media_type=result.media_type, + headers=result.headers + ) + + if isinstance(result, dict) and result.get("success") is True: + actual_usage_count = result.get("successful_count", usage_count) + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count) + return result + + await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count) + return result + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}") + return wrapper + return decorator + +def is_valid_image_url(url: str) -> bool: + """验证图片URL是否有效 + + Args: + url: 要验证的URL + + Returns: + bool: URL是否有效 + """ + if not url or not isinstance(url, str): + return False + + try: + parsed = urlparse(url) + if not parsed.scheme or not parsed.netloc: + return False + + # 检查文件扩展名 + path = parsed.path.lower() + valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif'] + return any(path.endswith(ext) for ext in valid_extensions) + except: + return False + +def validate_image_file(file_path: str) -> bool: + """验证图片文件是否有效 + + Args: + file_path: 图片文件路径 + + Returns: + bool: 文件是否有效 + """ + try: + with Image.open(file_path) as img: + img.verify() + return True + except: + return False + +def get_image_size(image_url: str) -> Optional[Tuple[int, int]]: + """获取图片尺寸 + + Args: + image_url: 图片URL + + Returns: + Optional[Tuple[int, int]]: 图片尺寸(宽,高),如果获取失败则返回None + """ + try: + response = requests.get(image_url, verify=False, timeout=10) + if response.status_code != 200: + return None + + with Image.open(io.BytesIO(response.content)) as img: + return img.size + except: + return None + +def is_valid_image_size(image_url: str, min_size: int = 512) -> bool: + """验证图片尺寸是否满足最小要求 + + Args: + image_url: 图片URL + min_size: 最小尺寸要求 + + Returns: + bool: 图片尺寸是否满足要求 + """ + size = get_image_size(image_url) + if not size: + return False + width, height = size + return width >= min_size and height >= min_size + +def extract_image_urls_from_text(text: str) -> List[str]: + """从文本中提取图片URL + + Args: + text: 包含图片URL的文本 + + Returns: + List[str]: 提取到的图片URL列表 + """ + # 匹配常见的图片URL模式 + url_pattern = r'https?://[^\s<>"]+?\.(?:jpg|jpeg|png|webp|gif)(?:\?[^\s<>"]*)?' + urls = re.findall(url_pattern, text, re.IGNORECASE) + return [url for url in urls if is_valid_image_url(url)] + +def sanitize_filename(filename: str) -> str: + """清理文件名,移除非法字符 + + Args: + filename: 原始文件名 + + Returns: + str: 清理后的文件名 + """ + # 移除非法字符 + filename = re.sub(r'[<>:"/\\|?*]', '', filename) + # 限制长度 + if len(filename) > 255: + name, ext = os.path.splitext(filename) + filename = name[:255-len(ext)] + ext + return filename + +def get_new_image_url(image_url: str) -> str: + """将图片URL转换为新的存储URL + + Args: + image_url: 原始图片URL + + Returns: + str: 新的图片URL + + Raises: + HTTPException: 当图片处理失败时抛出 + """ + try: + # 使用settings中的upload_url + upload_url = settings.upload_url + if not upload_url: + raise HTTPException(status_code=500, detail="未配置上传URL") + + # 下载图片 + response = requests.get(image_url, verify=False, timeout=30) + if response.status_code != 200: + raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}") + image_data = response.content + + # 解析文件名和扩展名 + parsed_url = urlparse(image_url) + file_name = Path(parsed_url.path).name + file_name = sanitize_filename(file_name) + file_ext = Path(file_name).suffix.lower() + + # 如果图片是webp格式,转换为png格式 + if file_ext == '.webp': + image = Image.open(io.BytesIO(image_data)) + png_buffer = io.BytesIO() + image.save(png_buffer, format='PNG') + image_data = png_buffer.getvalue() + file_name = file_name.replace('.webp', '.png') + + # 准备文件上传 + files = {"file": (file_name, image_data)} + + # 上传图片 + upload_response = requests.post(upload_url, files=files, verify=False, timeout=30) + if upload_response.status_code != 200: + error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}" + raise HTTPException(status_code=500, detail=error_message) + + result = upload_response.json() + new_url = result.get("url") + if not new_url: + raise HTTPException(status_code=500, detail="上传成功但未返回URL") + return new_url + + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")