feat: Qdrant support for the VectorDB interface (#3035)

* feat: Qdrant support

* chore: pre-defined vector db

* Fix issues

---------

Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
Anush 2024-07-15 18:49:58 +05:30 committed by GitHub
parent 24d509c1b6
commit 5b1dc3bf63
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 401 additions and 2 deletions

View File

@ -1,6 +1,7 @@
import hashlib
import os
import re
import uuid
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from IPython import get_ipython
@ -365,7 +366,11 @@ class RetrieveUserProxyAgent(UserProxyAgent):
else:
all_docs_ids = set()
chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
chunk_ids = (
[hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
if not self._vector_db.type == "qdrant"
else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
)
chunk_ids_set = set(chunk_ids)
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
docs = [

View File

@ -185,7 +185,7 @@ class VectorDBFactory:
Factory class for creating vector databases.
"""
PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
@staticmethod
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
@ -207,6 +207,10 @@ class VectorDBFactory:
from .pgvectordb import PGVectorDB
return PGVectorDB(**kwargs)
if db_type.lower() in ["qdrant", "qdrantdb"]:
from .qdrant import QdrantVectorDB
return QdrantVectorDB(**kwargs)
else:
raise ValueError(
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."

View File

@ -0,0 +1,322 @@
import abc
import logging
import os
from typing import Callable, List, Optional, Sequence, Tuple, Union
from .base import Document, ItemID, QueryResults, VectorDB
from .utils import get_logger
try:
from qdrant_client import QdrantClient, models
except ImportError:
raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
logger = get_logger(__name__)
Embeddings = Union[Sequence[float], Sequence[int]]
class EmbeddingFunction(abc.ABC):
@abc.abstractmethod
def __call__(self, inputs: List[str]) -> List[Embeddings]:
raise NotImplementedError
class FastEmbedEmbeddingFunction(EmbeddingFunction):
"""Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
batch_size: int = 256,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
parallel: Optional[int] = None,
**kwargs,
):
"""Initialize fastembed.TextEmbedding.
Args:
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
Defaults to 256.
cache_dir (str, optional): The path to the model cache directory.\
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
threads (int, optional): The number of threads single onnxruntime session can use.
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
If `0`, use all available cores.\
If `None`, don't use data-parallel processing, use default onnxruntime threading.\
Defaults to None.
**kwargs: Additional options to pass to fastembed.TextEmbedding
Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-small-en-v1.5.
"""
try:
from fastembed import TextEmbedding
except ImportError as e:
raise ValueError(
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
) from e
self._batch_size = batch_size
self._parallel = parallel
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
def __call__(self, inputs: List[str]) -> List[Embeddings]:
embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
return [embedding.tolist() for embedding in embeddings]
class QdrantVectorDB(VectorDB):
"""
A vector database implementation that uses Qdrant as the backend.
"""
def __init__(
self,
*,
client=None,
embedding_function: EmbeddingFunction = None,
content_payload_key: str = "_content",
metadata_payload_key: str = "_metadata",
collection_options: dict = {},
**kwargs,
) -> None:
"""
Initialize the vector database.
Args:
client: qdrant_client.QdrantClient | An instance of QdrantClient.
embedding_function: Callable | The embedding function used to generate the vector representation
of the documents. Defaults to FastEmbedEmbeddingFunction.
collection_options: dict | The options for creating the collection.
kwargs: dict | Additional keyword arguments.
"""
self.client: QdrantClient = client if client is not None else QdrantClient(location=":memory:")
self.embedding_function = FastEmbedEmbeddingFunction() or embedding_function
self.collection_options = collection_options
self.content_payload_key = content_payload_key
self.metadata_payload_key = metadata_payload_key
self.type = "qdrant"
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
"""
Create a collection in the vector database.
Args:
collection_name: str | The name of the collection.
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
get_or_create: bool | Whether to get the collection if it exists. Default is True.
Returns:
Any | The collection object.
"""
embeddings_size = len(self.embedding_function(["test"])[0])
if self.client.collection_exists(collection_name) and overwrite:
self.client.delete_collection(collection_name)
if not self.client.collection_exists(collection_name):
self.client.create_collection(
collection_name,
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
**self.collection_options,
)
def get_collection(self, collection_name: str = None):
"""
Get the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any | The collection object.
"""
if collection_name is None:
raise ValueError("The collection name is required.")
return self.client.get_collection(collection_name)
def delete_collection(self, collection_name: str) -> None:
"""Delete the collection from the vector database.
Args:
collection_name: str | The name of the collection.
Returns:
Any
"""
return self.client.delete_collection(collection_name)
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
"""
Insert documents into the collection of the vector database.
Args:
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
collection_name: str | The name of the collection. Default is None.
upsert: bool | Whether to update the document if it exists. Default is False.
kwargs: Dict | Additional keyword arguments.
Returns:
None
"""
if not docs:
return
if any(doc.get("content") is None for doc in docs):
raise ValueError("The document content is required.")
if any(doc.get("id") is None for doc in docs):
raise ValueError("The document id is required.")
if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
self.client.upsert(collection_name, points=self._documents_to_points(docs))
def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
if not docs:
return
if any(doc.get("id") is None for doc in docs):
raise ValueError("The document id is required.")
if any(doc.get("content") is None for doc in docs):
raise ValueError("The document content is required.")
if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
return self.client.upsert(collection_name, points=self._documents_to_points(docs))
raise ValueError("Some IDs do not exist. Skipping update")
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
"""
Delete documents from the collection of the vector database.
Args:
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
collection_name: str | The name of the collection. Default is None.
kwargs: Dict | Additional keyword arguments.
Returns:
None
"""
self.client.delete(collection_name, ids)
def retrieve_docs(
self,
queries: List[str],
collection_name: str = None,
n_results: int = 10,
distance_threshold: float = 0,
**kwargs,
) -> QueryResults:
"""
Retrieve documents from the collection of the vector database based on the queries.
Args:
queries: List[str] | A list of queries. Each query is a string.
collection_name: str | The name of the collection. Default is None.
n_results: int | The number of relevant documents to return. Default is 10.
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
returned. Don't filter with it if < 0. Default is 0.
kwargs: Dict | Additional keyword arguments.
Returns:
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
the distance.
"""
embeddings = self.embedding_function(queries)
requests = [
models.SearchRequest(
vector=embedding,
limit=n_results,
score_threshold=distance_threshold,
with_payload=True,
with_vector=False,
)
for embedding in embeddings
]
batch_results = self.client.search_batch(collection_name, requests)
return [self._scored_points_to_documents(results) for results in batch_results]
def get_docs_by_ids(
self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
) -> List[Document]:
"""
Retrieve documents from the collection of the vector database based on the ids.
Args:
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
collection_name: str | The name of the collection. Default is None.
include: List[str] | The fields to include. Default is True.
If None, will include ["metadatas", "documents"], ids will always be included.
kwargs: dict | Additional keyword arguments.
Returns:
List[Document] | The results.
"""
if ids is None:
results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
else:
results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
return [self._point_to_document(result) for result in results]
def _point_to_document(self, point) -> Document:
return {
"id": point.id,
"content": point.payload.get(self.content_payload_key, ""),
"metadata": point.payload.get(self.metadata_payload_key, {}),
"embedding": point.vector,
}
def _points_to_documents(self, points) -> List[Document]:
return [self._point_to_document(point) for point in points]
def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
return self._point_to_document(scored_point), scored_point.score
def _documents_to_points(self, documents: List[Document]):
contents = [document["content"] for document in documents]
embeddings = self.embedding_function(contents)
points = [
models.PointStruct(
id=documents[i]["id"],
vector=embeddings[i],
payload={
self.content_payload_key: documents[i].get("content"),
self.metadata_payload_key: documents[i].get("metadata"),
},
)
for i in range(len(documents))
]
return points
def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
"""
Validates all the IDs exist in the collection
"""
retrieved_ids = [
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
]
if missing_ids := set(ids) - set(retrieved_ids):
logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
return False
return True
def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
"""
Validate none of the IDs exist in the collection
"""
retrieved_ids = [
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
]
if existing_ids := set(ids) & set(retrieved_ids):
logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
return False
return True

View File

@ -0,0 +1,68 @@
import os
import sys
import pytest
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
try:
import uuid
from qdrant_client import QdrantClient
from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB
except ImportError:
skip = True
else:
skip = False
@pytest.mark.skipif(skip, reason="dependency is not installed")
def test_qdrant():
# test create collection
client = QdrantClient(location=":memory:")
db = QdrantVectorDB(client=client)
collection_name = uuid.uuid4().hex
db.create_collection(collection_name, overwrite=True, get_or_create=True)
assert client.collection_exists(collection_name)
# test_delete_collection
db.delete_collection(collection_name)
assert not client.collection_exists(collection_name)
# test_get_collection
db.create_collection(collection_name, overwrite=True, get_or_create=True)
collection_info = db.get_collection(collection_name)
# Assert default FastEmbed model dimensions
assert collection_info.config.params.vectors.size == 384
# test_insert_docs
docs = [{"content": "doc1", "id": 1}, {"content": "doc2", "id": 2}]
db.insert_docs(docs, collection_name, upsert=False)
res = db.get_docs_by_ids([1, 2], collection_name)
assert res[0]["id"] == 1
assert res[0]["content"] == "doc1"
assert res[1]["id"] == 2
assert res[1]["content"] == "doc2"
# test_update_docs and get_docs_by_ids
docs = [{"content": "doc11", "id": 1}, {"content": "doc22", "id": 2}]
db.update_docs(docs, collection_name)
res = db.get_docs_by_ids([1, 2], collection_name)
assert res[0]["id"] == 1
assert res[0]["content"] == "doc11"
assert res[1]["id"] == 2
assert res[1]["content"] == "doc22"
# test_retrieve_docs
queries = ["doc22", "doc11"]
res = db.retrieve_docs(queries, collection_name)
assert [[r[0]["id"] for r in rr] for rr in res] == [[2, 1], [1, 2]]
# test_delete_docs
db.delete_docs([1], collection_name)
assert db.client.count(collection_name).count == 1
if __name__ == "__main__":
test_qdrant()