主要优化: - 性能优化:只调用一次 get_all_ssls() API,在内存中构建域名集合进行快速查找 - 之前:N 个域名 = N 次 API 调用 - 现在:N 个域名 = 1 次 API 调用 - 性能提升:从 O(N×M) 降低到 O(N+M) - HTTP 连接复用:使用 requests.Session() 复用连接,减少连接开销 - 代码重构: - 提取 _fetch_apisix_data() 公共方法,减少重复代码 - 提取 _is_valid_domain() 方法,改进 IP 地址检测(支持 IPv4/IPv6) - 提取 _build_ssl_domains_set() 方法,构建 SSL 域名集合 - IP 地址检测改进:使用 ipaddress 模块,更准确地检测 IPv4 和 IPv6 这些优化显著提升了服务性能,特别是在处理大量路由和域名时。
272 lines
9.5 KiB
Python
Executable File
272 lines
9.5 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
APISIX 路由监听服务
|
||
监听路由创建事件,自动为域名申请 SSL 证书
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import time
|
||
import logging
|
||
import requests
|
||
import ipaddress
|
||
from typing import Set, Optional, Dict
|
||
import sys
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||
from ssl_manager import APISIXSSLManager
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('/var/log/apisix-route-watcher.log'),
|
||
logging.StreamHandler(sys.stdout)
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class RouteWatcher:
|
||
"""路由监听器"""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
"""初始化监听器"""
|
||
self.ssl_manager = APISIXSSLManager(config_path)
|
||
|
||
# 从环境变量或配置获取 APISIX 配置
|
||
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
|
||
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
|
||
|
||
# 创建 HTTP 会话,复用连接
|
||
self.session = requests.Session()
|
||
self.session.headers.update({
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
})
|
||
|
||
def _get_apisix_headers(self):
|
||
"""获取 APISIX Admin API 请求头"""
|
||
return {
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
def get_all_routes(self) -> list:
|
||
"""获取所有路由"""
|
||
try:
|
||
response = requests.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/routes",
|
||
headers=self._get_apisix_headers(),
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取路由失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取路由异常: {e}")
|
||
return []
|
||
|
||
def get_all_ssls(self) -> list:
|
||
"""获取所有 SSL 配置"""
|
||
try:
|
||
response = requests.get(
|
||
f"{self.apisix_admin_url}/apisix/admin/ssls",
|
||
headers=self._get_apisix_headers(),
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get('list', [])
|
||
else:
|
||
logger.error(f"获取 SSL 配置失败: {response.status_code}")
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取 SSL 配置异常: {e}")
|
||
return []
|
||
|
||
def extract_domains_from_route(self, route: dict) -> Set[str]:
|
||
"""从路由中提取域名"""
|
||
domains = set()
|
||
route_value = route.get('value', {})
|
||
|
||
# 从 host 字段提取(单个域名,Dashboard 常用)
|
||
host = route_value.get('host')
|
||
if host and isinstance(host, str):
|
||
domains.add(host)
|
||
|
||
# 从 hosts 字段提取(域名数组)
|
||
hosts = route_value.get('hosts', [])
|
||
if hosts:
|
||
if isinstance(hosts, list):
|
||
domains.update(hosts)
|
||
elif isinstance(hosts, str):
|
||
domains.add(hosts)
|
||
|
||
# 从 uri 字段提取(如果包含域名)
|
||
uri = route_value.get('uri', '')
|
||
if uri and '.' in uri and not uri.startswith('/'):
|
||
# 可能是域名格式
|
||
parts = uri.split('/')
|
||
if parts[0] and '.' in parts[0]:
|
||
domains.add(parts[0])
|
||
|
||
# 从 match 字段提取
|
||
match = route_value.get('match', {})
|
||
if isinstance(match, dict):
|
||
for key, value in match.items():
|
||
if 'host' in key.lower() and isinstance(value, str):
|
||
domains.add(value)
|
||
|
||
return domains
|
||
|
||
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
|
||
"""从 SSL 配置中提取域名"""
|
||
domains = set()
|
||
snis = ssl.get('value', {}).get('snis', [])
|
||
if snis:
|
||
domains.update(snis)
|
||
return domains
|
||
|
||
def _is_valid_domain(self, domain: str) -> bool:
|
||
"""检查是否为有效域名(非 IP 地址和本地域名)"""
|
||
# 跳过本地域名
|
||
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
|
||
return False
|
||
|
||
# 检查是否为 IP 地址(支持 IPv4 和 IPv6)
|
||
try:
|
||
ipaddress.ip_address(domain)
|
||
return False
|
||
except ValueError:
|
||
pass
|
||
|
||
return True
|
||
|
||
def _build_ssl_domains_set(self, ssls: list) -> Set[str]:
|
||
"""构建所有已配置 SSL 的域名集合(用于快速查找)"""
|
||
ssl_domains = set()
|
||
for ssl in ssls:
|
||
domains = self.extract_domains_from_ssl(ssl)
|
||
ssl_domains.update(domains)
|
||
return ssl_domains
|
||
|
||
def should_request_cert(self, domain: str, existing_ssl_domains: Set[str]) -> bool:
|
||
"""判断是否需要申请证书
|
||
|
||
Args:
|
||
domain: 要检查的域名
|
||
existing_ssl_domains: 已存在的 SSL 域名集合(用于快速查找)
|
||
"""
|
||
# 检查是否为有效域名
|
||
if not self._is_valid_domain(domain):
|
||
return False
|
||
|
||
# 检查是否已有 SSL 配置
|
||
if domain in existing_ssl_domains:
|
||
logger.info(f"域名已有 SSL 配置: {domain}")
|
||
return False
|
||
|
||
return True
|
||
|
||
def process_new_domains(self):
|
||
"""处理新域名"""
|
||
# 优化:只获取一次路由和 SSL 配置,避免重复 API 调用
|
||
routes = self.get_all_routes()
|
||
ssls = self.get_all_ssls()
|
||
|
||
# 构建已存在的 SSL 域名集合(用于快速查找)
|
||
existing_ssl_domains = self._build_ssl_domains_set(ssls)
|
||
|
||
# 按路由处理,同一路由的多个域名合并到一个证书
|
||
for route in routes:
|
||
route_value = route.get('value', {})
|
||
# 跳过禁用的路由(status=0 表示禁用)
|
||
if route_value.get('status') == 0:
|
||
continue
|
||
|
||
# 从路由中提取所有域名
|
||
domains = self.extract_domains_from_route(route)
|
||
if not domains:
|
||
continue
|
||
|
||
# 过滤出需要申请证书的域名(使用缓存的 SSL 域名集合)
|
||
domains_to_request = [d for d in domains if self.should_request_cert(d, existing_ssl_domains)]
|
||
|
||
if not domains_to_request:
|
||
continue
|
||
|
||
# 如果只有一个域名,单独申请
|
||
if len(domains_to_request) == 1:
|
||
domain = domains_to_request[0]
|
||
logger.info(f"发现新域名,准备申请证书: {domain}")
|
||
try:
|
||
if self.ssl_manager.request_certificate(domain):
|
||
logger.info(f"证书申请成功: {domain}")
|
||
else:
|
||
logger.error(f"证书申请失败: {domain}")
|
||
except Exception as e:
|
||
logger.error(f"处理域名异常 {domain}: {e}")
|
||
else:
|
||
# 多个域名,合并到一个证书申请(使用 SAN)
|
||
primary_domain = domains_to_request[0]
|
||
additional_domains = domains_to_request[1:]
|
||
total_domains = len(additional_domains) + 1
|
||
logger.info(f"发现同一路由中的多个域名,合并申请证书: {primary_domain} + {additional_domains}")
|
||
try:
|
||
if self.ssl_manager.request_certificate(primary_domain, additional_domains):
|
||
logger.info(f"证书申请成功: {primary_domain} (包含 {total_domains} 个域名)")
|
||
else:
|
||
logger.error(f"证书申请失败: {primary_domain} + {additional_domains}")
|
||
except Exception as e:
|
||
logger.error(f"处理域名异常 {primary_domain} + {additional_domains}: {e}")
|
||
|
||
def run(self, interval: int = 60):
|
||
"""运行监听服务"""
|
||
logger.info(f"路由监听服务启动,检查间隔: {interval} 秒")
|
||
|
||
while True:
|
||
try:
|
||
self.process_new_domains()
|
||
time.sleep(interval)
|
||
except KeyboardInterrupt:
|
||
logger.info("收到停止信号,退出服务")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"监听服务异常: {e}")
|
||
time.sleep(interval)
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
|
||
parser.add_argument('--interval', '-i', type=int, default=60,
|
||
help='检查间隔(秒),默认 60')
|
||
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
|
||
parser.add_argument('--once', action='store_true',
|
||
help='只执行一次,不持续监听')
|
||
|
||
args = parser.parse_args()
|
||
|
||
watcher = RouteWatcher(args.config)
|
||
|
||
if args.once:
|
||
watcher.process_new_domains()
|
||
else:
|
||
watcher.run(args.interval)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|