mirror of
https://github.com/frappe/gunicorn.git
synced 2026-07-03 19:21:29 +08:00
feat(dirty): add stash - global shared state between workers (#3503)
* feat(dirty): add stash - global shared state between workers
Add a simple key-value store (stash) that allows dirty workers to share
state through the arbiter. Tables are stored directly in arbiter memory
for fast access and simplicity.
Features:
- Auto-create tables on first access
- Dict-like interface via stash.table()
- Pattern matching for keys (glob patterns)
- Module-level API: stash.put(), stash.get(), stash.delete(), etc.
Usage:
from gunicorn.dirty import stash
stash.put("sessions", "user:1", {"name": "Alice"})
user = stash.get("sessions", "user:1")
# Or dict-like
sessions = stash.table("sessions")
sessions["user:1"] = {"name": "Alice"}
New files:
- gunicorn/dirty/stash.py - Client API and StashTable class
- Protocol additions for MSG_TYPE_STASH and STASH_OP_* codes
Note: Tables are ephemeral - lost if arbiter restarts.
* test(dirty): add tests for stash protocol and encoding
Test coverage for:
- Stash message creation and encoding
- Protocol constants (MSG_TYPE_STASH, STASH_OP_*)
- Error classes (StashError, StashTableNotFoundError, StashKeyNotFoundError)
- StashTable dict-like interface
- Edge cases: unicode, complex values, special patterns
* example(dirty): add stash usage example and integration tests
- Add SessionApp to dirty_app.py demonstrating stash usage
- Add /session/* endpoints to wsgi_app.py
- Add test_stash_integration.py with Docker tests
- Update docker-compose.yml with stash-test service
- Fix: Set GUNICORN_DIRTY_SOCKET in dirty arbiter for worker access
* docs(dirty): add stash documentation
This commit is contained in:
parent
236c9371d0
commit
709a6ad159
@ -558,6 +558,259 @@ def generate_view(request):
|
||||
4. **Keep chunks small** - Smaller chunks provide better perceived latency
|
||||
5. **Handle client disconnection** - Streams continue even if client disconnects; design accordingly
|
||||
|
||||
## Stash (Shared State via Message Passing)
|
||||
|
||||
Stash provides shared state between dirty workers, similar to Erlang's ETS (Erlang Term Storage). Workers remain fully isolated - all state access goes through message passing to the arbiter.
|
||||
|
||||
### Architecture
|
||||
|
||||
```
|
||||
+------------------+
|
||||
| Dirty Arbiter |
|
||||
| |
|
||||
| stash_tables: |
|
||||
| sessions: {} |
|
||||
| cache: {} |
|
||||
+--------+---------+
|
||||
|
|
||||
Unix Socket IPC (message passing)
|
||||
|
|
||||
+-------------------+-------------------+
|
||||
| | |
|
||||
+-----v-----+ +-----v-----+ +-----v-----+
|
||||
| Worker 1 | | Worker 2 | | Worker 3 |
|
||||
| | | | | |
|
||||
| (isolated)| | (isolated)| | (isolated)|
|
||||
+-----------+ +-----------+ +-----------+
|
||||
|
||||
Workers have NO shared memory.
|
||||
All stash operations are IPC messages to arbiter.
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. Worker calls `stash.put("sessions", "user:1", data)`
|
||||
2. Worker sends message to arbiter via Unix socket
|
||||
3. Arbiter stores data in its memory (`self.stash_tables`)
|
||||
4. Arbiter sends response back to worker
|
||||
5. Worker receives confirmation
|
||||
|
||||
This is **not** shared memory - workers remain fully isolated. The arbiter acts as a centralized store that workers communicate with via message passing. This matches Erlang's model where ETS tables are owned by a process.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import stash
|
||||
|
||||
# Store a value (table auto-created)
|
||||
# This sends a message to arbiter, which stores it
|
||||
stash.put("sessions", "user:123", {"name": "Alice", "role": "admin"})
|
||||
|
||||
# Retrieve a value
|
||||
# This sends a request to arbiter, which returns the value
|
||||
user = stash.get("sessions", "user:123")
|
||||
|
||||
# Delete a key
|
||||
stash.delete("sessions", "user:123")
|
||||
|
||||
# Check existence
|
||||
if stash.exists("sessions", "user:123"):
|
||||
print("Session exists")
|
||||
|
||||
# List keys with pattern matching
|
||||
keys = stash.keys("sessions", pattern="user:*")
|
||||
```
|
||||
|
||||
### Dict-like Interface
|
||||
|
||||
For more Pythonic access, use the table interface:
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import stash
|
||||
|
||||
# Get a table reference
|
||||
sessions = stash.table("sessions")
|
||||
|
||||
# Dict-like operations (each is an IPC message)
|
||||
sessions["user:123"] = {"name": "Alice"}
|
||||
user = sessions["user:123"]
|
||||
del sessions["user:123"]
|
||||
|
||||
# Iteration
|
||||
for key in sessions:
|
||||
print(key, sessions[key])
|
||||
|
||||
# Length
|
||||
count = len(sessions)
|
||||
```
|
||||
|
||||
### Table Management
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import stash
|
||||
|
||||
# Explicit table creation (idempotent)
|
||||
stash.ensure("cache")
|
||||
|
||||
# Get table info
|
||||
info = stash.info("sessions")
|
||||
print(f"Table has {info['size']} entries")
|
||||
|
||||
# Clear all entries in a table
|
||||
stash.clear("sessions")
|
||||
|
||||
# Delete entire table
|
||||
stash.delete_table("sessions")
|
||||
|
||||
# List all tables
|
||||
tables = stash.tables()
|
||||
```
|
||||
|
||||
### Using Stash in DirtyApp
|
||||
|
||||
Declare tables your app uses with the `stashes` class attribute:
|
||||
|
||||
```python
|
||||
from gunicorn.dirty import DirtyApp, stash
|
||||
|
||||
class SessionApp(DirtyApp):
|
||||
# Tables declared here are auto-created on startup
|
||||
stashes = ["sessions", "counters"]
|
||||
|
||||
def init(self):
|
||||
# Initialize counter if needed
|
||||
if not stash.exists("counters", "requests"):
|
||||
stash.put("counters", "requests", 0)
|
||||
|
||||
def login(self, user_id, user_data):
|
||||
"""Store session - any worker can read it via arbiter."""
|
||||
stash.put("sessions", f"user:{user_id}", {
|
||||
"data": user_data,
|
||||
"logged_in_at": time.time(),
|
||||
})
|
||||
self._increment_counter()
|
||||
return {"status": "ok"}
|
||||
|
||||
def get_session(self, user_id):
|
||||
"""Get session - request goes to arbiter."""
|
||||
return stash.get("sessions", f"user:{user_id}")
|
||||
|
||||
def _increment_counter(self):
|
||||
"""Increment global counter via arbiter."""
|
||||
current = stash.get("counters", "requests", 0)
|
||||
stash.put("counters", "requests", current + 1)
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
```
|
||||
|
||||
### API Reference
|
||||
|
||||
| Function | Description |
|
||||
|----------|-------------|
|
||||
| `stash.put(table, key, value)` | Store a value (table auto-created) |
|
||||
| `stash.get(table, key, default=None)` | Retrieve a value |
|
||||
| `stash.delete(table, key)` | Delete a key, returns True if deleted |
|
||||
| `stash.exists(table, key=None)` | Check if table/key exists |
|
||||
| `stash.keys(table, pattern=None)` | List keys, optional glob pattern |
|
||||
| `stash.clear(table)` | Delete all entries in table |
|
||||
| `stash.info(table)` | Get table info (size, etc.) |
|
||||
| `stash.ensure(table)` | Create table if not exists |
|
||||
| `stash.delete_table(table)` | Delete entire table |
|
||||
| `stash.tables()` | List all table names |
|
||||
| `stash.table(name)` | Get dict-like interface |
|
||||
|
||||
### Patterns and Use Cases
|
||||
|
||||
**Session Storage:**
|
||||
```python
|
||||
# Store session on login (worker 1)
|
||||
stash.put("sessions", f"user:{user_id}", session_data)
|
||||
|
||||
# Check session on request (may be worker 2)
|
||||
session = stash.get("sessions", f"user:{user_id}")
|
||||
if session is None:
|
||||
raise AuthError("Not logged in")
|
||||
```
|
||||
|
||||
**Shared Cache:**
|
||||
```python
|
||||
def get_expensive_result(key):
|
||||
# Check cache first (via arbiter)
|
||||
cached = stash.get("cache", key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Compute and cache
|
||||
result = expensive_computation()
|
||||
stash.put("cache", key, result)
|
||||
return result
|
||||
```
|
||||
|
||||
**Global Counters:**
|
||||
```python
|
||||
def increment_counter(name):
|
||||
# Note: not atomic - two workers could read same value
|
||||
current = stash.get("counters", name, 0)
|
||||
stash.put("counters", name, current + 1)
|
||||
return current + 1
|
||||
```
|
||||
|
||||
**Feature Flags:**
|
||||
```python
|
||||
# Set flag (from admin endpoint)
|
||||
stash.put("flags", "new_feature", True)
|
||||
|
||||
# Check flag (from any worker)
|
||||
if stash.get("flags", "new_feature", False):
|
||||
enable_new_feature()
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
from gunicorn.dirty.stash import (
|
||||
StashError,
|
||||
StashTableNotFoundError,
|
||||
StashKeyNotFoundError,
|
||||
)
|
||||
|
||||
try:
|
||||
info = stash.info("nonexistent")
|
||||
except StashTableNotFoundError as e:
|
||||
print(f"Table not found: {e.table_name}")
|
||||
|
||||
# Using get() with default avoids KeyNotFoundError
|
||||
value = stash.get("table", "key", default="fallback")
|
||||
```
|
||||
|
||||
### Best Practices
|
||||
|
||||
1. **Use descriptive table names** - `user_sessions`, `ml_cache`, not `data`
|
||||
2. **Use key prefixes** - `user:123`, `cache:model:v1` for organization
|
||||
3. **Handle missing data** - Always provide defaults or check existence
|
||||
4. **Don't store large data** - Each access is an IPC round-trip
|
||||
5. **Remember it's ephemeral** - Data is lost on arbiter restart
|
||||
|
||||
### Advantages
|
||||
|
||||
- **Worker isolation** - Workers remain fully isolated; no shared memory bugs
|
||||
- **Simple API** - Dict-like interface, no locking required
|
||||
- **Binary support** - Efficiently stores bytes (images, model weights)
|
||||
- **Pattern matching** - `keys(pattern="user:*")` for querying
|
||||
- **Zero setup** - Works automatically with dirty workers
|
||||
- **Table-based** - Organize data into logical namespaces
|
||||
|
||||
### Limitations
|
||||
|
||||
- **No persistence** - Data lives only in arbiter memory
|
||||
- **No transactions** - No atomic read-modify-write operations
|
||||
- **No TTL** - Entries don't expire automatically
|
||||
- **IPC overhead** - Each operation is a network round-trip
|
||||
- **Single arbiter** - Not distributed across multiple machines
|
||||
|
||||
For persistent or distributed state, use Redis, PostgreSQL, or similar external systems.
|
||||
|
||||
### Flask Example
|
||||
|
||||
```python
|
||||
|
||||
@ -11,9 +11,11 @@ This demonstrates how to create a DirtyApp that:
|
||||
3. Cleans up on shutdown (close)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
from gunicorn.dirty.app import DirtyApp
|
||||
from gunicorn.dirty import stash
|
||||
|
||||
|
||||
class MLApp(DirtyApp):
|
||||
@ -171,3 +173,96 @@ class ComputeApp(DirtyApp):
|
||||
|
||||
def close(self):
|
||||
print(f"[ComputeApp] Shutting down. Total computations: {self.computation_count}")
|
||||
|
||||
|
||||
class SessionApp(DirtyApp):
|
||||
"""
|
||||
Example dirty application demonstrating stash (shared state).
|
||||
|
||||
This shows how multiple dirty workers can share state through
|
||||
the arbiter's stash tables. All workers see the same data.
|
||||
"""
|
||||
|
||||
# Declare stash tables used by this app (auto-created on startup)
|
||||
stashes = ["sessions", "counters"]
|
||||
|
||||
def __init__(self):
|
||||
self.worker_pid = None
|
||||
|
||||
def init(self):
|
||||
self.worker_pid = os.getpid()
|
||||
print(f"[SessionApp] Initialized on worker {self.worker_pid}")
|
||||
# Initialize a global counter if it doesn't exist
|
||||
if not stash.exists("counters", "requests"):
|
||||
stash.put("counters", "requests", 0)
|
||||
|
||||
def __call__(self, action, *args, **kwargs):
|
||||
method = getattr(self, action, None)
|
||||
if method is None or action.startswith('_'):
|
||||
raise ValueError(f"Unknown action: {action}")
|
||||
return method(*args, **kwargs)
|
||||
|
||||
def login(self, user_id, user_data):
|
||||
"""Store user session in shared stash."""
|
||||
session = {
|
||||
"user_id": user_id,
|
||||
"data": user_data,
|
||||
"logged_in_at": time.time(),
|
||||
"worker_pid": self.worker_pid,
|
||||
}
|
||||
stash.put("sessions", f"user:{user_id}", session)
|
||||
self._increment_counter()
|
||||
return {"status": "ok", "session": session}
|
||||
|
||||
def logout(self, user_id):
|
||||
"""Remove user session."""
|
||||
key = f"user:{user_id}"
|
||||
if stash.exists("sessions", key):
|
||||
stash.delete("sessions", key)
|
||||
self._increment_counter()
|
||||
return {"status": "logged_out", "user_id": user_id}
|
||||
return {"status": "not_found", "user_id": user_id}
|
||||
|
||||
def get_session(self, user_id):
|
||||
"""Get user session - visible from any worker."""
|
||||
session = stash.get("sessions", f"user:{user_id}")
|
||||
self._increment_counter()
|
||||
return {
|
||||
"session": session,
|
||||
"served_by_worker": self.worker_pid,
|
||||
}
|
||||
|
||||
def list_sessions(self):
|
||||
"""List all active sessions."""
|
||||
keys = stash.keys("sessions", pattern="user:*")
|
||||
sessions = []
|
||||
for key in keys:
|
||||
sessions.append(stash.get("sessions", key))
|
||||
self._increment_counter()
|
||||
return {
|
||||
"sessions": sessions,
|
||||
"count": len(sessions),
|
||||
"served_by_worker": self.worker_pid,
|
||||
}
|
||||
|
||||
def get_stats(self):
|
||||
"""Get global request counter (shared across all workers)."""
|
||||
count = stash.get("counters", "requests", 0)
|
||||
return {
|
||||
"total_requests": count,
|
||||
"served_by_worker": self.worker_pid,
|
||||
}
|
||||
|
||||
def _increment_counter(self):
|
||||
"""Increment global request counter."""
|
||||
current = stash.get("counters", "requests", 0)
|
||||
stash.put("counters", "requests", current + 1)
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all sessions (for testing)."""
|
||||
stash.clear("sessions")
|
||||
stash.put("counters", "requests", 0)
|
||||
return {"status": "cleared"}
|
||||
|
||||
def close(self):
|
||||
print(f"[SessionApp] Shutting down worker {self.worker_pid}")
|
||||
|
||||
@ -52,3 +52,15 @@ services:
|
||||
environment:
|
||||
- TEST_BASE_URL=http://server:8000
|
||||
command: python examples/dirty_example/test_integration.py
|
||||
|
||||
# Run stash integration test against the server
|
||||
stash-test:
|
||||
build:
|
||||
context: ../..
|
||||
dockerfile: examples/dirty_example/Dockerfile
|
||||
depends_on:
|
||||
server:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
- TEST_BASE_URL=http://server:8000
|
||||
command: python examples/dirty_example/test_stash_integration.py
|
||||
|
||||
@ -22,6 +22,7 @@ timeout = 30
|
||||
dirty_apps = [
|
||||
"examples.dirty_example.dirty_app:MLApp",
|
||||
"examples.dirty_example.dirty_app:ComputeApp",
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
]
|
||||
dirty_workers = 2
|
||||
dirty_timeout = 300
|
||||
|
||||
226
examples/dirty_example/test_stash_integration.py
Normal file
226
examples/dirty_example/test_stash_integration.py
Normal file
@ -0,0 +1,226 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Integration tests for stash (shared state) functionality.
|
||||
|
||||
These tests verify that stash works correctly across multiple dirty workers,
|
||||
demonstrating that state is truly shared.
|
||||
|
||||
Run with Docker:
|
||||
docker-compose up --build
|
||||
docker-compose exec app python test_stash_integration.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
|
||||
BASE_URL = os.environ.get("TEST_BASE_URL", "http://localhost:8000")
|
||||
|
||||
|
||||
def request(path):
|
||||
"""Make HTTP request and return JSON response."""
|
||||
url = f"{BASE_URL}{path}"
|
||||
try:
|
||||
with urllib.request.urlopen(url, timeout=10) as resp:
|
||||
return json.loads(resp.read().decode())
|
||||
except urllib.error.HTTPError as e:
|
||||
return {"error": str(e), "code": e.code}
|
||||
except urllib.error.URLError as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def test_stash_shared_state():
|
||||
"""Test that stash state is shared across workers."""
|
||||
print("\n=== Test: Stash Shared State ===")
|
||||
|
||||
# Clear any existing state
|
||||
result = request("/session/clear")
|
||||
print(f"Clear: {result}")
|
||||
|
||||
# Login a user
|
||||
result = request("/session/login?user_id=100&name=Alice")
|
||||
print(f"Login Alice: {result}")
|
||||
assert result.get("status") == "ok", f"Login failed: {result}"
|
||||
worker1 = result.get("session", {}).get("worker_pid")
|
||||
print(f" -> Handled by worker: {worker1}")
|
||||
|
||||
# Make multiple requests to potentially hit different workers
|
||||
# and verify they all see the same session
|
||||
workers_seen = set()
|
||||
for i in range(5):
|
||||
result = request("/session/get?user_id=100")
|
||||
worker = result.get("served_by_worker")
|
||||
workers_seen.add(worker)
|
||||
session = result.get("session")
|
||||
assert session is not None, f"Session not found on request {i+1}"
|
||||
assert session.get("data", {}).get("name") == "Alice", f"Wrong session data"
|
||||
|
||||
print(f" -> Session visible from workers: {workers_seen}")
|
||||
print("PASSED: State is shared across workers")
|
||||
return True
|
||||
|
||||
|
||||
def test_stash_counter():
|
||||
"""Test that global counter increments correctly."""
|
||||
print("\n=== Test: Global Counter ===")
|
||||
|
||||
# Clear state
|
||||
request("/session/clear")
|
||||
|
||||
# Get initial stats
|
||||
result = request("/session/stats")
|
||||
initial = result.get("total_requests", 0)
|
||||
print(f"Initial count: {initial}")
|
||||
|
||||
# Make several requests
|
||||
for i in range(5):
|
||||
request(f"/session/login?user_id={i}&name=User{i}")
|
||||
|
||||
# Check counter increased
|
||||
result = request("/session/stats")
|
||||
final = result.get("total_requests", 0)
|
||||
print(f"Final count: {final}")
|
||||
|
||||
# Each login increments counter by 1
|
||||
assert final >= initial + 5, f"Counter didn't increment enough: {initial} -> {final}"
|
||||
print("PASSED: Global counter works across workers")
|
||||
return True
|
||||
|
||||
|
||||
def test_stash_list_sessions():
|
||||
"""Test listing all sessions."""
|
||||
print("\n=== Test: List Sessions ===")
|
||||
|
||||
# Clear and create some sessions
|
||||
request("/session/clear")
|
||||
request("/session/login?user_id=1&name=Alice")
|
||||
request("/session/login?user_id=2&name=Bob")
|
||||
request("/session/login?user_id=3&name=Charlie")
|
||||
|
||||
# List all sessions
|
||||
result = request("/session/list")
|
||||
sessions = result.get("sessions", [])
|
||||
count = result.get("count", 0)
|
||||
|
||||
print(f"Sessions: {count}")
|
||||
for s in sessions:
|
||||
print(f" - user:{s.get('user_id')} = {s.get('data', {}).get('name')}")
|
||||
|
||||
assert count == 3, f"Expected 3 sessions, got {count}"
|
||||
print("PASSED: List sessions works")
|
||||
return True
|
||||
|
||||
|
||||
def test_stash_logout():
|
||||
"""Test session deletion."""
|
||||
print("\n=== Test: Logout (Delete) ===")
|
||||
|
||||
# Clear and create a session
|
||||
request("/session/clear")
|
||||
request("/session/login?user_id=999&name=TestUser")
|
||||
|
||||
# Verify it exists
|
||||
result = request("/session/get?user_id=999")
|
||||
assert result.get("session") is not None, "Session should exist"
|
||||
|
||||
# Logout
|
||||
result = request("/session/logout?user_id=999")
|
||||
print(f"Logout: {result}")
|
||||
assert result.get("status") == "logged_out", f"Logout failed: {result}"
|
||||
|
||||
# Verify it's gone
|
||||
result = request("/session/get?user_id=999")
|
||||
assert result.get("session") is None, "Session should be deleted"
|
||||
|
||||
print("PASSED: Logout deletes session")
|
||||
return True
|
||||
|
||||
|
||||
def test_multiple_workers_see_updates():
|
||||
"""Test that updates from one worker are visible to others."""
|
||||
print("\n=== Test: Cross-Worker Updates ===")
|
||||
|
||||
request("/session/clear")
|
||||
|
||||
# Create sessions and track which workers handled them
|
||||
workers = {}
|
||||
for i in range(10):
|
||||
result = request(f"/session/login?user_id={i}&name=User{i}")
|
||||
worker = result.get("session", {}).get("worker_pid")
|
||||
workers[i] = worker
|
||||
|
||||
unique_workers = set(workers.values())
|
||||
print(f"Sessions created by workers: {unique_workers}")
|
||||
|
||||
# Now read all sessions and verify all workers can see all data
|
||||
result = request("/session/list")
|
||||
count = result.get("count", 0)
|
||||
served_by = result.get("served_by_worker")
|
||||
|
||||
print(f"List returned {count} sessions, served by worker {served_by}")
|
||||
assert count == 10, f"Expected 10 sessions, got {count}"
|
||||
|
||||
print("PASSED: All workers see all updates")
|
||||
return True
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests."""
|
||||
print("=" * 60)
|
||||
print("Stash Integration Tests")
|
||||
print("=" * 60)
|
||||
|
||||
# Check server is running
|
||||
try:
|
||||
result = request("/")
|
||||
if "error" in result and "Connection refused" in str(result.get("error", "")):
|
||||
print("ERROR: Server not running. Start with: docker-compose up")
|
||||
return 1
|
||||
if not result.get("dirty_enabled"):
|
||||
print("ERROR: Dirty workers not enabled")
|
||||
return 1
|
||||
print(f"Server running, dirty workers enabled")
|
||||
except Exception as e:
|
||||
print(f"ERROR: Cannot connect to server: {e}")
|
||||
return 1
|
||||
|
||||
# Run tests
|
||||
tests = [
|
||||
test_stash_shared_state,
|
||||
test_stash_counter,
|
||||
test_stash_list_sessions,
|
||||
test_stash_logout,
|
||||
test_multiple_workers_see_updates,
|
||||
]
|
||||
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for test in tests:
|
||||
try:
|
||||
if test():
|
||||
passed += 1
|
||||
else:
|
||||
failed += 1
|
||||
except AssertionError as e:
|
||||
print(f"FAILED: {e}")
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}")
|
||||
failed += 1
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(f"Results: {passed} passed, {failed} failed")
|
||||
print("=" * 60)
|
||||
|
||||
return 0 if failed == 0 else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@ -52,6 +52,11 @@ def app(environ, start_response):
|
||||
"/fibonacci?n=NUMBER": "Compute fibonacci",
|
||||
"/prime?n=NUMBER": "Check if prime",
|
||||
"/stats": "Get dirty worker stats",
|
||||
"/session/login?user_id=ID&name=NAME": "Login user (stash demo)",
|
||||
"/session/get?user_id=ID": "Get session (stash demo)",
|
||||
"/session/list": "List all sessions (stash demo)",
|
||||
"/session/logout?user_id=ID": "Logout user (stash demo)",
|
||||
"/session/stats": "Get stash stats (stash demo)",
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,6 +144,71 @@ def app(environ, start_response):
|
||||
"http_worker_pid": os.getpid(),
|
||||
}
|
||||
|
||||
# =====================================================================
|
||||
# Session endpoints (stash demo)
|
||||
# =====================================================================
|
||||
elif path == '/session/login':
|
||||
user_id = query.get('user_id', ['1'])[0]
|
||||
name = query.get('name', ['Anonymous'])[0]
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"login",
|
||||
user_id=user_id,
|
||||
user_data={"name": name}
|
||||
)
|
||||
|
||||
elif path == '/session/get':
|
||||
user_id = query.get('user_id', ['1'])[0]
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"get_session",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
elif path == '/session/list':
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"list_sessions"
|
||||
)
|
||||
|
||||
elif path == '/session/logout':
|
||||
user_id = query.get('user_id', ['1'])[0]
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"logout",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
elif path == '/session/stats':
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"get_stats"
|
||||
)
|
||||
|
||||
elif path == '/session/clear':
|
||||
if client is None:
|
||||
result = {"error": "Dirty workers not enabled"}
|
||||
else:
|
||||
result = client.execute(
|
||||
"examples.dirty_example.dirty_app:SessionApp",
|
||||
"clear_all"
|
||||
)
|
||||
|
||||
else:
|
||||
start_response('404 Not Found', [('Content-Type', 'application/json')])
|
||||
return [json.dumps({"error": "Not found"}).encode()]
|
||||
|
||||
@ -38,6 +38,16 @@ from .client import (
|
||||
close_dirty_client_async,
|
||||
)
|
||||
|
||||
# Stash (shared state between workers)
|
||||
from . import stash
|
||||
from .stash import (
|
||||
StashClient,
|
||||
StashTable,
|
||||
StashError,
|
||||
StashTableNotFoundError,
|
||||
StashKeyNotFoundError,
|
||||
)
|
||||
|
||||
# Internal imports used by gunicorn core (not part of public API)
|
||||
from .arbiter import DirtyArbiter
|
||||
|
||||
@ -58,6 +68,13 @@ __all__ = [
|
||||
"get_dirty_client_async",
|
||||
"close_dirty_client",
|
||||
"close_dirty_client_async",
|
||||
# Stash (shared state)
|
||||
"stash",
|
||||
"StashClient",
|
||||
"StashTable",
|
||||
"StashError",
|
||||
"StashTableNotFoundError",
|
||||
"StashKeyNotFoundError",
|
||||
# Internal (used by gunicorn core)
|
||||
"DirtyArbiter",
|
||||
"set_dirty_socket_path",
|
||||
|
||||
@ -11,6 +11,7 @@ requests from HTTP workers to available dirty workers.
|
||||
|
||||
import asyncio
|
||||
import errno
|
||||
import fnmatch
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@ -29,6 +30,17 @@ from .errors import (
|
||||
from .protocol import (
|
||||
DirtyProtocol,
|
||||
make_error_response,
|
||||
make_response,
|
||||
STASH_OP_PUT,
|
||||
STASH_OP_GET,
|
||||
STASH_OP_DELETE,
|
||||
STASH_OP_KEYS,
|
||||
STASH_OP_CLEAR,
|
||||
STASH_OP_INFO,
|
||||
STASH_OP_ENSURE,
|
||||
STASH_OP_DELETE_TABLE,
|
||||
STASH_OP_TABLES,
|
||||
STASH_OP_EXISTS,
|
||||
)
|
||||
from .worker import DirtyWorker
|
||||
|
||||
@ -97,6 +109,10 @@ class DirtyArbiter:
|
||||
# Queue of app lists from dead workers to respawn with same apps
|
||||
self._pending_respawns = []
|
||||
|
||||
# Stash (shared state) - global tables stored in arbiter
|
||||
# Maps table_name -> dict of data
|
||||
self.stash_tables = {}
|
||||
|
||||
# Parse app specs on init
|
||||
self._parse_app_specs()
|
||||
|
||||
@ -209,6 +225,9 @@ class DirtyArbiter:
|
||||
except IOError as e:
|
||||
self.log.warning("Failed to write PID file: %s", e)
|
||||
|
||||
# Set socket path env var for dirty workers (enables stash access)
|
||||
os.environ['GUNICORN_DIRTY_SOCKET'] = self.socket_path
|
||||
|
||||
# Call hook
|
||||
self.cfg.on_dirty_starting(self)
|
||||
|
||||
@ -337,6 +356,7 @@ class DirtyArbiter:
|
||||
|
||||
Routes requests to available dirty workers and returns responses.
|
||||
Supports both regular responses and streaming (chunk-based) responses.
|
||||
Also handles stash (shared state) operations.
|
||||
"""
|
||||
self.log.debug("New client connection from HTTP worker")
|
||||
|
||||
@ -347,8 +367,14 @@ class DirtyArbiter:
|
||||
except asyncio.IncompleteReadError:
|
||||
break
|
||||
|
||||
# Route request to a dirty worker - pass writer for streaming
|
||||
await self.route_request(message, writer)
|
||||
msg_type = message.get("type")
|
||||
|
||||
# Handle stash operations
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_STASH:
|
||||
await self.handle_stash_request(message, writer)
|
||||
else:
|
||||
# Route request to a dirty worker - pass writer for streaming
|
||||
await self.route_request(message, writer)
|
||||
except Exception as e:
|
||||
self.log.error("Client connection error: %s", e)
|
||||
finally:
|
||||
@ -565,6 +591,127 @@ class DirtyArbiter:
|
||||
_reader, writer = self.worker_connections.pop(worker_pid)
|
||||
writer.close()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Stash (shared state) operations - handled directly in arbiter
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def handle_stash_request(self, message, client_writer):
|
||||
"""
|
||||
Handle a stash operation directly in the arbiter.
|
||||
|
||||
All stash tables are stored in arbiter memory for simplicity
|
||||
and fast access.
|
||||
|
||||
Args:
|
||||
message: Stash operation message
|
||||
client_writer: StreamWriter to send response to client
|
||||
"""
|
||||
request_id = message.get("id", "unknown")
|
||||
op = message.get("op")
|
||||
table = message.get("table", "")
|
||||
key = message.get("key")
|
||||
value = message.get("value")
|
||||
pattern = message.get("pattern")
|
||||
|
||||
try:
|
||||
result = None
|
||||
|
||||
if op == STASH_OP_PUT:
|
||||
# Auto-create table if needed
|
||||
if table not in self.stash_tables:
|
||||
self.stash_tables[table] = {}
|
||||
self.stash_tables[table][key] = value
|
||||
result = True
|
||||
|
||||
elif op == STASH_OP_GET:
|
||||
if table not in self.stash_tables:
|
||||
result = {"error": "key_not_found"}
|
||||
elif key not in self.stash_tables[table]:
|
||||
result = {"error": "key_not_found"}
|
||||
else:
|
||||
result = self.stash_tables[table][key]
|
||||
|
||||
elif op == STASH_OP_DELETE:
|
||||
if table in self.stash_tables and key in self.stash_tables[table]:
|
||||
del self.stash_tables[table][key]
|
||||
result = True
|
||||
else:
|
||||
result = False
|
||||
|
||||
elif op == STASH_OP_KEYS:
|
||||
if table not in self.stash_tables:
|
||||
result = []
|
||||
else:
|
||||
all_keys = list(self.stash_tables[table].keys())
|
||||
if pattern:
|
||||
all_keys = [k for k in all_keys
|
||||
if fnmatch.fnmatch(str(k), pattern)]
|
||||
result = all_keys
|
||||
|
||||
elif op == STASH_OP_CLEAR:
|
||||
if table in self.stash_tables:
|
||||
self.stash_tables[table].clear()
|
||||
result = True
|
||||
|
||||
elif op == STASH_OP_INFO:
|
||||
if table not in self.stash_tables:
|
||||
result = {"error": "table_not_found"}
|
||||
else:
|
||||
result = {
|
||||
"size": len(self.stash_tables[table]),
|
||||
"table": table,
|
||||
}
|
||||
|
||||
elif op == STASH_OP_ENSURE:
|
||||
if table not in self.stash_tables:
|
||||
self.stash_tables[table] = {}
|
||||
result = True
|
||||
|
||||
elif op == STASH_OP_DELETE_TABLE:
|
||||
if table in self.stash_tables:
|
||||
del self.stash_tables[table]
|
||||
result = True
|
||||
else:
|
||||
result = False
|
||||
|
||||
elif op == STASH_OP_TABLES:
|
||||
result = list(self.stash_tables.keys())
|
||||
|
||||
elif op == STASH_OP_EXISTS:
|
||||
if table not in self.stash_tables:
|
||||
result = False
|
||||
elif key is None:
|
||||
result = True
|
||||
else:
|
||||
result = key in self.stash_tables[table]
|
||||
|
||||
else:
|
||||
error = DirtyError(f"Unknown stash operation: {op}")
|
||||
response = make_error_response(request_id, error)
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
return
|
||||
|
||||
# Handle error results
|
||||
if isinstance(result, dict) and "error" in result:
|
||||
error_type = result["error"]
|
||||
if error_type == "table_not_found":
|
||||
error = DirtyError(f"Table not found: {table}")
|
||||
elif error_type == "key_not_found":
|
||||
error = DirtyError(f"Key not found: {key}")
|
||||
else:
|
||||
error = DirtyError(str(result))
|
||||
error.error_type = f"Stash{error_type.title().replace('_', '')}Error"
|
||||
response = make_error_response(request_id, error)
|
||||
else:
|
||||
response = make_response(request_id, result)
|
||||
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
except Exception as e:
|
||||
self.log.error("Stash operation error: %s", e)
|
||||
response = make_error_response(request_id, DirtyError(str(e)))
|
||||
await DirtyProtocol.write_message_async(client_writer, response)
|
||||
|
||||
async def manage_workers(self):
|
||||
"""Maintain the number of dirty workers."""
|
||||
if not self.alive:
|
||||
|
||||
@ -659,6 +659,10 @@ def set_dirty_socket_path(path):
|
||||
global _dirty_socket_path # pylint: disable=global-statement
|
||||
_dirty_socket_path = path
|
||||
|
||||
# Also set the stash socket path (uses same arbiter socket)
|
||||
from .stash import set_stash_socket_path
|
||||
set_stash_socket_path(path)
|
||||
|
||||
|
||||
def get_dirty_socket_path():
|
||||
"""Get the dirty socket path."""
|
||||
|
||||
@ -42,6 +42,7 @@ MSG_TYPE_RESPONSE = 0x02
|
||||
MSG_TYPE_ERROR = 0x03
|
||||
MSG_TYPE_CHUNK = 0x04
|
||||
MSG_TYPE_END = 0x05
|
||||
MSG_TYPE_STASH = 0x10 # Stash operations (shared state between workers)
|
||||
|
||||
# Message type names (for backwards compatibility with old API)
|
||||
MSG_TYPE_REQUEST_STR = "request"
|
||||
@ -49,6 +50,7 @@ MSG_TYPE_RESPONSE_STR = "response"
|
||||
MSG_TYPE_ERROR_STR = "error"
|
||||
MSG_TYPE_CHUNK_STR = "chunk"
|
||||
MSG_TYPE_END_STR = "end"
|
||||
MSG_TYPE_STASH_STR = "stash"
|
||||
|
||||
# Map int types to string names
|
||||
MSG_TYPE_TO_STR = {
|
||||
@ -57,11 +59,24 @@ MSG_TYPE_TO_STR = {
|
||||
MSG_TYPE_ERROR: MSG_TYPE_ERROR_STR,
|
||||
MSG_TYPE_CHUNK: MSG_TYPE_CHUNK_STR,
|
||||
MSG_TYPE_END: MSG_TYPE_END_STR,
|
||||
MSG_TYPE_STASH: MSG_TYPE_STASH_STR,
|
||||
}
|
||||
|
||||
# Map string names to int types
|
||||
MSG_TYPE_FROM_STR = {v: k for k, v in MSG_TYPE_TO_STR.items()}
|
||||
|
||||
# Stash operation codes
|
||||
STASH_OP_PUT = 1
|
||||
STASH_OP_GET = 2
|
||||
STASH_OP_DELETE = 3
|
||||
STASH_OP_KEYS = 4
|
||||
STASH_OP_CLEAR = 5
|
||||
STASH_OP_INFO = 6
|
||||
STASH_OP_ENSURE = 7
|
||||
STASH_OP_DELETE_TABLE = 8
|
||||
STASH_OP_TABLES = 9
|
||||
STASH_OP_EXISTS = 10
|
||||
|
||||
# Header format: Magic (2) + Version (1) + Type (1) + Length (4) + RequestID (8) = 16
|
||||
HEADER_FORMAT = ">2sBBIQ"
|
||||
HEADER_SIZE = struct.calcsize(HEADER_FORMAT)
|
||||
@ -82,6 +97,7 @@ class BinaryProtocol:
|
||||
MSG_TYPE_ERROR = MSG_TYPE_ERROR_STR
|
||||
MSG_TYPE_CHUNK = MSG_TYPE_CHUNK_STR
|
||||
MSG_TYPE_END = MSG_TYPE_END_STR
|
||||
MSG_TYPE_STASH = MSG_TYPE_STASH_STR
|
||||
|
||||
@staticmethod
|
||||
def encode_header(msg_type: int, request_id: int, payload_length: int) -> bytes:
|
||||
@ -257,6 +273,39 @@ class BinaryProtocol:
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_END, request_id, 0)
|
||||
return header
|
||||
|
||||
@staticmethod
|
||||
def encode_stash(request_id: int, op: int, table: str,
|
||||
key=None, value=None, pattern=None) -> bytes:
|
||||
"""
|
||||
Encode a stash operation message.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (uint64)
|
||||
op: Stash operation code (STASH_OP_*)
|
||||
table: Table name
|
||||
key: Optional key for put/get/delete operations
|
||||
value: Optional value for put operation
|
||||
pattern: Optional pattern for keys operation
|
||||
|
||||
Returns:
|
||||
bytes: Complete message (header + payload)
|
||||
"""
|
||||
payload_dict = {
|
||||
"op": op,
|
||||
"table": table,
|
||||
}
|
||||
if key is not None:
|
||||
payload_dict["key"] = key
|
||||
if value is not None:
|
||||
payload_dict["value"] = value
|
||||
if pattern is not None:
|
||||
payload_dict["pattern"] = pattern
|
||||
|
||||
payload = TLVEncoder.encode(payload_dict)
|
||||
header = BinaryProtocol.encode_header(MSG_TYPE_STASH, request_id,
|
||||
len(payload))
|
||||
return header + payload
|
||||
|
||||
@staticmethod
|
||||
def decode_message(data: bytes) -> tuple:
|
||||
"""
|
||||
@ -524,6 +573,15 @@ class BinaryProtocol:
|
||||
)
|
||||
elif msg_type == MSG_TYPE_END:
|
||||
return BinaryProtocol.encode_end(request_id)
|
||||
elif msg_type == MSG_TYPE_STASH:
|
||||
return BinaryProtocol.encode_stash(
|
||||
request_id,
|
||||
message.get("op"),
|
||||
message.get("table", ""),
|
||||
message.get("key"),
|
||||
message.get("value"),
|
||||
message.get("pattern")
|
||||
)
|
||||
else:
|
||||
raise DirtyProtocolError(f"Unhandled message type: {msg_type}")
|
||||
|
||||
@ -642,3 +700,34 @@ def make_end_message(request_id) -> dict:
|
||||
"type": DirtyProtocol.MSG_TYPE_END,
|
||||
"id": request_id,
|
||||
}
|
||||
|
||||
|
||||
def make_stash_message(request_id, op: int, table: str,
|
||||
key=None, value=None, pattern=None) -> dict:
|
||||
"""
|
||||
Build a stash operation message dict.
|
||||
|
||||
Args:
|
||||
request_id: Unique request identifier (int or str)
|
||||
op: Stash operation code (STASH_OP_*)
|
||||
table: Table name
|
||||
key: Optional key for put/get/delete operations
|
||||
value: Optional value for put operation
|
||||
pattern: Optional pattern for keys operation
|
||||
|
||||
Returns:
|
||||
dict: Stash message dict
|
||||
"""
|
||||
msg = {
|
||||
"type": DirtyProtocol.MSG_TYPE_STASH,
|
||||
"id": request_id,
|
||||
"op": op,
|
||||
"table": table,
|
||||
}
|
||||
if key is not None:
|
||||
msg["key"] = key
|
||||
if value is not None:
|
||||
msg["value"] = value
|
||||
if pattern is not None:
|
||||
msg["pattern"] = pattern
|
||||
return msg
|
||||
|
||||
503
gunicorn/dirty/stash.py
Normal file
503
gunicorn/dirty/stash.py
Normal file
@ -0,0 +1,503 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""
|
||||
Stash - Global Shared State for Dirty Workers
|
||||
|
||||
Provides simple key-value tables stored in the arbiter process.
|
||||
All workers can read and write to the same tables.
|
||||
|
||||
Usage::
|
||||
|
||||
from gunicorn.dirty import stash
|
||||
|
||||
# Basic operations - table is auto-created on first access
|
||||
stash.put("sessions", "user:1", {"name": "Alice", "role": "admin"})
|
||||
user = stash.get("sessions", "user:1")
|
||||
stash.delete("sessions", "user:1")
|
||||
|
||||
# Dict-like interface
|
||||
sessions = stash.table("sessions")
|
||||
sessions["user:1"] = {"name": "Alice"}
|
||||
user = sessions["user:1"]
|
||||
del sessions["user:1"]
|
||||
|
||||
# Query operations
|
||||
keys = stash.keys("sessions")
|
||||
keys = stash.keys("sessions", pattern="user:*")
|
||||
|
||||
# Table management
|
||||
stash.ensure("cache") # Explicit creation (idempotent)
|
||||
stash.clear("sessions") # Delete all entries
|
||||
stash.delete_table("sessions") # Delete the table itself
|
||||
tables = stash.tables() # List all tables
|
||||
|
||||
Declarative usage in DirtyApp::
|
||||
|
||||
class MyApp(DirtyApp):
|
||||
stashes = ["sessions", "cache"] # Auto-created on arbiter start
|
||||
|
||||
def __call__(self, action, *args, **kwargs):
|
||||
# Tables are ready to use
|
||||
stash.put("sessions", "key", "value")
|
||||
|
||||
Note: Tables are stored in the arbiter process and are ephemeral.
|
||||
If the arbiter restarts, all data is lost.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
from .errors import DirtyError
|
||||
from .protocol import (
|
||||
DirtyProtocol,
|
||||
STASH_OP_PUT,
|
||||
STASH_OP_GET,
|
||||
STASH_OP_DELETE,
|
||||
STASH_OP_KEYS,
|
||||
STASH_OP_CLEAR,
|
||||
STASH_OP_INFO,
|
||||
STASH_OP_ENSURE,
|
||||
STASH_OP_DELETE_TABLE,
|
||||
STASH_OP_TABLES,
|
||||
STASH_OP_EXISTS,
|
||||
make_stash_message,
|
||||
)
|
||||
|
||||
|
||||
class StashError(DirtyError):
|
||||
"""Base exception for stash operations."""
|
||||
|
||||
|
||||
class StashTableNotFoundError(StashError):
|
||||
"""Raised when a table does not exist."""
|
||||
|
||||
def __init__(self, table_name):
|
||||
self.table_name = table_name
|
||||
super().__init__(f"Stash table not found: {table_name}")
|
||||
|
||||
|
||||
class StashKeyNotFoundError(StashError):
|
||||
"""Raised when a key does not exist in a table."""
|
||||
|
||||
def __init__(self, table_name, key):
|
||||
self.table_name = table_name
|
||||
self.key = key
|
||||
super().__init__(f"Key not found in {table_name}: {key}")
|
||||
|
||||
|
||||
class StashClient:
|
||||
"""
|
||||
Client for stash operations.
|
||||
|
||||
Communicates with the arbiter which stores all tables in memory.
|
||||
"""
|
||||
|
||||
def __init__(self, socket_path, timeout=30.0):
|
||||
"""
|
||||
Initialize the stash client.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the dirty arbiter's Unix socket
|
||||
timeout: Default timeout for operations in seconds
|
||||
"""
|
||||
self.socket_path = socket_path
|
||||
self.timeout = timeout
|
||||
self._sock = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_request_id(self):
|
||||
"""Generate a unique request ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def _connect(self):
|
||||
"""Establish connection to arbiter."""
|
||||
import socket
|
||||
if self._sock is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
self._sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
self._sock.settimeout(self.timeout)
|
||||
self._sock.connect(self.socket_path)
|
||||
except (socket.error, OSError) as e:
|
||||
self._sock = None
|
||||
raise StashError(f"Failed to connect to arbiter: {e}") from e
|
||||
|
||||
def _close(self):
|
||||
"""Close the connection."""
|
||||
if self._sock is not None:
|
||||
try:
|
||||
self._sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._sock = None
|
||||
|
||||
def _execute(self, op, table, key=None, value=None, pattern=None):
|
||||
"""
|
||||
Execute a stash operation.
|
||||
|
||||
Args:
|
||||
op: Operation code (STASH_OP_*)
|
||||
table: Table name
|
||||
key: Optional key
|
||||
value: Optional value
|
||||
pattern: Optional pattern for keys operation
|
||||
|
||||
Returns:
|
||||
Result from the operation
|
||||
"""
|
||||
with self._lock:
|
||||
if self._sock is None:
|
||||
self._connect()
|
||||
|
||||
request_id = self._get_request_id()
|
||||
message = make_stash_message(
|
||||
request_id, op, table,
|
||||
key=key, value=value, pattern=pattern
|
||||
)
|
||||
|
||||
try:
|
||||
DirtyProtocol.write_message(self._sock, message)
|
||||
response = DirtyProtocol.read_message(self._sock)
|
||||
|
||||
msg_type = response.get("type")
|
||||
if msg_type == DirtyProtocol.MSG_TYPE_RESPONSE:
|
||||
return response.get("result")
|
||||
elif msg_type == DirtyProtocol.MSG_TYPE_ERROR:
|
||||
error_info = response.get("error", {})
|
||||
error_type = error_info.get("error_type", "StashError")
|
||||
error_msg = error_info.get("message", "Unknown error")
|
||||
|
||||
if error_type == "StashTableNotFoundError":
|
||||
raise StashTableNotFoundError(table)
|
||||
if error_type == "StashKeyNotFoundError":
|
||||
raise StashKeyNotFoundError(table, key)
|
||||
raise StashError(error_msg)
|
||||
else:
|
||||
raise StashError(f"Unexpected response type: {msg_type}")
|
||||
|
||||
except Exception as e:
|
||||
self._close()
|
||||
if isinstance(e, StashError):
|
||||
raise
|
||||
raise StashError(f"Stash operation failed: {e}") from e
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Public API
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def put(self, table, key, value):
|
||||
"""
|
||||
Store a value in a table.
|
||||
|
||||
The table is automatically created if it doesn't exist.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key: Key to store under
|
||||
value: Value to store (must be serializable)
|
||||
"""
|
||||
self._execute(STASH_OP_PUT, table, key=key, value=value)
|
||||
|
||||
def get(self, table, key, default=None):
|
||||
"""
|
||||
Retrieve a value from a table.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key: Key to retrieve
|
||||
default: Default value if key not found
|
||||
|
||||
Returns:
|
||||
The stored value, or default if not found
|
||||
"""
|
||||
try:
|
||||
return self._execute(STASH_OP_GET, table, key=key)
|
||||
except StashKeyNotFoundError:
|
||||
return default
|
||||
|
||||
def delete(self, table, key):
|
||||
"""
|
||||
Delete a key from a table.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key: Key to delete
|
||||
|
||||
Returns:
|
||||
True if key was deleted, False if it didn't exist
|
||||
"""
|
||||
return self._execute(STASH_OP_DELETE, table, key=key)
|
||||
|
||||
def keys(self, table, pattern=None):
|
||||
"""
|
||||
Get all keys in a table, optionally filtered by pattern.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
pattern: Optional glob pattern (e.g., "user:*")
|
||||
|
||||
Returns:
|
||||
List of keys
|
||||
"""
|
||||
return self._execute(STASH_OP_KEYS, table, pattern=pattern)
|
||||
|
||||
def clear(self, table):
|
||||
"""
|
||||
Delete all entries in a table.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
"""
|
||||
self._execute(STASH_OP_CLEAR, table)
|
||||
|
||||
def info(self, table):
|
||||
"""
|
||||
Get information about a table.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
|
||||
Returns:
|
||||
Dict with table info (size, etc.)
|
||||
"""
|
||||
return self._execute(STASH_OP_INFO, table)
|
||||
|
||||
def ensure(self, table):
|
||||
"""
|
||||
Ensure a table exists (create if not exists).
|
||||
|
||||
This is idempotent - calling it multiple times is safe.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
"""
|
||||
self._execute(STASH_OP_ENSURE, table)
|
||||
|
||||
def exists(self, table, key=None):
|
||||
"""
|
||||
Check if a table or key exists.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
key: Optional key to check within the table
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise
|
||||
"""
|
||||
return self._execute(STASH_OP_EXISTS, table, key=key)
|
||||
|
||||
def delete_table(self, table):
|
||||
"""
|
||||
Delete an entire table.
|
||||
|
||||
Args:
|
||||
table: Table name
|
||||
"""
|
||||
self._execute(STASH_OP_DELETE_TABLE, table)
|
||||
|
||||
def tables(self):
|
||||
"""
|
||||
List all tables.
|
||||
|
||||
Returns:
|
||||
List of table names
|
||||
"""
|
||||
return self._execute(STASH_OP_TABLES, "")
|
||||
|
||||
def table(self, name):
|
||||
"""
|
||||
Get a dict-like interface to a table.
|
||||
|
||||
Args:
|
||||
name: Table name
|
||||
|
||||
Returns:
|
||||
StashTable instance
|
||||
"""
|
||||
return StashTable(self, name)
|
||||
|
||||
def close(self):
|
||||
"""Close the client connection."""
|
||||
with self._lock:
|
||||
self._close()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
|
||||
class StashTable:
|
||||
"""
|
||||
Dict-like interface to a stash table.
|
||||
|
||||
Example::
|
||||
|
||||
sessions = stash.table("sessions")
|
||||
sessions["user:1"] = {"name": "Alice"}
|
||||
user = sessions["user:1"]
|
||||
del sessions["user:1"]
|
||||
|
||||
# Iteration
|
||||
for key in sessions:
|
||||
print(key, sessions[key])
|
||||
"""
|
||||
|
||||
def __init__(self, client, name):
|
||||
self._client = client
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Table name."""
|
||||
return self._name
|
||||
|
||||
def __getitem__(self, key):
|
||||
result = self._client.get(self._name, key)
|
||||
if result is None:
|
||||
# Check if key actually exists with None value
|
||||
if not self._client.exists(self._name, key):
|
||||
raise KeyError(key)
|
||||
return result
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._client.put(self._name, key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
if not self._client.delete(self._name, key):
|
||||
raise KeyError(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self._client.exists(self._name, key)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._client.keys(self._name))
|
||||
|
||||
def __len__(self):
|
||||
info = self._client.info(self._name)
|
||||
return info.get("size", 0)
|
||||
|
||||
def get(self, key, default=None):
|
||||
"""Get value with default."""
|
||||
return self._client.get(self._name, key, default)
|
||||
|
||||
def keys(self, pattern=None):
|
||||
"""Get all keys, optionally filtered by pattern."""
|
||||
return self._client.keys(self._name, pattern=pattern)
|
||||
|
||||
def clear(self):
|
||||
"""Delete all entries."""
|
||||
self._client.clear(self._name)
|
||||
|
||||
def items(self):
|
||||
"""Iterate over (key, value) pairs."""
|
||||
for key in self._client.keys(self._name):
|
||||
yield key, self._client.get(self._name, key)
|
||||
|
||||
def values(self):
|
||||
"""Iterate over values."""
|
||||
for key in self._client.keys(self._name):
|
||||
yield self._client.get(self._name, key)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global stash instance (module-level API)
|
||||
# =============================================================================
|
||||
|
||||
# Thread-local storage for stash clients
|
||||
_thread_local = threading.local()
|
||||
|
||||
# Global socket path
|
||||
_stash_socket_path = None
|
||||
|
||||
|
||||
def set_stash_socket_path(path):
|
||||
"""Set the global stash socket path (called during initialization)."""
|
||||
global _stash_socket_path # pylint: disable=global-statement
|
||||
_stash_socket_path = path
|
||||
|
||||
|
||||
def get_stash_socket_path():
|
||||
"""Get the stash socket path."""
|
||||
import os
|
||||
if _stash_socket_path is None:
|
||||
# Check environment variable
|
||||
path = os.environ.get('GUNICORN_DIRTY_SOCKET')
|
||||
if path:
|
||||
return path
|
||||
raise StashError(
|
||||
"Stash socket path not configured. "
|
||||
"Make sure dirty_workers > 0 and dirty_apps are configured."
|
||||
)
|
||||
return _stash_socket_path
|
||||
|
||||
|
||||
def _get_client():
|
||||
"""Get or create a thread-local stash client."""
|
||||
client = getattr(_thread_local, 'stash_client', None)
|
||||
if client is None:
|
||||
socket_path = get_stash_socket_path()
|
||||
client = StashClient(socket_path)
|
||||
_thread_local.stash_client = client
|
||||
return client
|
||||
|
||||
|
||||
# Module-level functions that use the thread-local client
|
||||
|
||||
def put(table, key, value):
|
||||
"""Store a value in a table."""
|
||||
_get_client().put(table, key, value)
|
||||
|
||||
|
||||
def get(table, key, default=None):
|
||||
"""Retrieve a value from a table."""
|
||||
return _get_client().get(table, key, default)
|
||||
|
||||
|
||||
def delete(table, key):
|
||||
"""Delete a key from a table."""
|
||||
return _get_client().delete(table, key)
|
||||
|
||||
|
||||
def keys(table, pattern=None):
|
||||
"""Get all keys in a table."""
|
||||
return _get_client().keys(table, pattern)
|
||||
|
||||
|
||||
def clear(table):
|
||||
"""Delete all entries in a table."""
|
||||
_get_client().clear(table)
|
||||
|
||||
|
||||
def info(table):
|
||||
"""Get information about a table."""
|
||||
return _get_client().info(table)
|
||||
|
||||
|
||||
def ensure(table):
|
||||
"""Ensure a table exists."""
|
||||
_get_client().ensure(table)
|
||||
|
||||
|
||||
def exists(table, key=None):
|
||||
"""Check if a table or key exists."""
|
||||
return _get_client().exists(table, key)
|
||||
|
||||
|
||||
def delete_table(table):
|
||||
"""Delete an entire table."""
|
||||
_get_client().delete_table(table)
|
||||
|
||||
|
||||
def tables():
|
||||
"""List all tables."""
|
||||
return _get_client().tables()
|
||||
|
||||
|
||||
def table(name):
|
||||
"""Get a dict-like interface to a table."""
|
||||
return _get_client().table(name)
|
||||
@ -168,10 +168,10 @@ class DirtyWorker:
|
||||
self.load_apps()
|
||||
|
||||
# Call hook
|
||||
self.pid = os.getpid()
|
||||
self.cfg.dirty_worker_init(self)
|
||||
|
||||
# Enter main run loop
|
||||
self.pid = os.getpid()
|
||||
self.booted = True
|
||||
self.run()
|
||||
|
||||
|
||||
206
tests/test_dirty_stash.py
Normal file
206
tests/test_dirty_stash.py
Normal file
@ -0,0 +1,206 @@
|
||||
#
|
||||
# This file is part of gunicorn released under the MIT license.
|
||||
# See the NOTICE for more information.
|
||||
|
||||
"""Tests for dirty stash (shared state) functionality."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gunicorn.dirty.stash import (
|
||||
StashClient,
|
||||
StashTable,
|
||||
StashError,
|
||||
StashTableNotFoundError,
|
||||
StashKeyNotFoundError,
|
||||
)
|
||||
from gunicorn.dirty.protocol import (
|
||||
BinaryProtocol,
|
||||
DirtyProtocol,
|
||||
MSG_TYPE_STASH,
|
||||
STASH_OP_PUT,
|
||||
STASH_OP_GET,
|
||||
STASH_OP_DELETE,
|
||||
STASH_OP_KEYS,
|
||||
STASH_OP_CLEAR,
|
||||
STASH_OP_INFO,
|
||||
STASH_OP_ENSURE,
|
||||
STASH_OP_DELETE_TABLE,
|
||||
STASH_OP_TABLES,
|
||||
STASH_OP_EXISTS,
|
||||
make_stash_message,
|
||||
)
|
||||
|
||||
|
||||
class TestStashProtocol:
|
||||
"""Test stash protocol encoding."""
|
||||
|
||||
def test_make_stash_message_basic(self):
|
||||
"""Test basic stash message creation."""
|
||||
msg = make_stash_message(123, STASH_OP_PUT, "test_table")
|
||||
assert msg["type"] == "stash"
|
||||
assert msg["id"] == 123
|
||||
assert msg["op"] == STASH_OP_PUT
|
||||
assert msg["table"] == "test_table"
|
||||
|
||||
def test_make_stash_message_with_key_value(self):
|
||||
"""Test stash message with key and value."""
|
||||
msg = make_stash_message(
|
||||
456, STASH_OP_PUT, "sessions",
|
||||
key="user:1", value={"name": "Alice"}
|
||||
)
|
||||
assert msg["key"] == "user:1"
|
||||
assert msg["value"] == {"name": "Alice"}
|
||||
|
||||
def test_make_stash_message_with_pattern(self):
|
||||
"""Test stash message with pattern."""
|
||||
msg = make_stash_message(
|
||||
789, STASH_OP_KEYS, "sessions",
|
||||
pattern="user:*"
|
||||
)
|
||||
assert msg["pattern"] == "user:*"
|
||||
|
||||
def test_encode_stash_message(self):
|
||||
"""Test binary encoding of stash message."""
|
||||
msg = make_stash_message(
|
||||
123, STASH_OP_PUT, "test",
|
||||
key="k", value="v"
|
||||
)
|
||||
encoded = BinaryProtocol._encode_from_dict(msg)
|
||||
assert isinstance(encoded, bytes)
|
||||
assert len(encoded) > 16 # Header + payload
|
||||
|
||||
def test_stash_message_roundtrip(self):
|
||||
"""Test encode/decode roundtrip for stash message."""
|
||||
original = make_stash_message(
|
||||
12345, STASH_OP_GET, "cache",
|
||||
key="my_key"
|
||||
)
|
||||
encoded = BinaryProtocol._encode_from_dict(original)
|
||||
msg_type, request_id, payload = BinaryProtocol.decode_message(encoded)
|
||||
|
||||
assert msg_type == "stash"
|
||||
assert payload["op"] == STASH_OP_GET
|
||||
assert payload["table"] == "cache"
|
||||
assert payload["key"] == "my_key"
|
||||
|
||||
def test_stash_operations_have_unique_codes(self):
|
||||
"""Test that all stash operations have unique codes."""
|
||||
ops = [
|
||||
STASH_OP_PUT,
|
||||
STASH_OP_GET,
|
||||
STASH_OP_DELETE,
|
||||
STASH_OP_KEYS,
|
||||
STASH_OP_CLEAR,
|
||||
STASH_OP_INFO,
|
||||
STASH_OP_ENSURE,
|
||||
STASH_OP_DELETE_TABLE,
|
||||
STASH_OP_TABLES,
|
||||
STASH_OP_EXISTS,
|
||||
]
|
||||
assert len(ops) == len(set(ops))
|
||||
|
||||
|
||||
class TestStashTable:
|
||||
"""Test StashTable dict-like interface."""
|
||||
|
||||
def test_stash_table_name(self):
|
||||
"""Test StashTable name property."""
|
||||
# Create a mock client
|
||||
class MockClient:
|
||||
pass
|
||||
|
||||
table = StashTable(MockClient(), "test_table")
|
||||
assert table.name == "test_table"
|
||||
|
||||
|
||||
class TestStashErrors:
|
||||
"""Test stash error classes."""
|
||||
|
||||
def test_stash_error_base(self):
|
||||
"""Test base StashError."""
|
||||
error = StashError("test error")
|
||||
assert str(error) == "test error"
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_stash_table_not_found_error(self):
|
||||
"""Test StashTableNotFoundError."""
|
||||
error = StashTableNotFoundError("my_table")
|
||||
assert error.table_name == "my_table"
|
||||
assert "my_table" in str(error)
|
||||
|
||||
def test_stash_key_not_found_error(self):
|
||||
"""Test StashKeyNotFoundError."""
|
||||
error = StashKeyNotFoundError("my_table", "my_key")
|
||||
assert error.table_name == "my_table"
|
||||
assert error.key == "my_key"
|
||||
assert "my_key" in str(error)
|
||||
|
||||
|
||||
class TestStashProtocolConstants:
|
||||
"""Test protocol constants for stash."""
|
||||
|
||||
def test_msg_type_stash_exists(self):
|
||||
"""Test MSG_TYPE_STASH constant exists."""
|
||||
assert MSG_TYPE_STASH == 0x10
|
||||
|
||||
def test_dirty_protocol_exports_stash_type(self):
|
||||
"""Test DirtyProtocol exports stash type."""
|
||||
assert DirtyProtocol.MSG_TYPE_STASH == "stash"
|
||||
|
||||
def test_stash_op_codes(self):
|
||||
"""Test stash operation codes are integers."""
|
||||
assert isinstance(STASH_OP_PUT, int)
|
||||
assert isinstance(STASH_OP_GET, int)
|
||||
assert isinstance(STASH_OP_DELETE, int)
|
||||
assert isinstance(STASH_OP_KEYS, int)
|
||||
assert isinstance(STASH_OP_CLEAR, int)
|
||||
assert isinstance(STASH_OP_INFO, int)
|
||||
assert isinstance(STASH_OP_ENSURE, int)
|
||||
assert isinstance(STASH_OP_DELETE_TABLE, int)
|
||||
assert isinstance(STASH_OP_TABLES, int)
|
||||
assert isinstance(STASH_OP_EXISTS, int)
|
||||
|
||||
|
||||
class TestStashEncodingEdgeCases:
|
||||
"""Test edge cases in stash encoding."""
|
||||
|
||||
def test_encode_empty_table_name(self):
|
||||
"""Test encoding with empty table name."""
|
||||
msg = make_stash_message(1, STASH_OP_TABLES, "")
|
||||
encoded = BinaryProtocol._encode_from_dict(msg)
|
||||
assert isinstance(encoded, bytes)
|
||||
|
||||
def test_encode_unicode_table_name(self):
|
||||
"""Test encoding with unicode table name."""
|
||||
msg = make_stash_message(1, STASH_OP_PUT, "テスト", key="k", value="v")
|
||||
encoded = BinaryProtocol._encode_from_dict(msg)
|
||||
_, _, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["table"] == "テスト"
|
||||
|
||||
def test_encode_complex_value(self):
|
||||
"""Test encoding with complex nested value."""
|
||||
value = {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"nested": {"a": [1, 2, 3]},
|
||||
"data": b"binary data",
|
||||
}
|
||||
msg = make_stash_message(1, STASH_OP_PUT, "test", key="k", value=value)
|
||||
encoded = BinaryProtocol._encode_from_dict(msg)
|
||||
_, _, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["value"] == value
|
||||
|
||||
def test_encode_none_key(self):
|
||||
"""Test encoding with None key (for table-level ops)."""
|
||||
msg = make_stash_message(1, STASH_OP_TABLES, "")
|
||||
assert "key" not in msg
|
||||
|
||||
def test_encode_special_characters_in_pattern(self):
|
||||
"""Test encoding with special characters in pattern."""
|
||||
msg = make_stash_message(
|
||||
1, STASH_OP_KEYS, "test",
|
||||
pattern="user:*:session:?"
|
||||
)
|
||||
encoded = BinaryProtocol._encode_from_dict(msg)
|
||||
_, _, payload = BinaryProtocol.decode_message(encoded)
|
||||
assert payload["pattern"] == "user:*:session:?"
|
||||
Loading…
x
Reference in New Issue
Block a user