diff --git a/gunicorn/asgi/protocol.py b/gunicorn/asgi/protocol.py index 29d8a384..75f2a550 100644 --- a/gunicorn/asgi/protocol.py +++ b/gunicorn/asgi/protocol.py @@ -77,8 +77,48 @@ _cached_date_time = 0.0 # Pre-compute common chunk size prefixes to avoid repeated formatting _CHUNK_PREFIXES = {i: f"{i:x}\r\n".encode("latin-1") for i in range(16384)} -# High water mark for write buffer backpressure (256KB) -_WRITE_BUFFER_HIGH_WATER = 262144 +# High water mark for write buffer backpressure (64KB) +HIGH_WATER_LIMIT = 65536 + + +class FlowControl: + """Manage transport-level write flow control. + + Blocks send() when transport buffer exceeds high water mark, + preventing memory issues with large streaming responses. + """ + __slots__ = ('_transport', 'read_paused', 'write_paused', '_is_writable_event') + + def __init__(self, transport): + self._transport = transport + self.read_paused = False + self.write_paused = False + self._is_writable_event = asyncio.Event() + self._is_writable_event.set() + + async def drain(self): + """Wait until transport is writable.""" + await self._is_writable_event.wait() + + def pause_reading(self): + if not self.read_paused: + self.read_paused = True + self._transport.pause_reading() + + def resume_reading(self): + if self.read_paused: + self.read_paused = False + self._transport.resume_reading() + + def pause_writing(self): + if not self.write_paused: + self.write_paused = True + self._is_writable_event.clear() + + def resume_writing(self): + if self.write_paused: + self.write_paused = False + self._is_writable_event.set() def _get_cached_date_header(): @@ -300,6 +340,9 @@ class ASGIProtocol(asyncio.Protocol): self._is_ssl = False self._use_callback_parser = False + # Write flow control + self._flow_control = None + def connection_made(self, transport): """Called when a connection is established.""" self.transport = transport @@ -322,6 +365,10 @@ class ASGIProtocol(asyncio.Protocol): self._data_event = asyncio.Event() self.writer = transport + # Setup flow control for HTTP/1.x + self._flow_control = FlowControl(transport) + transport.set_write_buffer_limits(high=HIGH_WATER_LIMIT) + # Check if callback parser should be used self._use_callback_parser = self._should_use_callback_parser() @@ -545,6 +592,16 @@ class ASGIProtocol(asyncio.Protocol): if self._task and not self._task.done(): self._task.cancel() + def pause_writing(self): + """Called by transport when write buffer exceeds high water mark.""" + if self._flow_control: + self._flow_control.pause_writing() + + def resume_writing(self): + """Called by transport when write buffer drains below low water mark.""" + if self._flow_control: + self._flow_control.resume_writing() + def _safe_write(self, data): """Write data to transport, handling connection errors gracefully. @@ -936,6 +993,9 @@ class ASGIProtocol(asyncio.Protocol): if body: self._send_body(body, chunked=use_chunked) response_sent += len(body) + # Apply write backpressure for streaming responses + if self._flow_control: + await self._flow_control.drain() if not more_body: if use_chunked: @@ -1349,19 +1409,33 @@ class ASGIProtocol(asyncio.Protocol): pass self._close_transport() + def _convert_h2_headers(self, headers): + """Convert ASGI headers to HTTP/2 format (lowercase string names).""" + result = [] + for name, value in headers: + if isinstance(name, bytes): + name = name.decode("latin-1") + if isinstance(value, bytes): + value = value.decode("latin-1") + result.append((name.lower(), value)) + return result + async def _handle_http2_request(self, request, h2_conn, sockname, peername): - """Handle a single HTTP/2 request.""" + """Handle a single HTTP/2 request with streaming support. + + Streams response body chunks immediately instead of buffering, + enabling SSE, streaming downloads, and other real-time use cases. + """ stream_id = request.stream.stream_id scope = self._build_http2_scope(request, sockname, peername) response_started = False response_complete = False + headers_sent = False exc_to_raise = None - response_status = 500 response_headers = [] - response_body = b'' - response_trailers = [] + response_sent = 0 async def receive(): # For HTTP/2, the body is already buffered in the stream @@ -1373,8 +1447,8 @@ class ASGIProtocol(asyncio.Protocol): } async def send(message): - nonlocal response_started, response_complete, exc_to_raise - nonlocal response_status, response_headers, response_body + nonlocal response_started, response_complete, headers_sent + nonlocal response_status, response_headers, response_sent, exc_to_raise msg_type = message["type"] @@ -1382,14 +1456,7 @@ class ASGIProtocol(asyncio.Protocol): # Handle informational responses (1xx) like 103 Early Hints over HTTP/2 info_status = message.get("status") info_headers = message.get("headers", []) - # Convert headers to list of string tuples - headers = [] - for name, value in info_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - headers.append((name, value)) + headers = self._convert_h2_headers(info_headers) await h2_conn.send_informational(stream_id, info_status, headers) return @@ -1400,6 +1467,7 @@ class ASGIProtocol(asyncio.Protocol): response_started = True response_status = message["status"] response_headers = message.get("headers", []) + # Don't send headers yet - wait for first body chunk elif msg_type == "http.response.body": if not response_started: @@ -1412,10 +1480,31 @@ class ASGIProtocol(asyncio.Protocol): body = message.get("body", b"") more_body = message.get("more_body", False) + # Send headers with first body chunk + if not headers_sent: + headers = self._convert_h2_headers(response_headers) + response_hdrs = [(':status', str(response_status))] + response_hdrs.extend(headers) + + # Send headers without end_stream since we have body + stream = h2_conn.streams.get(stream_id) + if stream is None: + exc_to_raise = RuntimeError("Stream closed") + return + h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False) + stream.send_headers(response_hdrs, end_stream=False) + await h2_conn._send_pending_data() + headers_sent = True + + # Stream body immediately if body: - response_body += body + await h2_conn.send_data(stream_id, body, end_stream=not more_body) + response_sent += len(body) if not more_body: + if not body: + # Empty final chunk - send end_stream + await h2_conn.send_data(stream_id, b"", end_stream=True) response_complete = True elif msg_type == "http.response.trailers": @@ -1423,15 +1512,8 @@ class ASGIProtocol(asyncio.Protocol): exc_to_raise = RuntimeError("Cannot send trailers before body complete") return trailer_headers = message.get("headers", []) - # Convert to list of tuples with string values - trailers = [] - for name, value in trailer_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - trailers.append((name, value)) - response_trailers.extend(trailers) + trailers = self._convert_h2_headers(trailer_headers) + await h2_conn.send_trailers(stream_id, trailers) # Only build environ for logging if access logging is enabled access_log_enabled = self.log.access_log_enabled @@ -1444,48 +1526,26 @@ class ASGIProtocol(asyncio.Protocol): if exc_to_raise is not None: raise exc_to_raise - # Send response via HTTP/2 - if response_started: - # Convert headers to list of tuples - headers = [] - for name, value in response_headers: - if isinstance(name, bytes): - name = name.decode("latin-1") - if isinstance(value, bytes): - value = value.decode("latin-1") - headers.append((name, value)) - - if response_trailers: - # Send headers, body, then trailers separately - response_hdrs = [(':status', str(response_status))] - for name, value in headers: - response_hdrs.append((name.lower(), str(value))) - - # Send headers without ending stream - h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=False) - stream = h2_conn.streams[stream_id] - stream.send_headers(response_hdrs, end_stream=False) - await h2_conn._send_pending_data() - - # Send body without ending stream - if response_body: - h2_conn.h2_conn.send_data(stream_id, response_body, end_stream=False) - stream.send_data(response_body, end_stream=False) - await h2_conn._send_pending_data() - - # Send trailers (ends stream) - await h2_conn.send_trailers(stream_id, response_trailers) - else: - await h2_conn.send_response( - stream_id, response_status, headers, response_body - ) - else: + # Handle case where app didn't send any response + if not response_started: await h2_conn.send_error(stream_id, 500, "Internal Server Error") response_status = 500 + # Handle case where headers were started but no body was sent + elif not headers_sent: + # Send headers now (empty body response) + headers = self._convert_h2_headers(response_headers) + response_hdrs = [(':status', str(response_status))] + response_hdrs.extend(headers) + stream = h2_conn.streams.get(stream_id) + if stream: + h2_conn.h2_conn.send_headers(stream_id, response_hdrs, end_stream=True) + stream.send_headers(response_hdrs, end_stream=True) + await h2_conn._send_pending_data() + except Exception: self.log.exception("Error in ASGI application") - if not response_started: + if not headers_sent: await h2_conn.send_error(stream_id, 500, "Internal Server Error") response_status = 500 finally: @@ -1495,7 +1555,7 @@ class ASGIProtocol(asyncio.Protocol): if access_log_enabled: environ = self._build_http2_environ(request, sockname, peername) resp = ASGIResponseInfo( - response_status, response_headers, len(response_body) + response_status, response_headers, response_sent ) self.log.access(resp, request, environ, request_time) else: