mirror of https://github.com/microsoft/autogen.git
Merge b1353a681b
into feef9d4d37
This commit is contained in:
commit
91c0d8b265
|
@ -91,6 +91,17 @@ jobs:
|
|||
image: mongodb/mongodb-atlas-local:latest
|
||||
ports:
|
||||
- 27017:27017
|
||||
couchbase:
|
||||
image: couchbase:enterprise-7.6.3
|
||||
ports:
|
||||
- "8091-8095:8091-8095"
|
||||
- "11210:11210"
|
||||
- "9102:9102"
|
||||
healthcheck: # checks couchbase server is up
|
||||
test: ["CMD", "curl", "-v", "http://localhost:8091/pools"]
|
||||
interval: 20s
|
||||
timeout: 20s
|
||||
retries: 5
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -111,6 +122,9 @@ jobs:
|
|||
- name: Install mongodb when on linux
|
||||
run: |
|
||||
pip install -e .[retrievechat-mongodb]
|
||||
- name: Install couchbase when on linux
|
||||
run: |
|
||||
pip install -e .[retrievechat-couchbase]
|
||||
- name: Install unstructured when python-version is 3.9 and on linux
|
||||
if: matrix.python-version == '3.9'
|
||||
run: |
|
||||
|
|
|
@ -201,7 +201,7 @@ class VectorDBFactory:
|
|||
Factory class for creating vector databases.
|
||||
"""
|
||||
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant"]
|
||||
PREDEFINED_VECTOR_DB = ["chroma", "pgvector", "mongodb", "qdrant", "couchbase"]
|
||||
|
||||
@staticmethod
|
||||
def create_vector_db(db_type: str, **kwargs) -> VectorDB:
|
||||
|
@ -231,6 +231,10 @@ class VectorDBFactory:
|
|||
from .qdrant import QdrantVectorDB
|
||||
|
||||
return QdrantVectorDB(**kwargs)
|
||||
if db_type.lower() in ["couchbase", "couchbasedb","capella"]:
|
||||
from .couchbase import CouchbaseVectorDB
|
||||
|
||||
return CouchbaseVectorDB(**kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported vector database type: {db_type}. Valid types are {VectorDBFactory.PREDEFINED_VECTOR_DB}."
|
||||
|
|
|
@ -0,0 +1,400 @@
|
|||
import json
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from couchbase import search
|
||||
from couchbase.cluster import Cluster, ClusterOptions
|
||||
from couchbase.collection import Collection
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
from couchbase.options import SearchOptions
|
||||
from couchbase.management.search import SearchIndex
|
||||
from couchbase.vector_search import VectorQuery, VectorSearch
|
||||
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from .base import Document, ItemID, QueryResults, VectorDB
|
||||
from .utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
DEFAULT_BATCH_SIZE = 1000
|
||||
_SAMPLE_SENTENCE = ["The weather is lovely today in paradise."]
|
||||
TEXT_KEY = "content"
|
||||
EMBEDDING_KEY = "embedding"
|
||||
|
||||
|
||||
class CouchbaseVectorDB(VectorDB):
|
||||
"""
|
||||
A vector database implementation that uses Couchbase as the backend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_string: str = "couchbase://localhost",
|
||||
username: str = "Administrator",
|
||||
password: str = "password",
|
||||
bucket_name: str = "vector_db",
|
||||
embedding_function: Callable = SentenceTransformer("all-MiniLM-L6-v2").encode,
|
||||
scope_name: str = "_default",
|
||||
collection_name: str = "_default",
|
||||
index_name: str = None,
|
||||
):
|
||||
"""
|
||||
Initialize the vector database.
|
||||
|
||||
Args:
|
||||
connection_string (str): The Couchbase connection string to connect to. Default is 'couchbase://localhost'.
|
||||
username (str): The username for Couchbase authentication. Default is 'Administrator'.
|
||||
password (str): The password for Couchbase authentication. Default is 'password'.
|
||||
bucket_name (str): The name of the bucket. Default is 'vector_db'.
|
||||
embedding_function (Callable): The embedding function used to generate the vector representation. Default is SentenceTransformer("all-MiniLM-L6-v2").encode.
|
||||
scope_name (str): The name of the scope. Default is '_default'.
|
||||
collection_name (str): The name of the collection to create for this vector database. Default is '_default'.
|
||||
index_name (str): Index name for the vector database. Default is None.
|
||||
overwrite (bool): Whether to overwrite existing data. Default is False.
|
||||
wait_until_index_ready (float | None): Blocking call to wait until the database indexes are ready. None means no wait. Default is None.
|
||||
wait_until_document_ready (float | None): Blocking call to wait until the database documents are ready. None means no wait. Default is None.
|
||||
"""
|
||||
print("CouchbaseVectorDB", connection_string, username, password, bucket_name, scope_name, collection_name, index_name)
|
||||
self.embedding_function = embedding_function
|
||||
self.index_name = index_name
|
||||
|
||||
# This will get the model dimension size by computing the embeddings dimensions
|
||||
self.dimensions = self._get_embedding_size()
|
||||
|
||||
try:
|
||||
auth = PasswordAuthenticator(username, password)
|
||||
cluster = Cluster(connection_string, ClusterOptions(auth))
|
||||
cluster.wait_until_ready(timedelta(seconds=5))
|
||||
self.cluster = cluster
|
||||
|
||||
self.bucket = cluster.bucket(bucket_name)
|
||||
self.scope = self.bucket.scope(scope_name)
|
||||
self.collection = self.scope.collection(collection_name)
|
||||
self.active_collection = self.collection
|
||||
|
||||
logger.debug("Successfully connected to Couchbase")
|
||||
except Exception as err:
|
||||
raise ConnectionError("Could not connect to Couchbase server") from err
|
||||
|
||||
def search_index_exists(self, index_name: str):
|
||||
"""Check if the specified index is ready"""
|
||||
try:
|
||||
search_index_mgr = self.scope.search_indexes()
|
||||
index = search_index_mgr.get_index(index_name)
|
||||
return index.is_valid()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _get_embedding_size(self):
|
||||
return len(self.embedding_function(_SAMPLE_SENTENCE)[0])
|
||||
|
||||
def create_collection(
|
||||
self,
|
||||
collection_name: str,
|
||||
overwrite: bool = False,
|
||||
get_or_create: bool = True,
|
||||
) -> Collection:
|
||||
"""
|
||||
Create a collection in the vector database and create a vector search index in the collection.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
overwrite: bool | Whether to overwrite the collection if it exists. Default is False.
|
||||
get_or_create: bool | Whether to get or create the collection. Default is True
|
||||
"""
|
||||
if overwrite:
|
||||
self.delete_collection(collection_name)
|
||||
|
||||
try:
|
||||
collection_mgr = self.bucket.collections()
|
||||
collection_mgr.create_collection(self.scope.name, collection_name)
|
||||
|
||||
except Exception:
|
||||
if not get_or_create:
|
||||
raise ValueError(f"Collection {collection_name} already exists.")
|
||||
else:
|
||||
logger.debug(f"Collection {collection_name} already exists. Getting the collection.")
|
||||
|
||||
collection = self.scope.collection(collection_name)
|
||||
self.create_index_if_not_exists(index_name=self.index_name, collection=collection)
|
||||
return collection
|
||||
|
||||
def create_index_if_not_exists(self, index_name: str = "vector_index", collection=None) -> None:
|
||||
"""
|
||||
Creates a vector search index on the specified collection in Couchbase.
|
||||
|
||||
Args:
|
||||
index_name (str, optional): The name of the vector search index to create. Defaults to "vector_search_index".
|
||||
collection (Collection, optional): The Couchbase collection to create the index on. Defaults to None.
|
||||
"""
|
||||
if not self.search_index_exists(index_name):
|
||||
self.create_vector_search_index(collection, index_name)
|
||||
|
||||
def get_collection(self, collection_name: str = None) -> Collection:
|
||||
"""
|
||||
Get the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection. Default is None. If None, return the
|
||||
current active collection.
|
||||
|
||||
Returns:
|
||||
Collection | The collection object.
|
||||
"""
|
||||
if collection_name is None:
|
||||
if self.active_collection is None:
|
||||
raise ValueError("No collection is specified.")
|
||||
else:
|
||||
logger.debug(
|
||||
f"No collection is specified. Using current active collection {self.active_collection.name}."
|
||||
)
|
||||
else:
|
||||
self.active_collection = self.scope.collection(collection_name)
|
||||
|
||||
return self.active_collection
|
||||
|
||||
def delete_collection(self, collection_name: str) -> None:
|
||||
"""
|
||||
Delete the collection from the vector database.
|
||||
|
||||
Args:
|
||||
collection_name: str | The name of the collection.
|
||||
"""
|
||||
try:
|
||||
collection_mgr = self.bucket.collections()
|
||||
collection_mgr.drop_collection(self.scope.name, collection_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting collection: {e}")
|
||||
|
||||
def create_vector_search_index(
|
||||
self,
|
||||
collection,
|
||||
index_name: Union[str, None] = "vector_index",
|
||||
similarity: Literal["l2_norm", "dot_product"] = "dot_product",
|
||||
) -> None:
|
||||
"""Create a vector search index in the collection."""
|
||||
search_index_mgr = self.scope.search_indexes()
|
||||
dims = self._get_embedding_size()
|
||||
index_definition = {
|
||||
"type": "fulltext-index",
|
||||
"name": index_name,
|
||||
"sourceType": "couchbase",
|
||||
"sourceName": self.bucket.name,
|
||||
"planParams": {"maxPartitionsPerPIndex": 1024, "indexPartitions": 1},
|
||||
"params": {
|
||||
"doc_config": {
|
||||
"docid_prefix_delim": "",
|
||||
"docid_regexp": "",
|
||||
"mode": "scope.collection.type_field",
|
||||
"type_field": "type",
|
||||
},
|
||||
"mapping": {
|
||||
"analysis": {},
|
||||
"default_analyzer": "standard",
|
||||
"default_datetime_parser": "dateTimeOptional",
|
||||
"default_field": "_all",
|
||||
"default_mapping": {"dynamic": True, "enabled": False},
|
||||
"default_type": "_default",
|
||||
"docvalues_dynamic": False,
|
||||
"index_dynamic": True,
|
||||
"store_dynamic": True,
|
||||
"type_field": "_type",
|
||||
"types": {
|
||||
f"{self.scope.name}.{collection.name}": {
|
||||
"dynamic": False,
|
||||
"enabled": True,
|
||||
"properties": {
|
||||
"embedding": {
|
||||
"dynamic": False,
|
||||
"enabled": True,
|
||||
"fields": [
|
||||
{
|
||||
"dims": dims,
|
||||
"index": True,
|
||||
"name": "embedding",
|
||||
"similarity": similarity,
|
||||
"type": "vector",
|
||||
"vector_index_optimized_for": "recall",
|
||||
}
|
||||
],
|
||||
},
|
||||
"metadata": {"dynamic": True, "enabled": True},
|
||||
"content": {
|
||||
"dynamic": False,
|
||||
"enabled": True,
|
||||
"fields": [
|
||||
{
|
||||
"include_in_all": True,
|
||||
"index": True,
|
||||
"name": "content",
|
||||
"store": True,
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
"store": {"indexType": "scorch", "segmentVersion": 16},
|
||||
},
|
||||
"sourceParams": {},
|
||||
}
|
||||
|
||||
search_index_def = SearchIndex.from_json(json.dumps(index_definition))
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
search_index_mgr.upsert_index(search_index_def)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.debug(f"Attempt {attempt + 1}/{max_attempts}: Error creating search index: {e}")
|
||||
time.sleep(3)
|
||||
attempt += 1
|
||||
|
||||
if attempt == max_attempts:
|
||||
logger.error(f"Error creating search index after {max_attempts} attempts.")
|
||||
raise RuntimeError(f"Error creating search index after {max_attempts} attempts.")
|
||||
|
||||
logger.info(f"Search index {index_name} created successfully.")
|
||||
|
||||
def upsert_docs(self, docs: List[Document], collection: Collection, batch_size=DEFAULT_BATCH_SIZE,
|
||||
**kwargs: Any) -> None:
|
||||
if docs[0].get("content") is None:
|
||||
raise ValueError("The document content is required.")
|
||||
if docs[0].get("id") is None:
|
||||
raise ValueError("The document id is required.")
|
||||
|
||||
for i in range(0, len(docs), batch_size):
|
||||
batch = docs[i:i + batch_size]
|
||||
docs_to_upsert = dict()
|
||||
for doc in batch:
|
||||
doc_id = doc["id"]
|
||||
embedding = self.embedding_function(
|
||||
[doc["content"]]).tolist() # Gets new embedding even in case of document update
|
||||
|
||||
doc_content = {
|
||||
TEXT_KEY: doc["content"],
|
||||
"metadata": doc.get("metadata", {}),
|
||||
EMBEDDING_KEY: embedding
|
||||
}
|
||||
docs_to_upsert[doc_id] = doc_content
|
||||
collection.upsert_multi(docs_to_upsert)
|
||||
|
||||
def insert_docs(
|
||||
self,
|
||||
docs: List[Document],
|
||||
collection_name: str = None,
|
||||
upsert: bool = False,
|
||||
batch_size=DEFAULT_BATCH_SIZE,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Insert Documents and Vector Embeddings into the collection of the vector database. Documents are upserted in all cases."""
|
||||
if not docs:
|
||||
logger.info("No documents to insert.")
|
||||
return
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
self.upsert_docs(docs, collection, batch_size=batch_size)
|
||||
|
||||
def update_docs(self, docs: List[Document], collection_name: str = None, batch_size=DEFAULT_BATCH_SIZE,
|
||||
**kwargs: Any) -> None:
|
||||
"""Update documents, including their embeddings, in the Collection."""
|
||||
collection = self.get_collection(collection_name)
|
||||
self.upsert_docs(docs, collection, batch_size)
|
||||
|
||||
def delete_docs(self, ids: List[ItemID], collection_name: str = None, batch_size=DEFAULT_BATCH_SIZE, **kwargs):
|
||||
"""Delete documents from the collection of the vector database."""
|
||||
collection = self.get_collection(collection_name)
|
||||
# based on batch size, delete the documents
|
||||
for i in range(0, len(ids), batch_size):
|
||||
batch = ids[i:i + batch_size]
|
||||
collection.remove_multi(batch)
|
||||
|
||||
def get_docs_by_ids(
|
||||
self, ids: List[ItemID] | None = None, collection_name: str = None, include: List[str] | None = None,
|
||||
**kwargs
|
||||
) -> List[Document]:
|
||||
"""Retrieve documents from the collection of the vector database based on the ids."""
|
||||
if include is None:
|
||||
include = [TEXT_KEY, "metadata", "id"]
|
||||
elif "id" not in include:
|
||||
include.append("id")
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
if ids is not None:
|
||||
docs = [collection.get(doc_id) for doc_id in ids]
|
||||
else:
|
||||
# Get all documents using couchbase query
|
||||
include_str = ", ".join(include)
|
||||
query = f"SELECT {include_str} FROM {self.bucket.name}.{self.scope.name}.{collection.name}"
|
||||
result = self.cluster.query(query)
|
||||
docs = []
|
||||
for row in result:
|
||||
docs.append(row)
|
||||
|
||||
return [{k: v for k, v in doc.items() if k in include or k == "id"} for doc in docs]
|
||||
|
||||
def retrieve_docs(
|
||||
self,
|
||||
queries: List[str],
|
||||
collection_name: str = None,
|
||||
n_results: int = 10,
|
||||
distance_threshold: float = -1,
|
||||
**kwargs,
|
||||
) -> QueryResults:
|
||||
"""Retrieve documents from the collection of the vector database based on the queries.
|
||||
Note: Distance threshold is not supported in Couchbase FTS.
|
||||
"""
|
||||
|
||||
collection = self.get_collection(collection_name)
|
||||
results: QueryResults = []
|
||||
for query_text in queries:
|
||||
query_vector = np.array(self.embedding_function([query_text])).tolist()[0]
|
||||
query_result = self._vector_search(
|
||||
query_vector,
|
||||
n_results,
|
||||
**kwargs,
|
||||
)
|
||||
results.append(query_result)
|
||||
return results
|
||||
|
||||
def _vector_search(
|
||||
self,
|
||||
embedding_vector: List[float],
|
||||
n_results: int = 10,
|
||||
**kwargs
|
||||
) -> List[Tuple[Dict, float]]:
|
||||
"""Core vector search using Couchbase FTS."""
|
||||
|
||||
search_req = search.SearchRequest.create(
|
||||
VectorSearch.from_vector_query(
|
||||
VectorQuery(
|
||||
EMBEDDING_KEY,
|
||||
embedding_vector,
|
||||
n_results,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
search_options = SearchOptions(limit=n_results, fields=["*"])
|
||||
result = self.scope.search(
|
||||
self.index_name,
|
||||
search_req,
|
||||
search_options
|
||||
)
|
||||
|
||||
docs_with_score = []
|
||||
|
||||
for row in result.rows():
|
||||
doc = row.fields
|
||||
doc["id"] = row.id
|
||||
score = row.score
|
||||
|
||||
docs_with_score.append((doc, score))
|
||||
|
||||
return docs_with_score
|
1
setup.py
1
setup.py
|
@ -74,6 +74,7 @@ extra_require = {
|
|||
"retrievechat-pgvector": retrieve_chat_pgvector,
|
||||
"retrievechat-mongodb": [*retrieve_chat, "pymongo>=4.0.0"],
|
||||
"retrievechat-qdrant": [*retrieve_chat, "qdrant_client", "fastembed>=0.3.1"],
|
||||
"retrievechat-couchbase": [*retrieve_chat, "couchbase>=4.3.0"],
|
||||
"autobuild": ["chromadb", "sentence-transformers", "huggingface-hub", "pysqlite3"],
|
||||
"teachable": ["chromadb"],
|
||||
"lmm": ["replicate", "pillow"],
|
||||
|
|
|
@ -0,0 +1,153 @@
|
|||
import logging
|
||||
import os
|
||||
import random
|
||||
from time import monotonic, sleep
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from autogen.agentchat.contrib.vectordb.base import Document
|
||||
|
||||
try:
|
||||
|
||||
import sentence_transformers
|
||||
import couchbase
|
||||
from autogen.agentchat.contrib.vectordb.couchbase import CouchbaseVectorDB
|
||||
except ImportError:
|
||||
print("skipping test_couchbase.py. It requires one to pip install couchbase or the extra [retrievechat-couchbase]")
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
f"skipping {__name__}. It requires one to pip install couchbase or the extra [retrievechat-couchbase]")
|
||||
pytest.skip("Required modules not installed", allow_module_level=True)
|
||||
|
||||
from couchbase.cluster import Cluster, ClusterOptions
|
||||
from couchbase.auth import PasswordAuthenticator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get the directory of the current script
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Construct the absolute path to the .env file
|
||||
env_path = os.path.join(script_dir, ".env")
|
||||
|
||||
# Load the .env file
|
||||
load_dotenv(env_path)
|
||||
load_dotenv(".env")
|
||||
|
||||
COUCHBASE_HOST = os.environ.get("CB_CONN_STR", "couchbase://localhost")
|
||||
COUCHBASE_USERNAME = os.environ.get("CB_USERNAME", "Administrator")
|
||||
COUCHBASE_PASSWORD = os.environ.get("CB_PASSWORD", "password")
|
||||
COUCHBASE_BUCKET = os.environ.get("CB_BUCKET", "autogen_test_bucket")
|
||||
COUCHBASE_SCOPE = os.environ.get("CB_SCOPE", "_default")
|
||||
COUCHBASE_COLLECTION = os.environ.get("CB_COLLECTION", "autogen_test_vectorstore")
|
||||
COUCHBASE_INDEX = os.environ.get("CB_INDEX_NAME", "vector_index")
|
||||
|
||||
RETRIES = 10
|
||||
DELAY = 2
|
||||
TIMEOUT = 120.0
|
||||
|
||||
|
||||
def _empty_collections_and_delete_indexes(cluster:Cluster, bucket_name, scope_name, collections=None):
|
||||
bucket = cluster.bucket(bucket_name)
|
||||
try:
|
||||
scope_manager = bucket.collections().get_all_scopes(scope_name=scope_name)
|
||||
for scope_ in scope_manager:
|
||||
all_collections = scope_.collections
|
||||
for curr_collection in all_collections:
|
||||
bucket.collections().drop_collection( scope_name, curr_collection.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to drop collections: {e}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db():
|
||||
print("Creating couchbase connection", COUCHBASE_HOST, COUCHBASE_USERNAME, COUCHBASE_PASSWORD)
|
||||
cluster = Cluster(COUCHBASE_HOST, ClusterOptions(PasswordAuthenticator(COUCHBASE_USERNAME, COUCHBASE_PASSWORD)))
|
||||
_empty_collections_and_delete_indexes(cluster, COUCHBASE_BUCKET, COUCHBASE_SCOPE)
|
||||
vectorstore = CouchbaseVectorDB(
|
||||
connection_string=COUCHBASE_HOST,
|
||||
username=COUCHBASE_USERNAME,
|
||||
password=COUCHBASE_PASSWORD,
|
||||
bucket_name=COUCHBASE_BUCKET,
|
||||
scope_name=COUCHBASE_SCOPE,
|
||||
collection_name=COUCHBASE_COLLECTION,
|
||||
index_name=COUCHBASE_INDEX,
|
||||
)
|
||||
yield vectorstore
|
||||
_empty_collections_and_delete_indexes(cluster, COUCHBASE_BUCKET, COUCHBASE_SCOPE)
|
||||
|
||||
|
||||
_COLLECTION_NAMING_CACHE = []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def collection_name():
|
||||
collection_id = random.randint(0, 100)
|
||||
while collection_id in _COLLECTION_NAMING_CACHE:
|
||||
collection_id = random.randint(0, 100)
|
||||
_COLLECTION_NAMING_CACHE.append(collection_id)
|
||||
return f"{COUCHBASE_COLLECTION}_{collection_id}"
|
||||
|
||||
def test_couchbase(db, collection_name):
|
||||
# db = CouchbaseVectorDB(path=".db")
|
||||
with pytest.raises(Exception):
|
||||
curr_col = db.get_collection(collection_name)
|
||||
curr_col.upsert("1", {"content": "Dogs are lovely."})
|
||||
|
||||
collection = db.create_collection(collection_name, overwrite=True, get_or_create=True)
|
||||
assert collection.name == collection_name
|
||||
collection.upsert("1", {"content": "Dogs are lovely."})
|
||||
|
||||
# test_delete_collection
|
||||
db.delete_collection(collection_name)
|
||||
sleep(5) # wait for the collection to be deleted
|
||||
with pytest.raises(Exception):
|
||||
curr_col = db.get_collection(collection_name)
|
||||
curr_col.upsert("1", {"content": "Dogs are lovely."})
|
||||
|
||||
# test more create collection
|
||||
collection = db.create_collection(collection_name, overwrite=False, get_or_create=False)
|
||||
assert collection.name == collection_name
|
||||
pytest.raises(ValueError, db.create_collection, collection_name, overwrite=False, get_or_create=False)
|
||||
collection = db.create_collection(collection_name, overwrite=True, get_or_create=False)
|
||||
assert collection.name == collection_name
|
||||
collection = db.create_collection(collection_name, overwrite=False, get_or_create=True)
|
||||
assert collection.name == collection_name
|
||||
|
||||
# test_get_collection
|
||||
collection = db.get_collection(collection_name)
|
||||
assert collection.name == collection_name
|
||||
|
||||
# test_insert_docs
|
||||
docs = [{"content": "doc1", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
||||
db.insert_docs(docs, collection_name, upsert=False)
|
||||
res = db.get_collection(collection_name).get_multi(["1", "2"]).results
|
||||
|
||||
assert res["1"].value["content"] == "doc1"
|
||||
assert res["2"].value["content"] == "doc2"
|
||||
|
||||
|
||||
# test_update_docs
|
||||
docs = [{"content": "doc11", "id": "1"}, {"content": "doc2", "id": "2"}, {"content": "doc3", "id": "3"}]
|
||||
db.update_docs(docs, collection_name)
|
||||
res = db.get_collection(collection_name).get_multi(["1", "2"]).results
|
||||
assert res["1"].value["content"] == "doc11"
|
||||
assert res["2"].value["content"] == "doc2"
|
||||
|
||||
# test_delete_docs
|
||||
ids = ["1"]
|
||||
db.delete_docs(ids, collection_name)
|
||||
with pytest.raises(Exception):
|
||||
res = db.get_collection(collection_name).get(ids[0])
|
||||
|
||||
# test_retrieve_docs
|
||||
queries = ["doc2", "doc3"]
|
||||
res = db.retrieve_docs(queries, collection_name)
|
||||
texts = [[item[0]['content'] for item in sublist] for sublist in res]
|
||||
received_ids = [[item[0]['id'] for item in sublist] for sublist in res]
|
||||
|
||||
assert texts[0] == ["doc2", "doc3"]
|
||||
assert received_ids[0] == ["2", "3"]
|
||||
|
|
@ -202,8 +202,8 @@ ragproxyagent = RetrieveUserProxyAgent(
|
|||
|
||||
|
||||
### Customizing Vector Database
|
||||
We are using chromadb as the default vector database, you can also use mongodb, pgvectordb and qdrantdb
|
||||
by simply set `vector_db` to `mongodb`, `pgvector` and `qdrant` in `retrieve_config`, respectively.
|
||||
We are using chromadb as the default vector database, you can also use mongodb, pgvectordb, qdrantdb and couchbase
|
||||
by simply set `vector_db` to `mongodb`, `pgvector`, `qdrant` and `couchbase` in `retrieve_config`, respectively.
|
||||
|
||||
To plugin any other dbs, you can also extend class `agentchat.contrib.vectordb.base`,
|
||||
check out the code [here](https://github.com/microsoft/autogen/blob/main/autogen/agentchat/contrib/vectordb/base.py).
|
||||
|
|
Loading…
Reference in New Issue