feat(sqlalchemy): format backend

This commit is contained in:
Jonathan Rohde 2024-06-24 09:57:08 +02:00
parent 320e658595
commit c134eab27a
21 changed files with 232 additions and 289 deletions

View File

@ -53,7 +53,9 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
Base = declarative_base()
@ -66,4 +68,3 @@ def get_session():
except Exception as e:
db.rollback()
raise e

View File

@ -126,9 +126,7 @@ class AuthsTable:
else:
return None
def authenticate_user(
self, email: str, password: str
) -> Optional[UserModel]:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
with get_session() as db:
try:
@ -144,9 +142,7 @@ class AuthsTable:
except:
return None
def authenticate_user_by_api_key(
self, api_key: str
) -> Optional[UserModel]:
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
with get_session() as db:
# if no api_key, return None
@ -159,9 +155,7 @@ class AuthsTable:
except:
return False
def authenticate_user_by_trusted_header(
self, email: str
) -> Optional[UserModel]:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
with get_session() as db:
try:
@ -172,12 +166,12 @@ class AuthsTable:
except:
return None
def update_user_password_by_id(
self, id: str, new_password: str
) -> bool:
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
with get_session() as db:
try:
result = db.query(Auth).filter_by(id=id).update({"password": new_password})
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
return True if result == 1 else False
except:
return False

View File

