Add write flow control and HTTP/2 streaming

- Add FlowControl class for transport-level write backpressure
- Integrate flow control into HTTP/1.1 protocol to prevent memory
  issues with large streaming responses
- Set write buffer high water mark to 64KB
- Add pause_writing/resume_writing protocol callbacks
- Stream HTTP/2 responses immediately instead of buffering
- Add _convert_h2_headers helper for cleaner header conversion
This commit is contained in:
Benoit Chesneau 2026-03-21 23:50:06 +01:00
parent 22bdca22e1
commit 464cbbfad5

View File

@ -77,8 +77,48 @@ _cached_date_time = 0.0
# Pre-compute common chunk size prefixes to avoid repeated formatting
_CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)}
# High water mark for write buffer backpressure (256KB)
_WRITE_BUFFER_HIGH_WATER = 262144
# High water mark for write buffer backpressure (64KB)
HIGH_WATER_LIMIT = 65536
class FlowControl:
"""Manage transport-level write flow control.
Blocks send() when transport buffer exceeds high water mark,
preventing memory issues with large streaming responses.
"""
__slots__ = ('_transport', 'read_paused', 'write_paused', '_is_writable_event')
def __init__(self, transport):
self._transport = transport
self.read_paused = False
self.write_paused = False
self._is_writable_event = asyncio.Event()
self._is_writable_event.set()
async def drain(self):
"""Wait until transport is writable."""
await self._is_writable_event.wait()
def pause_reading(self):
if not self.read_paused:
self.read_paused = True
self._transport.pause_reading()
def resume_reading(self):
if self.read_paused:
self.read_paused = False
self._transport.resume_reading()
def pause_writing(self):
if not self.write_paused:
self.write_paused = True
self._is_writable_event.clear()
def resume_writing(self):
if self.write_paused:
self.write_paused = False
self._is_writable_event.set()
def _get_cached_date_header():
@ -300,6 +340,9 @@ class ASGIProtocol(asyncio.Protocol):
self._is_ssl = False
self._use_callback_parser = False
# Write flow control
self._flow_control = None
def connection_made(self, transport):
"""Called when a connection is established."""
self.transport = transport
@ -322,6 +365,10 @@ class ASGIProtocol(asyncio.Protocol):
self._data_event = asyncio.Event()
self.writer = transport
# Setup flow control for HTTP/1.x
self._flow_control = FlowControl(transport)
transport.set_write_buffer_limits(high=HIGH_WATER_LIMIT)
# Check if callback parser should be used
self._use_callback_parser = self._should_use_callback_parser()
@ -545,6 +592,16 @@ class ASGIProtocol(asyncio.Protocol):
if self._task and not self._task.done():
self._task.cancel()
def pause_writing(self):
"""Called by transport when write buffer exceeds high water mark."""
if self._flow_control:
self._flow_control.pause_writing()
def resume_writing(self):
"""Called by transport when write buffer drains below low water mark."""
if self._flow_control:
self._flow_control.resume_writing()
def _safe_write(self, data):
"""Write data to transport, handling connection errors gracefully.
@ -936,6 +993,9 @@ class ASGIProtocol(asyncio.Protocol):
if body:
self._send_body(body, chunked=use_chunked)
response_sent += len(body)
# Apply write backpressure for streaming responses
if self._flow_control:
await self._flow_control.drain()
if not more_body:
if use_chunked:
@ -1349,19 +1409,33 @@ class ASGIProtocol(asyncio.Protocol):
pass
self._close_transport()
def _convert_h2_headers(self, headers):
"""Convert ASGI headers to HTTP/2 format (lowercase string names)."""
result = []
for name, value in headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
result.append((name.lower(), value))
return result
async def _handle_http2_request(self, request, h2_conn, sockname, peername):
"""Handle a single HTTP/2 request."""
"""Handle a single HTTP/2 request with streaming support.
Streams response body chunks immediately instead of buffering,
enabling SSE, streaming downloads, and other real-time use cases.
"""
stream_id = request.stream.stream_id
scope = self._build_http2_scope(request, sockname, peername)
response_started = False
response_complete = False
headers_sent = False
exc_to_raise = None
response_status = 500
response_headers = []
response_body = b''
response_trailers = []
response_sent = 0
async def receive():
# For HTTP/2, the body is already buffered in the stream
@ -1373,8 +1447,8 @@ class ASGIProtocol(asyncio.Protocol):
}
async def send(message):
nonlocal response_started, response_complete, exc_to_raise
nonlocal response_status, response_headers, response_body
nonlocal response_started, response_complete, headers_sent
nonlocal response_status, response_headers, response_sent, exc_to_raise
msg_type = message["type"]
@ -1382,14 +1456,7 @@ class ASGIProtocol(asyncio.Protocol):
# Handle informational responses (1xx) like 103 Early Hints over HTTP/2
info_status = message.get("status")
info_headers = message.get("headers", [])
# Convert headers to list of string tuples
headers = []
for name, value in info_headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
headers.append((name, value))
headers = self._convert_h2_headers(info_headers)
await h2_conn.send_informational(stream_id, info_status, headers)
return
@ -1400,6 +1467,7 @@ class ASGIProtocol(asyncio.Protocol):
response_started = True
response_status = message["status"]
response_headers = message.get("headers", [])
# Don't send headers yet - wait for first body chunk
elif msg_type == "http.response.body":
if not response_started:
@ -1412,10 +1480,31 @@ class ASGIProtocol(asyncio.Protocol):
body = message.get("body", b"")
more_body = message.get("more_body", False)
# Send headers with first body chunk
if not headers_sent:
headers = self._convert_h2_headers(response_headers)
response_hdrs = [(':status', str(response_status))]
response_hdrs.extend(headers)
# Send headers without end_stream since we have body
stream = h2_conn.streams.get(stream_id)
if stream is None:
exc_to_raise = RuntimeError("Stream closed")
return
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
stream.send_headers(response_hdrs, end_stream=False)
await h2_conn._send_pending_data()
headers_sent = True
# Stream body immediately
if body:
response_body += body
await h2_conn.send_data(stream_id, body, end_stream=not more_body)
response_sent += len(body)
if not more_body:
if not body:
# Empty final chunk - send end_stream
await h2_conn.send_data(stream_id, b"", end_stream=True)
response_complete = True
elif msg_type == "http.response.trailers":
@ -1423,15 +1512,8 @@ class ASGIProtocol(asyncio.Protocol):
exc_to_raise = RuntimeError("Cannot send trailers before body complete")
return
trailer_headers = message.get("headers", [])
# Convert to list of tuples with string values
trailers = []
for name, value in trailer_headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
trailers.append((name, value))
response_trailers.extend(trailers)
trailers = self._convert_h2_headers(trailer_headers)
await h2_conn.send_trailers(stream_id, trailers)
# Only build environ for logging if access logging is enabled
access_log_enabled = self.log.access_log_enabled
@ -1444,48 +1526,26 @@ class ASGIProtocol(asyncio.Protocol):
if exc_to_raise is not None:
raise exc_to_raise
# Send response via HTTP/2
if response_started:
# Convert headers to list of tuples
headers = []
for name, value in response_headers:
if isinstance(name, bytes):
name = name.decode("latin-1")
if isinstance(value, bytes):
value = value.decode("latin-1")
headers.append((name, value))
if response_trailers:
# Send headers, body, then trailers separately
response_hdrs = [(':status', str(response_status))]
for name, value in headers:
response_hdrs.append((name.lower(), str(value)))
# Send headers without ending stream
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False)
stream = h2_conn.streams[stream_id]
stream.send_headers(response_hdrs, end_stream=False)
await h2_conn._send_pending_data()
# Send body without ending stream
if response_body:
h2_conn.h2_conn.send_data(stream_id, response_body, end_stream=False)
stream.send_data(response_body, end_stream=False)
await h2_conn._send_pending_data()
# Send trailers (ends stream)
await h2_conn.send_trailers(stream_id, response_trailers)
else:
await h2_conn.send_response(
stream_id, response_status, headers, response_body
)
else:
# Handle case where app didn't send any response
if not response_started:
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
response_status = 500
# Handle case where headers were started but no body was sent
elif not headers_sent:
# Send headers now (empty body response)
headers = self._convert_h2_headers(response_headers)
response_hdrs = [(':status', str(response_status))]
response_hdrs.extend(headers)
stream = h2_conn.streams.get(stream_id)
if stream:
h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=True)
stream.send_headers(response_hdrs, end_stream=True)
await h2_conn._send_pending_data()
except Exception:
self.log.exception("Error in ASGI application")
if not response_started:
if not headers_sent:
await h2_conn.send_error(stream_id, 500, "Internal Server Error")
response_status = 500
finally:
@ -1495,7 +1555,7 @@ class ASGIProtocol(asyncio.Protocol):
if access_log_enabled:
environ = self._build_http2_environ(request, sockname, peername)
resp = ASGIResponseInfo(
response_status, response_headers, len(response_body)
response_status, response_headers, response_sent
)
self.log.access(resp, request, environ, request_time)
else: