mirror of
https://github.com/frappe/gunicorn.git
synced 2026-01-14 11:09:11 +08:00
Capture peer name from accept
Avoid calls to getpeername by capturing the peer name returned by accept.
This commit is contained in:
parent
548d5828da
commit
3573fd38d0
@ -6,9 +6,7 @@
|
|||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
from errno import ENOTCONN
|
|
||||||
|
|
||||||
from gunicorn.http.unreader import SocketUnreader
|
|
||||||
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
|
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
|
||||||
from gunicorn.http.errors import (
|
from gunicorn.http.errors import (
|
||||||
InvalidHeader, InvalidHeaderName, NoMoreData,
|
InvalidHeader, InvalidHeaderName, NoMoreData,
|
||||||
@ -29,9 +27,10 @@ VERSION_RE = re.compile(r"HTTP/(\d+)\.(\d+)")
|
|||||||
|
|
||||||
|
|
||||||
class Message(object):
|
class Message(object):
|
||||||
def __init__(self, cfg, unreader):
|
def __init__(self, cfg, unreader, peer_addr):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.unreader = unreader
|
self.unreader = unreader
|
||||||
|
self.peer_addr = peer_addr
|
||||||
self.version = None
|
self.version = None
|
||||||
self.headers = []
|
self.headers = []
|
||||||
self.trailers = []
|
self.trailers = []
|
||||||
@ -69,16 +68,10 @@ class Message(object):
|
|||||||
# handle scheme headers
|
# handle scheme headers
|
||||||
scheme_header = False
|
scheme_header = False
|
||||||
secure_scheme_headers = {}
|
secure_scheme_headers = {}
|
||||||
if '*' in cfg.forwarded_allow_ips:
|
if ('*' in cfg.forwarded_allow_ips or
|
||||||
|
not isinstance(self.peer_addr, tuple)
|
||||||
|
or self.peer_addr[0] in cfg.forwarded_allow_ips):
|
||||||
secure_scheme_headers = cfg.secure_scheme_headers
|
secure_scheme_headers = cfg.secure_scheme_headers
|
||||||
elif isinstance(self.unreader, SocketUnreader):
|
|
||||||
remote_addr = self.unreader.sock.getpeername()
|
|
||||||
if self.unreader.sock.family in (socket.AF_INET, socket.AF_INET6):
|
|
||||||
remote_host = remote_addr[0]
|
|
||||||
if remote_host in cfg.forwarded_allow_ips:
|
|
||||||
secure_scheme_headers = cfg.secure_scheme_headers
|
|
||||||
elif self.unreader.sock.family == socket.AF_UNIX:
|
|
||||||
secure_scheme_headers = cfg.secure_scheme_headers
|
|
||||||
|
|
||||||
# Parse headers into key/value pairs paying attention
|
# Parse headers into key/value pairs paying attention
|
||||||
# to continuation lines.
|
# to continuation lines.
|
||||||
@ -169,7 +162,7 @@ class Message(object):
|
|||||||
|
|
||||||
|
|
||||||
class Request(Message):
|
class Request(Message):
|
||||||
def __init__(self, cfg, unreader, req_number=1):
|
def __init__(self, cfg, unreader, peer_addr, req_number=1):
|
||||||
self.method = None
|
self.method = None
|
||||||
self.uri = None
|
self.uri = None
|
||||||
self.path = None
|
self.path = None
|
||||||
@ -184,7 +177,7 @@ class Request(Message):
|
|||||||
|
|
||||||
self.req_number = req_number
|
self.req_number = req_number
|
||||||
self.proxy_protocol_info = None
|
self.proxy_protocol_info = None
|
||||||
super().__init__(cfg, unreader)
|
super().__init__(cfg, unreader, peer_addr)
|
||||||
|
|
||||||
def get_data(self, unreader, buf, stop=False):
|
def get_data(self, unreader, buf, stop=False):
|
||||||
data = unreader.read()
|
data = unreader.read()
|
||||||
@ -280,16 +273,10 @@ class Request(Message):
|
|||||||
|
|
||||||
def proxy_protocol_access_check(self):
|
def proxy_protocol_access_check(self):
|
||||||
# check in allow list
|
# check in allow list
|
||||||
if isinstance(self.unreader, SocketUnreader):
|
if ("*" not in self.cfg.proxy_allow_ips and
|
||||||
try:
|
isinstance(self.peer_addr, tuple) and
|
||||||
remote_host = self.unreader.sock.getpeername()[0]
|
self.peer_addr[0] not in self.cfg.proxy_allow_ips):
|
||||||
except socket.error as e:
|
raise ForbiddenProxyRequest(self.peer_addr[0])
|
||||||
if e.args[0] == ENOTCONN:
|
|
||||||
raise ForbiddenProxyRequest("UNKNOW")
|
|
||||||
raise
|
|
||||||
if ("*" not in self.cfg.proxy_allow_ips and
|
|
||||||
remote_host not in self.cfg.proxy_allow_ips):
|
|
||||||
raise ForbiddenProxyRequest(remote_host)
|
|
||||||
|
|
||||||
def parse_proxy_protocol(self, line):
|
def parse_proxy_protocol(self, line):
|
||||||
bits = line.split()
|
bits = line.split()
|
||||||
|
|||||||
@ -11,13 +11,14 @@ class Parser(object):
|
|||||||
|
|
||||||
mesg_class = None
|
mesg_class = None
|
||||||
|
|
||||||
def __init__(self, cfg, source):
|
def __init__(self, cfg, source, source_addr):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
if hasattr(source, "recv"):
|
if hasattr(source, "recv"):
|
||||||
self.unreader = SocketUnreader(source)
|
self.unreader = SocketUnreader(source)
|
||||||
else:
|
else:
|
||||||
self.unreader = IterUnreader(source)
|
self.unreader = IterUnreader(source)
|
||||||
self.mesg = None
|
self.mesg = None
|
||||||
|
self.source_addr = source_addr
|
||||||
|
|
||||||
# request counter (for keepalive connetions)
|
# request counter (for keepalive connetions)
|
||||||
self.req_count = 0
|
self.req_count = 0
|
||||||
@ -38,7 +39,7 @@ class Parser(object):
|
|||||||
|
|
||||||
# Parse the next request
|
# Parse the next request
|
||||||
self.req_count += 1
|
self.req_count += 1
|
||||||
self.mesg = self.mesg_class(self.cfg, self.unreader, self.req_count)
|
self.mesg = self.mesg_class(self.cfg, self.unreader, self.source_addr, self.req_count)
|
||||||
if not self.mesg:
|
if not self.mesg:
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
return self.mesg
|
return self.mesg
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class AsyncWorker(base.Worker):
|
|||||||
def handle(self, listener, client, addr):
|
def handle(self, listener, client, addr):
|
||||||
req = None
|
req = None
|
||||||
try:
|
try:
|
||||||
parser = http.RequestParser(self.cfg, client)
|
parser = http.RequestParser(self.cfg, client, addr)
|
||||||
try:
|
try:
|
||||||
listener_name = listener.getsockname()
|
listener_name = listener.getsockname()
|
||||||
if not self.cfg.keepalive:
|
if not self.cfg.keepalive:
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class TConn(object):
|
|||||||
**self.cfg.ssl_options)
|
**self.cfg.ssl_options)
|
||||||
|
|
||||||
# initialize the parser
|
# initialize the parser
|
||||||
self.parser = http.RequestParser(self.cfg, self.sock)
|
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
|
||||||
|
|
||||||
def set_timeout(self):
|
def set_timeout(self):
|
||||||
# set the timeout
|
# set the timeout
|
||||||
|
|||||||
@ -131,7 +131,7 @@ class SyncWorker(base.Worker):
|
|||||||
client = ssl.wrap_socket(client, server_side=True,
|
client = ssl.wrap_socket(client, server_side=True,
|
||||||
**self.cfg.ssl_options)
|
**self.cfg.ssl_options)
|
||||||
|
|
||||||
parser = http.RequestParser(self.cfg, client)
|
parser = http.RequestParser(self.cfg, client, addr)
|
||||||
req = next(parser)
|
req = next(parser)
|
||||||
self.handle_request(listener, req, client, addr)
|
self.handle_request(listener, req, client, addr)
|
||||||
except http.errors.NoMoreData as e:
|
except http.errors.NoMoreData as e:
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class request(object):
|
|||||||
def __call__(self, func):
|
def __call__(self, func):
|
||||||
def run():
|
def run():
|
||||||
src = data_source(self.fname)
|
src = data_source(self.fname)
|
||||||
func(src, RequestParser(src, None))
|
func(src, RequestParser(src, None, None))
|
||||||
run.func_name = func.func_name
|
run.func_name = func.func_name
|
||||||
return run
|
return run
|
||||||
|
|
||||||
|
|||||||
@ -245,7 +245,7 @@ class request(object):
|
|||||||
|
|
||||||
def check(self, cfg, sender, sizer, matcher):
|
def check(self, cfg, sender, sizer, matcher):
|
||||||
cases = self.expect[:]
|
cases = self.expect[:]
|
||||||
p = RequestParser(cfg, sender())
|
p = RequestParser(cfg, sender(), None)
|
||||||
for req in p:
|
for req in p:
|
||||||
self.same(req, sizer, matcher, cases.pop(0))
|
self.same(req, sizer, matcher, cases.pop(0))
|
||||||
assert not cases
|
assert not cases
|
||||||
@ -282,5 +282,5 @@ class badrequest(object):
|
|||||||
read += chunk
|
read += chunk
|
||||||
|
|
||||||
def check(self, cfg):
|
def check(self, cfg):
|
||||||
p = RequestParser(cfg, self.send())
|
p = RequestParser(cfg, self.send(), None)
|
||||||
next(p)
|
next(p)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user