fix multiple issues with request limit

patch from Djoume Salvetti . address the following issues in gunicorn:

* Gunicorn does not limit the size of a request header (the
* limit_request_field_size configuration parameter is not used)

* When the configured request limit is lower than its maximum value, the
* maximum value is used instead. For instance if limit_request_line is
* set to 1024, gunicorn will only limit the request line to 4096 chars
* (this issue also affects limit_request_fields)

* Request limits are not limited to their maximum authorized values. For
* instance it is possible to set limit_request_line to 64K (this issue
* also affects limit_request_fields)

* Setting limit_request_fields and limit_request_field_size to 0 does
* not make them unlimited. The following patch allows limit_request_line
* and limit_request_field_size to be unlimited. limit_request_fields can
* no longer be unlimited (I can't imagine 32K fields to not be enough
* but I have a use case where 8K for the request line is not enough).

* Parsing errors (premature client disconnection) are not reported

* When request line limit is exceeded the configured value is reported
* instead of the effective value.
This commit is contained in:
benoitc 2012-05-24 12:13:34 +02:00
parent 124963249a
commit d79ff999ce
19 changed files with 123 additions and 50 deletions

View File

@ -449,8 +449,8 @@ class LimitRequestLine(Setting):
restriction on the length of a request-URI allowed for a request restriction on the length of a request-URI allowed for a request
on the server. A server needs this value to be large enough to on the server. A server needs this value to be large enough to
hold any of its resource names, including any information that hold any of its resource names, including any information that
might be passed in the query part of a GET request. By default might be passed in the query part of a GET request. Value is a number
this value is 4094 and can't be larger than 8190. from 0 (unlimited) to 8190.
This parameter can be used to prevent any DDOS attack. This parameter can be used to prevent any DDOS attack.
""" """
@ -466,10 +466,10 @@ class LimitRequestFields(Setting):
desc= """\ desc= """\
Limit the number of HTTP headers fields in a request. Limit the number of HTTP headers fields in a request.
Value is a number from 0 (unlimited) to 32768. This parameter is This parameter is used to limit the number of headers in a request to
used to limit the number of headers in a request to prevent DDOS prevent DDOS attack. Used with the `limit_request_field_size` it allows
attack. Used with the `limit_request_field_size` it allows more more safety. By default this value is 100 and can't be larger than
safety. 32768.
""" """
class LimitRequestFieldSize(Setting): class LimitRequestFieldSize(Setting):

View File

@ -6,7 +6,7 @@
class ParseException(Exception): class ParseException(Exception):
pass pass
class NoMoreData(ParseException, StopIteration): class NoMoreData(ParseException):
def __init__(self, buf=None): def __init__(self, buf=None):
self.buf = buf self.buf = buf
def __str__(self): def __str__(self):

View File

