This commit is contained in:
Leonardo Pinheiro 2024-11-04 22:12:52 +00:00 committed by GitHub
commit 134fa25114
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 2259 additions and 10 deletions

View File

@ -27,6 +27,7 @@ langchain = ["langchain_core~= 0.3.3"]
azure = ["azure-core", "azure-identity"]
docker = ["docker~=7.0"]
openai = ["openai>=1.3"]
chromadb = ["chromadb~=0.5.15", "sentence-transformers"]
[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]
@ -56,4 +57,4 @@ test = "pytest -n auto"
[tool.mypy]
[[tool.mypy.overrides]]
module = "docker.*"
ignore_missing_imports = true
ignore_missing_imports = true

View File

@ -0,0 +1,4 @@
from ._chromadb import AsyncChromaVectorDB, ChromaVectorDB
from ._factory import VectorDBFactory
__all__ = ["ChromaVectorDB", "AsyncChromaVectorDB", "VectorDBFactory"]

View File

@ -0,0 +1,378 @@
from typing import (
Any,
Callable,
List,
Mapping,
Optional,
Protocol,
Sequence,
Tuple,
Union,
runtime_checkable,
)
from pydantic import BaseModel
Metadata = Union[Mapping[str, Any], None]
Vector = Union[Sequence[float], Sequence[int]]
ItemID = Union[str, int]
class Document(BaseModel):
"""Define Document according to autogen 0.4 specifications."""
id: ItemID
content: Optional[str] = None
metadata: Optional[Metadata] = None
embedding: Optional[Vector] = None
model_config = {"arbitrary_types_allowed": True}
"""QueryResults is the response from the vector database for a query/queries.
A query is a list containing one string while queries is a list containing multiple strings.
The response is a list of query results, each query result is a list of tuples containing the document and the distance.
"""
QueryResults = List[List[Tuple[Document, float]]]
@runtime_checkable
class AsyncVectorDB(Protocol):
"""
Abstract class for async vector database. A vector database is responsible for storing and retrieving documents.
Attributes:
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
Methods:
create_collection: Callable[[str, bool, bool], Awaitable[Any]] | Create a collection in the vector database.
get_collection: Callable[[str], Awaitable[Any]] | Get the collection from the vector database.
delete_collection: Callable[[str], Awaitable[Any]] | Delete the collection from the vector database.
insert_docs: Callable[[List[Document], str, bool], Awaitable[None]] | Insert documents into the collection of the vector database.
update_docs: Callable[[List[Document], str], Awaitable[None]] | Update documents in the collection of the vector database.
delete_docs: Callable[[List[ItemID], str], Awaitable[None]] | Delete documents from the collection of the vector database.
retrieve_docs: Callable[[List[str], str, int, float], Awaitable[QueryResults]] | Retrieve documents from the collection of the vector database based on the queries.
get_docs_by_ids: Callable[[List[ItemID], str], Awaitable[List[Document]]] | Retrieve documents from the collection of the vector database based on the ids.
"""
active_collection: Any = None
type: str = ""
embedding_function: Optional[Callable[..., Any]] = None # embeddings = embedding_function(sentences)
async def create_collection(
self,
collection_name: str,
overwrite: bool = False,
get_or_create: bool = True,
**kwargs: Any,
) -> Any:
"""
Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
kwargs: Dict[str, Any] | Additional keyword arguments for collection creation (e.g. schema).
Returns:
Any | The collection object.
"""
...
async def get_collection(self, collection_name: Optional[str] = None) -> Any:
"""
Get the collection from the vector database.
Args:
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Any | The collection object.
"""
...
async def delete_collection(self, collection_name: str) -> Any:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any
"""
...
async def insert_docs(
self,
docs: List[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
async def update_docs(self, docs: List[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
async def delete_docs(self, ids: List[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
async def retrieve_docs(
self,
queries: List[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
...
async def get_docs_by_ids(
self,
ids: Optional[List[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[List[str]] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
depending on the implementation.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
...
@runtime_checkable
class VectorDB(Protocol):
"""
Abstract class for synchronous vector database. A vector database is responsible for storing and retrieving documents.
For async support, use AsyncVectorDB instead.
Attributes:
active_collection: Any | The active collection in the vector database. Make get_collection faster. Default is None.
type: str | The type of the vector database, chroma, pgvector, etc. Default is "".
Methods:
create_collection: Callable[[str, bool, bool], Any] | Create a collection in the vector database.
get_collection: Callable[[str], Any] | Get the collection from the vector database.
delete_collection: Callable[[str], Any] | Delete the collection from the vector database.
insert_docs: Callable[[List[Document], str, bool], None] | Insert documents into the collection of the vector database.
update_docs: Callable[[List[Document], str], None] | Update documents in the collection of the vector database.
delete_docs: Callable[[List[ItemID], str], None] | Delete documents from the collection of the vector database.
retrieve_docs: Callable[[List[str], str, int, float], QueryResults] | Retrieve documents from the collection of the vector database based on the queries.
get_docs_by_ids: Callable[[List[ItemID], str], List[Document]] | Retrieve documents from the collection of the vector database based on the ids.
"""
active_collection: Any = None
type: str = ""
embedding_function: Optional[Callable[[List[str]], List[List[float]]]] = (
None # embeddings = embedding_function(sentences)
)
def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True, **kwargs: Any
) -> Any:
"""
Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection,
otherwise it raise a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Any | The collection object.
"""
...
def get_collection(self, collection_name: Optional[str] = None) -> Any:
"""
Get the collection from the vector database.
Args:
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Any | The collection object.
"""
...
def delete_collection(self, collection_name: str) -> Any:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any
"""
...
def insert_docs(
self,
docs: List[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
def update_docs(self, docs: List[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
def delete_docs(self, ids: List[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
...
def retrieve_docs(
self,
queries: List[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
...
def get_docs_by_ids(
self,
ids: Optional[List[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: Optional[List[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[List[str]] | The fields to include. Default is None.
If None, will include ["metadatas", "documents"], ids will always be included. This may differ
depending on the implementation.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
...

View File

@ -0,0 +1,784 @@
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union, cast
from autogen_core.application.logging import TRACE_LOGGER_NAME
from chromadb import GetResult
from chromadb import QueryResult as ChromaQueryResult
if TYPE_CHECKING:
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.Collection import Collection
from chromadb.api.types import Embeddable, EmbeddingFunction
from chromadb.config import Settings
from ._base import AsyncVectorDB, Document, ItemID, Metadata, QueryResults, Vector, VectorDB
CHROMADB_MAX_BATCH_SIZE = int(os.environ.get("CHROMADB_MAX_BATCH_SIZE", 40000))
logger = logging.getLogger(f"{TRACE_LOGGER_NAME}.{__name__}")
class ChromaVectorDB(VectorDB):
"""
A vector database that uses ChromaDB as the backend.
.. note::
This class requires the :code:`chromadb` extra for the :code:`autogen-ext` package.
"""
ChromaError = Exception # Default to Exception if chromadb is not installed
def __init__(
self,
*,
client: Optional["ClientAPI"] = None,
path: Optional[str] = None,
embedding_function: Optional[
Union[Callable[[List[str]], List[List[float]]], "EmbeddingFunction[Embeddable]"]
] = None,
metadata: Optional[Dict[str, Any]] = None,
client_type: str = "persistent",
host: str = "localhost",
port: int = 8000,
**kwargs: Any,
) -> None:
"""
Initialize the vector database.
Args:
client: chromadb.Client | The client object of the vector database. Default is None.
If provided, it will use the client object directly and ignore other arguments.
path: Optional[str] | The path to the vector database. Required if client_type is 'persistent'.
embedding_function: Optional[Union[Callable, EmbeddingFunction]] | The embedding function used to generate the vector representation
of the documents. Default is None, SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") will be used.
metadata: dict | The metadata of the vector database. Default is None.
client_type: str | The type of client to use. Can be 'persistent' or 'http'. Default is 'persistent'.
host: str | The host of the HTTP server. Default is 'localhost'.
port: int | The port of the HTTP server. Default is 8000.
kwargs: dict | Additional keyword arguments.
Returns:
None
"""
try:
import chromadb
if chromadb.__version__ < "0.5.0":
raise ImportError("Please upgrade chromadb to version 0.5.0 or later.")
from chromadb.api.types import IncludeEnum
from chromadb.errors import ChromaError
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
ChromaVectorDB.ChromaError = ChromaError # Set the class attribute
except ImportError as e:
raise RuntimeError(
"Missing dependencies for ChromaVectorDB. Please ensure the autogen-ext package was installed with the 'chromadb' extra."
) from e
self.IncludeEnum = IncludeEnum
self.embedding_function: "EmbeddingFunction[Any]" = ( # type: ignore
SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2") # type: ignore
if embedding_function is None
else cast("EmbeddingFunction[Any]", embedding_function)
)
self.metadata = metadata
self.type = "chroma"
if client is not None:
self.client: "ClientAPI" = client
else:
if client_type == "persistent":
if path is None:
raise ValueError("Persistent client requires a 'path' to save the database.")
self.client = chromadb.PersistentClient(path=path, **kwargs)
elif client_type == "http":
self.client = chromadb.HttpClient(host=host, port=port, **kwargs)
else:
raise ValueError(f"Invalid client_type: {client_type}")
self.active_collection: Optional["Collection"] = None
def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True, **kwargs: Any
) -> "Collection":
"""Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, otherwise it raises a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Collection | The collection object.
"""
try:
if self.active_collection and self.active_collection.name == collection_name:
collection = self.active_collection
else:
collection = self.client.get_collection(
name=collection_name, embedding_function=self.embedding_function
)
except (ValueError, ChromaVectorDB.ChromaError):
collection = None
if collection is None:
return self.client.create_collection(
name=collection_name,
embedding_function=kwargs.pop("embedding_function", self.embedding_function),
metadata=kwargs.pop("metadata", self.metadata),
data_loader=kwargs.pop("data_loader", None),
)
elif overwrite:
self.client.delete_collection(name=collection_name)
return self.client.create_collection(
name=collection_name,
embedding_function=kwargs.pop("embedding_function", self.embedding_function),
metadata=kwargs.pop("metadata", self.metadata),
data_loader=kwargs.pop("data_loader", None),
)
elif get_or_create:
return collection
else:
raise ValueError(f"Collection {collection_name} already exists.")
def get_collection(self, collection_name: Optional[str] = None) -> "Collection":
"""Get the collection from the vector database.
Args:
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Collection | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.info(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
if not (self.active_collection and self.active_collection.name == collection_name):
self.active_collection = self.client.get_collection(
name=collection_name, embedding_function=self.embedding_function
)
return self.active_collection
def delete_collection(self, collection_name: str) -> None:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
None
"""
self.client.delete_collection(name=collection_name)
if self.active_collection and self.active_collection.name == collection_name:
self.active_collection = None
def _batch_insert(
self,
collection: "Collection",
embeddings: Optional[List[Any]] = None,
ids: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[str]] = None,
upsert: bool = False,
) -> None:
batch_size = CHROMADB_MAX_BATCH_SIZE
for i in range(0, len(ids or []), batch_size):
end_idx = i + batch_size
collection_kwargs = {
"documents": documents[i:end_idx] if documents else None,
"ids": ids[i:end_idx] if ids else None,
"metadatas": metadatas[i:end_idx] if metadatas else None,
"embeddings": embeddings[i:end_idx] if embeddings else None,
}
if upsert:
collection.upsert(**collection_kwargs) # type: ignore
else:
collection.add(**collection_kwargs) # type: ignore
def insert_docs(
self,
docs: List[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
if not docs:
return
if docs[0].content is None and docs[0].embedding is None:
raise ValueError("Either document content or embedding is required.")
documents = [doc.content for doc in docs] if docs[0].content else None
ids = [str(doc.id) for doc in docs]
collection = self.get_collection(collection_name)
embeddings = [doc.embedding for doc in docs] if docs[0].embedding else None
if not embeddings and not documents:
raise ValueError("Either documents or embeddings must be provided.")
metadatas = [doc.metadata for doc in docs] if docs[0].metadata else None
self._batch_insert(
collection,
embeddings=embeddings,
ids=ids,
metadatas=metadatas, # type: ignore
documents=documents, # type: ignore
upsert=upsert,
)
def update_docs(self, docs: List[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
self.insert_docs(docs, collection_name=collection_name, upsert=True, **kwargs)
def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: Sequence[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
collection = self.get_collection(collection_name)
collection.delete(ids=[str(id_) for id_ in ids] if ids else None)
def retrieve_docs(
self,
queries: List[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
collection = self.get_collection(collection_name)
if isinstance(queries, str):
queries = [queries]
results = collection.query(
query_texts=queries,
n_results=n_results,
)
results_list = _chroma_results_to_query_results(results)
results_filtered = _filter_results_by_distance(results_list, distance_threshold)
return results_filtered
def get_docs_by_ids(
self,
ids: Optional[Sequence[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: Optional[Sequence[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[List[IncludeEnum]] | The fields to include. Default is None.
If None, will include [IncludeEnum.metadatas, IncludeEnum.documents]. IDs are always included.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
if include is not None:
include_enums = [self.IncludeEnum(item) for item in include]
else:
include_enums = [self.IncludeEnum.metadatas, self.IncludeEnum.documents]
collection = self.get_collection(collection_name)
results = collection.get(ids=[str(id_) for id_ in ids] if ids else None, include=include_enums)
results_list = _chroma_get_results_to_list_documents(results)
return results_list
class AsyncChromaVectorDB(AsyncVectorDB):
"""
An asynchronous vector database that uses ChromaDB as the backend.
.. note::
This class requires the :code:`chromadb` extra for the :code:`autogen-ext` package.
"""
ChromaError = Exception # Default to Exception if chromadb is not installed
def __init__(
self,
*,
client: Optional["AsyncClientAPI"] = None,
embedding_function: Optional[
Union[Callable[[List[str]], List[List[float]]], "EmbeddingFunction[Embeddable]"]
] = None,
host: str = "localhost",
port: int = 8000,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
settings: Optional["Settings"] = None,
tenant: str = "default_tenant",
database: str = "default_database",
**kwargs: Any,
) -> None:
"""
Initialize the async vector database.
Args:
client: chromadb.AsyncClientAPI | The client object of the vector database. Default is None.
If provided, it will use the client object directly and ignore other arguments.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Default is None. Must be provided for async client.
host: str | The host of the HTTP server. Default is 'localhost'.
port: int | The port of the HTTP server. Default is 8000.
ssl: bool | Whether to use SSL to connect to the Chroma server. Defaults to False.
headers: Optional[Dict[str, str]] | A dictionary of headers to send to the Chroma server. Defaults to None.
settings: Optional[Settings] | A dictionary of settings to communicate with the chroma server.
tenant: str | The tenant to use for this client. Defaults to "default_tenant".
database: str | The database to use for this client. Defaults to "default_database".
kwargs: dict | Additional keyword arguments.
Returns:
None
"""
try:
import chromadb
if chromadb.__version__ < "0.5.0":
raise ImportError("Please upgrade chromadb to version 0.5.0 or later.")
from chromadb.api.types import IncludeEnum
from chromadb.errors import ChromaError
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
AsyncChromaVectorDB.ChromaError = ChromaError # Set the class attribute
except ImportError as e:
raise RuntimeError(
"Missing dependencies for AsyncChromaVectorDB. Please ensure the autogen-ext package was installed with the 'chromadb' extra."
) from e
self.IncludeEnum = IncludeEnum
self.embedding_function: "EmbeddingFunction[Embeddable]" = ( # type: ignore
cast(
"EmbeddingFunction[Embeddable]",
embedding_function or SentenceTransformerEmbeddingFunction("all-MiniLM-L6-v2"),
)
)
self.type = "chroma"
if client is not None:
self.client: "AsyncClientAPI" = client
else:
self.client = chromadb.AsyncHttpClient( # type: ignore
host=host,
port=port,
ssl=ssl,
headers=headers,
settings=settings,
tenant=tenant,
database=database,
**kwargs,
)
self.active_collection: Optional[Any] = None
async def create_collection(
self, collection_name: str, overwrite: bool = False, get_or_create: bool = True, **kwargs: Any
) -> Any:
"""Create a collection in the vector database.
Case 1. if the collection does not exist, create the collection.
Case 2. the collection exists, if overwrite is True, it will overwrite the collection.
Case 3. the collection exists and overwrite is False, if get_or_create is True, it will get the collection, otherwise it raises a ValueError.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Any | The collection object.
"""
try:
if self.active_collection and self.active_collection.name == collection_name:
collection = self.active_collection
else:
collection = await self.client.get_collection(
name=collection_name, embedding_function=self.embedding_function
)
except (ValueError, AsyncChromaVectorDB.ChromaError):
collection = None
if collection is None:
return await self.client.create_collection(
name=collection_name,
embedding_function=kwargs.pop("embedding_function", self.embedding_function),
metadata=kwargs.pop("metadata", {}),
data_loader=kwargs.pop("data_loader", None),
)
elif overwrite:
await self.client.delete_collection(name=collection_name)
return await self.client.create_collection(
name=collection_name,
embedding_function=kwargs.pop("embedding_function", self.embedding_function),
metadata=kwargs.pop("metadata", {}),
data_loader=kwargs.pop("data_loader", None),
)
elif get_or_create:
return collection
else:
raise ValueError(f"Collection {collection_name} already exists.")
async def get_collection(self, collection_name: Optional[str] = None) -> Any:
"""
Get the collection from the vector database.
Args:
collection_name: Optional[str] | The name of the collection. Default is None.
If None, return the current active collection.
Returns:
Any | The collection object.
"""
if collection_name is None:
if self.active_collection is None:
raise ValueError("No collection is specified.")
else:
logger.info(
f"No collection is specified. Using current active collection {self.active_collection.name}."
)
else:
if not (self.active_collection and self.active_collection.name == collection_name):
self.active_collection = await self.client.get_collection(
name=collection_name, embedding_function=self.embedding_function
)
return self.active_collection
async def delete_collection(self, collection_name: str) -> Any:
"""
Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any
"""
await self.client.delete_collection(name=collection_name)
if self.active_collection and self.active_collection.name == collection_name:
self.active_collection = None
async def _batch_insert(
self,
collection: Any,
embeddings: Optional[List[Any]] = None,
ids: Optional[List[str]] = None,
metadatas: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[str]] = None,
upsert: bool = False,
) -> None:
batch_size = CHROMADB_MAX_BATCH_SIZE
for i in range(0, len(ids or []), batch_size):
end_idx = i + batch_size
collection_kwargs = {
"documents": documents[i:end_idx] if documents else None,
"ids": ids[i:end_idx] if ids else None,
"metadatas": metadatas[i:end_idx] if metadatas else None,
"embeddings": embeddings[i:end_idx] if embeddings else None,
}
if upsert:
await collection.upsert(**collection_kwargs)
else:
await collection.add(**collection_kwargs)
async def insert_docs(
self,
docs: List[Document],
collection_name: Optional[str] = None,
upsert: bool = False,
**kwargs: Any,
) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a Pydantic Document model.
collection_name: Optional[str] | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
if not docs:
return
if docs[0].content is None and docs[0].embedding is None:
raise ValueError("Either document content or embedding is required.")
documents = [doc.content for doc in docs] if docs[0].content else None
ids = [str(doc.id) for doc in docs]
collection = await self.get_collection(collection_name)
embeddings = [doc.embedding for doc in docs] if docs[0].embedding else None
if not embeddings and not documents:
raise ValueError("Either documents or embeddings must be provided.")
metadatas = [doc.metadata for doc in docs] if docs[0].metadata else None
await self._batch_insert(
collection,
embeddings=embeddings,
ids=ids,
metadatas=metadatas, # type: ignore
documents=documents, # type: ignore
upsert=upsert,
)
async def update_docs(self, docs: List[Document], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Update documents in the collection of the vector database.
Args:
docs: List[Document] | A list of documents.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
await self.insert_docs(docs, collection_name=collection_name, upsert=True, **kwargs)
async def delete_docs(self, ids: Sequence[ItemID], collection_name: Optional[str] = None, **kwargs: Any) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: Sequence[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: Optional[str] | The name of the collection. Default is None.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
None
"""
collection = await self.get_collection(collection_name)
await collection.delete(ids=ids)
async def retrieve_docs(
self,
queries: List[str],
collection_name: Optional[str] = None,
n_results: int = 10,
distance_threshold: float = -1,
**kwargs: Any,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: Optional[str] | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is -1.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
collection = await self.get_collection(collection_name)
if isinstance(queries, str):
queries = [queries]
results = await collection.query(
query_texts=queries,
n_results=n_results,
)
results_list = _chroma_results_to_query_results(results)
results_filtered = _filter_results_by_distance(results_list, distance_threshold)
return results_filtered
async def get_docs_by_ids(
self,
ids: Optional[Sequence[ItemID]] = None,
collection_name: Optional[str] = None,
include: Optional[List[str]] = None,
**kwargs: Any,
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: Optional[Sequence[ItemID]] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: Optional[str] | The name of the collection. Default is None.
include: Optional[Sequence[IncludeEnum]] | The fields to include. Default is None.
If None, will include [IncludeEnum.metadatas, IncludeEnum.documents]. IDs are always included.
kwargs: Dict[str, Any] | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
collection = await self.get_collection(collection_name)
if include is not None:
include_enums = [self.IncludeEnum(item) for item in include]
else:
include_enums = [self.IncludeEnum.metadatas, self.IncludeEnum.documents]
results: GetResult = await collection.get(ids=ids, include=include_enums)
results_list = _chroma_get_results_to_list_documents(results)
return results_list
def _chroma_results_to_query_results(
data_dict: ChromaQueryResult, special_key: Literal["distances"] = "distances"
) -> QueryResults:
"""Converts a ChromaDB query result into a list of lists of (Document, float) tuples.
Args:
data_dict: A dictionary containing the results of a ChromaDB query.
special_key: The key in the dictionary containing the float values for each tuple (default is "distances").
Returns:
A list of lists, where each sublist corresponds to a query and contains tuples of (Document, float).
Example:
data_dict = {
'ids': [['1', '2'], ['3', '4']],
'documents': [['doc1', 'doc2'], ['doc3', 'doc4']],
'metadatas': [[{'meta': 'data1'}, {'meta': 'data2'}], [{'meta': 'data3'}, {'meta': 'data4'}]],
'distances': [[0.1, 0.2], [0.3, 0.4]],
}
results = _chroma_results_to_query_results(data_dict)
# results will be:
# [
# [
# (Document(id='1', content='doc1', metadata={'meta': 'data1'}, embedding=None), 0.1),
# (Document(id='2', content='doc2', metadata={'meta': 'data2'}, embedding=None), 0.2),
# ],
# [
# (Document(id='3', content='doc3', metadata={'meta': 'data3'}, embedding=None), 0.3),
# (Document(id='4', content='doc4', metadata={'meta': 'data4'}, embedding=None), 0.4),
# ],
# ]
"""
if not data_dict or special_key not in data_dict or not data_dict.get(special_key):
return []
result: List[List[Tuple[Document, float]]] = []
data_special_key: Optional[List[List[float]]] = data_dict[special_key]
if data_special_key is None:
return result
for i in range(len(data_special_key)):
sub_result: List[Tuple[Document, float]] = []
ids_i = data_dict["ids"][i]
documents_list = data_dict.get("documents")
documents_i = documents_list[i] if documents_list else [None] * len(ids_i) # type: ignore
metadatas_list = data_dict.get("metadatas")
metadatas_i = metadatas_list[i] if metadatas_list else [None] * len(ids_i) # type: ignore
embeddings_list = data_dict.get("embeddings")
embeddings_i = embeddings_list[i] if embeddings_list else [None] * len(ids_i)
for j in range(len(data_special_key[i])):
document = Document(
id=ids_i[j],
content=documents_i[j],
metadata=cast(Optional[Metadata], metadatas_i[j]),
embedding=cast(Optional[Vector], embeddings_i[j]),
)
value = data_special_key[i][j]
sub_result.append((document, value))
result.append(sub_result)
return result
def _filter_results_by_distance(results: QueryResults, distance_threshold: float = -1) -> QueryResults:
"""Filters results based on a distance threshold.
Args:
results: QueryResults | The query results. List[List[Tuple[Document, float]]]
distance_threshold: The maximum distance allowed for results.
Returns:
QueryResults | A filtered results containing only distances smaller than the threshold.
"""
if distance_threshold > 0:
results = [[(key, value) for key, value in data if value < distance_threshold] for data in results]
return results
def _chroma_get_results_to_list_documents(data_dict: GetResult) -> List[Document]:
"""Converts a dictionary with list values to a list of Document.
Args:
data_dict: A dictionary where keys map to lists or None.
Returns:
List[Document] | The list of Document.
"""
results: List[Document] = []
num_items = len(data_dict["ids"])
ids = data_dict["ids"]
documents = data_dict.get("documents") or [None] * num_items # type: ignore
metadatas = data_dict.get("metadatas") or [None] * num_items # type: ignore
embeddings = data_dict.get("embeddings") or [None] * num_items
for i in range(num_items):
results.append(
Document(
id=ids[i],
content=documents[i],
metadata=metadatas[i],
embedding=cast(Vector, embeddings[i]),
)
)
return results

View File

@ -0,0 +1,33 @@
from typing import Any, Literal
from ._base import VectorDB
class VectorDBFactory:
"""
Factory class for creating vector databases.
"""
PREDEFINED_VECTOR_DB = ["chromadb"]
@staticmethod
def create_vector_db(db_type: Literal["chromadb"], **kwargs: Any) -> VectorDB:
"""
Create a vector database.
Args:
db_type: Literal["chroma", "chromadb"] | The type of the vector database.
kwargs: Dict | The keyword arguments for initializing the vector database.
Returns:
VectorDB | The vector database.
"""
if db_type.lower() == "chromadb":
from ._chromadb import ChromaVectorDB
return ChromaVectorDB(**kwargs) # type: ignore
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
)

View File

@ -0,0 +1,141 @@
from pathlib import Path
from typing import Generator
import pytest
from autogen_ext.storage import ChromaVectorDB
from autogen_ext.storage._base import Document
from chromadb import Collection
from chromadb.config import Settings
from chromadb.errors import ChromaError
# Fixture for the synchronous database instance with function-level scope
@pytest.fixture(scope="function")
def db(tmp_path: Path) -> Generator[ChromaVectorDB, None, None]:
db_path = tmp_path / "test_db"
db_instance = ChromaVectorDB(path=str(db_path), settings=Settings(allow_reset=True))
yield db_instance
# Teardown code
db_instance.client.reset()
# Fixture for unique collection names per test
@pytest.fixture(scope="function")
def collection_name(request: pytest.FixtureRequest) -> str:
return f"test_collection_{request.node.name}" # type: ignore
# Fixture to create and delete the collection around each test
@pytest.fixture(scope="function")
def collection(db: ChromaVectorDB, collection_name: str) -> Generator[Collection, None, None]:
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True)
yield collection
db.delete_collection(collection_name)
def test_create_collection(db: ChromaVectorDB, collection_name: str) -> None:
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True)
assert collection.name == collection_name
def test_delete_collection(db: ChromaVectorDB, collection_name: str) -> None:
# Create the collection first
db.create_collection(collection_name, overwrite=True, get_or_create=True)
db.delete_collection(collection_name)
with pytest.raises((ValueError, ChromaError)):
db.get_collection(collection_name)
def test_more_create_collection(db: ChromaVectorDB, collection_name: str) -> None:
# Ensure the collection is deleted at the start
try:
db.delete_collection(collection_name)
except (ValueError, ChromaError):
pass
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
assert collection.name == collection_name
with pytest.raises((ValueError, ChromaError)):
db.create_collection(collection_name, overwrite=False, get_or_create=False)
collection = db.create_collection(collection_name, overwrite=True, get_or_create=False)
assert collection.name == collection_name
collection = db.create_collection(collection_name, overwrite=False, get_or_create=True)
assert collection.name == collection_name
def test_get_collection(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
retrieved_collection = db.get_collection(collection_name)
assert retrieved_collection.name == collection_name
def test_insert_docs(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
docs = [
Document(content="doc1", id="1"),
Document(content="doc2", id="2"),
Document(content="doc3", id="3"),
]
db.insert_docs(docs, collection_name, upsert=False)
res = db.get_collection(collection_name).get(["1", "2"])
assert res["documents"] == ["doc1", "doc2"]
def test_update_docs(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
# Insert initial docs
initial_docs = [
Document(content="doc1", id="1"),
Document(content="doc2", id="2"),
]
db.insert_docs(initial_docs, collection_name, upsert=False)
# Now update
updated_docs = [
Document(content="doc11", id="1"),
Document(content="doc2", id="2"),
Document(content="doc3", id="3"),
]
db.update_docs(updated_docs, collection_name)
res = db.get_collection(collection_name).get(["1", "2"])
assert res["documents"] == ["doc11", "doc2"]
def test_delete_docs(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
# Insert initial docs
initial_docs = [
Document(content="doc1", id="1"),
Document(content="doc2", id="2"),
]
db.insert_docs(initial_docs, collection_name, upsert=False)
ids = ["1"]
db.delete_docs(ids, collection_name)
res = db.get_collection(collection_name).get(ids)
assert res["documents"] == []
def test_retrieve_docs(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
# Insert initial docs
initial_docs = [
Document(content="doc2", id="2"),
Document(content="doc3", id="3"),
]
db.insert_docs(initial_docs, collection_name, upsert=False)
queries = ["doc2", "doc3"]
res = db.retrieve_docs(queries, collection_name)
assert [[r[0].id for r in rr] for rr in res] == [["2", "3"], ["3", "2"]]
res = db.retrieve_docs(queries, collection_name, distance_threshold=0.1)
assert [[r[0].id for r in rr] for rr in res] == [["2"], ["3"]]
def test_get_docs_by_ids(db: ChromaVectorDB, collection_name: str, collection: Collection) -> None:
# Insert initial docs
initial_docs = [
Document(content="doc2", id="2"),
Document(content="doc3", id="3"),
]
db.insert_docs(initial_docs, collection_name, upsert=False)
res = db.get_docs_by_ids(["1", "2"], collection_name)
assert [r.id for r in res] == ["2"]
res = db.get_docs_by_ids(collection_name=collection_name)
assert [r.id for r in res] == ["2", "3"]

File diff suppressed because it is too large Load Diff