mirror of https://github.com/microsoft/autogen.git
mypy fixes
This commit is contained in:
parent
1d59e51616
commit
299a2eb866
|
@ -58,9 +58,7 @@ class AsyncVectorDB(Protocol):
|
||||||
|
|
||||||
active_collection: Any = None
|
active_collection: Any = None
|
||||||
type: str = ""
|
type: str = ""
|
||||||
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
|
embedding_function: Optional[Callable[..., Any]] = None # embeddings = embedding_function(sentences)
|
||||||
None # embeddings = embedding_function(sentences)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
|
async def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
|
||||||
|
|
||||||
from autogen_core.application.logging import TRACE_LOGGER_NAME
|
from autogen_core.application.logging import TRACE_LOGGER_NAME
|
||||||
from chromadb import GetResult
|
from chromadb import GetResult
|
||||||
|
@ -79,10 +79,10 @@ class ChromaVectorDB(VectorDB):
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
self.IncludeEnum = IncludeEnum
|
self.IncludeEnum = IncludeEnum
|
||||||
self.embedding_function: "EmbeddingFunction[Embeddable]" = ( # type: ignore
|
self.embedding_function: "EmbeddingFunction[Any]" = ( # type: ignore
|
||||||
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2")
|
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") # type: ignore
|
||||||
if embedding_function is None
|
if embedding_function is None
|
||||||
else cast("EmbeddingFunction[Embeddable]", embedding_function)
|
else cast("EmbeddingFunction[Any]", embedding_function)
|
||||||
)
|
)
|
||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.type = "chroma"
|
self.type = "chroma"
|
||||||
|
@ -337,8 +337,8 @@ class ChromaVectorDB(VectorDB):
|
||||||
collection = self.get_collection(collection_name)
|
collection = self.get_collection(collection_name)
|
||||||
|
|
||||||
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
|
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
|
||||||
results = _chroma_get_results_to_list_documents(results)
|
results_list = _chroma_get_results_to_list_documents(results)
|
||||||
return results
|
return results_list
|
||||||
|
|
||||||
|
|
||||||
class AsyncChromaVectorDB(AsyncVectorDB):
|
class AsyncChromaVectorDB(AsyncVectorDB):
|
||||||
|
@ -666,7 +666,9 @@ class AsyncChromaVectorDB(AsyncVectorDB):
|
||||||
return results_list
|
return results_list
|
||||||
|
|
||||||
|
|
||||||
def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key: str = "distances") -> QueryResults:
|
def _chroma_results_to_query_results(
|
||||||
|
data_dict: ChromaQueryResult, special_key: Literal["distances"] = "distances"
|
||||||
|
) -> QueryResults:
|
||||||
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
|
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -703,7 +705,7 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
result: List[List[Tuple[Document, float]]] = []
|
result: List[List[Tuple[Document, float]]] = []
|
||||||
data_special_key: Any = data_dict[special_key]
|
data_special_key: Optional[List[List[float]]] = data_dict[special_key]
|
||||||
|
|
||||||
if data_special_key is None:
|
if data_special_key is None:
|
||||||
return result
|
return result
|
||||||
|
@ -711,15 +713,15 @@ def _chroma_results_to_query_results(data_dict: ChromaQueryResult, special_key:
|
||||||
for i in range(len(data_special_key)):
|
for i in range(len(data_special_key)):
|
||||||
sub_result: List[Tuple[Document, float]] = []
|
sub_result: List[Tuple[Document, float]] = []
|
||||||
ids = data_dict["ids"][i]
|
ids = data_dict["ids"][i]
|
||||||
documents = data_dict.get("documents") or [None] * len(ids)
|
documents = data_dict.get("documents") or [None] * len(ids) # type: ignore
|
||||||
metadatas = data_dict.get("metadatas") or [None] * len(ids)
|
metadatas = data_dict.get("metadatas") or [None] * len(ids) # type: ignore
|
||||||
embeddings = data_dict.get("embeddings") or [None] * len(ids)
|
embeddings = data_dict.get("embeddings") or [None] * len(ids)
|
||||||
for j in range(len(data_special_key[i])):
|
for j in range(len(data_special_key[i])):
|
||||||
document = Document(
|
document = Document(
|
||||||
id=ids[j],
|
id=ids[j],
|
||||||
content=cast(str, documents[j]),
|
content=cast(Optional[str], documents[j]),
|
||||||
metadata=cast(Metadata, metadatas[j]),
|
metadata=cast(Optional[Metadata], metadatas[j]),
|
||||||
embedding=cast(Vector, embeddings[j]),
|
embedding=cast(Optional[Vector], embeddings[j]),
|
||||||
)
|
)
|
||||||
value = data_special_key[i][j]
|
value = data_special_key[i][j]
|
||||||
sub_result.append((document, value))
|
sub_result.append((document, value))
|
||||||
|
@ -758,8 +760,8 @@ def _chroma_get_results_to_list_documents(data_dict: GetResult) -> List[Document
|
||||||
|
|
||||||
num_items = len(data_dict["ids"])
|
num_items = len(data_dict["ids"])
|
||||||
ids = data_dict["ids"]
|
ids = data_dict["ids"]
|
||||||
documents = data_dict.get("documents") or [None] * num_items
|
documents = data_dict.get("documents") or [None] * num_items # type: ignore
|
||||||
metadatas = data_dict.get("metadatas") or [None] * num_items
|
metadatas = data_dict.get("metadatas") or [None] * num_items # type: ignore
|
||||||
embeddings = data_dict.get("embeddings") or [None] * num_items
|
embeddings = data_dict.get("embeddings") or [None] * num_items
|
||||||
|
|
||||||
for i in range(num_items):
|
for i in range(num_items):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from chromadb.errors import ChromaError
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.skipif(skip, reason="dependency is not installed")
|
# @pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||||
def test_chromadb():
|
def test_chromadb() -> None:
|
||||||
# test create collection
|
# test create collection
|
||||||
db = ChromaVectorDB(path=".db")
|
db = ChromaVectorDB(path=".db")
|
||||||
collection_name = "test_collection"
|
collection_name = "test_collection"
|
||||||
|
@ -53,13 +53,13 @@ def test_chromadb():
|
||||||
# test_retrieve_docs
|
# test_retrieve_docs
|
||||||
queries = ["doc2", "doc3"]
|
queries = ["doc2", "doc3"]
|
||||||
collection_name = "test_collection"
|
collection_name = "test_collection"
|
||||||
res = db.retrieve_docs(queries, collection_name)
|
res = db.retrieve_docs(queries, collection_name) # type: ignore
|
||||||
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
|
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]] # type: ignore
|
||||||
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
|
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1) # type: ignore
|
||||||
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]]
|
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]] # type: ignore
|
||||||
|
|
||||||
# test_get_docs_by_ids
|
# test_get_docs_by_ids
|
||||||
res = db.get_docs_by_ids(["1", "2"], collection_name)
|
res = db.get_docs_by_ids(["1", "2"], collection_name) # type: ignore
|
||||||
assert [r.id for r in res] == ["2"] # "1" has been deleted
|
assert [r.id for r in res] == ["2"] # type: ignore
|
||||||
res = db.get_docs_by_ids(collection_name=collection_name)
|
res = db.get_docs_by_ids(collection_name=collection_name) # type: ignore
|
||||||
assert [r.id for r in res] == ["2", "3"]
|
assert [r.id for r in res] == ["2", "3"] # type: ignore
|
||||||
|
|
Loading…
Reference in New Issue