mirror of
https://github.com/frappe/gunicorn.git
synced 2026-01-14 11:09:11 +08:00
Capture peer name from accept
Avoid calls to getpeername by capturing the peer name returned by accept.
This commit is contained in:
parent
548d5828da
commit
3573fd38d0
@ -6,9 +6,7 @@
|
||||
import io
|
||||
import re
|
||||
import socket
|
||||
from errno import ENOTCONN
|
||||
|
||||
from gunicorn.http.unreader import SocketUnreader
|
||||
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
|
||||
from gunicorn.http.errors import (
|
||||
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||
@ -29,9 +27,10 @@ VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
|
||||
|
||||
|
||||
class Message(object):
|
||||
def __init__(self, cfg, unreader):
|
||||
def __init__(self, cfg, unreader, peer_addr):
|
||||
self.cfg = cfg
|
||||
self.unreader = unreader
|
||||
self.peer_addr = peer_addr
|
||||
self.version = None
|
||||
self.headers = []
|
||||
self.trailers = []
|
||||
@ -69,16 +68,10 @@ class Message(object):
|
||||
# handle scheme headers
|
||||
scheme_header = False
|
||||
secure_scheme_headers = {}
|
||||
if '*' in cfg.forwarded_allow_ips:
|
||||
if ('*' in cfg.forwarded_allow_ips or
|
||||
not isinstance(self.peer_addr, tuple)
|
||||
or self.peer_addr[0] in cfg.forwarded_allow_ips):
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
elif isinstance(self.unreader, SocketUnreader):
|
||||
remote_addr = self.unreader.sock.getpeername()
|
||||
if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6):
|
||||
remote_host = remote_addr[0]
|
||||
if remote_host in cfg.forwarded_allow_ips:
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
elif self.unreader.sock.family == socket.AF_UNIX:
|
||||
secure_scheme_headers = cfg.secure_scheme_headers
|
||||
|
||||
# Parse headers into key/value pairs paying attention
|
||||
# to continuation lines.
|
||||
@ -169,7 +162,7 @@ class Message(object):
|
||||
|
||||
|
||||
class Request(Message):
|
||||
def __init__(self, cfg, unreader, req_number=1):
|
||||
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||
self.method = None
|
||||
self.uri = None
|
||||
self.path = None
|
||||
@ -184,7 +177,7 @@ class Request(Message):
|
||||
|
||||
self.req_number = req_number
|
||||
self.proxy_protocol_info = None
|
||||
super().__init__(cfg, unreader)
|
||||
super().__init__(cfg, unreader, peer_addr)
|
||||
|
||||
def get_data(self, unreader, buf, stop=False):
|
||||
data = unreader.read()
|
||||
@ -280,16 +273,10 @@ class Request(Message):
|
||||
|
||||
def proxy_protocol_access_check(self):
|
||||
# check in allow list
|
||||
if isinstance(self.unreader, SocketUnreader):
|
||||
try:
|
||||
remote_host = self.unreader.sock.getpeername()[0]
|
||||
except socket.error as e:
|
||||
if e.args[0] == ENOTCONN:
|
||||
raise ForbiddenProxyRequest("UNKNOW")
|
||||
raise
|
||||
if ("*" not in self.cfg.proxy_allow_ips and
|
||||
remote_host not in self.cfg.proxy_allow_ips):
|
||||
raise ForbiddenProxyRequest(remote_host)
|
||||
if ("*" not in self.cfg.proxy_allow_ips and
|
||||
isinstance(self.peer_addr, tuple) and
|
||||
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
|
||||
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||
|
||||
def parse_proxy_protocol(self, line):
|
||||
bits = line.split()
|
||||
|
||||
@ -11,13 +11,14 @@ class Parser(object):
|
||||
|
||||
mesg_class = None
|
||||
|
||||
def __init__(self, cfg, source):
|
||||
def __init__(self, cfg, source, source_addr):
|
||||
self.cfg = cfg
|
||||
if hasattr(source, "recv"):
|
||||
self.unreader = SocketUnreader(source)
|
||||
else:
|
||||
self.unreader = IterUnreader(source)
|
||||
self.mesg = None
|
||||
self.source_addr = source_addr
|
||||
|
||||
# request counter (for keepalive connetions)
|
||||
self.req_count = 0
|
||||
@ -38,7 +39,7 @@ class Parser(object):
|
||||
|
||||
# Parse the next request
|
||||
self.req_count += 1
|
||||
self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count)
|
||||
self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count)
|
||||
if not self.mesg:
|
||||
raise StopIteration()
|
||||
return self.mesg
|
||||
|
||||
@ -33,7 +33,7 @@ class AsyncWorker(base.Worker):
|
||||
def handle(self, listener, client, addr):
|
||||
req = None
|
||||
try:
|
||||
parser = http.RequestParser(self.cfg, client)
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
try:
|
||||
listener_name = listener.getsockname()
|
||||
if not self.cfg.keepalive:
|
||||
|
||||
@ -53,7 +53,7 @@ class TConn(object):
|
||||
**self.cfg.ssl_options)
|
||||
|
||||
# initialize the parser
|
||||
self.parser = http.RequestParser(self.cfg, self.sock)
|
||||
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
|
||||
|
||||
def set_timeout(self):
|
||||
# set the timeout
|
||||
|
||||
@ -131,7 +131,7 @@ class SyncWorker(base.Worker):
|
||||
client = ssl.wrap_socket(client, server_side=True,
|
||||
**self.cfg.ssl_options)
|
||||
|
||||
parser = http.RequestParser(self.cfg, client)
|
||||
parser = http.RequestParser(self.cfg, client, addr)
|
||||
req = next(parser)
|
||||
self.handle_request(listener, req, client, addr)
|
||||
except http.errors.NoMoreData as e:
|
||||
|
||||
@ -29,7 +29,7 @@ class request(object):
|
||||
def __call__(self, func):
|
||||
def run():
|
||||
src = data_source(self.fname)
|
||||
func(src, RequestParser(src, None))
|
||||
func(src, RequestParser(src, None, None))
|
||||
run.func_name = func.func_name
|
||||
return run
|
||||
|
||||
|
||||
@ -245,7 +245,7 @@ class request(object):
|
||||
|
||||
def check(self, cfg, sender, sizer, matcher):
|
||||
cases = self.expect[:]
|
||||
p = RequestParser(cfg, sender())
|
||||
p = RequestParser(cfg, sender(), None)
|
||||
for req in p:
|
||||
self.same(req, sizer, matcher, cases.pop(0))
|
||||
assert not cases
|
||||
@ -282,5 +282,5 @@ class badrequest(object):
|
||||
read += chunk
|
||||
|
||||
def check(self, cfg):
|
||||
p = RequestParser(cfg, self.send())
|
||||
p = RequestParser(cfg, self.send(), None)
|
||||
next(p)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user