238 lines
7.9 KiB
Python
Executable File
238 lines
7.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
APISIX 路由监听服务
|
|
监听路由创建事件,自动为域名申请 SSL 证书
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import logging
|
|
import requests
|
|
from typing import Set, Optional
|
|
import sys
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from ssl_manager import APISIXSSLManager
|
|
|
|
# 配置日志
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('/var/log/apisix-route-watcher.log'),
|
|
logging.StreamHandler(sys.stdout)
|
|
]
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RouteWatcher:
|
|
"""路由监听器"""
|
|
|
|
def __init__(self, config_path: str = None):
|
|
"""初始化监听器"""
|
|
self.ssl_manager = APISIXSSLManager(config_path)
|
|
|
|
# 从环境变量或配置获取 APISIX 配置
|
|
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
|
|
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
|
|
|
|
# 已处理的域名集合
|
|
self.processed_domains: Set[str] = set()
|
|
|
|
# 加载已处理的域名
|
|
self._load_processed_domains()
|
|
|
|
def _get_apisix_headers(self):
|
|
"""获取 APISIX Admin API 请求头"""
|
|
return {
|
|
'X-API-KEY': self.apisix_admin_key,
|
|
'Content-Type': 'application/json'
|
|
}
|
|
|
|
def _load_processed_domains(self):
|
|
"""加载已处理的域名列表"""
|
|
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
|
|
if os.path.exists(state_file):
|
|
try:
|
|
with open(state_file, 'r') as f:
|
|
self.processed_domains = set(json.load(f))
|
|
logger.info(f"加载已处理域名: {len(self.processed_domains)} 个")
|
|
except Exception as e:
|
|
logger.warning(f"加载已处理域名失败: {e}")
|
|
|
|
def _save_processed_domains(self):
|
|
"""保存已处理的域名列表"""
|
|
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
|
|
os.makedirs(os.path.dirname(state_file), exist_ok=True)
|
|
try:
|
|
with open(state_file, 'w') as f:
|
|
json.dump(list(self.processed_domains), f)
|
|
except Exception as e:
|
|
logger.error(f"保存已处理域名失败: {e}")
|
|
|
|
def get_all_routes(self) -> list:
|
|
"""获取所有路由"""
|
|
try:
|
|
response = requests.get(
|
|
f"{self.apisix_admin_url}/apisix/admin/routes",
|
|
headers=self._get_apisix_headers(),
|
|
timeout=10
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return data.get('list', [])
|
|
else:
|
|
logger.error(f"获取路由失败: {response.status_code}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取路由异常: {e}")
|
|
return []
|
|
|
|
def get_all_ssls(self) -> list:
|
|
"""获取所有 SSL 配置"""
|
|
try:
|
|
response = requests.get(
|
|
f"{self.apisix_admin_url}/apisix/admin/ssls",
|
|
headers=self._get_apisix_headers(),
|
|
timeout=10
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return data.get('list', [])
|
|
else:
|
|
logger.error(f"获取 SSL 配置失败: {response.status_code}")
|
|
return []
|
|
|
|
except Exception as e:
|
|
logger.error(f"获取 SSL 配置异常: {e}")
|
|
return []
|
|
|
|
def extract_domains_from_route(self, route: dict) -> Set[str]:
|
|
"""从路由中提取域名"""
|
|
domains = set()
|
|
|
|
# 从 hosts 字段提取
|
|
hosts = route.get('value', {}).get('hosts', [])
|
|
if hosts:
|
|
domains.update(hosts)
|
|
|
|
# 从 uri 字段提取(如果包含域名)
|
|
uri = route.get('value', {}).get('uri', '')
|
|
if uri and '.' in uri and not uri.startswith('/'):
|
|
# 可能是域名格式
|
|
parts = uri.split('/')
|
|
if parts[0] and '.' in parts[0]:
|
|
domains.add(parts[0])
|
|
|
|
# 从 match 字段提取
|
|
match = route.get('value', {}).get('match', {})
|
|
if isinstance(match, dict):
|
|
for key, value in match.items():
|
|
if 'host' in key.lower() and isinstance(value, str):
|
|
domains.add(value)
|
|
|
|
return domains
|
|
|
|
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
|
|
"""从 SSL 配置中提取域名"""
|
|
domains = set()
|
|
snis = ssl.get('value', {}).get('snis', [])
|
|
if snis:
|
|
domains.update(snis)
|
|
return domains
|
|
|
|
def should_request_cert(self, domain: str) -> bool:
|
|
"""判断是否需要申请证书"""
|
|
# 跳过已处理的域名
|
|
if domain in self.processed_domains:
|
|
return False
|
|
|
|
# 跳过本地域名
|
|
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
|
|
return False
|
|
|
|
# 跳过 IP 地址
|
|
if domain.replace('.', '').isdigit():
|
|
return False
|
|
|
|
# 检查是否已有 SSL 配置
|
|
ssls = self.get_all_ssls()
|
|
for ssl in ssls:
|
|
ssl_domains = self.extract_domains_from_ssl(ssl)
|
|
if domain in ssl_domains:
|
|
logger.info(f"域名已有 SSL 配置: {domain}")
|
|
self.processed_domains.add(domain)
|
|
return False
|
|
|
|
return True
|
|
|
|
def process_new_domains(self):
|
|
"""处理新域名"""
|
|
routes = self.get_all_routes()
|
|
new_domains = set()
|
|
|
|
# 从路由中提取所有域名
|
|
for route in routes:
|
|
domains = self.extract_domains_from_route(route)
|
|
new_domains.update(domains)
|
|
|
|
# 处理需要申请证书的域名
|
|
for domain in new_domains:
|
|
if self.should_request_cert(domain):
|
|
logger.info(f"发现新域名,准备申请证书: {domain}")
|
|
try:
|
|
if self.ssl_manager.request_certificate(domain):
|
|
logger.info(f"证书申请成功: {domain}")
|
|
self.processed_domains.add(domain)
|
|
self._save_processed_domains()
|
|
else:
|
|
logger.error(f"证书申请失败: {domain}")
|
|
except Exception as e:
|
|
logger.error(f"处理域名异常 {domain}: {e}")
|
|
|
|
def run(self, interval: int = 60):
|
|
"""运行监听服务"""
|
|
logger.info(f"路由监听服务启动,检查间隔: {interval} 秒")
|
|
|
|
while True:
|
|
try:
|
|
self.process_new_domains()
|
|
time.sleep(interval)
|
|
except KeyboardInterrupt:
|
|
logger.info("收到停止信号,退出服务")
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"监听服务异常: {e}")
|
|
time.sleep(interval)
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
|
|
parser.add_argument('--interval', '-i', type=int, default=60,
|
|
help='检查间隔(秒),默认 60')
|
|
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
|
|
parser.add_argument('--once', action='store_true',
|
|
help='只执行一次,不持续监听')
|
|
|
|
args = parser.parse_args()
|
|
|
|
watcher = RouteWatcher(args.config)
|
|
|
|
if args.once:
|
|
watcher.process_new_domains()
|
|
else:
|
|
watcher.run(args.interval)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|