主要改进: 1. 速率限制处理: - 添加 RateLimitError 异常类,用于标识速率限制错误 - 在 ssl_manager.py 中检测 Let's Encrypt 速率限制错误 - 解析重试时间,提供详细的错误提示 - 在 route_watcher.py 中记录被限制的域名和重试时间 - 自动跳过限制期间的域名,避免持续触发限制 - 限制解除后自动恢复申请 2. 代码优化: - 修复重复导入 sys 的问题 - 修复 API 调用未使用 session 连接复用的问题 - 移除未使用的 _get_apisix_headers 方法 - 将 RateLimitError 导入移到文件顶部 优势: - 避免持续触发速率限制,形成死循环 - 自动等待限制解除,无需手动干预 - 提升代码质量和可维护性 - 充分利用 HTTP 连接复用,提升性能
305 lines
12 KiB
Python
Executable File
305 lines
12 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
APISIX 路由监听服务
|
||
监听路由创建事件,自动为域名申请 SSL 证书
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import logging
|
||
import requests
|
||
import ipaddress
|
||
from typing import Set, Optional, Dict
|
||
from datetime import datetime
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from ssl_manager import APISIXSSLManager, RateLimitError
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('/var/log/apisix-route-watcher.log'),
|
||
logging.StreamHandler(sys.stdout)
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RouteWatcher:
|
||
"""路由监听器"""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
"""初始化监听器"""
|
||
self.ssl_manager = APISIXSSLManager(config_path)
|
||
|
||
# 从环境变量或配置获取 APISIX 配置
|
||
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
|
||
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
|
||
|
||
# 创建 HTTP 会话,复用连接
|
||
self.session = requests.Session()
|
||
self.session.headers.update({
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
})
|
||
|
||
# 速率限制记录:{domain: retry_after_timestamp}
|
||
self.rate_limited_domains: Dict[str, float] = {}
|
||
|
||
def _get_apisix_headers(self):
|
||
"""获取 APISIX Admin API 请求头"""
|
||
return {
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
def get_all_routes(self) -> list:
|
||
"""获取所有路由"""
|
||
try:
|
||
response = self.session.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/routes",
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取路由失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取路由异常: {e}")
|
||
return []
|
||
|
||
def get_all_ssls(self) -> list:
|
||
"""获取所有 SSL 配置"""
|
||
try:
|
||
response = self.session.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/ssls",
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取 SSL 配置失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 SSL 配置异常: {e}")
|
||
return []
|
||
|
||
def extract_domains_from_route(self, route: dict) -> Set[str]:
|
||
"""从路由中提取域名"""
|
||
domains = set()
|
||
route_value = route.get('value', {})
|
||
|
||
# 从 host 字段提取(单个域名,Dashboard 常用)
|
||
host = route_value.get('host')
|
||
if host and isinstance(host, str):
|
||
domains.add(host)
|
||
|
||
# 从 hosts 字段提取(域名数组)
|
||
hosts = route_value.get('hosts', [])
|
||
if hosts:
|
||
if isinstance(hosts, list):
|
||
domains.update(hosts)
|
||
elif isinstance(hosts, str):
|
||
domains.add(hosts)
|
||
|
||
# 从 uri 字段提取(如果包含域名)
|
||
uri = route_value.get('uri', '')
|
||
if uri and '.' in uri and not uri.startswith('/'):
|
||
# 可能是域名格式
|
||
parts = uri.split('/')
|
||
if parts[0] and '.' in parts[0]:
|
||
domains.add(parts[0])
|
||
|
||
# 从 match 字段提取
|
||
match = route_value.get('match', {})
|
||
if isinstance(match, dict):
|
||
for key, value in match.items():
|
||
if 'host' in key.lower() and isinstance(value, str):
|
||
domains.add(value)
|
||
|
||
return domains
|
||
|
||
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
|
||
"""从 SSL 配置中提取域名"""
|
||
domains = set()
|
||
snis = ssl.get('value', {}).get('snis', [])
|
||
if snis:
|
||
domains.update(snis)
|
||
return domains
|
||
|
||
def _is_valid_domain(self, domain: str) -> bool:
|
||
"""检查是否为有效域名(非 IP 地址和本地域名)"""
|
||
# 跳过本地域名
|
||
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
|
||
return False
|
||
|
||
# 检查是否为 IP 地址(支持 IPv4 和 IPv6)
|
||
try:
|
||
ipaddress.ip_address(domain)
|
||
return False
|
||
except ValueError:
|
||
pass
|
||
|
||
return True
|
||
|
||
def _build_ssl_domains_set(self, ssls: list) -> Set[str]:
|
||
"""构建所有已配置 SSL 的域名集合(用于快速查找)"""
|
||
ssl_domains = set()
|
||
for ssl in ssls:
|
||
domains = self.extract_domains_from_ssl(ssl)
|
||
ssl_domains.update(domains)
|
||
return ssl_domains
|
||
|
||
def should_request_cert(self, domain: str, existing_ssl_domains: Set[str]) -> bool:
|
||
"""判断是否需要申请证书
|
||
|
||
Args:
|
||
domain: 要检查的域名
|
||
existing_ssl_domains: 已存在的 SSL 域名集合(用于快速查找)
|
||
"""
|
||
# 检查是否为有效域名
|
||
if not self._is_valid_domain(domain):
|
||
return False
|
||
|
||
# 检查是否已有 SSL 配置
|
||
if domain in existing_ssl_domains:
|
||
logger.info(f"域名已有 SSL 配置: {domain}")
|
||
return False
|
||
|
||
# 检查是否在速率限制期间
|
||
if domain in self.rate_limited_domains:
|
||
retry_after = self.rate_limited_domains[domain]
|
||
current_time = time.time()
|
||
if current_time < retry_after:
|
||
remaining_minutes = int((retry_after - current_time) / 60) + 1
|
||
logger.debug(f"域名 {domain} 仍在速率限制期间,剩余约 {remaining_minutes} 分钟后重试")
|
||
return False
|
||
else:
|
||
# 限制已解除,移除记录
|
||
del self.rate_limited_domains[domain]
|
||
logger.info(f"域名 {domain} 速率限制已解除,将重新尝试申请证书")
|
||
|
||
return True
|
||
|
||
def process_new_domains(self):
|
||
"""处理新域名"""
|
||
# 优化:只获取一次路由和 SSL 配置,避免重复 API 调用
|
||
routes = self.get_all_routes()
|
||
ssls = self.get_all_ssls()
|
||
|
||
# 构建已存在的 SSL 域名集合(用于快速查找)
|
||
existing_ssl_domains = self._build_ssl_domains_set(ssls)
|
||
|
||
# 按路由处理,同一路由的多个域名合并到一个证书
|
||
for route in routes:
|
||
route_value = route.get('value', {})
|
||
# 跳过禁用的路由(status=0 表示禁用)
|
||
if route_value.get('status') == 0:
|
||
continue
|
||
|
||
# 从路由中提取所有域名
|
||
domains = self.extract_domains_from_route(route)
|
||
if not domains:
|
||
continue
|
||
|
||
# 过滤出需要申请证书的域名(使用缓存的 SSL 域名集合)
|
||
domains_to_request = [d for d in domains if self.should_request_cert(d, existing_ssl_domains)]
|
||
|
||
if not domains_to_request:
|
||
continue
|
||
|
||
# 如果只有一个域名,单独申请
|
||
if len(domains_to_request) == 1:
|
||
domain = domains_to_request[0]
|
||
logger.info(f"发现新域名,准备申请证书: {domain}")
|
||
try:
|
||
if self.ssl_manager.request_certificate(domain):
|
||
logger.info(f"证书申请成功: {domain}")
|
||
# 申请成功,清除速率限制记录(如果存在)
|
||
self.rate_limited_domains.pop(domain, None)
|
||
else:
|
||
logger.error(f"证书申请失败: {domain}")
|
||
except Exception as e:
|
||
# 检查是否是速率限制错误
|
||
if isinstance(e, RateLimitError):
|
||
logger.warning(f"域名 {e.domain} 遇到速率限制,将在 {datetime.fromtimestamp(e.retry_after_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 后自动重试")
|
||
# 记录速率限制的域名和重试时间
|
||
self.rate_limited_domains[e.domain] = e.retry_after_timestamp
|
||
else:
|
||
logger.error(f"处理域名异常 {domain}: {e}")
|
||
else:
|
||
# 多个域名,合并到一个证书申请(使用 SAN)
|
||
primary_domain = domains_to_request[0]
|
||
additional_domains = domains_to_request[1:]
|
||
total_domains = len(additional_domains) + 1
|
||
logger.info(f"发现同一路由中的多个域名,合并申请证书: {primary_domain} + {additional_domains}")
|
||
try:
|
||
if self.ssl_manager.request_certificate(primary_domain, additional_domains):
|
||
logger.info(f"证书申请成功: {primary_domain} (包含 {total_domains} 个域名)")
|
||
# 申请成功,清除速率限制记录
|
||
for d in [primary_domain] + additional_domains:
|
||
self.rate_limited_domains.pop(d, None)
|
||
else:
|
||
logger.error(f"证书申请失败: {primary_domain} + {additional_domains}")
|
||
except Exception as e:
|
||
# 检查是否是速率限制错误
|
||
if isinstance(e, RateLimitError):
|
||
logger.warning(f"域名 {e.domain} 遇到速率限制,将在 {datetime.fromtimestamp(e.retry_after_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 后自动重试")
|
||
# 记录速率限制的域名和重试时间(所有相关域名)
|
||
for d in [primary_domain] + additional_domains:
|
||
self.rate_limited_domains[d] = e.retry_after_timestamp
|
||
else:
|
||
logger.error(f"处理域名异常 {primary_domain} + {additional_domains}: {e}")
|
||
|
||
def run(self, interval: int = 60):
|
||
"""运行监听服务"""
|
||
logger.info(f"路由监听服务启动,检查间隔: {interval} 秒")
|
||
|
||
while True:
|
||
try:
|
||
self.process_new_domains()
|
||
time.sleep(interval)
|
||
except KeyboardInterrupt:
|
||
logger.info("收到停止信号,退出服务")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"监听服务异常: {e}")
|
||
time.sleep(interval)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
|
||
parser.add_argument('--interval', '-i', type=int, default=60,
|
||
help='检查间隔(秒),默认 60')
|
||
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
|
||
parser.add_argument('--once', action='store_true',
|
||
help='只执行一次,不持续监听')
|
||
|
||
args = parser.parse_args()
|
||
|
||
watcher = RouteWatcher(args.config)
|
||
|
||
if args.once:
|
||
watcher.process_new_domains()
|
||
else:
|
||
watcher.run(args.interval)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|