ssl_manager增加支持额外域名

This commit is contained in:
jingrow 2026-01-01 18:28:10 +00:00
parent 88f5127d9b
commit 5bd9e95e15
4 changed files with 236 additions and 79 deletions

View File

@ -11,6 +11,8 @@ APISIX SSL 证书自动管理脚本
import os import os
import sys import sys
import json import json
import re
import time
import subprocess import subprocess
import requests import requests
import logging import logging
@ -38,7 +40,7 @@ DEFAULT_CONFIG = {
'certbot_path': '/usr/bin/certbot', 'certbot_path': '/usr/bin/certbot',
'cert_dir': '/etc/letsencrypt/live', 'cert_dir': '/etc/letsencrypt/live',
'letsencrypt_email': 'admin@jingrowtools.cn', 'letsencrypt_email': 'admin@jingrowtools.cn',
'letsencrypt_staging': False, # 默认使用 staging 模式,生产环境改为 False 'letsencrypt_staging': True, # 默认使用 staging 模式,生产环境改为 False
'webroot_path': '/var/www/certbot' 'webroot_path': '/var/www/certbot'
} }
@ -90,6 +92,83 @@ class APISIXSSLManager:
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
def extract_domains_from_cert(self, cert_content: str) -> List[str]:
"""从证书内容中提取所有域名(包括 Subject Alternative Names"""
domains = []
try:
# 确保 cert_content 是字符串类型(因为 text=True
if isinstance(cert_content, bytes):
cert_str = cert_content.decode('utf-8')
else:
cert_str = cert_content
# 使用 openssl 命令提取证书信息
# 注意text=True 时input 应该是字符串
result = subprocess.run(
['openssl', 'x509', '-noout', '-text', '-in', '/dev/stdin'],
input=cert_str,
capture_output=True,
text=True,
check=True,
timeout=10
)
cert_text = result.stdout
# 提取 Subject CN
cn_match = re.search(r'Subject:.*?CN\s*=\s*([^\s,/\n]+)', cert_text)
if cn_match:
domains.append(cn_match.group(1))
# 提取 Subject Alternative Names
# 匹配 SAN 部分,包括多行情况
san_pattern = r'X509v3 Subject Alternative Name:\s*\n\s*((?:DNS:[^\n]+(?:\n\s+[^\n]+)*))'
san_section = re.search(san_pattern, cert_text, re.MULTILINE)
if san_section:
san_text = san_section.group(1)
# 匹配所有 DNS: 后面的域名(支持跨行)
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_text)
domains.extend(dns_matches)
# 如果上面的方法没匹配到,尝试更宽松的匹配
if not domains or (san_section is None and 'Subject Alternative Name' in cert_text):
# 直接在 SAN 部分查找所有 DNS 条目
san_start = cert_text.find('X509v3 Subject Alternative Name:')
if san_start != -1:
# 找到 SAN 部分,提取接下来的几行
san_end = cert_text.find('\n\n', san_start)
if san_end == -1:
san_end = min(san_start + 500, len(cert_text)) # 最多取500字符
san_block = cert_text[san_start:san_end]
# 匹配所有 DNS: 条目
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_block)
if dns_matches:
domains.extend(dns_matches)
# 去重并保持顺序
seen = set()
unique_domains = []
for domain in domains:
if domain and domain not in seen:
seen.add(domain)
unique_domains.append(domain)
if unique_domains:
logger.info(f"从证书中提取到域名: {unique_domains}")
else:
logger.warning("未能从证书中提取域名")
return unique_domains
except subprocess.CalledProcessError as e:
logger.error(f"提取证书域名失败: {e.stderr}")
return []
except Exception as e:
logger.error(f"提取证书域名异常: {e}")
import traceback
logger.error(f"异常堆栈: {traceback.format_exc()}")
return []
def read_cert_files(self, domain: str) -> Optional[Dict[str, str]]: def read_cert_files(self, domain: str) -> Optional[Dict[str, str]]:
"""读取证书文件""" """读取证书文件"""
domain_cert_dir = Path(self.cert_dir) / domain domain_cert_dir = Path(self.cert_dir) / domain
@ -117,36 +196,73 @@ class APISIXSSLManager:
def upload_cert_to_apisix(self, domain: str, cert_content: str, key_content: str) -> bool: def upload_cert_to_apisix(self, domain: str, cert_content: str, key_content: str) -> bool:
"""将证书上传到 APISIX""" """将证书上传到 APISIX"""
# 生成 SSL ID使用域名作为 ID # 从证书中提取所有域名(包括 SAN
cert_domains = self.extract_domains_from_cert(cert_content)
# 如果没有提取到域名,使用传入的 domain 作为后备
if not cert_domains:
logger.warning(f"无法从证书提取域名,使用传入的域名: {domain}")
cert_domains = [domain]
# 确保传入的 domain 也在列表中(如果不在的话)
if domain not in cert_domains:
cert_domains.append(domain)
# 生成 SSL ID使用主域名作为 ID用于查找和更新
ssl_id = domain.replace('.', '_').replace('*', 'wildcard') ssl_id = domain.replace('.', '_').replace('*', 'wildcard')
# 构建 SSL 配置(创建时不包含 id # 构建 SSL 配置(创建时不包含 id更新时需要 id
# SNI 列表包含证书中的所有域名
ssl_config = { ssl_config = {
"snis": [domain], "snis": cert_domains,
"cert": cert_content, "cert": cert_content,
"key": key_content "key": key_content
} }
# 检查是否已存在 logger.info(f"配置 SNI 域名列表: {cert_domains}")
check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}"
headers = self._get_apisix_headers() headers = self._get_apisix_headers()
try: try:
# 先检查是否存在 # 先检查是否已存在相同 SNI 的配置
# 方法1通过 ID 查找(如果之前创建时使用了这个 ID
check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}"
response = requests.get(check_url, headers=headers, timeout=10) response = requests.get(check_url, headers=headers, timeout=10)
existing_ssl_id = None
if response.status_code == 200: if response.status_code == 200:
existing_ssl_id = ssl_id
logger.info(f"找到现有 SSL 配置 (ID: {ssl_id})")
else:
# 方法2查询所有 SSL 配置,检查是否有相同 SNI 的配置
all_ssls_url = f"{self.apisix_admin_url}/apisix/admin/ssls"
all_response = requests.get(all_ssls_url, headers=headers, timeout=10)
if all_response.status_code == 200:
all_ssls = all_response.json()
ssl_list = all_ssls.get('list', []) if isinstance(all_ssls, dict) else all_ssls
# 检查每个 SSL 配置的 SNI 是否匹配
for ssl_item in ssl_list:
ssl_value = ssl_item.get('value', {}) if isinstance(ssl_item, dict) else ssl_item
existing_snis = ssl_value.get('snis', [])
# 检查 SNI 列表是否相同(忽略顺序)
if set(existing_snis) == set(cert_domains):
existing_ssl_id = ssl_item.get('id') or ssl_item.get('key', {}).get('id')
logger.info(f"找到现有 SSL 配置SNI 匹配 (ID: {existing_ssl_id})")
break
if existing_ssl_id:
# 更新现有证书(更新时需要 id # 更新现有证书(更新时需要 id
logger.info(f"更新 APISIX SSL 配置: {domain}") logger.info(f"更新 APISIX SSL 配置: {domain} (ID: {existing_ssl_id})")
ssl_config["id"] = ssl_id ssl_config["id"] = existing_ssl_id
response = requests.put( response = requests.put(
f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}", f"{self.apisix_admin_url}/apisix/admin/ssls/{existing_ssl_id}",
headers=headers, headers=headers,
json=ssl_config, json=ssl_config,
timeout=10 timeout=10
) )
else: else:
# 创建新证书(创建时不需要 idAPISIX 会自动生成) # 创建新证书(POST 时不包含 id让 APISIX 自动生成)
logger.info(f"创建 APISIX SSL 配置: {domain}") logger.info(f"创建 APISIX SSL 配置: {domain}")
response = requests.post( response = requests.post(
f"{self.apisix_admin_url}/apisix/admin/ssls", f"{self.apisix_admin_url}/apisix/admin/ssls",
@ -166,8 +282,14 @@ class APISIXSSLManager:
logger.error(f"上传证书到 APISIX 失败: {e}") logger.error(f"上传证书到 APISIX 失败: {e}")
return False return False
def request_certificate(self, domain: str, additional_domains: List[str] = None) -> bool: def request_certificate(self, domain: str, additional_domains: List[str] = None, max_retries: int = 3) -> bool:
"""申请 Let's Encrypt 证书""" """申请 Let's Encrypt 证书
Args:
domain: 主域名
additional_domains: 额外域名列表
max_retries: 最大重试次数默认3次
"""
domains = [domain] domains = [domain]
if additional_domains: if additional_domains:
domains.extend(additional_domains) domains.extend(additional_domains)
@ -193,34 +315,73 @@ class APISIXSSLManager:
logger.info(f"申请证书: {domain}, 命令: {' '.join(cmd)}") logger.info(f"申请证书: {domain}, 命令: {' '.join(cmd)}")
try: # 重试机制
result = subprocess.run( for attempt in range(1, max_retries + 1):
cmd, try:
capture_output=True, if attempt > 1:
text=True, logger.info(f"{attempt} 次尝试申请证书 (共 {max_retries} 次)...")
check=True, time.sleep(5) # 重试前等待5秒
timeout=300
)
if result.returncode == 0:
logger.info(f"证书申请成功: {domain}")
# 读取证书并上传到 APISIX
cert_data = self.read_cert_files(domain)
if cert_data:
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
else:
logger.error(f"无法读取证书文件: {domain}")
return False
else:
logger.error(f"证书申请失败: {result.stderr}")
return False
except subprocess.TimeoutExpired: result = subprocess.run(
logger.error(f"证书申请超时: {domain}") cmd,
return False capture_output=True,
except Exception as e: text=True,
logger.error(f"证书申请异常: {e}") check=False, # 不自动抛出异常,手动处理
return False timeout=300
)
if result.returncode == 0:
logger.info(f"证书申请成功: {domain}")
# 读取证书并上传到 APISIX
cert_data = self.read_cert_files(domain)
if cert_data:
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
else:
logger.error(f"无法读取证书文件: {domain}")
return False
else:
# 检查是否是网络超时错误
error_output = result.stderr or ""
is_timeout_error = "ReadTimeout" in error_output or "timed out" in error_output.lower()
if is_timeout_error and attempt < max_retries:
logger.warning(f"证书申请网络超时 (尝试 {attempt}/{max_retries}),将重试...")
continue
else:
logger.error(f"证书申请失败 (退出码: {result.returncode})")
if result.stdout:
logger.error(f"标准输出: {result.stdout}")
if result.stderr:
logger.error(f"错误输出: {result.stderr}")
# 如果是网络超时且已尝试所有次数,给出提示
if is_timeout_error:
logger.error("网络连接超时,可能的原因:")
logger.error("1. 服务器无法访问 Let's Encrypt 服务器 (acme-staging-v02.api.letsencrypt.org 或 acme-v02.api.letsencrypt.org)")
logger.error("2. 防火墙阻止了 HTTPS 连接")
logger.error("3. 网络不稳定,建议稍后重试")
logger.error("4. 可以检查网络连接: curl -I https://acme-staging-v02.api.letsencrypt.org/directory")
return False
except subprocess.TimeoutExpired:
if attempt < max_retries:
logger.warning(f"证书申请超时 (尝试 {attempt}/{max_retries}),将重试...")
continue
else:
logger.error(f"证书申请超时: {domain} (已尝试 {max_retries} 次)")
return False
except Exception as e:
if attempt < max_retries:
logger.warning(f"证书申请异常 (尝试 {attempt}/{max_retries}): {e},将重试...")
continue
else:
logger.error(f"证书申请异常: {e}")
import traceback
logger.error(f"异常堆栈: {traceback.format_exc()}")
return False
return False
def renew_certificate(self, domain: str) -> bool: def renew_certificate(self, domain: str) -> bool:
"""续期证书""" """续期证书"""

View File

@ -331,15 +331,21 @@ class SSLTestRunner:
print_error(f"测试验证路径异常: {e}") print_error(f"测试验证路径异常: {e}")
return False return False
def create_test_route(self, domain: str) -> bool: def create_test_route(self, domain: str, additional_domains: list = None) -> bool:
"""创建测试路由""" """创建测试路由"""
print_info(f"创建测试路由: {domain}") print_info(f"创建测试路由: {domain}")
# 构建域名列表(主域名 + 额外域名)
hosts = [domain]
if additional_domains:
hosts.extend(additional_domains)
print_info(f"路由将包含域名: {', '.join(hosts)}")
route_config = { route_config = {
"uri": "/*", "uri": "/*",
"name": domain, "name": domain,
"methods": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"], "methods": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"],
"host": domain, "hosts": hosts, # 使用 hosts 数组支持多个域名
"upstream": { "upstream": {
"nodes": [ "nodes": [
{ {
@ -373,44 +379,22 @@ class SSLTestRunner:
print_error(f"创建测试路由异常: {e}") print_error(f"创建测试路由异常: {e}")
return False return False
def request_certificate(self, domain: str) -> bool: def request_certificate(self, domain: str, additional_domains: list = None) -> bool:
"""申请证书""" """申请证书"""
print_info(f"申请证书: {domain} (staging={self.staging})") if additional_domains:
print_info(f"申请证书: {domain} + {additional_domains} (staging={self.staging})")
cmd = [ else:
self.ssl_manager.certbot_path, print_info(f"申请证书: {domain} (staging={self.staging})")
'certonly',
'--webroot',
'--webroot-path', self.webroot_path,
'--non-interactive',
'--agree-tos',
'--email', self.email,
'--cert-name', domain,
'-d', domain
]
if self.staging:
cmd.append('--staging')
# 使用 ssl_manager 的 request_certificate 方法,它已经支持额外域名
try: try:
print_info(f"执行命令: {' '.join(cmd)}") result = self.ssl_manager.request_certificate(domain, additional_domains)
result = subprocess.run( if result:
cmd,
capture_output=True,
text=True,
timeout=300
)
if result.returncode == 0:
print_success(f"证书申请成功: {domain}") print_success(f"证书申请成功: {domain}")
print_info(result.stdout)
return True return True
else: else:
print_error(f"证书申请失败: {result.stderr}") print_error(f"证书申请失败: {domain}")
return False return False
except subprocess.TimeoutExpired:
print_error("证书申请超时")
return False
except Exception as e: except Exception as e:
print_error(f"证书申请异常: {e}") print_error(f"证书申请异常: {e}")
return False return False
@ -504,7 +488,7 @@ class SSLTestRunner:
except: except:
pass pass
def run_full_test(self, domain: str = None, cleanup: bool = False): def run_full_test(self, domain: str = None, additional_domains: list = None, cleanup: bool = False):
"""运行完整测试""" """运行完整测试"""
if not domain: if not domain:
domain = self.test_domain domain = self.test_domain
@ -512,6 +496,8 @@ class SSLTestRunner:
print(f"\n{Colors.BOLD}{'='*60}") print(f"\n{Colors.BOLD}{'='*60}")
print(f"APISIX SSL 证书自动申请测试") print(f"APISIX SSL 证书自动申请测试")
print(f"测试域名: {domain}") print(f"测试域名: {domain}")
if additional_domains:
print(f"额外域名: {', '.join(additional_domains)}")
print(f"Staging 模式: {self.staging}") print(f"Staging 模式: {self.staging}")
print(f"{'='*60}{Colors.RESET}\n") print(f"{'='*60}{Colors.RESET}\n")
@ -521,10 +507,10 @@ class SSLTestRunner:
(3, "检查 Webroot 目录", lambda: self.check_webroot_directory()), (3, "检查 Webroot 目录", lambda: self.check_webroot_directory()),
(4, "检查/创建 Webroot 路由", lambda: self.check_webroot_route(domain)), (4, "检查/创建 Webroot 路由", lambda: self.check_webroot_route(domain)),
(5, "测试验证路径", lambda: self.test_verification_path(domain)), (5, "测试验证路径", lambda: self.test_verification_path(domain)),
(6, "创建测试路由", lambda: self.create_test_route(domain)), (6, "创建测试路由", lambda: self.create_test_route(domain, additional_domains)),
(7, "申请 SSL 证书", lambda: self.request_certificate(domain)), (7, "申请 SSL 证书", lambda: self.request_certificate(domain, additional_domains)),
(8, "同步证书到 APISIX", lambda: self.sync_certificate_to_apisix(domain)), # 注意:证书申请成功后会自动上传到 APISIX不需要单独同步步骤
(9, "验证证书信息", lambda: self.verify_certificate(domain)), (8, "验证证书信息", lambda: self.verify_certificate(domain)),
] ]
success_count = 0 success_count = 0
@ -564,6 +550,7 @@ def main():
parser = argparse.ArgumentParser(description='APISIX SSL 证书自动申请测试脚本') parser = argparse.ArgumentParser(description='APISIX SSL 证书自动申请测试脚本')
parser.add_argument('--domain', '-d', help='测试域名(不指定则自动生成)') parser.add_argument('--domain', '-d', help='测试域名(不指定则自动生成)')
parser.add_argument('--additional-domains', '-a', nargs='+', help='额外域名(如 www 子域名)')
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)') parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
parser.add_argument('--cleanup', action='store_true', help='测试完成后清理测试数据') parser.add_argument('--cleanup', action='store_true', help='测试完成后清理测试数据')
parser.add_argument('--no-cleanup', action='store_true', help='测试完成后不清理测试数据') parser.add_argument('--no-cleanup', action='store_true', help='测试完成后不清理测试数据')
@ -579,9 +566,18 @@ def main():
print_warning(f"未指定域名,使用自动生成的测试域名: {domain}") print_warning(f"未指定域名,使用自动生成的测试域名: {domain}")
print_info("注意:此域名需要 DNS 解析到当前服务器才能申请证书") print_info("注意:此域名需要 DNS 解析到当前服务器才能申请证书")
# 处理额外域名
additional_domains = []
if args.additional_domains:
additional_domains.extend(args.additional_domains)
cleanup = args.cleanup or (not args.no_cleanup and not args.domain) cleanup = args.cleanup or (not args.no_cleanup and not args.domain)
success = runner.run_full_test(domain, cleanup=cleanup) success = runner.run_full_test(
domain,
additional_domains=additional_domains if additional_domains else None,
cleanup=cleanup
)
sys.exit(0 if success else 1) sys.exit(0 if success else 1)