Merge pull request #2598 from open-webui/user-settings

feat: user settings to db
This commit is contained in:
Timothy Jaeryang Baek 2024-05-26 22:48:13 -07:00 committed by GitHub
commit dd959c8f88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 235 additions and 70 deletions

View File

@ -198,7 +198,7 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
@ -237,7 +237,7 @@ async def get_all_models():
]
responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
log.debug(f"get_all_models:responses() {responses}")
models = {
"data": merge_models_lists(
@ -254,7 +254,7 @@ async def get_all_models():
)
}
log.info(f"models: {models}")
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models

View File

@ -0,0 +1,48 @@
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
# Adding fields settings to the 'user' table
migrator.add_fields("user", settings=pw.TextField(null=True))
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
# Remove the settings field
migrator.remove_fields("user", "settings")

View File

@ -1,11 +1,11 @@
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.misc import get_gravatar_url
from apps.webui.internal.db import DB
from apps.webui.internal.db import DB, JSONField
from apps.webui.models.chats import Chats
####################
@ -25,11 +25,18 @@ class User(Model):
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True)
settings = JSONField(null=True)
class Meta:
database = DB
class UserSettings(BaseModel):
ui: Optional[dict] = {}
model_config = ConfigDict(extra="allow")
pass
class UserModel(BaseModel):
id: str
name: str
@ -42,6 +49,7 @@ class UserModel(BaseModel):
created_at: int # timestamp in epoch
api_key: Optional[str] = None
settings: Optional[UserSettings] = None
####################

View File

@ -9,7 +9,13 @@ import time
import uuid
import logging
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
UserRoleUpdateForm,
UserSettings,
Users,
)
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
@ -68,6 +74,42 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
)
############################
# GetUserSettingsBySessionUser
############################
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserSettingsBySessionUser
############################
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# GetUserById
############################
@ -81,6 +123,8 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)

View File

@ -115,6 +115,62 @@ export const getUsers = async (token: string) => {
return res ? res : [];
};
export const getUserSettings = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateUserSettings = async (token: string, settings: object) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/users/user/settings/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
...settings
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getUserById = async (token: string, userId: string) => {
let error = null;

View File

@ -42,6 +42,7 @@
import type { Writable } from 'svelte/store';
import type { i18n as i18nType } from 'i18next';
import Banner from '../common/Banner.svelte';
import { getUserSettings } from '$lib/apis/users';
const i18n: Writable<i18nType> = getContext('i18n');
@ -154,10 +155,13 @@
$models.map((m) => m.id).includes(modelId) ? modelId : ''
);
let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
settings.set({
..._settings
});
const userSettings = await getUserSettings(localStorage.token);
if (userSettings) {
settings.set(userSettings.ui);
} else {
settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
}
const chatInput = document.getElementById('chat-textarea');
setTimeout(() => chatInput?.focus(), 0);
@ -187,11 +191,18 @@
: convertMessagesToHistory(chatContent.messages);
title = chatContent.title;
let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
const userSettings = await getUserSettings(localStorage.token);
if (userSettings) {
await settings.set(userSettings.ui);
} else {
await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
}
await settings.set({
..._settings,
system: chatContent.system ?? _settings.system,
params: chatContent.options ?? _settings.params
...$settings,
system: chatContent.system ?? $settings.system,
params: chatContent.options ?? $settings.params
});
autoScroll = true;
await tick();

View File

