feat(dirty): add stash - global shared state between workers (#3503)

* feat(dirty): add stash - global shared state between workers

Add a simple key-value store (stash) that allows dirty workers to share
state through the arbiter. Tables are stored directly in arbiter memory
for fast access and simplicity.

Features:
- Auto-create tables on first access
- Dict-like interface via stash.table()
- Pattern matching for keys (glob patterns)
- Module-level API: stash.put(), stash.get(), stash.delete(), etc.

Usage:
    from gunicorn.dirty import stash

    stash.put("sessions", "user:1", {"name": "Alice"})
    user = stash.get("sessions", "user:1")

    # Or dict-like
    sessions = stash.table("sessions")
    sessions["user:1"] = {"name": "Alice"}

New files:
- gunicorn/dirty/stash.py - Client API and StashTable class
- Protocol additions for MSG_TYPE_STASH and STASH_OP_* codes

Note: Tables are ephemeral - lost if arbiter restarts.

* test(dirty): add tests for stash protocol and encoding

Test coverage for:
- Stash message creation and encoding
- Protocol constants (MSG_TYPE_STASH, STASH_OP_*)
- Error classes (StashError, StashTableNotFoundError, StashKeyNotFoundError)
- StashTable dict-like interface
- Edge cases: unicode, complex values, special patterns

* example(dirty): add stash usage example and integration tests

- Add SessionApp to dirty_app.py demonstrating stash usage
- Add /session/* endpoints to wsgi_app.py
- Add test_stash_integration.py with Docker tests
- Update docker-compose.yml with stash-test service
- Fix: Set GUNICORN_DIRTY_SOCKET in dirty arbiter for worker access

* docs(dirty): add stash documentation
This commit is contained in:
Benoit Chesneau 2026-02-12 21:45:49 +01:00 committed by GitHub
parent 236c9371d0
commit 709a6ad159
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 1626 additions and 3 deletions

View File

@ -558,6 +558,259 @@ def generate_view(request):
4. **Keep chunks small** - Smaller chunks provide better perceived latency
5. **Handle client disconnection** - Streams continue even if client disconnects; design accordingly
## Stash (Shared State via Message Passing)
Stash provides shared state between dirty workers, similar to Erlang's ETS (Erlang Term Storage). Workers remain fully isolated - all state access goes through message passing to the arbiter.
### Architecture
```
+------------------+
| Dirty Arbiter |
| |
| stash_tables: |
| sessions: {} |
| cache: {} |
+--------+---------+
|
Unix Socket IPC (message passing)
|
+-------------------+-------------------+
| | |
+-----v-----+ +-----v-----+ +-----v-----+
| Worker 1 | | Worker 2 | | Worker 3 |
| | | | | |
| (isolated)| | (isolated)| | (isolated)|
+-----------+ +-----------+ +-----------+
Workers have NO shared memory.
All stash operations are IPC messages to arbiter.
```
### How It Works
1. Worker calls `stash.put("sessions", "user:1", data)`
2. Worker sends message to arbiter via Unix socket
3. Arbiter stores data in its memory (`self.stash_tables`)
4. Arbiter sends response back to worker
5. Worker receives confirmation
This is **not** shared memory - workers remain fully isolated. The arbiter acts as a centralized store that workers communicate with via message passing. This matches Erlang's model where ETS tables are owned by a process.
### Basic Usage
```python
from gunicorn.dirty import stash
# Store a value (table auto-created)
# This sends a message to arbiter, which stores it
stash.put("sessions", "user:123", {"name": "Alice", "role": "admin"})
# Retrieve a value
# This sends a request to arbiter, which returns the value
user = stash.get("sessions", "user:123")
# Delete a key
stash.delete("sessions", "user:123")
# Check existence
if stash.exists("sessions", "user:123"):
print("Session exists")
# List keys with pattern matching
keys = stash.keys("sessions", pattern="user:*")
```
### Dict-like Interface
For more Pythonic access, use the table interface:
```python
from gunicorn.dirty import stash
# Get a table reference
sessions = stash.table("sessions")
# Dict-like operations (each is an IPC message)
sessions["user:123"] = {"name": "Alice"}
user = sessions["user:123"]
del sessions["user:123"]
# Iteration
for key in sessions:
print(key, sessions[key])
# Length
count = len(sessions)
```
### Table Management
```python
from gunicorn.dirty import stash
# Explicit table creation (idempotent)
stash.ensure("cache")
# Get table info
info = stash.info("sessions")
print(f"Table has {info['size']} entries")
# Clear all entries in a table
stash.clear("sessions")
# Delete entire table
stash.delete_table("sessions")
# List all tables
tables = stash.tables()
```
### Using Stash in DirtyApp
Declare tables your app uses with the `stashes` class attribute:
```python
from gunicorn.dirty import DirtyApp, stash
class SessionApp(DirtyApp):
# Tables declared here are auto-created on startup
stashes = ["sessions", "counters"]
def init(self):
# Initialize counter if needed
if not stash.exists("counters", "requests"):
stash.put("counters", "requests", 0)
def login(self, user_id, user_data):
"""Store session - any worker can read it via arbiter."""
stash.put("sessions", f"user:{user_id}", {
"data": user_data,
"logged_in_at": time.time(),
})
self._increment_counter()
return {"status": "ok"}
def get_session(self, user_id):
"""Get session - request goes to arbiter."""
return stash.get("sessions", f"user:{user_id}")
def _increment_counter(self):
"""Increment global counter via arbiter."""
current = stash.get("counters", "requests", 0)
stash.put("counters", "requests", current + 1)
def close(self):
pass
```
### API Reference
| Function | Description |
|----------|-------------|
| `stash.put(table, key, value)` | Store a value (table auto-created) |
| `stash.get(table, key, default=None)` | Retrieve a value |
| `stash.delete(table, key)` | Delete a key, returns True if deleted |
| `stash.exists(table, key=None)` | Check if table/key exists |
| `stash.keys(table, pattern=None)` | List keys, optional glob pattern |
| `stash.clear(table)` | Delete all entries in table |
| `stash.info(table)` | Get table info (size, etc.) |
| `stash.ensure(table)` | Create table if not exists |
| `stash.delete_table(table)` | Delete entire table |
| `stash.tables()` | List all table names |
| `stash.table(name)` | Get dict-like interface |
### Patterns and Use Cases
**Session Storage:**
```python
# Store session on login (worker 1)
stash.put("sessions", f"user:{user_id}", session_data)
# Check session on request (may be worker 2)
session = stash.get("sessions", f"user:{user_id}")
if session is None:
raise AuthError("Not logged in")
```
**Shared Cache:**
```python
def get_expensive_result(key):
# Check cache first (via arbiter)
cached = stash.get("cache", key)
if cached is not None:
return cached
# Compute and cache
result = expensive_computation()
stash.put("cache", key, result)
return result
```
**Global Counters:**
```python
def increment_counter(name):
# Note: not atomic - two workers could read same value
current = stash.get("counters", name, 0)
stash.put("counters", name, current + 1)
return current + 1
```
**Feature Flags:**
```python
# Set flag (from admin endpoint)
stash.put("flags", "new_feature", True)
# Check flag (from any worker)
if stash.get("flags", "new_feature", False):
enable_new_feature()
```
### Error Handling
```python
from gunicorn.dirty.stash import (
StashError,
StashTableNotFoundError,
StashKeyNotFoundError,
)
try:
info = stash.info("nonexistent")
except StashTableNotFoundError as e:
print(f"Table not found: {e.table_name}")
# Using get() with default avoids KeyNotFoundError
value = stash.get("table", "key", default="fallback")
```
### Best Practices
1. **Use descriptive table names** - `user_sessions`, `ml_cache`, not `data`
2. **Use key prefixes** - `user:123`, `cache:model:v1` for organization
3. **Handle missing data** - Always provide defaults or check existence
4. **Don't store large data** - Each access is an IPC round-trip
5. **Remember it's ephemeral** - Data is lost on arbiter restart
### Advantages
- **Worker isolation** - Workers remain fully isolated; no shared memory bugs
- **Simple API** - Dict-like interface, no locking required
- **Binary support** - Efficiently stores bytes (images, model weights)
- **Pattern matching** - `keys(pattern="user:*")` for querying
- **Zero setup** - Works automatically with dirty workers
- **Table-based** - Organize data into logical namespaces
### Limitations
- **No persistence** - Data lives only in arbiter memory
- **No transactions** - No atomic read-modify-write operations
- **No TTL** - Entries don't expire automatically
- **IPC overhead** - Each operation is a network round-trip
- **Single arbiter** - Not distributed across multiple machines
For persistent or distributed state, use Redis, PostgreSQL, or similar external systems.
### Flask Example
```python

View File

@ -11,9 +11,11 @@ This demonstrates how to create a DirtyApp that:
3. Cleans up on shutdown (close)
"""
import os
import time
import hashlib
from gunicorn.dirty.app import DirtyApp
from gunicorn.dirty import stash
class MLApp(DirtyApp):
@ -171,3 +173,96 @@ class ComputeApp(DirtyApp):
def close(self):
print(f"[ComputeApp] Shutting down. Total computations: {self.computation_count}")
class SessionApp(DirtyApp):
"""
Example dirty application demonstrating stash (shared state).
This shows how multiple dirty workers can share state through
the arbiter's stash tables. All workers see the same data.
"""
# Declare stash tables used by this app (auto-created on startup)
stashes = ["sessions", "counters"]
def __init__(self):
self.worker_pid = None
def init(self):
self.worker_pid = os.getpid()
print(f"[SessionApp] Initialized on worker {self.worker_pid}")
# Initialize a global counter if it doesn't exist
if not stash.exists("counters", "requests"):
stash.put("counters", "requests", 0)
def __call__(self, action, *args, **kwargs):
method = getattr(self, action, None)
if method is None or action.startswith('_'):
raise ValueError(f"Unknown action: {action}")
return method(*args, **kwargs)
def login(self, user_id, user_data):
"""Store user session in shared stash."""
session = {
"user_id": user_id,
"data": user_data,
"logged_in_at": time.time(),
"worker_pid": self.worker_pid,
}
stash.put("sessions", f"user:{user_id}", session)
self._increment_counter()
return {"status": "ok", "session": session}
def logout(self, user_id):
"""Remove user session."""
key = f"user:{user_id}"
if stash.exists("sessions", key):
stash.delete("sessions", key)
self._increment_counter()
return {"status": "logged_out", "user_id": user_id}
return {"status": "not_found", "user_id": user_id}
def get_session(self, user_id):
"""Get user session - visible from any worker."""
session = stash.get("sessions", f"user:{user_id}")
self._increment_counter()
return {
"session": session,
"served_by_worker": self.worker_pid,
}
def list_sessions(self):
"""List all active sessions."""
keys = stash.keys("sessions", pattern="user:*")
sessions = []
for key in keys:
sessions.append(stash.get("sessions", key))
self._increment_counter()
return {
"sessions": sessions,
"count": len(sessions),
"served_by_worker": self.worker_pid,
}
def get_stats(self):
"""Get global request counter (shared across all workers)."""
count = stash.get("counters", "requests", 0)
return {
"total_requests": count,
"served_by_worker": self.worker_pid,
}
def _increment_counter(self):
"""Increment global request counter."""
current = stash.get("counters", "requests", 0)
stash.put("counters", "requests", current + 1)
def clear_all(self):
"""Clear all sessions (for testing)."""
stash.clear("sessions")
stash.put("counters", "requests", 0)
return {"status": "cleared"}
def close(self):
print(f"[SessionApp] Shutting down worker {self.worker_pid}")

View File

@ -52,3 +52,15 @@ services:
environment:
- TEST_BASE_URL=http://server:8000
command: python examples/dirty_example/test_integration.py
# Run stash integration test against the server
stash-test:
build:
context: ../..
dockerfile: examples/dirty_example/Dockerfile
depends_on:
server:
condition: service_healthy
environment:
- TEST_BASE_URL=http://server:8000
command: python examples/dirty_example/test_stash_integration.py

View File

@ -22,6 +22,7 @@ timeout = 30
dirty_apps = [
"examples.dirty_example.dirty_app:MLApp",
"examples.dirty_example.dirty_app:ComputeApp",
"examples.dirty_example.dirty_app:SessionApp",
]
dirty_workers = 2
dirty_timeout = 300

View File

@ -0,0 +1,226 @@
#!/usr/bin/env python3
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Integration tests for stash (shared state) functionality.
These tests verify that stash works correctly across multiple dirty workers,
demonstrating that state is truly shared.
Run with Docker:
docker-compose up --build
docker-compose exec app python test_stash_integration.py
"""
import json
import os
import sys
import urllib.request
import urllib.error
BASE_URL = os.environ.get("TEST_BASE_URL", "http://localhost:8000")
def request(path):
"""Make HTTP request and return JSON response."""
url = f"{BASE_URL}{path}"
try:
with urllib.request.urlopen(url, timeout=10) as resp:
return json.loads(resp.read().decode())
except urllib.error.HTTPError as e:
return {"error": str(e), "code": e.code}
except urllib.error.URLError as e:
return {"error": str(e)}
def test_stash_shared_state():
"""Test that stash state is shared across workers."""
print("\n=== Test: Stash Shared State ===")
# Clear any existing state
result = request("/session/clear")
print(f"Clear: {result}")
# Login a user
result = request("/session/login?user_id=100&name=Alice")
print(f"Login Alice: {result}")
assert result.get("status") == "ok", f"Login failed: {result}"
worker1 = result.get("session", {}).get("worker_pid")
print(f" -> Handled by worker: {worker1}")
# Make multiple requests to potentially hit different workers
# and verify they all see the same session
workers_seen = set()
for i in range(5):
result = request("/session/get?user_id=100")
worker = result.get("served_by_worker")
workers_seen.add(worker)
session = result.get("session")
assert session is not None, f"Session not found on request {i+1}"
assert session.get("data", {}).get("name") == "Alice", f"Wrong session data"
print(f" -> Session visible from workers: {workers_seen}")
print("PASSED: State is shared across workers")
return True
def test_stash_counter():
"""Test that global counter increments correctly."""
print("\n=== Test: Global Counter ===")
# Clear state
request("/session/clear")
# Get initial stats
result = request("/session/stats")
initial = result.get("total_requests", 0)
print(f"Initial count: {initial}")
# Make several requests
for i in range(5):
request(f"/session/login?user_id={i}&name=User{i}")
# Check counter increased
result = request("/session/stats")
final = result.get("total_requests", 0)
print(f"Final count: {final}")
# Each login increments counter by 1
assert final >= initial + 5, f"Counter didn't increment enough: {initial} -> {final}"
print("PASSED: Global counter works across workers")
return True
def test_stash_list_sessions():
"""Test listing all sessions."""
print("\n=== Test: List Sessions ===")
# Clear and create some sessions
request("/session/clear")
request("/session/login?user_id=1&name=Alice")
request("/session/login?user_id=2&name=Bob")
request("/session/login?user_id=3&name=Charlie")
# List all sessions
result = request("/session/list")
sessions = result.get("sessions", [])
count = result.get("count", 0)
print(f"Sessions: {count}")
for s in sessions:
print(f" - user:{s.get('user_id')} = {s.get('data', {}).get('name')}")
assert count == 3, f"Expected 3 sessions, got {count}"
print("PASSED: List sessions works")
return True
def test_stash_logout():
"""Test session deletion."""
print("\n=== Test: Logout (Delete) ===")
# Clear and create a session
request("/session/clear")
request("/session/login?user_id=999&name=TestUser")
# Verify it exists
result = request("/session/get?user_id=999")
assert result.get("session") is not None, "Session should exist"
# Logout
result = request("/session/logout?user_id=999")
print(f"Logout: {result}")
assert result.get("status") == "logged_out", f"Logout failed: {result}"
# Verify it's gone
result = request("/session/get?user_id=999")
assert result.get("session") is None, "Session should be deleted"
print("PASSED: Logout deletes session")
return True
def test_multiple_workers_see_updates():
"""Test that updates from one worker are visible to others."""
print("\n=== Test: Cross-Worker Updates ===")
request("/session/clear")
# Create sessions and track which workers handled them
workers = {}
for i in range(10):
result = request(f"/session/login?user_id={i}&name=User{i}")
worker = result.get("session", {}).get("worker_pid")
workers[i] = worker
unique_workers = set(workers.values())
print(f"Sessions created by workers: {unique_workers}")
# Now read all sessions and verify all workers can see all data
result = request("/session/list")
count = result.get("count", 0)
served_by = result.get("served_by_worker")
print(f"List returned {count} sessions, served by worker {served_by}")
assert count == 10, f"Expected 10 sessions, got {count}"
print("PASSED: All workers see all updates")
return True
def main():
"""Run all tests."""
print("=" * 60)
print("Stash Integration Tests")
print("=" * 60)
# Check server is running
try:
result = request("/")
if "error" in result and "Connection refused" in str(result.get("error", "")):
print("ERROR: Server not running. Start with: docker-compose up")
return 1
if not result.get("dirty_enabled"):
print("ERROR: Dirty workers not enabled")
return 1
print(f"Server running, dirty workers enabled")
except Exception as e:
print(f"ERROR: Cannot connect to server: {e}")
return 1
# Run tests
tests = [
test_stash_shared_state,
test_stash_counter,
test_stash_list_sessions,
test_stash_logout,
test_multiple_workers_see_updates,
]
passed = 0
failed = 0
for test in tests:
try:
if test():
passed += 1
else:
failed += 1
except AssertionError as e:
print(f"FAILED: {e}")
failed += 1
except Exception as e:
print(f"ERROR: {e}")
failed += 1
print("\n" + "=" * 60)
print(f"Results: {passed} passed, {failed} failed")
print("=" * 60)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@ -52,6 +52,11 @@ def app(environ, start_response):
"/fibonacci?n=NUMBER": "Compute fibonacci",
"/prime?n=NUMBER": "Check if prime",
"/stats": "Get dirty worker stats",
"/session/login?user_id=ID&name=NAME": "Login user (stash demo)",
"/session/get?user_id=ID": "Get session (stash demo)",
"/session/list": "List all sessions (stash demo)",
"/session/logout?user_id=ID": "Logout user (stash demo)",
"/session/stats": "Get stash stats (stash demo)",
}
}
@ -139,6 +144,71 @@ def app(environ, start_response):
"http_worker_pid": os.getpid(),
}
# =====================================================================
# Session endpoints (stash demo)
# =====================================================================
elif path == '/session/login':
user_id = query.get('user_id', ['1'])[0]
name = query.get('name', ['Anonymous'])[0]
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"login",
user_id=user_id,
user_data={"name": name}
)
elif path == '/session/get':
user_id = query.get('user_id', ['1'])[0]
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"get_session",
user_id=user_id
)
elif path == '/session/list':
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"list_sessions"
)
elif path == '/session/logout':
user_id = query.get('user_id', ['1'])[0]
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"logout",
user_id=user_id
)
elif path == '/session/stats':
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"get_stats"
)
elif path == '/session/clear':
if client is None:
result = {"error": "Dirty workers not enabled"}
else:
result = client.execute(
"examples.dirty_example.dirty_app:SessionApp",
"clear_all"
)
else:
start_response('404 Not Found', [('Content-Type', 'application/json')])
return [json.dumps({"error": "Not found"}).encode()]

View File

@ -38,6 +38,16 @@ from .client import (
close_dirty_client_async,
)
# Stash (shared state between workers)
from . import stash
from .stash import (
StashClient,
StashTable,
StashError,
StashTableNotFoundError,
StashKeyNotFoundError,
)
# Internal imports used by gunicorn core (not part of public API)
from .arbiter import DirtyArbiter
@ -58,6 +68,13 @@ __all__ = [
"get_dirty_client_async",
"close_dirty_client",
"close_dirty_client_async",
# Stash (shared state)
"stash",
"StashClient",
"StashTable",
"StashError",
"StashTableNotFoundError",
"StashKeyNotFoundError",
# Internal (used by gunicorn core)
"DirtyArbiter",
"set_dirty_socket_path",

View File

@ -11,6 +11,7 @@ requests from HTTP workers to available dirty workers.
import asyncio
import errno
import fnmatch
import os
import signal
import sys
@ -29,6 +30,17 @@ from .errors import (
from .protocol import (
DirtyProtocol,
make_error_response,
make_response,
STASH_OP_PUT,
STASH_OP_GET,
STASH_OP_DELETE,
STASH_OP_KEYS,
STASH_OP_CLEAR,
STASH_OP_INFO,
STASH_OP_ENSURE,
STASH_OP_DELETE_TABLE,
STASH_OP_TABLES,
STASH_OP_EXISTS,
)
from .worker import DirtyWorker
@ -97,6 +109,10 @@ class DirtyArbiter:
# Queue of app lists from dead workers to respawn with same apps
self._pending_respawns = []
# Stash (shared state) - global tables stored in arbiter
# Maps table_name -> dict of data
self.stash_tables = {}
# Parse app specs on init
self._parse_app_specs()
@ -209,6 +225,9 @@ class DirtyArbiter:
except IOError as e:
self.log.warning("Failed to write PID file: %s", e)
# Set socket path env var for dirty workers (enables stash access)
os.environ['GUNICORN_DIRTY_SOCKET'] = self.socket_path
# Call hook
self.cfg.on_dirty_starting(self)
@ -337,6 +356,7 @@ class DirtyArbiter:
Routes requests to available dirty workers and returns responses.
Supports both regular responses and streaming (chunk-based) responses.
Also handles stash (shared state) operations.
"""
self.log.debug("New client connection from HTTP worker")
@ -347,8 +367,14 @@ class DirtyArbiter:
except asyncio.IncompleteReadError:
break
# Route request to a dirty worker - pass writer for streaming
await self.route_request(message, writer)
msg_type = message.get("type")
# Handle stash operations
if msg_type == DirtyProtocol.MSG_TYPE_STASH:
await self.handle_stash_request(message, writer)
else:
# Route request to a dirty worker - pass writer for streaming
await self.route_request(message, writer)
except Exception as e:
self.log.error("Client connection error: %s", e)
finally:
@ -565,6 +591,127 @@ class DirtyArbiter:
_reader, writer = self.worker_connections.pop(worker_pid)
writer.close()
# -------------------------------------------------------------------------
# Stash (shared state) operations - handled directly in arbiter
# -------------------------------------------------------------------------
async def handle_stash_request(self, message, client_writer):
"""
Handle a stash operation directly in the arbiter.
All stash tables are stored in arbiter memory for simplicity
and fast access.
Args:
message: Stash operation message
client_writer: StreamWriter to send response to client
"""
request_id = message.get("id", "unknown")
op = message.get("op")
table = message.get("table", "")
key = message.get("key")
value = message.get("value")
pattern = message.get("pattern")
try:
result = None
if op == STASH_OP_PUT:
# Auto-create table if needed
if table not in self.stash_tables:
self.stash_tables[table] = {}
self.stash_tables[table][key] = value
result = True
elif op == STASH_OP_GET:
if table not in self.stash_tables:
result = {"error": "key_not_found"}
elif key not in self.stash_tables[table]:
result = {"error": "key_not_found"}
else:
result = self.stash_tables[table][key]
elif op == STASH_OP_DELETE:
if table in self.stash_tables and key in self.stash_tables[table]:
del self.stash_tables[table][key]
result = True
else:
result = False
elif op == STASH_OP_KEYS:
if table not in self.stash_tables:
result = []
else:
all_keys = list(self.stash_tables[table].keys())
if pattern:
all_keys = [k for k in all_keys
if fnmatch.fnmatch(str(k), pattern)]
result = all_keys
elif op == STASH_OP_CLEAR:
if table in self.stash_tables:
self.stash_tables[table].clear()
result = True
elif op == STASH_OP_INFO:
if table not in self.stash_tables:
result = {"error": "table_not_found"}
else:
result = {
"size": len(self.stash_tables[table]),
"table": table,
}
elif op == STASH_OP_ENSURE:
if table not in self.stash_tables:
self.stash_tables[table] = {}
result = True
elif op == STASH_OP_DELETE_TABLE:
if table in self.stash_tables:
del self.stash_tables[table]
result = True
else:
result = False
elif op == STASH_OP_TABLES:
result = list(self.stash_tables.keys())
elif op == STASH_OP_EXISTS:
if table not in self.stash_tables:
result = False
elif key is None:
result = True
else:
result = key in self.stash_tables[table]
else:
error = DirtyError(f"Unknown stash operation: {op}")
response = make_error_response(request_id, error)
await DirtyProtocol.write_message_async(client_writer, response)
return
# Handle error results
if isinstance(result, dict) and "error" in result:
error_type = result["error"]
if error_type == "table_not_found":
error = DirtyError(f"Table not found: {table}")
elif error_type == "key_not_found":
error = DirtyError(f"Key not found: {key}")
else:
error = DirtyError(str(result))
error.error_type = f"Stash{error_type.title().replace('_', '')}Error"
response = make_error_response(request_id, error)
else:
response = make_response(request_id, result)
await DirtyProtocol.write_message_async(client_writer, response)
except Exception as e:
self.log.error("Stash operation error: %s", e)
response = make_error_response(request_id, DirtyError(str(e)))
await DirtyProtocol.write_message_async(client_writer, response)
async def manage_workers(self):
"""Maintain the number of dirty workers."""
if not self.alive:

View File

@ -659,6 +659,10 @@ def set_dirty_socket_path(path):
global _dirty_socket_path # pylint: disable=global-statement
_dirty_socket_path = path
# Also set the stash socket path (uses same arbiter socket)
from .stash import set_stash_socket_path
set_stash_socket_path(path)
def get_dirty_socket_path():
"""Get the dirty socket path."""

View File

@ -42,6 +42,7 @@ MSG_TYPE_RESPONSE = 0x02
MSG_TYPE_ERROR = 0x03
MSG_TYPE_CHUNK = 0x04
MSG_TYPE_END = 0x05
MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers)
# Message type names (for backwards compatibility with old API)
MSG_TYPE_REQUEST_STR = "request"
@ -49,6 +50,7 @@ MSG_TYPE_RESPONSE_STR = "response"
MSG_TYPE_ERROR_STR = "error"
MSG_TYPE_CHUNK_STR = "chunk"
MSG_TYPE_END_STR = "end"
MSG_TYPE_STASH_STR = "stash"
# Map int types to string names
MSG_TYPE_TO_STR = {
@ -57,11 +59,24 @@ MSG_TYPE_TO_STR = {
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
MSG_TYPE_END: MSG_TYPE_END_STR,
MSG_TYPE_STASH: MSG_TYPE_STASH_STR,
}
# Map string names to int types
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
# Stash operation codes
STASH_OP_PUT = 1
STASH_OP_GET = 2
STASH_OP_DELETE = 3
STASH_OP_KEYS = 4
STASH_OP_CLEAR = 5
STASH_OP_INFO = 6
STASH_OP_ENSURE = 7
STASH_OP_DELETE_TABLE = 8
STASH_OP_TABLES = 9
STASH_OP_EXISTS = 10
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
HEADER_FORMAT = ">2sBBIQ"
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
@ -82,6 +97,7 @@ class BinaryProtocol:
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
MSG_TYPE_END = MSG_TYPE_END_STR
MSG_TYPE_STASH = MSG_TYPE_STASH_STR
@staticmethod
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
@ -257,6 +273,39 @@ class BinaryProtocol:
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
return header
@staticmethod
def encode_stash(request_id: int, op: int, table: str,
key=None, value=None, pattern=None) -> bytes:
"""
Encode a stash operation message.
Args:
request_id: Unique request identifier (uint64)
op: Stash operation code (STASH_OP_*)
table: Table name
key: Optional key for put/get/delete operations
value: Optional value for put operation
pattern: Optional pattern for keys operation
Returns:
bytes: Complete message (header + payload)
"""
payload_dict = {
"op": op,
"table": table,
}
if key is not None:
payload_dict["key"] = key
if value is not None:
payload_dict["value"] = value
if pattern is not None:
payload_dict["pattern"] = pattern
payload = TLVEncoder.encode(payload_dict)
header = BinaryProtocol.encode_header(MSG_TYPE_STASH, request_id,
len(payload))
return header + payload
@staticmethod
def decode_message(data: bytes) -> tuple:
"""
@ -524,6 +573,15 @@ class BinaryProtocol:
)
elif msg_type == MSG_TYPE_END:
return BinaryProtocol.encode_end(request_id)
elif msg_type == MSG_TYPE_STASH:
return BinaryProtocol.encode_stash(
request_id,
message.get("op"),
message.get("table", ""),
message.get("key"),
message.get("value"),
message.get("pattern")
)
else:
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
@ -642,3 +700,34 @@ def make_end_message(request_id) -> dict:
"type": DirtyProtocol.MSG_TYPE_END,
"id": request_id,
}
def make_stash_message(request_id, op: int, table: str,
key=None, value=None, pattern=None) -> dict:
"""
Build a stash operation message dict.
Args:
request_id: Unique request identifier (int or str)
op: Stash operation code (STASH_OP_*)
table: Table name
key: Optional key for put/get/delete operations
value: Optional value for put operation
pattern: Optional pattern for keys operation
Returns:
dict: Stash message dict
"""
msg = {
"type": DirtyProtocol.MSG_TYPE_STASH,
"id": request_id,
"op": op,
"table": table,
}
if key is not None:
msg["key"] = key
if value is not None:
msg["value"] = value
if pattern is not None:
msg["pattern"] = pattern
return msg

503
gunicorn/dirty/stash.py Normal file
View File

@ -0,0 +1,503 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Stash - Global Shared State for Dirty Workers
Provides simple key-value tables stored in the arbiter process.
All workers can read and write to the same tables.
Usage::
from gunicorn.dirty import stash
# Basic operations - table is auto-created on first access
stash.put("sessions", "user:1", {"name": "Alice", "role": "admin"})
user = stash.get("sessions", "user:1")
stash.delete("sessions", "user:1")
# Dict-like interface
sessions = stash.table("sessions")
sessions["user:1"] = {"name": "Alice"}
user = sessions["user:1"]
del sessions["user:1"]
# Query operations
keys = stash.keys("sessions")
keys = stash.keys("sessions", pattern="user:*")
# Table management
stash.ensure("cache") # Explicit creation (idempotent)
stash.clear("sessions") # Delete all entries
stash.delete_table("sessions") # Delete the table itself
tables = stash.tables() # List all tables
Declarative usage in DirtyApp::
class MyApp(DirtyApp):
stashes = ["sessions", "cache"] # Auto-created on arbiter start
def __call__(self, action, *args, **kwargs):
# Tables are ready to use
stash.put("sessions", "key", "value")
Note: Tables are stored in the arbiter process and are ephemeral.
If the arbiter restarts, all data is lost.
"""
import threading
import uuid
from .errors import DirtyError
from .protocol import (
DirtyProtocol,
STASH_OP_PUT,
STASH_OP_GET,
STASH_OP_DELETE,
STASH_OP_KEYS,
STASH_OP_CLEAR,
STASH_OP_INFO,
STASH_OP_ENSURE,
STASH_OP_DELETE_TABLE,
STASH_OP_TABLES,
STASH_OP_EXISTS,
make_stash_message,
)
class StashError(DirtyError):
"""Base exception for stash operations."""
class StashTableNotFoundError(StashError):
"""Raised when a table does not exist."""
def __init__(self, table_name):
self.table_name = table_name
super().__init__(f"Stash table not found: {table_name}")
class StashKeyNotFoundError(StashError):
"""Raised when a key does not exist in a table."""
def __init__(self, table_name, key):
self.table_name = table_name
self.key = key
super().__init__(f"Key not found in {table_name}: {key}")
class StashClient:
"""
Client for stash operations.
Communicates with the arbiter which stores all tables in memory.
"""
def __init__(self, socket_path, timeout=30.0):
"""
Initialize the stash client.
Args:
socket_path: Path to the dirty arbiter's Unix socket
timeout: Default timeout for operations in seconds
"""
self.socket_path = socket_path
self.timeout = timeout
self._sock = None
self._lock = threading.Lock()
def _get_request_id(self):
"""Generate a unique request ID."""
return str(uuid.uuid4())
def _connect(self):
"""Establish connection to arbiter."""
import socket
if self._sock is not None:
return
try:
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self._sock.settimeout(self.timeout)
self._sock.connect(self.socket_path)
except (socket.error, OSError) as e:
self._sock = None
raise StashError(f"Failed to connect to arbiter: {e}") from e
def _close(self):
"""Close the connection."""
if self._sock is not None:
try:
self._sock.close()
except Exception:
pass
self._sock = None
def _execute(self, op, table, key=None, value=None, pattern=None):
"""
Execute a stash operation.
Args:
op: Operation code (STASH_OP_*)
table: Table name
key: Optional key
value: Optional value
pattern: Optional pattern for keys operation
Returns:
Result from the operation
"""
with self._lock:
if self._sock is None:
self._connect()
request_id = self._get_request_id()
message = make_stash_message(
request_id, op, table,
key=key, value=value, pattern=pattern
)
try:
DirtyProtocol.write_message(self._sock, message)
response = DirtyProtocol.read_message(self._sock)
msg_type = response.get("type")
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
return response.get("result")
elif msg_type == DirtyProtocol.MSG_TYPE_ERROR:
error_info = response.get("error", {})
error_type = error_info.get("error_type", "StashError")
error_msg = error_info.get("message", "Unknown error")
if error_type == "StashTableNotFoundError":
raise StashTableNotFoundError(table)
if error_type == "StashKeyNotFoundError":
raise StashKeyNotFoundError(table, key)
raise StashError(error_msg)
else:
raise StashError(f"Unexpected response type: {msg_type}")
except Exception as e:
self._close()
if isinstance(e, StashError):
raise
raise StashError(f"Stash operation failed: {e}") from e
# -------------------------------------------------------------------------
# Public API
# -------------------------------------------------------------------------
def put(self, table, key, value):
"""
Store a value in a table.
The table is automatically created if it doesn't exist.
Args:
table: Table name
key: Key to store under
value: Value to store (must be serializable)
"""
self._execute(STASH_OP_PUT, table, key=key, value=value)
def get(self, table, key, default=None):
"""
Retrieve a value from a table.
Args:
table: Table name
key: Key to retrieve
default: Default value if key not found
Returns:
The stored value, or default if not found
"""
try:
return self._execute(STASH_OP_GET, table, key=key)
except StashKeyNotFoundError:
return default
def delete(self, table, key):
"""
Delete a key from a table.
Args:
table: Table name
key: Key to delete
Returns:
True if key was deleted, False if it didn't exist
"""
return self._execute(STASH_OP_DELETE, table, key=key)
def keys(self, table, pattern=None):
"""
Get all keys in a table, optionally filtered by pattern.
Args:
table: Table name
pattern: Optional glob pattern (e.g., "user:*")
Returns:
List of keys
"""
return self._execute(STASH_OP_KEYS, table, pattern=pattern)
def clear(self, table):
"""
Delete all entries in a table.
Args:
table: Table name
"""
self._execute(STASH_OP_CLEAR, table)
def info(self, table):
"""
Get information about a table.
Args:
table: Table name
Returns:
Dict with table info (size, etc.)
"""
return self._execute(STASH_OP_INFO, table)
def ensure(self, table):
"""
Ensure a table exists (create if not exists).
This is idempotent - calling it multiple times is safe.
Args:
table: Table name
"""
self._execute(STASH_OP_ENSURE, table)
def exists(self, table, key=None):
"""
Check if a table or key exists.
Args:
table: Table name
key: Optional key to check within the table
Returns:
True if exists, False otherwise
"""
return self._execute(STASH_OP_EXISTS, table, key=key)
def delete_table(self, table):
"""
Delete an entire table.
Args:
table: Table name
"""
self._execute(STASH_OP_DELETE_TABLE, table)
def tables(self):
"""
List all tables.
Returns:
List of table names
"""
return self._execute(STASH_OP_TABLES, "")
def table(self, name):
"""
Get a dict-like interface to a table.
Args:
name: Table name
Returns:
StashTable instance
"""
return StashTable(self, name)
def close(self):
"""Close the client connection."""
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
class StashTable:
"""
Dict-like interface to a stash table.
Example::
sessions = stash.table("sessions")
sessions["user:1"] = {"name": "Alice"}
user = sessions["user:1"]
del sessions["user:1"]
# Iteration
for key in sessions:
print(key, sessions[key])
"""
def __init__(self, client, name):
self._client = client
self._name = name
@property
def name(self):
"""Table name."""
return self._name
def __getitem__(self, key):
result = self._client.get(self._name, key)
if result is None:
# Check if key actually exists with None value
if not self._client.exists(self._name, key):
raise KeyError(key)
return result
def __setitem__(self, key, value):
self._client.put(self._name, key, value)
def __delitem__(self, key):
if not self._client.delete(self._name, key):
raise KeyError(key)
def __contains__(self, key):
return self._client.exists(self._name, key)
def __iter__(self):
return iter(self._client.keys(self._name))
def __len__(self):
info = self._client.info(self._name)
return info.get("size", 0)
def get(self, key, default=None):
"""Get value with default."""
return self._client.get(self._name, key, default)
def keys(self, pattern=None):
"""Get all keys, optionally filtered by pattern."""
return self._client.keys(self._name, pattern=pattern)
def clear(self):
"""Delete all entries."""
self._client.clear(self._name)
def items(self):
"""Iterate over (key, value) pairs."""
for key in self._client.keys(self._name):
yield key, self._client.get(self._name, key)
def values(self):
"""Iterate over values."""
for key in self._client.keys(self._name):
yield self._client.get(self._name, key)
# =============================================================================
# Global stash instance (module-level API)
# =============================================================================
# Thread-local storage for stash clients
_thread_local = threading.local()
# Global socket path
_stash_socket_path = None
def set_stash_socket_path(path):
"""Set the global stash socket path (called during initialization)."""
global _stash_socket_path # pylint: disable=global-statement
_stash_socket_path = path
def get_stash_socket_path():
"""Get the stash socket path."""
import os
if _stash_socket_path is None:
# Check environment variable
path = os.environ.get('GUNICORN_DIRTY_SOCKET')
if path:
return path
raise StashError(
"Stash socket path not configured. "
"Make sure dirty_workers > 0 and dirty_apps are configured."
)
return _stash_socket_path
def _get_client():
"""Get or create a thread-local stash client."""
client = getattr(_thread_local, 'stash_client', None)
if client is None:
socket_path = get_stash_socket_path()
client = StashClient(socket_path)
_thread_local.stash_client = client
return client
# Module-level functions that use the thread-local client
def put(table, key, value):
"""Store a value in a table."""
_get_client().put(table, key, value)
def get(table, key, default=None):
"""Retrieve a value from a table."""
return _get_client().get(table, key, default)
def delete(table, key):
"""Delete a key from a table."""
return _get_client().delete(table, key)
def keys(table, pattern=None):
"""Get all keys in a table."""
return _get_client().keys(table, pattern)
def clear(table):
"""Delete all entries in a table."""
_get_client().clear(table)
def info(table):
"""Get information about a table."""
return _get_client().info(table)
def ensure(table):
"""Ensure a table exists."""
_get_client().ensure(table)
def exists(table, key=None):
"""Check if a table or key exists."""
return _get_client().exists(table, key)
def delete_table(table):
"""Delete an entire table."""
_get_client().delete_table(table)
def tables():
"""List all tables."""
return _get_client().tables()
def table(name):
"""Get a dict-like interface to a table."""
return _get_client().table(name)

View File

@ -168,10 +168,10 @@ class DirtyWorker:
self.load_apps()
# Call hook
self.pid = os.getpid()
self.cfg.dirty_worker_init(self)
# Enter main run loop
self.pid = os.getpid()
self.booted = True
self.run()

206
tests/test_dirty_stash.py Normal file
View File

@ -0,0 +1,206 @@
#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""Tests for dirty stash (shared state) functionality."""
import pytest
from gunicorn.dirty.stash import (
StashClient,
StashTable,
StashError,
StashTableNotFoundError,
StashKeyNotFoundError,
)
from gunicorn.dirty.protocol import (
BinaryProtocol,
DirtyProtocol,
MSG_TYPE_STASH,
STASH_OP_PUT,
STASH_OP_GET,
STASH_OP_DELETE,
STASH_OP_KEYS,
STASH_OP_CLEAR,
STASH_OP_INFO,
STASH_OP_ENSURE,
STASH_OP_DELETE_TABLE,
STASH_OP_TABLES,
STASH_OP_EXISTS,
make_stash_message,
)
class TestStashProtocol:
"""Test stash protocol encoding."""
def test_make_stash_message_basic(self):
"""Test basic stash message creation."""
msg = make_stash_message(123, STASH_OP_PUT, "test_table")
assert msg["type"] == "stash"
assert msg["id"] == 123
assert msg["op"] == STASH_OP_PUT
assert msg["table"] == "test_table"
def test_make_stash_message_with_key_value(self):
"""Test stash message with key and value."""
msg = make_stash_message(
456, STASH_OP_PUT, "sessions",
key="user:1", value={"name": "Alice"}
)
assert msg["key"] == "user:1"
assert msg["value"] == {"name": "Alice"}
def test_make_stash_message_with_pattern(self):
"""Test stash message with pattern."""
msg = make_stash_message(
789, STASH_OP_KEYS, "sessions",
pattern="user:*"
)
assert msg["pattern"] == "user:*"
def test_encode_stash_message(self):
"""Test binary encoding of stash message."""
msg = make_stash_message(
123, STASH_OP_PUT, "test",
key="k", value="v"
)
encoded = BinaryProtocol._encode_from_dict(msg)
assert isinstance(encoded, bytes)
assert len(encoded) > 16 # Header + payload
def test_stash_message_roundtrip(self):
"""Test encode/decode roundtrip for stash message."""
original = make_stash_message(
12345, STASH_OP_GET, "cache",
key="my_key"
)
encoded = BinaryProtocol._encode_from_dict(original)
msg_type, request_id, payload = BinaryProtocol.decode_message(encoded)
assert msg_type == "stash"
assert payload["op"] == STASH_OP_GET
assert payload["table"] == "cache"
assert payload["key"] == "my_key"
def test_stash_operations_have_unique_codes(self):
"""Test that all stash operations have unique codes."""
ops = [
STASH_OP_PUT,
STASH_OP_GET,
STASH_OP_DELETE,
STASH_OP_KEYS,
STASH_OP_CLEAR,
STASH_OP_INFO,
STASH_OP_ENSURE,
STASH_OP_DELETE_TABLE,
STASH_OP_TABLES,
STASH_OP_EXISTS,
]
assert len(ops) == len(set(ops))
class TestStashTable:
"""Test StashTable dict-like interface."""
def test_stash_table_name(self):
"""Test StashTable name property."""
# Create a mock client
class MockClient:
pass
table = StashTable(MockClient(), "test_table")
assert table.name == "test_table"
class TestStashErrors:
"""Test stash error classes."""
def test_stash_error_base(self):
"""Test base StashError."""
error = StashError("test error")
assert str(error) == "test error"
assert isinstance(error, Exception)
def test_stash_table_not_found_error(self):
"""Test StashTableNotFoundError."""
error = StashTableNotFoundError("my_table")
assert error.table_name == "my_table"
assert "my_table" in str(error)
def test_stash_key_not_found_error(self):
"""Test StashKeyNotFoundError."""
error = StashKeyNotFoundError("my_table", "my_key")
assert error.table_name == "my_table"
assert error.key == "my_key"
assert "my_key" in str(error)
class TestStashProtocolConstants:
"""Test protocol constants for stash."""
def test_msg_type_stash_exists(self):
"""Test MSG_TYPE_STASH constant exists."""
assert MSG_TYPE_STASH == 0x10
def test_dirty_protocol_exports_stash_type(self):
"""Test DirtyProtocol exports stash type."""
assert DirtyProtocol.MSG_TYPE_STASH == "stash"
def test_stash_op_codes(self):
"""Test stash operation codes are integers."""
assert isinstance(STASH_OP_PUT, int)
assert isinstance(STASH_OP_GET, int)
assert isinstance(STASH_OP_DELETE, int)
assert isinstance(STASH_OP_KEYS, int)
assert isinstance(STASH_OP_CLEAR, int)
assert isinstance(STASH_OP_INFO, int)
assert isinstance(STASH_OP_ENSURE, int)
assert isinstance(STASH_OP_DELETE_TABLE, int)
assert isinstance(STASH_OP_TABLES, int)
assert isinstance(STASH_OP_EXISTS, int)
class TestStashEncodingEdgeCases:
"""Test edge cases in stash encoding."""
def test_encode_empty_table_name(self):
"""Test encoding with empty table name."""
msg = make_stash_message(1, STASH_OP_TABLES, "")
encoded = BinaryProtocol._encode_from_dict(msg)
assert isinstance(encoded, bytes)
def test_encode_unicode_table_name(self):
"""Test encoding with unicode table name."""
msg = make_stash_message(1, STASH_OP_PUT, "テスト", key="k", value="v")
encoded = BinaryProtocol._encode_from_dict(msg)
_, _, payload = BinaryProtocol.decode_message(encoded)
assert payload["table"] == "テスト"
def test_encode_complex_value(self):
"""Test encoding with complex nested value."""
value = {
"name": "test",
"count": 42,
"nested": {"a": [1, 2, 3]},
"data": b"binary data",
}
msg = make_stash_message(1, STASH_OP_PUT, "test", key="k", value=value)
encoded = BinaryProtocol._encode_from_dict(msg)
_, _, payload = BinaryProtocol.decode_message(encoded)
assert payload["value"] == value
def test_encode_none_key(self):
"""Test encoding with None key (for table-level ops)."""
msg = make_stash_message(1, STASH_OP_TABLES, "")
assert "key" not in msg
def test_encode_special_characters_in_pattern(self):
"""Test encoding with special characters in pattern."""
msg = make_stash_message(
1, STASH_OP_KEYS, "test",
pattern="user:*:session:?"
)
encoded = BinaryProtocol._encode_from_dict(msg)
_, _, payload = BinaryProtocol.decode_message(encoded)
assert payload["pattern"] == "user:*:session:?"