Capture peer name from accept

Avoid calls to getpeername by capturing the peer name returned by
accept.
This commit is contained in:
Randall Leeds 2020-12-17 16:24:37 -05:00
parent 548d5828da
commit 3573fd38d0
7 changed files with 20 additions and 32 deletions

View File

@ -6,9 +6,7 @@
import io import io
import re import re
import socket import socket
from errno import ENOTCONN
from gunicorn.http.unreader import SocketUnreader
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
from gunicorn.http.errors import ( from gunicorn.http.errors import (
InvalidHeader, InvalidHeaderName, NoMoreData, InvalidHeader, InvalidHeaderName, NoMoreData,
@ -29,9 +27,10 @@ VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
class Message(object): class Message(object):
def __init__(self, cfg, unreader): def __init__(self, cfg, unreader, peer_addr):
self.cfg = cfg self.cfg = cfg
self.unreader = unreader self.unreader = unreader
self.peer_addr = peer_addr
self.version = None self.version = None
self.headers = [] self.headers = []
self.trailers = [] self.trailers = []
@ -69,16 +68,10 @@ class Message(object):
# handle scheme headers # handle scheme headers
scheme_header = False scheme_header = False
secure_scheme_headers = {} secure_scheme_headers = {}
if '*' in cfg.forwarded_allow_ips: if ('*' in cfg.forwarded_allow_ips or
not isinstance(self.peer_addr, tuple)
or self.peer_addr[0] in cfg.forwarded_allow_ips):
secure_scheme_headers = cfg.secure_scheme_headers secure_scheme_headers = cfg.secure_scheme_headers
elif isinstance(self.unreader, SocketUnreader):
remote_addr = self.unreader.sock.getpeername()
if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6):
remote_host = remote_addr[0]
if remote_host in cfg.forwarded_allow_ips:
secure_scheme_headers = cfg.secure_scheme_headers
elif self.unreader.sock.family == socket.AF_UNIX:
secure_scheme_headers = cfg.secure_scheme_headers
# Parse headers into key/value pairs paying attention # Parse headers into key/value pairs paying attention
# to continuation lines. # to continuation lines.
@ -169,7 +162,7 @@ class Message(object):
class Request(Message): class Request(Message):
def __init__(self, cfg, unreader, req_number=1): def __init__(self, cfg, unreader, peer_addr, req_number=1):
self.method = None self.method = None
self.uri = None self.uri = None
self.path = None self.path = None
@ -184,7 +177,7 @@ class Request(Message):
self.req_number = req_number self.req_number = req_number
self.proxy_protocol_info = None self.proxy_protocol_info = None
super().__init__(cfg, unreader) super().__init__(cfg, unreader, peer_addr)
def get_data(self, unreader, buf, stop=False): def get_data(self, unreader, buf, stop=False):
data = unreader.read() data = unreader.read()
@ -280,16 +273,10 @@ class Request(Message):
def proxy_protocol_access_check(self): def proxy_protocol_access_check(self):
# check in allow list # check in allow list
if isinstance(self.unreader, SocketUnreader): if ("*" not in self.cfg.proxy_allow_ips and
try: isinstance(self.peer_addr, tuple) and
remote_host = self.unreader.sock.getpeername()[0] self.peer_addr[0] not in self.cfg.proxy_allow_ips):
except socket.error as e: raise ForbiddenProxyRequest(self.peer_addr[0])
if e.args[0] == ENOTCONN:
raise ForbiddenProxyRequest("UNKNOW")
raise
if ("*" not in self.cfg.proxy_allow_ips and
remote_host not in self.cfg.proxy_allow_ips):
raise ForbiddenProxyRequest(remote_host)
def parse_proxy_protocol(self, line): def parse_proxy_protocol(self, line):
bits = line.split() bits = line.split()

View File

@ -11,13 +11,14 @@ class Parser(object):
mesg_class = None mesg_class = None
def __init__(self, cfg, source): def __init__(self, cfg, source, source_addr):
self.cfg = cfg self.cfg = cfg
if hasattr(source, "recv"): if hasattr(source, "recv"):
self.unreader = SocketUnreader(source) self.unreader = SocketUnreader(source)
else: else:
self.unreader = IterUnreader(source) self.unreader = IterUnreader(source)
self.mesg = None self.mesg = None
self.source_addr = source_addr
# request counter (for keepalive connetions) # request counter (for keepalive connetions)
self.req_count = 0 self.req_count = 0
@ -38,7 +39,7 @@ class Parser(object):
# Parse the next request # Parse the next request
self.req_count += 1 self.req_count += 1
self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count) self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count)
if not self.mesg: if not self.mesg:
raise StopIteration() raise StopIteration()
return self.mesg return self.mesg

View File

@ -33,7 +33,7 @@ class AsyncWorker(base.Worker):
def handle(self, listener, client, addr): def handle(self, listener, client, addr):
req = None req = None
try: try:
parser = http.RequestParser(self.cfg, client) parser = http.RequestParser(self.cfg, client, addr)
try: try:
listener_name = listener.getsockname() listener_name = listener.getsockname()
if not self.cfg.keepalive: if not self.cfg.keepalive:

View File

@ -53,7 +53,7 @@ class TConn(object):
**self.cfg.ssl_options) **self.cfg.ssl_options)
# initialize the parser # initialize the parser
self.parser = http.RequestParser(self.cfg, self.sock) self.parser = http.RequestParser(self.cfg, self.sock, self.client)
def set_timeout(self): def set_timeout(self):
# set the timeout # set the timeout

View File

@ -131,7 +131,7 @@ class SyncWorker(base.Worker):
client = ssl.wrap_socket(client, server_side=True, client = ssl.wrap_socket(client, server_side=True,
**self.cfg.ssl_options) **self.cfg.ssl_options)
parser = http.RequestParser(self.cfg, client) parser = http.RequestParser(self.cfg, client, addr)
req = next(parser) req = next(parser)
self.handle_request(listener, req, client, addr) self.handle_request(listener, req, client, addr)
except http.errors.NoMoreData as e: except http.errors.NoMoreData as e:

View File

@ -29,7 +29,7 @@ class request(object):
def __call__(self, func): def __call__(self, func):
def run(): def run():
src = data_source(self.fname) src = data_source(self.fname)
func(src, RequestParser(src, None)) func(src, RequestParser(src, None, None))
run.func_name = func.func_name run.func_name = func.func_name
return run return run

View File

@ -245,7 +245,7 @@ class request(object):
def check(self, cfg, sender, sizer, matcher): def check(self, cfg, sender, sizer, matcher):
cases = self.expect[:] cases = self.expect[:]
p = RequestParser(cfg, sender()) p = RequestParser(cfg, sender(), None)
for req in p: for req in p:
self.same(req, sizer, matcher, cases.pop(0)) self.same(req, sizer, matcher, cases.pop(0))
assert not cases assert not cases
@ -282,5 +282,5 @@ class badrequest(object):
read += chunk read += chunk
def check(self, cfg): def check(self, cfg):
p = RequestParser(cfg, self.send()) p = RequestParser(cfg, self.send(), None)
next(p) next(p)