#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ APISIX 路由监听服务 监听路由创建事件,自动为域名申请 SSL 证书 """ import os import sys import time import logging import requests import ipaddress from typing import Set, Optional, Dict, List from datetime import datetime sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from ssl_manager import APISIXSSLManager, RateLimitError # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('/var/log/apisix-route-watcher.log'), logging.StreamHandler(sys.stdout) ] ) logger = logging.getLogger(__name__) class RouteWatcher: """路由监听器""" def __init__(self, config_path: str = None): """初始化监听器""" self.ssl_manager = APISIXSSLManager(config_path) # 从环境变量或配置获取 APISIX 配置 self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', 'http://localhost:9180') self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', '8206e6e42b6b53243c52a767cc633137') # 创建 HTTP 会话,复用连接 self.session = requests.Session() self.session.headers.update({ 'X-API-KEY': self.apisix_admin_key, 'Content-Type': 'application/json' }) # 速率限制记录:{domain: retry_after_timestamp} self.rate_limited_domains: Dict[str, float] = {} def get_all_routes(self) -> list: """获取所有路由""" try: response = self.session.get( f"{self.apisix_admin_url}/apisix/admin/routes", timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取路由失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取路由异常: {e}") return [] def get_all_ssls(self) -> list: """获取所有 SSL 配置""" try: response = self.session.get( f"{self.apisix_admin_url}/apisix/admin/ssls", timeout=10 ) if response.status_code == 200: data = response.json() return data.get('list', []) else: logger.error(f"获取 SSL 配置失败: {response.status_code}") return [] except Exception as e: logger.error(f"获取 SSL 配置异常: {e}") return [] def extract_domains_from_route(self, route: dict) -> Set[str]: """从路由中提取域名""" domains = set() route_value = route.get('value', {}) # 从 host 字段提取(单个域名,Dashboard 常用) host = route_value.get('host') if host and isinstance(host, str): domains.add(host) # 从 hosts 字段提取(域名数组) hosts = route_value.get('hosts', []) if hosts: if isinstance(hosts, list): domains.update(hosts) elif isinstance(hosts, str): domains.add(hosts) # 从 uri 字段提取(如果包含域名) uri = route_value.get('uri', '') if uri and '.' in uri and not uri.startswith('/'): # 可能是域名格式 parts = uri.split('/') if parts[0] and '.' in parts[0]: domains.add(parts[0]) # 从 match 字段提取 match = route_value.get('match', {}) if isinstance(match, dict): for key, value in match.items(): if 'host' in key.lower() and isinstance(value, str): domains.add(value) return domains def extract_domains_from_ssl(self, ssl: dict) -> Set[str]: """从 SSL 配置中提取域名""" domains = set() snis = ssl.get('value', {}).get('snis', []) if snis: domains.update(snis) return domains def _is_valid_domain(self, domain: str) -> bool: """检查是否为有效域名(非 IP 地址和本地域名)""" # 跳过本地域名 if domain in ['localhost', '127.0.0.1', '0.0.0.0']: return False # 检查是否为 IP 地址(支持 IPv4 和 IPv6) try: ipaddress.ip_address(domain) return False except ValueError: pass return True def _build_ssl_domains_set(self, ssls: list) -> Set[str]: """构建所有已配置 SSL 的域名集合(用于快速查找)""" ssl_domains = set() for ssl in ssls: domains = self.extract_domains_from_ssl(ssl) ssl_domains.update(domains) return ssl_domains def should_request_cert(self, domain: str, existing_ssl_domains: Set[str]) -> bool: """判断是否需要申请证书 Args: domain: 要检查的域名 existing_ssl_domains: 已存在的 SSL 域名集合(用于快速查找) """ # 检查是否为有效域名 if not self._is_valid_domain(domain): return False # 检查是否已有 SSL 配置 if domain in existing_ssl_domains: logger.info(f"域名已有 SSL 配置: {domain}") return False # 检查是否在速率限制期间 if domain in self.rate_limited_domains: retry_after = self.rate_limited_domains[domain] current_time = time.time() if current_time < retry_after: remaining_minutes = int((retry_after - current_time) / 60) + 1 logger.debug(f"域名 {domain} 仍在速率限制期间,剩余约 {remaining_minutes} 分钟后重试") return False else: # 限制已解除,移除记录 del self.rate_limited_domains[domain] logger.info(f"域名 {domain} 速率限制已解除,将重新尝试申请证书") return True def _handle_cert_request(self, primary_domain: str, additional_domains: List[str] = None): """处理证书申请(单个或多个域名) Args: primary_domain: 主域名 additional_domains: 额外域名列表(可选) Returns: bool: 是否成功 """ domains_list = [primary_domain] + (additional_domains or []) if additional_domains: total_domains = len(additional_domains) + 1 logger.info(f"发现同一路由中的多个域名,合并申请证书: {primary_domain} + {additional_domains}") else: logger.info(f"发现新域名,准备申请证书: {primary_domain}") try: if self.ssl_manager.request_certificate(primary_domain, additional_domains): if additional_domains: logger.info(f"证书申请成功: {primary_domain} (包含 {len(additional_domains) + 1} 个域名)") else: logger.info(f"证书申请成功: {primary_domain}") # 申请成功,清除速率限制记录 for d in domains_list: self.rate_limited_domains.pop(d, None) return True else: if additional_domains: logger.error(f"证书申请失败: {primary_domain} + {additional_domains}") else: logger.error(f"证书申请失败: {primary_domain}") return False except RateLimitError as e: logger.warning(f"域名 {e.domain} 遇到速率限制,将在 {datetime.fromtimestamp(e.retry_after_timestamp).strftime('%Y-%m-%d %H:%M:%S')} 后自动重试") # 记录速率限制的域名和重试时间(所有相关域名) for d in domains_list: self.rate_limited_domains[d] = e.retry_after_timestamp return False except Exception as e: if additional_domains: logger.error(f"处理域名异常 {primary_domain} + {additional_domains}: {e}") else: logger.error(f"处理域名异常 {primary_domain}: {e}") return False def process_new_domains(self): """处理新域名""" # 优化:只获取一次路由和 SSL 配置,避免重复 API 调用 routes = self.get_all_routes() ssls = self.get_all_ssls() # 构建已存在的 SSL 域名集合(用于快速查找) existing_ssl_domains = self._build_ssl_domains_set(ssls) # 按路由处理,同一路由的多个域名合并到一个证书 for route in routes: route_value = route.get('value', {}) # 跳过禁用的路由(status=0 表示禁用) if route_value.get('status') == 0: continue # 从路由中提取所有域名 domains = self.extract_domains_from_route(route) if not domains: continue # 过滤出需要申请证书的域名(使用缓存的 SSL 域名集合) domains_to_request = [d for d in domains if self.should_request_cert(d, existing_ssl_domains)] if not domains_to_request: continue # 处理证书申请(单个或多个域名) primary_domain = domains_to_request[0] additional_domains = domains_to_request[1:] if len(domains_to_request) > 1 else None self._handle_cert_request(primary_domain, additional_domains) def run(self, interval: int = 60): """运行监听服务""" logger.info(f"路由监听服务启动,检查间隔: {interval} 秒") while True: try: self.process_new_domains() time.sleep(interval) except KeyboardInterrupt: logger.info("收到停止信号,退出服务") break except Exception as e: logger.error(f"监听服务异常: {e}") time.sleep(interval) def main(): """主函数""" import argparse parser = argparse.ArgumentParser(description='APISIX 路由监听服务') parser.add_argument('--interval', '-i', type=int, default=60, help='检查间隔(秒),默认 60') parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)') parser.add_argument('--once', action='store_true', help='只执行一次,不持续监听') args = parser.parse_args() watcher = RouteWatcher(args.config) if args.once: watcher.process_new_domains() else: watcher.run(args.interval) if __name__ == '__main__': main()