mypy fixes

This commit is contained in:
Leonardo Pinheiro 2024-10-24 17:20:04 +10:00
parent 1d59e51616
commit 299a2eb866
3 changed files with 27 additions and 27 deletions

View File

@ -58,9 +58,7 @@ class AsyncVectorDB(Protocol):
active_collection: Any = None
type: str = ""
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
None # embeddings = embedding_function(sentences)
)
embedding_function: Optional[Callable[..., Any]] = None # embeddings = embedding_function(sentences)
async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
"""

View File

@ -1,6 +1,6 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
from autogen_core.application.logging import TRACE_LOGGER_NAME
from chromadb import GetResult
@ -79,10 +79,10 @@ class ChromaVectorDB(VectorDB):
) from e
self.IncludeEnum = IncludeEnum
self.embedding_function: "EmbeddingFunction[Embeddable]" = ( # type: ignore
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
self.embedding_function: "EmbeddingFunction[Any]" = ( # type: ignore
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") # type: ignore
if embedding_function is None
else cast("EmbeddingFunction[Embeddable]", embedding_function)
else cast("EmbeddingFunction[Any]", embedding_function)
)
self.metadata = metadata
self.type = "chroma"
@ -337,8 +337,8 @@ class ChromaVectorDB(VectorDB):
collection = self.get_collection(collection_name)
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
results = _chroma_get_results_to_list_documents(results)
return results
results_list = _chroma_get_results_to_list_documents(results)
return results_list
class AsyncChromaVectorDB(AsyncVectorDB):
@ -666,7 +666,9 @@ class AsyncChromaVectorDB(AsyncVectorDB):
return results_list
def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key: str = "distances") -> QueryResults:
def _chroma_results_to_query_results(
data_dict: ChromaQueryResult, special_key: Literal["distances"] = "distances"
) -> QueryResults:
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
Args:
@ -703,7 +705,7 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
return []
result: List[List[Tuple[Document, float]]] = []
data_special_key: Any = data_dict[special_key]
data_special_key: Optional[List[List[float]]] = data_dict[special_key]
if data_special_key is None:
return result
@ -711,15 +713,15 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
for i in range(len(data_special_key)):
sub_result: List[Tuple[Document, float]] = []
ids = data_dict["ids"][i]
documents = data_dict.get("documents") or [None] * len(ids)
metadatas = data_dict.get("metadatas") or [None] * len(ids)
documents = data_dict.get("documents") or [None] * len(ids) # type: ignore
metadatas = data_dict.get("metadatas") or [None] * len(ids) # type: ignore
embeddings = data_dict.get("embeddings") or [None] * len(ids)
for j in range(len(data_special_key[i])):
document = Document(
id=ids[j],
content=cast(str, documents[j]),
metadata=cast(Metadata, metadatas[j]),
embedding=cast(Vector, embeddings[j]),
content=cast(Optional[str], documents[j]),
metadata=cast(Optional[Metadata], metadatas[j]),
embedding=cast(Optional[Vector], embeddings[j]),
)
value = data_special_key[i][j]
sub_result.append((document, value))
@ -758,8 +760,8 @@ def _chroma_get_results_to_list_documents(data_dict: GetResult) -> List[Document
num_items = len(data_dict["ids"])
ids = data_dict["ids"]
documents = data_dict.get("documents") or [None] * num_items
metadatas = data_dict.get("metadatas") or [None] * num_items
documents = data_dict.get("documents") or [None] * num_items # type: ignore
metadatas = data_dict.get("metadatas") or [None] * num_items # type: ignore
embeddings = data_dict.get("embeddings") or [None] * num_items
for i in range(num_items):

View File

@ -5,7 +5,7 @@ from chromadb.errors import ChromaError
# @pytest.mark.skipif(skip, reason="dependency is not installed")
def test_chromadb():
def test_chromadb() -> None:
# test create collection
db = ChromaVectorDB(path=".db")
collection_name = "test_collection"
@ -53,13 +53,13 @@ def test_chromadb():
# test_retrieve_docs
queries = ["doc2", "doc3"]
collection_name = "test_collection"
res = db.retrieve_docs(queries, collection_name)
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]]
res = db.retrieve_docs(queries, collection_name) # type: ignore
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] # type: ignore
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) # type: ignore
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] # type: ignore
# test_get_docs_by_ids
res = db.get_docs_by_ids(["1", "2"], collection_name)
assert [r.id for r in res] == ["2"] # "1" has been deleted
res = db.get_docs_by_ids(collection_name=collection_name)
assert [r.id for r in res] == ["2", "3"]
res = db.get_docs_by_ids(["1", "2"], collection_name) # type: ignore
assert [r.id for r in res] == ["2"] # type: ignore
res = db.get_docs_by_ids(collection_name=collection_name) # type: ignore
assert [r.id for r in res] == ["2", "3"] # type: ignore