gunicorn/gunicorn/http/message.py
2012-09-27 19:14:40 +02:00

348 lines
11 KiB
Python

# -*- coding: utf-8 -
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
import re
import urlparse
import socket
from errno import ENOTCONN
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
from gunicorn.http.unreader import SocketUnreader
from gunicorn.http.body import ChunkedReader, LengthReader, EOFReader, Body
from gunicorn.http.errors import InvalidHeader, InvalidHeaderName, NoMoreData, \
InvalidRequestLine, InvalidRequestMethod, InvalidHTTPVersion, \
LimitRequestLine, LimitRequestHeaders
from gunicorn.http.errors import InvalidProxyLine, ForbiddenProxyRequest
MAX_REQUEST_LINE = 8190
MAX_HEADERS = 32768
MAX_HEADERFIELD_SIZE = 8190
class Message(object):
def __init__(self, cfg, unreader):
self.cfg = cfg
self.unreader = unreader
self.version = None
self.headers = []
self.trailers = []
self.body = None
self.hdrre = re.compile("[\x00-\x1F\x7F()<>@,;:\[\]={} \t\\\\\"]")
# set headers limits
self.limit_request_fields = cfg.limit_request_fields
if (self.limit_request_fields <= 0
or self.limit_request_fields > MAX_HEADERS):
self.limit_request_fields = MAX_HEADERS
self.limit_request_field_size = cfg.limit_request_field_size
if (self.limit_request_field_size < 0
or self.limit_request_field_size > MAX_HEADERFIELD_SIZE):
self.limit_request_field_size = MAX_HEADERFIELD_SIZE
# set max header buffer size
max_header_field_size = self.limit_request_field_size or MAX_HEADERFIELD_SIZE
self.max_buffer_headers = self.limit_request_fields * \
(max_header_field_size + 2) + 4
unused = self.parse(self.unreader)
self.unreader.unread(unused)
self.set_body_reader()
def parse(self):
raise NotImplementedError()
def parse_headers(self, data):
headers = []
# Split lines on \r\n keeping the \r\n on each line
lines = [line + "\r\n" for line in data.split("\r\n")]
# Parse headers into key/value pairs paying attention
# to continuation lines.
while len(lines):
if len(headers) >= self.limit_request_fields:
raise LimitRequestHeaders("limit request headers fields")
# Parse initial header name : value pair.
curr = lines.pop(0)
header_length = len(curr)
if curr.find(":") < 0:
raise InvalidHeader(curr.strip())
name, value = curr.split(":", 1)
name = name.rstrip(" \t").upper()
if self.hdrre.search(name):
raise InvalidHeaderName(name)
name, value = name.strip(), [value.lstrip()]
# Consume value continuation lines
while len(lines) and lines[0].startswith((" ", "\t")):
curr = lines.pop(0)
header_length += len(curr)
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers "
+ "fields size")
value.append(curr)
value = ''.join(value).rstrip()
if header_length > self.limit_request_field_size > 0:
raise LimitRequestHeaders("limit request headers fields size")
headers.append((name, value))
return headers
def set_body_reader(self):
chunked = False
content_length = None
for (name, value) in self.headers:
if name == "CONTENT-LENGTH":
content_length = value
elif name == "TRANSFER-ENCODING":
chunked = value.lower() == "chunked"
elif name == "SEC-WEBSOCKET-KEY1":
content_length = 8
if chunked:
self.body = Body(ChunkedReader(self, self.unreader))
elif content_length is not None:
try:
content_length = int(content_length)
except ValueError:
raise InvalidHeader("CONTENT-LENGTH", req=self)
if content_length < 0:
raise InvalidHeader("CONTENT-LENGTH", req=self)
self.body = Body(LengthReader(self.unreader, content_length))
else:
self.body = Body(EOFReader(self.unreader))
def should_close(self):
for (h, v) in self.headers:
if h == "CONNECTION":
v = v.lower().strip()
if v == "close":
return True
elif v == "keep-alive":
return False
break
return self.version <= (1, 0)
class Request(Message):
def __init__(self, cfg, unreader, req_number=1):
self.methre = re.compile("[A-Z0-9$-_.]{3,20}")
self.versre = re.compile("HTTP/(\d+).(\d+)")
self.method = None
self.uri = None
self.path = None
self.query = None
self.fragment = None
# get max request line size
self.limit_request_line = cfg.limit_request_line
if (self.limit_request_line < 0
or self.limit_request_line >= MAX_REQUEST_LINE):
self.limit_request_line = MAX_REQUEST_LINE
self.req_number = req_number
self.proxy_protocol_info = None
super(Request, self).__init__(cfg, unreader)
def get_data(self, unreader, buf, stop=False):
data = unreader.read()
if not data:
if stop:
raise StopIteration()
raise NoMoreData(buf.getvalue())
buf.write(data)
def parse(self, unreader):
buf = StringIO()
self.get_data(unreader, buf, stop=True)
# get request line
line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
# proxy protocol
if self.proxy_protocol(line):
# get next request line
buf = StringIO()
buf.write(rbuf)
line, rbuf = self.read_line(unreader, buf, self.limit_request_line)
self.parse_request_line(line)
buf = StringIO()
buf.write(rbuf)
# Headers
data = buf.getvalue()
idx = data.find("\r\n\r\n")
done = data[:2] == "\r\n"
while True:
idx = data.find("\r\n\r\n")
done = data[:2] == "\r\n"
if idx < 0 and not done:
self.get_data(unreader, buf)
data = buf.getvalue()
if len(data) > self.max_buffer_headers:
raise LimitRequestHeaders("max buffer headers")
else:
break
if done:
self.unreader.unread(data[2:])
return ""
self.headers = self.parse_headers(data[:idx])
ret = data[idx+4:]
buf = StringIO()
return ret
def read_line(self, unreader, buf, limit=0):
data = buf.getvalue()
while True:
idx = data.find("\r\n")
if idx >= 0:
# check if the request line is too large
if idx > limit > 0:
raise LimitRequestLine(idx, limit)
break
self.get_data(unreader, buf)
data = buf.getvalue()
if len(data) - 2 > limit > 0:
raise LimitRequestLine(len(data), limit)
return (data[:idx], # request line,
data[idx + 2:]) # residue in the buffer, skip \r\n
def proxy_protocol(self, line):
"""\
Detect, check and parse proxy protocol.
:raises: ForbiddenProxyRequest, InvalidProxyLine.
:return: True for proxy protocol line else False
"""
if not self.cfg.proxy_protocol:
return False
if self.req_number != 1:
return False
if not line.startswith("PROXY"):
return False
self.proxy_protocol_access_check()
self.parse_proxy_protocol(line)
return True
def proxy_protocol_access_check(self):
# check in allow list
if isinstance(self.unreader, SocketUnreader):
try:
remote_host = self.unreader.sock.getpeername()[0]
except socket.error as e:
if e[0] == ENOTCONN:
raise ForbiddenProxyRequest("UNKNOW")
raise
if remote_host not in self.cfg.proxy_allow_ips:
raise ForbiddenProxyRequest(remote_host)
def parse_proxy_protocol(self, line):
bits = line.split()
if len(bits) != 6:
raise InvalidProxyLine(line)
# Extract data
proto = bits[1]
s_addr = bits[2]
d_addr = bits[3]
# Validation
if proto not in ["TCP4", "TCP6"]:
raise InvalidProxyLine("protocol '%s' not supported" % proto)
if proto == "TCP4":
try:
socket.inet_pton(socket.AF_INET, s_addr)
socket.inet_pton(socket.AF_INET, d_addr)
except socket.error:
raise InvalidProxyLine(line)
elif proto == "TCP6":
try:
socket.inet_pton(socket.AF_INET6, s_addr)
socket.inet_pton(socket.AF_INET6, d_addr)
except socket.error:
raise InvalidProxyLine(line)
try:
s_port = int(bits[4])
d_port = int(bits[5])
except ValueError:
raise InvalidProxyLine("invalid port %s" % line)
if not ((0 <= s_port <= 65535) and (0 <= d_port <= 65535)):
raise InvalidProxyLine("invalid port %s" % line)
# Set data
self.proxy_protocol_info = {
"proxy_protocol": proto,
"client_addr": s_addr,
"client_port": s_port,
"proxy_addr": d_addr,
"proxy_port": d_port
}
def parse_request_line(self, line):
bits = line.split(None, 2)
if len(bits) != 3:
raise InvalidRequestLine(line)
# Method
if not self.methre.match(bits[0]):
raise InvalidRequestMethod(bits[0])
self.method = bits[0].upper()
# URI
# When the path starts with //, urlsplit considers it as a
# relative uri while the RDF says it shouldnt
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5.1.2
# considers it as an absolute url.
# fix issue #297
if bits[1].startswith("//"):
self.uri = bits[1][1:]
else:
self.uri = bits[1]
parts = urlparse.urlsplit(self.uri)
self.path = parts.path or ""
self.query = parts.query or ""
self.fragment = parts.fragment or ""
# Version
match = self.versre.match(bits[2])
if match is None:
raise InvalidHTTPVersion(bits[2])
self.version = (int(match.group(1)), int(match.group(2)))
def set_body_reader(self):
super(Request, self).set_body_reader()
if isinstance(self.body.reader, EOFReader):
self.body = Body(LengthReader(self.unreader, 0))