增加midjourney微服务

This commit is contained in:
jingrow 2025-05-19 22:34:45 +08:00
parent 4be051e459
commit 6e70ad3033
6 changed files with 1052 additions and 0 deletions

View File

69
apps/midjourney/api.py Normal file
View File

@ -0,0 +1,69 @@
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from service import MidjourneyService
from utils import jingrow_api_verify_and_billing
from settings import settings
import json
import asyncio
from typing import AsyncGenerator, List
router = APIRouter(prefix=settings.router_prefix)
service = MidjourneyService()
@router.post(settings.generate_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name)
async def generate_image(data: dict, request: Request):
"""
根据文本提示生成图像
Args:
data: 包含文本提示和配置参数的字典
request: FastAPI 请求对象
Returns:
生成的图像内容
"""
if "prompt" not in data:
raise HTTPException(status_code=400, detail="缺少prompt参数")
prompt = data["prompt"]
config = data.get("config", {})
async def generate() -> AsyncGenerator[str, None]:
async for result in service.generate_image(prompt, config):
yield json.dumps(result, ensure_ascii=False) + "\n"
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)
@router.post(settings.batch_route)
@jingrow_api_verify_and_billing(api_name=settings.api_name)
async def batch_process_images(data: dict, request: Request):
"""
批量处理多个图像URL将每张图片分割成4张并保存
Args:
data: 包含图片URLs列表的字典
request: FastAPI 请求对象
Returns:
处理结果的流式响应
"""
if "image_urls" not in data or not isinstance(data["image_urls"], list):
raise HTTPException(status_code=400, detail="缺少有效的image_urls参数")
image_urls: List[str] = data["image_urls"]
config = data.get("config", {})
async def process() -> AsyncGenerator[str, None]:
async for result in service.process_batch(image_urls, config):
yield json.dumps(result, ensure_ascii=False) + "\n"
return StreamingResponse(
process(),
media_type="application/x-ndjson",
headers={"X-Content-Type-Options": "nosniff"}
)

21
apps/midjourney/app.py Normal file
View File

@ -0,0 +1,21 @@
from fastapi import FastAPI
from settings import settings
from api import router
app = FastAPI(
title="Midjourney",
description="Midjourney绘画服务API",
version="1.0.0"
)
# 注册路由
app.include_router(router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app:app",
host=settings.host,
port=settings.port,
reload=settings.debug
)

580
apps/midjourney/service.py Normal file
View File

