560 lines
23 KiB
Python
Executable File
560 lines
23 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
APISIX SSL 证书自动管理脚本
|
||
功能:
|
||
1. 申请 Let's Encrypt 证书
|
||
2. 将证书上传到 APISIX
|
||
3. 自动续期管理
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import re
|
||
import time
|
||
import subprocess
|
||
import requests
|
||
import logging
|
||
from pathlib import Path
|
||
from typing import Optional, List, Dict
|
||
from datetime import datetime, timedelta
|
||
import base64
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler('/var/log/apisix-ssl-manager.log'),
|
||
logging.StreamHandler(sys.stdout)
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# 默认配置(可通过环境变量覆盖)
|
||
DEFAULT_CONFIG = {
|
||
'apisix_admin_url': 'http://localhost:9180',
|
||
'apisix_admin_key': '8206e6e42b6b53243c52a767cc633137',
|
||
'certbot_path': '/usr/bin/certbot',
|
||
'cert_dir': '/etc/letsencrypt/live',
|
||
'letsencrypt_email': 'admin@jingrowtools.cn',
|
||
'letsencrypt_staging': True, # 默认使用 staging 模式,生产环境改为 False
|
||
'webroot_path': '/var/www/certbot'
|
||
}
|
||
|
||
|
||
class APISIXSSLManager:
|
||
"""APISIX SSL 证书管理器"""
|
||
|
||
def __init__(self, config_path: str = None):
|
||
"""初始化管理器"""
|
||
# 从环境变量或默认配置加载
|
||
self.apisix_admin_url = os.getenv('APISIX_ADMIN_URL', DEFAULT_CONFIG['apisix_admin_url'])
|
||
self.apisix_admin_key = os.getenv('APISIX_ADMIN_KEY', DEFAULT_CONFIG['apisix_admin_key'])
|
||
self.certbot_path = os.getenv('CERTBOT_PATH', DEFAULT_CONFIG['certbot_path'])
|
||
self.cert_dir = os.getenv('CERT_DIR', DEFAULT_CONFIG['cert_dir'])
|
||
self.email = os.getenv('LETSENCRYPT_EMAIL', DEFAULT_CONFIG['letsencrypt_email'])
|
||
self.staging = os.getenv('LETSENCRYPT_STAGING', str(DEFAULT_CONFIG['letsencrypt_staging'])).lower() == 'true'
|
||
self.webroot_path = os.getenv('WEBROOT_PATH', DEFAULT_CONFIG['webroot_path'])
|
||
|
||
# 如果提供了配置文件,从文件加载(覆盖环境变量和默认值)
|
||
if config_path and os.path.exists(config_path):
|
||
self.load_config(config_path)
|
||
|
||
# 验证配置
|
||
self._validate_config()
|
||
|
||
def load_config(self, config_path: str):
|
||
"""从配置文件加载配置(可选,用于覆盖默认配置)"""
|
||
with open(config_path, 'r') as f:
|
||
config = json.load(f)
|
||
self.apisix_admin_url = config.get('apisix_admin_url', self.apisix_admin_url)
|
||
self.apisix_admin_key = config.get('apisix_admin_key', self.apisix_admin_key)
|
||
self.certbot_path = config.get('certbot_path', self.certbot_path)
|
||
self.cert_dir = config.get('cert_dir', self.cert_dir)
|
||
self.email = config.get('letsencrypt_email', self.email)
|
||
self.staging = config.get('letsencrypt_staging', self.staging)
|
||
self.webroot_path = config.get('webroot_path', self.webroot_path)
|
||
|
||
def _validate_config(self):
|
||
"""验证配置"""
|
||
if not self.email:
|
||
logger.warning("未设置 Let's Encrypt 邮箱,建议设置以便接收证书到期提醒")
|
||
if not os.path.exists(self.certbot_path):
|
||
raise FileNotFoundError(f"Certbot 未找到: {self.certbot_path}")
|
||
|
||
def _get_apisix_headers(self) -> Dict[str, str]:
|
||
"""获取 APISIX Admin API 请求头"""
|
||
return {
|
||
'X-API-KEY': self.apisix_admin_key,
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
def extract_domains_from_cert(self, cert_content: str) -> List[str]:
|
||
"""从证书内容中提取所有域名(包括 Subject Alternative Names)"""
|
||
domains = []
|
||
try:
|
||
# 确保 cert_content 是字符串类型(因为 text=True)
|
||
if isinstance(cert_content, bytes):
|
||
cert_str = cert_content.decode('utf-8')
|
||
else:
|
||
cert_str = cert_content
|
||
|
||
# 使用 openssl 命令提取证书信息
|
||
# 注意:text=True 时,input 应该是字符串
|
||
result = subprocess.run(
|
||
['openssl', 'x509', '-noout', '-text', '-in', '/dev/stdin'],
|
||
input=cert_str,
|
||
capture_output=True,
|
||
text=True,
|
||
check=True,
|
||
timeout=10
|
||
)
|
||
|
||
cert_text = result.stdout
|
||
|
||
# 提取 Subject CN
|
||
cn_match = re.search(r'Subject:.*?CN\s*=\s*([^\s,/\n]+)', cert_text)
|
||
if cn_match:
|
||
domains.append(cn_match.group(1))
|
||
|
||
# 提取 Subject Alternative Names
|
||
# 匹配 SAN 部分,包括多行情况
|
||
san_pattern = r'X509v3 Subject Alternative Name:\s*\n\s*((?:DNS:[^\n]+(?:\n\s+[^\n]+)*))'
|
||
san_section = re.search(san_pattern, cert_text, re.MULTILINE)
|
||
if san_section:
|
||
san_text = san_section.group(1)
|
||
# 匹配所有 DNS: 后面的域名(支持跨行)
|
||
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_text)
|
||
domains.extend(dns_matches)
|
||
|
||
# 如果上面的方法没匹配到,尝试更宽松的匹配
|
||
if not domains or (san_section is None and 'Subject Alternative Name' in cert_text):
|
||
# 直接在 SAN 部分查找所有 DNS 条目
|
||
san_start = cert_text.find('X509v3 Subject Alternative Name:')
|
||
if san_start != -1:
|
||
# 找到 SAN 部分,提取接下来的几行
|
||
san_end = cert_text.find('\n\n', san_start)
|
||
if san_end == -1:
|
||
san_end = min(san_start + 500, len(cert_text)) # 最多取500字符
|
||
san_block = cert_text[san_start:san_end]
|
||
# 匹配所有 DNS: 条目
|
||
dns_matches = re.findall(r'DNS:([^\s,/\n]+)', san_block)
|
||
if dns_matches:
|
||
domains.extend(dns_matches)
|
||
|
||
# 去重并保持顺序
|
||
seen = set()
|
||
unique_domains = []
|
||
for domain in domains:
|
||
if domain and domain not in seen:
|
||
seen.add(domain)
|
||
unique_domains.append(domain)
|
||
|
||
if unique_domains:
|
||
logger.info(f"从证书中提取到域名: {unique_domains}")
|
||
else:
|
||
logger.warning("未能从证书中提取域名")
|
||
|
||
return unique_domains
|
||
|
||
except subprocess.CalledProcessError as e:
|
||
logger.error(f"提取证书域名失败: {e.stderr}")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"提取证书域名异常: {e}")
|
||
import traceback
|
||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||
return []
|
||
|
||
def read_cert_files(self, domain: str) -> Optional[Dict[str, str]]:
|
||
"""读取证书文件"""
|
||
domain_cert_dir = Path(self.cert_dir) / domain
|
||
|
||
cert_file = domain_cert_dir / 'fullchain.pem'
|
||
key_file = domain_cert_dir / 'privkey.pem'
|
||
|
||
if not cert_file.exists() or not key_file.exists():
|
||
logger.error(f"证书文件不存在: {domain_cert_dir}")
|
||
return None
|
||
|
||
try:
|
||
with open(cert_file, 'r') as f:
|
||
cert_content = f.read()
|
||
with open(key_file, 'r') as f:
|
||
key_content = f.read()
|
||
|
||
return {
|
||
'cert': cert_content,
|
||
'key': key_content
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"读取证书文件失败: {e}")
|
||
return None
|
||
|
||
def upload_cert_to_apisix(self, domain: str, cert_content: str, key_content: str) -> bool:
|
||
"""将证书上传到 APISIX"""
|
||
# 从证书中提取所有域名(包括 SAN)
|
||
cert_domains = self.extract_domains_from_cert(cert_content)
|
||
|
||
# 如果没有提取到域名,使用传入的 domain 作为后备
|
||
if not cert_domains:
|
||
logger.warning(f"无法从证书提取域名,使用传入的域名: {domain}")
|
||
cert_domains = [domain]
|
||
|
||
# 确保传入的 domain 也在列表中(如果不在的话)
|
||
if domain not in cert_domains:
|
||
cert_domains.append(domain)
|
||
|
||
# 生成 SSL ID(使用主域名作为 ID,用于查找和更新)
|
||
ssl_id = domain.replace('.', '_').replace('*', 'wildcard')
|
||
|
||
# 构建 SSL 配置(创建时不包含 id,更新时需要 id)
|
||
# SNI 列表包含证书中的所有域名
|
||
ssl_config = {
|
||
"snis": cert_domains,
|
||
"cert": cert_content,
|
||
"key": key_content
|
||
}
|
||
|
||
logger.info(f"配置 SNI 域名列表: {cert_domains}")
|
||
|
||
headers = self._get_apisix_headers()
|
||
|
||
try:
|
||
# 先检查是否已存在相同 SNI 的配置
|
||
# 方法1:通过 ID 查找(如果之前创建时使用了这个 ID)
|
||
check_url = f"{self.apisix_admin_url}/apisix/admin/ssls/{ssl_id}"
|
||
response = requests.get(check_url, headers=headers, timeout=10)
|
||
|
||
existing_ssl_id = None
|
||
if response.status_code == 200:
|
||
existing_ssl_id = ssl_id
|
||
logger.info(f"找到现有 SSL 配置 (ID: {ssl_id})")
|
||
else:
|
||
# 方法2:查询所有 SSL 配置,检查是否有相同 SNI 的配置
|
||
all_ssls_url = f"{self.apisix_admin_url}/apisix/admin/ssls"
|
||
all_response = requests.get(all_ssls_url, headers=headers, timeout=10)
|
||
if all_response.status_code == 200:
|
||
all_ssls = all_response.json()
|
||
ssl_list = all_ssls.get('list', []) if isinstance(all_ssls, dict) else all_ssls
|
||
|
||
# 检查每个 SSL 配置的 SNI 是否匹配
|
||
for ssl_item in ssl_list:
|
||
ssl_value = ssl_item.get('value', {}) if isinstance(ssl_item, dict) else ssl_item
|
||
existing_snis = ssl_value.get('snis', [])
|
||
# 检查 SNI 列表是否相同(忽略顺序)
|
||
if set(existing_snis) == set(cert_domains):
|
||
# 从 value 中获取 id,或从 key 字段中提取 id
|
||
existing_ssl_id = ssl_value.get('id')
|
||
if not existing_ssl_id and isinstance(ssl_item, dict):
|
||
# 如果 value 中没有 id,尝试从 key 字段提取(格式:/apisix/ssls/xxx)
|
||
key_str = ssl_item.get('key', '')
|
||
if key_str and isinstance(key_str, str):
|
||
existing_ssl_id = key_str.split('/')[-1]
|
||
logger.info(f"找到现有 SSL 配置,SNI 匹配 (ID: {existing_ssl_id})")
|
||
break
|
||
|
||
if existing_ssl_id:
|
||
# 更新现有证书(更新时需要 id)
|
||
logger.info(f"更新 APISIX SSL 配置: {domain} (ID: {existing_ssl_id})")
|
||
ssl_config["id"] = existing_ssl_id
|
||
response = requests.put(
|
||
f"{self.apisix_admin_url}/apisix/admin/ssls/{existing_ssl_id}",
|
||
headers=headers,
|
||
json=ssl_config,
|
||
timeout=10
|
||
)
|
||
else:
|
||
# 创建新证书(POST 时不包含 id,让 APISIX 自动生成)
|
||
logger.info(f"创建 APISIX SSL 配置: {domain}")
|
||
response = requests.post(
|
||
f"{self.apisix_admin_url}/apisix/admin/ssls",
|
||
headers=headers,
|
||
json=ssl_config,
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code in [200, 201]:
|
||
logger.info(f"证书上传成功: {domain}")
|
||
return True
|
||
else:
|
||
logger.error(f"证书上传失败: {response.status_code} - {response.text}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"上传证书到 APISIX 失败: {e}")
|
||
return False
|
||
|
||
def request_certificate(self, domain: str, additional_domains: List[str] = None, max_retries: int = 3) -> bool:
|
||
"""申请 Let's Encrypt 证书
|
||
|
||
Args:
|
||
domain: 主域名
|
||
additional_domains: 额外域名列表
|
||
max_retries: 最大重试次数(默认3次)
|
||
"""
|
||
domains = [domain]
|
||
if additional_domains:
|
||
domains.extend(additional_domains)
|
||
|
||
# 构建 certbot 命令
|
||
cmd = [
|
||
self.certbot_path,
|
||
'certonly',
|
||
'--webroot',
|
||
'--webroot-path', self.webroot_path,
|
||
'--non-interactive',
|
||
'--agree-tos',
|
||
'--email', self.email if self.email else 'admin@example.com',
|
||
'--cert-name', domain,
|
||
]
|
||
|
||
if self.staging:
|
||
cmd.append('--staging')
|
||
|
||
# 添加域名
|
||
for d in domains:
|
||
cmd.extend(['-d', d])
|
||
|
||
logger.info(f"申请证书: {domain}, 命令: {' '.join(cmd)}")
|
||
|
||
# 重试机制
|
||
for attempt in range(1, max_retries + 1):
|
||
try:
|
||
if attempt > 1:
|
||
logger.info(f"第 {attempt} 次尝试申请证书 (共 {max_retries} 次)...")
|
||
time.sleep(5) # 重试前等待5秒
|
||
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
check=False, # 不自动抛出异常,手动处理
|
||
timeout=300
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
logger.info(f"证书申请成功: {domain}")
|
||
# 读取证书并上传到 APISIX
|
||
cert_data = self.read_cert_files(domain)
|
||
if cert_data:
|
||
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
|
||
else:
|
||
logger.error(f"无法读取证书文件: {domain}")
|
||
return False
|
||
else:
|
||
# 检查是否是网络超时错误
|
||
error_output = result.stderr or ""
|
||
is_timeout_error = "ReadTimeout" in error_output or "timed out" in error_output.lower()
|
||
|
||
if is_timeout_error and attempt < max_retries:
|
||
logger.warning(f"证书申请网络超时 (尝试 {attempt}/{max_retries}),将重试...")
|
||
continue
|
||
else:
|
||
logger.error(f"证书申请失败 (退出码: {result.returncode})")
|
||
if result.stdout:
|
||
logger.error(f"标准输出: {result.stdout}")
|
||
if result.stderr:
|
||
logger.error(f"错误输出: {result.stderr}")
|
||
|
||
# 如果是网络超时且已尝试所有次数,给出提示
|
||
if is_timeout_error:
|
||
logger.error("网络连接超时,可能的原因:")
|
||
logger.error("1. 服务器无法访问 Let's Encrypt 服务器 (acme-staging-v02.api.letsencrypt.org 或 acme-v02.api.letsencrypt.org)")
|
||
logger.error("2. 防火墙阻止了 HTTPS 连接")
|
||
logger.error("3. 网络不稳定,建议稍后重试")
|
||
logger.error("4. 可以检查网络连接: curl -I https://acme-staging-v02.api.letsencrypt.org/directory")
|
||
|
||
return False
|
||
|
||
except subprocess.TimeoutExpired:
|
||
if attempt < max_retries:
|
||
logger.warning(f"证书申请超时 (尝试 {attempt}/{max_retries}),将重试...")
|
||
continue
|
||
else:
|
||
logger.error(f"证书申请超时: {domain} (已尝试 {max_retries} 次)")
|
||
return False
|
||
except Exception as e:
|
||
if attempt < max_retries:
|
||
logger.warning(f"证书申请异常 (尝试 {attempt}/{max_retries}): {e},将重试...")
|
||
continue
|
||
else:
|
||
logger.error(f"证书申请异常: {e}")
|
||
import traceback
|
||
logger.error(f"异常堆栈: {traceback.format_exc()}")
|
||
return False
|
||
|
||
return False
|
||
|
||
def renew_certificate(self, domain: str) -> bool:
|
||
"""续期证书"""
|
||
cmd = [
|
||
self.certbot_path,
|
||
'renew',
|
||
'--cert-name', domain,
|
||
'--non-interactive',
|
||
'--webroot',
|
||
'--webroot-path', self.webroot_path,
|
||
]
|
||
|
||
if self.staging:
|
||
cmd.append('--staging')
|
||
|
||
logger.info(f"续期证书: {domain}")
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
check=True,
|
||
timeout=300
|
||
)
|
||
|
||
if result.returncode == 0:
|
||
logger.info(f"证书续期成功: {domain}")
|
||
# 读取新证书并上传到 APISIX
|
||
cert_data = self.read_cert_files(domain)
|
||
if cert_data:
|
||
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
|
||
else:
|
||
logger.error(f"无法读取续期后的证书文件: {domain}")
|
||
return False
|
||
else:
|
||
logger.error(f"证书续期失败: {result.stderr}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"证书续期异常: {e}")
|
||
return False
|
||
|
||
def renew_all_certificates(self) -> Dict[str, bool]:
|
||
"""续期所有证书"""
|
||
results = {}
|
||
|
||
# 获取所有证书
|
||
cert_dir = Path(self.cert_dir)
|
||
if not cert_dir.exists():
|
||
logger.warning(f"证书目录不存在: {cert_dir}")
|
||
return results
|
||
|
||
# 查找所有证书目录
|
||
for domain_dir in cert_dir.iterdir():
|
||
if domain_dir.is_dir():
|
||
domain = domain_dir.name
|
||
results[domain] = self.renew_certificate(domain)
|
||
|
||
return results
|
||
|
||
def sync_cert_to_apisix(self, domain: str) -> bool:
|
||
"""同步现有证书到 APISIX(不申请新证书)"""
|
||
cert_data = self.read_cert_files(domain)
|
||
if cert_data:
|
||
return self.upload_cert_to_apisix(domain, cert_data['cert'], cert_data['key'])
|
||
else:
|
||
logger.error(f"无法读取证书文件: {domain}")
|
||
return False
|
||
|
||
def check_cert_expiry(self, domain: str) -> Optional[datetime]:
|
||
"""检查证书过期时间"""
|
||
cert_file = Path(self.cert_dir) / domain / 'fullchain.pem'
|
||
if not cert_file.exists():
|
||
return None
|
||
|
||
try:
|
||
result = subprocess.run(
|
||
['openssl', 'x509', '-in', str(cert_file), '-noout', '-enddate'],
|
||
capture_output=True,
|
||
text=True,
|
||
check=True
|
||
)
|
||
|
||
# 解析日期
|
||
date_str = result.stdout.strip().split('=')[1]
|
||
expiry_date = datetime.strptime(date_str, '%b %d %H:%M:%S %Y %Z')
|
||
return expiry_date
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查证书过期时间失败: {e}")
|
||
return None
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description='APISIX SSL 证书管理器')
|
||
parser.add_argument('action', choices=['request', 'renew', 'renew-all', 'sync', 'check'],
|
||
help='操作类型')
|
||
parser.add_argument('--domain', '-d', help='域名')
|
||
parser.add_argument('--config', '-c', help='配置文件路径(可选,用于覆盖默认配置)')
|
||
parser.add_argument('--additional-domains', '-a', nargs='+', help='额外域名')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化管理器
|
||
try:
|
||
manager = APISIXSSLManager(args.config)
|
||
except Exception as e:
|
||
logger.error(f"初始化失败: {e}")
|
||
sys.exit(1)
|
||
|
||
# 执行操作
|
||
try:
|
||
if args.action == 'request':
|
||
if not args.domain:
|
||
logger.error("申请证书需要指定域名 (--domain)")
|
||
sys.exit(1)
|
||
success = manager.request_certificate(args.domain, args.additional_domains)
|
||
sys.exit(0 if success else 1)
|
||
|
||
elif args.action == 'renew':
|
||
if not args.domain:
|
||
logger.error("续期证书需要指定域名 (--domain)")
|
||
sys.exit(1)
|
||
success = manager.renew_certificate(args.domain)
|
||
sys.exit(0 if success else 1)
|
||
|
||
elif args.action == 'renew-all':
|
||
results = manager.renew_all_certificates()
|
||
failed = [d for d, s in results.items() if not s]
|
||
if failed:
|
||
logger.error(f"以下域名续期失败: {', '.join(failed)}")
|
||
sys.exit(1)
|
||
else:
|
||
logger.info("所有证书续期成功")
|
||
sys.exit(0)
|
||
|
||
elif args.action == 'sync':
|
||
if not args.domain:
|
||
logger.error("同步证书需要指定域名 (--domain)")
|
||
sys.exit(1)
|
||
success = manager.sync_cert_to_apisix(args.domain)
|
||
sys.exit(0 if success else 1)
|
||
|
||
elif args.action == 'check':
|
||
if not args.domain:
|
||
logger.error("检查证书需要指定域名 (--domain)")
|
||
sys.exit(1)
|
||
expiry = manager.check_cert_expiry(args.domain)
|
||
if expiry:
|
||
days_left = (expiry - datetime.now()).days
|
||
logger.info(f"证书过期时间: {expiry.strftime('%Y-%m-%d %H:%M:%S')}")
|
||
logger.info(f"剩余天数: {days_left} 天")
|
||
if days_left < 30:
|
||
logger.warning(f"证书即将过期,建议续期")
|
||
else:
|
||
logger.error("无法获取证书过期时间")
|
||
sys.exit(1)
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行操作失败: {e}")
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|