apisix/ssl_manager/route_watcher.py

238 lines
7.9 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
APISIX 路由监听服务
监听路由创建事件,自动为域名申请 SSL 证书
"""
import os
import sys
import json
import time
import logging
import requests
from typing import Set, Optional
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ssl_manager import APISIXSSLManager
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('/var/log/apisix-route-watcher.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class RouteWatcher:
"""路由监听器"""
def __init__(self, config_path: str = None):
"""初始化监听器"""
self.ssl_manager = APISIXSSLManager(config_path)
# 从环境变量或配置获取 APISIX 配置
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
# 已处理的域名集合
self.processed_domains: Set[str] = set()
# 加载已处理的域名
self._load_processed_domains()
def _get_apisix_headers(self):
"""获取 APISIX Admin API 请求头"""
return {
'X-API-KEY': self.apisix_admin_key,
'Content-Type': 'application/json'
}
def _load_processed_domains(self):
"""加载已处理的域名列表"""
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
if os.path.exists(state_file):
try:
with open(state_file, 'r') as f:
self.processed_domains = set(json.load(f))
logger.info(f"加载已处理域名: {len(self.processed_domains)}")
except Exception as e:
logger.warning(f"加载已处理域名失败: {e}")
def _save_processed_domains(self):
"""保存已处理的域名列表"""
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
os.makedirs(os.path.dirname(state_file), exist_ok=True)
try:
with open(state_file, 'w') as f:
json.dump(list(self.processed_domains), f)
except Exception as e:
logger.error(f"保存已处理域名失败: {e}")
def get_all_routes(self) -> list:
"""获取所有路由"""
try:
response = requests.get(
f"{self.apisix_admin_url}/apisix/admin/routes",
headers=self._get_apisix_headers(),
timeout=10
)
if response.status_code == 200:
data = response.json()
return data.get('list', [])
else:
logger.error(f"获取路由失败: {response.status_code}")
return []
except Exception as e:
logger.error(f"获取路由异常: {e}")
return []
def get_all_ssls(self) -> list:
"""获取所有 SSL 配置"""
try:
response = requests.get(
f"{self.apisix_admin_url}/apisix/admin/ssls",
headers=self._get_apisix_headers(),
timeout=10
)
if response.status_code == 200:
data = response.json()
return data.get('list', [])
else:
logger.error(f"获取 SSL 配置失败: {response.status_code}")
return []
except Exception as e:
logger.error(f"获取 SSL 配置异常: {e}")
return []
def extract_domains_from_route(self, route: dict) -> Set[str]:
"""从路由中提取域名"""
domains = set()
# 从 hosts 字段提取
hosts = route.get('value', {}).get('hosts', [])
if hosts:
domains.update(hosts)
# 从 uri 字段提取(如果包含域名)
uri = route.get('value', {}).get('uri', '')
if uri and '.' in uri and not uri.startswith('/'):
# 可能是域名格式
parts = uri.split('/')
if parts[0] and '.' in parts[0]:
domains.add(parts[0])
# 从 match 字段提取
match = route.get('value', {}).get('match', {})
if isinstance(match, dict):
for key, value in match.items():
if 'host' in key.lower() and isinstance(value, str):
domains.add(value)
return domains
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
"""从 SSL 配置中提取域名"""
domains = set()
snis = ssl.get('value', {}).get('snis', [])
if snis:
domains.update(snis)
return domains
def should_request_cert(self, domain: str) -> bool:
"""判断是否需要申请证书"""
# 跳过已处理的域名
if domain in self.processed_domains:
return False
# 跳过本地域名
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
return False
# 跳过 IP 地址
if domain.replace('.', '').isdigit():
return False
# 检查是否已有 SSL 配置
ssls = self.get_all_ssls()
for ssl in ssls:
ssl_domains = self.extract_domains_from_ssl(ssl)
if domain in ssl_domains:
logger.info(f"域名已有 SSL 配置: {domain}")
self.processed_domains.add(domain)
return False
return True
def process_new_domains(self):
"""处理新域名"""
routes = self.get_all_routes()
new_domains = set()
# 从路由中提取所有域名
for route in routes:
domains = self.extract_domains_from_route(route)
new_domains.update(domains)
# 处理需要申请证书的域名
for domain in new_domains:
if self.should_request_cert(domain):
logger.info(f"发现新域名,准备申请证书: {domain}")
try:
if self.ssl_manager.request_certificate(domain):
logger.info(f"证书申请成功: {domain}")
self.processed_domains.add(domain)
self._save_processed_domains()
else:
logger.error(f"证书申请失败: {domain}")
except Exception as e:
logger.error(f"处理域名异常 {domain}: {e}")
def run(self, interval: int = 60):
"""运行监听服务"""
logger.info(f"路由监听服务启动,检查间隔: {interval}")
while True:
try:
self.process_new_domains()
time.sleep(interval)
except KeyboardInterrupt:
logger.info("收到停止信号,退出服务")
break
except Exception as e:
logger.error(f"监听服务异常: {e}")
time.sleep(interval)
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
parser.add_argument('--interval', '-i', type=int, default=60,
help='检查间隔(秒),默认 60')
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
parser.add_argument('--once', action='store_true',
help='只执行一次,不持续监听')
args = parser.parse_args()
watcher = RouteWatcher(args.config)
if args.once:
watcher.process_new_domains()
else:
watcher.run(args.interval)
if __name__ == '__main__':
main()