@ -0,0 +1,580 @@
import json
import sys
import os
import io
import requests
import time
import random
import re
import uuid
import urllib.request
import urllib3
import traceback
import logging
import base64
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
import mimetypes
import asyncio
from typing import Dict, Any, List, AsyncGenerator, Optional
from settings import settings
from utils import get_new_image_url, is_valid_image_url
# 设置日志记录器
logger = logging.getLogger("midjourney_service")
# 禁用不安全请求警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
class MidjourneyService:
def __init__(self):
"""初始化MidjourneyService"""
self.API_URL = settings.midjourney_api_url
self.APPLICATION_ID = settings.midjourney_application_id
self.DATA_ID = settings.midjourney_data_id
self.DATA_VERSION = settings.midjourney_data_version
self.SESSION_ID = settings.midjourney_session_id
# 设置代理
self.proxies = {}
if settings.http_proxy and settings.http_proxy.strip():
self.proxies['http'] = settings.http_proxy
if settings.https_proxy and settings.https_proxy.strip():
self.proxies['https'] = settings.https_proxy
# 确保保存目录存在
os.makedirs(settings.save_dir, exist_ok=True)
@staticmethod
def first_where(array, key, value=None):
"""在数组中找到第一个匹配条件的项"""
for item in array:
if callable(key) and key(item):
return item
if isinstance(key, str) and item[key].startswith(value):
return item
return None
async def initialize_client(self, oauth_token, channel_id):
"""初始化Discord客户端会话"""
client = requests.Session()
client.headers.update({
'Authorization': oauth_token
})
# 设置代理
if self.proxies:
client.proxies.update(self.proxies)
try:
# 获取频道信息
response = client.get(f'{self.API_URL}/channels/{channel_id}')
data = response.json()
guild_id = data['guild_id']
# 获取用户信息
response = client.get(f'{self.API_URL}/users/@me')
data = response.json()
user_id = data['id']
return client, guild_id, user_id
except Exception as e:
logger.error(f"初始化Discord客户端失败: {str(e)}")
traceback.print_exc()
raise Exception(f"初始化Discord客户端失败: {str(e)}")
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
"""发送imagine请求到Discord"""
params = {
'type': 2,
'application_id': self.APPLICATION_ID,
'guild_id': guild_id,
'channel_id': channel_id,
'session_id': self.SESSION_ID,
'data': {
'version': self.DATA_VERSION,
'id': self.DATA_ID,
'name': 'imagine',
'type': 1,
'options': [{
'type': 3,
'name': 'prompt',
'value': prompt
}],
'application_command': {
'id': self.DATA_ID,
'application_id': self.APPLICATION_ID,
'version': self.DATA_VERSION,
'default_member_permissions': None,
'type': 1,
'nsfw': False,
'name': 'imagine',
'description': 'Create images with Midjourney',
'dm_permission': True,
'options': [{
'type': 3,
'name': 'prompt',
'description': 'The prompt to imagine',
'required': True
}]
},
'attachments': []
}
}
try:
# 发送请求
r = client.post(f'{self.API_URL}/interactions', json=params)
# 初始等待时间从5秒延长到30秒给Discord足够的时间开始处理请求
print(f"[生成] 已发送请求等待30秒后开始轮询...")
await asyncio.sleep(30)
# 轮询获取结果
imagine_message = None
count = 0
last_progress = 0
# 轮询直到获取完整的结果或达到最大次数
while count < settings.max_polling_attempts:
# 轮询两种消息:
# 1. 进度消息 - 用于更新进度
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
if progress_msg:
content = progress_msg.get('content', '')
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 只有进度有变化时才发送更新
if progress_value > last_progress:
last_progress = progress_value
logger.info(f"生成进度: {progress_value}%")
yield {
"status": "progress",
"progress": progress_value,
"message_id": progress_msg.get('id'),
"content": content
}
# 2. 完成的消息 - 包含图像结果
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
if imagine_message:
# 找到最终结果
break
# 没有找到结果,继续等待
logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待")
await asyncio.sleep(settings.polling_interval)
count += 1
# 检查是否超过最大轮询次数
if count >= settings.max_polling_attempts:
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
return
# 检查是否有有效的最终结果
if not imagine_message:
logger.error("轮询结束但没有获取到有效结果")
yield {"status": "error", "message": "没有获取到有效结果"}
return
# 返回最终结果
logger.info(f"成功获取图像结果消息ID: {imagine_message.get('id')}")
yield {
"status": "success",
"message_id": imagine_message.get('id'),
"content": imagine_message.get('content', ''),
"images": self.extract_image_urls(imagine_message)
}
except Exception as e:
logger.error(f"发送Imagine请求失败: {str(e)}")
traceback.print_exc()
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
"""获取生成图像的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
def criteria(item):
content = item.get('content', '')
# 检查进度信息并更新状态
if seed is not None and f"--seed {seed}" in content:
# 匹配百分比,支持多种格式:(93%) 或 93%
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
progress_value = int(progress_match.group(1))
# 记录进度信息
logger.info(f"任务进度: {progress_value}%")
# 如果消息包含百分比说明任务还在进行中返回False继续轮询
if "%" in content:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成")
return False
# 排除进行中消息,只匹配完成的消息
if "%" in content:
return False
# seed 匹配
if seed is not None:
seed_pattern = f"--seed {seed}"
# 检查是否包含指定的seed
if seed_pattern not in content:
return False
else:
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成")
# 检查是否包含图像附件
if 'attachments' in item and len(item.get('attachments', [])) > 0:
return True
return False # 默认不匹配
# 查找匹配完成状态的消息
raw_message = self.first_where(data, criteria)
if raw_message is None:
print("[轮询] 没有找到完成的消息,继续等待")
return None
# 打印匹配到的完整消息内容
print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}")
try:
print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}")
except Exception as e:
print(f"[轮询] 无法打印完整消息内容: {str(e)}")
# 尝试打印关键字段
print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}")
print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}")
for i, attachment in enumerate(raw_message.get('attachments', [])):
print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}")
return raw_message
except Exception as e:
logger.error(f"获取通道消息失败: {str(e)}")
return None
async def get_progress_message(self, client, channel_id, prompt, seed=None):
"""获取包含进度信息的消息"""
try:
# 获取最近的消息
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
data = response.json()
for item in data:
content = item.get('content', '')
# 检查是否包含seed和进度信息
if seed is not None and f"--seed {seed}" in content and "%" in content:
progress_match = re.search(r'(?:\(|)?(\d+)%(?:\)|)?', content)
if progress_match:
return item
return None
except Exception as e:
logger.error(f"获取进度消息失败: {str(e)}")
return None
def extract_image_urls(self, message):
"""从消息中提取图像URL"""
image_urls = []
if 'attachments' in message and message['attachments']:
for attachment in message['attachments']:
if 'url' in attachment:
image_urls.append(attachment['url'])
return image_urls
async def split_image(self, image_url):
"""将一张大图切割成四张子图保存到本地并返回URLs"""
try:
# 下载图像
response = requests.get(image_url, proxies=self.proxies if self.proxies else None)
if response.status_code != 200:
logger.error(f"下载图像失败,状态码: {response.status_code}")
return None
image_data = response.content
# 从二进制数据创建图像对象
img = Image.open(io.BytesIO(image_data))
width, height = img.size
# 确认图像尺寸约为2048x2048
if width < 1500 or height < 1500:
logger.error(f"图像尺寸不符合预期: {width}x{height}")
return None
# 计算每个象限的尺寸
half_width = width // 2
half_height = height // 2
# 分割图像为四个象限
top_left = img.crop((0, 0, half_width, half_height))
top_right = img.crop((half_width, 0, width, half_height))
bottom_left = img.crop((0, half_height, half_width, height))
bottom_right = img.crop((half_width, half_height, width, height))
# 生成唯一的图片名称前缀
image_id = uuid.uuid4().hex[:10]
# 保存图片到本地并生成URLs
image_urls = []
for i, quadrant in enumerate([top_left, top_right, bottom_left, bottom_right], 1):
# 生成文件名和保存路径
filename = f"split_{image_id}_{i}.png"
file_path = os.path.join(settings.save_dir, filename)
# 保存图片
quadrant.save(file_path, format="PNG")
# 构建图片URL
image_url = f"{settings.download_url}/{filename}"
image_urls.append(image_url)
return image_urls
except Exception as e:
logger.error(f"分割图像失败: {str(e)}")
traceback.print_exc()
return None
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
"""批量处理多个图像URL"""
if not config:
config = {}
total = len(image_urls)
success_count = 0
error_count = 0
for i, image_url in enumerate(image_urls, 1):
try:
if not is_valid_image_url(image_url):
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": "无效的图片URL"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
continue
split_urls = await self.split_image(image_url)
if split_urls and len(split_urls) == 4:
success_count += 1
response = {
"status": "success",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"images": split_urls
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": "分割图片失败"
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
except Exception as e:
error_count += 1
response = {
"status": "error",
"index": i,
"total": total,
"success_count": success_count,
"error_count": error_count,
"message": str(e)
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
async def generate_image(self, prompt, config=None):
"""生成图像并以流式方式返回结果"""
if not config:
config = {}
# 获取必要的认证信息
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
channel_id = config.get('channel_id', settings.midjourney_channel_id)
if not oauth_token or not channel_id:
response = {
"status": "error",
"message": "缺少Discord配置",
"success_count": 0,
"error_count": 1
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 初始化客户端
try:
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
except Exception as e:
response = {
"status": "error",
"message": str(e),
"success_count": 0,
"error_count": 1
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
return
# 解析和准备提示词
prompt = prompt.strip()
# 获取或设置默认选项
options = settings.midjourney_default_options.copy()
if 'options' in config:
options.update(config.get('options', {}))
# 强制设置seed如果没有
if "seed" not in options or not options.get('seed'):
options['seed'] = random.randint(0, 4294967295)
# 处理选项,构建参数字符串
seed = options.pop('seed', None)
parameter = ""
no_value_key = ['relax', 'fast', 'turbo', 'tile']
for key, value in options.items():
if key in no_value_key:
parameter += f" --{key}"
else:
parameter += f" --{key} {value}"
# 添加seed
if seed:
parameter += f" --seed {seed}"
# 打印使用的seed值
print(f"[生成] 使用的seed值: {seed}")
# 处理参考图像
if 'reference_images' in config and config['reference_images']:
# 确保是列表格式
reference_images = config['reference_images']
if not isinstance(reference_images, list):
reference_images = [reference_images]
# 转换图片URL
image_urls = []
for image_url in reference_images:
if is_valid_image_url(image_url):
try:
new_url = get_new_image_url(image_url)
if new_url:
image_urls.append(new_url)
except Exception as e:
logger.warning(f"转换图片URL失败: {str(e)}")
# 添加到prompt前面
if image_urls:
prompt = " ".join(image_urls) + " " + prompt
# 设置图像权重
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间
parameter += f" --iw {iw}"
# 添加字符引用
if 'characters' in config and config['characters']:
char_urls = []
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
for char_url in characters:
if is_valid_image_url(char_url):
try:
new_url = get_new_image_url(char_url)
if new_url:
char_urls.append(new_url)
except Exception as e:
logger.warning(f"转换角色图片URL失败: {str(e)}")
if char_urls:
prompt = prompt + " --cref " + " ".join(char_urls)
# 添加风格引用
if 'styles' in config and config['styles']:
style_urls = []
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
for style_url in styles:
if is_valid_image_url(style_url):
try:
new_url = get_new_image_url(style_url)
if new_url:
style_urls.append(new_url)
except Exception as e:
logger.warning(f"转换风格图片URL失败: {str(e)}")
if style_urls:
prompt = prompt + " --sref " + " ".join(style_urls)
# 添加参数
prompt = prompt + parameter
# 流式返回结果
success_count = 0
error_count = 0
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
if result.get("status") == "progress":
# 进度信息保持简单
response = {
"status": "progress",
"progress": result.get("progress", 0),
"success_count": success_count,
"error_count": error_count
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
elif result.get("status") == "success":
success_count += 1
response = {
"status": "success",
"success_count": success_count,
"error_count": error_count
}
# 如果需要分割图片
if config.get("split_image", False) and result.get("images"):
try:
orig_image_url = result["images"][0]
split_urls = await self.split_image(orig_image_url)
if split_urls:
response["images"] = split_urls
except Exception as e:
error_count += 1
logger.error(f"分割图像失败: {str(e)}")
response["error_count"] = error_count
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response
else:
# 错误信息保持简单
error_count += 1
response = {
"status": "error",
"message": result.get("message", "未知错误"),
"success_count": success_count,
"error_count": error_count
}
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
yield response

View File

@ -0,0 +1,66 @@
from pydantic_settings import BaseSettings
from typing import Optional, Dict
from functools import lru_cache
class Settings(BaseSettings):
# Japi Server 配置
host: str = "0.0.0.0"
port: int = 8113
debug: bool = True
# API路由配置
router_prefix: str = "/midjourney"
generate_route: str = "/generate" # 生成图片的路由
batch_route: str = "/batch" # 批量处理图片的路由
api_name: str = "midjourney" # 默认API名称
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
# 图片保存配置
save_dir: str = "../jfile/midjourney"
# Japi 静态资源下载URL
download_url: str = "http://api.jingrow.com:9080/midjourney"
# Jingrow Jcloud API 配置
jingrow_api_url: str = "https://cloud.jingrow.com"
jingrow_api_key: Optional[str] = None
jingrow_api_secret: Optional[str] = None
# Discord Midjourney配置
midjourney_api_url: str = "https://discord.com/api/v9"
midjourney_application_id: str = "936929561302675456"
midjourney_data_id: str = "938956540159881230"
midjourney_data_version: str = "1237876415471554623"
midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029"
midjourney_channel_id: str = "1259838588510670941"
midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0"
midjourney_suffix: str = "mj" # 图片文件名的后缀
# 代理配置
http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理
https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理
# Midjourney默认选项
midjourney_default_options: Dict = {
"ar": "1:1",
"v": "6.1",
"quality": "1"
}
# 图像设置
add_title_to_image_name: bool = False # 是否将标题添加到图片名称中
# 超时设置(秒)
request_timeout: int = 30
max_polling_attempts: int = 60
polling_interval: int = 3
class Config:
env_file = ".env"
@lru_cache()
def get_settings() -> Settings:
return Settings()
# 创建全局配置实例
settings = get_settings()

316
apps/midjourney/utils.py Normal file
View File

@ -0,0 +1,316 @@
import aiohttp
from functools import wraps
from fastapi import HTTPException
import os
from typing import Callable, Any, Dict, Optional, Tuple, List
from settings import settings
from fastapi.responses import StreamingResponse
import json
import requests
import io
import re
from pathlib import Path
from urllib.parse import urlparse
from PIL import Image
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
"""验证API密钥和团队余额"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
result = await response.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
if not result.get("success"):
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
"""从Jingrow平台扣除API使用费"""
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
json={
"api_key": api_key,
"api_secret": api_secret,
"api_name": api_name,
"usage_count": usage_count
}
) as response:
if response.status != 200:
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
result = await response.json()
if "message" in result and isinstance(result["message"], dict):
result = result["message"]
return result
except HTTPException:
raise
except Exception as e:
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
def get_token_from_request(request) -> str:
"""从请求中获取访问令牌"""
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
auth_header = request.headers.get("Authorization", "")
if not auth_header or not auth_header.startswith("token "):
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
token = auth_header[6:]
if ":" not in token:
raise HTTPException(status_code=401, detail="无效的令牌格式")
return token
def jingrow_api_verify_and_billing(api_name: str):
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
try:
request = kwargs.get('request')
if not request:
raise HTTPException(status_code=400, detail="无法获取请求信息")
token = get_token_from_request(request)
api_key, api_secret = token.split(":", 1)
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
if not verify_result.get("success"):
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
result = await func(*args, **kwargs)
usage_count = 1
try:
body_data = await request.json()
if isinstance(body_data, dict):
for key in ["items", "urls", "images", "files"]:
if key in body_data and isinstance(body_data[key], list):
usage_count = len(body_data[key])
break
except Exception:
pass
if isinstance(result, StreamingResponse):
original_generator = result.body_iterator
success_count = 0
async def wrapped_generator():
nonlocal success_count
async for chunk in original_generator:
try:
data = json.loads(chunk)
if isinstance(data, dict) and data.get("status") == "success":
success_count += 1
except:
pass
yield chunk
if success_count > 0:
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
return StreamingResponse(
wrapped_generator(),
media_type=result.media_type,
headers=result.headers
)
if isinstance(result, dict) and result.get("success") is True:
actual_usage_count = result.get("successful_count", usage_count)
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
return result
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
return wrapper
return decorator
def is_valid_image_url(url: str) -> bool:
"""验证图片URL是否有效
Args:
url: 要验证的URL
Returns:
bool: URL是否有效
"""
if not url or not isinstance(url, str):
return False
try:
parsed = urlparse(url)
if not parsed.scheme or not parsed.netloc:
return False
# 检查文件扩展名
path = parsed.path.lower()
valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif']
return any(path.endswith(ext) for ext in valid_extensions)
except:
return False
def validate_image_file(file_path: str) -> bool:
"""验证图片文件是否有效
Args:
file_path: 图片文件路径
Returns:
bool: 文件是否有效
"""
try:
with Image.open(file_path) as img:
img.verify()
return True
except:
return False
def get_image_size(image_url: str) -> Optional[Tuple[int, int]]:
"""获取图片尺寸
Args:
image_url: 图片URL
Returns:
Optional[Tuple[int, int]]: 图片尺寸(,)如果获取失败则返回None
"""
try:
response = requests.get(image_url, verify=False, timeout=10)
if response.status_code != 200:
return None
with Image.open(io.BytesIO(response.content)) as img:
return img.size
except:
return None
def is_valid_image_size(image_url: str, min_size: int = 512) -> bool:
"""验证图片尺寸是否满足最小要求
Args:
image_url: 图片URL
min_size: 最小尺寸要求
Returns:
bool: 图片尺寸是否满足要求
"""
size = get_image_size(image_url)
if not size:
return False
width, height = size
return width >= min_size and height >= min_size
def extract_image_urls_from_text(text: str) -> List[str]:
"""从文本中提取图片URL
Args:
text: 包含图片URL的文本
Returns:
List[str]: 提取到的图片URL列表
"""
# 匹配常见的图片URL模式
url_pattern = r'https?://[^\s<>"]+?\.(?:jpg|jpeg|png|webp|gif)(?:\?[^\s<>"]*)?'
urls = re.findall(url_pattern, text, re.IGNORECASE)
return [url for url in urls if is_valid_image_url(url)]
def sanitize_filename(filename: str) -> str:
"""清理文件名,移除非法字符
Args:
filename: 原始文件名
Returns:
str: 清理后的文件名
"""
# 移除非法字符
filename = re.sub(r'[<>:"/\\|?*]', '', filename)
# 限制长度
if len(filename) > 255:
name, ext = os.path.splitext(filename)
filename = name[:255-len(ext)] + ext
return filename
def get_new_image_url(image_url: str) -> str:
"""将图片URL转换为新的存储URL
Args:
image_url: 原始图片URL
Returns:
str: 新的图片URL
Raises:
HTTPException: 当图片处理失败时抛出
"""
try:
# 使用settings中的upload_url
upload_url = settings.upload_url
if not upload_url:
raise HTTPException(status_code=500, detail="未配置上传URL")
# 下载图片
response = requests.get(image_url, verify=False, timeout=30)
if response.status_code != 200:
raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}")
image_data = response.content
# 解析文件名和扩展名
parsed_url = urlparse(image_url)
file_name = Path(parsed_url.path).name
file_name = sanitize_filename(file_name)
file_ext = Path(file_name).suffix.lower()
# 如果图片是webp格式转换为png格式
if file_ext == '.webp':
image = Image.open(io.BytesIO(image_data))
png_buffer = io.BytesIO()
image.save(png_buffer, format='PNG')
image_data = png_buffer.getvalue()
file_name = file_name.replace('.webp', '.png')
# 准备文件上传
files = {"file": (file_name, image_data)}
# 上传图片
upload_response = requests.post(upload_url, files=files, verify=False, timeout=30)
if upload_response.status_code != 200:
error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}"
raise HTTPException(status_code=500, detail=error_message)
result = upload_response.json()
new_url = result.get("url")
if not new_url:
raise HTTPException(status_code=500, detail="上传成功但未返回URL")
return new_url
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")