#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ APISIX 路由监听服务 监听路由创建事件,自动为域名申请 SSL 证书 """ import os import sys import json import time import logging import requests import ipaddress from typing import Set, Optional, Dict import sys sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from ssl_manager import APISIXSSLManager # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('/var/log/apisix-route-watcher.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) class RouteWatcher: """路由监听器""" def __init__(self, config_path: str = None): """初始化监听器""" self.ssl_manager = APISIXSSLManager(config_path) # 从环境变量或配置获取 APISIX 配置 self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180') self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137') # 创建 HTTP 会话,复用连接 self.session = requests.Session() self.session.headers.update({ 'X-API-KEY': self.apisix_admin_key, 'Content-Type': 'application/json' }) def _get_apisix_headers(self): """获取 APISIX Admin API 请求头""" return { 'X-API-KEY': self.apisix_admin_key, 'Content-Type': 'application/json' } def get_all_routes(self) -> list: """获取所有路由""" try: response = requests.get( f"{self.apisix_admin_url}/apisix/admin/routes", headers=self._get_apisix_headers(), timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取路由失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取路由异常: {e}") return [] def get_all_ssls(self) -> list: """获取所有 SSL 配置""" try: response = requests.get( f"{self.apisix_admin_url}/apisix/admin/ssls", headers=self._get_apisix_headers(), timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取 SSL 配置失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取 SSL 配置异常: {e}") return [] def extract_domains_from_route(self, route: dict) -> Set[str]: """从路由中提取域名""" domains = set() route_value = route.get('value', {}) # 从 host 字段提取(单个域名,Dashboard 常用) host = route_value.get('host') if host and isinstance(host, str): domains.add(host) # 从 hosts 字段提取(域名数组) hosts = route_value.get('hosts', []) if hosts: if isinstance(hosts, list): domains.update(hosts) elif isinstance(hosts, str): domains.add(hosts) # 从 uri 字段提取(如果包含域名) uri = route_value.get('uri', '') if uri and '.' in uri and not uri.startswith('/'): # 可能是域名格式 parts = uri.split('/') if parts[0] and '.' in parts[0]: domains.add(parts[0]) # 从 match 字段提取 match = route_value.get('match', {}) if isinstance(match, dict): for key, value in match.items(): if 'host' in key.lower() and isinstance(value, str): domains.add(value) return domains def extract_domains_from_ssl(self, ssl: dict) -> Set[str]: """从 SSL 配置中提取域名""" domains = set() snis = ssl.get('value', {}).get('snis', []) if snis: domains.update(snis) return domains def _is_valid_domain(self, domain: str) -> bool: """检查是否为有效域名(非 IP 地址和本地域名)""" # 跳过本地域名 if domain in ['localhost', '127.0.0.1', '0.0.0.0']: return False # 检查是否为 IP 地址(支持 IPv4 和 IPv6) try: ipaddress.ip_address(domain) return False except ValueError: pass return True def _build_ssl_domains_set(self, ssls: list) -> Set[str]: """构建所有已配置 SSL 的域名集合(用于快速查找)""" ssl_domains = set() for ssl in ssls: domains = self.extract_domains_from_ssl(ssl) ssl_domains.update(domains) return ssl_domains def should_request_cert(self, domain: str, existing_ssl_domains: Set[str]) -> bool: """判断是否需要申请证书 Args: domain: 要检查的域名 existing_ssl_domains: 已存在的 SSL 域名集合(用于快速查找) """ # 检查是否为有效域名 if not self._is_valid_domain(domain): return False # 检查是否已有 SSL 配置 if domain in existing_ssl_domains: logger.info(f"域名已有 SSL 配置: {domain}") return False return True def process_new_domains(self): """处理新域名""" # 优化:只获取一次路由和 SSL 配置,避免重复 API 调用 routes = self.get_all_routes() ssls = self.get_all_ssls() # 构建已存在的 SSL 域名集合(用于快速查找) existing_ssl_domains = self._build_ssl_domains_set(ssls) # 按路由处理,同一路由的多个域名合并到一个证书 for route in routes: route_value = route.get('value', {}) # 跳过禁用的路由(status=0 表示禁用) if route_value.get('status') == 0: continue # 从路由中提取所有域名 domains = self.extract_domains_from_route(route) if not domains: continue # 过滤出需要申请证书的域名(使用缓存的 SSL 域名集合) domains_to_request = [d for d in domains if self.should_request_cert(d, existing_ssl_domains)] if not domains_to_request: continue # 如果只有一个域名,单独申请 if len(domains_to_request) == 1: domain = domains_to_request[0] logger.info(f"发现新域名,准备申请证书: {domain}") try: if self.ssl_manager.request_certificate(domain): logger.info(f"证书申请成功: {domain}") else: logger.error(f"证书申请失败: {domain}") except Exception as e: logger.error(f"处理域名异常 {domain}: {e}") else: # 多个域名,合并到一个证书申请(使用 SAN) primary_domain = domains_to_request[0] additional_domains = domains_to_request[1:] total_domains = len(additional_domains) + 1 logger.info(f"发现同一路由中的多个域名,合并申请证书: {primary_domain} + {additional_domains}") try: if self.ssl_manager.request_certificate(primary_domain, additional_domains): logger.info(f"证书申请成功: {primary_domain} (包含 {total_domains} 个域名)") else: logger.error(f"证书申请失败: {primary_domain} + {additional_domains}") except Exception as e: logger.error(f"处理域名异常 {primary_domain} + {additional_domains}: {e}") def run(self, interval: int = 60): """运行监听服务""" logger.info(f"路由监听服务启动,检查间隔: {interval} 秒") while True: try: self.process_new_domains() time.sleep(interval) except KeyboardInterrupt: logger.info("收到停止信号,退出服务") break except Exception as e: logger.error(f"监听服务异常: {e}") time.sleep(interval) def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='APISIX 路由监听服务') parser.add_argument('--interval', '-i', type=int, default=60, help='检查间隔(秒),默认 60') parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)') parser.add_argument('--once', action='store_true', help='只执行一次,不持续监听') args = parser.parse_args() watcher = RouteWatcher(args.config) if args.once: watcher.process_new_domains() else: watcher.run(args.interval) if __name__ == '__main__': main()