Capture peer name from accept

Avoid calls to getpeername by capturing the peer name returned by
accept.
This commit is contained in:
Randall Leeds 2020-12-17 16:24:37 -05:00
parent 548d5828da
commit 3573fd38d0
7 changed files with 20 additions and 32 deletions

View File

@ -6,9 +6,7 @@
import io
import re
import socket
from errno import ENOTCONN
from gunicorn.http.unreader import SocketUnreader
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData,
@ -29,9 +27,10 @@ VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
class Message(object):
def __init__(self, cfg, unreader):
def __init__(self, cfg, unreader, peer_addr):
self.cfg = cfg
self.unreader = unreader
self.peer_addr = peer_addr
self.version = None
self.headers = []
self.trailers = []
@ -69,16 +68,10 @@ class Message(object):
# handle scheme headers
scheme_header = False
secure_scheme_headers = {}
if '*' in cfg.forwarded_allow_ips:
if ('*' in cfg.forwarded_allow_ips or
not isinstance(self.peer_addr, tuple)
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers
elif isinstance(self.unreader, SocketUnreader):
remote_addr = self.unreader.sock.getpeername()
if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6):
remote_host = remote_addr[0]
if remote_host in cfg.forwarded_allow_ips:
secure_scheme_headers = cfg.secure_scheme_headers
elif self.unreader.sock.family == socket.AF_UNIX:
secure_scheme_headers = cfg.secure_scheme_headers
# Parse headers into key/value pairs paying attention
# to continuation lines.
@ -169,7 +162,7 @@ class Message(object):
class Request(Message):
def __init__(self, cfg, unreader, req_number=1):
def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.method = None
self.uri = None
self.path = None
@ -184,7 +177,7 @@ class Request(Message):
self.req_number = req_number
self.proxy_protocol_info = None
super().__init__(cfg, unreader)
super().__init__(cfg, unreader, peer_addr)
def get_data(self, unreader, buf, stop=False):
data = unreader.read()
@ -280,16 +273,10 @@ class Request(Message):
def proxy_protocol_access_check(self):
# check in allow list
if isinstance(self.unreader, SocketUnreader):
try:
remote_host = self.unreader.sock.getpeername()[0]
except socket.error as e:
if e.args[0] == ENOTCONN:
raise ForbiddenProxyRequest("UNKNOW")
raise
if ("*" not in self.cfg.proxy_allow_ips and
remote_host not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(remote_host)
if ("*" not in self.cfg.proxy_allow_ips and
isinstance(self.peer_addr, tuple) and
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(self.peer_addr[0])
def parse_proxy_protocol(self, line):
bits = line.split()

View File

@ -11,13 +11,14 @@ class Parser(object):
mesg_class = None
def __init__(self, cfg, source):
def __init__(self, cfg, source, source_addr):
self.cfg = cfg
if hasattr(source, "recv"):
self.unreader = SocketUnreader(source)
else:
self.unreader = IterUnreader(source)
self.mesg = None
self.source_addr = source_addr
# request counter (for keepalive connetions)
self.req_count = 0
@ -38,7 +39,7 @@ class Parser(object):
# Parse the next request
self.req_count += 1
self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count)
self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count)
if not self.mesg:
raise StopIteration()
return self.mesg

View File

@ -33,7 +33,7 @@ class AsyncWorker(base.Worker):
def handle(self, listener, client, addr):
req = None
try:
parser = http.RequestParser(self.cfg, client)
parser = http.RequestParser(self.cfg, client, addr)
try:
listener_name = listener.getsockname()
if not self.cfg.keepalive:

View File

@ -53,7 +53,7 @@ class TConn(object):
**self.cfg.ssl_options)
# initialize the parser
self.parser = http.RequestParser(self.cfg, self.sock)
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
def set_timeout(self):
# set the timeout

View File

@ -131,7 +131,7 @@ class SyncWorker(base.Worker):
client = ssl.wrap_socket(client, server_side=True,
**self.cfg.ssl_options)
parser = http.RequestParser(self.cfg, client)
parser = http.RequestParser(self.cfg, client, addr)
req = next(parser)
self.handle_request(listener, req, client, addr)
except http.errors.NoMoreData as e:

View File

@ -29,7 +29,7 @@ class request(object):
def __call__(self, func):
def run():
src = data_source(self.fname)
func(src, RequestParser(src, None))
func(src, RequestParser(src, None, None))
run.func_name = func.func_name
return run

View File

@ -245,7 +245,7 @@ class request(object):
def check(self, cfg, sender, sizer, matcher):
cases = self.expect[:]
p = RequestParser(cfg, sender())
p = RequestParser(cfg, sender(), None)
for req in p:
self.same(req, sizer, matcher, cases.pop(0))
assert not cases
@ -282,5 +282,5 @@ class badrequest(object):
read += chunk
def check(self, cfg):
p = RequestParser(cfg, self.send())
p = RequestParser(cfg, self.send(), None)
next(p)