apisix/ssl_manager/ssl_manager.py

559 lines
23 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
APISIX SSL 证书自动管理脚本
功能:
1. 申请 Let's Encrypt 证书
2. 将证书上传到 APISIX
3. 自动续期管理
"""
import os
import sys
import json
import re
import time
import subprocess
import requests
import logging
from pathlib import Path
from typing import Optional, List, Dict
from datetime import datetime
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('/var/log/apisix-ssl-manager.log'),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# 默认配置(可通过环境变量覆盖)
DEFAULT_CONFIG = {
'apisix_admin_url': 'http://localhost:9180',
'apisix_admin_key': '8206e6e42b6b53243c52a767cc633137',
'certbot_path': '/usr/bin/certbot',
'cert_dir': '/etc/letsencrypt/live',
'letsencrypt_email': 'admin@jingrowtools.cn',
'letsencrypt_staging': False, # 默认使用 staging 模式,生产环境改为 False
'webroot_path': '/var/www/certbot'
}
class APISIXSSLManager:
"""APISIX SSL 证书管理器"""
def __init__(self, config_path: str = None):
"""初始化管理器"""
# 从环境变量或默认配置加载
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', DEFAULT_CONFIG['apisix_admin_url'])
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', DEFAULT_CONFIG['apisix_admin_key'])
self.certbot_path = os.getenv('CERTBOT_PATH', DEFAULT_CONFIG['certbot_path'])
self.cert_dir = os.getenv('CERT_DIR', DEFAULT_CONFIG['cert_dir'])
self.email = os.getenv('LETSENCRYPT_EMAIL', DEFAULT_CONFIG['letsencrypt_email'])
self.staging = os.getenv('LETSENCRYPT_STAGING', str(DEFAULT_CONFIG['letsencrypt_staging'])).lower() == 'true'
self.webroot_path = os.getenv('WEBROOT_PATH', DEFAULT_CONFIG['webroot_path'])
# 如果提供了配置文件,从文件加载(覆盖环境变量和默认值)
if config_path and os.path.exists(config_path):
self.load_config(config_path)
# 验证配置
self._validate_config()
# 创建 HTTP 会话,复用连接
self.session = requests.Session()
self.session.headers.update(self._get_apisix_headers())
def load_config(self, config_path: str):
"""从配置文件加载配置(可选,用于覆盖默认配置)"""
with open(config_path, 'r') as f:
config = json.load(f)
self.apisix_admin_url = config.get('apisix_admin_url', self.apisix_admin_url)
self.apisix_admin_key = config.get('apisix_admin_key', self.apisix_admin_key)
self.certbot_path = config.get('certbot_path', self.certbot_path)
self.cert_dir = config.get('cert_dir', self.cert_dir)
self.email = config.get('letsencrypt_email', self.email)
self.staging = config.get('letsencrypt_staging', self.staging)
self.webroot_path = config.get('webroot_path', self.webroot_path)
def _validate_config(self):
"""验证配置"""
if not self.email:
logger.warning("未设置 Let's Encrypt 邮箱,建议设置以便接收证书到期提醒")
if not os.path.exists(self.certbot_path):
raise FileNotFoundError(f"Certbot 未找到: {self.certbot_path}")
def _get_apisix_headers(self) -> Dict[str, str]:
"""获取 APISIX Admin API 请求头"""
return {
'X-API-KEY': self.apisix_admin_key,
'Content-Type': 'application/json'
}
def extract_domains_from_cert(self, cert_content: str) -> List[str]:
"""从证书内容中提取所有域名(包括 Subject Alternative Names"""
domains = []
try:
# 确保 cert_content 是字符串类型(因为 text=True
if isinstance(cert_content, bytes):
cert_str = cert_content.decode('utf-8')
else:
cert_str = cert_content
# 使用 openssl 命令提取证书信息
# 注意text=True 时input 应该是字符串
result = subprocess.run(
['openssl', 'x509', '-noout', '-text', '-in', '/dev/stdin'],
input=cert_str,
capture_output=True,
text=True,
check=True,
timeout=10
)
cert_text = result.stdout
# 提取 Subject CN
cn_match = re.search(r'Subject:.*?CN\s*=\s*([^\s,/\n]+)', cert_text)
if cn_match:
domains.append(cn_match.group(1))
# 提取 Subject Alternative Names
# 匹配 SAN 部分,包括多行情况
san_pattern = r'X509v3 Subject Alternative Name:\s*\n\s*((?:DNS:[^\n]+(?:\n\s+[^\n]+)*))'
san_section = re.search(san_pattern, cert_text, re.MULTILINE)
if san_section:
san_text = san_section.group(1)
# 匹配所有 DNS: 后面的域名(支持跨行)
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_text)
domains.extend(dns_matches)
# 如果上面的方法没匹配到,尝试更宽松的匹配
if not domains or (san_section is None and 'Subject Alternative Name' in cert_text):
# 直接在 SAN 部分查找所有 DNS 条目
san_start = cert_text.find('X509v3 Subject Alternative Name:')
if san_start != -1:
# 找到 SAN 部分,提取接下来的几行
san_end = cert_text.find('\n\n', san_start)
if san_end == -1:
san_end = min(san_start + 500, len(cert_text)) # 最多取500字符
san_block = cert_text[san_start:san_end]
# 匹配所有 DNS: 条目
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_block)
if dns_matches:
domains.extend(dns_matches)
# 去重并保持顺序
seen = set()
unique_domains = []
for domain in domains:
if domain and domain not in seen:
seen.add(domain)
unique_domains.append(domain)
if unique_domains:
logger.info(f"从证书中提取到域名: {unique_domains}")
else:
logger.warning("未能从证书中提取域名")
return unique_domains
except subprocess.CalledProcessError as e:
logger.error(f"提取证书域名失败: {e.stderr}")
return []
except Exception as e:
logger.error(f"提取证书域名异常: {e}")
import traceback
logger.error(f"异常堆栈: {traceback.format_exc()}")
return []
def read_cert_files(self, domain: str) -> Optional[Dict[str, str]]:
"""读取证书文件"""
domain_cert_dir = Path(self.cert_dir) / domain
cert_file = domain_cert_dir / 'fullchain.pem'
key_file = domain_cert_dir / 'privkey.pem'
if not cert_file.exists() or not key_file.exists():
logger.error(f"证书文件不存在: {domain_cert_dir}")
return None
try:
with open(cert_file, 'r') as f:
cert_content = f.read()
with open(key_file, 'r') as f:
key_content = f.read()
return {
'cert': cert_content,
'key': key_content
}
except Exception as e:
logger.error(f"读取证书文件失败: {e}")
return None
def upload_cert_to_apisix(self, domain: str, cert_content: str, key_content: str) -> bool:
"""将证书上传到 APISIX"""
# 从证书中提取所有域名(包括 SAN
cert_domains = self.extract_domains_from_cert(cert_content)
# 如果没有提取到域名,使用传入的 domain 作为后备
if not cert_domains:
logger.warning(f"无法从证书提取域名,使用传入的域名: {domain}")
cert_domains = [domain]
# 确保传入的 domain 也在列表中(如果不在的话)
if domain not in cert_domains:
cert_domains.append(domain)
# 生成 SSL ID使用主域名作为 ID用于查找和更新
ssl_id = domain.replace('.', '_').replace('*', 'wildcard')
# 构建 SSL 配置(创建时不包含 id更新时需要 id
# SNI 列表包含证书中的所有域名
ssl_config = {
"snis": cert_domains,
"cert": cert_content,
"key": key_content
}
logger.info(f"配置 SNI 域名列表: {cert_domains}")
try:
# 先检查是否已存在相同 SNI 的配置
# 方法1通过 ID 查找(如果之前创建时使用了这个 ID
check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}"
response = self.session.get(check_url, timeout=10)
existing_ssl_id = None
if response.status_code == 200:
existing_ssl_id = ssl_id
logger.info(f"找到现有 SSL 配置 (ID: {ssl_id})")
else:
# 方法2查询所有 SSL 配置,检查是否有相同 SNI 的配置
all_ssls_url = f"{self.apisix_admin_url}/apisix/admin/ssls"
all_response = self.session.get(all_ssls_url, timeout=10)
if all_response.status_code == 200:
all_ssls = all_response.json()
ssl_list = all_ssls.get('list', []) if isinstance(all_ssls, dict) else all_ssls
# 检查每个 SSL 配置的 SNI 是否匹配
for ssl_item in ssl_list:
ssl_value = ssl_item.get('value', {}) if isinstance(ssl_item, dict) else ssl_item
existing_snis = ssl_value.get('snis', [])
# 检查 SNI 列表是否相同(忽略顺序)
if set(existing_snis) == set(cert_domains):
# 从 value 中获取 id或从 key 字段中提取 id
existing_ssl_id = ssl_value.get('id')
if not existing_ssl_id and isinstance(ssl_item, dict):
# 如果 value 中没有 id尝试从 key 字段提取(格式:/apisix/ssls/xxx
key_str = ssl_item.get('key', '')
if key_str and isinstance(key_str, str):
existing_ssl_id = key_str.split('/')[-1]
logger.info(f"找到现有 SSL 配置SNI 匹配 (ID: {existing_ssl_id})")
break
if existing_ssl_id:
# 更新现有证书(更新时需要 id
logger.info(f"更新 APISIX SSL 配置: {domain} (ID: {existing_ssl_id})")
ssl_config["id"] = existing_ssl_id
response = self.session.put(
f"{self.apisix_admin_url}/apisix/admin/ssls/{existing_ssl_id}",
json=ssl_config,
timeout=10
)
else:
# 创建新证书POST 时不包含 id让 APISIX 自动生成)
logger.info(f"创建 APISIX SSL 配置: {domain}")
response = self.session.post(
f"{self.apisix_admin_url}/apisix/admin/ssls",
json=ssl_config,
timeout=10
)
if response.status_code in [200, 201]:
logger.info(f"证书上传成功: {domain}")
return True
else:
logger.error(f"证书上传失败: {response.status_code} - {response.text}")
return False
except Exception as e:
logger.error(f"上传证书到 APISIX 失败: {e}")
return False
def request_certificate(self, domain: str, additional_domains: List[str] = None, max_retries: int = 3) -> bool:
"""申请 Let's Encrypt 证书
Args:
domain: 主域名
additional_domains: 额外域名列表
max_retries: 最大重试次数默认3次
"""
domains = [domain]
if additional_domains:
domains.extend(additional_domains)
# 构建 certbot 命令
cmd = [
self.certbot_path,
'certonly',
'--webroot',
'--webroot-path', self.webroot_path,
'--non-interactive',
'--agree-tos',
'--email', self.email if self.email else 'admin@example.com',
'--cert-name', domain,
]
if self.staging:
cmd.append('--staging')
# 添加域名
for d in domains:
cmd.extend(['-d', d])
logger.info(f"申请证书: {domain}, 命令: {' '.join(cmd)}")
# 重试机制
for attempt in range(1, max_retries + 1):
try:
if attempt > 1:
logger.info(f"{attempt} 次尝试申请证书 (共 {max_retries} 次)...")
time.sleep(5) # 重试前等待5秒
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=False, # 不自动抛出异常,手动处理
timeout=300
)
if result.returncode == 0:
logger.info(f"证书申请成功: {domain}")
# 读取证书并上传到 APISIX
cert_data = self.read_cert_files(domain)
if cert_data:
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
else:
logger.error(f"无法读取证书文件: {domain}")
return False
else:
# 检查是否是网络超时错误
error_output = result.stderr or ""
is_timeout_error = "ReadTimeout" in error_output or "timed out" in error_output.lower()
if is_timeout_error and attempt < max_retries:
logger.warning(f"证书申请网络超时 (尝试 {attempt}/{max_retries}),将重试...")
continue
else:
logger.error(f"证书申请失败 (退出码: {result.returncode})")
if result.stdout:
logger.error(f"标准输出: {result.stdout}")
if result.stderr:
logger.error(f"错误输出: {result.stderr}")
# 如果是网络超时且已尝试所有次数,给出提示
if is_timeout_error:
logger.error("网络连接超时,可能的原因:")
logger.error("1. 服务器无法访问 Let's Encrypt 服务器 (acme-staging-v02.api.letsencrypt.org 或 acme-v02.api.letsencrypt.org)")
logger.error("2. 防火墙阻止了 HTTPS 连接")
logger.error("3. 网络不稳定,建议稍后重试")
logger.error("4. 可以检查网络连接: curl -I https://acme-staging-v02.api.letsencrypt.org/directory")
return False
except subprocess.TimeoutExpired:
if attempt < max_retries:
logger.warning(f"证书申请超时 (尝试 {attempt}/{max_retries}),将重试...")
continue
else:
logger.error(f"证书申请超时: {domain} (已尝试 {max_retries} 次)")
return False
except Exception as e:
if attempt < max_retries:
logger.warning(f"证书申请异常 (尝试 {attempt}/{max_retries}): {e},将重试...")
continue
else:
logger.error(f"证书申请异常: {e}")
import traceback
logger.error(f"异常堆栈: {traceback.format_exc()}")
return False
return False
def renew_certificate(self, domain: str) -> bool:
"""续期证书"""
cmd = [
self.certbot_path,
'renew',
'--cert-name', domain,
'--non-interactive',
'--webroot',
'--webroot-path', self.webroot_path,
]
if self.staging:
cmd.append('--staging')
logger.info(f"续期证书: {domain}")
try:
result = subprocess.run(
cmd,
capture_output=True,
text=True,
check=True,
timeout=300
)
if result.returncode == 0:
logger.info(f"证书续期成功: {domain}")
# 读取新证书并上传到 APISIX
cert_data = self.read_cert_files(domain)
if cert_data:
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
else:
logger.error(f"无法读取续期后的证书文件: {domain}")
return False
else:
logger.error(f"证书续期失败: {result.stderr}")
return False
except Exception as e:
logger.error(f"证书续期异常: {e}")
return False
def renew_all_certificates(self) -> Dict[str, bool]:
"""续期所有证书"""
results = {}
# 获取所有证书
cert_dir = Path(self.cert_dir)
if not cert_dir.exists():
logger.warning(f"证书目录不存在: {cert_dir}")
return results
# 查找所有证书目录
for domain_dir in cert_dir.iterdir():
if domain_dir.is_dir():
domain = domain_dir.name
results[domain] = self.renew_certificate(domain)
return results
def sync_cert_to_apisix(self, domain: str) -> bool:
"""同步现有证书到 APISIX不申请新证书"""
cert_data = self.read_cert_files(domain)
if cert_data:
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
else:
logger.error(f"无法读取证书文件: {domain}")
return False
def check_cert_expiry(self, domain: str) -> Optional[datetime]:
"""检查证书过期时间"""
cert_file = Path(self.cert_dir) / domain / 'fullchain.pem'
if not cert_file.exists():
return None
try:
result = subprocess.run(
['openssl', 'x509', '-in', str(cert_file), '-noout', '-enddate'],
capture_output=True,
text=True,
check=True
)
# 解析日期
date_str = result.stdout.strip().split('=')[1]
expiry_date = datetime.strptime(date_str, '%b %d %H:%M:%S %Y %Z')
return expiry_date
except Exception as e:
logger.error(f"检查证书过期时间失败: {e}")
return None
def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description='APISIX SSL 证书管理器')
parser.add_argument('action', choices=['request', 'renew', 'renew-all', 'sync', 'check'],
help='操作类型')
parser.add_argument('--domain', '-d', help='域名')
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
parser.add_argument('--additional-domains', '-a', nargs='+', help='额外域名')
args = parser.parse_args()
# 初始化管理器
try:
manager = APISIXSSLManager(args.config)
except Exception as e:
logger.error(f"初始化失败: {e}")
sys.exit(1)
# 执行操作
try:
if args.action == 'request':
if not args.domain:
logger.error("申请证书需要指定域名 (--domain)")
sys.exit(1)
success = manager.request_certificate(args.domain, args.additional_domains)
sys.exit(0 if success else 1)
elif args.action == 'renew':
if not args.domain:
logger.error("续期证书需要指定域名 (--domain)")
sys.exit(1)
success = manager.renew_certificate(args.domain)
sys.exit(0 if success else 1)
elif args.action == 'renew-all':
results = manager.renew_all_certificates()
failed = [d for d, s in results.items() if not s]
if failed:
logger.error(f"以下域名续期失败: {', '.join(failed)}")
sys.exit(1)
else:
logger.info("所有证书续期成功")
sys.exit(0)
elif args.action == 'sync':
if not args.domain:
logger.error("同步证书需要指定域名 (--domain)")
sys.exit(1)
success = manager.sync_cert_to_apisix(args.domain)
sys.exit(0 if success else 1)
elif args.action == 'check':
if not args.domain:
logger.error("检查证书需要指定域名 (--domain)")
sys.exit(1)
expiry = manager.check_cert_expiry(args.domain)
if expiry:
days_left = (expiry - datetime.now()).days
logger.info(f"证书过期时间: {expiry.strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"剩余天数: {days_left}")
if days_left < 30:
logger.warning(f"证书即将过期,建议续期")
else:
logger.error("无法获取证书过期时间")
sys.exit(1)
except Exception as e:
logger.error(f"执行操作失败: {e}")
sys.exit(1)
if __name__ == '__main__':
main()