mirror of https://github.com/open-webui/open-webui
feat(sqlalchemy): format backend
This commit is contained in:
parent
320e658595
commit
c134eab27a
|
@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
|
|||
)
|
||||
else:
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
|
||||
SessionLocal = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
|
||||
)
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
|
@ -66,4 +68,3 @@ def get_session():
|
|||
except Exception as e:
|
||||
db.rollback()
|
||||
raise e
|
||||
|
||||
|
|
|
@ -126,9 +126,7 @@ class AuthsTable:
|
|||
else:
|
||||
return None
|
||||
|
||||
def authenticate_user(
|
||||
self, email: str, password: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user: {email}")
|
||||
with get_session() as db:
|
||||
try:
|
||||
|
@ -144,9 +142,7 @@ class AuthsTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def authenticate_user_by_api_key(
|
||||
self, api_key: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_api_key: {api_key}")
|
||||
with get_session() as db:
|
||||
# if no api_key, return None
|
||||
|
@ -159,9 +155,7 @@ class AuthsTable:
|
|||
except:
|
||||
return False
|
||||
|
||||
def authenticate_user_by_trusted_header(
|
||||
self, email: str
|
||||
) -> Optional[UserModel]:
|
||||
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
|
||||
log.info(f"authenticate_user_by_trusted_header: {email}")
|
||||
with get_session() as db:
|
||||
try:
|
||||
|
@ -172,12 +166,12 @@ class AuthsTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_user_password_by_id(
|
||||
self, id: str, new_password: str
|
||||
) -> bool:
|
||||
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
|
||||
with get_session() as db:
|
||||
try:
|
||||
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
result = (
|
||||
db.query(Auth).filter_by(id=id).update({"password": new_password})
|
||||
)
|
||||
return True if result == 1 else False
|
||||
except:
|
||||
return False
|
||||
|
|
|
@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel):
|
|||
|
||||
class ChatTable:
|
||||
|
||||
def insert_new_chat(
|
||||
self, user_id: str, form_data: ChatForm
|
||||
) -> Optional[ChatModel]:
|
||||
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
|
@ -89,7 +87,9 @@ class ChatTable:
|
|||
"id": id,
|
||||
"user_id": user_id,
|
||||
"title": (
|
||||
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
|
||||
form_data.chat["title"]
|
||||
if "title" in form_data.chat
|
||||
else "New Chat"
|
||||
),
|
||||
"chat": json.dumps(form_data.chat),
|
||||
"created_at": int(time.time()),
|
||||
|
@ -103,9 +103,7 @@ class ChatTable:
|
|||
db.refresh(result)
|
||||
return ChatModel.model_validate(result) if result else None
|
||||
|
||||
def update_chat_by_id(
|
||||
self, id: str, chat: dict
|
||||
) -> Optional[ChatModel]:
|
||||
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
chat_obj = db.get(Chat, id)
|
||||
|
@ -119,9 +117,7 @@ class ChatTable:
|
|||
except Exception as e:
|
||||
return None
|
||||
|
||||
def insert_shared_chat_by_chat_id(
|
||||
self, chat_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
# Get the existing chat to share
|
||||
chat = db.get(Chat, chat_id)
|
||||
|
@ -145,14 +141,14 @@ class ChatTable:
|
|||
db.refresh(shared_result)
|
||||
# Update the original chat with the share_id
|
||||
result = (
|
||||
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
|
||||
db.query(Chat)
|
||||
.filter_by(id=chat_id)
|
||||
.update({"share_id": shared_chat.id})
|
||||
)
|
||||
|
||||
return shared_chat if (shared_result and result) else None
|
||||
|
||||
def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
print("update_shared_chat_by_id")
|
||||
|
@ -271,9 +267,7 @@ class ChatTable:
|
|||
except Exception as e:
|
||||
return None
|
||||
|
||||
def get_chat_by_id_and_user_id(
|
||||
self, id: str, user_id: str
|
||||
) -> Optional[ChatModel]:
|
||||
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
|
||||
|
@ -293,13 +287,13 @@ class ChatTable:
|
|||
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
with get_session() as db:
|
||||
all_chats = (
|
||||
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
|
||||
db.query(Chat)
|
||||
.filter_by(user_id=user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
def get_archived_chats_by_user_id(
|
||||
self, user_id: str
|
||||
) -> List[ChatModel]:
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
with get_session() as db:
|
||||
all_chats = (
|
||||
db.query(Chat)
|
||||
|
|
|
@ -106,7 +106,9 @@ class DocumentsTable:
|
|||
|
||||
def get_docs(self) -> List[DocumentModel]:
|
||||
with get_session() as db:
|
||||
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
|
||||
return [
|
||||
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
|
||||
]
|
||||
|
||||
def update_doc_by_name(
|
||||
self, name: str, form_data: DocumentUpdateForm
|
||||
|
|
|
@ -39,6 +39,7 @@ class FileModel(BaseModel):
|
|||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
|
|
@ -142,9 +142,9 @@ class FunctionsTable:
|
|||
with get_session() as db:
|
||||
return [
|
||||
FunctionModel.model_validate(function)
|
||||
for function in db.query(Function).filter_by(
|
||||
type=type, is_active=True
|
||||
).all()
|
||||
for function in db.query(Function)
|
||||
.filter_by(type=type, is_active=True)
|
||||
.all()
|
||||
]
|
||||
else:
|
||||
with get_session() as db:
|
||||
|
@ -220,10 +220,12 @@ class FunctionsTable:
|
|||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
db.query(Function).filter_by(id=id).update({
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
})
|
||||
db.query(Function).filter_by(id=id).update(
|
||||
{
|
||||
**updated,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return self.get_function_by_id(id)
|
||||
except:
|
||||
|
@ -232,10 +234,12 @@ class FunctionsTable:
|
|||
def deactivate_all_functions(self) -> Optional[bool]:
|
||||
try:
|
||||
with get_session() as db:
|
||||
db.query(Function).update({
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
})
|
||||
db.query(Function).update(
|
||||
{
|
||||
"is_active": False,
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
db.commit()
|
||||
return True
|
||||
except:
|
||||
|
|
|
@ -153,9 +153,7 @@ class ModelsTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_model_by_id(
|
||||
self, id: str, model: ModelForm
|
||||
) -> Optional[ModelModel]:
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
try:
|
||||
# update only the fields that are present in the model
|
||||
with get_session() as db:
|
||||
|
|
|
@ -83,7 +83,9 @@ class PromptsTable:
|
|||
|
||||
def get_prompts(self) -> List[PromptModel]:
|
||||
with get_session() as db:
|
||||
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
|
||||
return [
|
||||
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
|
||||
]
|
||||
|
||||
def update_prompt_by_command(
|
||||
self, command: str, form_data: PromptForm
|
||||
|
|
|
@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel):
|
|||
|
||||
class TagTable:
|
||||
|
||||
def insert_new_tag(
|
||||
self, name: str, user_id: str
|
||||
) -> Optional[TagModel]:
|
||||
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||
id = str(uuid.uuid4())
|
||||
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
|
||||
try:
|
||||
|
@ -201,11 +199,13 @@ class TagTable:
|
|||
self, tag_name: str, user_id: str
|
||||
) -> int:
|
||||
with get_session() as db:
|
||||
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
|
||||
return (
|
||||
db.query(ChatIdTag)
|
||||
.filter_by(tag_name=tag_name, user_id=user_id)
|
||||
.count()
|
||||
)
|
||||
|
||||
def delete_tag_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str
|
||||
) -> bool:
|
||||
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
|
||||
try:
|
||||
with get_session() as db:
|
||||
res = (
|
||||
|
@ -252,9 +252,7 @@ class TagTable:
|
|||
log.error(f"delete_tag: {e}")
|
||||
return False
|
||||
|
||||
def delete_tags_by_chat_id_and_user_id(
|
||||
self, chat_id: str, user_id: str
|
||||
) -> bool:
|
||||
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
|
||||
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
|
||||
|
||||
for tag in tags:
|
||||
|
|
|
@ -165,9 +165,7 @@ class UsersTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_user_role_by_id(
|
||||
self, id: str, role: str
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update({"role": role})
|
||||
|
@ -193,12 +191,12 @@ class UsersTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_user_last_active_by_id(
|
||||
self, id: str
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
|
||||
db.query(User).filter_by(id=id).update(
|
||||
{"last_active_at": int(time.time())}
|
||||
)
|
||||
|
||||
user = db.query(User).filter_by(id=id).first()
|
||||
return UserModel.model_validate(user)
|
||||
|
@ -217,9 +215,7 @@ class UsersTable:
|
|||
except:
|
||||
return None
|
||||
|
||||
def update_user_by_id(
|
||||
self, id: str, updated: dict
|
||||
) -> Optional[UserModel]:
|
||||
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
|
||||
with get_session() as db:
|
||||
try:
|
||||
db.query(User).filter_by(id=id).update(updated)
|
||||
|
|
|
@ -78,8 +78,7 @@ async def get_session_user(
|
|||
|
||||
@router.post("/update/profile", response_model=UserResponse)
|
||||
async def update_profile(
|
||||
form_data: UpdateProfileForm,
|
||||
session_user=Depends(get_current_user)
|
||||
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
|
||||
):
|
||||
if session_user:
|
||||
user = Users.update_user_by_id(
|
||||
|
@ -101,8 +100,7 @@ async def update_profile(
|
|||
|
||||
@router.post("/update/password", response_model=bool)
|
||||
async def update_password(
|
||||
form_data: UpdatePasswordForm,
|
||||
session_user=Depends(get_current_user)
|
||||
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
|
||||
):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||
|
@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
|||
|
||||
|
||||
@router.post("/add", response_model=SigninResponse)
|
||||
async def add_user(
|
||||
form_data: AddUserForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
|
@ -316,9 +312,7 @@ async def add_user(
|
|||
|
||||
|
||||
@router.get("/admin/details")
|
||||
async def get_admin_details(
|
||||
request: Request, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
|
|
@ -55,9 +55,7 @@ async def get_session_user_chat_list(
|
|||
|
||||
|
||||
@router.delete("/", response_model=bool)
|
||||
async def delete_all_user_chats(
|
||||
request: Request, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
|
||||
|
||||
if (
|
||||
user.role == "user"
|
||||
|
@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id(
|
|||
|
||||
|
||||
@router.post("/new", response_model=Optional[ChatResponse])
|
||||
async def create_new_chat(
|
||||
form_data: ChatForm, user=Depends(get_current_user)
|
||||
):
|
||||
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
chat = Chats.insert_new_chat(user.id, form_data)
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
|
@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(
|
||||
share_id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
|
||||
if user.role == "pending":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||
|
@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name(
|
|||
)
|
||||
]
|
||||
|
||||
chats = Chats.get_chat_list_by_chat_ids(
|
||||
chat_ids, form_data.skip, form_data.limit
|
||||
)
|
||||
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
|
||||
|
||||
if len(chats) == 0:
|
||||
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
|
||||
|
@ -297,9 +289,7 @@ async def update_chat_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}", response_model=bool)
|
||||
async def delete_chat_by_id(
|
||||
request: Request, id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
|
||||
|
||||
if user.role == "admin":
|
||||
result = Chats.delete_chat_by_id(id)
|
||||
|
@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
|
||||
async def archive_chat_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
chat = Chats.toggle_chat_archive_by_id(id)
|
||||
|
@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.delete("/{id}/share", response_model=Optional[bool])
|
||||
async def delete_shared_chat_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
if not chat.share_id:
|
||||
|
@ -423,9 +409,7 @@ async def delete_shared_chat_by_id(
|
|||
|
||||
|
||||
@router.get("/{id}/tags", response_model=List[TagModel])
|
||||
async def get_chat_tags_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if tags != None:
|
||||
|
@ -443,9 +427,7 @@ async def get_chat_tags_by_id(
|
|||
|
||||
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
|
||||
async def add_chat_tag_by_id(
|
||||
id: str,
|
||||
form_data: ChatIdTagForm,
|
||||
user=Depends(get_current_user)
|
||||
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
|
||||
):
|
||||
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
|
@ -494,9 +476,7 @@ async def delete_chat_tag_by_id(
|
|||
|
||||
|
||||
@router.delete("/{id}/tags/all", response_model=Optional[bool])
|
||||
async def delete_all_chat_tags_by_id(
|
||||
id: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
|
||||
|
||||
if result:
|
||||
|
|
|
@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[DocumentResponse])
|
||||
async def create_new_doc(
|
||||
form_data: DocumentForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
|
||||
doc = Documents.get_doc_by_name(form_data.name)
|
||||
if doc == None:
|
||||
doc = Documents.insert_new_doc(user.id, form_data)
|
||||
|
@ -76,9 +74,7 @@ async def create_new_doc(
|
|||
|
||||
|
||||
@router.get("/doc", response_model=Optional[DocumentResponse])
|
||||
async def get_doc_by_name(
|
||||
name: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
|
||||
doc = Documents.get_doc_by_name(name)
|
||||
|
||||
if doc:
|
||||
|
@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel):
|
|||
|
||||
|
||||
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
|
||||
async def tag_doc_by_name(
|
||||
form_data: TagDocumentForm, user=Depends(get_current_user)
|
||||
):
|
||||
doc = Documents.update_doc_content_by_name(
|
||||
form_data.name, {"tags": form_data.tags}
|
||||
)
|
||||
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
|
||||
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
|
||||
|
||||
if doc:
|
||||
return DocumentResponse(
|
||||
|
@ -163,8 +155,6 @@ async def update_doc_by_name(
|
|||
|
||||
|
||||
@router.delete("/doc/delete", response_model=bool)
|
||||
async def delete_doc_by_name(
|
||||
name: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
|
||||
result = Documents.delete_doc_by_name(name)
|
||||
return result
|
||||
|
|
|
@ -50,10 +50,7 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.post("/")
|
||||
def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
|
|
|
@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
|
|||
|
||||
|
||||
@router.delete("/{memory_id}", response_model=bool)
|
||||
async def delete_memory_by_id(
|
||||
memory_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
|
||||
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
|
||||
|
||||
if result:
|
||||
|
|
|
@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)):
|
|||
|
||||
|
||||
@router.post("/create", response_model=Optional[PromptModel])
|
||||
async def create_new_prompt(
|
||||
form_data: PromptForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
|
||||
prompt = Prompts.get_prompt_by_command(form_data.command)
|
||||
if prompt == None:
|
||||
prompt = Prompts.insert_new_prompt(user.id, form_data)
|
||||
|
@ -54,9 +52,7 @@ async def create_new_prompt(
|
|||
|
||||
|
||||
@router.get("/command/{command}", response_model=Optional[PromptModel])
|
||||
async def get_prompt_by_command(
|
||||
command: str, user=Depends(get_current_user)
|
||||
):
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
|
||||
prompt = Prompts.get_prompt_by_command(f"/{command}")
|
||||
|
||||
if prompt:
|
||||
|
@ -95,8 +91,6 @@ async def update_prompt_by_command(
|
|||
|
||||
|
||||
@router.delete("/command/{command}/delete", response_model=bool)
|
||||
async def delete_prompt_by_command(
|
||||
command: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
|
||||
result = Prompts.delete_prompt_by_command(f"/{command}")
|
||||
return result
|
||||
|
|
|
@ -180,9 +180,7 @@ async def update_toolkit_by_id(
|
|||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_toolkit_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
|
||||
if result:
|
||||
|
|
|
@ -40,9 +40,7 @@ router = APIRouter()
|
|||
|
||||
|
||||
@router.get("/", response_model=List[UserModel])
|
||||
async def get_users(
|
||||
skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
|
||||
):
|
||||
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
|
||||
return Users.get_users(skip, limit)
|
||||
|
||||
|
||||
|
@ -70,9 +68,7 @@ async def update_user_permissions(
|
|||
|
||||
|
||||
@router.post("/update/role", response_model=Optional[UserModel])
|
||||
async def update_user_role(
|
||||
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
|
||||
|
||||
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
|
||||
return Users.update_user_role_by_id(form_data.id, form_data.role)
|
||||
|
@ -89,9 +85,7 @@ async def update_user_role(
|
|||
|
||||
|
||||
@router.get("/user/settings", response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.settings
|
||||
|
@ -127,9 +121,7 @@ async def update_user_settings_by_session_user(
|
|||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(
|
||||
user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
|
@ -154,9 +146,7 @@ async def update_user_info_by_session_user(
|
|||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(
|
||||
user.id, {"info": {**user.info, **form_data}}
|
||||
)
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
|
@ -182,9 +172,7 @@ class UserResponse(BaseModel):
|
|||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user_by_id(
|
||||
user_id: str, user=Depends(get_verified_user)
|
||||
):
|
||||
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if user_id is a shared chat
|
||||
# If it is, get the user_id from the chat
|
||||
|
@ -267,9 +255,7 @@ async def update_user_by_id(
|
|||
|
||||
|
||||
@router.delete("/{user_id}", response_model=bool)
|
||||
async def delete_user_by_id(
|
||||
user_id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
|
||||
if user.id != user_id:
|
||||
result = Auths.delete_auth_by_id(user_id)
|
||||
|
||||
|
|
|
@ -175,7 +175,9 @@ https://github.com/open-webui/open-webui
|
|||
def run_migrations():
|
||||
env = os.environ.copy()
|
||||
env["DATABASE_URL"] = DATABASE_URL
|
||||
migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env)
|
||||
migration_task = subprocess.run(
|
||||
["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env
|
||||
)
|
||||
if migration_task.returncode > 0:
|
||||
raise ValueError("Error running migrations")
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ Revises:
|
|||
Create Date: 2024-06-24 09:09:11.636336
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
@ -13,7 +14,7 @@ import apps.webui.internal.db
|
|||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ba76b0bae648'
|
||||
revision: str = "ba76b0bae648"
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
@ -21,141 +22,153 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('auth',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('password', sa.String(), nullable=True),
|
||||
sa.Column('active', sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"auth",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("password", sa.String(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('chat',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=True),
|
||||
sa.Column('chat', sa.String(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('share_id', sa.String(), nullable=True),
|
||||
sa.Column('archived', sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('share_id')
|
||||
op.create_table(
|
||||
"chat",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("chat", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("share_id", sa.String(), nullable=True),
|
||||
sa.Column("archived", sa.Boolean(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("share_id"),
|
||||
)
|
||||
op.create_table('chatidtag',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('tag_name', sa.String(), nullable=True),
|
||||
sa.Column('chat_id', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"chatidtag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("tag_name", sa.String(), nullable=True),
|
||||
sa.Column("chat_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('document',
|
||||
sa.Column('collection_name', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=True),
|
||||
sa.Column('filename', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('collection_name'),
|
||||
sa.UniqueConstraint('name')
|
||||
op.create_table(
|
||||
"document",
|
||||
sa.Column("collection_name", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("filename", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("collection_name"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table('file',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('filename', sa.String(), nullable=True),
|
||||
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"file",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("filename", sa.String(), nullable=True),
|
||||
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('function',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('type', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"function",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('memory',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.String(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"memory",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.String(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('model',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('base_model_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"model",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("base_model_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('prompt',
|
||||
sa.Column('command', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('command')
|
||||
op.create_table(
|
||||
"prompt",
|
||||
sa.Column("command", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("command"),
|
||||
)
|
||||
op.create_table('tag',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('data', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"tag",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("data", sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('tool',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.String(), nullable=True),
|
||||
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
op.create_table(
|
||||
"tool",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.String(), nullable=True),
|
||||
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('role', sa.String(), nullable=True),
|
||||
sa.Column('profile_image_url', sa.String(), nullable=True),
|
||||
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('api_key', sa.String(), nullable=True),
|
||||
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('api_key')
|
||||
op.create_table(
|
||||
"user",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("role", sa.String(), nullable=True),
|
||||
sa.Column("profile_image_url", sa.String(), nullable=True),
|
||||
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("api_key"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('user')
|
||||
op.drop_table('tool')
|
||||
op.drop_table('tag')
|
||||
op.drop_table('prompt')
|
||||
op.drop_table('model')
|
||||
op.drop_table('memory')
|
||||
op.drop_table('function')
|
||||
op.drop_table('file')
|
||||
op.drop_table('document')
|
||||
op.drop_table('chatidtag')
|
||||
op.drop_table('chat')
|
||||
op.drop_table('auth')
|
||||
op.drop_table("user")
|
||||
op.drop_table("tool")
|
||||
op.drop_table("tag")
|
||||
op.drop_table("prompt")
|
||||
op.drop_table("model")
|
||||
op.drop_table("memory")
|
||||
op.drop_table("function")
|
||||
op.drop_table("file")
|
||||
op.drop_table("document")
|
||||
op.drop_table("chatidtag")
|
||||
op.drop_table("chat")
|
||||
op.drop_table("auth")
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
@ -91,6 +91,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
|
|||
while retries > 0:
|
||||
try:
|
||||
from config import BACKEND_DIR
|
||||
|
||||
db = create_engine(database_url, pool_pre_ping=True)
|
||||
db = db.connect()
|
||||
log.info("postgres is ready!")
|
||||
|
|
Loading…
Reference in New Issue