问题: - 限制解除后立即重试,可能再次触发速率限制 - 没有缓冲时间,容易导致连续失败 修复: - 限制解除后,增加15分钟缓冲期 - 缓冲期内不申请证书,避免立即重试 - 优化检查间隔计算,在缓冲期结束后才重试 实现: - should_request_cert: 检查缓冲期,缓冲期内返回 False - run: 计算重试时间时加上15分钟缓冲期 - 实际重试时间 = 限制解除时间 + 15分钟 优势: - 避免立即重试再次触发限制 - 给 Let's Encrypt 足够的恢复时间 - 提高证书申请成功率
335 lines
13 KiB
Python
Executable File
335 lines
13 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
APISIX 路由监听服务
|
||
监听路由创建事件,自动为域名申请 SSL 证书
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import logging
|
||
import requests
|
||
import ipaddress
|
||
from typing import Set, Optional, Dict, List
|
||
from datetime import datetime
|
||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from ssl_manager import APISIXSSLManager, RateLimitError
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('/var/log/apisix-route-watcher.log'),
|
||
logging.StreamHandler(sys.stdout)
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RouteWatcher:
|
||
"""路由监听器"""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
"""初始化监听器"""
|
||
self.ssl_manager = APISIXSSLManager(config_path)
|
||
|
||
# 从环境变量或配置获取 APISIX 配置
|
||
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
|
||
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
|
||
|
||
# 创建 HTTP 会话,复用连接
|
||
self.session = requests.Session()
|
||
self.session.headers.update({
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
})
|
||
|
||
# 速率限制记录:{domain: retry_after_timestamp}
|
||
self.rate_limited_domains: Dict[str, float] = {}
|
||
|
||
def get_all_routes(self) -> list:
|
||
"""获取所有路由"""
|
||
try:
|
||
response = self.session.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/routes",
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取路由失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取路由异常: {e}")
|
||
return []
|
||
|
||
def get_all_ssls(self) -> list:
|
||
"""获取所有 SSL 配置"""
|
||
try:
|
||
response = self.session.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/ssls",
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取 SSL 配置失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 SSL 配置异常: {e}")
|
||
return []
|
||
|
||
def extract_domains_from_route(self, route: dict) -> Set[str]:
|
||
"""从路由中提取域名"""
|
||
domains = set()
|
||
route_value = route.get('value', {})
|
||
|
||
# 从 host 字段提取(单个域名,Dashboard 常用)
|
||
host = route_value.get('host')
|
||
if host and isinstance(host, str):
|
||
domains.add(host)
|
||
|
||
# 从 hosts 字段提取(域名数组)
|
||
hosts = route_value.get('hosts', [])
|
||
if hosts:
|
||
if isinstance(hosts, list):
|
||
domains.update(hosts)
|
||
elif isinstance(hosts, str):
|
||
domains.add(hosts)
|
||
|
||
# 从 uri 字段提取(如果包含域名)
|
||
uri = route_value.get('uri', '')
|
||
if uri and '.' in uri and not uri.startswith('/'):
|
||
# 可能是域名格式
|
||
parts = uri.split('/')
|
||
if parts[0] and '.' in parts[0]:
|
||
domains.add(parts[0])
|
||
|
||
# 从 match 字段提取
|
||
match = route_value.get('match', {})
|
||
if isinstance(match, dict):
|
||
for key, value in match.items():
|
||
if 'host' in key.lower() and isinstance(value, str):
|
||
domains.add(value)
|
||
|
||
return domains
|
||
|
||
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
|
||
"""从 SSL 配置中提取域名"""
|
||
domains = set()
|
||
snis = ssl.get('value', {}).get('snis', [])
|
||
if snis:
|
||
domains.update(snis)
|
||
return domains
|
||
|
||
def _is_valid_domain(self, domain: str) -> bool:
|
||
"""检查是否为有效域名(非 IP 地址和本地域名)"""
|
||
# 跳过本地域名
|
||
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
|
||
return False
|
||
|
||
# 检查是否为 IP 地址(支持 IPv4 和 IPv6)
|
||
try:
|
||
ipaddress.ip_address(domain)
|
||
return False
|
||
except ValueError:
|
||
pass
|
||
|
||
return True
|
||
|
||
def _build_ssl_domains_set(self, ssls: list) -> Set[str]:
|
||
"""构建所有已配置 SSL 的域名集合(用于快速查找)"""
|
||
ssl_domains = set()
|
||
for ssl in ssls:
|
||
domains = self.extract_domains_from_ssl(ssl)
|
||
ssl_domains.update(domains)
|
||
return ssl_domains
|
||
|
||
def should_request_cert(self, domain: str, existing_ssl_domains: Set[str]) -> bool:
|
||
"""判断是否需要申请证书
|
||
|
||
Args:
|
||
domain: 要检查的域名
|
||
existing_ssl_domains: 已存在的 SSL 域名集合(用于快速查找)
|
||
"""
|
||
# 检查是否为有效域名
|
||
if not self._is_valid_domain(domain):
|
||
return False
|
||
|
||
# 检查是否已有 SSL 配置
|
||
if domain in existing_ssl_domains:
|
||
logger.info(f"域名已有 SSL 配置: {domain}")
|
||
return False
|
||
|
||
# 检查是否在速率限制期间
|
||
if domain in self.rate_limited_domains:
|
||
retry_after = self.rate_limited_domains[domain]
|
||
current_time = time.time()
|
||
if current_time < retry_after:
|
||
remaining_minutes = int((retry_after - current_time) / 60) + 1
|
||
logger.debug(f"域名 {domain} 仍在速率限制期间,剩余约 {remaining_minutes} 分钟后重试")
|
||
return False
|
||
else:
|
||
# 限制已解除,但再等待15分钟缓冲时间,避免立即重试再次触发限制
|
||
buffer_seconds = 15 * 60 # 15分钟缓冲
|
||
if current_time < retry_after + buffer_seconds:
|
||
remaining_minutes = int((retry_after + buffer_seconds - current_time) / 60) + 1
|
||
logger.debug(f"域名 {domain} 限制已解除,等待缓冲期,剩余约 {remaining_minutes} 分钟后重试")
|
||
return False
|
||
else:
|
||
# 缓冲期已过,移除记录,允许重试
|
||
del self.rate_limited_domains[domain]
|
||
logger.info(f"域名 {domain} 速率限制和缓冲期已过,将重新尝试申请证书")
|
||
|
||
return True
|
||
|
||
def _handle_cert_request(self, primary_domain: str, additional_domains: List[str] = None):
|
||
"""处理证书申请(单个或多个域名)
|
||
|
||
Args:
|
||
primary_domain: 主域名
|
||
additional_domains: 额外域名列表(可选)
|
||
|
||
Returns:
|
||
bool: 是否成功
|
||
"""
|
||
domains_list = [primary_domain] + (additional_domains or [])
|
||
|
||
if additional_domains:
|
||
total_domains = len(additional_domains) + 1
|
||
logger.info(f"发现同一路由中的多个域名,合并申请证书: {primary_domain} + {additional_domains}")
|
||
else:
|
||
logger.info(f"发现新域名,准备申请证书: {primary_domain}")
|
||
|
||
try:
|
||
if self.ssl_manager.request_certificate(primary_domain, additional_domains):
|
||
if additional_domains:
|
||
logger.info(f"证书申请成功: {primary_domain} (包含 {len(additional_domains) + 1} 个域名)")
|
||
else:
|
||
logger.info(f"证书申请成功: {primary_domain}")
|
||
# 申请成功,清除速率限制记录
|
||
for d in domains_list:
|
||
self.rate_limited_domains.pop(d, None)
|
||
return True
|
||
else:
|
||
if additional_domains:
|
||
logger.error(f"证书申请失败: {primary_domain} + {additional_domains}")
|
||
else:
|
||
logger.error(f"证书申请失败: {primary_domain}")
|
||
return False
|
||
except RateLimitError as e:
|
||
logger.warning(f"域名 {e.domain} 遇到速率限制,将在 {datetime.fromtimestamp(e.retry_after_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 后自动重试")
|
||
# 记录速率限制的域名和重试时间(所有相关域名)
|
||
for d in domains_list:
|
||
self.rate_limited_domains[d] = e.retry_after_timestamp
|
||
return False
|
||
except Exception as e:
|
||
if additional_domains:
|
||
logger.error(f"处理域名异常 {primary_domain} + {additional_domains}: {e}")
|
||
else:
|
||
logger.error(f"处理域名异常 {primary_domain}: {e}")
|
||
return False
|
||
|
||
def process_new_domains(self):
|
||
"""处理新域名"""
|
||
# 优化:只获取一次路由和 SSL 配置,避免重复 API 调用
|
||
routes = self.get_all_routes()
|
||
ssls = self.get_all_ssls()
|
||
|
||
# 构建已存在的 SSL 域名集合(用于快速查找)
|
||
existing_ssl_domains = self._build_ssl_domains_set(ssls)
|
||
|
||
# 按路由处理,同一路由的多个域名合并到一个证书
|
||
for route in routes:
|
||
route_value = route.get('value', {})
|
||
# 跳过禁用的路由(status=0 表示禁用)
|
||
if route_value.get('status') == 0:
|
||
continue
|
||
|
||
# 从路由中提取所有域名
|
||
domains = self.extract_domains_from_route(route)
|
||
if not domains:
|
||
continue
|
||
|
||
# 过滤出需要申请证书的域名(使用缓存的 SSL 域名集合)
|
||
domains_to_request = [d for d in domains if self.should_request_cert(d, existing_ssl_domains)]
|
||
|
||
if not domains_to_request:
|
||
continue
|
||
|
||
# 处理证书申请(单个或多个域名)
|
||
primary_domain = domains_to_request[0]
|
||
additional_domains = domains_to_request[1:] if len(domains_to_request) > 1 else None
|
||
self._handle_cert_request(primary_domain, additional_domains)
|
||
|
||
def run(self, interval: int = 60):
|
||
"""运行监听服务"""
|
||
logger.info(f"路由监听服务启动,检查间隔: {interval} 秒")
|
||
|
||
while True:
|
||
try:
|
||
self.process_new_domains()
|
||
|
||
# 计算下次检查时间:如果有速率限制的域名即将解除,提前检查
|
||
next_check_interval = interval
|
||
current_time = time.time()
|
||
buffer_seconds = 15 * 60 # 15分钟缓冲时间
|
||
|
||
# 检查是否有即将解除限制的域名(包括缓冲期)
|
||
for domain, retry_after in list(self.rate_limited_domains.items()):
|
||
# 计算到可以重试的时间(限制解除时间 + 15分钟缓冲)
|
||
actual_retry_time = retry_after + buffer_seconds
|
||
|
||
if actual_retry_time <= current_time:
|
||
# 限制和缓冲期都已过,下次检查时处理
|
||
continue
|
||
else:
|
||
# 计算到可以重试的时间
|
||
time_until_retry = actual_retry_time - current_time
|
||
if time_until_retry < next_check_interval:
|
||
# 在可以重试时检查,而不是等到正常检查周期
|
||
next_check_interval = max(1, min(time_until_retry + 1, interval)) # +1 秒缓冲,至少1秒
|
||
|
||
time.sleep(next_check_interval)
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("收到停止信号,退出服务")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"监听服务异常: {e}")
|
||
time.sleep(interval)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
|
||
parser.add_argument('--interval', '-i', type=int, default=60,
|
||
help='检查间隔(秒),默认 60')
|
||
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
|
||
parser.add_argument('--once', action='store_true',
|
||
help='只执行一次,不持续监听')
|
||
|
||
args = parser.parse_args()
|
||
|
||
watcher = RouteWatcher(args.config)
|
||
|
||
if args.once:
|
||
watcher.process_new_domains()
|
||
else:
|
||
watcher.run(args.interval)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|