#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ APISIX 路由监听服务 监听路由创建事件,自动为域名申请 SSL 证书 """ import os import sys import json import time import logging import requests from typing import Set, Optional import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from ssl_manager import APISIXSSLManager # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('/var/log/apisix-route-watcher.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) class RouteWatcher: """路由监听器""" def __init__(self, config_path: str = None): """初始化监听器""" self.ssl_manager = APISIXSSLManager(config_path) # 从环境变量或配置获取 APISIX 配置 self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180') self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137') # 已处理的域名集合 self.processed_domains: Set[str] = set() # 加载已处理的域名 self._load_processed_domains() def _get_apisix_headers(self): """获取 APISIX Admin API 请求头""" return { 'X-API-KEY': self.apisix_admin_key, 'Content-Type': 'application/json' } def _load_processed_domains(self): """加载已处理的域名列表""" state_file = '/var/lib/apisix-ssl-manager/processed_domains.json' if os.path.exists(state_file): try: with open(state_file, 'r') as f: self.processed_domains = set(json.load(f)) logger.info(f"加载已处理域名: {len(self.processed_domains)} 个") except Exception as e: logger.warning(f"加载已处理域名失败: {e}") def _save_processed_domains(self): """保存已处理的域名列表""" state_file = '/var/lib/apisix-ssl-manager/processed_domains.json' os.makedirs(os.path.dirname(state_file), exist_ok=True) try: with open(state_file, 'w') as f: json.dump(list(self.processed_domains), f) except Exception as e: logger.error(f"保存已处理域名失败: {e}") def get_all_routes(self) -> list: """获取所有路由""" try: response = requests.get( f"{self.apisix_admin_url}/apisix/admin/routes", headers=self._get_apisix_headers(), timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取路由失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取路由异常: {e}") return [] def get_all_ssls(self) -> list: """获取所有 SSL 配置""" try: response = requests.get( f"{self.apisix_admin_url}/apisix/admin/ssls", headers=self._get_apisix_headers(), timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取 SSL 配置失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取 SSL 配置异常: {e}") return [] def extract_domains_from_route(self, route: dict) -> Set[str]: """从路由中提取域名""" domains = set() route_value = route.get('value', {}) # 从 host 字段提取(单个域名,Dashboard 常用) host = route_value.get('host') if host and isinstance(host, str): domains.add(host) # 从 hosts 字段提取(域名数组) hosts = route_value.get('hosts', []) if hosts: if isinstance(hosts, list): domains.update(hosts) elif isinstance(hosts, str): domains.add(hosts) # 从 uri 字段提取(如果包含域名) uri = route_value.get('uri', '') if uri and '.' in uri and not uri.startswith('/'): # 可能是域名格式 parts = uri.split('/') if parts[0] and '.' in parts[0]: domains.add(parts[0]) # 从 match 字段提取 match = route_value.get('match', {}) if isinstance(match, dict): for key, value in match.items(): if 'host' in key.lower() and isinstance(value, str): domains.add(value) return domains def extract_domains_from_ssl(self, ssl: dict) -> Set[str]: """从 SSL 配置中提取域名""" domains = set() snis = ssl.get('value', {}).get('snis', []) if snis: domains.update(snis) return domains def should_request_cert(self, domain: str) -> bool: """判断是否需要申请证书""" # 跳过已处理的域名 if domain in self.processed_domains: return False # 跳过本地域名 if domain in ['localhost', '127.0.0.1', '0.0.0.0']: return False # 跳过 IP 地址 if domain.replace('.', '').isdigit(): return False # 检查是否已有 SSL 配置 ssls = self.get_all_ssls() for ssl in ssls: ssl_domains = self.extract_domains_from_ssl(ssl) if domain in ssl_domains: logger.info(f"域名已有 SSL 配置: {domain}") self.processed_domains.add(domain) return False return True def process_new_domains(self): """处理新域名""" routes = self.get_all_routes() new_domains = set() # 从路由中提取所有域名 for route in routes: domains = self.extract_domains_from_route(route) new_domains.update(domains) # 处理需要申请证书的域名 for domain in new_domains: if self.should_request_cert(domain): logger.info(f"发现新域名,准备申请证书: {domain}") try: if self.ssl_manager.request_certificate(domain): logger.info(f"证书申请成功: {domain}") self.processed_domains.add(domain) self._save_processed_domains() else: logger.error(f"证书申请失败: {domain}") except Exception as e: logger.error(f"处理域名异常 {domain}: {e}") def run(self, interval: int = 60): """运行监听服务""" logger.info(f"路由监听服务启动,检查间隔: {interval} 秒") while True: try: self.process_new_domains() time.sleep(interval) except KeyboardInterrupt: logger.info("收到停止信号,退出服务") break except Exception as e: logger.error(f"监听服务异常: {e}") time.sleep(interval) def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='APISIX 路由监听服务') parser.add_argument('--interval', '-i', type=int, default=60, help='检查间隔(秒),默认 60') parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)') parser.add_argument('--once', action='store_true', help='只执行一次,不持续监听') args = parser.parse_args() watcher = RouteWatcher(args.config) if args.once: watcher.process_new_domains() else: watcher.run(args.interval) if __name__ == '__main__': main()