diff --git a/ssl_manager/__pycache__/ssl_manager.cpython-313.pyc b/ssl_manager/__pycache__/ssl_manager.cpython-313.pyc index 1a09d1e..c8800d3 100644 Binary files a/ssl_manager/__pycache__/ssl_manager.cpython-313.pyc and b/ssl_manager/__pycache__/ssl_manager.cpython-313.pyc differ diff --git a/ssl_manager/__pycache__/test_ssl_auto.cpython-313.pyc b/ssl_manager/__pycache__/test_ssl_auto.cpython-313.pyc deleted file mode 100644 index 8553bf0..0000000 Binary files a/ssl_manager/__pycache__/test_ssl_auto.cpython-313.pyc and /dev/null differ diff --git a/ssl_manager/ssl_manager.py b/ssl_manager/ssl_manager.py index 754399d..12feeb0 100755 --- a/ssl_manager/ssl_manager.py +++ b/ssl_manager/ssl_manager.py @@ -11,6 +11,8 @@ APISIX SSL 证书自动管理脚本 import os import sys import json +import re +import time import subprocess import requests import logging @@ -38,7 +40,7 @@ DEFAULT_CONFIG = { 'certbot_path': '/usr/bin/certbot', 'cert_dir': '/etc/letsencrypt/live', 'letsencrypt_email': 'admin@jingrowtools.cn', - 'letsencrypt_staging': False, # 默认使用 staging 模式,生产环境改为 False + 'letsencrypt_staging': True, # 默认使用 staging 模式,生产环境改为 False 'webroot_path': '/var/www/certbot' } @@ -90,6 +92,83 @@ class APISIXSSLManager: 'Content-Type': 'application/json' } + def extract_domains_from_cert(self, cert_content: str) -> List[str]: + """从证书内容中提取所有域名(包括 Subject Alternative Names)""" + domains = [] + try: + # 确保 cert_content 是字符串类型(因为 text=True) + if isinstance(cert_content, bytes): + cert_str = cert_content.decode('utf-8') + else: + cert_str = cert_content + + # 使用 openssl 命令提取证书信息 + # 注意:text=True 时,input 应该是字符串 + result = subprocess.run( + ['openssl', 'x509', '-noout', '-text', '-in', '/dev/stdin'], + input=cert_str, + capture_output=True, + text=True, + check=True, + timeout=10 + ) + + cert_text = result.stdout + + # 提取 Subject CN + cn_match = re.search(r'Subject:.*?CN\s*=\s*([^\s,/\n]+)', cert_text) + if cn_match: + domains.append(cn_match.group(1)) + + # 提取 Subject Alternative Names + # 匹配 SAN 部分,包括多行情况 + san_pattern = r'X509v3 Subject Alternative Name:\s*\n\s*((?:DNS:[^\n]+(?:\n\s+[^\n]+)*))' + san_section = re.search(san_pattern, cert_text, re.MULTILINE) + if san_section: + san_text = san_section.group(1) + # 匹配所有 DNS: 后面的域名(支持跨行) + dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_text) + domains.extend(dns_matches) + + # 如果上面的方法没匹配到,尝试更宽松的匹配 + if not domains or (san_section is None and 'Subject Alternative Name' in cert_text): + # 直接在 SAN 部分查找所有 DNS 条目 + san_start = cert_text.find('X509v3 Subject Alternative Name:') + if san_start != -1: + # 找到 SAN 部分,提取接下来的几行 + san_end = cert_text.find('\n\n', san_start) + if san_end == -1: + san_end = min(san_start + 500, len(cert_text)) # 最多取500字符 + san_block = cert_text[san_start:san_end] + # 匹配所有 DNS: 条目 + dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_block) + if dns_matches: + domains.extend(dns_matches) + + # 去重并保持顺序 + seen = set() + unique_domains = [] + for domain in domains: + if domain and domain not in seen: + seen.add(domain) + unique_domains.append(domain) + + if unique_domains: + logger.info(f"从证书中提取到域名: {unique_domains}") + else: + logger.warning("未能从证书中提取域名") + + return unique_domains + + except subprocess.CalledProcessError as e: + logger.error(f"提取证书域名失败: {e.stderr}") + return [] + except Exception as e: + logger.error(f"提取证书域名异常: {e}") + import traceback + logger.error(f"异常堆栈: {traceback.format_exc()}") + return [] + def read_cert_files(self, domain: str) -> Optional[Dict[str, str]]: """读取证书文件""" domain_cert_dir = Path(self.cert_dir) / domain @@ -117,36 +196,73 @@ class APISIXSSLManager: def upload_cert_to_apisix(self, domain: str, cert_content: str, key_content: str) -> bool: """将证书上传到 APISIX""" - # 生成 SSL ID(使用域名作为 ID) + # 从证书中提取所有域名(包括 SAN) + cert_domains = self.extract_domains_from_cert(cert_content) + + # 如果没有提取到域名,使用传入的 domain 作为后备 + if not cert_domains: + logger.warning(f"无法从证书提取域名,使用传入的域名: {domain}") + cert_domains = [domain] + + # 确保传入的 domain 也在列表中(如果不在的话) + if domain not in cert_domains: + cert_domains.append(domain) + + # 生成 SSL ID(使用主域名作为 ID,用于查找和更新) ssl_id = domain.replace('.', '_').replace('*', 'wildcard') - # 构建 SSL 配置(创建时不包含 id) + # 构建 SSL 配置(创建时不包含 id,更新时需要 id) + # SNI 列表包含证书中的所有域名 ssl_config = { - "snis": [domain], + "snis": cert_domains, "cert": cert_content, "key": key_content } - # 检查是否已存在 - check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}" + logger.info(f"配置 SNI 域名列表: {cert_domains}") + headers = self._get_apisix_headers() try: - # 先检查是否存在 + # 先检查是否已存在相同 SNI 的配置 + # 方法1:通过 ID 查找(如果之前创建时使用了这个 ID) + check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}" response = requests.get(check_url, headers=headers, timeout=10) + existing_ssl_id = None if response.status_code == 200: + existing_ssl_id = ssl_id + logger.info(f"找到现有 SSL 配置 (ID: {ssl_id})") + else: + # 方法2:查询所有 SSL 配置,检查是否有相同 SNI 的配置 + all_ssls_url = f"{self.apisix_admin_url}/apisix/admin/ssls" + all_response = requests.get(all_ssls_url, headers=headers, timeout=10) + if all_response.status_code == 200: + all_ssls = all_response.json() + ssl_list = all_ssls.get('list', []) if isinstance(all_ssls, dict) else all_ssls + + # 检查每个 SSL 配置的 SNI 是否匹配 + for ssl_item in ssl_list: + ssl_value = ssl_item.get('value', {}) if isinstance(ssl_item, dict) else ssl_item + existing_snis = ssl_value.get('snis', []) + # 检查 SNI 列表是否相同(忽略顺序) + if set(existing_snis) == set(cert_domains): + existing_ssl_id = ssl_item.get('id') or ssl_item.get('key', {}).get('id') + logger.info(f"找到现有 SSL 配置,SNI 匹配 (ID: {existing_ssl_id})") + break + + if existing_ssl_id: # 更新现有证书(更新时需要 id) - logger.info(f"更新 APISIX SSL 配置: {domain}") - ssl_config["id"] = ssl_id + logger.info(f"更新 APISIX SSL 配置: {domain} (ID: {existing_ssl_id})") + ssl_config["id"] = existing_ssl_id response = requests.put( - f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}", + f"{self.apisix_admin_url}/apisix/admin/ssls/{existing_ssl_id}", headers=headers, json=ssl_config, timeout=10 ) else: - # 创建新证书(创建时不需要 id,APISIX 会自动生成) + # 创建新证书(POST 时不包含 id,让 APISIX 自动生成) logger.info(f"创建 APISIX SSL 配置: {domain}") response = requests.post( f"{self.apisix_admin_url}/apisix/admin/ssls", @@ -166,8 +282,14 @@ class APISIXSSLManager: logger.error(f"上传证书到 APISIX 失败: {e}") return False - def request_certificate(self, domain: str, additional_domains: List[str] = None) -> bool: - """申请 Let's Encrypt 证书""" + def request_certificate(self, domain: str, additional_domains: List[str] = None, max_retries: int = 3) -> bool: + """申请 Let's Encrypt 证书 + + Args: + domain: 主域名 + additional_domains: 额外域名列表 + max_retries: 最大重试次数(默认3次) + """ domains = [domain] if additional_domains: domains.extend(additional_domains) @@ -193,34 +315,73 @@ class APISIXSSLManager: logger.info(f"申请证书: {domain}, 命令: {' '.join(cmd)}") - try: - result = subprocess.run( - cmd, - capture_output=True, - text=True, - check=True, - timeout=300 - ) - - if result.returncode == 0: - logger.info(f"证书申请成功: {domain}") - # 读取证书并上传到 APISIX - cert_data = self.read_cert_files(domain) - if cert_data: - return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key']) - else: - logger.error(f"无法读取证书文件: {domain}") - return False - else: - logger.error(f"证书申请失败: {result.stderr}") - return False + # 重试机制 + for attempt in range(1, max_retries + 1): + try: + if attempt > 1: + logger.info(f"第 {attempt} 次尝试申请证书 (共 {max_retries} 次)...") + time.sleep(5) # 重试前等待5秒 - except subprocess.TimeoutExpired: - logger.error(f"证书申请超时: {domain}") - return False - except Exception as e: - logger.error(f"证书申请异常: {e}") - return False + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=False, # 不自动抛出异常,手动处理 + timeout=300 + ) + + if result.returncode == 0: + logger.info(f"证书申请成功: {domain}") + # 读取证书并上传到 APISIX + cert_data = self.read_cert_files(domain) + if cert_data: + return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key']) + else: + logger.error(f"无法读取证书文件: {domain}") + return False + else: + # 检查是否是网络超时错误 + error_output = result.stderr or "" + is_timeout_error = "ReadTimeout" in error_output or "timed out" in error_output.lower() + + if is_timeout_error and attempt < max_retries: + logger.warning(f"证书申请网络超时 (尝试 {attempt}/{max_retries}),将重试...") + continue + else: + logger.error(f"证书申请失败 (退出码: {result.returncode})") + if result.stdout: + logger.error(f"标准输出: {result.stdout}") + if result.stderr: + logger.error(f"错误输出: {result.stderr}") + + # 如果是网络超时且已尝试所有次数,给出提示 + if is_timeout_error: + logger.error("网络连接超时,可能的原因:") + logger.error("1. 服务器无法访问 Let's Encrypt 服务器 (acme-staging-v02.api.letsencrypt.org 或 acme-v02.api.letsencrypt.org)") + logger.error("2. 防火墙阻止了 HTTPS 连接") + logger.error("3. 网络不稳定,建议稍后重试") + logger.error("4. 可以检查网络连接: curl -I https://acme-staging-v02.api.letsencrypt.org/directory") + + return False + + except subprocess.TimeoutExpired: + if attempt < max_retries: + logger.warning(f"证书申请超时 (尝试 {attempt}/{max_retries}),将重试...") + continue + else: + logger.error(f"证书申请超时: {domain} (已尝试 {max_retries} 次)") + return False + except Exception as e: + if attempt < max_retries: + logger.warning(f"证书申请异常 (尝试 {attempt}/{max_retries}): {e},将重试...") + continue + else: + logger.error(f"证书申请异常: {e}") + import traceback + logger.error(f"异常堆栈: {traceback.format_exc()}") + return False + + return False def renew_certificate(self, domain: str) -> bool: """续期证书""" diff --git a/ssl_manager/test_ssl_auto.py b/ssl_manager/test_ssl_auto.py index 17a2b7c..072754a 100755 --- a/ssl_manager/test_ssl_auto.py +++ b/ssl_manager/test_ssl_auto.py @@ -331,15 +331,21 @@ class SSLTestRunner: print_error(f"测试验证路径异常: {e}") return False - def create_test_route(self, domain: str) -> bool: + def create_test_route(self, domain: str, additional_domains: list = None) -> bool: """创建测试路由""" print_info(f"创建测试路由: {domain}") + # 构建域名列表(主域名 + 额外域名) + hosts = [domain] + if additional_domains: + hosts.extend(additional_domains) + print_info(f"路由将包含域名: {', '.join(hosts)}") + route_config = { "uri": "/*", "name": domain, "methods": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"], - "host": domain, + "hosts": hosts, # 使用 hosts 数组支持多个域名 "upstream": { "nodes": [ { @@ -373,44 +379,22 @@ class SSLTestRunner: print_error(f"创建测试路由异常: {e}") return False - def request_certificate(self, domain: str) -> bool: + def request_certificate(self, domain: str, additional_domains: list = None) -> bool: """申请证书""" - print_info(f"申请证书: {domain} (staging={self.staging})") - - cmd = [ - self.ssl_manager.certbot_path, - 'certonly', - '--webroot', - '--webroot-path', self.webroot_path, - '--non-interactive', - '--agree-tos', - '--email', self.email, - '--cert-name', domain, - '-d', domain - ] - - if self.staging: - cmd.append('--staging') + if additional_domains: + print_info(f"申请证书: {domain} + {additional_domains} (staging={self.staging})") + else: + print_info(f"申请证书: {domain} (staging={self.staging})") + # 使用 ssl_manager 的 request_certificate 方法,它已经支持额外域名 try: - print_info(f"执行命令: {' '.join(cmd)}") - result = subprocess.run( - cmd, - capture_output=True, - text=True, - timeout=300 - ) - - if result.returncode == 0: + result = self.ssl_manager.request_certificate(domain, additional_domains) + if result: print_success(f"证书申请成功: {domain}") - print_info(result.stdout) return True else: - print_error(f"证书申请失败: {result.stderr}") + print_error(f"证书申请失败: {domain}") return False - except subprocess.TimeoutExpired: - print_error("证书申请超时") - return False except Exception as e: print_error(f"证书申请异常: {e}") return False @@ -504,7 +488,7 @@ class SSLTestRunner: except: pass - def run_full_test(self, domain: str = None, cleanup: bool = False): + def run_full_test(self, domain: str = None, additional_domains: list = None, cleanup: bool = False): """运行完整测试""" if not domain: domain = self.test_domain @@ -512,6 +496,8 @@ class SSLTestRunner: print(f"\n{Colors.BOLD}{'='*60}") print(f"APISIX SSL 证书自动申请测试") print(f"测试域名: {domain}") + if additional_domains: + print(f"额外域名: {', '.join(additional_domains)}") print(f"Staging 模式: {self.staging}") print(f"{'='*60}{Colors.RESET}\n") @@ -521,10 +507,10 @@ class SSLTestRunner: (3, "检查 Webroot 目录", lambda: self.check_webroot_directory()), (4, "检查/创建 Webroot 路由", lambda: self.check_webroot_route(domain)), (5, "测试验证路径", lambda: self.test_verification_path(domain)), - (6, "创建测试路由", lambda: self.create_test_route(domain)), - (7, "申请 SSL 证书", lambda: self.request_certificate(domain)), - (8, "同步证书到 APISIX", lambda: self.sync_certificate_to_apisix(domain)), - (9, "验证证书信息", lambda: self.verify_certificate(domain)), + (6, "创建测试路由", lambda: self.create_test_route(domain, additional_domains)), + (7, "申请 SSL 证书", lambda: self.request_certificate(domain, additional_domains)), + # 注意:证书申请成功后会自动上传到 APISIX,不需要单独同步步骤 + (8, "验证证书信息", lambda: self.verify_certificate(domain)), ] success_count = 0 @@ -564,6 +550,7 @@ def main(): parser = argparse.ArgumentParser(description='APISIX SSL 证书自动申请测试脚本') parser.add_argument('--domain', '-d', help='测试域名(不指定则自动生成)') + parser.add_argument('--additional-domains', '-a', nargs='+', help='额外域名(如 www 子域名)') parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)') parser.add_argument('--cleanup', action='store_true', help='测试完成后清理测试数据') parser.add_argument('--no-cleanup', action='store_true', help='测试完成后不清理测试数据') @@ -579,9 +566,18 @@ def main(): print_warning(f"未指定域名,使用自动生成的测试域名: {domain}") print_info("注意:此域名需要 DNS 解析到当前服务器才能申请证书") + # 处理额外域名 + additional_domains = [] + if args.additional_domains: + additional_domains.extend(args.additional_domains) + cleanup = args.cleanup or (not args.no_cleanup and not args.domain) - success = runner.run_full_test(domain, cleanup=cleanup) + success = runner.run_full_test( + domain, + additional_domains=additional_domains if additional_domains else None, + cleanup=cleanup + ) sys.exit(0 if success else 1)