mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 11:11:30 +08:00
gthread: Improve reliability and fix edge cases
This commit addresses three issues with the gthread worker: 1. Request body handling on keepalive - Add finish_body() method to Parser to discard unread body bytes - Call it before returning connections to the poller - Prevents socket appearing readable due to leftover body Fixes #3301 2. Timeout reliability with monotonic clock - Replace time.time() with time.monotonic() in set_timeout() - Replace time.time() with time.monotonic() in murder_keepalived() - Prevents timeout issues caused by NTP adjustments 3. SSL error handling - Move conn.init() from enqueue_req() to handle() - SSL handshake now runs in worker thread, not main thread - ENOTCONN errors during ssl_wrap_socket are caught per-connection - Prevents entire worker crashes on SSL handshake failures Also adds comprehensive unit tests for the gthread worker. Closes #3303 Closes #3308
This commit is contained in:
parent
1dc4ce9d59
commit
b43dc6d398
@ -25,16 +25,25 @@ class Parser:
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def finish_body(self):
|
||||
"""Discard any unread body of the current message.
|
||||
|
||||
This should be called before returning a keepalive connection to
|
||||
the poller to ensure the socket doesn't appear readable due to
|
||||
leftover body bytes.
|
||||
"""
|
||||
if self.mesg:
|
||||
data = self.mesg.body.read(8192)
|
||||
while data:
|
||||
data = self.mesg.body.read(8192)
|
||||
|
||||
def __next__(self):
|
||||
# Stop if HTTP dictates a stop.
|
||||
if self.mesg and self.mesg.should_close():
|
||||
raise StopIteration()
|
||||
|
||||
# Discard any unread body of the previous message
|
||||
if self.mesg:
|
||||
data = self.mesg.body.read(8192)
|
||||
while data:
|
||||
data = self.mesg.body.read(8192)
|
||||
self.finish_body()
|
||||
|
||||
# Parse the next request
|
||||
self.req_count += 1
|
||||
|
||||
@ -46,6 +46,9 @@ class TConn:
|
||||
self.sock.setblocking(False)
|
||||
|
||||
def init(self):
|
||||
# Guard against double initialization
|
||||
if self.initialized:
|
||||
return
|
||||
self.initialized = True
|
||||
self.sock.setblocking(True)
|
||||
|
||||
@ -58,8 +61,8 @@ class TConn:
|
||||
self.parser = http.RequestParser(self.cfg, self.sock, self.client)
|
||||
|
||||
def set_timeout(self):
|
||||
# set the timeout
|
||||
self.timeout = time.time() + self.cfg.keepalive
|
||||
# Use monotonic clock for reliability (time.time() can jump due to NTP)
|
||||
self.timeout = time.monotonic() + self.cfg.keepalive
|
||||
|
||||
def close(self):
|
||||
util.close(self.sock)
|
||||
@ -111,8 +114,8 @@ class ThreadWorker(base.Worker):
|
||||
fs.add_done_callback(self.finish_request)
|
||||
|
||||
def enqueue_req(self, conn):
|
||||
conn.init()
|
||||
# submit the connection to a worker
|
||||
# submit the connection to a worker thread
|
||||
# (conn.init() is called in handle() to avoid SSL errors in main thread)
|
||||
fs = self.tpool.submit(self.handle, conn)
|
||||
self._wrap_future(fs, conn)
|
||||
|
||||
@ -149,7 +152,7 @@ class ThreadWorker(base.Worker):
|
||||
self.enqueue_req(conn)
|
||||
|
||||
def murder_keepalived(self):
|
||||
now = time.time()
|
||||
now = time.monotonic()
|
||||
while True:
|
||||
with self._lock:
|
||||
try:
|
||||
@ -273,6 +276,9 @@ class ThreadWorker(base.Worker):
|
||||
keepalive = False
|
||||
req = None
|
||||
try:
|
||||
# Initialize connection in worker thread to handle SSL errors gracefully
|
||||
# (ENOTCONN from ssl_wrap_socket would crash main thread otherwise)
|
||||
conn.init()
|
||||
req = next(conn.parser)
|
||||
if not req:
|
||||
return (False, conn)
|
||||
@ -280,6 +286,9 @@ class ThreadWorker(base.Worker):
|
||||
# handle the request
|
||||
keepalive = self.handle_request(req, conn)
|
||||
if keepalive:
|
||||
# Discard any unread request body before keepalive
|
||||
# to prevent socket appearing readable due to leftover bytes
|
||||
conn.parser.finish_body()
|
||||
return (keepalive, conn)
|
||||
except http.errors.NoMoreData as e:
|
||||
self.log.debug("Ignored premature client disconnection. %s", e)
|
||||
|
||||
415
tests/test_gthread.py
Normal file
415
tests/test_gthread.py
Normal file
@ -0,0 +1,415 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for the gthread worker."""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import queue
|
||||
import selectors
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent import futures
|
||||
from functools import partial
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn import http
|
||||
from gunicorn.config import Config
|
||||
from gunicorn.workers import gthread
|
||||
|
||||
|
||||
class FakeSocket:
|
||||
"""Mock socket for testing."""
|
||||
|
||||
def __init__(self, data=b''):
|
||||
self.data = data
|
||||
self.closed = False
|
||||
self.blocking = True
|
||||
self._fileno = id(self) % 65536
|
||||
|
||||
def fileno(self):
|
||||
return self._fileno
|
||||
|
||||
def setblocking(self, blocking):
|
||||
self.blocking = blocking
|
||||
|
||||
def recv(self, size):
|
||||
if self.closed:
|
||||
raise OSError(errno.EBADF, "Bad file descriptor")
|
||||
result = self.data[:size]
|
||||
self.data = self.data[size:]
|
||||
return result
|
||||
|
||||
def send(self, data):
|
||||
if self.closed:
|
||||
raise OSError(errno.EPIPE, "Broken pipe")
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
def getsockname(self):
|
||||
return ('127.0.0.1', 8000)
|
||||
|
||||
def getpeername(self):
|
||||
return ('127.0.0.1', 12345)
|
||||
|
||||
|
||||
class TestTConn:
|
||||
"""Tests for TConn connection wrapper."""
|
||||
|
||||
def test_tconn_init(self):
|
||||
"""Test TConn initialization."""
|
||||
cfg = Config()
|
||||
sock = FakeSocket()
|
||||
client = ('127.0.0.1', 12345)
|
||||
server = ('127.0.0.1', 8000)
|
||||
|
||||
conn = gthread.TConn(cfg, sock, client, server)
|
||||
|
||||
assert conn.cfg is cfg
|
||||
assert conn.sock is sock
|
||||
assert conn.client == client
|
||||
assert conn.server == server
|
||||
assert conn.timeout is None
|
||||
assert conn.parser is None
|
||||
assert conn.initialized is False
|
||||
|
||||
def test_tconn_init_sets_blocking_false(self):
|
||||
"""Test that TConn sets socket to non-blocking initially."""
|
||||
cfg = Config()
|
||||
sock = FakeSocket()
|
||||
sock.setblocking(True)
|
||||
|
||||
conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000))
|
||||
|
||||
# TConn sets socket to non-blocking in __init__
|
||||
assert sock.blocking is False
|
||||
|
||||
def test_tconn_init_method_sets_blocking_true(self):
|
||||
"""Test that conn.init() sets socket back to blocking."""
|
||||
cfg = Config()
|
||||
sock = FakeSocket()
|
||||
|
||||
conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000))
|
||||
conn.init()
|
||||
|
||||
assert sock.blocking is True
|
||||
assert conn.initialized is True
|
||||
assert conn.parser is not None
|
||||
|
||||
def test_tconn_set_timeout(self):
|
||||
"""Test timeout setting using monotonic clock."""
|
||||
cfg = Config()
|
||||
cfg.set('keepalive', 5)
|
||||
sock = FakeSocket()
|
||||
|
||||
conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000))
|
||||
before = time.monotonic()
|
||||
conn.set_timeout()
|
||||
after = time.monotonic()
|
||||
|
||||
assert conn.timeout is not None
|
||||
assert before + 5 <= conn.timeout <= after + 5
|
||||
|
||||
def test_tconn_close(self):
|
||||
"""Test connection closing."""
|
||||
cfg = Config()
|
||||
sock = FakeSocket()
|
||||
|
||||
conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000))
|
||||
conn.close()
|
||||
|
||||
assert sock.closed is True
|
||||
|
||||
|
||||
class TestThreadWorker:
|
||||
"""Tests for ThreadWorker."""
|
||||
|
||||
def create_worker(self, cfg=None):
|
||||
"""Create a worker instance for testing."""
|
||||
if cfg is None:
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('threads', 4)
|
||||
cfg.set('worker_connections', 1000)
|
||||
cfg.set('keepalive', 2)
|
||||
|
||||
# Mock the required attributes
|
||||
worker = gthread.ThreadWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=mock.Mock(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
return worker
|
||||
|
||||
def test_worker_init(self):
|
||||
"""Test worker initialization."""
|
||||
worker = self.create_worker()
|
||||
|
||||
assert worker.worker_connections == 1000
|
||||
assert worker.max_keepalived == 1000 - 4 # connections - threads
|
||||
assert worker.tpool is None
|
||||
assert worker.poller is None
|
||||
assert worker._lock is None
|
||||
assert worker.nr_conns == 0
|
||||
|
||||
def test_worker_check_config_warning(self):
|
||||
"""Test that check_config warns when keepalive impossible."""
|
||||
cfg = Config()
|
||||
cfg.set('worker_connections', 4)
|
||||
cfg.set('threads', 4)
|
||||
cfg.set('keepalive', 2)
|
||||
log = mock.Mock()
|
||||
|
||||
gthread.ThreadWorker.check_config(cfg, log)
|
||||
|
||||
log.warning.assert_called()
|
||||
|
||||
def test_worker_check_config_no_warning(self):
|
||||
"""Test that check_config doesn't warn with valid config."""
|
||||
cfg = Config()
|
||||
cfg.set('worker_connections', 100)
|
||||
cfg.set('threads', 4)
|
||||
cfg.set('keepalive', 2)
|
||||
log = mock.Mock()
|
||||
|
||||
gthread.ThreadWorker.check_config(cfg, log)
|
||||
|
||||
log.warning.assert_not_called()
|
||||
|
||||
def test_worker_init_process(self):
|
||||
"""Test worker process initialization."""
|
||||
worker = self.create_worker()
|
||||
worker.tmp = mock.Mock()
|
||||
worker.log = mock.Mock()
|
||||
|
||||
# Mock super().init_process() to avoid full initialization
|
||||
with mock.patch.object(gthread.base.Worker, 'init_process'):
|
||||
worker.init_process()
|
||||
|
||||
assert worker.tpool is not None
|
||||
assert worker.poller is not None
|
||||
assert worker._lock is not None
|
||||
|
||||
# Cleanup
|
||||
worker.tpool.shutdown(wait=False)
|
||||
worker.poller.close()
|
||||
|
||||
def test_worker_get_thread_pool(self):
|
||||
"""Test thread pool creation."""
|
||||
worker = self.create_worker()
|
||||
|
||||
pool = worker.get_thread_pool()
|
||||
|
||||
assert isinstance(pool, futures.ThreadPoolExecutor)
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
def test_worker_murder_keepalived(self):
|
||||
"""Test that expired keepalive connections are cleaned up."""
|
||||
worker = self.create_worker()
|
||||
worker.poller = selectors.DefaultSelector()
|
||||
worker._lock = threading.RLock()
|
||||
|
||||
# Create an expired connection (using monotonic to match implementation)
|
||||
cfg = Config()
|
||||
sock = FakeSocket()
|
||||
conn = gthread.TConn(cfg, sock, ('127.0.0.1', 12345), ('127.0.0.1', 8000))
|
||||
conn.timeout = time.monotonic() - 10 # Expired 10 seconds ago
|
||||
|
||||
worker._keep.append(conn)
|
||||
worker.nr_conns = 1
|
||||
|
||||
# Register with poller (so it can be unregistered)
|
||||
try:
|
||||
# Can't register FakeSocket with real selector, mock it
|
||||
with mock.patch.object(worker.poller, 'unregister'):
|
||||
worker.murder_keepalived()
|
||||
except (OSError, ValueError):
|
||||
pass # Expected with fake socket
|
||||
|
||||
# Connection should have been removed
|
||||
assert len(worker._keep) == 0
|
||||
assert sock.closed is True
|
||||
|
||||
worker.poller.close()
|
||||
|
||||
def test_worker_is_parent_alive(self):
|
||||
"""Test parent process check."""
|
||||
worker = self.create_worker()
|
||||
|
||||
# With correct ppid
|
||||
worker.ppid = os.getppid()
|
||||
assert worker.is_parent_alive() is True
|
||||
|
||||
# With wrong ppid
|
||||
worker.ppid = -1
|
||||
assert worker.is_parent_alive() is False
|
||||
|
||||
|
||||
class TestFinishRequest:
|
||||
"""Tests for finish_request handling."""
|
||||
|
||||
def create_worker(self):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('threads', 4)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
worker = gthread.ThreadWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=mock.Mock(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
worker._lock = threading.RLock()
|
||||
worker.poller = mock.Mock()
|
||||
worker.alive = True
|
||||
return worker
|
||||
|
||||
def test_finish_request_cancelled(self):
|
||||
"""Test handling of cancelled future."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 1
|
||||
|
||||
conn = mock.Mock()
|
||||
fs = mock.Mock()
|
||||
fs.cancelled.return_value = True
|
||||
fs.conn = conn
|
||||
|
||||
worker.finish_request(fs)
|
||||
|
||||
assert worker.nr_conns == 0
|
||||
conn.close.assert_called_once()
|
||||
|
||||
def test_finish_request_keepalive(self):
|
||||
"""Test handling of keepalive response."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 1
|
||||
|
||||
conn = mock.Mock()
|
||||
conn.sock = mock.Mock()
|
||||
fs = mock.Mock()
|
||||
fs.cancelled.return_value = False
|
||||
fs.result.return_value = (True, conn) # keepalive=True
|
||||
fs.conn = conn
|
||||
|
||||
worker.finish_request(fs)
|
||||
|
||||
assert worker.nr_conns == 1 # Connection kept
|
||||
assert conn in worker._keep
|
||||
conn.set_timeout.assert_called_once()
|
||||
worker.poller.register.assert_called_once()
|
||||
|
||||
def test_finish_request_close(self):
|
||||
"""Test handling of non-keepalive response."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 1
|
||||
|
||||
conn = mock.Mock()
|
||||
fs = mock.Mock()
|
||||
fs.cancelled.return_value = False
|
||||
fs.result.return_value = (False, conn) # keepalive=False
|
||||
fs.conn = conn
|
||||
|
||||
worker.finish_request(fs)
|
||||
|
||||
assert worker.nr_conns == 0
|
||||
conn.close.assert_called_once()
|
||||
|
||||
def test_finish_request_exception(self):
|
||||
"""Test handling of exception in request."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 1
|
||||
|
||||
conn = mock.Mock()
|
||||
fs = mock.Mock()
|
||||
fs.cancelled.return_value = False
|
||||
fs.result.side_effect = Exception("Test error")
|
||||
fs.conn = conn
|
||||
|
||||
worker.finish_request(fs)
|
||||
|
||||
assert worker.nr_conns == 0
|
||||
conn.close.assert_called_once()
|
||||
|
||||
|
||||
class TestAccept:
|
||||
"""Tests for connection acceptance."""
|
||||
|
||||
def create_worker(self):
|
||||
"""Create a worker for testing."""
|
||||
cfg = Config()
|
||||
cfg.set('workers', 1)
|
||||
cfg.set('threads', 4)
|
||||
cfg.set('worker_connections', 1000)
|
||||
|
||||
worker = gthread.ThreadWorker(
|
||||
age=1,
|
||||
ppid=os.getpid(),
|
||||
sockets=[],
|
||||
app=mock.Mock(),
|
||||
timeout=30,
|
||||
cfg=cfg,
|
||||
log=mock.Mock(),
|
||||
)
|
||||
worker._lock = threading.RLock()
|
||||
worker.poller = mock.Mock()
|
||||
return worker
|
||||
|
||||
def test_accept_success(self):
|
||||
"""Test successful connection acceptance."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 0
|
||||
|
||||
client_sock = FakeSocket()
|
||||
client_addr = ('127.0.0.1', 12345)
|
||||
listener = mock.Mock()
|
||||
listener.accept.return_value = (client_sock, client_addr)
|
||||
server = ('127.0.0.1', 8000)
|
||||
|
||||
worker.accept(server, listener)
|
||||
|
||||
assert worker.nr_conns == 1
|
||||
worker.poller.register.assert_called_once()
|
||||
|
||||
def test_accept_eagain(self):
|
||||
"""Test handling of EAGAIN during accept."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 0
|
||||
|
||||
listener = mock.Mock()
|
||||
listener.accept.side_effect = OSError(errno.EAGAIN, "Try again")
|
||||
server = ('127.0.0.1', 8000)
|
||||
|
||||
# Should not raise
|
||||
worker.accept(server, listener)
|
||||
|
||||
assert worker.nr_conns == 0
|
||||
|
||||
def test_accept_econnaborted(self):
|
||||
"""Test handling of ECONNABORTED during accept."""
|
||||
worker = self.create_worker()
|
||||
worker.nr_conns = 0
|
||||
|
||||
listener = mock.Mock()
|
||||
listener.accept.side_effect = OSError(errno.ECONNABORTED, "Connection aborted")
|
||||
server = ('127.0.0.1', 8000)
|
||||
|
||||
# Should not raise
|
||||
worker.accept(server, listener)
|
||||
|
||||
assert worker.nr_conns == 0
|
||||
Loading…
x
Reference in New Issue
Block a user