This commit is contained in:
Thomas Armstrong 2025-03-12 11:12:52 +08:00 committed by GitHub
commit 1af75f7a3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 104 additions and 7 deletions

View File

@ -15,6 +15,7 @@ from open_webui.models.files import (
FileModelResponse,
Files,
)
from open_webui.routers.knowledge import get_knowledge, get_knowledge_list
from open_webui.routers.retrieval import ProcessFileForm, process_file
from open_webui.routers.audio import transcribe
from open_webui.storage.provider import Storage
@ -27,6 +28,43 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# Check if the current user has access to a file through any knowledge bases the user may be in.
############################
async def check_user_has_access_to_file_via_any_knowledge_base(file_id: Optional[str], access_type: str, user=Depends(get_verified_user)) -> bool:
file = Files.get_file_by_id(file_id)
log.debug(f"Checking if user has {access_type} access to file")
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_access = False
knowledge_base_id = file.meta.get("collection_name") if file.meta else None
log.debug(f"Knowledge base associated with file: {knowledge_base_id}")
if knowledge_base_id:
if access_type == "read":
user_access = await get_knowledge(user=user) # get_knowledge checks for read access
elif access_type == "write":
user_access = await get_knowledge_list(user=user) # get_knowledge_list checks for write access
else:
user_access = list()
for knowledge_base in user_access:
if knowledge_base.id == knowledge_base_id:
log.debug(f"User knowledge base with {access_type} access {knowledge_base.id} == File knowledge base {knowledge_base_id}")
has_access = True
break
log.debug(f"Does user have {access_type} access to file: {has_access}")
return has_access
############################
# Upload File
############################
@ -160,7 +198,15 @@ async def delete_all_files(user=Depends(get_admin_user)):
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
if file.user_id == user.id or user.role == "admin" or has_read_access:
return file
else:
raise HTTPException(
@ -178,7 +224,15 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
if file.user_id == user.id or user.role == "admin" or has_read_access:
return {"content": file.data.get("content", "")}
else:
raise HTTPException(
@ -202,7 +256,15 @@ async def update_file_data_content_by_id(
):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
if file.user_id == user.id or user.role == "admin" or has_write_access:
try:
process_file(
request,
@ -230,7 +292,16 @@ async def update_file_data_content_by_id(
@router.get("/{id}/content")
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
if file.user_id == user.id or user.role == "admin" or has_read_access:
try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
@ -282,7 +353,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content/html")
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
if file.user_id == user.id or user.role == "admin" or has_read_access:
try:
file_path = Storage.get_file(file.path)
file_path = Path(file_path)
@ -314,7 +394,15 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user)):
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_read_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "read", user)
if file.user_id == user.id or user.role == "admin" or has_read_access:
file_path = file.path
# Handle Unicode filenames
@ -365,7 +453,16 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file and (file.user_id == user.id or user.role == "admin"):
if not file:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
has_write_access: bool = await check_user_has_access_to_file_via_any_knowledge_base(id, "write", user)
if file.user_id == user.id or user.role == "admin" or has_write_access:
# We should add Chroma cleanup here
result = Files.delete_file_by_id(id)