This commit is contained in:
David Sewell 2025-03-11 14:16:08 +00:00 committed by GitHub
commit fa6d4fd531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 175 additions and 9 deletions

View File

@ -387,6 +387,12 @@ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
WEBSOCKET_REDIS_CERTS = os.environ.get("WEBSOCKET_REDIS_CERTS", "")
WEBSOCKET_REDIS_USERNAME = os.environ.get("WEBSOCKET_REDIS_USERNAME", "")
WEBSOCKET_REDIS_PASSWORD = os.environ.get("WEBSOCKET_REDIS_PASSWORD", "")
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
if AIOHTTP_CLIENT_TIMEOUT == "":
@ -442,3 +448,9 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
)
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
####################################
# WEBSOCKET_REDIS_AZURE_CREDENTIALS
####################################
WEBSOCKET_REDIS_CREDENTIALS = os.environ.get("WEBSOCKET_REDIS_CREDENTIALS", "").lower()

View File

@ -13,9 +13,14 @@ from open_webui.env import (
WEBSOCKET_MANAGER,
WEBSOCKET_REDIS_URL,
WEBSOCKET_REDIS_LOCK_TIMEOUT,
WEBSOCKET_REDIS_CERTS,
WEBSOCKET_REDIS_USERNAME,
WEBSOCKET_REDIS_PASSWORD,
WEBSOCKET_REDIS_CREDENTIALS,
)
from open_webui.utils.auth import decode_token
from open_webui.socket.utils import RedisDict, RedisLock
from open_webui.utils.azure_services import AzureCredentialService
from open_webui.env import (
GLOBAL_LOG_LEVEL,
@ -27,9 +32,40 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
redis_options = {}
if WEBSOCKET_REDIS_CREDENTIALS == 'azure':
azure_credential_service = AzureCredentialService()
redis_options["password"] = azure_credential_service.get_token()
redis_options["username"] = azure_credential_service.get_username(redis_options["password"])
elif not WEBSOCKET_REDIS_PASSWORD:
redis_options["password"] = WEBSOCKET_REDIS_PASSWORD
redis_options["username"] = WEBSOCKET_REDIS_USERNAME
if WEBSOCKET_REDIS_URL.startswith("rediss") and WEBSOCKET_REDIS_CERTS:
redis_options["ssl_ca_certs"] = WEBSOCKET_REDIS_CERTS
# Retrieves and configures the Redis manager for Socket.IO with authentication and SSL if configured
def get_redis_manager():
try:
mgr = socketio.AsyncRedisManager(
WEBSOCKET_REDIS_URL,
redis_options=redis_options
)
return mgr
except ConnectionError as e:
log.exception(f"Could not connect to Redis: {e}")
raise e
def refresh_azure_credentials():
if WEBSOCKET_REDIS_CREDENTIALS == 'azure':
if azure_credential_service.is_expired():
redis_options["password"] = azure_credential_service.get_token()
if WEBSOCKET_MANAGER == "redis":
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
mgr = get_redis_manager()
if not mgr:
log.error("Could not connect to Redis. Exiting.")
sys.exit(1)
sio = socketio.AsyncServer(
cors_allowed_origins=[],
async_mode="asgi",
@ -47,7 +83,6 @@ else:
always_connect=True,
)
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@ -55,14 +90,27 @@ TIMEOUT_DURATION = 3
if WEBSOCKET_MANAGER == "redis":
log.debug("Using Redis to manage websockets.")
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
SESSION_POOL = RedisDict(
"open-webui:session_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_options=redis_options
)
USER_POOL = RedisDict(
"open-webui:user_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_options=redis_options
)
USAGE_POOL = RedisDict(
"open-webui:usage_pool",
redis_url=WEBSOCKET_REDIS_URL,
redis_options=redis_options
)
clean_up_lock = RedisLock(
redis_url=WEBSOCKET_REDIS_URL,
lock_name="usage_cleanup_lock",
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
redis_options=redis_options
)
aquire_func = clean_up_lock.aquire_lock
renew_func = clean_up_lock.renew_lock
@ -85,6 +133,10 @@ async def periodic_usage_pool_cleanup():
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
raise Exception("Unable to renew usage pool cleanup lock.")
if WEBSOCKET_MANAGER == "redis":
# update redis options to refresh auth token
refresh_azure_credentials()
now = int(time.time())
send_usage = False
for model_id, connections in list(USAGE_POOL.items()):

View File

@ -1,16 +1,106 @@
import json
import redis
import uuid
import logging
from open_webui.utils.azure_services import AzureCredentialService
from open_webui.env import (
WEBSOCKET_REDIS_CREDENTIALS,
SRC_LOG_LEVELS
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
class RedisService:
def __init__(self, redis_url, redis_options={}):
self.redis_url = redis_url
self.client = None
self.ssl_ca_certs = redis_options.get("ssl_ca_certs", None)
self.username = redis_options.get("username", None)
self.password = redis_options.get("password", None)
self.azure_credential_service = AzureCredentialService() if WEBSOCKET_REDIS_CREDENTIALS == 'azure' else None
self.init_redis()
def init_redis(self):
token = None
if self.azure_credential_service:
token = self.azure_credential_service.get_token()
self.username = self.azure_credential_service.get_username(token)
else:
token = self.password
try:
log.debug(f"redis_url: {self.redis_url}")
parameters = {
"url": self.redis_url,
"decode_responses": True,
"socket_timeout": 5,
}
if self.username and token:
log.debug(f"redis_username: {self.username}")
masked_password = f"{token[:3]}***{token[-3:]}" if token else None
log.debug(f"redis_password: {masked_password}")
parameters["username"] = self.username
parameters["password"] = token
if self.redis_url.startswith("rediss://"):
log.debug(f"redis_ssl_ca_certs: {self.ssl_ca_certs}")
parameters["ssl_ca_certs"] = self.ssl_ca_certs
self.client = redis.Redis.from_url(**parameters)
if self.client.ping():
log.info(f"Connected to Redis: {self.redis_url}")
else:
log.error(f"Failed to connect to Redis: {self.redis_url}")
raise e
except ConnectionError as e:
log.error(f"Failed to connect to Redis: {self.redis_url} {e}")
raise e
except TimeoutError as e:
log.error(f"Timed out connecting to Redis: {self.redis_url} {e}")
raise e
except redis.AuthenticationError as e:
log.error(f"Authentication failed connecting to Redis: {self.redis_url} {e}")
raise e
except Exception as e:
log.error(f"Failed to connect to Redis: {self.redis_url} {e}")
raise e
def get_client(self):
return self.client
# reinitialize the redis connection if an exception occurs
def reinit_onerror(func):
def wrapper(*args, **kwargs):
# Get the instance of the class
instance = args[0]
try:
return func(*args, **kwargs)
except Exception as e:
log.error(f'{instance.__class__}.{func.__name__}: {e}')
log.warning(f"Re-authenticate and initialize Redis Cache connection")
instance.init_redis()
return func(*args, **kwargs)
return wrapper
class RedisLock:
def __init__(self, redis_url, lock_name, timeout_secs):
def __init__(self, redis_url, lock_name, timeout_secs, **redis_kwargs):
self.lock_name = lock_name
self.lock_id = str(uuid.uuid4())
self.timeout_secs = timeout_secs
self.lock_obtained = False
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
self.redis_url = redis_url
self.redis_kwargs = redis_kwargs
self.init_redis()
def init_redis(self):
self.redis = RedisService(self.redis_url, **self.redis_kwargs).get_client()
@reinit_onerror
def aquire_lock(self):
# nx=True will only set this key if it _hasn't_ already been set
self.lock_obtained = self.redis.set(
@ -18,12 +108,14 @@ class RedisLock:
)
return self.lock_obtained
@reinit_onerror
def renew_lock(self):
# xx=True will only set this key if it _has_ already been set
return self.redis.set(
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
)
@reinit_onerror
def release_lock(self):
lock_value = self.redis.get(self.lock_name)
if lock_value and lock_value == self.lock_id:
@ -31,37 +123,46 @@ class RedisLock:
class RedisDict:
def __init__(self, name, redis_url):
def __init__(self, name, redis_url, **redis_kwargs):
self.name = name
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
self.redis_service = RedisService(redis_url, **redis_kwargs)
self.redis = self.redis_service.get_client()
@reinit_onerror
def __setitem__(self, key, value):
serialized_value = json.dumps(value)
self.redis.hset(self.name, key, serialized_value)
@reinit_onerror
def __getitem__(self, key):
value = self.redis.hget(self.name, key)
if value is None:
raise KeyError(key)
return json.loads(value)
@reinit_onerror
def __delitem__(self, key):
result = self.redis.hdel(self.name, key)
if result == 0:
raise KeyError(key)
@reinit_onerror
def __contains__(self, key):
return self.redis.hexists(self.name, key)
@reinit_onerror
def __len__(self):
return self.redis.hlen(self.name)
@reinit_onerror
def keys(self):
return self.redis.hkeys(self.name)
@reinit_onerror
def values(self):
return [json.loads(v) for v in self.redis.hvals(self.name)]
@reinit_onerror
def items(self):
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
@ -71,6 +172,7 @@ class RedisDict:
except KeyError:
return default
@reinit_onerror
def clear(self):
self.redis.delete(self.name)