Added parameter to ssl_context hook for constructing default context

Signed-off-by: Tero Saarni <tero.saarni@est.tech>
This commit is contained in:
Tero Saarni 2022-02-11 19:26:56 +02:00
parent 5a581c0b14
commit 362a52bd84
3 changed files with 22 additions and 25 deletions

View File

@ -215,32 +215,26 @@ def worker_int(worker):
def worker_abort(worker): def worker_abort(worker):
worker.log.info("worker received SIGABRT signal") worker.log.info("worker received SIGABRT signal")
def ssl_context(conf): def ssl_context(conf, default_ssl_context_factory):
import ssl import ssl
def set_defaults(context): # The default SSLContext returned by the factory function is initialized
context.verify_mode = conf.cert_reqs # with the TLS parameters from config, including TLS certificates and other
context.minimum_version = ssl.TLSVersion.TLSv1_3 # parameters.
if conf.ciphers: context = default_ssl_context_factory()
context.set_ciphers(conf.ciphers)
if conf.ca_certs:
context.load_verify_locations(cafile=conf.ca_certs)
# Return different server certificate depending which hostname the client # The SSLContext can be further customized, for example by enforcing
# uses. Requires Python 3.7 or later. # minimum TLS version.
context.minimum_version = ssl.TLSVersion.TLSv1_3
# Server can also return different server certificate depending which
# hostname the client uses. Requires Python 3.7 or later.
def sni_callback(socket, server_hostname, context): def sni_callback(socket, server_hostname, context):
if server_hostname == "foo.127.0.0.1.nip.io": if server_hostname == "foo.127.0.0.1.nip.io":
new_context = ssl.SSLContext() new_context = default_ssl_context_factory()
new_context.load_cert_chain(certfile="foo.pem", keyfile="foo-key.pem") new_context.load_cert_chain(certfile="foo.pem", keyfile="foo-key.pem")
set_defaults(new_context)
socket.context = new_context socket.context = new_context
context = ssl.SSLContext(conf.ssl_version)
context.sni_callback = sni_callback context.sni_callback = sni_callback
set_defaults(context)
# Load fallback certificate that will be returned when there is no match
# or client did not set TLS SNI (server_hostname == None)
context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile)
return context return context

View File

@ -1970,11 +1970,11 @@ class OnExit(Setting):
class NewSSLContext(Setting): class NewSSLContext(Setting):
name = "ssl_context" name = "ssl_context"
section = "Server Hooks" section = "Server Hooks"
validator = validate_callable(1) validator = validate_callable(2)
type = callable type = callable
def ssl_context(config): def ssl_context(config, default_ssl_context_factory):
return None return default_ssl_context_factory()
default = staticmethod(ssl_context) default = staticmethod(ssl_context)
desc = """\ desc = """\
@ -1983,7 +1983,10 @@ class NewSSLContext(Setting):
Allows fully customized SSL context to be used in place of the default Allows fully customized SSL context to be used in place of the default
context. context.
The callable needs to accept a single instance variable for the Config. The callable needs to accept an instance variable for the Config and
a factory function that returns default SSLContext which is initialized
with certificates, private key, cert_reqs, and ciphers according to
config and can be further customized by the callable.
The callable needs to return SSLContext object. The callable needs to return SSLContext object.
""" """

View File

@ -212,8 +212,7 @@ def close_sockets(listeners, unlink=True):
os.unlink(sock_name) os.unlink(sock_name)
def ssl_context(conf): def ssl_context(conf):
context = conf.ssl_context(conf) def default_ssl_context_factory():
if context is None:
context = ssl.SSLContext(conf.ssl_version) context = ssl.SSLContext(conf.ssl_version)
context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile) context.load_cert_chain(certfile=conf.certfile, keyfile=conf.keyfile)
context.verify_mode = conf.cert_reqs context.verify_mode = conf.cert_reqs
@ -221,8 +220,9 @@ def ssl_context(conf):
context.set_ciphers(conf.ciphers) context.set_ciphers(conf.ciphers)
if conf.ca_certs: if conf.ca_certs:
context.load_verify_locations(cafile=conf.ca_certs) context.load_verify_locations(cafile=conf.ca_certs)
return context
return context return conf.ssl_context(conf, default_ssl_context_factory)
def ssl_wrap_socket(sock, conf): def ssl_wrap_socket(sock, conf):
return ssl_context(conf).wrap_socket(sock, server_side=True, return ssl_context(conf).wrap_socket(sock, server_side=True,