mirror of https://github.com/microsoft/autogen.git
feat: Qdrant support for the VectorDB interface (#3035)
* feat: Qdrant support * chore: pre-defined vector db * Fix issues --------- Co-authored-by: Li Jiang <bnujli@gmail.com>
This commit is contained in:
parent
24d509c1b6
commit
5b1dc3bf63
|
@ -1,6 +1,7 @@
|
|||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from IPython import get_ipython
|
||||
|
@ -365,7 +366,11 @@ class RetrieveUserProxyAgent(UserProxyAgent):
|
|||
else:
|
||||
all_docs_ids = set()
|
||||
|
||||
chunk_ids = [hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
|
||||
chunk_ids = (
|
||||
[hashlib.blake2b(chunk.encode("utf-8")).hexdigest()[:HASH_LENGTH] for chunk in chunks]
|
||||
if not self._vector_db.type == "qdrant"
|
||||
else [str(uuid.UUID(hex=hashlib.md5(chunk.encode("utf-8")).hexdigest())) for chunk in chunks]
|
||||
)
|
||||
chunk_ids_set = set(chunk_ids)
|
||||
chunk_ids_set_idx = [chunk_ids.index(hash_value) for hash_value in chunk_ids_set]
|
||||
docs = [
|
||||
|
|
|
@ -185,7 +185,7 @@ class VectorDBFactory:
|
|||
Factory class for creating vector databases.
|
||||
"""
|
||||
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector"]
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "qdrant"]
|
||||
|
||||
@staticmethod
|
||||
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
|
||||
|
@ -207,6 +207,10 @@ class VectorDBFactory:
|
|||
from .pgvectordb import PGVectorDB
|
||||
|
||||
return PGVectorDB(**kwargs)
|
||||
if db_type.lower() in ["qdrant", "qdrantdb"]:
|
||||
from .qdrant import QdrantVectorDB
|
||||
|
||||
return QdrantVectorDB(**kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
|
||||
|
|
|
@ -0,0 +1,322 @@
|
|||
import abc
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from .base import Document, ItemID, QueryResults, VectorDB
|
||||
from .utils import get_logger
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient, models
|
||||
except ImportError:
|
||||
raise ImportError("Please install qdrant-client: `pip install qdrant-client`")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
Embeddings = Union[Sequence[float], Sequence[int]]
|
||||
|
||||
|
||||
class EmbeddingFunction(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, inputs: List[str]) -> List[Embeddings]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FastEmbedEmbeddingFunction(EmbeddingFunction):
|
||||
"""Embedding function implementation using FastEmbed - https://qdrant.github.io/fastembed."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
batch_size: int = 256,
|
||||
cache_dir: Optional[str] = None,
|
||||
threads: Optional[int] = None,
|
||||
parallel: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize fastembed.TextEmbedding.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
|
||||
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
|
||||
Defaults to 256.
|
||||
cache_dir (str, optional): The path to the model cache directory.\
|
||||
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
|
||||
threads (int, optional): The number of threads single onnxruntime session can use.
|
||||
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for large datasets.\
|
||||
If `0`, use all available cores.\
|
||||
If `None`, don't use data-parallel processing, use default onnxruntime threading.\
|
||||
Defaults to None.
|
||||
**kwargs: Additional options to pass to fastembed.TextEmbedding
|
||||
Raises:
|
||||
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-small-en-v1.5.
|
||||
"""
|
||||
try:
|
||||
from fastembed import TextEmbedding
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`",
|
||||
) from e
|
||||
self._batch_size = batch_size
|
||||
self._parallel = parallel
|
||||
self._model = TextEmbedding(model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs)
|
||||
|
||||
def __call__(self, inputs: List[str]) -> List[Embeddings]:
|
||||
embeddings = self._model.embed(inputs, batch_size=self._batch_size, parallel=self._parallel)
|
||||
|
||||
return [embedding.tolist() for embedding in embeddings]
|
||||
|
||||
|
||||
class QdrantVectorDB(VectorDB):
|
||||
"""
|
||||
A vector database implementation that uses Qdrant as the backend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
client=None,
|
||||
embedding_function: EmbeddingFunction = None,
|
||||
content_payload_key: str = "_content",
|
||||
metadata_payload_key: str = "_metadata",
|
||||
collection_options: dict = {},
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the vector database.
|
||||
|
||||
Args:
|
||||
client: qdrant_client.QdrantClient | An instance of QdrantClient.
|
||||
embedding_function: Callable | The embedding function used to generate the vector representation
|
||||
of the documents. Defaults to FastEmbedEmbeddingFunction.
|
||||
collection_options: dict | The options for creating the collection.
|
||||
kwargs: dict | Additional keyword arguments.
|
||||
"""
|
||||
self.client: QdrantClient = client if client is not None else QdrantClient(location=":memory:")
|
||||
self.embedding_function = FastEmbedEmbeddingFunction() or embedding_function
|
||||
self.collection_options = collection_options
|
||||
self.content_payload_key = content_payload_key
|
||||
self.metadata_payload_key = metadata_payload_key
|
||||
self.type = "qdrant"
|
||||
|
||||
def create_collection(self, collection_name: str, overwrite: bool = False, get_or_create: bool = True) -> None:
|
||||
"""
|
||||
Create a collection in the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
|
||||
get_or_create: bool | Whether to get the collection if it exists. Default is True.
|
||||
|
||||
Returns:
|
||||
Any | The collection object.
|
||||
"""
|
||||
embeddings_size = len(self.embedding_function(["test"])[0])
|
||||
|
||||
if self.client.collection_exists(collection_name) and overwrite:
|
||||
self.client.delete_collection(collection_name)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
self.client.create_collection(
|
||||
collection_name,
|
||||
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
|
||||
**self.collection_options,
|
||||
)
|
||||
|
||||
def get_collection(self, collection_name: str = None):
|
||||
"""
|
||||
Get the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
|
||||
Returns:
|
||||
Any | The collection object.
|
||||
"""
|
||||
if collection_name is None:
|
||||
raise ValueError("The collection name is required.")
|
||||
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""Delete the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
|
||||
Returns:
|
||||
Any
|
||||
"""
|
||||
return self.client.delete_collection(collection_name)
|
||||
|
||||
def insert_docs(self, docs: List[Document], collection_name: str = None, upsert: bool = False) -> None:
|
||||
"""
|
||||
Insert documents into the collection of the vector database.
|
||||
|
||||
Args:
|
||||
docs: List[Document] | A list of documents. Each document is a TypedDict `Document`.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
upsert: bool | Whether to update the document if it exists. Default is False.
|
||||
kwargs: Dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not docs:
|
||||
return
|
||||
if any(doc.get("content") is None for doc in docs):
|
||||
raise ValueError("The document content is required.")
|
||||
if any(doc.get("id") is None for doc in docs):
|
||||
raise ValueError("The document id is required.")
|
||||
|
||||
if not upsert and not self._validate_upsert_ids(collection_name, [doc["id"] for doc in docs]):
|
||||
logger.log("Some IDs already exist. Skipping insert", level=logging.WARN)
|
||||
|
||||
self.client.upsert(collection_name, points=self._documents_to_points(docs))
|
||||
|
||||
def update_docs(self, docs: List[Document], collection_name: str = None) -> None:
|
||||
if not docs:
|
||||
return
|
||||
if any(doc.get("id") is None for doc in docs):
|
||||
raise ValueError("The document id is required.")
|
||||
if any(doc.get("content") is None for doc in docs):
|
||||
raise ValueError("The document content is required.")
|
||||
if self._validate_update_ids(collection_name, [doc["id"] for doc in docs]):
|
||||
return self.client.upsert(collection_name, points=self._documents_to_points(docs))
|
||||
|
||||
raise ValueError("Some IDs do not exist. Skipping update")
|
||||
|
||||
def delete_docs(self, ids: List[ItemID], collection_name: str = None, **kwargs) -> None:
|
||||
"""
|
||||
Delete documents from the collection of the vector database.
|
||||
|
||||
Args:
|
||||
ids: List[ItemID] | A list of document ids. Each id is a typed `ItemID`.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
kwargs: Dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.client.delete(collection_name, ids)
|
||||
|
||||
def retrieve_docs(
|
||||
self,
|
||||
queries: List[str],
|
||||
collection_name: str = None,
|
||||
n_results: int = 10,
|
||||
distance_threshold: float = 0,
|
||||
**kwargs,
|
||||
) -> QueryResults:
|
||||
"""
|
||||
Retrieve documents from the collection of the vector database based on the queries.
|
||||
|
||||
Args:
|
||||
queries: List[str] | A list of queries. Each query is a string.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
n_results: int | The number of relevant documents to return. Default is 10.
|
||||
distance_threshold: float | The threshold for the distance score, only distance smaller than it will be
|
||||
returned. Don't filter with it if < 0. Default is 0.
|
||||
kwargs: Dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
QueryResults | The query results. Each query result is a list of list of tuples containing the document and
|
||||
the distance.
|
||||
"""
|
||||
embeddings = self.embedding_function(queries)
|
||||
requests = [
|
||||
models.SearchRequest(
|
||||
vector=embedding,
|
||||
limit=n_results,
|
||||
score_threshold=distance_threshold,
|
||||
with_payload=True,
|
||||
with_vector=False,
|
||||
)
|
||||
for embedding in embeddings
|
||||
]
|
||||
|
||||
batch_results = self.client.search_batch(collection_name, requests)
|
||||
return [self._scored_points_to_documents(results) for results in batch_results]
|
||||
|
||||
def get_docs_by_ids(
|
||||
self, ids: List[ItemID] = None, collection_name: str = None, include=True, **kwargs
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Retrieve documents from the collection of the vector database based on the ids.
|
||||
|
||||
Args:
|
||||
ids: List[ItemID] | A list of document ids. If None, will return all the documents. Default is None.
|
||||
collection_name: str | The name of the collection. Default is None.
|
||||
include: List[str] | The fields to include. Default is True.
|
||||
If None, will include ["metadatas", "documents"], ids will always be included.
|
||||
kwargs: dict | Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
List[Document] | The results.
|
||||
"""
|
||||
if ids is None:
|
||||
results = self.client.scroll(collection_name=collection_name, with_payload=include, with_vectors=True)[0]
|
||||
else:
|
||||
results = self.client.retrieve(collection_name, ids=ids, with_payload=include, with_vectors=True)
|
||||
return [self._point_to_document(result) for result in results]
|
||||
|
||||
def _point_to_document(self, point) -> Document:
|
||||
return {
|
||||
"id": point.id,
|
||||
"content": point.payload.get(self.content_payload_key, ""),
|
||||
"metadata": point.payload.get(self.metadata_payload_key, {}),
|
||||
"embedding": point.vector,
|
||||
}
|
||||
|
||||
def _points_to_documents(self, points) -> List[Document]:
|
||||
return [self._point_to_document(point) for point in points]
|
||||
|
||||
def _scored_point_to_document(self, scored_point: models.ScoredPoint) -> Tuple[Document, float]:
|
||||
return self._point_to_document(scored_point), scored_point.score
|
||||
|
||||
def _documents_to_points(self, documents: List[Document]):
|
||||
contents = [document["content"] for document in documents]
|
||||
embeddings = self.embedding_function(contents)
|
||||
points = [
|
||||
models.PointStruct(
|
||||
id=documents[i]["id"],
|
||||
vector=embeddings[i],
|
||||
payload={
|
||||
self.content_payload_key: documents[i].get("content"),
|
||||
self.metadata_payload_key: documents[i].get("metadata"),
|
||||
},
|
||||
)
|
||||
for i in range(len(documents))
|
||||
]
|
||||
return points
|
||||
|
||||
def _scored_points_to_documents(self, scored_points: List[models.ScoredPoint]) -> List[Tuple[Document, float]]:
|
||||
return [self._scored_point_to_document(scored_point) for scored_point in scored_points]
|
||||
|
||||
def _validate_update_ids(self, collection_name: str, ids: List[str]) -> bool:
|
||||
"""
|
||||
Validates all the IDs exist in the collection
|
||||
"""
|
||||
retrieved_ids = [
|
||||
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
|
||||
]
|
||||
|
||||
if missing_ids := set(ids) - set(retrieved_ids):
|
||||
logger.log(f"Missing IDs: {missing_ids}. Skipping update", level=logging.WARN)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _validate_upsert_ids(self, collection_name: str, ids: List[str]) -> bool:
|
||||
"""
|
||||
Validate none of the IDs exist in the collection
|
||||
"""
|
||||
retrieved_ids = [
|
||||
point.id for point in self.client.retrieve(collection_name, ids=ids, with_payload=False, with_vectors=False)
|
||||
]
|
||||
|
||||
if existing_ids := set(ids) & set(retrieved_ids):
|
||||
logger.log(f"Existing IDs: {existing_ids}.", level=logging.WARN)
|
||||
return False
|
||||
|
||||
return True
|
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
|
||||
try:
|
||||
import uuid
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
from autogen.agentchat.contrib.vectordb.qdrant import QdrantVectorDB
|
||||
except ImportError:
|
||||
skip = True
|
||||
else:
|
||||
skip = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(skip, reason="dependency is not installed")
|
||||
def test_qdrant():
|
||||
# test create collection
|
||||
client = QdrantClient(location=":memory:")
|
||||
db = QdrantVectorDB(client=client)
|
||||
collection_name = uuid.uuid4().hex
|
||||
db.create_collection(collection_name, overwrite=True, get_or_create=True)
|
||||
assert client.collection_exists(collection_name)
|
||||
|
||||
# test_delete_collection
|
||||
db.delete_collection(collection_name)
|
||||
assert not client.collection_exists(collection_name)
|
||||
|
||||
# test_get_collection
|
||||
db.create_collection(collection_name, overwrite=True, get_or_create=True)
|
||||
collection_info = db.get_collection(collection_name)
|
||||
# Assert default FastEmbed model dimensions
|
||||
assert collection_info.config.params.vectors.size == 384
|
||||
|
||||
# test_insert_docs
|
||||
docs = [{"content": "doc1", "id": 1}, {"content": "doc2", "id": 2}]
|
||||
db.insert_docs(docs, collection_name, upsert=False)
|
||||
res = db.get_docs_by_ids([1, 2], collection_name)
|
||||
assert res[0]["id"] == 1
|
||||
assert res[0]["content"] == "doc1"
|
||||
assert res[1]["id"] == 2
|
||||
assert res[1]["content"] == "doc2"
|
||||
|
||||
# test_update_docs and get_docs_by_ids
|
||||
docs = [{"content": "doc11", "id": 1}, {"content": "doc22", "id": 2}]
|
||||
db.update_docs(docs, collection_name)
|
||||
res = db.get_docs_by_ids([1, 2], collection_name)
|
||||
assert res[0]["id"] == 1
|
||||
assert res[0]["content"] == "doc11"
|
||||
assert res[1]["id"] == 2
|
||||
assert res[1]["content"] == "doc22"
|
||||
|
||||
# test_retrieve_docs
|
||||
queries = ["doc22", "doc11"]
|
||||
res = db.retrieve_docs(queries, collection_name)
|
||||
assert [[r[0]["id"] for r in rr] for rr in res] == [[2, 1], [1, 2]]
|
||||
|
||||
# test_delete_docs
|
||||
db.delete_docs([1], collection_name)
|
||||
assert db.client.count(collection_name).count == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_qdrant()
|
Loading…
Reference in New Issue