diff --git a/gunicorn/http/message.py b/gunicorn/http/message.py index 0c6bc053..0dda58db 100644 --- a/gunicorn/http/message.py +++ b/gunicorn/http/message.py @@ -71,11 +71,11 @@ class Message(object): secure_scheme_headers = cfg.secure_scheme_headers elif isinstance(self.unreader, SocketUnreader): remote_addr = self.unreader.sock.getpeername() - if isinstance(remote_addr, tuple): + if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6): remote_host = remote_addr[0] if remote_host in cfg.forwarded_allow_ips: secure_scheme_headers = cfg.secure_scheme_headers - elif isinstance(remote_addr, str): + elif self.unreader.sock.family == socket.AF_UNIX: secure_scheme_headers = cfg.secure_scheme_headers # Parse headers into key/value pairs paying attention diff --git a/gunicorn/sock.py b/gunicorn/sock.py index 08ede89e..f61443a1 100644 --- a/gunicorn/sock.py +++ b/gunicorn/sock.py @@ -132,7 +132,7 @@ def _sock_type(addr): sock_type = TCP6Socket else: sock_type = TCPSocket - elif isinstance(addr, str): + elif isinstance(addr, (str, bytes)): sock_type = UnixSocket else: raise TypeError("Unable to create socket from: %r" % addr) diff --git a/tests/test_sock.py b/tests/test_sock.py index 29522349..f70ae09e 100644 --- a/tests/test_sock.py +++ b/tests/test_sock.py @@ -11,6 +11,27 @@ except ImportError: from gunicorn import sock +@mock.patch('os.stat') +def test_create_sockets_unix_bytes(stat): + conf = mock.Mock(address=[b'127.0.0.1:8000']) + log = mock.Mock() + with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): + listeners = sock.create_sockets(conf, log) + assert len(listeners) == 1 + print(type(listeners[0])) + assert isinstance(listeners[0], sock.UnixSocket) + + +@mock.patch('os.stat') +def test_create_sockets_unix_strings(stat): + conf = mock.Mock(address=['127.0.0.1:8000']) + log = mock.Mock() + with mock.patch.object(sock.UnixSocket, '__init__', lambda *args: None): + listeners = sock.create_sockets(conf, log) + assert len(listeners) == 1 + assert isinstance(listeners[0], sock.UnixSocket) + + def test_socket_close(): listener1 = mock.Mock() listener1.getsockname.return_value = ('127.0.0.1', '80')