mirror of https://github.com/microsoft/autogen.git
mypy fixes
This commit is contained in:
parent
1d59e51616
commit
299a2eb866
|
@ -58,9 +58,7 @@ class AsyncVectorDB(Protocol):
|
|||
|
||||
active_collection: Any = None
|
||||
type: str = ""
|
||||
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
|
||||
None # embeddings = embedding_function(sentences)
|
||||
)
|
||||
embedding_function: Optional[Callable[..., Any]] = None # embeddings = embedding_function(sentences)
|
||||
|
||||
async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
|
||||
|
||||
from autogen_core.application.logging import TRACE_LOGGER_NAME
|
||||
from chromadb import GetResult
|
||||
|
@ -79,10 +79,10 @@ class ChromaVectorDB(VectorDB):
|
|||
) from e
|
||||
|
||||
self.IncludeEnum = IncludeEnum
|
||||
self.embedding_function: "EmbeddingFunction[Embeddable]" = ( # type: ignore
|
||||
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
|
||||
self.embedding_function: "EmbeddingFunction[Any]" = ( # type: ignore
|
||||
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") # type: ignore
|
||||
if embedding_function is None
|
||||
else cast("EmbeddingFunction[Embeddable]", embedding_function)
|
||||
else cast("EmbeddingFunction[Any]", embedding_function)
|
||||
)
|
||||
self.metadata = metadata
|
||||
self.type = "chroma"
|
||||
|
@ -337,8 +337,8 @@ class ChromaVectorDB(VectorDB):
|
|||
collection = self.get_collection(collection_name)
|
||||
|
||||
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
|
||||
results = _chroma_get_results_to_list_documents(results)
|
||||
return results
|
||||
results_list = _chroma_get_results_to_list_documents(results)
|
||||
return results_list
|
||||
|
||||
|
||||
class AsyncChromaVectorDB(AsyncVectorDB):
|
||||
|
@ -666,7 +666,9 @@ class AsyncChromaVectorDB(AsyncVectorDB):
|
|||
return results_list
|
||||
|
||||
|
||||
def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key: str = "distances") -> QueryResults:
|
||||
def _chroma_results_to_query_results(
|
||||
data_dict: ChromaQueryResult, special_key: Literal["distances"] = "distances"
|
||||
) -> QueryResults:
|
||||
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
|
||||
|
||||
Args:
|
||||
|
@ -703,7 +705,7 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
|
|||
return []
|
||||
|
||||
result: List[List[Tuple[Document, float]]] = []
|
||||
data_special_key: Any = data_dict[special_key]
|
||||
data_special_key: Optional[List[List[float]]] = data_dict[special_key]
|
||||
|
||||
if data_special_key is None:
|
||||
return result
|
||||
|
@ -711,15 +713,15 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
|
|||
for i in range(len(data_special_key)):
|
||||
sub_result: List[Tuple[Document, float]] = []
|
||||
ids = data_dict["ids"][i]
|
||||
documents = data_dict.get("documents") or [None] * len(ids)
|
||||
metadatas = data_dict.get("metadatas") or [None] * len(ids)
|
||||
documents = data_dict.get("documents") or [None] * len(ids) # type: ignore
|
||||
metadatas = data_dict.get("metadatas") or [None] * len(ids) # type: ignore
|
||||
embeddings = data_dict.get("embeddings") or [None] * len(ids)
|
||||
for j in range(len(data_special_key[i])):
|
||||
document = Document(
|
||||
id=ids[j],
|
||||
content=cast(str, documents[j]),
|
||||
metadata=cast(Metadata, metadatas[j]),
|
||||
embedding=cast(Vector, embeddings[j]),
|
||||
content=cast(Optional[str], documents[j]),
|
||||
metadata=cast(Optional[Metadata], metadatas[j]),
|
||||
embedding=cast(Optional[Vector], embeddings[j]),
|
||||
)
|
||||
value = data_special_key[i][j]
|
||||
sub_result.append((document, value))
|
||||
|
@ -758,8 +760,8 @@ def _chroma_get_results_to_list_documents(data_dict: GetResult) -> List[Document
|
|||
|
||||
num_items = len(data_dict["ids"])
|
||||
ids = data_dict["ids"]
|
||||
documents = data_dict.get("documents") or [None] * num_items
|
||||
metadatas = data_dict.get("metadatas") or [None] * num_items
|
||||
documents = data_dict.get("documents") or [None] * num_items # type: ignore
|
||||
metadatas = data_dict.get("metadatas") or [None] * num_items # type: ignore
|
||||
embeddings = data_dict.get("embeddings") or [None] * num_items
|
||||
|
||||
for i in range(num_items):
|
||||
|
|
|
@ -5,7 +5,7 @@ from chromadb.errors import ChromaError
|
|||
|
||||
|
||||
# @pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
def test_chromadb():
|
||||
def test_chromadb() -> None:
|
||||
# test create collection
|
||||
db = ChromaVectorDB(path=".db")
|
||||
collection_name = "test_collection"
|
||||
|
@ -53,13 +53,13 @@ def test_chromadb():
|
|||
# test_retrieve_docs
|
||||
queries = ["doc2", "doc3"]
|
||||
collection_name = "test_collection"
|
||||
res = db.retrieve_docs(queries, collection_name)
|
||||
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
|
||||
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
|
||||
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]]
|
||||
res = db.retrieve_docs(queries, collection_name) # type: ignore
|
||||
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] # type: ignore
|
||||
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) # type: ignore
|
||||
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] # type: ignore
|
||||
|
||||
# test_get_docs_by_ids
|
||||
res = db.get_docs_by_ids(["1", "2"], collection_name)
|
||||
assert [r.id for r in res] == ["2"] # "1" has been deleted
|
||||
res = db.get_docs_by_ids(collection_name=collection_name)
|
||||
assert [r.id for r in res] == ["2", "3"]
|
||||
res = db.get_docs_by_ids(["1", "2"], collection_name) # type: ignore
|
||||
assert [r.id for r in res] == ["2"] # type: ignore
|
||||
res = db.get_docs_by_ids(collection_name=collection_name) # type: ignore
|
||||
assert [r.id for r in res] == ["2", "3"] # type: ignore
|
||||
|
|
Loading…
Reference in New Issue