diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 2abf65924..6d4269968 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -387,6 +387,12 @@ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "") WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL) WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60) +WEBSOCKET_REDIS_CERTS = os.environ.get("WEBSOCKET_REDIS_CERTS", "") + +WEBSOCKET_REDIS_USERNAME = os.environ.get("WEBSOCKET_REDIS_USERNAME", "") + +WEBSOCKET_REDIS_PASSWORD = os.environ.get("WEBSOCKET_REDIS_PASSWORD", "") + AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "") if AIOHTTP_CLIENT_TIMEOUT == "": @@ -442,3 +448,9 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders" ) AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS] AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS] + +#################################### +# WEBSOCKET_REDIS_AZURE_CREDENTIALS +#################################### + +WEBSOCKET_REDIS_CREDENTIALS = os.environ.get("WEBSOCKET_REDIS_CREDENTIALS", "").lower() diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index 8f5a9568b..58bb6ebb8 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -13,9 +13,14 @@ from open_webui.env import ( WEBSOCKET_MANAGER, WEBSOCKET_REDIS_URL, WEBSOCKET_REDIS_LOCK_TIMEOUT, + WEBSOCKET_REDIS_CERTS, + WEBSOCKET_REDIS_USERNAME, + WEBSOCKET_REDIS_PASSWORD, + WEBSOCKET_REDIS_CREDENTIALS, ) from open_webui.utils.auth import decode_token from open_webui.socket.utils import RedisDict, RedisLock +from open_webui.utils.azure_services import AzureCredentialService from open_webui.env import ( GLOBAL_LOG_LEVEL, @@ -27,9 +32,40 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["SOCKET"]) +redis_options = {} + +if WEBSOCKET_REDIS_CREDENTIALS == 'azure': + azure_credential_service = AzureCredentialService() + redis_options["password"] = azure_credential_service.get_token() + redis_options["username"] = azure_credential_service.get_username(redis_options["password"]) +elif not WEBSOCKET_REDIS_PASSWORD: + redis_options["password"] = WEBSOCKET_REDIS_PASSWORD + redis_options["username"] = WEBSOCKET_REDIS_USERNAME +if WEBSOCKET_REDIS_URL.startswith("rediss") and WEBSOCKET_REDIS_CERTS: + redis_options["ssl_ca_certs"] = WEBSOCKET_REDIS_CERTS + +# Retrieves and configures the Redis manager for Socket.IO with authentication and SSL if configured +def get_redis_manager(): + try: + mgr = socketio.AsyncRedisManager( + WEBSOCKET_REDIS_URL, + redis_options=redis_options + ) + return mgr + except ConnectionError as e: + log.exception(f"Could not connect to Redis: {e}") + raise e + +def refresh_azure_credentials(): + if WEBSOCKET_REDIS_CREDENTIALS == 'azure': + if azure_credential_service.is_expired(): + redis_options["password"] = azure_credential_service.get_token() if WEBSOCKET_MANAGER == "redis": - mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL) + mgr = get_redis_manager() + if not mgr: + log.error("Could not connect to Redis. Exiting.") + sys.exit(1) sio = socketio.AsyncServer( cors_allowed_origins=[], async_mode="asgi", @@ -47,7 +83,6 @@ else: always_connect=True, ) - # Timeout duration in seconds TIMEOUT_DURATION = 3 @@ -55,14 +90,27 @@ TIMEOUT_DURATION = 3 if WEBSOCKET_MANAGER == "redis": log.debug("Using Redis to manage websockets.") - SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL) - USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL) - USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL) + SESSION_POOL = RedisDict( + "open-webui:session_pool", + redis_url=WEBSOCKET_REDIS_URL, + redis_options=redis_options + ) + USER_POOL = RedisDict( + "open-webui:user_pool", + redis_url=WEBSOCKET_REDIS_URL, + redis_options=redis_options + ) + USAGE_POOL = RedisDict( + "open-webui:usage_pool", + redis_url=WEBSOCKET_REDIS_URL, + redis_options=redis_options + ) clean_up_lock = RedisLock( redis_url=WEBSOCKET_REDIS_URL, lock_name="usage_cleanup_lock", timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT, + redis_options=redis_options ) aquire_func = clean_up_lock.aquire_lock renew_func = clean_up_lock.renew_lock @@ -85,6 +133,10 @@ async def periodic_usage_pool_cleanup(): log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.") raise Exception("Unable to renew usage pool cleanup lock.") + if WEBSOCKET_MANAGER == "redis": + # update redis options to refresh auth token + refresh_azure_credentials() + now = int(time.time()) send_usage = False for model_id, connections in list(USAGE_POOL.items()): diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 46fafbb9e..26d5771d6 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,16 +1,106 @@ import json import redis import uuid +import logging +from open_webui.utils.azure_services import AzureCredentialService + +from open_webui.env import ( + WEBSOCKET_REDIS_CREDENTIALS, + SRC_LOG_LEVELS +) + +log = logging.getLogger(__name__) +log.setLevel(SRC_LOG_LEVELS["SOCKET"]) + + +class RedisService: + + def __init__(self, redis_url, redis_options={}): + self.redis_url = redis_url + self.client = None + self.ssl_ca_certs = redis_options.get("ssl_ca_certs", None) + self.username = redis_options.get("username", None) + self.password = redis_options.get("password", None) + self.azure_credential_service = AzureCredentialService() if WEBSOCKET_REDIS_CREDENTIALS == 'azure' else None + self.init_redis() + + def init_redis(self): + token = None + if self.azure_credential_service: + token = self.azure_credential_service.get_token() + self.username = self.azure_credential_service.get_username(token) + else: + token = self.password + try: + log.debug(f"redis_url: {self.redis_url}") + parameters = { + "url": self.redis_url, + "decode_responses": True, + "socket_timeout": 5, + } + if self.username and token: + log.debug(f"redis_username: {self.username}") + masked_password = f"{token[:3]}***{token[-3:]}" if token else None + log.debug(f"redis_password: {masked_password}") + parameters["username"] = self.username + parameters["password"] = token + if self.redis_url.startswith("rediss://"): + log.debug(f"redis_ssl_ca_certs: {self.ssl_ca_certs}") + parameters["ssl_ca_certs"] = self.ssl_ca_certs + self.client = redis.Redis.from_url(**parameters) + + if self.client.ping(): + log.info(f"Connected to Redis: {self.redis_url}") + else: + log.error(f"Failed to connect to Redis: {self.redis_url}") + raise e + + except ConnectionError as e: + log.error(f"Failed to connect to Redis: {self.redis_url} {e}") + raise e + except TimeoutError as e: + log.error(f"Timed out connecting to Redis: {self.redis_url} {e}") + raise e + except redis.AuthenticationError as e: + log.error(f"Authentication failed connecting to Redis: {self.redis_url} {e}") + raise e + except Exception as e: + log.error(f"Failed to connect to Redis: {self.redis_url} {e}") + raise e + + def get_client(self): + return self.client + +# reinitialize the redis connection if an exception occurs +def reinit_onerror(func): + def wrapper(*args, **kwargs): + # Get the instance of the class + instance = args[0] + try: + return func(*args, **kwargs) + except Exception as e: + log.error(f'{instance.__class__}.{func.__name__}: {e}') + log.warning(f"Re-authenticate and initialize Redis Cache connection") + instance.init_redis() + return func(*args, **kwargs) + return wrapper class RedisLock: - def __init__(self, redis_url, lock_name, timeout_secs): + def __init__(self, redis_url, lock_name, timeout_secs, **redis_kwargs): self.lock_name = lock_name self.lock_id = str(uuid.uuid4()) self.timeout_secs = timeout_secs self.lock_obtained = False - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis_url = redis_url + self.redis_kwargs = redis_kwargs + self.init_redis() + + def init_redis(self): + self.redis = RedisService(self.redis_url, **self.redis_kwargs).get_client() + + @reinit_onerror def aquire_lock(self): # nx=True will only set this key if it _hasn't_ already been set self.lock_obtained = self.redis.set( @@ -18,12 +108,14 @@ class RedisLock: ) return self.lock_obtained + @reinit_onerror def renew_lock(self): # xx=True will only set this key if it _has_ already been set return self.redis.set( self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs ) + @reinit_onerror def release_lock(self): lock_value = self.redis.get(self.lock_name) if lock_value and lock_value == self.lock_id: @@ -31,37 +123,46 @@ class RedisLock: class RedisDict: - def __init__(self, name, redis_url): + def __init__(self, name, redis_url, **redis_kwargs): self.name = name - self.redis = redis.Redis.from_url(redis_url, decode_responses=True) + self.redis_service = RedisService(redis_url, **redis_kwargs) + self.redis = self.redis_service.get_client() + @reinit_onerror def __setitem__(self, key, value): serialized_value = json.dumps(value) self.redis.hset(self.name, key, serialized_value) + @reinit_onerror def __getitem__(self, key): value = self.redis.hget(self.name, key) if value is None: raise KeyError(key) return json.loads(value) + @reinit_onerror def __delitem__(self, key): result = self.redis.hdel(self.name, key) if result == 0: raise KeyError(key) + @reinit_onerror def __contains__(self, key): return self.redis.hexists(self.name, key) + @reinit_onerror def __len__(self): return self.redis.hlen(self.name) + @reinit_onerror def keys(self): return self.redis.hkeys(self.name) + @reinit_onerror def values(self): return [json.loads(v) for v in self.redis.hvals(self.name)] + @reinit_onerror def items(self): return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()] @@ -71,6 +172,7 @@ class RedisDict: except KeyError: return default + @reinit_onerror def clear(self): self.redis.delete(self.name)