增加midjourney微服务
This commit is contained in:
parent
4be051e459
commit
6e70ad3033
0
apps/midjourney/__init__.py
Normal file
0
apps/midjourney/__init__.py
Normal file
69
apps/midjourney/api.py
Normal file
69
apps/midjourney/api.py
Normal file
@ -0,0 +1,69 @@
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from service import MidjourneyService
|
||||
from utils import jingrow_api_verify_and_billing
|
||||
from settings import settings
|
||||
import json
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
router = APIRouter(prefix=settings.router_prefix)
|
||||
service = MidjourneyService()
|
||||
|
||||
@router.post(settings.generate_route)
|
||||
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||
async def generate_image(data: dict, request: Request):
|
||||
"""
|
||||
根据文本提示生成图像
|
||||
|
||||
Args:
|
||||
data: 包含文本提示和配置参数的字典
|
||||
request: FastAPI 请求对象
|
||||
|
||||
Returns:
|
||||
生成的图像内容
|
||||
"""
|
||||
if "prompt" not in data:
|
||||
raise HTTPException(status_code=400, detail="缺少prompt参数")
|
||||
|
||||
prompt = data["prompt"]
|
||||
config = data.get("config", {})
|
||||
|
||||
async def generate() -> AsyncGenerator[str, None]:
|
||||
async for result in service.generate_image(prompt, config):
|
||||
yield json.dumps(result, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Content-Type-Options": "nosniff"}
|
||||
)
|
||||
|
||||
@router.post(settings.batch_route)
|
||||
@jingrow_api_verify_and_billing(api_name=settings.api_name)
|
||||
async def batch_process_images(data: dict, request: Request):
|
||||
"""
|
||||
批量处理多个图像URL,将每张图片分割成4张并保存
|
||||
|
||||
Args:
|
||||
data: 包含图片URLs列表的字典
|
||||
request: FastAPI 请求对象
|
||||
|
||||
Returns:
|
||||
处理结果的流式响应
|
||||
"""
|
||||
if "image_urls" not in data or not isinstance(data["image_urls"], list):
|
||||
raise HTTPException(status_code=400, detail="缺少有效的image_urls参数")
|
||||
|
||||
image_urls: List[str] = data["image_urls"]
|
||||
config = data.get("config", {})
|
||||
|
||||
async def process() -> AsyncGenerator[str, None]:
|
||||
async for result in service.process_batch(image_urls, config):
|
||||
yield json.dumps(result, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
process(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"X-Content-Type-Options": "nosniff"}
|
||||
)
|
||||
21
apps/midjourney/app.py
Normal file
21
apps/midjourney/app.py
Normal file
@ -0,0 +1,21 @@
|
||||
from fastapi import FastAPI
|
||||
from settings import settings
|
||||
from api import router
|
||||
|
||||
app = FastAPI(
|
||||
title="Midjourney",
|
||||
description="Midjourney绘画服务API",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"app:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug
|
||||
)
|
||||
580
apps/midjourney/service.py
Normal file
580
apps/midjourney/service.py
Normal file
@ -0,0 +1,580 @@
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import io
|
||||
import requests
|
||||
import time
|
||||
import random
|
||||
import re
|
||||
import uuid
|
||||
import urllib.request
|
||||
import urllib3
|
||||
import traceback
|
||||
import logging
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
import mimetypes
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, AsyncGenerator, Optional
|
||||
from settings import settings
|
||||
from utils import get_new_image_url, is_valid_image_url
|
||||
|
||||
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger("midjourney_service")
|
||||
|
||||
# 禁用不安全请求警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
class MidjourneyService:
|
||||
def __init__(self):
|
||||
"""初始化MidjourneyService"""
|
||||
self.API_URL = settings.midjourney_api_url
|
||||
self.APPLICATION_ID = settings.midjourney_application_id
|
||||
self.DATA_ID = settings.midjourney_data_id
|
||||
self.DATA_VERSION = settings.midjourney_data_version
|
||||
self.SESSION_ID = settings.midjourney_session_id
|
||||
|
||||
# 设置代理
|
||||
self.proxies = {}
|
||||
if settings.http_proxy and settings.http_proxy.strip():
|
||||
self.proxies['http'] = settings.http_proxy
|
||||
if settings.https_proxy and settings.https_proxy.strip():
|
||||
self.proxies['https'] = settings.https_proxy
|
||||
|
||||
# 确保保存目录存在
|
||||
os.makedirs(settings.save_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def first_where(array, key, value=None):
|
||||
"""在数组中找到第一个匹配条件的项"""
|
||||
for item in array:
|
||||
if callable(key) and key(item):
|
||||
return item
|
||||
if isinstance(key, str) and item[key].startswith(value):
|
||||
return item
|
||||
return None
|
||||
|
||||
async def initialize_client(self, oauth_token, channel_id):
|
||||
"""初始化Discord客户端会话"""
|
||||
client = requests.Session()
|
||||
client.headers.update({
|
||||
'Authorization': oauth_token
|
||||
})
|
||||
|
||||
# 设置代理
|
||||
if self.proxies:
|
||||
client.proxies.update(self.proxies)
|
||||
|
||||
try:
|
||||
# 获取频道信息
|
||||
response = client.get(f'{self.API_URL}/channels/{channel_id}')
|
||||
data = response.json()
|
||||
guild_id = data['guild_id']
|
||||
|
||||
# 获取用户信息
|
||||
response = client.get(f'{self.API_URL}/users/@me')
|
||||
data = response.json()
|
||||
user_id = data['id']
|
||||
|
||||
return client, guild_id, user_id
|
||||
except Exception as e:
|
||||
logger.error(f"初始化Discord客户端失败: {str(e)}")
|
||||
traceback.print_exc()
|
||||
raise Exception(f"初始化Discord客户端失败: {str(e)}")
|
||||
|
||||
async def imagine(self, client, guild_id, channel_id, prompt, seed=None):
|
||||
"""发送imagine请求到Discord"""
|
||||
params = {
|
||||
'type': 2,
|
||||
'application_id': self.APPLICATION_ID,
|
||||
'guild_id': guild_id,
|
||||
'channel_id': channel_id,
|
||||
'session_id': self.SESSION_ID,
|
||||
'data': {
|
||||
'version': self.DATA_VERSION,
|
||||
'id': self.DATA_ID,
|
||||
'name': 'imagine',
|
||||
'type': 1,
|
||||
'options': [{
|
||||
'type': 3,
|
||||
'name': 'prompt',
|
||||
'value': prompt
|
||||
}],
|
||||
'application_command': {
|
||||
'id': self.DATA_ID,
|
||||
'application_id': self.APPLICATION_ID,
|
||||
'version': self.DATA_VERSION,
|
||||
'default_member_permissions': None,
|
||||
'type': 1,
|
||||
'nsfw': False,
|
||||
'name': 'imagine',
|
||||
'description': 'Create images with Midjourney',
|
||||
'dm_permission': True,
|
||||
'options': [{
|
||||
'type': 3,
|
||||
'name': 'prompt',
|
||||
'description': 'The prompt to imagine',
|
||||
'required': True
|
||||
}]
|
||||
},
|
||||
'attachments': []
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
r = client.post(f'{self.API_URL}/interactions', json=params)
|
||||
# 初始等待时间从5秒延长到30秒,给Discord足够的时间开始处理请求
|
||||
print(f"[生成] 已发送请求,等待30秒后开始轮询...")
|
||||
await asyncio.sleep(30)
|
||||
|
||||
# 轮询获取结果
|
||||
imagine_message = None
|
||||
count = 0
|
||||
last_progress = 0
|
||||
|
||||
# 轮询直到获取完整的结果或达到最大次数
|
||||
while count < settings.max_polling_attempts:
|
||||
# 轮询两种消息:
|
||||
# 1. 进度消息 - 用于更新进度
|
||||
progress_msg = await self.get_progress_message(client, channel_id, prompt, seed)
|
||||
if progress_msg:
|
||||
content = progress_msg.get('content', '')
|
||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||
if progress_match:
|
||||
progress_value = int(progress_match.group(1))
|
||||
# 只有进度有变化时才发送更新
|
||||
if progress_value > last_progress:
|
||||
last_progress = progress_value
|
||||
logger.info(f"生成进度: {progress_value}%")
|
||||
yield {
|
||||
"status": "progress",
|
||||
"progress": progress_value,
|
||||
"message_id": progress_msg.get('id'),
|
||||
"content": content
|
||||
}
|
||||
|
||||
# 2. 完成的消息 - 包含图像结果
|
||||
imagine_message = await self.get_imagine(client, channel_id, prompt, count, seed)
|
||||
if imagine_message:
|
||||
# 找到最终结果
|
||||
break
|
||||
|
||||
# 没有找到结果,继续等待
|
||||
logger.info(f"轮询尝试 {count+1}/{settings.max_polling_attempts}: 继续等待")
|
||||
await asyncio.sleep(settings.polling_interval)
|
||||
count += 1
|
||||
|
||||
# 检查是否超过最大轮询次数
|
||||
if count >= settings.max_polling_attempts:
|
||||
logger.error(f"轮询超过最大尝试次数: {settings.max_polling_attempts}")
|
||||
yield {"status": "error", "message": "获取结果超时,超过最大轮询次数"}
|
||||
return
|
||||
|
||||
# 检查是否有有效的最终结果
|
||||
if not imagine_message:
|
||||
logger.error("轮询结束但没有获取到有效结果")
|
||||
yield {"status": "error", "message": "没有获取到有效结果"}
|
||||
return
|
||||
|
||||
# 返回最终结果
|
||||
logger.info(f"成功获取图像结果,消息ID: {imagine_message.get('id')}")
|
||||
yield {
|
||||
"status": "success",
|
||||
"message_id": imagine_message.get('id'),
|
||||
"content": imagine_message.get('content', ''),
|
||||
"images": self.extract_image_urls(imagine_message)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送Imagine请求失败: {str(e)}")
|
||||
traceback.print_exc()
|
||||
yield {"status": "error", "message": f"发送Imagine请求失败: {str(e)}"}
|
||||
|
||||
async def get_imagine(self, client, channel_id, prompt, count=0, seed=None):
|
||||
"""获取生成图像的消息"""
|
||||
try:
|
||||
# 获取最近的消息
|
||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||||
data = response.json()
|
||||
|
||||
def criteria(item):
|
||||
content = item.get('content', '')
|
||||
|
||||
# 检查进度信息并更新状态
|
||||
if seed is not None and f"--seed {seed}" in content:
|
||||
# 匹配百分比,支持多种格式:(93%) 或 93%
|
||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||
if progress_match:
|
||||
progress_value = int(progress_match.group(1))
|
||||
# 记录进度信息
|
||||
logger.info(f"任务进度: {progress_value}%")
|
||||
# 如果消息包含百分比,说明任务还在进行中,返回False继续轮询
|
||||
if "%" in content:
|
||||
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 包含进度信息{progress_value}%,继续等待完成")
|
||||
return False
|
||||
|
||||
# 排除进行中消息,只匹配完成的消息
|
||||
if "%" in content:
|
||||
return False
|
||||
|
||||
# seed 匹配
|
||||
if seed is not None:
|
||||
seed_pattern = f"--seed {seed}"
|
||||
# 检查是否包含指定的seed
|
||||
if seed_pattern not in content:
|
||||
return False
|
||||
else:
|
||||
print(f"[轮询] 消息ID: {item.get('id')} 内容: {content[:100]}... 匹配seed并且任务已完成")
|
||||
# 检查是否包含图像附件
|
||||
if 'attachments' in item and len(item.get('attachments', [])) > 0:
|
||||
return True
|
||||
return False # 默认不匹配
|
||||
|
||||
# 查找匹配完成状态的消息
|
||||
raw_message = self.first_where(data, criteria)
|
||||
if raw_message is None:
|
||||
print("[轮询] 没有找到完成的消息,继续等待")
|
||||
return None
|
||||
|
||||
# 打印匹配到的完整消息内容
|
||||
print(f"[轮询] 匹配到完成的消息ID: {raw_message.get('id')}")
|
||||
try:
|
||||
print(f"[轮询] 完整消息内容: {json.dumps(raw_message, indent=2, ensure_ascii=False)}")
|
||||
except Exception as e:
|
||||
print(f"[轮询] 无法打印完整消息内容: {str(e)}")
|
||||
# 尝试打印关键字段
|
||||
print(f"[轮询] 消息内容: {raw_message.get('content', '无内容')}")
|
||||
print(f"[轮询] 附件数量: {len(raw_message.get('attachments', []))}")
|
||||
for i, attachment in enumerate(raw_message.get('attachments', [])):
|
||||
print(f"[轮询] 附件 {i+1} URL: {attachment.get('url', '无URL')}")
|
||||
|
||||
return raw_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取通道消息失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_progress_message(self, client, channel_id, prompt, seed=None):
|
||||
"""获取包含进度信息的消息"""
|
||||
try:
|
||||
# 获取最近的消息
|
||||
response = client.get(f'{self.API_URL}/channels/{channel_id}/messages?limit=10')
|
||||
data = response.json()
|
||||
|
||||
for item in data:
|
||||
content = item.get('content', '')
|
||||
# 检查是否包含seed和进度信息
|
||||
if seed is not None and f"--seed {seed}" in content and "%" in content:
|
||||
progress_match = re.search(r'(?:\(|()?(\d+)%(?:\)|))?', content)
|
||||
if progress_match:
|
||||
return item
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取进度消息失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def extract_image_urls(self, message):
|
||||
"""从消息中提取图像URL"""
|
||||
image_urls = []
|
||||
if 'attachments' in message and message['attachments']:
|
||||
for attachment in message['attachments']:
|
||||
if 'url' in attachment:
|
||||
image_urls.append(attachment['url'])
|
||||
return image_urls
|
||||
|
||||
async def split_image(self, image_url):
|
||||
"""将一张大图切割成四张子图,保存到本地并返回URLs"""
|
||||
try:
|
||||
# 下载图像
|
||||
response = requests.get(image_url, proxies=self.proxies if self.proxies else None)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"下载图像失败,状态码: {response.status_code}")
|
||||
return None
|
||||
|
||||
image_data = response.content
|
||||
|
||||
# 从二进制数据创建图像对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
width, height = img.size
|
||||
|
||||
# 确认图像尺寸约为2048x2048
|
||||
if width < 1500 or height < 1500:
|
||||
logger.error(f"图像尺寸不符合预期: {width}x{height}")
|
||||
return None
|
||||
|
||||
# 计算每个象限的尺寸
|
||||
half_width = width // 2
|
||||
half_height = height // 2
|
||||
|
||||
# 分割图像为四个象限
|
||||
top_left = img.crop((0, 0, half_width, half_height))
|
||||
top_right = img.crop((half_width, 0, width, half_height))
|
||||
bottom_left = img.crop((0, half_height, half_width, height))
|
||||
bottom_right = img.crop((half_width, half_height, width, height))
|
||||
|
||||
# 生成唯一的图片名称前缀
|
||||
image_id = uuid.uuid4().hex[:10]
|
||||
|
||||
# 保存图片到本地并生成URLs
|
||||
image_urls = []
|
||||
for i, quadrant in enumerate([top_left, top_right, bottom_left, bottom_right], 1):
|
||||
# 生成文件名和保存路径
|
||||
filename = f"split_{image_id}_{i}.png"
|
||||
file_path = os.path.join(settings.save_dir, filename)
|
||||
|
||||
# 保存图片
|
||||
quadrant.save(file_path, format="PNG")
|
||||
|
||||
# 构建图片URL
|
||||
image_url = f"{settings.download_url}/{filename}"
|
||||
image_urls.append(image_url)
|
||||
|
||||
return image_urls
|
||||
except Exception as e:
|
||||
logger.error(f"分割图像失败: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
async def process_batch(self, image_urls: List[str], config: Optional[Dict] = None):
|
||||
"""批量处理多个图像URL"""
|
||||
if not config:
|
||||
config = {}
|
||||
|
||||
total = len(image_urls)
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
for i, image_url in enumerate(image_urls, 1):
|
||||
try:
|
||||
if not is_valid_image_url(image_url):
|
||||
error_count += 1
|
||||
response = {
|
||||
"status": "error",
|
||||
"index": i,
|
||||
"total": total,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"message": "无效的图片URL"
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
continue
|
||||
|
||||
split_urls = await self.split_image(image_url)
|
||||
if split_urls and len(split_urls) == 4:
|
||||
success_count += 1
|
||||
response = {
|
||||
"status": "success",
|
||||
"index": i,
|
||||
"total": total,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"images": split_urls
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
else:
|
||||
error_count += 1
|
||||
response = {
|
||||
"status": "error",
|
||||
"index": i,
|
||||
"total": total,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"message": "分割图片失败"
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
response = {
|
||||
"status": "error",
|
||||
"index": i,
|
||||
"total": total,
|
||||
"success_count": success_count,
|
||||
"error_count": error_count,
|
||||
"message": str(e)
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
|
||||
async def generate_image(self, prompt, config=None):
|
||||
"""生成图像并以流式方式返回结果"""
|
||||
if not config:
|
||||
config = {}
|
||||
|
||||
# 获取必要的认证信息
|
||||
oauth_token = config.get('oauth_token', settings.midjourney_oauth_token)
|
||||
channel_id = config.get('channel_id', settings.midjourney_channel_id)
|
||||
|
||||
if not oauth_token or not channel_id:
|
||||
response = {
|
||||
"status": "error",
|
||||
"message": "缺少Discord配置",
|
||||
"success_count": 0,
|
||||
"error_count": 1
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
return
|
||||
|
||||
# 初始化客户端
|
||||
try:
|
||||
client, guild_id, user_id = await self.initialize_client(oauth_token, channel_id)
|
||||
except Exception as e:
|
||||
response = {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"success_count": 0,
|
||||
"error_count": 1
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
return
|
||||
|
||||
# 解析和准备提示词
|
||||
prompt = prompt.strip()
|
||||
|
||||
# 获取或设置默认选项
|
||||
options = settings.midjourney_default_options.copy()
|
||||
if 'options' in config:
|
||||
options.update(config.get('options', {}))
|
||||
|
||||
# 强制设置seed如果没有
|
||||
if "seed" not in options or not options.get('seed'):
|
||||
options['seed'] = random.randint(0, 4294967295)
|
||||
|
||||
# 处理选项,构建参数字符串
|
||||
seed = options.pop('seed', None)
|
||||
parameter = ""
|
||||
no_value_key = ['relax', 'fast', 'turbo', 'tile']
|
||||
for key, value in options.items():
|
||||
if key in no_value_key:
|
||||
parameter += f" --{key}"
|
||||
else:
|
||||
parameter += f" --{key} {value}"
|
||||
|
||||
# 添加seed
|
||||
if seed:
|
||||
parameter += f" --seed {seed}"
|
||||
# 打印使用的seed值
|
||||
print(f"[生成] 使用的seed值: {seed}")
|
||||
|
||||
# 处理参考图像
|
||||
if 'reference_images' in config and config['reference_images']:
|
||||
# 确保是列表格式
|
||||
reference_images = config['reference_images']
|
||||
if not isinstance(reference_images, list):
|
||||
reference_images = [reference_images]
|
||||
|
||||
# 转换图片URL
|
||||
image_urls = []
|
||||
for image_url in reference_images:
|
||||
if is_valid_image_url(image_url):
|
||||
try:
|
||||
new_url = get_new_image_url(image_url)
|
||||
if new_url:
|
||||
image_urls.append(new_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换图片URL失败: {str(e)}")
|
||||
|
||||
# 添加到prompt前面
|
||||
if image_urls:
|
||||
prompt = " ".join(image_urls) + " " + prompt
|
||||
|
||||
# 设置图像权重
|
||||
if 'image_weight' in config and isinstance(config['image_weight'], (int, float)):
|
||||
iw = max(0.1, min(config['image_weight'], 3)) # 限制在0.1到3之间
|
||||
parameter += f" --iw {iw}"
|
||||
|
||||
# 添加字符引用
|
||||
if 'characters' in config and config['characters']:
|
||||
char_urls = []
|
||||
characters = config['characters'] if isinstance(config['characters'], list) else [config['characters']]
|
||||
|
||||
for char_url in characters:
|
||||
if is_valid_image_url(char_url):
|
||||
try:
|
||||
new_url = get_new_image_url(char_url)
|
||||
if new_url:
|
||||
char_urls.append(new_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换角色图片URL失败: {str(e)}")
|
||||
|
||||
if char_urls:
|
||||
prompt = prompt + " --cref " + " ".join(char_urls)
|
||||
|
||||
# 添加风格引用
|
||||
if 'styles' in config and config['styles']:
|
||||
style_urls = []
|
||||
styles = config['styles'] if isinstance(config['styles'], list) else [config['styles']]
|
||||
|
||||
for style_url in styles:
|
||||
if is_valid_image_url(style_url):
|
||||
try:
|
||||
new_url = get_new_image_url(style_url)
|
||||
if new_url:
|
||||
style_urls.append(new_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"转换风格图片URL失败: {str(e)}")
|
||||
|
||||
if style_urls:
|
||||
prompt = prompt + " --sref " + " ".join(style_urls)
|
||||
|
||||
# 添加参数
|
||||
prompt = prompt + parameter
|
||||
|
||||
# 流式返回结果
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
async for result in self.imagine(client, guild_id, channel_id, prompt, seed):
|
||||
if result.get("status") == "progress":
|
||||
# 进度信息保持简单
|
||||
response = {
|
||||
"status": "progress",
|
||||
"progress": result.get("progress", 0),
|
||||
"success_count": success_count,
|
||||
"error_count": error_count
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
elif result.get("status") == "success":
|
||||
success_count += 1
|
||||
response = {
|
||||
"status": "success",
|
||||
"success_count": success_count,
|
||||
"error_count": error_count
|
||||
}
|
||||
|
||||
# 如果需要分割图片
|
||||
if config.get("split_image", False) and result.get("images"):
|
||||
try:
|
||||
orig_image_url = result["images"][0]
|
||||
split_urls = await self.split_image(orig_image_url)
|
||||
if split_urls:
|
||||
response["images"] = split_urls
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
logger.error(f"分割图像失败: {str(e)}")
|
||||
response["error_count"] = error_count
|
||||
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
else:
|
||||
# 错误信息保持简单
|
||||
error_count += 1
|
||||
response = {
|
||||
"status": "error",
|
||||
"message": result.get("message", "未知错误"),
|
||||
"success_count": success_count,
|
||||
"error_count": error_count
|
||||
}
|
||||
print(f"[服务端响应] {json.dumps(response, ensure_ascii=False)}")
|
||||
yield response
|
||||
66
apps/midjourney/settings.py
Normal file
66
apps/midjourney/settings.py
Normal file
@ -0,0 +1,66 @@
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional, Dict
|
||||
from functools import lru_cache
|
||||
|
||||
class Settings(BaseSettings):
|
||||
# Japi Server 配置
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8113
|
||||
debug: bool = True
|
||||
|
||||
# API路由配置
|
||||
router_prefix: str = "/midjourney"
|
||||
generate_route: str = "/generate" # 生成图片的路由
|
||||
batch_route: str = "/batch" # 批量处理图片的路由
|
||||
api_name: str = "midjourney" # 默认API名称
|
||||
|
||||
upload_url: str = "http://images.jingrow.com:8080/api/v1/image"
|
||||
|
||||
# 图片保存配置
|
||||
save_dir: str = "../jfile/midjourney"
|
||||
# Japi 静态资源下载URL
|
||||
download_url: str = "http://api.jingrow.com:9080/midjourney"
|
||||
|
||||
# Jingrow Jcloud API 配置
|
||||
jingrow_api_url: str = "https://cloud.jingrow.com"
|
||||
jingrow_api_key: Optional[str] = None
|
||||
jingrow_api_secret: Optional[str] = None
|
||||
|
||||
# Discord Midjourney配置
|
||||
midjourney_api_url: str = "https://discord.com/api/v9"
|
||||
midjourney_application_id: str = "936929561302675456"
|
||||
midjourney_data_id: str = "938956540159881230"
|
||||
midjourney_data_version: str = "1237876415471554623"
|
||||
midjourney_session_id: str = "a64ede0f3ce497d949e2f6f195c19029"
|
||||
midjourney_channel_id: str = "1259838588510670941"
|
||||
midjourney_oauth_token: str = "MTA4NzQ0MDY0MTU5MzcxNjc0Ng.GVDauj.6Cwr5EpXOfN9FpQU0-VfteR56XQOwLLUGYovG0"
|
||||
midjourney_suffix: str = "mj" # 图片文件名的后缀
|
||||
|
||||
# 代理配置
|
||||
http_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTP代理
|
||||
https_proxy: Optional[str] = "http://127.0.0.1:1080" # 默认HTTPS代理
|
||||
|
||||
# Midjourney默认选项
|
||||
midjourney_default_options: Dict = {
|
||||
"ar": "1:1",
|
||||
"v": "6.1",
|
||||
"quality": "1"
|
||||
}
|
||||
|
||||
# 图像设置
|
||||
add_title_to_image_name: bool = False # 是否将标题添加到图片名称中
|
||||
|
||||
# 超时设置(秒)
|
||||
request_timeout: int = 30
|
||||
max_polling_attempts: int = 60
|
||||
polling_interval: int = 3
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
# 创建全局配置实例
|
||||
settings = get_settings()
|
||||
316
apps/midjourney/utils.py
Normal file
316
apps/midjourney/utils.py
Normal file
@ -0,0 +1,316 @@
|
||||
import aiohttp
|
||||
from functools import wraps
|
||||
from fastapi import HTTPException
|
||||
import os
|
||||
from typing import Callable, Any, Dict, Optional, Tuple, List
|
||||
from settings import settings
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
import requests
|
||||
import io
|
||||
import re
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from PIL import Image
|
||||
|
||||
async def verify_api_credentials_and_balance(api_key: str, api_secret: str, api_name: str) -> Dict[str, Any]:
|
||||
"""验证API密钥和团队余额"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.verify_api_credentials_and_balance",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={"api_key": api_key, "api_secret": api_secret, "api_name": api_name}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="验证服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(status_code=401, detail=result.get("message", "验证失败"))
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"验证服务暂时不可用: {str(e)}")
|
||||
|
||||
async def deduct_jingrow_api_usage_fee(api_key: str, api_secret: str, api_name: str, usage_count: int = 1) -> Dict[str, Any]:
|
||||
"""从Jingrow平台扣除API使用费"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{settings.jingrow_api_url}/api/method/jcloud.api.account.deduct_api_usage_fee",
|
||||
headers={"Authorization": f"token {settings.jingrow_api_key}:{settings.jingrow_api_secret}"},
|
||||
json={
|
||||
"api_key": api_key,
|
||||
"api_secret": api_secret,
|
||||
"api_name": api_name,
|
||||
"usage_count": usage_count
|
||||
}
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
raise HTTPException(status_code=500, detail="扣费服务暂时不可用")
|
||||
|
||||
result = await response.json()
|
||||
if "message" in result and isinstance(result["message"], dict):
|
||||
result = result["message"]
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
return {"success": False, "message": f"扣费服务暂时不可用: {str(e)}"}
|
||||
|
||||
def get_token_from_request(request) -> str:
|
||||
"""从请求中获取访问令牌"""
|
||||
if not request:
|
||||
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header or not auth_header.startswith("token "):
|
||||
raise HTTPException(status_code=401, detail="无效的Authorization头格式")
|
||||
|
||||
token = auth_header[6:]
|
||||
if ":" not in token:
|
||||
raise HTTPException(status_code=401, detail="无效的令牌格式")
|
||||
|
||||
return token
|
||||
|
||||
def jingrow_api_verify_and_billing(api_name: str):
|
||||
"""Jingrow API 验证装饰器(带余额检查和扣费)"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
request = kwargs.get('request')
|
||||
if not request:
|
||||
raise HTTPException(status_code=400, detail="无法获取请求信息")
|
||||
|
||||
token = get_token_from_request(request)
|
||||
api_key, api_secret = token.split(":", 1)
|
||||
|
||||
verify_result = await verify_api_credentials_and_balance(api_key, api_secret, api_name)
|
||||
if not verify_result.get("success"):
|
||||
raise HTTPException(status_code=401, detail=verify_result.get("message", "验证失败"))
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
usage_count = 1
|
||||
try:
|
||||
body_data = await request.json()
|
||||
if isinstance(body_data, dict):
|
||||
for key in ["items", "urls", "images", "files"]:
|
||||
if key in body_data and isinstance(body_data[key], list):
|
||||
usage_count = len(body_data[key])
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(result, StreamingResponse):
|
||||
original_generator = result.body_iterator
|
||||
success_count = 0
|
||||
|
||||
async def wrapped_generator():
|
||||
nonlocal success_count
|
||||
async for chunk in original_generator:
|
||||
try:
|
||||
data = json.loads(chunk)
|
||||
if isinstance(data, dict) and data.get("status") == "success":
|
||||
success_count += 1
|
||||
except:
|
||||
pass
|
||||
yield chunk
|
||||
|
||||
if success_count > 0:
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, success_count)
|
||||
|
||||
return StreamingResponse(
|
||||
wrapped_generator(),
|
||||
media_type=result.media_type,
|
||||
headers=result.headers
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and result.get("success") is True:
|
||||
actual_usage_count = result.get("successful_count", usage_count)
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, actual_usage_count)
|
||||
return result
|
||||
|
||||
await deduct_jingrow_api_usage_fee(api_key, api_secret, api_name, usage_count)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"API验证过程发生异常: {str(e)}")
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def is_valid_image_url(url: str) -> bool:
|
||||
"""验证图片URL是否有效
|
||||
|
||||
Args:
|
||||
url: 要验证的URL
|
||||
|
||||
Returns:
|
||||
bool: URL是否有效
|
||||
"""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
return False
|
||||
|
||||
# 检查文件扩展名
|
||||
path = parsed.path.lower()
|
||||
valid_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.gif']
|
||||
return any(path.endswith(ext) for ext in valid_extensions)
|
||||
except:
|
||||
return False
|
||||
|
||||
def validate_image_file(file_path: str) -> bool:
|
||||
"""验证图片文件是否有效
|
||||
|
||||
Args:
|
||||
file_path: 图片文件路径
|
||||
|
||||
Returns:
|
||||
bool: 文件是否有效
|
||||
"""
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
img.verify()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def get_image_size(image_url: str) -> Optional[Tuple[int, int]]:
|
||||
"""获取图片尺寸
|
||||
|
||||
Args:
|
||||
image_url: 图片URL
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, int]]: 图片尺寸(宽,高),如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
response = requests.get(image_url, verify=False, timeout=10)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
with Image.open(io.BytesIO(response.content)) as img:
|
||||
return img.size
|
||||
except:
|
||||
return None
|
||||
|
||||
def is_valid_image_size(image_url: str, min_size: int = 512) -> bool:
|
||||
"""验证图片尺寸是否满足最小要求
|
||||
|
||||
Args:
|
||||
image_url: 图片URL
|
||||
min_size: 最小尺寸要求
|
||||
|
||||
Returns:
|
||||
bool: 图片尺寸是否满足要求
|
||||
"""
|
||||
size = get_image_size(image_url)
|
||||
if not size:
|
||||
return False
|
||||
width, height = size
|
||||
return width >= min_size and height >= min_size
|
||||
|
||||
def extract_image_urls_from_text(text: str) -> List[str]:
|
||||
"""从文本中提取图片URL
|
||||
|
||||
Args:
|
||||
text: 包含图片URL的文本
|
||||
|
||||
Returns:
|
||||
List[str]: 提取到的图片URL列表
|
||||
"""
|
||||
# 匹配常见的图片URL模式
|
||||
url_pattern = r'https?://[^\s<>"]+?\.(?:jpg|jpeg|png|webp|gif)(?:\?[^\s<>"]*)?'
|
||||
urls = re.findall(url_pattern, text, re.IGNORECASE)
|
||||
return [url for url in urls if is_valid_image_url(url)]
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""清理文件名,移除非法字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
str: 清理后的文件名
|
||||
"""
|
||||
# 移除非法字符
|
||||
filename = re.sub(r'[<>:"/\\|?*]', '', filename)
|
||||
# 限制长度
|
||||
if len(filename) > 255:
|
||||
name, ext = os.path.splitext(filename)
|
||||
filename = name[:255-len(ext)] + ext
|
||||
return filename
|
||||
|
||||
def get_new_image_url(image_url: str) -> str:
|
||||
"""将图片URL转换为新的存储URL
|
||||
|
||||
Args:
|
||||
image_url: 原始图片URL
|
||||
|
||||
Returns:
|
||||
str: 新的图片URL
|
||||
|
||||
Raises:
|
||||
HTTPException: 当图片处理失败时抛出
|
||||
"""
|
||||
try:
|
||||
# 使用settings中的upload_url
|
||||
upload_url = settings.upload_url
|
||||
if not upload_url:
|
||||
raise HTTPException(status_code=500, detail="未配置上传URL")
|
||||
|
||||
# 下载图片
|
||||
response = requests.get(image_url, verify=False, timeout=30)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(status_code=400, detail=f"无法下载图片: HTTP {response.status_code}")
|
||||
image_data = response.content
|
||||
|
||||
# 解析文件名和扩展名
|
||||
parsed_url = urlparse(image_url)
|
||||
file_name = Path(parsed_url.path).name
|
||||
file_name = sanitize_filename(file_name)
|
||||
file_ext = Path(file_name).suffix.lower()
|
||||
|
||||
# 如果图片是webp格式,转换为png格式
|
||||
if file_ext == '.webp':
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
png_buffer = io.BytesIO()
|
||||
image.save(png_buffer, format='PNG')
|
||||
image_data = png_buffer.getvalue()
|
||||
file_name = file_name.replace('.webp', '.png')
|
||||
|
||||
# 准备文件上传
|
||||
files = {"file": (file_name, image_data)}
|
||||
|
||||
# 上传图片
|
||||
upload_response = requests.post(upload_url, files=files, verify=False, timeout=30)
|
||||
if upload_response.status_code != 200:
|
||||
error_message = f"图片URL转换失败: 状态码 {upload_response.status_code}, 响应: {upload_response.text[:200]}"
|
||||
raise HTTPException(status_code=500, detail=error_message)
|
||||
|
||||
result = upload_response.json()
|
||||
new_url = result.get("url")
|
||||
if not new_url:
|
||||
raise HTTPException(status_code=500, detail="上传成功但未返回URL")
|
||||
return new_url
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"图片URL转换异常: {str(e)}")
|
||||
Loading…
x
Reference in New Issue
Block a user