mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-02 18:51:31 +08:00
Add write flow control and HTTP/2 streaming
- Add FlowControl class for transport-level write backpressure - Integrate flow control into HTTP/1.1 protocol to prevent memory issues with large streaming responses - Set write buffer high water mark to 64KB - Add pause_writing/resume_writing protocol callbacks - Stream HTTP/2 responses immediately instead of buffering - Add _convert_h2_headers helper for cleaner header conversion
This commit is contained in:
parent
22bdca22e1
commit
464cbbfad5
@ -77,8 +77,48 @@ _cached_date_time = 0.0
|
|||||||
# Pre-compute common chunk size prefixes to avoid repeated formatting
|
# Pre-compute common chunk size prefixes to avoid repeated formatting
|
||||||
_CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)}
|
_CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)}
|
||||||
|
|
||||||
# High water mark for write buffer backpressure (256KB)
|
# High water mark for write buffer backpressure (64KB)
|
||||||
_WRITE_BUFFER_HIGH_WATER = 262144
|
HIGH_WATER_LIMIT = 65536
|
||||||
|
|
||||||
|
|
||||||
|
class FlowControl:
|
||||||
|
"""Manage transport-level write flow control.
|
||||||
|
|
||||||
|
Blocks send() when transport buffer exceeds high water mark,
|
||||||
|
preventing memory issues with large streaming responses.
|
||||||
|
"""
|
||||||
|
__slots__ = ('_transport', 'read_paused', 'write_paused', '_is_writable_event')
|
||||||
|
|
||||||
|
def __init__(self, transport):
|
||||||
|
self._transport = transport
|
||||||
|
self.read_paused = False
|
||||||
|
self.write_paused = False
|
||||||
|
self._is_writable_event = asyncio.Event()
|
||||||
|
self._is_writable_event.set()
|
||||||
|
|
||||||
|
async def drain(self):
|
||||||
|
"""Wait until transport is writable."""
|
||||||
|
await self._is_writable_event.wait()
|
||||||
|
|
||||||
|
def pause_reading(self):
|
||||||
|
if not self.read_paused:
|
||||||
|
self.read_paused = True
|
||||||
|
self._transport.pause_reading()
|
||||||
|
|
||||||
|
def resume_reading(self):
|
||||||
|
if self.read_paused:
|
||||||
|
self.read_paused = False
|
||||||
|
self._transport.resume_reading()
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
if not self.write_paused:
|
||||||
|
self.write_paused = True
|
||||||
|
self._is_writable_event.clear()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
if self.write_paused:
|
||||||
|
self.write_paused = False
|
||||||
|
self._is_writable_event.set()
|
||||||
|
|
||||||
|
|
||||||
def _get_cached_date_header():
|
def _get_cached_date_header():
|
||||||
@ -300,6 +340,9 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
self._is_ssl = False
|
self._is_ssl = False
|
||||||
self._use_callback_parser = False
|
self._use_callback_parser = False
|
||||||
|
|
||||||
|
# Write flow control
|
||||||
|
self._flow_control = None
|
||||||
|
|
||||||
def connection_made(self, transport):
|
def connection_made(self, transport):
|
||||||
"""Called when a connection is established."""
|
"""Called when a connection is established."""
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
@ -322,6 +365,10 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
self._data_event = asyncio.Event()
|
self._data_event = asyncio.Event()
|
||||||
self.writer = transport
|
self.writer = transport
|
||||||
|
|
||||||
|
# Setup flow control for HTTP/1.x
|
||||||
|
self._flow_control = FlowControl(transport)
|
||||||
|
transport.set_write_buffer_limits(high=HIGH_WATER_LIMIT)
|
||||||
|
|
||||||
# Check if callback parser should be used
|
# Check if callback parser should be used
|
||||||
self._use_callback_parser = self._should_use_callback_parser()
|
self._use_callback_parser = self._should_use_callback_parser()
|
||||||
|
|
||||||
@ -545,6 +592,16 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
if self._task and not self._task.done():
|
if self._task and not self._task.done():
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
|
def pause_writing(self):
|
||||||
|
"""Called by transport when write buffer exceeds high water mark."""
|
||||||
|
if self._flow_control:
|
||||||
|
self._flow_control.pause_writing()
|
||||||
|
|
||||||
|
def resume_writing(self):
|
||||||
|
"""Called by transport when write buffer drains below low water mark."""
|
||||||
|
if self._flow_control:
|
||||||
|
self._flow_control.resume_writing()
|
||||||
|
|
||||||
def _safe_write(self, data):
|
def _safe_write(self, data):
|
||||||
"""Write data to transport, handling connection errors gracefully.
|
"""Write data to transport, handling connection errors gracefully.
|
||||||
|
|
||||||
@ -936,6 +993,9 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
if body:
|
if body:
|
||||||
self._send_body(body, chunked=use_chunked)
|
self._send_body(body, chunked=use_chunked)
|
||||||
response_sent += len(body)
|
response_sent += len(body)
|
||||||
|
# Apply write backpressure for streaming responses
|
||||||
|
if self._flow_control:
|
||||||
|
await self._flow_control.drain()
|
||||||
|
|
||||||
if not more_body:
|
if not more_body:
|
||||||
if use_chunked:
|
if use_chunked:
|
||||||
@ -1349,19 +1409,33 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
pass
|
pass
|
||||||
self._close_transport()
|
self._close_transport()
|
||||||
|
|
||||||
|
def _convert_h2_headers(self, headers):
|
||||||
|
"""Convert ASGI headers to HTTP/2 format (lowercase string names)."""
|
||||||
|
result = []
|
||||||
|
for name, value in headers:
|
||||||
|
if isinstance(name, bytes):
|
||||||
|
name = name.decode("latin-1")
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode("latin-1")
|
||||||
|
result.append((name.lower(), value))
|
||||||
|
return result
|
||||||
|
|
||||||
async def _handle_http2_request(self, request, h2_conn, sockname, peername):
|
async def _handle_http2_request(self, request, h2_conn, sockname, peername):
|
||||||
"""Handle a single HTTP/2 request."""
|
"""Handle a single HTTP/2 request with streaming support.
|
||||||
|
|
||||||
|
Streams response body chunks immediately instead of buffering,
|
||||||
|
enabling SSE, streaming downloads, and other real-time use cases.
|
||||||
|
"""
|
||||||
stream_id = request.stream.stream_id
|
stream_id = request.stream.stream_id
|
||||||
scope = self._build_http2_scope(request, sockname, peername)
|
scope = self._build_http2_scope(request, sockname, peername)
|
||||||
|
|
||||||
response_started = False
|
response_started = False
|
||||||
response_complete = False
|
response_complete = False
|
||||||
|
headers_sent = False
|
||||||
exc_to_raise = None
|
exc_to_raise = None
|
||||||
|
|
||||||
response_status = 500
|
response_status = 500
|
||||||
response_headers = []
|
response_headers = []
|
||||||
response_body = b''
|
response_sent = 0
|
||||||
response_trailers = []
|
|
||||||
|
|
||||||
async def receive():
|
async def receive():
|
||||||
# For HTTP/2, the body is already buffered in the stream
|
# For HTTP/2, the body is already buffered in the stream
|
||||||
@ -1373,8 +1447,8 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def send(message):
|
async def send(message):
|
||||||
nonlocal response_started, response_complete, exc_to_raise
|
nonlocal response_started, response_complete, headers_sent
|
||||||
nonlocal response_status, response_headers, response_body
|
nonlocal response_status, response_headers, response_sent, exc_to_raise
|
||||||
|
|
||||||
msg_type = message["type"]
|
msg_type = message["type"]
|
||||||
|
|
||||||
@ -1382,14 +1456,7 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
# Handle informational responses (1xx) like 103 Early Hints over HTTP/2
|
# Handle informational responses (1xx) like 103 Early Hints over HTTP/2
|
||||||
info_status = message.get("status")
|
info_status = message.get("status")
|
||||||
info_headers = message.get("headers", [])
|
info_headers = message.get("headers", [])
|
||||||
# Convert headers to list of string tuples
|
headers = self._convert_h2_headers(info_headers)
|
||||||
headers = []
|
|
||||||
for name, value in info_headers:
|
|
||||||
if isinstance(name, bytes):
|
|
||||||
name = name.decode("latin-1")
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
value = value.decode("latin-1")
|
|
||||||
headers.append((name, value))
|
|
||||||
await h2_conn.send_informational(stream_id, info_status, headers)
|
await h2_conn.send_informational(stream_id, info_status, headers)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -1400,6 +1467,7 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
response_started = True
|
response_started = True
|
||||||
response_status = message["status"]
|
response_status = message["status"]
|
||||||
response_headers = message.get("headers", [])
|
response_headers = message.get("headers", [])
|
||||||
|
# Don't send headers yet - wait for first body chunk
|
||||||
|
|
||||||
elif msg_type == "http.response.body":
|
elif msg_type == "http.response.body":
|
||||||
if not response_started:
|
if not response_started:
|
||||||
@ -1412,10 +1480,31 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
body = message.get("body", b"")
|
body = message.get("body", b"")
|
||||||
more_body = message.get("more_body", False)
|
more_body = message.get("more_body", False)
|
||||||
|
|
||||||
|
# Send headers with first body chunk
|
||||||
|
if not headers_sent:
|
||||||
|
headers = self._convert_h2_headers(response_headers)
|
||||||
|
response_hdrs = [(':status', str(response_status))]
|
||||||
|
response_hdrs.extend(headers)
|
||||||
|
|
||||||
|
# Send headers without end_stream since we have body
|
||||||
|
stream = h2_conn.streams.get(stream_id)
|
||||||
|
if stream is None:
|
||||||
|
exc_to_raise = RuntimeError("Stream closed")
|
||||||
|
return
|
||||||
|
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
|
||||||
|
stream.send_headers(response_hdrs, end_stream=False)
|
||||||
|
await h2_conn._send_pending_data()
|
||||||
|
headers_sent = True
|
||||||
|
|
||||||
|
# Stream body immediately
|
||||||
if body:
|
if body:
|
||||||
response_body += body
|
await h2_conn.send_data(stream_id, body, end_stream=not more_body)
|
||||||
|
response_sent += len(body)
|
||||||
|
|
||||||
if not more_body:
|
if not more_body:
|
||||||
|
if not body:
|
||||||
|
# Empty final chunk - send end_stream
|
||||||
|
await h2_conn.send_data(stream_id, b"", end_stream=True)
|
||||||
response_complete = True
|
response_complete = True
|
||||||
|
|
||||||
elif msg_type == "http.response.trailers":
|
elif msg_type == "http.response.trailers":
|
||||||
@ -1423,15 +1512,8 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
exc_to_raise = RuntimeError("Cannot send trailers before body complete")
|
exc_to_raise = RuntimeError("Cannot send trailers before body complete")
|
||||||
return
|
return
|
||||||
trailer_headers = message.get("headers", [])
|
trailer_headers = message.get("headers", [])
|
||||||
# Convert to list of tuples with string values
|
trailers = self._convert_h2_headers(trailer_headers)
|
||||||
trailers = []
|
await h2_conn.send_trailers(stream_id, trailers)
|
||||||
for name, value in trailer_headers:
|
|
||||||
if isinstance(name, bytes):
|
|
||||||
name = name.decode("latin-1")
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
value = value.decode("latin-1")
|
|
||||||
trailers.append((name, value))
|
|
||||||
response_trailers.extend(trailers)
|
|
||||||
|
|
||||||
# Only build environ for logging if access logging is enabled
|
# Only build environ for logging if access logging is enabled
|
||||||
access_log_enabled = self.log.access_log_enabled
|
access_log_enabled = self.log.access_log_enabled
|
||||||
@ -1444,48 +1526,26 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
if exc_to_raise is not None:
|
if exc_to_raise is not None:
|
||||||
raise exc_to_raise
|
raise exc_to_raise
|
||||||
|
|
||||||
# Send response via HTTP/2
|
# Handle case where app didn't send any response
|
||||||
if response_started:
|
if not response_started:
|
||||||
# Convert headers to list of tuples
|
|
||||||
headers = []
|
|
||||||
for name, value in response_headers:
|
|
||||||
if isinstance(name, bytes):
|
|
||||||
name = name.decode("latin-1")
|
|
||||||
if isinstance(value, bytes):
|
|
||||||
value = value.decode("latin-1")
|
|
||||||
headers.append((name, value))
|
|
||||||
|
|
||||||
if response_trailers:
|
|
||||||
# Send headers, body, then trailers separately
|
|
||||||
response_hdrs = [(':status', str(response_status))]
|
|
||||||
for name, value in headers:
|
|
||||||
response_hdrs.append((name.lower(), str(value)))
|
|
||||||
|
|
||||||
# Send headers without ending stream
|
|
||||||
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
|
|
||||||
stream = h2_conn.streams[stream_id]
|
|
||||||
stream.send_headers(response_hdrs, end_stream=False)
|
|
||||||
await h2_conn._send_pending_data()
|
|
||||||
|
|
||||||
# Send body without ending stream
|
|
||||||
if response_body:
|
|
||||||
h2_conn.h2_conn.send_data(stream_id, response_body, end_stream=False)
|
|
||||||
stream.send_data(response_body, end_stream=False)
|
|
||||||
await h2_conn._send_pending_data()
|
|
||||||
|
|
||||||
# Send trailers (ends stream)
|
|
||||||
await h2_conn.send_trailers(stream_id, response_trailers)
|
|
||||||
else:
|
|
||||||
await h2_conn.send_response(
|
|
||||||
stream_id, response_status, headers, response_body
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
||||||
response_status = 500
|
response_status = 500
|
||||||
|
|
||||||
|
# Handle case where headers were started but no body was sent
|
||||||
|
elif not headers_sent:
|
||||||
|
# Send headers now (empty body response)
|
||||||
|
headers = self._convert_h2_headers(response_headers)
|
||||||
|
response_hdrs = [(':status', str(response_status))]
|
||||||
|
response_hdrs.extend(headers)
|
||||||
|
stream = h2_conn.streams.get(stream_id)
|
||||||
|
if stream:
|
||||||
|
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=True)
|
||||||
|
stream.send_headers(response_hdrs, end_stream=True)
|
||||||
|
await h2_conn._send_pending_data()
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Error in ASGI application")
|
self.log.exception("Error in ASGI application")
|
||||||
if not response_started:
|
if not headers_sent:
|
||||||
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
||||||
response_status = 500
|
response_status = 500
|
||||||
finally:
|
finally:
|
||||||
@ -1495,7 +1555,7 @@ class ASGIProtocol(asyncio.Protocol):
|
|||||||
if access_log_enabled:
|
if access_log_enabled:
|
||||||
environ = self._build_http2_environ(request, sockname, peername)
|
environ = self._build_http2_environ(request, sockname, peername)
|
||||||
resp = ASGIResponseInfo(
|
resp = ASGIResponseInfo(
|
||||||
response_status, response_headers, len(response_body)
|
response_status, response_headers, response_sent
|
||||||
)
|
)
|
||||||
self.log.access(resp, request, environ, request_time)
|
self.log.access(resp, request, environ, request_time)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user