From b43dc6d398c7ec58d411e97770d5e8b05c4a7e76 Mon Sep 17 00:00:00 2001 From: Benoit Chesneau Date: Thu, 22 Jan 2026 09:14:19 +0100 Subject: [PATCH] gthread: Improve reliability and fix edge cases This commit addresses three issues with the gthread worker: 1. Request body handling on keepalive - Add finish_body() method to Parser to discard unread body bytes - Call it before returning connections to the poller - Prevents socket appearing readable due to leftover body Fixes #3301 2. Timeout reliability with monotonic clock - Replace time.time() with time.monotonic() in set_timeout() - Replace time.time() with time.monotonic() in murder_keepalived() - Prevents timeout issues caused by NTP adjustments 3. SSL error handling - Move conn.init() from enqueue_req() to handle() - SSL handshake now runs in worker thread, not main thread - ENOTCONN errors during ssl_wrap_socket are caught per-connection - Prevents entire worker crashes on SSL handshake failures Also adds comprehensive unit tests for the gthread worker. Closes #3303 Closes #3308 --- gunicorn/http/parser.py | 17 +- gunicorn/workers/gthread.py | 19 +- tests/test_gthread.py | 415 ++++++++++++++++++++++++++++++++++++ 3 files changed, 442 insertions(+), 9 deletions(-) create mode 100644 tests/test_gthread.py diff --git a/gunicorn/http/parser.py b/gunicorn/http/parser.py index 88da17ab..05ee6ca6 100644 --- a/gunicorn/http/parser.py +++ b/gunicorn/http/parser.py @@ -25,16 +25,25 @@ class Parser: def __iter__(self): return self + def finish_body(self): + """Discard any unread body of the current message. + + This should be called before returning a keepalive connection to + the poller to ensure the socket doesn't appear readable due to + leftover body bytes. + """ + if self.mesg: + data = self.mesg.body.read(8192) + while data: + data = self.mesg.body.read(8192) + def __next__(self): # Stop if HTTP dictates a stop. if self.mesg and self.mesg.should_close(): raise StopIteration() # Discard any unread body of the previous message - if self.mesg: - data = self.mesg.body.read(8192) - while data: - data = self.mesg.body.read(8192) + self.finish_body() # Parse the next request self.req_count += 1 diff --git a/gunicorn/workers/gthread.py b/gunicorn/workers/gthread.py index 7a23228c..f3938ef7 100644 --- a/gunicorn/workers/gthread.py +++ b/gunicorn/workers/gthread.py @@ -46,6 +46,9 @@ class TConn: self.sock.setblocking(False) def init(self): + # Guard against double initialization + if self.initialized: + return self.initialized = True self.sock.setblocking(True) @@ -58,8 +61,8 @@ class TConn: self.parser = http.RequestParser(self.cfg, self.sock, self.client) def set_timeout(self): - # set the timeout - self.timeout = time.time() + self.cfg.keepalive + # Use monotonic clock for reliability (time.time() can jump due to NTP) + self.timeout = time.monotonic() + self.cfg.keepalive def close(self): util.close(self.sock) @@ -111,8 +114,8 @@ class ThreadWorker(base.Worker): fs.add_done_callback(self.finish_request) def enqueue_req(self, conn): - conn.init() - # submit the connection to a worker + # submit the connection to a worker thread + # (conn.init() is called in handle() to avoid SSL errors in main thread) fs = self.tpool.submit(self.handle, conn) self._wrap_future(fs, conn) @@ -149,7 +152,7 @@ class ThreadWorker(base.Worker): self.enqueue_req(conn) def murder_keepalived(self): - now = time.time() + now = time.monotonic() while True: with self._lock: try: @@ -273,6 +276,9 @@ class ThreadWorker(base.Worker): keepalive = False req = None try: + # Initialize connection in worker thread to handle SSL errors gracefully + # (ENOTCONN from ssl_wrap_socket would crash main thread otherwise) + conn.init() req = next(conn.parser) if not req: return (False, conn) @@ -280,6 +286,9 @@ class ThreadWorker(base.Worker): # handle the request keepalive = self.handle_request(req, conn) if keepalive: + # Discard any unread request body before keepalive + # to prevent socket appearing readable due to leftover bytes + conn.parser.finish_body() return (keepalive, conn) except http.errors.NoMoreData as e: self.log.debug("Ignored premature client disconnection. %s", e) diff --git a/tests/test_gthread.py b/tests/test_gthread.py new file mode 100644 index 00000000..1cc4bb39 --- /dev/null +++ b/tests/test_gthread.py @@ -0,0 +1,415 @@ +# +# This file is part of gunicorn released under the MIT license. +# See the NOTICE for more information. + +"""Tests for the gthread worker.""" + +import errno +import os +import queue +import selectors +import socket +import threading +import time +from collections import deque +from concurrent import futures +from functools import partial +from unittest import mock + +import pytest + +from gunicorn import http +from gunicorn.config import Config +from gunicorn.workers import gthread + + +class FakeSocket: + """Mock socket for testing.""" + + def __init__(self, data=b''): + self.data = data + self.closed = False + self.blocking = True + self._fileno = id(self) % 65536 + + def fileno(self): + return self._fileno + + def setblocking(self, blocking): + self.blocking = blocking + + def recv(self, size): + if self.closed: + raise OSError(errno.EBADF, "Bad file descriptor") + result = self.data[:size] + self.data = self.data[size:] + return result + + def send(self, data): + if self.closed: + raise OSError(errno.EPIPE, "Broken pipe") + return len(data) + + def close(self): + self.closed = True + + def getsockname(self): + return ('127.0.0.1', 8000) + + def getpeername(self): + return ('127.0.0.1', 12345) + + +class TestTConn: + """Tests for TConn connection wrapper.""" + + def test_tconn_init(self): + """Test TConn initialization.""" + cfg = Config() + sock = FakeSocket() + client = ('127.0.0.1', 12345) + server = ('127.0.0.1', 8000) + + conn = gthread.TConn(cfg, sock, client, server) + + assert conn.cfg is cfg + assert conn.sock is sock + assert conn.client == client + assert conn.server == server + assert conn.timeout is None + assert conn.parser is None + assert conn.initialized is False + + def test_tconn_init_sets_blocking_false(self): + """Test that TConn sets socket to non-blocking initially.""" + cfg = Config() + sock = FakeSocket() + sock.setblocking(True) + + conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000)) + + # TConn sets socket to non-blocking in __init__ + assert sock.blocking is False + + def test_tconn_init_method_sets_blocking_true(self): + """Test that conn.init() sets socket back to blocking.""" + cfg = Config() + sock = FakeSocket() + + conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000)) + conn.init() + + assert sock.blocking is True + assert conn.initialized is True + assert conn.parser is not None + + def test_tconn_set_timeout(self): + """Test timeout setting using monotonic clock.""" + cfg = Config() + cfg.set('keepalive', 5) + sock = FakeSocket() + + conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000)) + before = time.monotonic() + conn.set_timeout() + after = time.monotonic() + + assert conn.timeout is not None + assert before + 5 <= conn.timeout <= after + 5 + + def test_tconn_close(self): + """Test connection closing.""" + cfg = Config() + sock = FakeSocket() + + conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000)) + conn.close() + + assert sock.closed is True + + +class TestThreadWorker: + """Tests for ThreadWorker.""" + + def create_worker(self, cfg=None): + """Create a worker instance for testing.""" + if cfg is None: + cfg = Config() + cfg.set('workers', 1) + cfg.set('threads', 4) + cfg.set('worker_connections', 1000) + cfg.set('keepalive', 2) + + # Mock the required attributes + worker = gthread.ThreadWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=mock.Mock(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + return worker + + def test_worker_init(self): + """Test worker initialization.""" + worker = self.create_worker() + + assert worker.worker_connections == 1000 + assert worker.max_keepalived == 1000 - 4 # connections - threads + assert worker.tpool is None + assert worker.poller is None + assert worker._lock is None + assert worker.nr_conns == 0 + + def test_worker_check_config_warning(self): + """Test that check_config warns when keepalive impossible.""" + cfg = Config() + cfg.set('worker_connections', 4) + cfg.set('threads', 4) + cfg.set('keepalive', 2) + log = mock.Mock() + + gthread.ThreadWorker.check_config(cfg, log) + + log.warning.assert_called() + + def test_worker_check_config_no_warning(self): + """Test that check_config doesn't warn with valid config.""" + cfg = Config() + cfg.set('worker_connections', 100) + cfg.set('threads', 4) + cfg.set('keepalive', 2) + log = mock.Mock() + + gthread.ThreadWorker.check_config(cfg, log) + + log.warning.assert_not_called() + + def test_worker_init_process(self): + """Test worker process initialization.""" + worker = self.create_worker() + worker.tmp = mock.Mock() + worker.log = mock.Mock() + + # Mock super().init_process() to avoid full initialization + with mock.patch.object(gthread.base.Worker, 'init_process'): + worker.init_process() + + assert worker.tpool is not None + assert worker.poller is not None + assert worker._lock is not None + + # Cleanup + worker.tpool.shutdown(wait=False) + worker.poller.close() + + def test_worker_get_thread_pool(self): + """Test thread pool creation.""" + worker = self.create_worker() + + pool = worker.get_thread_pool() + + assert isinstance(pool, futures.ThreadPoolExecutor) + pool.shutdown(wait=False) + + def test_worker_murder_keepalived(self): + """Test that expired keepalive connections are cleaned up.""" + worker = self.create_worker() + worker.poller = selectors.DefaultSelector() + worker._lock = threading.RLock() + + # Create an expired connection (using monotonic to match implementation) + cfg = Config() + sock = FakeSocket() + conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000)) + conn.timeout = time.monotonic() - 10 # Expired 10 seconds ago + + worker._keep.append(conn) + worker.nr_conns = 1 + + # Register with poller (so it can be unregistered) + try: + # Can't register FakeSocket with real selector, mock it + with mock.patch.object(worker.poller, 'unregister'): + worker.murder_keepalived() + except (OSError, ValueError): + pass # Expected with fake socket + + # Connection should have been removed + assert len(worker._keep) == 0 + assert sock.closed is True + + worker.poller.close() + + def test_worker_is_parent_alive(self): + """Test parent process check.""" + worker = self.create_worker() + + # With correct ppid + worker.ppid = os.getppid() + assert worker.is_parent_alive() is True + + # With wrong ppid + worker.ppid = -1 + assert worker.is_parent_alive() is False + + +class TestFinishRequest: + """Tests for finish_request handling.""" + + def create_worker(self): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('threads', 4) + cfg.set('worker_connections', 1000) + + worker = gthread.ThreadWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=mock.Mock(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + worker._lock = threading.RLock() + worker.poller = mock.Mock() + worker.alive = True + return worker + + def test_finish_request_cancelled(self): + """Test handling of cancelled future.""" + worker = self.create_worker() + worker.nr_conns = 1 + + conn = mock.Mock() + fs = mock.Mock() + fs.cancelled.return_value = True + fs.conn = conn + + worker.finish_request(fs) + + assert worker.nr_conns == 0 + conn.close.assert_called_once() + + def test_finish_request_keepalive(self): + """Test handling of keepalive response.""" + worker = self.create_worker() + worker.nr_conns = 1 + + conn = mock.Mock() + conn.sock = mock.Mock() + fs = mock.Mock() + fs.cancelled.return_value = False + fs.result.return_value = (True, conn) # keepalive=True + fs.conn = conn + + worker.finish_request(fs) + + assert worker.nr_conns == 1 # Connection kept + assert conn in worker._keep + conn.set_timeout.assert_called_once() + worker.poller.register.assert_called_once() + + def test_finish_request_close(self): + """Test handling of non-keepalive response.""" + worker = self.create_worker() + worker.nr_conns = 1 + + conn = mock.Mock() + fs = mock.Mock() + fs.cancelled.return_value = False + fs.result.return_value = (False, conn) # keepalive=False + fs.conn = conn + + worker.finish_request(fs) + + assert worker.nr_conns == 0 + conn.close.assert_called_once() + + def test_finish_request_exception(self): + """Test handling of exception in request.""" + worker = self.create_worker() + worker.nr_conns = 1 + + conn = mock.Mock() + fs = mock.Mock() + fs.cancelled.return_value = False + fs.result.side_effect = Exception("Test error") + fs.conn = conn + + worker.finish_request(fs) + + assert worker.nr_conns == 0 + conn.close.assert_called_once() + + +class TestAccept: + """Tests for connection acceptance.""" + + def create_worker(self): + """Create a worker for testing.""" + cfg = Config() + cfg.set('workers', 1) + cfg.set('threads', 4) + cfg.set('worker_connections', 1000) + + worker = gthread.ThreadWorker( + age=1, + ppid=os.getpid(), + sockets=[], + app=mock.Mock(), + timeout=30, + cfg=cfg, + log=mock.Mock(), + ) + worker._lock = threading.RLock() + worker.poller = mock.Mock() + return worker + + def test_accept_success(self): + """Test successful connection acceptance.""" + worker = self.create_worker() + worker.nr_conns = 0 + + client_sock = FakeSocket() + client_addr = ('127.0.0.1', 12345) + listener = mock.Mock() + listener.accept.return_value = (client_sock, client_addr) + server = ('127.0.0.1', 8000) + + worker.accept(server, listener) + + assert worker.nr_conns == 1 + worker.poller.register.assert_called_once() + + def test_accept_eagain(self): + """Test handling of EAGAIN during accept.""" + worker = self.create_worker() + worker.nr_conns = 0 + + listener = mock.Mock() + listener.accept.side_effect = OSError(errno.EAGAIN, "Try again") + server = ('127.0.0.1', 8000) + + # Should not raise + worker.accept(server, listener) + + assert worker.nr_conns == 0 + + def test_accept_econnaborted(self): + """Test handling of ECONNABORTED during accept.""" + worker = self.create_worker() + worker.nr_conns = 0 + + listener = mock.Mock() + listener.accept.side_effect = OSError(errno.ECONNABORTED, "Connection aborted") + server = ('127.0.0.1', 8000) + + # Should not raise + worker.accept(server, listener) + + assert worker.nr_conns == 0