mypy fixes

This commit is contained in:
Leonardo Pinheiro 2024-10-24 17:20:04 +10:00
parent 1d59e51616
commit 299a2eb866
3 changed files with 27 additions and 27 deletions

View File

@ -58,9 +58,7 @@ class AsyncVectorDB(Protocol):
active_collection: Any = None active_collection: Any = None
type: str = "" type: str = ""
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = ( embedding_function: Optional[Callable[..., Any]] = None # embeddings = embedding_function(sentences)
None # embeddings = embedding_function(sentences)
)
async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any: async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
""" """

View File

@ -1,6 +1,6 @@
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
from autogen_core.application.logging import TRACE_LOGGER_NAME from autogen_core.application.logging import TRACE_LOGGER_NAME
from chromadb import GetResult from chromadb import GetResult
@ -79,10 +79,10 @@ class ChromaVectorDB(VectorDB):
) from e ) from e
self.IncludeEnum = IncludeEnum self.IncludeEnum = IncludeEnum
self.embedding_function: "EmbeddingFunction[Embeddable]" = ( # type: ignore self.embedding_function: "EmbeddingFunction[Any]" = ( # type: ignore
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") # type: ignore
if embedding_function is None if embedding_function is None
else cast("EmbeddingFunction[Embeddable]", embedding_function) else cast("EmbeddingFunction[Any]", embedding_function)
) )
self.metadata = metadata self.metadata = metadata
self.type = "chroma" self.type = "chroma"
@ -337,8 +337,8 @@ class ChromaVectorDB(VectorDB):
collection = self.get_collection(collection_name) collection = self.get_collection(collection_name)
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums) results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
results = _chroma_get_results_to_list_documents(results) results_list = _chroma_get_results_to_list_documents(results)
return results return results_list
class AsyncChromaVectorDB(AsyncVectorDB): class AsyncChromaVectorDB(AsyncVectorDB):
@ -666,7 +666,9 @@ class AsyncChromaVectorDB(AsyncVectorDB):
return results_list return results_list
def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key: str = "distances") -> QueryResults: def _chroma_results_to_query_results(
data_dict: ChromaQueryResult, special_key: Literal["distances"] = "distances"
) -> QueryResults:
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples. """Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
Args: Args:
@ -703,7 +705,7 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
return [] return []
result: List[List[Tuple[Document, float]]] = [] result: List[List[Tuple[Document, float]]] = []
data_special_key: Any = data_dict[special_key] data_special_key: Optional[List[List[float]]] = data_dict[special_key]
if data_special_key is None: if data_special_key is None:
return result return result
@ -711,15 +713,15 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
for i in range(len(data_special_key)): for i in range(len(data_special_key)):
sub_result: List[Tuple[Document, float]] = [] sub_result: List[Tuple[Document, float]] = []
ids = data_dict["ids"][i] ids = data_dict["ids"][i]
documents = data_dict.get("documents") or [None] * len(ids) documents = data_dict.get("documents") or [None] * len(ids) # type: ignore
metadatas = data_dict.get("metadatas") or [None] * len(ids) metadatas = data_dict.get("metadatas") or [None] * len(ids) # type: ignore
embeddings = data_dict.get("embeddings") or [None] * len(ids) embeddings = data_dict.get("embeddings") or [None] * len(ids)
for j in range(len(data_special_key[i])): for j in range(len(data_special_key[i])):
document = Document( document = Document(
id=ids[j], id=ids[j],
content=cast(str, documents[j]), content=cast(Optional[str], documents[j]),
metadata=cast(Metadata, metadatas[j]), metadata=cast(Optional[Metadata], metadatas[j]),
embedding=cast(Vector, embeddings[j]), embedding=cast(Optional[Vector], embeddings[j]),
) )
value = data_special_key[i][j] value = data_special_key[i][j]
sub_result.append((document, value)) sub_result.append((document, value))
@ -758,8 +760,8 @@ def _chroma_get_results_to_list_documents(data_dict: GetResult) -> List[Document
num_items = len(data_dict["ids"]) num_items = len(data_dict["ids"])
ids = data_dict["ids"] ids = data_dict["ids"]
documents = data_dict.get("documents") or [None] * num_items documents = data_dict.get("documents") or [None] * num_items # type: ignore
metadatas = data_dict.get("metadatas") or [None] * num_items metadatas = data_dict.get("metadatas") or [None] * num_items # type: ignore
embeddings = data_dict.get("embeddings") or [None] * num_items embeddings = data_dict.get("embeddings") or [None] * num_items
for i in range(num_items): for i in range(num_items):

View File

@ -5,7 +5,7 @@ from chromadb.errors import ChromaError
# @pytest.mark.skipif(skip, reason="dependency is not installed") # @pytest.mark.skipif(skip, reason="dependency is not installed")
def test_chromadb(): def test_chromadb() -> None:
# test create collection # test create collection
db = ChromaVectorDB(path=".db") db = ChromaVectorDB(path=".db")
collection_name = "test_collection" collection_name = "test_collection"
@ -53,13 +53,13 @@ def test_chromadb():
# test_retrieve_docs # test_retrieve_docs
queries = ["doc2", "doc3"] queries = ["doc2", "doc3"]
collection_name = "test_collection" collection_name = "test_collection"
res = db.retrieve_docs(queries, collection_name) res = db.retrieve_docs(queries, collection_name) # type: ignore
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] # type: ignore
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) # type: ignore
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] # type: ignore
# test_get_docs_by_ids # test_get_docs_by_ids
res = db.get_docs_by_ids(["1", "2"], collection_name) res = db.get_docs_by_ids(["1", "2"], collection_name) # type: ignore
assert [r.id for r in res] == ["2"] # "1" has been deleted assert [r.id for r in res] == ["2"] # type: ignore
res = db.get_docs_by_ids(collection_name=collection_name) res = db.get_docs_by_ids(collection_name=collection_name) # type: ignore
assert [r.id for r in res] == ["2", "3"] assert [r.id for r in res] == ["2", "3"] # type: ignore