@ -32,17 +32,19 @@ class Message(object):
self.hdrre = re.compile("[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]") self.hdrre = re.compile("[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]")
# set headers limits # set headers limits
self.limit_request_fields = max(cfg.limit_request_fields, MAX_HEADERS) self.limit_request_fields = cfg.limit_request_fields
if self.limit_request_fields <= 0: if (self.limit_request_fields <= 0
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = max(cfg.limit_request_field_size, self.limit_request_field_size = cfg.limit_request_field_size
MAX_HEADERFIELD_SIZE) if (self.limit_request_field_size < 0
if self.limit_request_field_size <= 0: or self.limit_request_field_size > MAX_HEADERFIELD_SIZE):
self.limit_request_field_size = MAX_HEADERFIELD_SIZE self.limit_request_field_size = MAX_HEADERFIELD_SIZE
# set max header buffer size # set max header buffer size
max_header_field_size = self.limit_request_field_size or MAX_HEADERFIELD_SIZE
self.max_buffer_headers = self.limit_request_fields * \ self.max_buffer_headers = self.limit_request_fields * \
(self.limit_request_field_size + 2) + 4 (max_header_field_size + 2) + 4
unused = self.parse(self.unreader) unused = self.parse(self.unreader)
self.unreader.unread(unused) self.unreader.unread(unused)
@ -60,11 +62,12 @@ class Message(object):
# Parse headers into key/value pairs paying attention # Parse headers into key/value pairs paying attention
# to continuation lines. # to continuation lines.
while len(lines): while len(lines):
if len(headers) > self.limit_request_fields: if len(headers) >= self.limit_request_fields:
raise LimitRequestHeaders("limit request headers fields") raise LimitRequestHeaders("limit request headers fields")
# Parse initial header name : value pair. # Parse initial header name : value pair.
curr = lines.pop(0) curr = lines.pop(0)
header_length = len(curr)
if curr.find(":") < 0: if curr.find(":") < 0:
raise InvalidHeader(curr.strip()) raise InvalidHeader(curr.strip())
name, value = curr.split(":", 1) name, value = curr.split(":", 1)
@ -76,9 +79,13 @@ class Message(object):
# Consume value continuation lines # Consume value continuation lines
while len(lines) and lines[0].startswith((" ", "\t")): while len(lines) and lines[0].startswith((" ", "\t")):
value.append(lines.pop(0)) curr = lines.pop(0)
header_length += len(curr)
value.append(curr)
value = ''.join(value).rstrip() value = ''.join(value).rstrip()
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
headers.append((name, value)) headers.append((name, value))
return headers return headers
@ -130,9 +137,9 @@ class Request(Message):
self.fragment = None self.fragment = None
# get max request line size # get max request line size
self.limit_request_line = max(cfg.limit_request_line, self.limit_request_line = cfg.limit_request_line
MAX_REQUEST_LINE) if (self.limit_request_line < 0
if self.limit_request_line <= 0: or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE self.limit_request_line = MAX_REQUEST_LINE
super(Request, self).__init__(cfg, unreader) super(Request, self).__init__(cfg, unreader)
@ -158,8 +165,8 @@ class Request(Message):
self.get_data(unreader, buf) self.get_data(unreader, buf)
data = buf.getvalue() data = buf.getvalue()
if len(data) - 2 > self.limit_request_line: if len(data) - 2 > self.limit_request_line > 0:
raise LimitRequestLine(len(data), self.cfg.limit_request_line) raise LimitRequestLine(len(data), self.limit_request_line)
self.parse_request_line(data[:idx]) self.parse_request_line(data[:idx])
buf = StringIO() buf = StringIO()

View File

@ -37,6 +37,8 @@ class AsyncWorker(base.Worker):
if not req: if not req:
break break
self.handle_request(req, client, addr) self.handle_request(req, client, addr)
except http.errors.NoMoreData, e:
self.log.debug("Ignored premature client disconnection. %s", e)
except StopIteration, e: except StopIteration, e:
self.log.debug("Closing connection. %s", e) self.log.debug("Closing connection. %s", e)
except Exception, e: except Exception, e:

View File

@ -70,6 +70,8 @@ class SyncWorker(base.Worker):
parser = http.RequestParser(self.cfg, client) parser = http.RequestParser(self.cfg, client)
req = parser.next() req = parser.next()
self.handle_request(req, client, addr) self.handle_request(req, client, addr)
except http.errors.NoMoreData, e:
self.log.debug("Ignored premature client disconnection. %s", e)
except StopIteration, e: except StopIteration, e:
self.log.debug("Closing connection. %s", e) self.log.debug("Closing connection. %s", e)
except socket.error, e: except socket.error, e:

View File

@ -12,17 +12,21 @@ dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "valid") reqdir = os.path.join(dirname, "requests", "valid")
def a_case(fname): def a_case(fname):
expect = treq.load_py(os.path.splitext(fname)[0] + ".py") env = treq.load_py(os.path.splitext(fname)[0] + ".py")
expect = env['request']
cfg = env['cfg']
req = treq.request(fname, expect) req = treq.request(fname, expect)
for case in req.gen_cases(): for case in req.gen_cases(cfg):
case[0](*case[1:]) case[0](*case[1:])
def test_http_parser(): def test_http_parser():
for fname in glob.glob(os.path.join(reqdir, "*.http")): for fname in glob.glob(os.path.join(reqdir, "*.http")):
if os.getenv("GUNS_BLAZING"): if os.getenv("GUNS_BLAZING"):
expect = treq.load_py(os.path.splitext(fname)[0] + ".py") env = treq.load_py(os.path.splitext(fname)[0] + ".py")
expect = env['request']
cfg = env['cfg']
req = treq.request(fname, expect) req = treq.request(fname, expect)
for case in req.gen_cases(): for case in req.gen_cases(cfg):
yield case yield case
else: else:
yield (a_case, fname) yield (a_case, fname)

View File

@ -8,11 +8,21 @@ import treq
import glob import glob
import os import os
from nose.tools import raises
dirname = os.path.dirname(__file__) dirname = os.path.dirname(__file__)
reqdir = os.path.join(dirname, "requests", "invalid") reqdir = os.path.join(dirname, "requests", "invalid")
def test_http_parser(): def test_http_parser():
for fname in glob.glob(os.path.join(reqdir, "*.http")): for fname in glob.glob(os.path.join(reqdir, "*.http")):
expect = treq.load_py(os.path.splitext(fname)[0] + ".py") env = treq.load_py(os.path.splitext(fname)[0] + ".py")
req = treq.badrequest(fname, expect) expect = env['request']
yield (req.check,) cfg = env['cfg']
req = treq.badrequest(fname)
@raises(expect)
def check(fname):
return req.check(cfg)
yield check, fname # fname is pass so that we know which test failed

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
GET /test HTTP/1.1\r\n
Accept: */*\r\n
\r\n

View File

@ -0,0 +1,6 @@
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestHeaders
request = LimitRequestHeaders
cfg = Config()
cfg.set('limit_request_field_size', 10)

View File

@ -0,0 +1,5 @@
GET /test HTTP/1.1\r\n
User-Agent: curl/7.18.0 (i486-pc-linux-gnu) libcurl/7.18.0 OpenSSL/0.9.8g zlib/1.2.3.3 libidn/1.1\r\n
Host: 0.0.0.0=5000\r\n
Accept: */*\r\n
\r\n

View File

@ -0,0 +1,6 @@
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestHeaders
request = LimitRequestHeaders
cfg = Config()
cfg.set('limit_request_fields', 2)

View File

@ -0,0 +1,5 @@
GET /test HTTP/1.1\r\n
User-Agent: curl/7.18.0 (i486-pc-linux-gnu) libcurl/7.18.0 OpenSSL/0.9.8g zlib/1.2.3.3 libidn/1.1\r\n
Host: 0.0.0.0=5000\r\n
Accept: */*\r\n
\r\n

View File

@ -0,0 +1,6 @@
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestHeaders
request = LimitRequestHeaders
cfg = Config()
cfg.set('limit_request_field_size', 98)

View File

@ -0,0 +1,4 @@
GET /test HTTP/1.1\r\n
Accept:\r\n
*/*\r\n
\r\n

View File

@ -0,0 +1,6 @@
from gunicorn.config import Config
from gunicorn.http.errors import LimitRequestHeaders
request = LimitRequestHeaders
cfg = Config()
cfg.set('limit_request_field_size', 14)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -41,8 +41,9 @@ def uri(data):
def load_py(fname): def load_py(fname):
config = globals().copy() config = globals().copy()
config["uri"] = uri config["uri"] = uri
config["cfg"] = Config()
execfile(fname, config) execfile(fname, config)
return config["request"] return config
class request(object): class request(object):
def __init__(self, fname, expect): def __init__(self, fname, expect):
@ -198,7 +199,7 @@ class request(object):
# Construct a series of test cases from the permutations of # Construct a series of test cases from the permutations of
# send, size, and match functions. # send, size, and match functions.
def gen_cases(self): def gen_cases(self, cfg):
def get_funs(p): def get_funs(p):
return [v for k, v in inspect.getmembers(self) if k.startswith(p)] return [v for k, v in inspect.getmembers(self) if k.startswith(p)]
senders = get_funs("send_") senders = get_funs("send_")
@ -217,15 +218,15 @@ class request(object):
szn = sz.func_name[5:] szn = sz.func_name[5:]
snn = sn.func_name[5:] snn = sn.func_name[5:]
def test_req(sn, sz, mt): def test_req(sn, sz, mt):
self.check(sn, sz, mt) self.check(cfg, sn, sz, mt)
desc = "%s: MT: %s SZ: %s SN: %s" % (self.name, mtn, szn, snn) desc = "%s: MT: %s SZ: %s SN: %s" % (self.name, mtn, szn, snn)
test_req.description = desc test_req.description = desc
ret.append((test_req, sn, sz, mt)) ret.append((test_req, sn, sz, mt))
return ret return ret
def check(self, sender, sizer, matcher): def check(self, cfg, sender, sizer, matcher):
cases = self.expect[:] cases = self.expect[:]
p = RequestParser(Config(), sender()) p = RequestParser(cfg, sender())
for req in p: for req in p:
self.same(req, sizer, matcher, cases.pop(0)) self.same(req, sizer, matcher, cases.pop(0))
t.eq(len(cases), 0) t.eq(len(cases), 0)
@ -242,14 +243,10 @@ class request(object):
t.eq(req.trailers, exp.get("trailers", [])) t.eq(req.trailers, exp.get("trailers", []))
class badrequest(object): class badrequest(object):
def __init__(self, fname, expect): def __init__(self, fname):
self.fname = fname self.fname = fname
self.name = os.path.basename(fname) self.name = os.path.basename(fname)
self.expect = expect
if not isinstance(self.expect, list):
self.expect = [self.expect]
with open(self.fname) as handle: with open(self.fname) as handle:
self.data = handle.read() self.data = handle.read()
self.data = self.data.replace("\n", "").replace("\\r\\n", "\r\n") self.data = self.data.replace("\n", "").replace("\\r\\n", "\r\n")
@ -263,16 +260,7 @@ class badrequest(object):
yield self.data[read:read+chunk] yield self.data[read:read+chunk]
read += chunk read += chunk
def check(self): def check(self, cfg):
cases = self.expect[:] p = RequestParser(cfg, self.send())
p = RequestParser(Config(), self.send())
try:
[req for req in p] [req for req in p]
except Exception, inst:
exp = cases.pop(0)
if not issubclass(exp, Exception):
raise TypeError("Test case is not an exception calss: %s" % exp)
t.istype(inst, exp)
return