@ -1,13 +1,13 @@
<script lang="ts">
import { Collapsible } from 'bits-ui';
import { setDefaultModels } from '$lib/apis/configs';
import { models, showSettings, settings, user, mobile } from '$lib/stores';
import { onMount, tick, getContext } from 'svelte';
import { toast } from 'svelte-sonner';
import Selector from './ModelSelector/Selector.svelte';
import Tooltip from '../common/Tooltip.svelte';
import { setDefaultModels } from '$lib/apis/configs';
import { updateUserSettings } from '$lib/apis/users';
const i18n = getContext('i18n');
export let selectedModels = [''];
@ -22,7 +22,9 @@
return;
}
settings.set({ ...$settings, models: selectedModels });
localStorage.setItem('settings', JSON.stringify($settings));
await updateUserSettings(localStorage.token, { ui: $settings });
if ($user.role === 'admin') {
console.log('setting default models globally');

View File

@ -1,6 +1,6 @@
<script lang="ts">
import { getAudioConfig, updateAudioConfig } from '$lib/apis/audio';
import { user } from '$lib/stores';
import { user, settings } from '$lib/stores';
import { createEventDispatcher, onMount, getContext } from 'svelte';
import { toast } from 'svelte-sonner';
const dispatch = createEventDispatcher();
@ -99,16 +99,14 @@
};
onMount(async () => {
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
conversationMode = $settings.conversationMode ?? false;
speechAutoSend = $settings.speechAutoSend ?? false;
responseAutoPlayback = $settings.responseAutoPlayback ?? false;
conversationMode = settings.conversationMode ?? false;
speechAutoSend = settings.speechAutoSend ?? false;
responseAutoPlayback = settings.responseAutoPlayback ?? false;
STTEngine = settings?.audio?.STTEngine ?? '';
TTSEngine = settings?.audio?.TTSEngine ?? '';
speaker = settings?.audio?.speaker ?? '';
model = settings?.audio?.model ?? '';
STTEngine = $settings?.audio?.STTEngine ?? '';
TTSEngine = $settings?.audio?.TTSEngine ?? '';
speaker = $settings?.audio?.speaker ?? '';
model = $settings?.audio?.model ?? '';
if (TTSEngine === 'openai') {
getOpenAIVoices();

View File

@ -2,7 +2,7 @@
import fileSaver from 'file-saver';
const { saveAs } = fileSaver;
import { chats, user, config } from '$lib/stores';
import { chats, user, settings } from '$lib/stores';
import {
archiveAllChats,
@ -99,9 +99,7 @@
};
onMount(async () => {
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
saveChatHistory = settings.saveChatHistory ?? true;
saveChatHistory = $settings.saveChatHistory ?? true;
});
</script>

View File

@ -4,7 +4,7 @@
import { getLanguages } from '$lib/i18n';
const dispatch = createEventDispatcher();
import { models, user, theme } from '$lib/stores';
import { models, settings, theme } from '$lib/stores';
const i18n = getContext('i18n');
@ -71,23 +71,22 @@
onMount(async () => {
selectedTheme = localStorage.theme ?? 'system';
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
languages = await getLanguages();
notificationEnabled = settings.notificationEnabled ?? false;
system = settings.system ?? '';
notificationEnabled = $settings.notificationEnabled ?? false;
system = $settings.system ?? '';
requestFormat = settings.requestFormat ?? '';
keepAlive = settings.keepAlive ?? null;
requestFormat = $settings.requestFormat ?? '';
keepAlive = $settings.keepAlive ?? null;
params.seed = settings.seed ?? 0;
params.temperature = settings.temperature ?? '';
params.frequency_penalty = settings.frequency_penalty ?? '';
params.top_k = settings.top_k ?? '';
params.top_p = settings.top_p ?? '';
params.num_ctx = settings.num_ctx ?? '';
params = { ...params, ...settings.params };
params.stop = settings?.params?.stop ? (settings?.params?.stop ?? []).join(',') : null;
params.seed = $settings.seed ?? 0;
params.temperature = $settings.temperature ?? '';
params.frequency_penalty = $settings.frequency_penalty ?? '';
params.top_k = $settings.top_k ?? '';
params.top_p = $settings.top_p ?? '';
params.num_ctx = $settings.num_ctx ?? '';
params = { ...params, ...$settings.params };
params.stop = $settings?.params?.stop ? ($settings?.params?.stop ?? []).join(',') : null;
});
const applyTheme = (_theme: string) => {

View File

@ -104,23 +104,18 @@
promptSuggestions = $config?.default_prompt_suggestions;
}
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
titleAutoGenerate = settings?.title?.auto ?? true;
titleAutoGenerateModel = settings?.title?.model ?? '';
titleAutoGenerateModelExternal = settings?.title?.modelExternal ?? '';
titleAutoGenerate = $settings?.title?.auto ?? true;
titleAutoGenerateModel = $settings?.title?.model ?? '';
titleAutoGenerateModelExternal = $settings?.title?.modelExternal ?? '';
titleGenerationPrompt =
settings?.title?.prompt ??
$i18n.t(
"Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title':"
) + ' {{prompt}}';
responseAutoCopy = settings.responseAutoCopy ?? false;
showUsername = settings.showUsername ?? false;
chatBubble = settings.chatBubble ?? true;
fullScreenMode = settings.fullScreenMode ?? false;
splitLargeChunks = settings.splitLargeChunks ?? false;
chatDirection = settings.chatDirection ?? 'LTR';
$settings?.title?.prompt ??
`Create a concise, 3-5 word phrase as a header for the following query, strictly adhering to the 3-5 word limit and avoiding the use of the word 'title': {{prompt}}`;
responseAutoCopy = $settings.responseAutoCopy ?? false;
showUsername = $settings.showUsername ?? false;
chatBubble = $settings.chatBubble ?? true;
fullScreenMode = $settings.fullScreenMode ?? false;
splitLargeChunks = $settings.splitLargeChunks ?? false;
chatDirection = $settings.chatDirection ?? 'LTR';
});
</script>

View File

@ -19,8 +19,7 @@
let enableMemory = false;
onMount(async () => {
let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
enableMemory = settings?.memory ?? false;
enableMemory = $settings?.memory ?? false;
});
</script>

View File

@ -17,6 +17,7 @@
import Images from './Settings/Images.svelte';
import User from '../icons/User.svelte';
import Personalization from './Settings/Personalization.svelte';
import { updateUserSettings } from '$lib/apis/users';
const i18n = getContext('i18n');
@ -26,7 +27,9 @@
console.log(updated);
await settings.set({ ...$settings, ...updated });
await models.set(await getModels());
localStorage.setItem('settings', JSON.stringify($settings));
await updateUserSettings(localStorage.token, { ui: $settings });
};
const getModels = async () => {

View File

@ -33,6 +33,7 @@
import ArchiveBox from '../icons/ArchiveBox.svelte';
import ArchivedChatsModal from './Sidebar/ArchivedChatsModal.svelte';
import UserMenu from './Sidebar/UserMenu.svelte';
import { updateUserSettings } from '$lib/apis/users';
const BREAKPOINT = 768;
@ -184,6 +185,8 @@
const saveSettings = async (updated) => {
await settings.set({ ...$settings, ...updated });
localStorage.setItem('settings', JSON.stringify($settings));
await updateUserSettings(localStorage.token, { ui: $settings });
location.href = '/';
};

View File

@ -35,6 +35,7 @@
import ChangelogModal from '$lib/components/ChangelogModal.svelte';
import Tooltip from '$lib/components/common/Tooltip.svelte';
import { getBanners } from '$lib/apis/configs';
import { getUserSettings } from '$lib/apis/users';
const i18n = getContext('i18n');
@ -72,7 +73,13 @@
// IndexedDB Not Found
}
settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
const userSettings = await getUserSettings(localStorage.token);
if (userSettings) {
await settings.set(userSettings.ui);
} else {
await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
}
await Promise.all([
(async () => {

View File

@ -98,12 +98,6 @@
: convertMessagesToHistory(chatContent.messages);
title = chatContent.title;
let _settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
await settings.set({
..._settings,
system: chatContent.system ?? _settings.system,
options: chatContent.options ?? _settings.options
});
autoScroll = true;
await tick();