Benoit Chesneau 709a6ad159
feat(dirty): add stash - global shared state between workers (#3503)
* feat(dirty): add stash - global shared state between workers

Add a simple key-value store (stash) that allows dirty workers to share
state through the arbiter. Tables are stored directly in arbiter memory
for fast access and simplicity.

Features:
- Auto-create tables on first access
- Dict-like interface via stash.table()
- Pattern matching for keys (glob patterns)
- Module-level API: stash.put(), stash.get(), stash.delete(), etc.

Usage:
    from gunicorn.dirty import stash

    stash.put("sessions", "user:1", {"name": "Alice"})
    user = stash.get("sessions", "user:1")

    # Or dict-like
    sessions = stash.table("sessions")
    sessions["user:1"] = {"name": "Alice"}

New files:
- gunicorn/dirty/stash.py - Client API and StashTable class
- Protocol additions for MSG_TYPE_STASH and STASH_OP_* codes

Note: Tables are ephemeral - lost if arbiter restarts.

* test(dirty): add tests for stash protocol and encoding

Test coverage for:
- Stash message creation and encoding
- Protocol constants (MSG_TYPE_STASH, STASH_OP_*)
- Error classes (StashError, StashTableNotFoundError, StashKeyNotFoundError)
- StashTable dict-like interface
- Edge cases: unicode, complex values, special patterns

* example(dirty): add stash usage example and integration tests

- Add SessionApp to dirty_app.py demonstrating stash usage
- Add /session/* endpoints to wsgi_app.py
- Add test_stash_integration.py with Docker tests
- Update docker-compose.yml with stash-test service
- Fix: Set GUNICORN_DIRTY_SOCKET in dirty arbiter for worker access

* docs(dirty): add stash documentation
2026-02-12 21:45:49 +01:00

269 lines
8.5 KiB
Python

#
# This file is part of gunicorn released under the MIT license.
# See the NOTICE for more information.
"""
Example Dirty Application - Simulates ML Model Loading and Inference
This demonstrates how to create a DirtyApp that:
1. Loads "models" at startup (init)
2. Handles requests from HTTP workers (__call__)
3. Cleans up on shutdown (close)
"""
import os
import time
import hashlib
from gunicorn.dirty.app import DirtyApp
from gunicorn.dirty import stash
class MLApp(DirtyApp):
"""
Example dirty application that simulates ML model operations.
In a real application, this would load actual ML models like:
- PyTorch models
- TensorFlow models
- Scikit-learn models
- LLM models (Hugging Face, etc.)
"""
def __init__(self):
self.models = {}
self.load_count = 0
self.inference_count = 0
def init(self):
"""Called once when dirty worker starts."""
print(f"[MLApp] Initializing... (pid: {__import__('os').getpid()})")
# Simulate loading a default model (takes time)
self._load_model("default")
print(f"[MLApp] Initialization complete. Models loaded: {list(self.models.keys())}")
def __call__(self, action, *args, **kwargs):
"""Dispatch to action methods."""
method = getattr(self, action, None)
if method is None or action.startswith('_'):
raise ValueError(f"Unknown action: {action}")
return method(*args, **kwargs)
def _load_model(self, name):
"""Simulate loading a model (expensive operation)."""
print(f"[MLApp] Loading model '{name}'...")
# Simulate model loading time
time.sleep(0.5)
# Create a fake "model" object
self.models[name] = {
"name": name,
"loaded_at": time.time(),
"version": "1.0.0",
"parameters": 1_000_000, # Simulated parameter count
}
self.load_count += 1
print(f"[MLApp] Model '{name}' loaded successfully")
return self.models[name]
def load_model(self, name):
"""Load a model into memory (called from HTTP workers)."""
if name in self.models:
return {"status": "already_loaded", "model": self.models[name]}
model = self._load_model(name)
return {"status": "loaded", "model": model}
def list_models(self):
"""List all loaded models."""
return {
"models": list(self.models.keys()),
"count": len(self.models),
"total_loads": self.load_count,
"total_inferences": self.inference_count,
}
def inference(self, model_name, input_data):
"""Run inference on a loaded model."""
if model_name not in self.models:
raise ValueError(f"Model not loaded: {model_name}")
model = self.models[model_name]
self.inference_count += 1
# Simulate inference (compute a hash as a "prediction")
time.sleep(0.1) # Simulate computation time
result = {
"model": model_name,
"input_hash": hashlib.md5(str(input_data).encode()).hexdigest()[:8],
"prediction": f"result_{self.inference_count}",
"confidence": 0.95,
"inference_time_ms": 100,
}
return result
def unload_model(self, name):
"""Unload a model from memory."""
if name not in self.models:
return {"status": "not_found", "name": name}
del self.models[name]
return {"status": "unloaded", "name": name}
def close(self):
"""Cleanup on shutdown."""
print(f"[MLApp] Shutting down. Total inferences: {self.inference_count}")
self.models.clear()
class ComputeApp(DirtyApp):
"""
Example dirty application for CPU-intensive computations.
This demonstrates operations that would block HTTP workers
but are fine in dirty workers.
"""
def __init__(self):
self.computation_count = 0
def init(self):
print(f"[ComputeApp] Initialized (pid: {__import__('os').getpid()})")
def __call__(self, action, *args, **kwargs):
method = getattr(self, action, None)
if method is None or action.startswith('_'):
raise ValueError(f"Unknown action: {action}")
return method(*args, **kwargs)
def fibonacci(self, n):
"""Compute fibonacci number (CPU-intensive for large n)."""
self.computation_count += 1
if n <= 1:
return {"n": n, "result": n, "computation_id": self.computation_count}
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return {"n": n, "result": b, "computation_id": self.computation_count}
def prime_check(self, n):
"""Check if a number is prime (CPU-intensive for large n)."""
self.computation_count += 1
if n < 2:
is_prime = False
elif n == 2:
is_prime = True
elif n % 2 == 0:
is_prime = False
else:
is_prime = True
for i in range(3, int(n**0.5) + 1, 2):
if n % i == 0:
is_prime = False
break
return {"n": n, "is_prime": is_prime, "computation_id": self.computation_count}
def stats(self):
"""Get computation statistics."""
return {"total_computations": self.computation_count}
def close(self):
print(f"[ComputeApp] Shutting down. Total computations: {self.computation_count}")
class SessionApp(DirtyApp):
"""
Example dirty application demonstrating stash (shared state).
This shows how multiple dirty workers can share state through
the arbiter's stash tables. All workers see the same data.
"""
# Declare stash tables used by this app (auto-created on startup)
stashes = ["sessions", "counters"]
def __init__(self):
self.worker_pid = None
def init(self):
self.worker_pid = os.getpid()
print(f"[SessionApp] Initialized on worker {self.worker_pid}")
# Initialize a global counter if it doesn't exist
if not stash.exists("counters", "requests"):
stash.put("counters", "requests", 0)
def __call__(self, action, *args, **kwargs):
method = getattr(self, action, None)
if method is None or action.startswith('_'):
raise ValueError(f"Unknown action: {action}")
return method(*args, **kwargs)
def login(self, user_id, user_data):
"""Store user session in shared stash."""
session = {
"user_id": user_id,
"data": user_data,
"logged_in_at": time.time(),
"worker_pid": self.worker_pid,
}
stash.put("sessions", f"user:{user_id}", session)
self._increment_counter()
return {"status": "ok", "session": session}
def logout(self, user_id):
"""Remove user session."""
key = f"user:{user_id}"
if stash.exists("sessions", key):
stash.delete("sessions", key)
self._increment_counter()
return {"status": "logged_out", "user_id": user_id}
return {"status": "not_found", "user_id": user_id}
def get_session(self, user_id):
"""Get user session - visible from any worker."""
session = stash.get("sessions", f"user:{user_id}")
self._increment_counter()
return {
"session": session,
"served_by_worker": self.worker_pid,
}
def list_sessions(self):
"""List all active sessions."""
keys = stash.keys("sessions", pattern="user:*")
sessions = []
for key in keys:
sessions.append(stash.get("sessions", key))
self._increment_counter()
return {
"sessions": sessions,
"count": len(sessions),
"served_by_worker": self.worker_pid,
}
def get_stats(self):
"""Get global request counter (shared across all workers)."""
count = stash.get("counters", "requests", 0)
return {
"total_requests": count,
"served_by_worker": self.worker_pid,
}
def _increment_counter(self):
"""Increment global request counter."""
current = stash.get("counters", "requests", 0)
stash.put("counters", "requests", current + 1)
def clear_all(self):
"""Clear all sessions (for testing)."""
stash.clear("sessions")
stash.put("counters", "requests", 0)
return {"status": "cleared"}
def close(self):
print(f"[SessionApp] Shutting down worker {self.worker_pid}")