apisix/ssl_manager/route_watcher.py

251 lines
8.4 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
APISIX 路由监听服务
监听路由创建事件,自动为域名申请 SSL 证书
"""
import os
import sys
import json
import time
import logging
import requests
from typing import Set, Optional
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from ssl_manager import APISIXSSLManager
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('/var/log/apisix-route-watcher.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
class RouteWatcher:
"""路由监听器"""
def __init__(self, config_path: str = None):
"""初始化监听器"""
self.ssl_manager = APISIXSSLManager(config_path)
# 从环境变量或配置获取 APISIX 配置
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180')
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137')
# 已处理的域名集合
self.processed_domains: Set[str] = set()
# 加载已处理的域名
self._load_processed_domains()
def _get_apisix_headers(self):
"""获取 APISIX Admin API 请求头"""
return {
'X-API-KEY': self.apisix_admin_key,
'Content-Type': 'application/json'
}
def _load_processed_domains(self):
"""加载已处理的域名列表"""
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
if os.path.exists(state_file):
try:
with open(state_file, 'r') as f:
self.processed_domains = set(json.load(f))
logger.info(f"加载已处理域名: {len(self.processed_domains)}")
except Exception as e:
logger.warning(f"加载已处理域名失败: {e}")
def _save_processed_domains(self):
"""保存已处理的域名列表"""
state_file = '/var/lib/apisix-ssl-manager/processed_domains.json'
os.makedirs(os.path.dirname(state_file), exist_ok=True)
try:
with open(state_file, 'w') as f:
json.dump(list(self.processed_domains), f)
except Exception as e:
logger.error(f"保存已处理域名失败: {e}")
def get_all_routes(self) -> list:
"""获取所有路由"""
try:
response = requests.get(
f"{self.apisix_admin_url}/apisix/admin/routes",
headers=self._get_apisix_headers(),
timeout=10
)
if response.status_code == 200:
data = response.json()
return data.get('list', [])
else:
logger.error(f"获取路由失败: {response.status_code}")
return []
except Exception as e:
logger.error(f"获取路由异常: {e}")
return []
def get_all_ssls(self) -> list:
"""获取所有 SSL 配置"""
try:
response = requests.get(
f"{self.apisix_admin_url}/apisix/admin/ssls",
headers=self._get_apisix_headers(),
timeout=10
)
if response.status_code == 200:
data = response.json()
return data.get('list', [])
else:
logger.error(f"获取 SSL 配置失败: {response.status_code}")
return []
except Exception as e:
logger.error(f"获取 SSL 配置异常: {e}")
return []
def extract_domains_from_route(self, route: dict) -> Set[str]:
"""从路由中提取域名"""
domains = set()
route_value = route.get('value', {})
# 从 host 字段提取单个域名Dashboard 常用)
host = route_value.get('host')
if host and isinstance(host, str):
domains.add(host)
# 从 hosts 字段提取(域名数组)
hosts = route_value.get('hosts', [])
if hosts:
if isinstance(hosts, list):
domains.update(hosts)
elif isinstance(hosts, str):
domains.add(hosts)
# 从 uri 字段提取(如果包含域名)
uri = route_value.get('uri', '')
if uri and '.' in uri and not uri.startswith('/'):
# 可能是域名格式
parts = uri.split('/')
if parts[0] and '.' in parts[0]:
domains.add(parts[0])
# 从 match 字段提取
match = route_value.get('match', {})
if isinstance(match, dict):
for key, value in match.items():
if 'host' in key.lower() and isinstance(value, str):
domains.add(value)
return domains
def extract_domains_from_ssl(self, ssl: dict) -> Set[str]:
"""从 SSL 配置中提取域名"""
domains = set()
snis = ssl.get('value', {}).get('snis', [])
if snis:
domains.update(snis)
return domains
def should_request_cert(self, domain: str) -> bool:
"""判断是否需要申请证书"""
# 跳过已处理的域名
if domain in self.processed_domains:
return False
# 跳过本地域名
if domain in ['localhost', '127.0.0.1', '0.0.0.0']:
return False
# 跳过 IP 地址
if domain.replace('.', '').isdigit():
return False
# 检查是否已有 SSL 配置
ssls = self.get_all_ssls()
for ssl in ssls:
ssl_domains = self.extract_domains_from_ssl(ssl)
if domain in ssl_domains:
logger.info(f"域名已有 SSL 配置: {domain}")
self.processed_domains.add(domain)
return False
return True
def process_new_domains(self):
"""处理新域名"""
routes = self.get_all_routes()
new_domains = set()
# 从路由中提取所有域名
for route in routes:
route_value = route.get('value', {})
# 跳过禁用的路由status=0 表示禁用)
if route_value.get('status') == 0:
continue
domains = self.extract_domains_from_route(route)
new_domains.update(domains)
# 处理需要申请证书的域名
for domain in new_domains:
if self.should_request_cert(domain):
logger.info(f"发现新域名,准备申请证书: {domain}")
try:
if self.ssl_manager.request_certificate(domain):
logger.info(f"证书申请成功: {domain}")
self.processed_domains.add(domain)
self._save_processed_domains()
else:
logger.error(f"证书申请失败: {domain}")
except Exception as e:
logger.error(f"处理域名异常 {domain}: {e}")
def run(self, interval: int = 60):
"""运行监听服务"""
logger.info(f"路由监听服务启动,检查间隔: {interval}")
while True:
try:
self.process_new_domains()
time.sleep(interval)
except KeyboardInterrupt:
logger.info("收到停止信号,退出服务")
break
except Exception as e:
logger.error(f"监听服务异常: {e}")
time.sleep(interval)
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='APISIX 路由监听服务')
parser.add_argument('--interval', '-i', type=int, default=60,
help='检查间隔(秒),默认 60')
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
parser.add_argument('--once', action='store_true',
help='只执行一次,不持续监听')
args = parser.parse_args()
watcher = RouteWatcher(args.config)
if args.once:
watcher.process_new_domains()
else:
watcher.run(args.interval)
if __name__ == '__main__':
main()