mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-01 18:21:30 +08:00
Add write flow control and HTTP/2 streaming
- Add FlowControl class for transport-level write backpressure - Integrate flow control into HTTP/1.1 protocol to prevent memory issues with large streaming responses - Set write buffer high water mark to 64KB - Add pause_writing/resume_writing protocol callbacks - Stream HTTP/2 responses immediately instead of buffering - Add _convert_h2_headers helper for cleaner header conversion
This commit is contained in:
parent
22bdca22e1
commit
464cbbfad5
@ -77,8 +77,48 @@ _cached_date_time = 0.0
|
||||
# Pre-compute common chunk size prefixes to avoid repeated formatting
|
||||
_CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)}
|
||||
|
||||
# High water mark for write buffer backpressure (256KB)
|
||||
_WRITE_BUFFER_HIGH_WATER = 262144
|
||||
# High water mark for write buffer backpressure (64KB)
|
||||
HIGH_WATER_LIMIT = 65536
|
||||
|
||||
|
||||
class FlowControl:
|
||||
"""Manage transport-level write flow control.
|
||||
|
||||
Blocks send() when transport buffer exceeds high water mark,
|
||||
preventing memory issues with large streaming responses.
|
||||
"""
|
||||
__slots__ = ('_transport', 'read_paused', 'write_paused', '_is_writable_event')
|
||||
|
||||
def __init__(self, transport):
|
||||
self._transport = transport
|
||||
self.read_paused = False
|
||||
self.write_paused = False
|
||||
self._is_writable_event = asyncio.Event()
|
||||
self._is_writable_event.set()
|
||||
|
||||
async def drain(self):
|
||||
"""Wait until transport is writable."""
|
||||
await self._is_writable_event.wait()
|
||||
|
||||
def pause_reading(self):
|
||||
if not self.read_paused:
|
||||
self.read_paused = True
|
||||
self._transport.pause_reading()
|
||||
|
||||
def resume_reading(self):
|
||||
if self.read_paused:
|
||||
self.read_paused = False
|
||||
self._transport.resume_reading()
|
||||
|
||||
def pause_writing(self):
|
||||
if not self.write_paused:
|
||||
self.write_paused = True
|
||||
self._is_writable_event.clear()
|
||||
|
||||
def resume_writing(self):
|
||||
if self.write_paused:
|
||||
self.write_paused = False
|
||||
self._is_writable_event.set()
|
||||
|
||||
|
||||
def _get_cached_date_header():
|
||||
@ -300,6 +340,9 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
self._is_ssl = False
|
||||
self._use_callback_parser = False
|
||||
|
||||
# Write flow control
|
||||
self._flow_control = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Called when a connection is established."""
|
||||
self.transport = transport
|
||||
@ -322,6 +365,10 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
self._data_event = asyncio.Event()
|
||||
self.writer = transport
|
||||
|
||||
# Setup flow control for HTTP/1.x
|
||||
self._flow_control = FlowControl(transport)
|
||||
transport.set_write_buffer_limits(high=HIGH_WATER_LIMIT)
|
||||
|
||||
# Check if callback parser should be used
|
||||
self._use_callback_parser = self._should_use_callback_parser()
|
||||
|
||||
@ -545,6 +592,16 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if self._task and not self._task.done():
|
||||
self._task.cancel()
|
||||
|
||||
def pause_writing(self):
|
||||
"""Called by transport when write buffer exceeds high water mark."""
|
||||
if self._flow_control:
|
||||
self._flow_control.pause_writing()
|
||||
|
||||
def resume_writing(self):
|
||||
"""Called by transport when write buffer drains below low water mark."""
|
||||
if self._flow_control:
|
||||
self._flow_control.resume_writing()
|
||||
|
||||
def _safe_write(self, data):
|
||||
"""Write data to transport, handling connection errors gracefully.
|
||||
|
||||
@ -936,6 +993,9 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if body:
|
||||
self._send_body(body, chunked=use_chunked)
|
||||
response_sent += len(body)
|
||||
# Apply write backpressure for streaming responses
|
||||
if self._flow_control:
|
||||
await self._flow_control.drain()
|
||||
|
||||
if not more_body:
|
||||
if use_chunked:
|
||||
@ -1349,19 +1409,33 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
pass
|
||||
self._close_transport()
|
||||
|
||||
def _convert_h2_headers(self, headers):
|
||||
"""Convert ASGI headers to HTTP/2 format (lowercase string names)."""
|
||||
result = []
|
||||
for name, value in headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
result.append((name.lower(), value))
|
||||
return result
|
||||
|
||||
async def _handle_http2_request(self, request, h2_conn, sockname, peername):
|
||||
"""Handle a single HTTP/2 request."""
|
||||
"""Handle a single HTTP/2 request with streaming support.
|
||||
|
||||
Streams response body chunks immediately instead of buffering,
|
||||
enabling SSE, streaming downloads, and other real-time use cases.
|
||||
"""
|
||||
stream_id = request.stream.stream_id
|
||||
scope = self._build_http2_scope(request, sockname, peername)
|
||||
|
||||
response_started = False
|
||||
response_complete = False
|
||||
headers_sent = False
|
||||
exc_to_raise = None
|
||||
|
||||
response_status = 500
|
||||
response_headers = []
|
||||
response_body = b''
|
||||
response_trailers = []
|
||||
response_sent = 0
|
||||
|
||||
async def receive():
|
||||
# For HTTP/2, the body is already buffered in the stream
|
||||
@ -1373,8 +1447,8 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
}
|
||||
|
||||
async def send(message):
|
||||
nonlocal response_started, response_complete, exc_to_raise
|
||||
nonlocal response_status, response_headers, response_body
|
||||
nonlocal response_started, response_complete, headers_sent
|
||||
nonlocal response_status, response_headers, response_sent, exc_to_raise
|
||||
|
||||
msg_type = message["type"]
|
||||
|
||||
@ -1382,14 +1456,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
# Handle informational responses (1xx) like 103 Early Hints over HTTP/2
|
||||
info_status = message.get("status")
|
||||
info_headers = message.get("headers", [])
|
||||
# Convert headers to list of string tuples
|
||||
headers = []
|
||||
for name, value in info_headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
headers.append((name, value))
|
||||
headers = self._convert_h2_headers(info_headers)
|
||||
await h2_conn.send_informational(stream_id, info_status, headers)
|
||||
return
|
||||
|
||||
@ -1400,6 +1467,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
response_started = True
|
||||
response_status = message["status"]
|
||||
response_headers = message.get("headers", [])
|
||||
# Don't send headers yet - wait for first body chunk
|
||||
|
||||
elif msg_type == "http.response.body":
|
||||
if not response_started:
|
||||
@ -1412,10 +1480,31 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
body = message.get("body", b"")
|
||||
more_body = message.get("more_body", False)
|
||||
|
||||
# Send headers with first body chunk
|
||||
if not headers_sent:
|
||||
headers = self._convert_h2_headers(response_headers)
|
||||
response_hdrs = [(':status', str(response_status))]
|
||||
response_hdrs.extend(headers)
|
||||
|
||||
# Send headers without end_stream since we have body
|
||||
stream = h2_conn.streams.get(stream_id)
|
||||
if stream is None:
|
||||
exc_to_raise = RuntimeError("Stream closed")
|
||||
return
|
||||
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
|
||||
stream.send_headers(response_hdrs, end_stream=False)
|
||||
await h2_conn._send_pending_data()
|
||||
headers_sent = True
|
||||
|
||||
# Stream body immediately
|
||||
if body:
|
||||
response_body += body
|
||||
await h2_conn.send_data(stream_id, body, end_stream=not more_body)
|
||||
response_sent += len(body)
|
||||
|
||||
if not more_body:
|
||||
if not body:
|
||||
# Empty final chunk - send end_stream
|
||||
await h2_conn.send_data(stream_id, b"", end_stream=True)
|
||||
response_complete = True
|
||||
|
||||
elif msg_type == "http.response.trailers":
|
||||
@ -1423,15 +1512,8 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
exc_to_raise = RuntimeError("Cannot send trailers before body complete")
|
||||
return
|
||||
trailer_headers = message.get("headers", [])
|
||||
# Convert to list of tuples with string values
|
||||
trailers = []
|
||||
for name, value in trailer_headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
trailers.append((name, value))
|
||||
response_trailers.extend(trailers)
|
||||
trailers = self._convert_h2_headers(trailer_headers)
|
||||
await h2_conn.send_trailers(stream_id, trailers)
|
||||
|
||||
# Only build environ for logging if access logging is enabled
|
||||
access_log_enabled = self.log.access_log_enabled
|
||||
@ -1444,48 +1526,26 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if exc_to_raise is not None:
|
||||
raise exc_to_raise
|
||||
|
||||
# Send response via HTTP/2
|
||||
if response_started:
|
||||
# Convert headers to list of tuples
|
||||
headers = []
|
||||
for name, value in response_headers:
|
||||
if isinstance(name, bytes):
|
||||
name = name.decode("latin-1")
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("latin-1")
|
||||
headers.append((name, value))
|
||||
|
||||
if response_trailers:
|
||||
# Send headers, body, then trailers separately
|
||||
response_hdrs = [(':status', str(response_status))]
|
||||
for name, value in headers:
|
||||
response_hdrs.append((name.lower(), str(value)))
|
||||
|
||||
# Send headers without ending stream
|
||||
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
|
||||
stream = h2_conn.streams[stream_id]
|
||||
stream.send_headers(response_hdrs, end_stream=False)
|
||||
await h2_conn._send_pending_data()
|
||||
|
||||
# Send body without ending stream
|
||||
if response_body:
|
||||
h2_conn.h2_conn.send_data(stream_id, response_body, end_stream=False)
|
||||
stream.send_data(response_body, end_stream=False)
|
||||
await h2_conn._send_pending_data()
|
||||
|
||||
# Send trailers (ends stream)
|
||||
await h2_conn.send_trailers(stream_id, response_trailers)
|
||||
else:
|
||||
await h2_conn.send_response(
|
||||
stream_id, response_status, headers, response_body
|
||||
)
|
||||
else:
|
||||
# Handle case where app didn't send any response
|
||||
if not response_started:
|
||||
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
||||
response_status = 500
|
||||
|
||||
# Handle case where headers were started but no body was sent
|
||||
elif not headers_sent:
|
||||
# Send headers now (empty body response)
|
||||
headers = self._convert_h2_headers(response_headers)
|
||||
response_hdrs = [(':status', str(response_status))]
|
||||
response_hdrs.extend(headers)
|
||||
stream = h2_conn.streams.get(stream_id)
|
||||
if stream:
|
||||
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=True)
|
||||
stream.send_headers(response_hdrs, end_stream=True)
|
||||
await h2_conn._send_pending_data()
|
||||
|
||||
except Exception:
|
||||
self.log.exception("Error in ASGI application")
|
||||
if not response_started:
|
||||
if not headers_sent:
|
||||
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
|
||||
response_status = 500
|
||||
finally:
|
||||
@ -1495,7 +1555,7 @@ class ASGIProtocol(asyncio.Protocol):
|
||||
if access_log_enabled:
|
||||
environ = self._build_http2_environ(request, sockname, peername)
|
||||
resp = ASGIResponseInfo(
|
||||
response_status, response_headers, len(response_body)
|
||||
response_status, response_headers, response_sent
|
||||
)
|
||||
self.log.access(resp, request, environ, request_time)
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user