@ -79,9 +79,7 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def insert_new_chat(
self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_session() as db:
id = str(uuid.uuid4())
chat = ChatModel(
@ -89,7 +87,9 @@ class ChatTable:
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"created_at": int(time.time()),
@ -103,9 +103,7 @@ class ChatTable:
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(
self, id: str, chat: dict
) -> Optional[ChatModel]:
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
with get_session() as db:
try:
chat_obj = db.get(Chat, id)
@ -119,9 +117,7 @@ class ChatTable:
except Exception as e:
return None
def insert_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
# Get the existing chat to share
chat = db.get(Chat, chat_id)
@ -145,14 +141,14 @@ class ChatTable:
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_session() as db:
try:
print("update_shared_chat_by_id")
@ -271,9 +267,7 @@ class ChatTable:
except Exception as e:
return None
def get_chat_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[ChatModel]:
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
with get_session() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
@ -293,13 +287,13 @@ class ChatTable:
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(
self, user_id: str
) -> List[ChatModel]:
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
with get_session() as db:
all_chats = (
db.query(Chat)

View File

@ -106,7 +106,9 @@ class DocumentsTable:
def get_docs(self) -> List[DocumentModel]:
with get_session() as db:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
return [
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm

View File

@ -39,6 +39,7 @@ class FileModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################

View File

@ -142,9 +142,9 @@ class FunctionsTable:
with get_session() as db:
return [
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(
type=type, is_active=True
).all()
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
with get_session() as db:
@ -220,10 +220,12 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
with get_session() as db:
db.query(Function).filter_by(id=id).update({
**updated,
"updated_at": int(time.time()),
})
db.query(Function).filter_by(id=id).update(
{
**updated,
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except:
@ -232,10 +234,12 @@ class FunctionsTable:
def deactivate_all_functions(self) -> Optional[bool]:
try:
with get_session() as db:
db.query(Function).update({
"is_active": False,
"updated_at": int(time.time()),
})
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
db.commit()
return True
except:

View File

@ -153,9 +153,7 @@ class ModelsTable:
except:
return None
def update_model_by_id(
self, id: str, model: ModelForm
) -> Optional[ModelModel]:
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
try:
# update only the fields that are present in the model
with get_session() as db:

View File

@ -83,7 +83,9 @@ class PromptsTable:
def get_prompts(self) -> List[PromptModel]:
with get_session() as db:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
return [
PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()
]
def update_prompt_by_command(
self, command: str, form_data: PromptForm

View File

@ -79,9 +79,7 @@ class ChatTagsResponse(BaseModel):
class TagTable:
def insert_new_tag(
self, name: str, user_id: str
) -> Optional[TagModel]:
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
@ -201,11 +199,13 @@ class TagTable:
self, tag_name: str, user_id: str
) -> int:
with get_session() as db:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
return (
db.query(ChatIdTag)
.filter_by(tag_name=tag_name, user_id=user_id)
.count()
)
def delete_tag_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> bool:
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
with get_session() as db:
res = (
@ -252,9 +252,7 @@ class TagTable:
log.error(f"delete_tag: {e}")
return False
def delete_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> bool:
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:

View File

@ -165,9 +165,7 @@ class UsersTable:
except:
return None
def update_user_role_by_id(
self, id: str, role: str
) -> Optional[UserModel]:
def update_user_role_by_id(self, id: str, role: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"role": role})
@ -193,12 +191,12 @@ class UsersTable:
except:
return None
def update_user_last_active_by_id(
self, id: str
) -> Optional[UserModel]:
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
db.query(User).filter_by(id=id).update(
{"last_active_at": int(time.time())}
)
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
@ -217,9 +215,7 @@ class UsersTable:
except:
return None
def update_user_by_id(
self, id: str, updated: dict
) -> Optional[UserModel]:
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
with get_session() as db:
try:
db.query(User).filter_by(id=id).update(updated)

View File

@ -78,8 +78,7 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse)
async def update_profile(
form_data: UpdateProfileForm,
session_user=Depends(get_current_user)
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
):
if session_user:
user = Users.update_user_by_id(
@ -101,8 +100,7 @@ async def update_profile(
@router.post("/update/password", response_model=bool)
async def update_password(
form_data: UpdatePasswordForm,
session_user=Depends(get_current_user)
form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
@ -269,9 +267,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
@router.post("/add", response_model=SigninResponse)
async def add_user(
form_data: AddUserForm, user=Depends(get_admin_user)
):
async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
@ -316,9 +312,7 @@ async def add_user(
@router.get("/admin/details")
async def get_admin_details(
request: Request, user=Depends(get_current_user)
):
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None

View File

@ -55,9 +55,7 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool)
async def delete_all_user_chats(
request: Request, user=Depends(get_current_user)
):
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
@ -95,9 +93,7 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(
form_data: ChatForm, user=Depends(get_current_user)
):
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@ -180,9 +176,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(
share_id: str, user=Depends(get_current_user)
):
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@ -225,9 +219,7 @@ async def get_user_chat_list_by_tag_name(
)
]
chats = Chats.get_chat_list_by_chat_ids(
chat_ids, form_data.skip, form_data.limit
)
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
@ -297,9 +289,7 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(
request: Request, id: str, user=Depends(get_current_user)
):
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
if user.role == "admin":
result = Chats.delete_chat_by_id(id)
@ -347,9 +337,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(
id: str, user=Depends(get_current_user)
):
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
@ -398,9 +386,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(
id: str, user=Depends(get_current_user)
):
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if not chat.share_id:
@ -423,9 +409,7 @@ async def delete_shared_chat_by_id(
@router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id(
id: str, user=Depends(get_current_user)
):
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
@ -443,9 +427,7 @@ async def get_chat_tags_by_id(
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str,
form_data: ChatIdTagForm,
user=Depends(get_current_user)
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
@ -494,9 +476,7 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(
id: str, user=Depends(get_current_user)
):
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result:

View File

@ -44,9 +44,7 @@ async def get_documents(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(
form_data: DocumentForm, user=Depends(get_admin_user)
):
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
if doc == None:
doc = Documents.insert_new_doc(user.id, form_data)
@ -76,9 +74,7 @@ async def create_new_doc(
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(
name: str, user=Depends(get_current_user)
):
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
if doc:
@ -110,12 +106,8 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(
form_data: TagDocumentForm, user=Depends(get_current_user)
):
doc = Documents.update_doc_content_by_name(
form_data.name, {"tags": form_data.tags}
)
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc:
return DocumentResponse(
@ -163,8 +155,6 @@ async def update_doc_by_name(
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(
name: str, user=Depends(get_admin_user)
):
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View File

@ -50,10 +50,7 @@ router = APIRouter()
@router.post("/")
def upload_file(
file: UploadFile = File(...),
user=Depends(get_verified_user)
):
def upload_file(file: UploadFile = File(...), user=Depends(get_verified_user)):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename

View File

@ -167,9 +167,7 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(
memory_id: str, user=Depends(get_verified_user)
):
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:

View File

@ -29,9 +29,7 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(
form_data: PromptForm, user=Depends(get_admin_user)
):
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
@ -54,9 +52,7 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(
command: str, user=Depends(get_current_user)
):
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
@ -95,8 +91,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(
command: str, user=Depends(get_admin_user)
):
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
return result

View File

@ -180,9 +180,7 @@ async def update_toolkit_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:

View File

@ -40,9 +40,7 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
):
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
return Users.get_users(skip, limit)
@ -70,9 +68,7 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
):
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
@ -89,9 +85,7 @@ async def update_user_role(
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(
user=Depends(get_verified_user)
):
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.settings
@ -127,9 +121,7 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(
user=Depends(get_verified_user)
):
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.info
@ -154,9 +146,7 @@ async def update_user_info_by_session_user(
if user.info is None:
user.info = {}
user = Users.update_user_by_id(
user.id, {"info": {**user.info, **form_data}}
)
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
if user:
return user.info
else:
@ -182,9 +172,7 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(
user_id: str, user=Depends(get_verified_user)
):
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
@ -267,9 +255,7 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user)
):
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)

View File

@ -175,7 +175,9 @@ https://github.com/open-webui/open-webui
def run_migrations():
env = os.environ.copy()
env["DATABASE_URL"] = DATABASE_URL
migration_task = subprocess.run(["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env)
migration_task = subprocess.run(
["alembic", f"-c{BACKEND_DIR}/alembic.ini", "upgrade", "head"], env=env
)
if migration_task.returncode > 0:
raise ValueError("Error running migrations")

View File

@ -5,6 +5,7 @@ Revises:
Create Date: 2024-06-24 09:09:11.636336
"""
from typing import Sequence, Union
from alembic import op
@ -13,7 +14,7 @@ import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = 'ba76b0bae648'
revision: str = "ba76b0bae648"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@ -21,141 +22,153 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('auth',
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.String(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('chat',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('chat', sa.String(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.String(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('share_id')
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("chat", sa.String(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.String(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
op.create_table('chatidtag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('document',
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint('name')
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"file",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"function",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('memory',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('model',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('base_model_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"model",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("base_model_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('prompt',
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('command')
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
op.create_table('tag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('tool',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("valves", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
op.create_table('user',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.String(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key')
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.String(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user')
op.drop_table('tool')
op.drop_table('tag')
op.drop_table('prompt')
op.drop_table('model')
op.drop_table('memory')
op.drop_table('function')
op.drop_table('file')
op.drop_table('document')
op.drop_table('chatidtag')
op.drop_table('chat')
op.drop_table('auth')
op.drop_table("user")
op.drop_table("tool")
op.drop_table("tag")
op.drop_table("prompt")
op.drop_table("model")
op.drop_table("memory")
op.drop_table("function")
op.drop_table("file")
op.drop_table("document")
op.drop_table("chatidtag")
op.drop_table("chat")
op.drop_table("auth")
# ### end Alembic commands ###

View File

@ -91,6 +91,7 @@ class AbstractPostgresTest(AbstractIntegrationTest):
while retries > 0:
try:
from config import BACKEND_DIR
db = create_engine(database_url, pool_pre_ping=True)
db = db.connect()
log.info("postgres is ready!")