mirror of https://github.com/open-webui/open-webui
Merge 4a15b61f6d
into 0e7164b4f5
This commit is contained in:
commit
fa6d4fd531
|
@ -387,6 +387,12 @@ WEBSOCKET_MANAGER = os.environ.get("WEBSOCKET_MANAGER", "")
|
|||
WEBSOCKET_REDIS_URL = os.environ.get("WEBSOCKET_REDIS_URL", REDIS_URL)
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT = os.environ.get("WEBSOCKET_REDIS_LOCK_TIMEOUT", 60)
|
||||
|
||||
WEBSOCKET_REDIS_CERTS = os.environ.get("WEBSOCKET_REDIS_CERTS", "")
|
||||
|
||||
WEBSOCKET_REDIS_USERNAME = os.environ.get("WEBSOCKET_REDIS_USERNAME", "")
|
||||
|
||||
WEBSOCKET_REDIS_PASSWORD = os.environ.get("WEBSOCKET_REDIS_PASSWORD", "")
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "")
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||
|
@ -442,3 +448,9 @@ AUDIT_EXCLUDED_PATHS = os.getenv("AUDIT_EXCLUDED_PATHS", "/chats,/chat,/folders"
|
|||
)
|
||||
AUDIT_EXCLUDED_PATHS = [path.strip() for path in AUDIT_EXCLUDED_PATHS]
|
||||
AUDIT_EXCLUDED_PATHS = [path.lstrip("/") for path in AUDIT_EXCLUDED_PATHS]
|
||||
|
||||
####################################
|
||||
# WEBSOCKET_REDIS_AZURE_CREDENTIALS
|
||||
####################################
|
||||
|
||||
WEBSOCKET_REDIS_CREDENTIALS = os.environ.get("WEBSOCKET_REDIS_CREDENTIALS", "").lower()
|
||||
|
|
|
@ -13,9 +13,14 @@ from open_webui.env import (
|
|||
WEBSOCKET_MANAGER,
|
||||
WEBSOCKET_REDIS_URL,
|
||||
WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
WEBSOCKET_REDIS_CERTS,
|
||||
WEBSOCKET_REDIS_USERNAME,
|
||||
WEBSOCKET_REDIS_PASSWORD,
|
||||
WEBSOCKET_REDIS_CREDENTIALS,
|
||||
)
|
||||
from open_webui.utils.auth import decode_token
|
||||
from open_webui.socket.utils import RedisDict, RedisLock
|
||||
from open_webui.utils.azure_services import AzureCredentialService
|
||||
|
||||
from open_webui.env import (
|
||||
GLOBAL_LOG_LEVEL,
|
||||
|
@ -27,9 +32,40 @@ logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
|||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
||||
|
||||
redis_options = {}
|
||||
|
||||
if WEBSOCKET_REDIS_CREDENTIALS == 'azure':
|
||||
azure_credential_service = AzureCredentialService()
|
||||
redis_options["password"] = azure_credential_service.get_token()
|
||||
redis_options["username"] = azure_credential_service.get_username(redis_options["password"])
|
||||
elif not WEBSOCKET_REDIS_PASSWORD:
|
||||
redis_options["password"] = WEBSOCKET_REDIS_PASSWORD
|
||||
redis_options["username"] = WEBSOCKET_REDIS_USERNAME
|
||||
if WEBSOCKET_REDIS_URL.startswith("rediss") and WEBSOCKET_REDIS_CERTS:
|
||||
redis_options["ssl_ca_certs"] = WEBSOCKET_REDIS_CERTS
|
||||
|
||||
# Retrieves and configures the Redis manager for Socket.IO with authentication and SSL if configured
|
||||
def get_redis_manager():
|
||||
try:
|
||||
mgr = socketio.AsyncRedisManager(
|
||||
WEBSOCKET_REDIS_URL,
|
||||
redis_options=redis_options
|
||||
)
|
||||
return mgr
|
||||
except ConnectionError as e:
|
||||
log.exception(f"Could not connect to Redis: {e}")
|
||||
raise e
|
||||
|
||||
def refresh_azure_credentials():
|
||||
if WEBSOCKET_REDIS_CREDENTIALS == 'azure':
|
||||
if azure_credential_service.is_expired():
|
||||
redis_options["password"] = azure_credential_service.get_token()
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
|
||||
mgr = get_redis_manager()
|
||||
if not mgr:
|
||||
log.error("Could not connect to Redis. Exiting.")
|
||||
sys.exit(1)
|
||||
sio = socketio.AsyncServer(
|
||||
cors_allowed_origins=[],
|
||||
async_mode="asgi",
|
||||
|
@ -47,7 +83,6 @@ else:
|
|||
always_connect=True,
|
||||
)
|
||||
|
||||
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
|
||||
|
@ -55,14 +90,27 @@ TIMEOUT_DURATION = 3
|
|||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
log.debug("Using Redis to manage websockets.")
|
||||
SESSION_POOL = RedisDict("open-webui:session_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USER_POOL = RedisDict("open-webui:user_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
USAGE_POOL = RedisDict("open-webui:usage_pool", redis_url=WEBSOCKET_REDIS_URL)
|
||||
|
||||
SESSION_POOL = RedisDict(
|
||||
"open-webui:session_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_options=redis_options
|
||||
)
|
||||
USER_POOL = RedisDict(
|
||||
"open-webui:user_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_options=redis_options
|
||||
)
|
||||
USAGE_POOL = RedisDict(
|
||||
"open-webui:usage_pool",
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
redis_options=redis_options
|
||||
)
|
||||
clean_up_lock = RedisLock(
|
||||
redis_url=WEBSOCKET_REDIS_URL,
|
||||
lock_name="usage_cleanup_lock",
|
||||
timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
|
||||
redis_options=redis_options
|
||||
)
|
||||
aquire_func = clean_up_lock.aquire_lock
|
||||
renew_func = clean_up_lock.renew_lock
|
||||
|
@ -85,6 +133,10 @@ async def periodic_usage_pool_cleanup():
|
|||
log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
|
||||
raise Exception("Unable to renew usage pool cleanup lock.")
|
||||
|
||||
if WEBSOCKET_MANAGER == "redis":
|
||||
# update redis options to refresh auth token
|
||||
refresh_azure_credentials()
|
||||
|
||||
now = int(time.time())
|
||||
send_usage = False
|
||||
for model_id, connections in list(USAGE_POOL.items()):
|
||||
|
|
|
@ -1,16 +1,106 @@
|
|||
import json
|
||||
import redis
|
||||
import uuid
|
||||
import logging
|
||||
|
||||
from open_webui.utils.azure_services import AzureCredentialService
|
||||
|
||||
from open_webui.env import (
|
||||
WEBSOCKET_REDIS_CREDENTIALS,
|
||||
SRC_LOG_LEVELS
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["SOCKET"])
|
||||
|
||||
|
||||
class RedisService:
|
||||
|
||||
def __init__(self, redis_url, redis_options={}):
|
||||
self.redis_url = redis_url
|
||||
self.client = None
|
||||
self.ssl_ca_certs = redis_options.get("ssl_ca_certs", None)
|
||||
self.username = redis_options.get("username", None)
|
||||
self.password = redis_options.get("password", None)
|
||||
self.azure_credential_service = AzureCredentialService() if WEBSOCKET_REDIS_CREDENTIALS == 'azure' else None
|
||||
self.init_redis()
|
||||
|
||||
def init_redis(self):
|
||||
token = None
|
||||
if self.azure_credential_service:
|
||||
token = self.azure_credential_service.get_token()
|
||||
self.username = self.azure_credential_service.get_username(token)
|
||||
else:
|
||||
token = self.password
|
||||
try:
|
||||
log.debug(f"redis_url: {self.redis_url}")
|
||||
parameters = {
|
||||
"url": self.redis_url,
|
||||
"decode_responses": True,
|
||||
"socket_timeout": 5,
|
||||
}
|
||||
if self.username and token:
|
||||
log.debug(f"redis_username: {self.username}")
|
||||
masked_password = f"{token[:3]}***{token[-3:]}" if token else None
|
||||
log.debug(f"redis_password: {masked_password}")
|
||||
parameters["username"] = self.username
|
||||
parameters["password"] = token
|
||||
if self.redis_url.startswith("rediss://"):
|
||||
log.debug(f"redis_ssl_ca_certs: {self.ssl_ca_certs}")
|
||||
parameters["ssl_ca_certs"] = self.ssl_ca_certs
|
||||
self.client = redis.Redis.from_url(**parameters)
|
||||
|
||||
if self.client.ping():
|
||||
log.info(f"Connected to Redis: {self.redis_url}")
|
||||
else:
|
||||
log.error(f"Failed to connect to Redis: {self.redis_url}")
|
||||
raise e
|
||||
|
||||
except ConnectionError as e:
|
||||
log.error(f"Failed to connect to Redis: {self.redis_url} {e}")
|
||||
raise e
|
||||
except TimeoutError as e:
|
||||
log.error(f"Timed out connecting to Redis: {self.redis_url} {e}")
|
||||
raise e
|
||||
except redis.AuthenticationError as e:
|
||||
log.error(f"Authentication failed connecting to Redis: {self.redis_url} {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
log.error(f"Failed to connect to Redis: {self.redis_url} {e}")
|
||||
raise e
|
||||
|
||||
def get_client(self):
|
||||
return self.client
|
||||
|
||||
# reinitialize the redis connection if an exception occurs
|
||||
def reinit_onerror(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the instance of the class
|
||||
instance = args[0]
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
log.error(f'{instance.__class__}.{func.__name__}: {e}')
|
||||
log.warning(f"Re-authenticate and initialize Redis Cache connection")
|
||||
instance.init_redis()
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
class RedisLock:
|
||||
def __init__(self, redis_url, lock_name, timeout_secs):
|
||||
def __init__(self, redis_url, lock_name, timeout_secs, **redis_kwargs):
|
||||
self.lock_name = lock_name
|
||||
self.lock_id = str(uuid.uuid4())
|
||||
self.timeout_secs = timeout_secs
|
||||
self.lock_obtained = False
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
self.redis_url = redis_url
|
||||
self.redis_kwargs = redis_kwargs
|
||||
self.init_redis()
|
||||
|
||||
|
||||
def init_redis(self):
|
||||
self.redis = RedisService(self.redis_url, **self.redis_kwargs).get_client()
|
||||
|
||||
@reinit_onerror
|
||||
def aquire_lock(self):
|
||||
# nx=True will only set this key if it _hasn't_ already been set
|
||||
self.lock_obtained = self.redis.set(
|
||||
|
@ -18,12 +108,14 @@ class RedisLock:
|
|||
)
|
||||
return self.lock_obtained
|
||||
|
||||
@reinit_onerror
|
||||
def renew_lock(self):
|
||||
# xx=True will only set this key if it _has_ already been set
|
||||
return self.redis.set(
|
||||
self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
|
||||
)
|
||||
|
||||
@reinit_onerror
|
||||
def release_lock(self):
|
||||
lock_value = self.redis.get(self.lock_name)
|
||||
if lock_value and lock_value == self.lock_id:
|
||||
|
@ -31,37 +123,46 @@ class RedisLock:
|
|||
|
||||
|
||||
class RedisDict:
|
||||
def __init__(self, name, redis_url):
|
||||
def __init__(self, name, redis_url, **redis_kwargs):
|
||||
self.name = name
|
||||
self.redis = redis.Redis.from_url(redis_url, decode_responses=True)
|
||||
self.redis_service = RedisService(redis_url, **redis_kwargs)
|
||||
self.redis = self.redis_service.get_client()
|
||||
|
||||
@reinit_onerror
|
||||
def __setitem__(self, key, value):
|
||||
serialized_value = json.dumps(value)
|
||||
self.redis.hset(self.name, key, serialized_value)
|
||||
|
||||
@reinit_onerror
|
||||
def __getitem__(self, key):
|
||||
value = self.redis.hget(self.name, key)
|
||||
if value is None:
|
||||
raise KeyError(key)
|
||||
return json.loads(value)
|
||||
|
||||
@reinit_onerror
|
||||
def __delitem__(self, key):
|
||||
result = self.redis.hdel(self.name, key)
|
||||
if result == 0:
|
||||
raise KeyError(key)
|
||||
|
||||
@reinit_onerror
|
||||
def __contains__(self, key):
|
||||
return self.redis.hexists(self.name, key)
|
||||
|
||||
@reinit_onerror
|
||||
def __len__(self):
|
||||
return self.redis.hlen(self.name)
|
||||
|
||||
@reinit_onerror
|
||||
def keys(self):
|
||||
return self.redis.hkeys(self.name)
|
||||
|
||||
@reinit_onerror
|
||||
def values(self):
|
||||
return [json.loads(v) for v in self.redis.hvals(self.name)]
|
||||
|
||||
@reinit_onerror
|
||||
def items(self):
|
||||
return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
|
||||
|
||||
|
@ -71,6 +172,7 @@ class RedisDict:
|
|||
except KeyError:
|
||||
return default
|
||||
|
||||
@reinit_onerror
|
||||
def clear(self):
|
||||
self.redis.delete(self.name)
|
||||
|
||||
|
|
Loading…
Reference in New Issue